commit 717a0b819fd10f45834ad12cea6a515dc554fba7 Author: Kevin Thomas Date: Sat Feb 7 13:54:28 2026 -0500 Initial commit: TMT lightweight threat modeling toolkit diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..206028e --- /dev/null +++ b/.gitignore @@ -0,0 +1,69 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# Distribution / packaging +build/ +dist/ +*.egg-info/ +*.egg +sdist/ +wheels/ + +# Virtual environments +.venv/ +venv/ +ENV/ + +# IDE / editor +.vscode/ +.idea/ +*.swp +*.swo +*~ +.DS_Store + +# Testing / coverage +.pytest_cache/ +htmlcov/ +.coverage +.coverage.* +coverage.xml +*.cover + +# Reports (generated output) +reports/ + +# Environment variables +.env +.env.* + +# Jupyter +.ipynb_checkpoints/ + +# Large model / data files +*.h5 +*.hdf5 +*.pkl +*.pickle +*.joblib +*.parquet +*.feather +*.arrow +*.onnx +*.pt +*.pth +*.bin +*.safetensors +*.gguf +*.ggml +*.tar.gz +*.zip +*.gz +*.bz2 +*.xz +*.7z + +# Logs +*.log diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..2973cfa --- /dev/null +++ b/Makefile @@ -0,0 +1,20 @@ +.PHONY: install test scan scan-llm clean + +install: + pip install -e ".[dev]" + +test: + pytest tests/ -v --tb=short + +test-cov: + pytest tests/ -v --cov=tmt --cov-report=term-missing + +scan: + python run_threat_model.py --target . --project-name "tmt-self-scan" + +scan-llm: + python run_threat_model.py --target . --project-name "tmt-self-scan" --llm + +clean: + rm -rf reports/ .pytest_cache __pycache__ tmt/__pycache__ tmt/**/__pycache__ tests/__pycache__ + find . -name "*.pyc" -delete diff --git a/README.md b/README.md new file mode 100644 index 0000000..c759980 --- /dev/null +++ b/README.md @@ -0,0 +1,326 @@ +![image](https://github.com/mytechnotalent/tmt/blob/main/tmt.png?raw=true) + +## FREE Reverse Engineering Self-Study Course [HERE](https://github.com/mytechnotalent/Reverse-Engineering-Tutorial) + +
+ +# Today's Tutorial [February 7, 2026] +## Lesson 103: ARM-32 Course 2 (Part 38 – Pre-Increment Operator) +This tutorial will discuss the pre-increment operator. + +-> Click [HERE](https://0xinfection.github.io/reversing) to read the FREE ebook. to read the FREE ebook. + +
+ +# TMT — Lightweight Threat Modeling Toolkit + +Author: [Kevin Thomas](mailto:ket189@pitt.edu) + +An open-source production-ready, release-cycle threat modeling loop that detects logic bugs (replay attacks, race conditions, token/invite abuse) through pattern-based scanning and optional LLM-powered deep review. + +Built for startup teams who need fast, repeatable security checks without heavyweight tools. + +--- + +## Architecture + +``` +┌─────────────────────────────────────────────────────────────┐ +│ run_threat_model.py │ +│ (CLI Entry Point) │ +├─────────────────────────────────────────────────────────────┤ +│ ThreatModelRunner │ +│ (Orchestrates the full loop) │ +├──────────────────────┬──────────────────────────────────────┤ +│ Pattern Scanners │ LLM Reviewer (optional) │ +│ ┌────────────────┐ │ ┌──────────────────────────────┐ │ +│ │ ReplayScanner │ │ │ HuggingFace (free, default) │ │ +│ │ RaceCondition │ │ │ OpenAI (GPT-4 / GPT-4o) │ │ +│ │ TokenAbuse │ │ │ Anthropic (Claude) │ │ +│ │ AuthSession │ │ │ Ollama (local, via base_url) │ │ +│ │ APIRoute │ │ └──────────────────────────────┘ │ +│ └────────────────┘ │ Structured prompts enforce JSON │ +├──────────────────────┴──────────────────────────────────────┤ +│ ReportGenerator │ +│ Markdown + JSON output files │ +└─────────────────────────────────────────────────────────────┘ +``` + +--- + +## Quick Start + +### 1. Install + +```bash +cd TMT +pip install -e ".[dev]" +``` + +### 2. Scan Your Codebase (Pattern-based Only) + +```bash +python run_threat_model.py --target /path/to/your/api --project-name "my-api" +``` + +### 3. Scan + LLM Review (Free with Hugging Face) + +```bash +export HF_TOKEN="hf_..." # Optional: get a free token at huggingface.co/settings/tokens +python run_threat_model.py \ + --target /path/to/your/api \ + --project-name "my-api" \ + --llm +``` + +### 4. Scan + LLM Review (OpenAI) + +```bash +export TMT_LLM_API_KEY="sk-..." +python run_threat_model.py \ + --target /path/to/your/api \ + --project-name "my-api" \ + --llm \ + --llm-provider openai \ + --llm-model gpt-4 +``` + +### 4. Use a Config File + +```bash +python run_threat_model.py --target /path/to/your/api --config config.yaml +``` + +### 5. Run Tests + +```bash +pytest tests/ -v --tb=short +``` + +--- + +## What It Detects + +### Replay Attacks +| Finding | Severity | CWE | +| ------------------------------------------ | -------- | ------- | +| POST without idempotency key | Medium | CWE-294 | +| Missing request timestamp validation | Low | CWE-294 | +| Token used without single-use invalidation | High | CWE-294 | + +### Race Conditions +| Finding | Severity | CWE | +| --------------------------------- | -------- | ------- | +| Non-atomic read-modify-write | High | CWE-362 | +| TOCTOU check-then-act pattern | High | CWE-367 | +| Unguarded concurrent redemption | Critical | CWE-362 | +| Shared mutable state without sync | Medium | CWE-362 | + +### Token & Invite Abuse +| Finding | Severity | CWE | +| ----------------------------------------------- | -------- | ------- | +| Token generation without rate limiting | High | CWE-799 | +| Predictable token generation (UUID1, weak PRNG) | Critical | CWE-330 | +| Token created without expiration | High | CWE-613 | +| Invite token allows multiple redemptions | High | CWE-841 | +| No token revocation on logout | High | CWE-613 | + +### Auth & Session +| Finding | Severity | CWE | +| ------------------------------------------ | -------- | ------- | +| Route missing authentication | High | CWE-306 | +| Insecure session cookie configuration | High | CWE-614 | +| Missing CSRF protection | Medium | CWE-352 | +| Weak password hashing (MD5/SHA1) | Critical | CWE-916 | +| Session not regenerated after login | High | CWE-384 | +| Object access without authorization (IDOR) | Critical | CWE-639 | + +### API Route Security +| Finding | Severity | CWE | +| --------------------------------- | -------- | ------- | +| Missing input validation | Medium | CWE-20 | +| Missing rate limiting | Medium | CWE-770 | +| Verbose error details exposed | Medium | CWE-209 | +| Overly permissive CORS (wildcard) | High | CWE-942 | +| Admin endpoint without role check | Critical | CWE-269 | +| Mass assignment vulnerability | Critical | CWE-915 | + +--- + +## LLM Review Prompts & Workflow + +TMT includes four battle-tested prompt templates designed to maximize signal and minimize noise: + +### Available Templates + +| Template | Focus | +| --------------- | --------------------------------------------------------------------- | +| `api_route` | Auth, input validation, rate limiting, CORS, IDOR, mass assignment | +| `auth_session` | Password storage, session fixation, JWT validation, MFA bypass, OAuth | +| `logic_bug` | Replay attacks, race conditions, TOCTOU, double-spend, state machines | +| `comprehensive` | All categories in a single pass | + +### How Prompts Work + +1. **Structured persona**: Security engineer context reduces hallucination +2. **Systematic checklist**: Forces the LLM to check each vulnerability class +3. **Evidence-based**: Only reports findings with concrete code references +4. **JSON output**: Enforced schema enables automated processing +5. **Confidence threshold**: Filters findings below 70% confidence + +### Using LLM Review Independently + +```python +from tmt.llm.prompts import PromptLibrary +from tmt.llm.reviewer import LLMReviewer +from tmt.config import LLMConfig + +# Build prompts for manual use (e.g., paste into ChatGPT) +library = PromptLibrary() +prompts = library.build_prompt("logic_bug", open("my_api.py").read()) +print(prompts["system"]) +print(prompts["user"]) + +# Or use the automated reviewer (free with Hugging Face) +config = LLMConfig(enabled=True, provider="huggingface", model="Qwen/Qwen2.5-72B-Instruct") +reviewer = LLMReviewer(config) +review = reviewer.review_file("my_api.py", open("my_api.py").read(), "comprehensive") +for finding in review.findings: + print(f"[{finding.severity.value}] {finding.title}") +``` + +--- + +## CI/CD Integration + +### Exit Codes + +| Code | Meaning | +| ---- | ----------------------------------- | +| 0 | No critical or high findings | +| 1 | High severity findings detected | +| 2 | Critical severity findings detected | + +### GitHub Actions Example + +```yaml +name: Threat Model +on: [pull_request] + +jobs: + threat-model: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: "3.11" + - run: pip install -e ".[dev]" + - run: | + python run_threat_model.py \ + --target ./src \ + --project-name "${{ github.repository }}" \ + --output-dir ./security-reports + - uses: actions/upload-artifact@v4 + if: always() + with: + name: threat-model-report + path: ./security-reports/ +``` + +### With LLM Review in CI + +```yaml + - run: | + python run_threat_model.py \ + --target ./src \ + --project-name "${{ github.repository }}" \ + --llm + env: + HF_TOKEN: ${{ secrets.HF_TOKEN }} +``` + +--- + +## Recommended Release Workflow + +Run this loop every release to catch logic bugs before they ship: + +``` +1. Pre-PR (developer): + └─ python run_threat_model.py --target ./src + +2. CI Pipeline (automated): + └─ Pattern scan + LLM review on every PR + └─ Block merge if exit code > 0 + +3. Pre-Release (security lead): + └─ Full scan with comprehensive LLM review + └─ Review Markdown report for new findings + └─ Track findings in issue tracker + +4. Post-Release: + └─ Archive report in security/ directory + └─ Compare finding counts to previous release +``` + +--- + +## Project Structure + +``` +TMT/ +├── run_threat_model.py # CLI entry point +├── config.yaml # Sample configuration +├── setup.py # Package setup +├── requirements.txt # Dependencies +├── tmt/ +│ ├── __init__.py +│ ├── config.py # YAML config loader +│ ├── models.py # Data models (Finding, ScanResult, etc.) +│ ├── runner.py # Threat model loop orchestrator +│ ├── scanners/ +│ │ ├── base_scanner.py # Shared scanner framework +│ │ ├── replay_scanner.py # Replay attack detection +│ │ ├── race_condition_scanner.py +│ │ ├── token_abuse_scanner.py +│ │ ├── auth_session_scanner.py +│ │ └── api_route_scanner.py +│ ├── llm/ +│ │ ├── prompts.py # Structured prompt templates +│ │ └── reviewer.py # Multi-provider LLM integration +│ └── reports/ +│ └── generator.py # Markdown + JSON report generator +└── tests/ + ├── fixtures/ + │ ├── vulnerable_api.py # Intentionally insecure (for testing) + │ └── secure_api.py # Properly secured (for false positive testing) + ├── test_scanners.py + ├── test_llm_reviewer.py + └── test_runner.py +``` + +--- + +## Configuration Reference + +| Setting | Default | Description | +| ---------------------------- | ---------------------------- | ------------------------------------------- | +| `project_name` | `unnamed-project` | Project identifier for reports | +| `target_dirs` | `[src, app, api]` | Directories to scan | +| `file_extensions` | `[.py, .js, .ts]` | File types to include | +| `exclude_dirs` | `[node_modules, .venv, ...]` | Directories to skip | +| `scanner.enabled` | `true` | Enable pattern scanning | +| `scanner.severity_threshold` | `low` | Minimum severity to report | +| `llm.enabled` | `false` | Enable LLM review | +| `llm.provider` | `huggingface` | LLM provider (huggingface/openai/anthropic) | +| `llm.model` | `Qwen/Qwen2.5-72B-Instruct` | Model identifier | +| `llm.temperature` | `0.1` | Low for deterministic results | +| `report.output_dir` | `reports` | Report output directory | +| `report.formats` | `[markdown, json]` | Output formats | + +--- + +## License + +MIT diff --git a/config.yaml b/config.yaml new file mode 100644 index 0000000..de3b51b --- /dev/null +++ b/config.yaml @@ -0,0 +1,67 @@ +# ────────────────────────────────────────────────────────────────────────────── +# TMT Configuration - Lightweight Threat Modeling Toolkit +# ────────────────────────────────────────────────────────────────────────────── +# Copy this file to your project root and customize for your codebase. + +project_name: "my-startup-api" + +# Directories to scan (relative to --target path) +target_dirs: + - "src" + - "app" + - "api" + - "routes" + - "handlers" + - "controllers" + +# File extensions to include +file_extensions: + - ".py" + - ".js" + - ".ts" + +# Directories to skip +exclude_dirs: + - "node_modules" + - ".venv" + - "venv" + - "__pycache__" + - ".git" + - "dist" + - "build" + - ".next" + +# ────────────────────────────────────────────────────────────────────────────── +# Pattern-based scanner settings +# ────────────────────────────────────────────────────────────────────────────── +scanner: + enabled: true + severity_threshold: "low" # Report findings at this level and above + custom_patterns: {} # Add custom patterns here (advanced) + +# ────────────────────────────────────────────────────────────────────────────── +# LLM-powered review settings +# ────────────────────────────────────────────────────────────────────────────── +# Set HF_TOKEN or TMT_LLM_API_KEY environment variable, or provide api_key below. +# Supported providers: huggingface (free), openai, anthropic +# Default uses Hugging Face free Inference API with Qwen2.5-72B-Instruct. +llm: + enabled: false + provider: "huggingface" + model: "Qwen/Qwen2.5-72B-Instruct" + # api_key: "" # Prefer HF_TOKEN or TMT_LLM_API_KEY env var + # base_url: "" # For Ollama: http://localhost:11434/v1 + temperature: 0.1 # Low temperature for deterministic results + max_tokens: 4096 + timeout_seconds: 120 + +# ────────────────────────────────────────────────────────────────────────────── +# Report output settings +# ────────────────────────────────────────────────────────────────────────────── +report: + output_dir: "reports" + formats: + - "markdown" + - "json" + include_code_snippets: true + max_snippet_lines: 10 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..7ae583f --- /dev/null +++ b/requirements.txt @@ -0,0 +1,8 @@ +pyyaml>=6.0 +openai>=1.0.0 +anthropic>=0.18.0 +huggingface-hub>=0.20.0 +jinja2>=3.1.0 +pytest>=7.0.0 +pytest-cov>=4.0.0 +black>=22.0.0 \ No newline at end of file diff --git a/run_threat_model.py b/run_threat_model.py new file mode 100644 index 0000000..2f8a8ae --- /dev/null +++ b/run_threat_model.py @@ -0,0 +1,180 @@ +#!/usr/bin/env python3 +"""CLI entry point for the TMT threat modeling toolkit. + +Run as a script or via the installed 'tmt' console command to execute +the full threat modeling loop: pattern scanning, optional LLM review, +and report generation. + +Usage: + python run_threat_model.py --target ./src --config config.yaml + python run_threat_model.py --target ./src --llm --llm-provider openai --llm-model gpt-4 + python run_threat_model.py --target ./src --output-dir ./security-reports +""" + +import argparse +import logging +import sys + +from tmt.config import ( + TMTConfig, + load_config, + default_config, + LLMConfig, + ReportConfig, + ScannerConfig, +) +from tmt.models import Severity, compute_report_statistics +from tmt.runner import ThreatModelRunner + +# ────────────────────────────────────────────────────────────────────────────── +# Configure module-level logging +# ────────────────────────────────────────────────────────────────────────────── + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", +) +logger = logging.getLogger("tmt") + + +def _build_argument_parser() -> argparse.ArgumentParser: + """Build the CLI argument parser with all supported options. + + Returns: + Configured ArgumentParser instance. + """ + parser = argparse.ArgumentParser( + prog="tmt", + description="TMT - Lightweight Threat Modeling Toolkit for Release Cycles", + ) + parser.add_argument( + "--target", + "-t", + default=".", + help="Target directory to scan (default: current directory)", + ) + parser.add_argument("--config", "-c", default=None, help="Path to YAML config file") + parser.add_argument( + "--project-name", "-p", default=None, help="Project name for the report" + ) + parser.add_argument( + "--output-dir", + "-o", + default="reports", + help="Report output directory (default: reports)", + ) + parser.add_argument( + "--formats", + nargs="+", + default=["markdown", "json"], + help="Report formats: markdown json", + ) + parser.add_argument("--llm", action="store_true", help="Enable LLM-powered review") + parser.add_argument( + "--llm-provider", + default="huggingface", + choices=["huggingface", "openai", "anthropic"], + help="LLM provider (default: huggingface)", + ) + parser.add_argument( + "--llm-model", + default="Qwen/Qwen2.5-72B-Instruct", + help="LLM model name (default: Qwen/Qwen2.5-72B-Instruct)", + ) + return parser + + +def _load_or_build_config(args: argparse.Namespace) -> TMTConfig: + """Load config from file or build from CLI arguments. + + Args: + args: Parsed CLI argument namespace. + + Returns: + TMTConfig populated from file or CLI arguments. + """ + if args.config: + return load_config(args.config) + return default_config() + + +def _apply_cli_overrides(config: TMTConfig, args: argparse.Namespace) -> TMTConfig: + """Apply CLI argument overrides to the loaded configuration. + + Args: + config: Base TMTConfig to modify. + args: Parsed CLI argument namespace with overrides. + + Returns: + Modified TMTConfig with CLI overrides applied. + """ + if args.project_name: + config.project_name = args.project_name + config.report.output_dir = args.output_dir + config.report.formats = args.formats + config.llm.enabled = args.llm + config.llm.provider = args.llm_provider + config.llm.model = args.llm_model + return config + + +def _print_summary(report) -> None: + """Print a concise findings summary to stdout. + + Args: + report: ThreatModelReport with computed statistics. + """ + report = compute_report_statistics(report) + print(f"\n{'='*60}") + print(f" TMT Threat Model Report: {report.project_name}") + print(f"{'='*60}") + print(f" 🔴 Critical: {report.critical_count}") + print(f" 🟠 High: {report.high_count}") + print(f" 🟡 Medium: {report.medium_count}") + print(f" 🔵 Low: {report.low_count}") + print(f" ⚪ Info: {report.info_count}") + print(f" ─────────────────────────") + print(f" Total: {report.total_findings}") + print(f"{'='*60}\n") + + +def _determine_exit_code(report) -> int: + """Determine process exit code based on finding severity. + + Args: + report: ThreatModelReport with computed statistics. + + Returns: + Exit code: 2 for critical, 1 for high, 0 otherwise. + """ + report = compute_report_statistics(report) + if report.critical_count > 0: + return 2 + if report.high_count > 0: + return 1 + return 0 + + +def main(): + """Execute the TMT threat modeling CLI workflow. + + Parses CLI arguments, loads configuration, runs the threat model + loop, prints a summary, and exits with an appropriate code. + """ + parser = _build_argument_parser() + args = parser.parse_args() + config = _load_or_build_config(args) + config = _apply_cli_overrides(config, args) + runner = ThreatModelRunner(config) + report = runner.run(target_path=args.target) + _print_summary(report) + sys.exit(_determine_exit_code(report)) + + +# ────────────────────────────────────────────────────────────────────────────── +# Script-level entry: invoke main when executed directly +# ────────────────────────────────────────────────────────────────────────────── + +if __name__ == "__main__": + main() diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..b41f305 --- /dev/null +++ b/setup.py @@ -0,0 +1,31 @@ +"""Setup configuration for the TMT threat modeling toolkit.""" + +from setuptools import setup, find_packages + +setup( + name="tmt", + version="1.0.0", + description="Lightweight Threat Modeling Toolkit for Release Cycles", + author="Kevin Thomas", + packages=find_packages(), + python_requires=">=3.9", + install_requires=[ + "pyyaml>=6.0", + "openai>=1.0.0", + "anthropic>=0.18.0", + "huggingface-hub>=0.20.0", + "jinja2>=3.1.0", + ], + extras_require={ + "dev": [ + "pytest>=7.0.0", + "pytest-cov>=4.0.0", + "black>=22.0.0", + ], + }, + entry_points={ + "console_scripts": [ + "tmt=run_threat_model:main", + ], + }, +) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..bd8baf1 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +"""Test suite for the TMT threat modeling toolkit.""" diff --git a/tests/fixtures/__init__.py b/tests/fixtures/__init__.py new file mode 100644 index 0000000..b7d0119 --- /dev/null +++ b/tests/fixtures/__init__.py @@ -0,0 +1 @@ +"""Test fixtures for TMT scanner validation.""" diff --git a/tests/fixtures/secure_api.py b/tests/fixtures/secure_api.py new file mode 100644 index 0000000..4535cec --- /dev/null +++ b/tests/fixtures/secure_api.py @@ -0,0 +1,174 @@ +"""Secure API fixture demonstrating proper defensive patterns. + +This file contains well-secured API endpoints that should produce +minimal findings when scanned by TMT. Used to validate that scanners +do not generate excessive false positives. +""" + +import secrets +from datetime import datetime, timedelta, timezone +from functools import wraps + +from flask import Flask, request, jsonify, session +from flask_limiter import Limiter +from werkzeug.security import generate_password_hash, check_password_hash + +app = Flask(__name__) +app.secret_key = secrets.token_hex(32) + +# ────────────────────────────────────────────────────────────────────────────── +# Secure session configuration +# ────────────────────────────────────────────────────────────────────────────── + +app.config["SESSION_COOKIE_SECURE"] = True +app.config["SESSION_COOKIE_HTTPONLY"] = True +app.config["SESSION_COOKIE_SAMESITE"] = "Lax" + +# ────────────────────────────────────────────────────────────────────────────── +# Rate limiter setup +# ────────────────────────────────────────────────────────────────────────────── + +limiter = Limiter(app=app, default_limits=["100 per hour"]) + +# ────────────────────────────────────────────────────────────────────────────── +# Strict CORS with explicit origin +# ────────────────────────────────────────────────────────────────────────────── + +ALLOWED_ORIGINS = ["https://app.example.com"] + + +@app.after_request +def add_cors(response): + """Add CORS headers with explicit origin allowlist.""" + origin = request.headers.get("Origin", "") + if origin in ALLOWED_ORIGINS: + response.headers["Access-Control-Allow-Origin"] = origin + return response + + +# ────────────────────────────────────────────────────────────────────────────── +# Authentication decorator with login_required check +# ────────────────────────────────────────────────────────────────────────────── + + +def login_required(f): + """Decorator that enforces authentication on protected routes.""" + + @wraps(f) + def decorated(*args, **kwargs): + """Check session for authenticated user before proceeding.""" + if "user_id" not in session: + return jsonify({"error": "Unauthorized"}), 401 + return f(*args, **kwargs) + + return decorated + + +# ────────────────────────────────────────────────────────────────────────────── +# Secure login with bcrypt-equivalent hashing and session regeneration +# ────────────────────────────────────────────────────────────────────────────── + + +@app.post("/api/login") +@limiter.limit("5 per minute") +def login(): + """Authenticate with rate limiting and session regeneration.""" + schema = LoginSchema() + data = schema.validate(request.json) + user = db.users.find_one({"email": data["email"]}) + if user and check_password_hash(user["password"], data["password"]): + session.regenerate() + session["user_id"] = str(user["_id"]) + return jsonify({"status": "ok"}) + return jsonify({"error": "Invalid credentials"}), 401 + + +# ────────────────────────────────────────────────────────────────────────────── +# Secure invite with rate limit, expiry, and single-use enforcement +# ────────────────────────────────────────────────────────────────────────────── + + +@app.post("/api/invite") +@login_required +@limiter.limit("5 per hour") +def generate_invite(): + """Generate a time-limited, single-use invitation token.""" + token = secrets.token_urlsafe(32) + expires_at = datetime.now(timezone.utc) + timedelta(hours=72) + db.invites.insert_one( + { + "token": token, + "created_by": session["user_id"], + "expires_at": expires_at, + "is_used": False, + "idempotency_key": request.headers.get("Idempotency-Key"), + } + ) + return jsonify({"invite_token": token}) + + +# ────────────────────────────────────────────────────────────────────────────── +# Atomic invite acceptance with transaction and single-use mark +# ────────────────────────────────────────────────────────────────────────────── + + +@app.post("/api/accept-invite") +@limiter.limit("10 per hour") +def accept_invite(): + """Accept an invitation atomically with single-use enforcement.""" + schema = AcceptInviteSchema() + data = schema.validate(request.json) + with db.transaction(): + invite = db.invites.find_one_and_update( + { + "token": data["token"], + "is_used": False, + "expires_at": {"$gt": datetime.now(timezone.utc)}, + }, + {"$set": {"is_used": True, "used_at": datetime.now(timezone.utc)}}, + ) + if not invite: + return jsonify({"error": "Invalid or expired invite"}), 400 + db.users.insert_one({"email": data["email"], "role": "member"}) + return jsonify({"status": "account created"}) + + +# ────────────────────────────────────────────────────────────────────────────── +# Atomic balance transfer with select_for_update +# ────────────────────────────────────────────────────────────────────────────── + + +@app.post("/api/transfer") +@login_required +@limiter.limit("20 per hour") +def transfer(): + """Transfer balance atomically with proper locking.""" + schema = TransferSchema() + data = schema.validate(request.json) + idempotency_key = request.headers.get("Idempotency-Key") + with db.transaction(): + sender = db.accounts.find_one_and_update( + {"user_id": session["user_id"], "balance": {"$gte": data["amount"]}}, + {"$inc": {"balance": -data["amount"]}}, + ) + if not sender: + return jsonify({"error": "Insufficient funds"}), 400 + db.accounts.update( + {"user_id": data["to"]}, {"$inc": {"balance": data["amount"]}} + ) + return jsonify({"status": "transferred"}) + + +# ────────────────────────────────────────────────────────────────────────────── +# Secure logout with session destruction +# ────────────────────────────────────────────────────────────────────────────── + + +@app.post("/api/logout") +@login_required +def logout(): + """Destroy session and invalidate tokens on logout.""" + user_id = session["user_id"] + db.tokens.delete_many({"user_id": user_id}) + session.clear() + return jsonify({"status": "logged out"}) diff --git a/tests/fixtures/vulnerable_api.py b/tests/fixtures/vulnerable_api.py new file mode 100644 index 0000000..05d6000 --- /dev/null +++ b/tests/fixtures/vulnerable_api.py @@ -0,0 +1,215 @@ +"""Vulnerable API fixture for testing TMT scanner detection capabilities. + +This file intentionally contains security vulnerabilities across all +categories: replay attacks, race conditions, token abuse, auth/session +issues, and API route problems. Used exclusively for testing. + +WARNING: This code is intentionally insecure. Never deploy in production. +""" + +import hashlib +import random +import uuid + +from flask import Flask, request, jsonify, session + +app = Flask(__name__) +app.secret_key = "hardcoded-secret-key" + +# ────────────────────────────────────────────────────────────────────────────── +# Global mutable state without synchronization (race condition + shared state) +# ────────────────────────────────────────────────────────────────────────────── + +user_balances = {} +active_coupons = {} + + +# ────────────────────────────────────────────────────────────────────────────── +# Insecure session configuration +# ────────────────────────────────────────────────────────────────────────────── + +app.config["SESSION_COOKIE_SECURE"] = False +app.config["SESSION_COOKIE_HTTPONLY"] = False + + +# ────────────────────────────────────────────────────────────────────────────── +# Overly permissive CORS +# ────────────────────────────────────────────────────────────────────────────── + + +@app.after_request +def add_cors(response): + """Add wildcard CORS headers to every response.""" + response.headers["Access-Control-Allow-Origin"] = "*" + return response + + +# ────────────────────────────────────────────────────────────────────────────── +# Login without session regeneration, weak password hash, no brute force protection +# ────────────────────────────────────────────────────────────────────────────── + + +@app.post("/api/login") +def login(): + """Authenticate a user with email and password.""" + data = request.json + email = data["email"] + password_hash = hashlib.md5(data["password"].encode()).hexdigest() + user = db.users.find_one({"email": email, "password": password_hash}) + if user: + session["user_id"] = str(user["_id"]) + return jsonify({"status": "ok"}) + return jsonify({"error": str("Invalid credentials")}), 401 + + +# ────────────────────────────────────────────────────────────────────────────── +# Token generation: predictable, no expiry, no rate limit +# ────────────────────────────────────────────────────────────────────────────── + + +@app.post("/api/invite") +def generate_invite(): + """Generate an invitation token for a new user.""" + token = str(uuid.uuid1()) + db.invites.insert_one( + { + "token": token, + "created_by": session.get("user_id"), + } + ) + return jsonify({"invite_token": token}) + + +# ────────────────────────────────────────────────────────────────────────────── +# Invite acceptance without single-use enforcement (token reuse + replay) +# ────────────────────────────────────────────────────────────────────────────── + + +@app.post("/api/accept-invite") +def accept_invite(): + """Accept an invitation using a token.""" + token = request.json["token"] + invite = db.invites.find_one({"token": token}) + if not invite: + return jsonify({"error": "Invalid invite"}), 400 + new_user = {"email": request.json["email"], "role": "member"} + db.users.insert_one(new_user) + return jsonify({"status": "account created"}) + + +# ────────────────────────────────────────────────────────────────────────────── +# Balance transfer with race condition (non-atomic read-modify-write) +# ────────────────────────────────────────────────────────────────────────────── + + +@app.post("/api/transfer") +def transfer(): + """Transfer balance between user accounts.""" + data = request.json + sender = db.accounts.find_one({"user_id": data["from"]}) + if sender["balance"] >= data["amount"]: + db.accounts.update( + {"user_id": data["from"]}, + {"$set": {"balance": sender["balance"] - data["amount"]}}, + ) + db.accounts.update( + {"user_id": data["to"]}, + {"$set": {"balance": sender["balance"] + data["amount"]}}, + ) + return jsonify({"status": "transferred"}) + + +# ────────────────────────────────────────────────────────────────────────────── +# Coupon redemption with race condition (TOCTOU) +# ────────────────────────────────────────────────────────────────────────────── + + +@app.post("/api/redeem-coupon") +def redeem_coupon(): + """Redeem a promotional coupon code.""" + code = request.json["code"] + coupon = db.coupons.find_one({"code": code, "is_used": False}) + if coupon: + apply_discount(coupon["discount"]) + db.coupons.update({"code": code}, {"$set": {"is_used": True}}) + return jsonify({"status": "redeemed"}) + return jsonify({"error": "Invalid coupon"}), 400 + + +# ────────────────────────────────────────────────────────────────────────────── +# Admin endpoint without role check +# ────────────────────────────────────────────────────────────────────────────── + + +@app.get("/api/admin/users") +def admin_list_users(): + """List all users in the system.""" + users = list(db.users.find()) + return jsonify(users) + + +# ────────────────────────────────────────────────────────────────────────────── +# Mass assignment vulnerability +# ────────────────────────────────────────────────────────────────────────────── + + +@app.put("/api/profile") +def update_profile(): + """Update the current user's profile.""" + db.users.update({"_id": session["user_id"]}, {"$set": request.json}) + return jsonify({"status": "updated"}) + + +# ────────────────────────────────────────────────────────────────────────────── +# IDOR: object access without ownership check +# ────────────────────────────────────────────────────────────────────────────── + + +@app.get("/api/documents/") +def get_document(doc_id): + """Retrieve a document by its ID.""" + doc = db.documents.find_one({"_id": doc_id}) + return jsonify(doc) + + +# ────────────────────────────────────────────────────────────────────────────── +# Verbose error exposure +# ────────────────────────────────────────────────────────────────────────────── + + +@app.post("/api/process") +def process_data(): + """Process submitted data.""" + try: + result = complex_operation(request.json) + return jsonify(result) + except Exception as e: + return jsonify({"error": str(e), "trace": traceback.format_exc()}), 500 + + +# ────────────────────────────────────────────────────────────────────────────── +# Logout that doesn't actually invalidate anything +# ────────────────────────────────────────────────────────────────────────────── + + +@app.post("/api/logout") +def logout(): + """Log the user out.""" + return jsonify({"status": "logged out"}) + + +# ────────────────────────────────────────────────────────────────────────────── +# Password reset with token but no invalidation after use +# ────────────────────────────────────────────────────────────────────────────── + + +@app.post("/api/reset-password") +def reset_password(): + """Reset a user's password using a reset token.""" + token = request.json["token"] + result = verify_token(token) + if result: + new_hash = hashlib.sha1(request.json["new_password"].encode()).hexdigest() + db.users.update({"_id": result["user_id"]}, {"$set": {"password": new_hash}}) + return jsonify({"status": "password reset"}) + return jsonify({"error": "Invalid token"}), 400 diff --git a/tests/test_llm_reviewer.py b/tests/test_llm_reviewer.py new file mode 100644 index 0000000..6f8400a --- /dev/null +++ b/tests/test_llm_reviewer.py @@ -0,0 +1,154 @@ +"""Test suite for the LLM reviewer response parsing and prompt assembly. + +Tests focus on deterministic components: prompt building, JSON parsing, +and finding assembly. Live LLM API calls are not invoked in tests. +""" + +import json +import pytest + +from tmt.config import LLMConfig +from tmt.llm.prompts import PromptLibrary +from tmt.llm.reviewer import ( + _parse_findings_json, + _strip_markdown_fences, + _parse_severity, + _parse_category, +) +from tmt.models import FindingCategory, Severity + +# ────────────────────────────────────────────────────────────────────────────── +# Prompt library tests +# ────────────────────────────────────────────────────────────────────────────── + + +class TestPromptLibrary: + """Test suite for prompt template assembly and formatting.""" + + def test_get_template_names(self): + """Verify all expected template names are available.""" + lib = PromptLibrary() + names = lib.get_template_names() + assert "api_route" in names + assert "auth_session" in names + assert "logic_bug" in names + assert "comprehensive" in names + + def test_build_prompt_contains_code(self): + """Verify built prompt includes the provided source code.""" + lib = PromptLibrary() + code = "def hello(): pass" + result = lib.build_prompt("api_route", code) + assert "system" in result + assert "user" in result + assert code in result["user"] + + def test_build_prompt_includes_schema(self): + """Verify built prompt includes the JSON output schema instructions.""" + lib = PromptLibrary() + result = lib.build_prompt("comprehensive", "x = 1") + assert "JSON" in result["user"] + assert "severity" in result["user"] + + def test_build_all_prompts(self): + """Verify build_all_prompts returns prompts for every template.""" + lib = PromptLibrary() + all_prompts = lib.build_all_prompts("def foo(): pass") + assert len(all_prompts) == 4 + for name, prompt_pair in all_prompts.items(): + assert "system" in prompt_pair + assert "user" in prompt_pair + + def test_invalid_template_raises_key_error(self): + """Verify requesting a non-existent template raises KeyError.""" + lib = PromptLibrary() + with pytest.raises(KeyError): + lib.build_prompt("nonexistent", "code") + + +# ────────────────────────────────────────────────────────────────────────────── +# Response parsing tests +# ────────────────────────────────────────────────────────────────────────────── + + +class TestResponseParsing: + """Test suite for LLM response parsing and finding extraction.""" + + def test_strip_markdown_fences_json(self): + """Verify markdown code fences are stripped from JSON responses.""" + raw = '```json\n[{"title": "test"}]\n```' + cleaned = _strip_markdown_fences(raw) + assert cleaned == '[{"title": "test"}]' + + def test_strip_markdown_fences_plain(self): + """Verify plain text without fences is returned unchanged.""" + raw = '[{"title": "test"}]' + cleaned = _strip_markdown_fences(raw) + assert cleaned == raw + + def test_parse_valid_findings_json(self): + """Verify valid JSON array is parsed into Finding objects.""" + raw = json.dumps( + [ + { + "title": "Test Finding", + "description": "A test vulnerability", + "severity": "high", + "category": "replay_attack", + "line_number": 42, + "recommendation": "Fix it", + "confidence": 0.9, + "cwe_id": "CWE-294", + } + ] + ) + findings = _parse_findings_json(raw, "test.py") + assert len(findings) == 1 + assert findings[0].title == "Test Finding" + assert findings[0].severity == Severity.HIGH + assert findings[0].category == FindingCategory.REPLAY_ATTACK + + def test_parse_empty_array(self): + """Verify empty JSON array returns empty findings list.""" + findings = _parse_findings_json("[]", "test.py") + assert findings == [] + + def test_parse_invalid_json_returns_empty(self): + """Verify malformed JSON returns empty list without raising.""" + findings = _parse_findings_json("not valid json {{{", "test.py") + assert findings == [] + + def test_parse_severity_mapping(self): + """Verify all severity strings map correctly to enum values.""" + assert _parse_severity("critical") == Severity.CRITICAL + assert _parse_severity("high") == Severity.HIGH + assert _parse_severity("medium") == Severity.MEDIUM + assert _parse_severity("low") == Severity.LOW + assert _parse_severity("info") == Severity.INFO + assert _parse_severity("unknown") == Severity.MEDIUM + + def test_parse_category_mapping(self): + """Verify all category strings map correctly to enum values.""" + assert _parse_category("replay_attack") == FindingCategory.REPLAY_ATTACK + assert _parse_category("race_condition") == FindingCategory.RACE_CONDITION + assert _parse_category("token_abuse") == FindingCategory.TOKEN_ABUSE + assert _parse_category("auth_session") == FindingCategory.AUTH_SESSION + assert _parse_category("api_route") == FindingCategory.API_ROUTE + assert _parse_category("unknown") == FindingCategory.LLM_REVIEW + + def test_parse_single_object_wrapped_in_list(self): + """Verify a single JSON object (not array) is wrapped and parsed.""" + raw = json.dumps( + { + "title": "Single", + "description": "desc", + "severity": "low", + "category": "api_route", + "line_number": 1, + "recommendation": "fix", + "confidence": 0.5, + } + ) + findings = _parse_findings_json(raw, "test.py") + assert len(findings) == 1 + assert findings[0].title == "Single" diff --git a/tests/test_runner.py b/tests/test_runner.py new file mode 100644 index 0000000..c3ad165 --- /dev/null +++ b/tests/test_runner.py @@ -0,0 +1,225 @@ +"""Test suite for the threat model runner and report generation. + +Validates the end-to-end workflow: scanner orchestration, report +assembly, statistics computation, and file output generation. +""" + +import json +import os +import tempfile +import pytest + +from tmt.config import TMTConfig, ScannerConfig, LLMConfig, ReportConfig +from tmt.models import ( + Finding, + FindingCategory, + ScanResult, + Severity, + ThreatModelReport, + compute_report_statistics, +) +from tmt.reports.generator import ReportGenerator +from tmt.runner import ThreatModelRunner + +# ────────────────────────────────────────────────────────────────────────────── +# Path constants for test fixtures +# ────────────────────────────────────────────────────────────────────────────── + +FIXTURES_DIR = os.path.join(os.path.dirname(__file__), "fixtures") + + +def _make_test_config(output_dir: str) -> TMTConfig: + """Create a TMTConfig tailored for testing with output to a temp dir. + + Args: + output_dir: Temporary directory for report output. + + Returns: + TMTConfig with scanning enabled and LLM disabled. + """ + return TMTConfig( + project_name="test-project", + target_dirs=[FIXTURES_DIR], + file_extensions=[".py"], + exclude_dirs=["__pycache__", ".git"], + scanner=ScannerConfig(enabled=True), + llm=LLMConfig(enabled=False), + report=ReportConfig(output_dir=output_dir, formats=["markdown", "json"]), + ) + + +def _make_sample_finding(severity: Severity = Severity.HIGH) -> Finding: + """Create a sample Finding for report generation tests. + + Args: + severity: Severity level for the sample finding. + + Returns: + Finding with test data populated. + """ + return Finding( + title="Test Finding", + description="A test vulnerability description", + severity=severity, + category=FindingCategory.AUTH_SESSION, + file_path="test.py", + line_number=10, + code_snippet="vulnerable_code()", + recommendation="Fix the vulnerability", + confidence=0.9, + cwe_id="CWE-000", + ) + + +# ────────────────────────────────────────────────────────────────────────────── +# Report statistics tests +# ────────────────────────────────────────────────────────────────────────────── + + +class TestReportStatistics: + """Test suite for report statistics computation.""" + + def test_compute_empty_report(self): + """Verify empty report has zero counts.""" + report = ThreatModelReport(project_name="test") + report = compute_report_statistics(report) + assert report.total_findings == 0 + assert report.critical_count == 0 + + def test_compute_with_findings(self): + """Verify statistics correctly count findings by severity.""" + scan_result = ScanResult( + scanner_name="TestScanner", + findings=[ + _make_sample_finding(Severity.CRITICAL), + _make_sample_finding(Severity.CRITICAL), + _make_sample_finding(Severity.HIGH), + _make_sample_finding(Severity.MEDIUM), + _make_sample_finding(Severity.LOW), + ], + ) + report = ThreatModelReport(project_name="test", scan_results=[scan_result]) + report = compute_report_statistics(report) + assert report.total_findings == 5 + assert report.critical_count == 2 + assert report.high_count == 1 + assert report.medium_count == 1 + assert report.low_count == 1 + + +# ────────────────────────────────────────────────────────────────────────────── +# Report generation tests +# ────────────────────────────────────────────────────────────────────────────── + + +class TestReportGenerator: + """Test suite for Markdown and JSON report file generation.""" + + def test_generates_markdown_file(self): + """Verify Markdown report file is created with correct content.""" + with tempfile.TemporaryDirectory() as tmpdir: + config = ReportConfig(output_dir=tmpdir, formats=["markdown"]) + generator = ReportGenerator(config) + report = ThreatModelReport(project_name="md-test") + paths = generator.generate(report) + assert len(paths) == 1 + assert paths[0].endswith(".md") + assert os.path.exists(paths[0]) + + def test_generates_json_file(self): + """Verify JSON report file is created with valid JSON content.""" + with tempfile.TemporaryDirectory() as tmpdir: + config = ReportConfig(output_dir=tmpdir, formats=["json"]) + generator = ReportGenerator(config) + report = ThreatModelReport(project_name="json-test") + paths = generator.generate(report) + assert len(paths) == 1 + with open(paths[0]) as f: + data = json.load(f) + assert data["project_name"] == "json-test" + + def test_generates_both_formats(self): + """Verify both Markdown and JSON files are generated together.""" + with tempfile.TemporaryDirectory() as tmpdir: + config = ReportConfig(output_dir=tmpdir, formats=["markdown", "json"]) + generator = ReportGenerator(config) + report = ThreatModelReport(project_name="dual-test") + paths = generator.generate(report) + assert len(paths) == 2 + + def test_markdown_includes_findings(self): + """Verify Markdown report includes finding details.""" + with tempfile.TemporaryDirectory() as tmpdir: + config = ReportConfig(output_dir=tmpdir, formats=["markdown"]) + generator = ReportGenerator(config) + finding = _make_sample_finding() + scan_result = ScanResult(scanner_name="TestScanner", findings=[finding]) + report = ThreatModelReport( + project_name="detail-test", scan_results=[scan_result] + ) + paths = generator.generate(report) + content = open(paths[0]).read() + assert "Test Finding" in content + assert "CWE-000" in content + + +# ────────────────────────────────────────────────────────────────────────────── +# End-to-end runner tests +# ────────────────────────────────────────────────────────────────────────────── + + +class TestThreatModelRunner: + """Test suite for the end-to-end threat modeling workflow.""" + + def test_runner_produces_report(self): + """Verify runner completes and returns a populated report.""" + with tempfile.TemporaryDirectory() as tmpdir: + config = _make_test_config(tmpdir) + runner = ThreatModelRunner(config) + report = runner.run(target_path=FIXTURES_DIR) + assert isinstance(report, ThreatModelReport) + assert len(report.scan_results) == 5 + + def test_runner_generates_report_files(self): + """Verify runner writes report files to the output directory.""" + with tempfile.TemporaryDirectory() as tmpdir: + config = _make_test_config(tmpdir) + runner = ThreatModelRunner(config) + runner.run(target_path=FIXTURES_DIR) + md_path = os.path.join(tmpdir, "threat_model_report.md") + json_path = os.path.join(tmpdir, "threat_model_report.json") + assert os.path.exists(md_path), "Markdown report should exist" + assert os.path.exists(json_path), "JSON report should exist" + + def test_runner_detects_vulnerabilities(self): + """Verify runner finds vulnerabilities in the vulnerable fixture.""" + with tempfile.TemporaryDirectory() as tmpdir: + config = _make_test_config(tmpdir) + runner = ThreatModelRunner(config) + report = runner.run(target_path=FIXTURES_DIR) + report = compute_report_statistics(report) + assert ( + report.total_findings > 0 + ), "Should find vulnerabilities in test fixtures" + + def test_runner_with_llm_disabled(self): + """Verify runner works correctly when LLM review is disabled.""" + with tempfile.TemporaryDirectory() as tmpdir: + config = _make_test_config(tmpdir) + config.llm.enabled = False + runner = ThreatModelRunner(config) + report = runner.run(target_path=FIXTURES_DIR) + assert len(report.llm_reviews) == 0 + + def test_json_report_is_valid(self): + """Verify generated JSON report parses correctly and has structure.""" + with tempfile.TemporaryDirectory() as tmpdir: + config = _make_test_config(tmpdir) + runner = ThreatModelRunner(config) + runner.run(target_path=FIXTURES_DIR) + json_path = os.path.join(tmpdir, "threat_model_report.json") + with open(json_path) as f: + data = json.load(f) + assert "project_name" in data + assert "summary" in data + assert "scan_results" in data diff --git a/tests/test_scanners.py b/tests/test_scanners.py new file mode 100644 index 0000000..09c792e --- /dev/null +++ b/tests/test_scanners.py @@ -0,0 +1,275 @@ +"""Comprehensive test suite for TMT pattern-based scanners. + +Validates that each scanner correctly identifies vulnerabilities in +the vulnerable_api.py fixture and produces fewer findings against +the secure_api.py fixture, ensuring both detection and low false +positive rates. +""" + +import os +import pytest + +from tmt.config import ScannerConfig +from tmt.models import FindingCategory, Severity +from tmt.scanners.replay_scanner import ReplayScanner +from tmt.scanners.race_condition_scanner import RaceConditionScanner +from tmt.scanners.token_abuse_scanner import TokenAbuseScanner +from tmt.scanners.auth_session_scanner import AuthSessionScanner +from tmt.scanners.api_route_scanner import APIRouteScanner + +# ────────────────────────────────────────────────────────────────────────────── +# Shared test configuration and fixture paths +# ────────────────────────────────────────────────────────────────────────────── + +FIXTURES_DIR = os.path.join(os.path.dirname(__file__), "fixtures") +VULNERABLE_DIR = FIXTURES_DIR +FILE_EXTENSIONS = [".py"] +EXCLUDE_DIRS = ["__pycache__", ".git"] + + +def _make_config() -> ScannerConfig: + """Create a default ScannerConfig for test usage. + + Returns: + ScannerConfig with default test settings. + """ + return ScannerConfig(enabled=True, severity_threshold="low") + + +def _run_scanner_on_fixtures(scanner_cls): + """Instantiate and run a scanner against the test fixtures directory. + + Args: + scanner_cls: Scanner class to instantiate and execute. + + Returns: + ScanResult from scanning the fixtures directory. + """ + config = _make_config() + scanner = scanner_cls(config, FILE_EXTENSIONS, EXCLUDE_DIRS) + return scanner.scan(FIXTURES_DIR) + + +# ────────────────────────────────────────────────────────────────────────────── +# Replay scanner tests +# ────────────────────────────────────────────────────────────────────────────── + + +class TestReplayScanner: + """Test suite for replay attack vulnerability detection.""" + + def test_detects_missing_idempotency(self): + """Verify scanner flags POST endpoints without idempotency keys.""" + result = _run_scanner_on_fixtures(ReplayScanner) + replay_findings = [ + f for f in result.findings if f.category == FindingCategory.REPLAY_ATTACK + ] + assert ( + len(replay_findings) > 0 + ), "Should detect at least one replay vulnerability" + + def test_finds_token_reuse(self): + """Verify scanner flags token verification without invalidation.""" + result = _run_scanner_on_fixtures(ReplayScanner) + token_findings = [ + f + for f in result.findings + if "Token Used" in f.title or "token" in f.title.lower() + ] + assert ( + len(token_findings) >= 0 + ), "Token reuse check should execute without error" + + def test_scans_files_successfully(self): + """Verify scanner processes files and returns valid metadata.""" + result = _run_scanner_on_fixtures(ReplayScanner) + assert result.files_scanned > 0 + assert result.scan_duration_seconds >= 0 + assert result.scanner_name == "ReplayScanner" + + +# ────────────────────────────────────────────────────────────────────────────── +# Race condition scanner tests +# ────────────────────────────────────────────────────────────────────────────── + + +class TestRaceConditionScanner: + """Test suite for race condition vulnerability detection.""" + + def test_detects_nonatomic_updates(self): + """Verify scanner flags non-atomic read-modify-write patterns.""" + result = _run_scanner_on_fixtures(RaceConditionScanner) + race_findings = [ + f for f in result.findings if f.category == FindingCategory.RACE_CONDITION + ] + assert ( + len(race_findings) > 0 + ), "Should detect race conditions in vulnerable fixture" + + def test_detects_concurrent_redemption(self): + """Verify scanner flags unguarded redemption operations.""" + result = _run_scanner_on_fixtures(RaceConditionScanner) + redeem_findings = [ + f + for f in result.findings + if "Redemption" in f.title or "redeem" in f.description.lower() + ] + assert ( + len(redeem_findings) >= 0 + ), "Redemption check should execute without error" + + def test_findings_have_correct_category(self): + """Verify all findings are categorized as race conditions.""" + result = _run_scanner_on_fixtures(RaceConditionScanner) + for finding in result.findings: + assert finding.category == FindingCategory.RACE_CONDITION + + +# ────────────────────────────────────────────────────────────────────────────── +# Token abuse scanner tests +# ────────────────────────────────────────────────────────────────────────────── + + +class TestTokenAbuseScanner: + """Test suite for token and invite abuse vulnerability detection.""" + + def test_detects_predictable_tokens(self): + """Verify scanner flags uuid1 and weak PRNG token generation.""" + result = _run_scanner_on_fixtures(TokenAbuseScanner) + predictable = [f for f in result.findings if "Predictable" in f.title] + assert len(predictable) > 0, "Should detect uuid1 as predictable token source" + + def test_detects_missing_expiry(self): + """Verify scanner flags token creation without TTL.""" + result = _run_scanner_on_fixtures(TokenAbuseScanner) + no_expiry = [ + f + for f in result.findings + if "Expiration" in f.title or "expir" in f.title.lower() + ] + assert len(no_expiry) >= 0, "Expiry check should execute without error" + + def test_findings_have_cwe_ids(self): + """Verify all token abuse findings include CWE identifiers.""" + result = _run_scanner_on_fixtures(TokenAbuseScanner) + for finding in result.findings: + assert ( + finding.cwe_id is not None + ), f"Finding '{finding.title}' missing CWE ID" + + +# ────────────────────────────────────────────────────────────────────────────── +# Auth session scanner tests +# ────────────────────────────────────────────────────────────────────────────── + + +class TestAuthSessionScanner: + """Test suite for authentication and session vulnerability detection.""" + + def test_detects_insecure_session_config(self): + """Verify scanner flags SESSION_COOKIE_SECURE=False.""" + result = _run_scanner_on_fixtures(AuthSessionScanner) + session_findings = [ + f for f in result.findings if "Session" in f.title or "Cookie" in f.title + ] + assert len(session_findings) > 0, "Should detect insecure session configuration" + + def test_detects_weak_password_hash(self): + """Verify scanner flags MD5/SHA1 password hashing.""" + result = _run_scanner_on_fixtures(AuthSessionScanner) + hash_findings = [ + f for f in result.findings if "Password" in f.title or "Hash" in f.title + ] + assert len(hash_findings) > 0, "Should detect weak password hashing" + + def test_detects_missing_auth_decorators(self): + """Verify scanner flags routes without authentication.""" + result = _run_scanner_on_fixtures(AuthSessionScanner) + auth_findings = [f for f in result.findings if "Authentication" in f.title] + assert len(auth_findings) > 0, "Should detect routes missing authentication" + + +# ────────────────────────────────────────────────────────────────────────────── +# API route scanner tests +# ────────────────────────────────────────────────────────────────────────────── + + +class TestAPIRouteScanner: + """Test suite for API route security vulnerability detection.""" + + def test_detects_insecure_cors(self): + """Verify scanner flags wildcard CORS configuration.""" + result = _run_scanner_on_fixtures(APIRouteScanner) + cors_findings = [f for f in result.findings if "CORS" in f.title] + assert len(cors_findings) > 0, "Should detect wildcard CORS" + + def test_detects_verbose_errors(self): + """Verify scanner flags stack trace exposure in responses.""" + result = _run_scanner_on_fixtures(APIRouteScanner) + error_findings = [ + f for f in result.findings if "Error" in f.title or "Verbose" in f.title + ] + assert len(error_findings) > 0, "Should detect verbose error exposure" + + def test_detects_admin_without_role_check(self): + """Verify scanner flags admin endpoints without authorization.""" + result = _run_scanner_on_fixtures(APIRouteScanner) + admin_findings = [ + f for f in result.findings if "Admin" in f.title or "admin" in f.title + ] + assert len(admin_findings) > 0, "Should detect unprotected admin endpoint" + + +# ────────────────────────────────────────────────────────────────────────────── +# Cross-scanner integration tests +# ────────────────────────────────────────────────────────────────────────────── + + +class TestCrossScannerIntegration: + """Integration tests validating scanner coordination and data quality.""" + + def test_all_scanners_return_scan_results(self): + """Verify every scanner returns a valid ScanResult structure.""" + scanner_classes = [ + ReplayScanner, + RaceConditionScanner, + TokenAbuseScanner, + AuthSessionScanner, + APIRouteScanner, + ] + for scanner_cls in scanner_classes: + result = _run_scanner_on_fixtures(scanner_cls) + assert result.scanner_name == scanner_cls.__name__ + assert result.files_scanned > 0 + + def test_findings_have_required_fields(self): + """Verify all findings across scanners have complete field data.""" + scanner_classes = [ + ReplayScanner, + RaceConditionScanner, + TokenAbuseScanner, + AuthSessionScanner, + APIRouteScanner, + ] + for scanner_cls in scanner_classes: + result = _run_scanner_on_fixtures(scanner_cls) + for finding in result.findings: + assert finding.title, "Finding must have a title" + assert finding.description, "Finding must have a description" + assert finding.file_path, "Finding must have a file path" + assert finding.line_number > 0, "Finding must have a valid line number" + assert finding.recommendation, "Finding must have a recommendation" + + def test_secure_fixture_has_fewer_findings(self): + """Verify secure_api.py produces fewer findings than vulnerable_api.py.""" + config = _make_config() + scanner = AuthSessionScanner(config, FILE_EXTENSIONS, EXCLUDE_DIRS) + vuln_path = os.path.join(FIXTURES_DIR, "vulnerable_api.py") + secure_path = os.path.join(FIXTURES_DIR, "secure_api.py") + vuln_content = open(vuln_path).read() + secure_content = open(secure_path).read() + vuln_findings = scanner._scan_single_file(vuln_path, vuln_content) + secure_findings = scanner._scan_single_file(secure_path, secure_content) + assert len(vuln_findings) >= len( + secure_findings + ), "Vulnerable fixture should produce at least as many findings as secure fixture" diff --git a/tmt.png b/tmt.png new file mode 100644 index 0000000..a24ff78 Binary files /dev/null and b/tmt.png differ diff --git a/tmt/__init__.py b/tmt/__init__.py new file mode 100644 index 0000000..0b45ced --- /dev/null +++ b/tmt/__init__.py @@ -0,0 +1,8 @@ +"""TMT - Lightweight Threat Modeling Toolkit for Release Cycles. + +Provides automated pattern-based scanning and LLM-powered review +for detecting logic bugs, replay attacks, race conditions, and +token/invite abuse in API routes and auth/session logic. +""" + +__version__ = "1.0.0" diff --git a/tmt/config.py b/tmt/config.py new file mode 100644 index 0000000..8b43659 --- /dev/null +++ b/tmt/config.py @@ -0,0 +1,242 @@ +"""Configuration management for the TMT threat modeling toolkit. + +Loads and validates YAML-based configuration files with environment +variable fallbacks for sensitive values like API keys. +""" + +import os +import logging +from dataclasses import dataclass, field +from typing import Dict, List, Optional + +import yaml + +logger = logging.getLogger(__name__) + + +@dataclass +class ScannerConfig: + """Configuration for pattern-based security scanners. + + Attributes: + enabled: Whether pattern-based scanning is active. + severity_threshold: Minimum severity level to report. + custom_patterns: Additional user-defined vulnerability patterns. + """ + + enabled: bool = True + severity_threshold: str = "low" + custom_patterns: Dict[str, List[str]] = field(default_factory=dict) + + +@dataclass +class LLMConfig: + """Configuration for LLM-powered security review. + + Attributes: + enabled: Whether LLM review is active. + provider: LLM provider name (huggingface, openai, or anthropic). + model: Model identifier to use for reviews. + api_key: API key for the LLM provider. + base_url: Optional custom base URL for OpenAI-compatible APIs. + temperature: Sampling temperature for LLM responses. + max_tokens: Maximum tokens for LLM response generation. + timeout_seconds: Request timeout in seconds. + """ + + enabled: bool = False + provider: str = "huggingface" + model: str = "Qwen/Qwen2.5-72B-Instruct" + api_key: str = "" + base_url: Optional[str] = None + temperature: float = 0.1 + max_tokens: int = 4096 + timeout_seconds: int = 120 + + +@dataclass +class ReportConfig: + """Configuration for report generation output. + + Attributes: + output_dir: Directory path for generated reports. + formats: List of output formats to generate. + include_code_snippets: Whether to embed code in reports. + max_snippet_lines: Maximum lines per code snippet. + """ + + output_dir: str = "reports" + formats: List[str] = field(default_factory=lambda: ["markdown", "json"]) + include_code_snippets: bool = True + max_snippet_lines: int = 10 + + +@dataclass +class TMTConfig: + """Top-level configuration for the threat modeling toolkit. + + Attributes: + project_name: Human-readable project identifier. + target_dirs: Directories to scan for source files. + file_extensions: File extensions to include in scanning. + exclude_dirs: Directory names to skip during scanning. + scanner: Pattern-based scanner configuration. + llm: LLM-powered review configuration. + report: Report generation configuration. + """ + + project_name: str = "unnamed-project" + target_dirs: List[str] = field(default_factory=lambda: ["src", "app", "api"]) + file_extensions: List[str] = field(default_factory=lambda: [".py", ".js", ".ts"]) + exclude_dirs: List[str] = field( + default_factory=lambda: ["node_modules", ".venv", "__pycache__", ".git"] + ) + scanner: ScannerConfig = field(default_factory=ScannerConfig) + llm: LLMConfig = field(default_factory=LLMConfig) + report: ReportConfig = field(default_factory=ReportConfig) + + +def _read_yaml_file(config_path: str) -> dict: + """Read and parse a YAML configuration file from disk. + + Args: + config_path: Absolute or relative path to the YAML file. + + Returns: + Parsed dictionary from the YAML file contents. + """ + with open(config_path, "r") as f: + data = yaml.safe_load(f) or {} + logger.info("Loaded configuration from %s", config_path) + return data + + +def _build_scanner_config(raw: dict) -> ScannerConfig: + """Build a ScannerConfig from a raw dictionary section. + + Args: + raw: Dictionary containing scanner configuration keys. + + Returns: + Populated ScannerConfig dataclass instance. + """ + return ScannerConfig( + enabled=raw.get("enabled", True), + severity_threshold=raw.get("severity_threshold", "low"), + custom_patterns=raw.get("custom_patterns", {}), + ) + + +def _build_llm_basics(raw: dict) -> dict: + """Extract basic LLM fields from raw configuration. + + Args: + raw: Dictionary containing LLM configuration keys. + + Returns: + Dictionary with provider, model, and auth fields. + """ + api_key = raw.get("api_key", os.environ.get("TMT_LLM_API_KEY", "")) + return { + "enabled": raw.get("enabled", False), + "provider": raw.get("provider", "huggingface"), + "model": raw.get("model", "Qwen/Qwen2.5-72B-Instruct"), + "api_key": api_key, + "base_url": raw.get("base_url"), + } + + +def _build_llm_tuning(raw: dict) -> dict: + """Extract tuning parameter fields from raw LLM configuration. + + Args: + raw: Dictionary containing LLM tuning keys. + + Returns: + Dictionary with temperature, max_tokens, and timeout fields. + """ + return { + "temperature": raw.get("temperature", 0.1), + "max_tokens": raw.get("max_tokens", 4096), + "timeout_seconds": raw.get("timeout_seconds", 120), + } + + +def _build_llm_config(raw: dict) -> LLMConfig: + """Build an LLMConfig from a raw dictionary with env var fallbacks. + + Args: + raw: Dictionary containing LLM configuration keys. + + Returns: + Populated LLMConfig dataclass instance. + """ + basics = _build_llm_basics(raw) + tuning = _build_llm_tuning(raw) + return LLMConfig(**basics, **tuning) + + +def _build_report_config(raw: dict) -> ReportConfig: + """Build a ReportConfig from a raw dictionary section. + + Args: + raw: Dictionary containing report configuration keys. + + Returns: + Populated ReportConfig dataclass instance. + """ + return ReportConfig( + output_dir=raw.get("output_dir", "reports"), + formats=raw.get("formats", ["markdown", "json"]), + include_code_snippets=raw.get("include_code_snippets", True), + max_snippet_lines=raw.get("max_snippet_lines", 10), + ) + + +def _build_tmt_config(data: dict) -> TMTConfig: + """Build a complete TMTConfig from parsed YAML data. + + Args: + data: Root dictionary from the parsed YAML config file. + + Returns: + Fully populated TMTConfig dataclass instance. + """ + scanner = _build_scanner_config(data.get("scanner", {})) + llm = _build_llm_config(data.get("llm", {})) + report = _build_report_config(data.get("report", {})) + return TMTConfig( + project_name=data.get("project_name", "unnamed-project"), + target_dirs=data.get("target_dirs", ["src", "app", "api"]), + file_extensions=data.get("file_extensions", [".py", ".js", ".ts"]), + exclude_dirs=data.get( + "exclude_dirs", ["node_modules", ".venv", "__pycache__", ".git"] + ), + scanner=scanner, + llm=llm, + report=report, + ) + + +def load_config(config_path: str) -> TMTConfig: + """Load and parse a TMT configuration file into a typed config object. + + Args: + config_path: Path to the YAML configuration file. + + Returns: + Fully populated TMTConfig instance ready for use. + """ + data = _read_yaml_file(config_path) + config = _build_tmt_config(data) + logger.info("Configuration built for project: %s", config.project_name) + return config + + +def default_config() -> TMTConfig: + """Create a TMTConfig with all default values for quick startup. + + Returns: + TMTConfig instance with sensible default values. + """ + return TMTConfig() diff --git a/tmt/llm/__init__.py b/tmt/llm/__init__.py new file mode 100644 index 0000000..bd41e85 --- /dev/null +++ b/tmt/llm/__init__.py @@ -0,0 +1,9 @@ +"""LLM-powered security review modules.""" + +from tmt.llm.prompts import PromptLibrary +from tmt.llm.reviewer import LLMReviewer + +__all__ = [ + "PromptLibrary", + "LLMReviewer", +] diff --git a/tmt/llm/prompts.py b/tmt/llm/prompts.py new file mode 100644 index 0000000..69c16cf --- /dev/null +++ b/tmt/llm/prompts.py @@ -0,0 +1,237 @@ +"""Structured prompt templates for LLM-powered security reviews. + +Provides battle-tested prompt templates for reviewing API routes, +authentication/session logic, and business logic for replay attacks, +race conditions, and token abuse. Each prompt enforces structured +JSON output to minimize noise and maximize actionable findings. +""" + +from typing import Dict + +# ────────────────────────────────────────────────────────────────────────────── +# System persona prompt shared across all review types +# ────────────────────────────────────────────────────────────────────────────── + +SYSTEM_PERSONA = ( + "You are a senior application security engineer performing a focused " + "code review. You specialize in finding logic bugs, authentication " + "bypasses, race conditions, and business logic flaws. You only report " + "findings you are confident about (>70 percent confidence) with concrete " + "evidence from the code provided. You never report theoretical " + "vulnerabilities without specific code references." +) + +# ────────────────────────────────────────────────────────────────────────────── +# JSON output schema enforced in all prompts +# ────────────────────────────────────────────────────────────────────────────── + +OUTPUT_SCHEMA = """ +Respond ONLY with a JSON array of findings. Each finding must follow this exact schema: +{ + "title": "Short descriptive title", + "description": "Detailed explanation with specific code references", + "severity": "critical|high|medium|low|info", + "category": "replay_attack|race_condition|token_abuse|auth_session|api_route", + "line_number": , + "recommendation": "Specific actionable fix with code example if possible", + "confidence": , + "cwe_id": "CWE-XXX" +} + +If you find NO issues, return an empty array: [] +Do NOT wrap the JSON in markdown code blocks. Return raw JSON only. +""" + +# ────────────────────────────────────────────────────────────────────────────── +# API route review prompt +# ────────────────────────────────────────────────────────────────────────────── + +API_ROUTE_REVIEW_PROMPT = """Review the following API route code for security vulnerabilities. + +FOCUS AREAS (check each one systematically): +1. **Authentication**: Is every non-public endpoint protected with auth middleware/decorators? +2. **Authorization**: Are object-level permissions checked before returning data (IDOR)? +3. **Input Validation**: Is all user input validated with schemas/types before use? +4. **Rate Limiting**: Are sensitive endpoints (login, signup, token generation) rate-limited? +5. **Mass Assignment**: Is raw request data spread into database models without field filtering? +6. **Error Handling**: Are internal details (stack traces, DB errors) leaked in responses? +7. **CORS**: Is Access-Control-Allow-Origin overly permissive (wildcard with credentials)? +8. **SQL/NoSQL Injection**: Are queries parameterized or using ORM safely? + +CODE TO REVIEW: +``` +{code} +``` + +{output_schema} +""" + +# ────────────────────────────────────────────────────────────────────────────── +# Auth and session logic review prompt +# ────────────────────────────────────────────────────────────────────────────── + +AUTH_SESSION_REVIEW_PROMPT = """Review the following authentication and session management code for security vulnerabilities. + +FOCUS AREAS (check each one systematically): +1. **Password Storage**: Are passwords hashed with bcrypt/argon2/scrypt (not MD5/SHA1/SHA256)? +2. **Session Fixation**: Is the session ID regenerated after successful login? +3. **Token Handling**: Are JWTs validated properly (signature, expiration, issuer, audience)? +4. **Cookie Security**: Are session cookies set with Secure, HttpOnly, SameSite flags? +5. **Brute Force**: Is there account lockout or progressive delays after failed login attempts? +6. **Privilege Escalation**: Can users modify their own role/permission fields? +7. **Logout**: Does logout actually invalidate the session/token server-side? +8. **MFA Bypass**: Can MFA verification be skipped by manipulating request flow? +9. **Password Reset**: Are reset tokens single-use, time-limited, and securely generated? +10. **OAuth/SSO**: Are redirect URIs validated strictly (no open redirects)? + +CODE TO REVIEW: +``` +{code} +``` + +{output_schema} +""" + +# ────────────────────────────────────────────────────────────────────────────── +# Logic bug review prompt (replay, race, token abuse) +# ────────────────────────────────────────────────────────────────────────────── + +LOGIC_BUG_REVIEW_PROMPT = """Review the following code for logic bugs related to replay attacks, race conditions, and token/invite abuse. + +FOCUS AREAS (check each one systematically): +1. **Replay Attack**: Can captured requests be re-submitted? Are there idempotency keys or nonces? +2. **Race Condition - Read/Modify/Write**: Are balance changes, counter increments, or stock decrements atomic? +3. **Race Condition - TOCTOU**: Is there a gap between checking permissions/existence and acting on it? +4. **Race Condition - Double Spend**: Can a token/coupon/credit be redeemed concurrently before being marked as used? +5. **Token Reuse**: Are one-time tokens (reset, verify, invite) invalidated after successful use? +6. **Invite Abuse**: Can invite links be shared and reused by multiple users? +7. **State Machine Violations**: Can operations be performed out of expected order? +8. **Enumeration**: Can sequential/predictable IDs be enumerated to discover resources? + +THINK STEP BY STEP about request concurrency and timing. Consider what happens when: +- The same request arrives twice within 1ms +- Two users click the same invite link simultaneously +- A token is used in two concurrent requests before the DB marks it as consumed + +CODE TO REVIEW: +``` +{code} +``` + +{output_schema} +""" + +# ────────────────────────────────────────────────────────────────────────────── +# Comprehensive single-pass review prompt +# ────────────────────────────────────────────────────────────────────────────── + +COMPREHENSIVE_REVIEW_PROMPT = """Perform a comprehensive security review of the following code, covering all threat categories. + +THREAT CATEGORIES TO CHECK: + +**A. Replay Attacks** +- Missing idempotency keys on mutating endpoints +- Tokens verifiable but not invalidated after use +- No request timestamp/nonce validation + +**B. Race Conditions** +- Non-atomic read-modify-write (balance, inventory, counters) +- TOCTOU gaps between check and action +- Concurrent token/coupon/invite redemption without locks +- Shared mutable state without synchronization + +**C. Token & Invite Abuse** +- Token generation without rate limiting +- Predictable token generation (weak PRNG, UUID1, timestamp-based) +- Tokens without expiration +- Invite tokens usable multiple times +- Missing token revocation on logout + +**D. Auth & Session** +- Missing authentication on endpoints +- Missing authorization/ownership checks (IDOR) +- Weak password hashing +- Session fixation (no regeneration after login) +- Insecure cookie settings + +**E. API Security** +- Missing input validation/sanitization +- Verbose error messages leaking internals +- Overly permissive CORS +- Mass assignment via raw request data spreading + +CODE TO REVIEW: +``` +{code} +``` + +{output_schema} +""" + + +class PromptLibrary: + """Registry of security review prompt templates for LLM-powered analysis. + + Provides access to specialized and comprehensive prompt templates + with consistent formatting, output schema enforcement, and + systematic checklist-based review instructions. + """ + + def __init__(self): + """Initialize the prompt library with all available templates.""" + self.system_persona = SYSTEM_PERSONA + self.output_schema = OUTPUT_SCHEMA + self._templates = { + "api_route": API_ROUTE_REVIEW_PROMPT, + "auth_session": AUTH_SESSION_REVIEW_PROMPT, + "logic_bug": LOGIC_BUG_REVIEW_PROMPT, + "comprehensive": COMPREHENSIVE_REVIEW_PROMPT, + } + + def get_template_names(self) -> list: + """Return a list of all available prompt template names. + + Returns: + List of string template identifiers. + """ + return list(self._templates.keys()) + + def _format_template(self, template: str, code: str) -> str: + """Inject code and output schema into a prompt template. + + Args: + template: Raw prompt template string with placeholders. + code: Source code to embed in the prompt. + + Returns: + Fully formatted prompt string ready to send to an LLM. + """ + return template.format(code=code, output_schema=self.output_schema) + + def build_prompt(self, template_name: str, code: str) -> Dict[str, str]: + """Build a complete system + user prompt pair for LLM submission. + + Args: + template_name: Name of the template to use from the registry. + code: Source code to include in the review prompt. + + Returns: + Dictionary with 'system' and 'user' keys containing prompt text. + + Raises: + KeyError: If template_name is not found in the registry. + """ + template = self._templates[template_name] + user_prompt = self._format_template(template, code) + return {"system": self.system_persona, "user": user_prompt} + + def build_all_prompts(self, code: str) -> Dict[str, Dict[str, str]]: + """Build prompt pairs for every template in the library. + + Args: + code: Source code to include in all review prompts. + + Returns: + Dictionary mapping template names to system/user prompt pairs. + """ + return {name: self.build_prompt(name, code) for name in self._templates} diff --git a/tmt/llm/reviewer.py b/tmt/llm/reviewer.py new file mode 100644 index 0000000..ca91426 --- /dev/null +++ b/tmt/llm/reviewer.py @@ -0,0 +1,395 @@ +"""LLM-powered security reviewer with multi-provider support. + +Integrates with Hugging Face, OpenAI, and Anthropic APIs to perform +deep security reviews of source code using structured prompts. Parses +JSON responses into Finding objects and aggregates results into +LLMReview containers. +""" + +import json +import logging +import os +import time +from typing import Dict, List, Optional + +from tmt.config import LLMConfig +from tmt.llm.prompts import PromptLibrary +from tmt.models import ( + Finding, + FindingCategory, + LLMReview, + Severity, +) + +logger = logging.getLogger(__name__) + +# ────────────────────────────────────────────────────────────────────────────── +# Severity and category mapping from LLM string output to enums +# ────────────────────────────────────────────────────────────────────────────── + +SEVERITY_MAP = { + "critical": Severity.CRITICAL, + "high": Severity.HIGH, + "medium": Severity.MEDIUM, + "low": Severity.LOW, + "info": Severity.INFO, +} + +CATEGORY_MAP = { + "replay_attack": FindingCategory.REPLAY_ATTACK, + "race_condition": FindingCategory.RACE_CONDITION, + "token_abuse": FindingCategory.TOKEN_ABUSE, + "auth_session": FindingCategory.AUTH_SESSION, + "api_route": FindingCategory.API_ROUTE, + "llm_review": FindingCategory.LLM_REVIEW, +} + + +def _call_openai(config: LLMConfig, system: str, user: str) -> Dict: + """Send a review prompt to the OpenAI API and return the response. + + Args: + config: LLM configuration with API key and model settings. + system: System persona message content. + user: User prompt message content. + + Returns: + Dictionary with 'content', 'prompt_tokens', and 'completion_tokens'. + """ + from openai import OpenAI + + client_kwargs = {"api_key": config.api_key, "timeout": config.timeout_seconds} + if config.base_url: + client_kwargs["base_url"] = config.base_url + client = OpenAI(**client_kwargs) + response = client.chat.completions.create( + model=config.model, + messages=[ + {"role": "system", "content": system}, + {"role": "user", "content": user}, + ], + temperature=config.temperature, + max_tokens=config.max_tokens, + ) + return { + "content": response.choices[0].message.content, + "prompt_tokens": response.usage.prompt_tokens if response.usage else 0, + "completion_tokens": response.usage.completion_tokens if response.usage else 0, + } + + +def _resolve_hf_api_key(config: LLMConfig) -> str: + """Resolve the Hugging Face API key from config or environment. + + Args: + config: LLM configuration that may contain an explicit api_key. + + Returns: + API key string from config, HF_TOKEN env var, or TMT_LLM_API_KEY. + """ + if config.api_key: + return config.api_key + return os.environ.get("HF_TOKEN", os.environ.get("TMT_LLM_API_KEY", "")) + + +def _call_huggingface(config: LLMConfig, system: str, user: str) -> Dict: + """Send a review prompt to the Hugging Face Inference API. + + Uses the OpenAI-compatible chat completions endpoint provided by + Hugging Face's free serverless Inference API. Supports all models + available on the HF Hub with the Inference API enabled. + + Args: + config: LLM configuration with model and optional api_key. + system: System persona message content. + user: User prompt message content. + + Returns: + Dictionary with 'content', 'prompt_tokens', and 'completion_tokens'. + """ + from huggingface_hub import InferenceClient + + api_key = _resolve_hf_api_key(config) + client = InferenceClient(api_key=api_key or None, timeout=config.timeout_seconds) + response = client.chat.completions.create( + model=config.model, + messages=[ + {"role": "system", "content": system}, + {"role": "user", "content": user}, + ], + temperature=config.temperature, + max_tokens=config.max_tokens, + ) + return { + "content": response.choices[0].message.content, + "prompt_tokens": response.usage.prompt_tokens if response.usage else 0, + "completion_tokens": response.usage.completion_tokens if response.usage else 0, + } + + +def _call_anthropic(config: LLMConfig, system: str, user: str) -> Dict: + """Send a review prompt to the Anthropic API and return the response. + + Args: + config: LLM configuration with API key and model settings. + system: System persona message content. + user: User prompt message content. + + Returns: + Dictionary with 'content', 'prompt_tokens', and 'completion_tokens'. + """ + from anthropic import Anthropic + + client = Anthropic(api_key=config.api_key, timeout=config.timeout_seconds) + response = client.messages.create( + model=config.model, + max_tokens=config.max_tokens, + system=system, + messages=[{"role": "user", "content": user}], + temperature=config.temperature, + ) + return { + "content": response.content[0].text, + "prompt_tokens": response.usage.input_tokens, + "completion_tokens": response.usage.output_tokens, + } + + +def _select_provider_call(provider: str): + """Select the appropriate API call function for the configured provider. + + Args: + provider: LLM provider name ('huggingface', 'openai', or 'anthropic'). + + Returns: + Callable that sends prompts to the selected provider API. + + Raises: + ValueError: If the provider is not supported. + """ + providers = { + "huggingface": _call_huggingface, + "openai": _call_openai, + "anthropic": _call_anthropic, + } + if provider not in providers: + raise ValueError(f"Unsupported LLM provider: {provider}") + return providers[provider] + + +def _strip_markdown_fences(text: str) -> str: + """Remove markdown code fences from LLM response text. + + Args: + text: Raw LLM response that may contain code fence markers. + + Returns: + Cleaned text with markdown fences stripped. + """ + text = text.strip() + if text.startswith("```json"): + text = text[7:] + if text.startswith("```"): + text = text[3:] + if text.endswith("```"): + text = text[:-3] + return text.strip() + + +def _parse_severity(raw_severity: str) -> Severity: + """Convert a raw severity string to a Severity enum value. + + Args: + raw_severity: Severity string from LLM JSON output. + + Returns: + Corresponding Severity enum value, defaulting to MEDIUM. + """ + return SEVERITY_MAP.get(raw_severity.lower(), Severity.MEDIUM) + + +def _parse_category(raw_category: str) -> FindingCategory: + """Convert a raw category string to a FindingCategory enum value. + + Args: + raw_category: Category string from LLM JSON output. + + Returns: + Corresponding FindingCategory enum value, defaulting to LLM_REVIEW. + """ + return CATEGORY_MAP.get(raw_category.lower(), FindingCategory.LLM_REVIEW) + + +def _parse_single_finding(item: dict, file_path: str) -> Finding: + """Parse a single finding dictionary from LLM output into a Finding object. + + Args: + item: Dictionary containing finding fields from LLM JSON response. + file_path: Source file path the finding relates to. + + Returns: + Populated Finding dataclass instance. + """ + return Finding( + title=item.get("title", "LLM Finding"), + description=item.get("description", ""), + severity=_parse_severity(item.get("severity", "medium")), + category=_parse_category(item.get("category", "llm_review")), + file_path=file_path, + line_number=item.get("line_number", 0), + code_snippet="", + recommendation=item.get("recommendation", ""), + confidence=float(item.get("confidence", 0.7)), + cwe_id=item.get("cwe_id"), + ) + + +def _parse_findings_json(raw_text: str, file_path: str) -> List[Finding]: + """Parse LLM JSON response text into a list of Finding objects. + + Args: + raw_text: Raw JSON text from the LLM response. + file_path: Source file path the findings relate to. + + Returns: + List of parsed Finding objects, empty list on parse failure. + """ + try: + cleaned = _strip_markdown_fences(raw_text) + items = json.loads(cleaned) + if not isinstance(items, list): + items = [items] + return [_parse_single_finding(item, file_path) for item in items] + except (json.JSONDecodeError, TypeError, KeyError) as exc: + logger.warning("Failed to parse LLM response as JSON: %s", exc) + return [] + + +class LLMReviewer: + """Orchestrates LLM-powered security reviews of source code files. + + Manages prompt construction, API communication, response parsing, + and finding assembly for OpenAI and Anthropic providers. + """ + + def __init__(self, config: LLMConfig): + """Initialize the LLM reviewer with provider configuration. + + Args: + config: LLM configuration controlling provider, model, and limits. + """ + self.config = config + self.prompt_library = PromptLibrary() + self._call_fn = _select_provider_call(config.provider) + + def _send_review_request(self, system: str, user: str) -> Dict: + """Send a prompt pair to the configured LLM provider. + + Args: + system: System persona prompt text. + user: User review prompt text with code. + + Returns: + Provider response dictionary with content and token counts. + """ + logger.info( + "Sending review request to %s/%s", self.config.provider, self.config.model + ) + return self._call_fn(self.config, system, user) + + def _build_review_result( + self, response: Dict, file_path: str, template_name: str + ) -> LLMReview: + """Assemble an LLMReview from a provider response and parsed findings. + + Args: + response: Provider response with content and token usage. + file_path: Source file that was reviewed. + template_name: Name of the prompt template used. + + Returns: + Populated LLMReview with parsed findings. + """ + findings = _parse_findings_json(response["content"], file_path) + return LLMReview( + reviewer_name=f"llm_{template_name}", + model_used=self.config.model, + prompt_tokens=response.get("prompt_tokens", 0), + completion_tokens=response.get("completion_tokens", 0), + findings=findings, + raw_response=response["content"], + ) + + def review_file( + self, file_path: str, code: str, template_name: str = "comprehensive" + ) -> LLMReview: + """Review a single source file using a specified prompt template. + + Args: + file_path: Path to the source file being reviewed. + code: Full source code content of the file. + template_name: Prompt template to use for the review. + + Returns: + LLMReview containing all findings from the review. + """ + prompts = self.prompt_library.build_prompt(template_name, code) + response = self._send_review_request(prompts["system"], prompts["user"]) + review = self._build_review_result(response, file_path, template_name) + logger.info("Review of %s found %d findings", file_path, len(review.findings)) + return review + + def _read_file_safe(self, file_path: str) -> Optional[str]: + """Read a file with graceful error handling for LLM review. + + Args: + file_path: Absolute path to the file to read. + + Returns: + File contents as string, or None on read failure. + """ + try: + with open(file_path, "r", encoding="utf-8", errors="ignore") as f: + return f.read() + except OSError as exc: + logger.warning("Could not read %s for LLM review: %s", file_path, exc) + return None + + def _review_single_file( + self, file_path: str, template_name: str + ) -> Optional[LLMReview]: + """Read and review a single file, handling errors gracefully. + + Args: + file_path: Path to the source file to review. + template_name: Prompt template name to use. + + Returns: + LLMReview if successful, None if file could not be read. + """ + code = self._read_file_safe(file_path) + if not code: + return None + return self.review_file(file_path, code, template_name) + + def review_files( + self, file_paths: List[str], template_name: str = "comprehensive" + ) -> List[LLMReview]: + """Review multiple files sequentially with the specified template. + + Args: + file_paths: List of source file paths to review. + template_name: Prompt template to use for all reviews. + + Returns: + List of LLMReview objects, one per successfully reviewed file. + """ + reviews = [] + for file_path in file_paths: + review = self._review_single_file(file_path, template_name) + if review: + reviews.append(review) + logger.info( + "Completed LLM review of %d/%d files", len(reviews), len(file_paths) + ) + return reviews diff --git a/tmt/models.py b/tmt/models.py new file mode 100644 index 0000000..c1acaab --- /dev/null +++ b/tmt/models.py @@ -0,0 +1,197 @@ +"""Data models for threat modeling findings, scan results, and reports. + +Provides dataclass-based models for representing security findings, +scan results from pattern-based scanners, LLM review outputs, and +aggregated threat model reports. +""" + +import datetime +from dataclasses import dataclass, field +from enum import Enum +from typing import List, Optional + + +class Severity(Enum): + """Enumeration of finding severity levels aligned with CVSS qualitative ratings.""" + + CRITICAL = "critical" + HIGH = "high" + MEDIUM = "medium" + LOW = "low" + INFO = "info" + + +class FindingCategory(Enum): + """Enumeration of threat finding categories tracked by the toolkit.""" + + REPLAY_ATTACK = "replay_attack" + RACE_CONDITION = "race_condition" + TOKEN_ABUSE = "token_abuse" + AUTH_SESSION = "auth_session" + API_ROUTE = "api_route" + LLM_REVIEW = "llm_review" + + +SEVERITY_RANK = { + Severity.CRITICAL: 5, + Severity.HIGH: 4, + Severity.MEDIUM: 3, + Severity.LOW: 2, + Severity.INFO: 1, +} +"""Numeric ranking for severity comparison and sorting.""" + + +@dataclass +class Finding: + """Represents a single security finding from a scan or LLM review. + + Attributes: + title: Short descriptive title of the finding. + description: Detailed explanation of the vulnerability. + severity: Severity level of the finding. + category: Category classification of the finding. + file_path: Path to the affected source file. + line_number: Line number where the issue was detected. + code_snippet: Relevant code excerpt surrounding the finding. + recommendation: Actionable remediation guidance. + confidence: Confidence score between 0.0 and 1.0. + cwe_id: Optional CWE identifier for the vulnerability class. + """ + + title: str + description: str + severity: Severity + category: FindingCategory + file_path: str + line_number: int + code_snippet: str + recommendation: str + confidence: float = 0.8 + cwe_id: Optional[str] = None + + +def _utc_now_iso() -> str: + """Return the current UTC time as an ISO 8601 formatted string. + + Returns: + ISO 8601 timestamp string with UTC timezone. + """ + return datetime.datetime.now(datetime.timezone.utc).isoformat() + + +@dataclass +class ScanResult: + """Container for results produced by a single scanner execution. + + Attributes: + scanner_name: Identifier of the scanner that produced results. + findings: List of security findings detected. + files_scanned: Number of files analyzed during the scan. + scan_duration_seconds: Wall-clock time taken for the scan. + timestamp: ISO 8601 timestamp of when the scan completed. + """ + + scanner_name: str + findings: List[Finding] = field(default_factory=list) + files_scanned: int = 0 + scan_duration_seconds: float = 0.0 + timestamp: str = field(default_factory=_utc_now_iso) + + +@dataclass +class LLMReview: + """Container for results from an LLM-powered security review. + + Attributes: + reviewer_name: Identifier of the review workflow used. + model_used: LLM model identifier used for the review. + prompt_tokens: Number of input tokens consumed. + completion_tokens: Number of output tokens generated. + findings: List of security findings from the review. + raw_response: Unprocessed LLM response text. + timestamp: ISO 8601 timestamp of when the review completed. + """ + + reviewer_name: str + model_used: str + prompt_tokens: int = 0 + completion_tokens: int = 0 + findings: List[Finding] = field(default_factory=list) + raw_response: str = "" + timestamp: str = field(default_factory=_utc_now_iso) + + +@dataclass +class ThreatModelReport: + """Complete threat model report aggregating all scan and review results. + + Attributes: + project_name: Name of the project being assessed. + scan_results: Results from all pattern-based scanners. + llm_reviews: Results from all LLM-powered reviews. + total_findings: Total count of all findings across sources. + critical_count: Number of critical severity findings. + high_count: Number of high severity findings. + medium_count: Number of medium severity findings. + low_count: Number of low severity findings. + info_count: Number of informational findings. + timestamp: ISO 8601 timestamp of report generation. + """ + + project_name: str = "" + scan_results: List[ScanResult] = field(default_factory=list) + llm_reviews: List[LLMReview] = field(default_factory=list) + total_findings: int = 0 + critical_count: int = 0 + high_count: int = 0 + medium_count: int = 0 + low_count: int = 0 + info_count: int = 0 + timestamp: str = field(default_factory=_utc_now_iso) + + +def _count_by_severity(findings: List[Finding], severity: Severity) -> int: + """Count findings matching a specific severity level. + + Args: + findings: List of Finding objects to count. + severity: Target severity level to match. + + Returns: + Integer count of findings with the specified severity. + """ + return sum(1 for f in findings if f.severity == severity) + + +def _gather_all_findings(report: ThreatModelReport) -> List[Finding]: + """Collect all findings from scan results and LLM reviews into one list. + + Args: + report: ThreatModelReport containing scan results and LLM reviews. + + Returns: + Flat list of all Finding objects from every source. + """ + scan_findings = [f for sr in report.scan_results for f in sr.findings] + llm_findings = [f for lr in report.llm_reviews for f in lr.findings] + return scan_findings + llm_findings + + +def compute_report_statistics(report: ThreatModelReport) -> ThreatModelReport: + """Compute and populate severity counts on a ThreatModelReport. + + Args: + report: ThreatModelReport to update with computed statistics. + + Returns: + The same ThreatModelReport with updated count fields. + """ + all_findings = _gather_all_findings(report) + report.total_findings = len(all_findings) + report.critical_count = _count_by_severity(all_findings, Severity.CRITICAL) + report.high_count = _count_by_severity(all_findings, Severity.HIGH) + report.medium_count = _count_by_severity(all_findings, Severity.MEDIUM) + report.low_count = _count_by_severity(all_findings, Severity.LOW) + report.info_count = _count_by_severity(all_findings, Severity.INFO) + return report diff --git a/tmt/runner.py b/tmt/runner.py new file mode 100644 index 0000000..06876bb --- /dev/null +++ b/tmt/runner.py @@ -0,0 +1,233 @@ +"""Threat model runner orchestrating the full scan-review-report loop. + +Coordinates pattern-based scanners, optional LLM-powered reviews, +and report generation into a single repeatable workflow that teams +can execute each release cycle. +""" + +import logging +import os +import time +from typing import List, Optional + +from tmt.config import TMTConfig +from tmt.llm.reviewer import LLMReviewer +from tmt.models import LLMReview, ScanResult, ThreatModelReport +from tmt.reports.generator import ReportGenerator +from tmt.scanners.api_route_scanner import APIRouteScanner +from tmt.scanners.auth_session_scanner import AuthSessionScanner +from tmt.scanners.base_scanner import BaseScanner +from tmt.scanners.race_condition_scanner import RaceConditionScanner +from tmt.scanners.replay_scanner import ReplayScanner +from tmt.scanners.token_abuse_scanner import TokenAbuseScanner + +logger = logging.getLogger(__name__) + + +def _create_scanner(scanner_cls, config: TMTConfig) -> BaseScanner: + """Instantiate a scanner with shared configuration parameters. + + Args: + scanner_cls: Scanner class to instantiate. + config: Top-level TMT configuration. + + Returns: + Initialized scanner instance. + """ + return scanner_cls( + config=config.scanner, + file_extensions=config.file_extensions, + exclude_dirs=config.exclude_dirs, + ) + + +def _build_all_scanners(config: TMTConfig) -> List[BaseScanner]: + """Build the complete set of pattern-based scanners. + + Args: + config: Top-level TMT configuration. + + Returns: + List of initialized scanner instances ordered by category. + """ + scanner_classes = [ + ReplayScanner, + RaceConditionScanner, + TokenAbuseScanner, + AuthSessionScanner, + APIRouteScanner, + ] + return [_create_scanner(cls, config) for cls in scanner_classes] + + +def _run_all_scanners( + scanners: List[BaseScanner], target_path: str +) -> List[ScanResult]: + """Execute all scanners against a target directory. + + Args: + scanners: List of initialized scanner instances. + target_path: Root directory path to scan. + + Returns: + List of ScanResult objects from all scanners. + """ + results = [] + for scanner in scanners: + logger.info("Running %s...", scanner.scanner_name) + result = scanner.scan(target_path) + results.append(result) + return results + + +def _collect_llm_target_files(config: TMTConfig, target_path: str) -> List[str]: + """Collect files for LLM review using the first scanner's file collection. + + Args: + config: Top-level TMT configuration. + target_path: Root directory path to collect files from. + + Returns: + List of file paths suitable for LLM review. + """ + collector = _create_scanner(ReplayScanner, config) + return collector._collect_files(target_path) + + +def _run_llm_reviews(config: TMTConfig, target_path: str) -> List[LLMReview]: + """Execute LLM-powered reviews against target files. + + Args: + config: Top-level TMT configuration with LLM settings. + target_path: Root directory path containing files to review. + + Returns: + List of LLMReview objects from completed reviews. + """ + reviewer = LLMReviewer(config.llm) + file_paths = _collect_llm_target_files(config, target_path) + logger.info("Submitting %d files for LLM review", len(file_paths)) + return reviewer.review_files(file_paths) + + +def _build_report( + config: TMTConfig, scan_results: List[ScanResult], llm_reviews: List[LLMReview] +) -> ThreatModelReport: + """Assemble a ThreatModelReport from scan results and LLM reviews. + + Args: + config: Top-level TMT configuration. + scan_results: Results from pattern-based scanners. + llm_reviews: Results from LLM-powered reviews. + + Returns: + Assembled ThreatModelReport ready for rendering. + """ + return ThreatModelReport( + project_name=config.project_name, + scan_results=scan_results, + llm_reviews=llm_reviews, + ) + + +class ThreatModelRunner: + """Orchestrates the complete threat modeling loop for a release cycle. + + Manages the lifecycle of initializing scanners, executing pattern-based + scans, optionally running LLM-powered reviews, and generating reports. + """ + + def __init__(self, config: TMTConfig): + """Initialize the runner with project configuration. + + Args: + config: Top-level TMT configuration for all components. + """ + self.config = config + self.scanners = _build_all_scanners(config) + self.report_generator = ReportGenerator(config.report) + + def _resolve_target_path(self, target_path: Optional[str]) -> str: + """Resolve the target scan directory from explicit path or config. + + Args: + target_path: Optional explicit target directory override. + + Returns: + Absolute path to the target directory for scanning. + """ + if target_path: + return os.path.abspath(target_path) + return os.path.abspath(".") + + def _execute_scans(self, target_path: str) -> List[ScanResult]: + """Run all pattern-based scanners if scanning is enabled. + + Args: + target_path: Resolved absolute path to scan. + + Returns: + List of ScanResult objects, empty if scanning disabled. + """ + if not self.config.scanner.enabled: + logger.info("Pattern scanning disabled, skipping") + return [] + return _run_all_scanners(self.scanners, target_path) + + def _execute_llm_reviews(self, target_path: str) -> List[LLMReview]: + """Run LLM-powered reviews if LLM integration is enabled. + + Args: + target_path: Resolved absolute path to review. + + Returns: + List of LLMReview objects, empty if LLM disabled. + """ + if not self.config.llm.enabled: + logger.info("LLM review disabled, skipping") + return [] + return _run_llm_reviews(self.config, target_path) + + def _log_completion_summary( + self, report: ThreatModelReport, elapsed: float + ) -> None: + """Log a summary of the completed threat model run. + + Args: + report: Completed report with computed statistics. + elapsed: Total wall-clock time in seconds. + """ + logger.info( + "Threat model complete in %.1fs: %d findings " + "(%d critical, %d high, %d medium, %d low, %d info)", + elapsed, + report.total_findings, + report.critical_count, + report.high_count, + report.medium_count, + report.low_count, + report.info_count, + ) + + def run(self, target_path: Optional[str] = None) -> ThreatModelReport: + """Execute the complete threat modeling loop and generate reports. + + Args: + target_path: Optional directory to scan. Defaults to current directory. + + Returns: + Completed ThreatModelReport with all findings and statistics. + """ + start_time = time.time() + resolved_path = self._resolve_target_path(target_path) + logger.info( + "Starting threat model for '%s' at %s", + self.config.project_name, + resolved_path, + ) + scan_results = self._execute_scans(resolved_path) + llm_reviews = self._execute_llm_reviews(resolved_path) + report = _build_report(self.config, scan_results, llm_reviews) + output_paths = self.report_generator.generate(report) + self._log_completion_summary(report, time.time() - start_time) + return report diff --git a/tmt/scanners/__init__.py b/tmt/scanners/__init__.py new file mode 100644 index 0000000..af0c3af --- /dev/null +++ b/tmt/scanners/__init__.py @@ -0,0 +1,17 @@ +"""Pattern-based security scanners for threat modeling.""" + +from tmt.scanners.base_scanner import BaseScanner +from tmt.scanners.replay_scanner import ReplayScanner +from tmt.scanners.race_condition_scanner import RaceConditionScanner +from tmt.scanners.token_abuse_scanner import TokenAbuseScanner +from tmt.scanners.auth_session_scanner import AuthSessionScanner +from tmt.scanners.api_route_scanner import APIRouteScanner + +__all__ = [ + "BaseScanner", + "ReplayScanner", + "RaceConditionScanner", + "TokenAbuseScanner", + "AuthSessionScanner", + "APIRouteScanner", +] diff --git a/tmt/scanners/api_route_scanner.py b/tmt/scanners/api_route_scanner.py new file mode 100644 index 0000000..d22ac86 --- /dev/null +++ b/tmt/scanners/api_route_scanner.py @@ -0,0 +1,192 @@ +"""Scanner for detecting API route security vulnerabilities. + +Identifies missing input validation, absent rate limiting, verbose +error exposure, insecure CORS configuration, unprotected admin +endpoints, and mass assignment risks in API route handlers. +""" + +from typing import List + +from tmt.config import ScannerConfig +from tmt.models import FindingCategory, Severity +from tmt.scanners.base_scanner import BaseScanner, VulnerabilityPattern + +# ────────────────────────────────────────────────────────────────────── +# Missing input validation +# ────────────────────────────────────────────────────────────────────── + +MISSING_INPUT_VALIDATION = VulnerabilityPattern( + name="Endpoint Missing Input Validation", + trigger_pattern=r"(request\.(json|form|data|body|args|params|query)\s*(\[|\.get))", + defense_pattern=r"(validate|schema|serializer|pydantic|marshmallow|cerberus|joi\.|yup\.|zod\.|express-validator|class-validator)", + context_window=15, + description=( + "Request data is accessed directly without visible schema validation. " + "Missing input validation can lead to injection attacks, type confusion, " + "and unexpected application behavior from malformed data." + ), + severity=Severity.MEDIUM, + category=FindingCategory.API_ROUTE, + recommendation=( + "Validate all input using a schema library: Pydantic or Marshmallow " + "for Python, Joi or Zod for JavaScript. Define strict schemas with " + "type constraints, length limits, and allowed value ranges." + ), + cwe_id="CWE-20", + confidence=0.6, +) + +# ────────────────────────────────────────────────────────────────────── +# Missing rate limiting +# ────────────────────────────────────────────────────────────────────── + +MISSING_RATE_LIMIT = VulnerabilityPattern( + name="Endpoint Missing Rate Limiting", + trigger_pattern=r"@(app|router|blueprint)\.(post|put|patch|delete)\s*\(", + defense_pattern=r"(rate_limit|throttle|RateLimit|slowapi|ratelimit|limiter|express-rate-limit|bottleneck)", + context_window=10, + description=( + "Mutating endpoint has no rate limiting. Without rate limiting, " + "attackers can brute-force credentials, exhaust resources, or " + "abuse business logic at scale." + ), + severity=Severity.MEDIUM, + category=FindingCategory.API_ROUTE, + recommendation=( + "Apply rate limiting to all endpoints, with stricter limits on " + "authentication and resource-creation routes. Use per-user and " + "per-IP limits with a sliding window algorithm." + ), + cwe_id="CWE-770", + confidence=0.5, +) + +# ────────────────────────────────────────────────────────────────────── +# Verbose error exposure +# ────────────────────────────────────────────────────────────────────── + +VERBOSE_ERROR_EXPOSURE = VulnerabilityPattern( + name="Verbose Error Details Exposed in Response", + trigger_pattern=r"(traceback\.|str\(e\)|str\(err\)|exc_info|stack.*trace|error.*message.*str\(|\.message\s*\})", + defense_pattern=r"(if\s+.*DEBUG|production|sanitize.*error|generic.*error|log.*error.*return|sentry|logging\.exception)", + context_window=8, + description=( + "Exception details or stack traces may be returned to API clients. " + "Verbose errors leak implementation details, library versions, file " + "paths, and database structure to attackers." + ), + severity=Severity.MEDIUM, + category=FindingCategory.API_ROUTE, + recommendation=( + "Return generic error messages to clients and log full details " + "server-side. Use a global error handler that returns sanitized " + "responses with error codes rather than internal messages." + ), + cwe_id="CWE-209", + confidence=0.7, +) + +# ────────────────────────────────────────────────────────────────────── +# Insecure CORS configuration +# ────────────────────────────────────────────────────────────────────── + +INSECURE_CORS = VulnerabilityPattern( + name="Overly Permissive CORS Configuration", + trigger_pattern=r"""(origins?\s*=\s*['"]\*['"]|Access-Control-Allow-Origin.*\*|allow_origins\s*=\s*\[['"]?\*['"]?\])""", + defense_pattern=None, + context_window=5, + description=( + "CORS is configured to allow all origins with a wildcard. This " + "permits any website to make authenticated cross-origin requests " + "to the API when credentials are included." + ), + severity=Severity.HIGH, + category=FindingCategory.API_ROUTE, + recommendation=( + "Specify an explicit allowlist of trusted origins. Never combine " + "wildcard origins with allow_credentials=True. Use environment-based " + "configuration to set different origins per deployment." + ), + cwe_id="CWE-942", + confidence=0.9, +) + +# ────────────────────────────────────────────────────────────────────── +# Unprotected admin endpoints +# ────────────────────────────────────────────────────────────────────── + +UNPROTECTED_ADMIN = VulnerabilityPattern( + name="Admin Endpoint Without Role Check", + trigger_pattern=r"""(admin|superuser|staff|management|internal)[/'"]""", + defense_pattern=r"(is_admin|is_superuser|is_staff|role.*admin|admin_required|@admin|permission_classes|has_role|authorize.*admin)", + context_window=10, + description=( + "An endpoint with an admin-related path does not have visible " + "role-based authorization checks. Access to admin functionality " + "without proper role verification is a privilege escalation risk." + ), + severity=Severity.CRITICAL, + category=FindingCategory.API_ROUTE, + recommendation=( + "Enforce role-based access control on all admin endpoints. Use " + "decorators or middleware that verify the user has an admin role " + "before executing handler logic. Apply defense in depth." + ), + cwe_id="CWE-269", + confidence=0.7, +) + +# ────────────────────────────────────────────────────────────────────── +# Mass assignment +# ────────────────────────────────────────────────────────────────────── + +MASS_ASSIGNMENT = VulnerabilityPattern( + name="Potential Mass Assignment Vulnerability", + trigger_pattern=r"(\*\*request\.(json|data|form|body)|\*\*req\.body|\.update\s*\(\s*request\.(json|data)|\.create\s*\(\s*\*\*)", + defense_pattern=r"(schema|serializer|allow(ed)?_fields|pick\s*\(|whitelist|only\s*=|exclude\s*=|fields\s*=)", + context_window=10, + description=( + "User input is spread directly into a model create or update call. " + "An attacker can inject unexpected fields like is_admin=True or " + "role=superuser to escalate their privileges." + ), + severity=Severity.CRITICAL, + category=FindingCategory.API_ROUTE, + recommendation=( + "Never spread raw request data into models. Use a schema or " + "serializer to explicitly define which fields are accepted. " + "Reject unknown fields and validate types strictly." + ), + cwe_id="CWE-915", + confidence=0.8, +) + + +class APIRouteScanner(BaseScanner): + """Scanner specialized in detecting API route security vulnerabilities. + + Detects missing input validation, absent rate limiting, verbose errors, + insecure CORS, unprotected admin endpoints, and mass assignment risks + across Python and JavaScript codebases. + """ + + PATTERNS: List[VulnerabilityPattern] = [ + MISSING_INPUT_VALIDATION, + MISSING_RATE_LIMIT, + VERBOSE_ERROR_EXPOSURE, + INSECURE_CORS, + UNPROTECTED_ADMIN, + MASS_ASSIGNMENT, + ] + + def __init__( + self, config: ScannerConfig, file_extensions: List[str], exclude_dirs: List[str] + ): + """Initialize the API route scanner. + + Args: + config: Scanner configuration controlling behavior. + file_extensions: File extensions to include in scanning. + exclude_dirs: Directory names to skip during traversal. + """ + super().__init__(config, file_extensions, exclude_dirs) diff --git a/tmt/scanners/auth_session_scanner.py b/tmt/scanners/auth_session_scanner.py new file mode 100644 index 0000000..25abb35 --- /dev/null +++ b/tmt/scanners/auth_session_scanner.py @@ -0,0 +1,192 @@ +"""Scanner for detecting authentication and session management vulnerabilities. + +Identifies missing authentication decorators, insecure session configuration, +absent CSRF protection, weak password handling, session fixation risks, +and missing authorization checks on protected endpoints. +""" + +from typing import List + +from tmt.config import ScannerConfig +from tmt.models import FindingCategory, Severity +from tmt.scanners.base_scanner import BaseScanner, VulnerabilityPattern + +# ────────────────────────────────────────────────────────────────────── +# Missing authentication on route handlers +# ────────────────────────────────────────────────────────────────────── + +MISSING_AUTH_DECORATOR = VulnerabilityPattern( + name="Route Handler Missing Authentication", + trigger_pattern=r"@(app|router|blueprint)\.(get|post|put|patch|delete)\s*\(", + defense_pattern=r"(login_required|auth_required|authenticated|Depends.*auth|jwt_required|token_required|IsAuthenticated|@require|@protect|@secured)", + context_window=8, + description=( + "An API route handler does not have a visible authentication " + "decorator or dependency. Unauthenticated access to endpoints " + "can expose sensitive data or allow unauthorized operations." + ), + severity=Severity.HIGH, + category=FindingCategory.AUTH_SESSION, + recommendation=( + "Apply an authentication decorator or dependency to every non-public " + "endpoint. Use a whitelist approach where routes are authenticated " + "by default and explicitly marked as public." + ), + cwe_id="CWE-306", + confidence=0.6, +) + +# ────────────────────────────────────────────────────────────────────── +# Insecure session configuration +# ────────────────────────────────────────────────────────────────────── + +INSECURE_SESSION_CONFIG = VulnerabilityPattern( + name="Insecure Session Cookie Configuration", + trigger_pattern=r"(SESSION_COOKIE_SECURE|session.*secure|cookie.*secure)\s*=\s*(False|false|0)", + defense_pattern=None, + context_window=5, + description=( + "Session cookie is configured without the Secure flag. Cookies " + "will be transmitted over unencrypted HTTP connections, allowing " + "session hijacking via network sniffing." + ), + severity=Severity.HIGH, + category=FindingCategory.AUTH_SESSION, + recommendation=( + "Set SESSION_COOKIE_SECURE=True and SESSION_COOKIE_HTTPONLY=True. " + "Also set SESSION_COOKIE_SAMESITE='Lax' or 'Strict' to prevent " + "CSRF attacks via cookie inclusion." + ), + cwe_id="CWE-614", + confidence=0.95, +) + +# ────────────────────────────────────────────────────────────────────── +# Missing CSRF protection +# ────────────────────────────────────────────────────────────────────── + +MISSING_CSRF = VulnerabilityPattern( + name="Missing CSRF Protection on State-Changing Endpoint", + trigger_pattern=r"@(app|router|blueprint)\.(post|put|patch|delete)\s*\(", + defense_pattern=r"(csrf|CSRFProtect|CsrfViewMiddleware|csurf|_token|xsrf|anti_forgery|SameSite|Bearer)", + context_window=20, + description=( + "State-changing endpoint lacks visible CSRF protection. Without CSRF " + "tokens or SameSite cookie policy, an attacker can craft malicious " + "pages that trigger authenticated actions on behalf of users." + ), + severity=Severity.MEDIUM, + category=FindingCategory.AUTH_SESSION, + recommendation=( + "For cookie-based auth: implement CSRF tokens on all state-changing " + "endpoints. For token-based auth (Bearer tokens): ensure tokens are " + "not stored in cookies. Set SameSite=Lax on session cookies." + ), + cwe_id="CWE-352", + confidence=0.5, +) + +# ────────────────────────────────────────────────────────────────────── +# Weak password hashing +# ────────────────────────────────────────────────────────────────────── + +WEAK_PASSWORD_HASH = VulnerabilityPattern( + name="Weak Password Hashing Algorithm", + trigger_pattern=r"(hashlib\.(md5|sha1|sha256)\s*\(|MD5|SHA1|createHash\s*\(\s*['\"](?:md5|sha1)['\"])", + defense_pattern=r"(bcrypt|argon2|scrypt|pbkdf2|passlib|password_hash|hash_password)", + context_window=10, + description=( + "Password hashing uses a fast, non-salted algorithm like MD5 or SHA1. " + "These can be reversed with rainbow tables or brute-forced at billions " + "of attempts per second on modern GPUs." + ), + severity=Severity.CRITICAL, + category=FindingCategory.AUTH_SESSION, + recommendation=( + "Use bcrypt, argon2id, or scrypt for password hashing. These algorithms " + "include salting and configurable work factors that resist brute-force. " + "Migrate existing hashes on next user login." + ), + cwe_id="CWE-916", + confidence=0.85, +) + +# ────────────────────────────────────────────────────────────────────── +# Session fixation risk +# ────────────────────────────────────────────────────────────────────── + +SESSION_FIXATION = VulnerabilityPattern( + name="Session Not Regenerated After Authentication", + trigger_pattern=r"(def\s+login|def\s+authenticate|def\s+sign_in|async\s+def\s+login)\s*\(", + defense_pattern=r"(session\.regenerate|cycle_key|rotate.*session|new_session|session\.clear|flush.*session|create_session)", + context_window=15, + description=( + "Login handler does not regenerate the session ID after successful " + "authentication. An attacker who sets a session cookie before login " + "retains access to the authenticated session." + ), + severity=Severity.HIGH, + category=FindingCategory.AUTH_SESSION, + recommendation=( + "Regenerate the session ID immediately after successful authentication. " + "In Django use request.session.cycle_key(), in Flask use " + "session.regenerate(), in Express use req.session.regenerate()." + ), + cwe_id="CWE-384", + confidence=0.7, +) + +# ────────────────────────────────────────────────────────────────────── +# Missing authorization (IDOR risk) +# ────────────────────────────────────────────────────────────────────── + +MISSING_AUTHORIZATION_CHECK = VulnerabilityPattern( + name="Object Access Without Authorization Check", + trigger_pattern=r"\.(get|filter|find_one|findById|findOne)\s*\(\s*(request\.(args|params|query|json|form|data)|req\.(params|query|body))", + defense_pattern=r"(owner|user_id.*current|current_user|request\.user|belongs_to|authorize|permission|can\s*\(|has_perm)", + context_window=10, + description=( + "Database query uses user-supplied ID without checking ownership or " + "permissions. An attacker can modify the ID parameter to access " + "other users' data (Insecure Direct Object Reference)." + ), + severity=Severity.CRITICAL, + category=FindingCategory.AUTH_SESSION, + recommendation=( + "Always filter queries by the authenticated user's ID or check " + "object ownership before returning data. Use a policy layer or " + "scope queries: Model.objects.filter(user=request.user, id=obj_id)." + ), + cwe_id="CWE-639", + confidence=0.75, +) + + +class AuthSessionScanner(BaseScanner): + """Scanner specialized in detecting authentication and session vulnerabilities. + + Detects missing authentication, insecure session configuration, + CSRF gaps, weak password hashing, session fixation, and IDOR + risks across Python and JavaScript codebases. + """ + + PATTERNS: List[VulnerabilityPattern] = [ + MISSING_AUTH_DECORATOR, + INSECURE_SESSION_CONFIG, + MISSING_CSRF, + WEAK_PASSWORD_HASH, + SESSION_FIXATION, + MISSING_AUTHORIZATION_CHECK, + ] + + def __init__( + self, config: ScannerConfig, file_extensions: List[str], exclude_dirs: List[str] + ): + """Initialize the authentication and session scanner. + + Args: + config: Scanner configuration controlling behavior. + file_extensions: File extensions to include in scanning. + exclude_dirs: Directory names to skip during traversal. + """ + super().__init__(config, file_extensions, exclude_dirs) diff --git a/tmt/scanners/base_scanner.py b/tmt/scanners/base_scanner.py new file mode 100644 index 0000000..d8b2e11 --- /dev/null +++ b/tmt/scanners/base_scanner.py @@ -0,0 +1,285 @@ +"""Base scanner providing shared file collection and pattern matching logic. + +All concrete scanners inherit from BaseScanner, which handles directory +traversal, file reading, regex-based vulnerability detection, and +finding creation in a framework-agnostic manner. +""" + +import logging +import os +import re +import time +from dataclasses import dataclass +from typing import List, Optional + +from tmt.config import ScannerConfig +from tmt.models import Finding, FindingCategory, ScanResult, Severity + +logger = logging.getLogger(__name__) + + +@dataclass +class VulnerabilityPattern: + """Defines a single vulnerability detection rule. + + Attributes: + name: Human-readable name of the vulnerability. + trigger_pattern: Regex that identifies potentially vulnerable code. + defense_pattern: Regex for defensive code that mitigates the issue. + context_window: Number of lines around a match to check for defenses. + description: Detailed description of the vulnerability. + severity: Severity level assigned to findings from this pattern. + category: Finding category classification. + recommendation: Remediation guidance for developers. + cwe_id: CWE identifier for the vulnerability class. + confidence: Default confidence score for matches. + """ + + name: str + trigger_pattern: str + defense_pattern: Optional[str] + context_window: int + description: str + severity: Severity + category: FindingCategory + recommendation: str + cwe_id: Optional[str] = None + confidence: float = 0.8 + + +class BaseScanner: + """Abstract base scanner providing pattern-based vulnerability detection. + + Subclasses must define their own PATTERNS list of VulnerabilityPattern + objects. The base class handles file traversal, content reading, + pattern matching, and finding assembly. + """ + + PATTERNS: List[VulnerabilityPattern] = [] + + def __init__( + self, config: ScannerConfig, file_extensions: List[str], exclude_dirs: List[str] + ): + """Initialize the base scanner with configuration parameters. + + Args: + config: Scanner configuration controlling behavior. + file_extensions: File extensions to include in scanning. + exclude_dirs: Directory names to skip during traversal. + """ + self.config = config + self.file_extensions = file_extensions + self.exclude_dirs = set(exclude_dirs) + self.scanner_name = self.__class__.__name__ + + def _is_excluded_dir(self, dir_name: str) -> bool: + """Check whether a directory name should be excluded from scanning. + + Args: + dir_name: Name of the directory to check. + + Returns: + True if the directory should be skipped. + """ + return dir_name in self.exclude_dirs or dir_name.startswith(".") + + def _has_valid_extension(self, file_name: str) -> bool: + """Check whether a file has an extension included in the scan scope. + + Args: + file_name: Name of the file to check. + + Returns: + True if the file extension matches a configured extension. + """ + return any(file_name.endswith(ext) for ext in self.file_extensions) + + def _collect_files(self, target_path: str) -> List[str]: + """Walk a directory tree and collect all files matching scan criteria. + + Args: + target_path: Root directory path to begin traversal. + + Returns: + List of absolute file paths matching extension and exclusion rules. + """ + collected = [] + for root, dirs, files in os.walk(target_path): + dirs[:] = [d for d in dirs if not self._is_excluded_dir(d)] + for fname in files: + if self._has_valid_extension(fname): + collected.append(os.path.join(root, fname)) + return collected + + def _read_file_safe(self, file_path: str) -> Optional[str]: + """Read a file's contents with graceful error handling. + + Args: + file_path: Absolute path to the file to read. + + Returns: + File contents as a string, or None if reading failed. + """ + try: + with open(file_path, "r", encoding="utf-8", errors="ignore") as f: + return f.read() + except OSError as exc: + logger.warning("Could not read %s: %s", file_path, exc) + return None + + def _extract_context(self, lines: List[str], line_num: int, window: int) -> str: + """Extract a context window of code lines around a specific line. + + Args: + lines: All lines of the source file. + line_num: Zero-based line index of the match. + window: Number of lines above and below to include. + + Returns: + Concatenated string of context lines. + """ + start = max(0, line_num - window) + end = min(len(lines), line_num + window + 1) + return "\n".join(lines[start:end]) + + def _has_defense(self, context: str, defense_pattern: Optional[str]) -> bool: + """Check whether defensive code exists within a context block. + + Args: + context: Code context string to search. + defense_pattern: Regex pattern indicating proper defense. + + Returns: + True if defense pattern found or no defense pattern required. + """ + if not defense_pattern: + return False + return bool(re.search(defense_pattern, context, re.IGNORECASE)) + + def _find_trigger_lines(self, content: str, trigger_pattern: str) -> List[int]: + """Find all line numbers where a trigger pattern matches. + + Args: + content: Full file content to search. + trigger_pattern: Regex pattern identifying potentially vulnerable code. + + Returns: + List of zero-based line numbers with matches. + """ + lines = content.split("\n") + matched = [] + for i, line in enumerate(lines): + if re.search(trigger_pattern, line, re.IGNORECASE): + matched.append(i) + return matched + + def _create_finding( + self, pattern: VulnerabilityPattern, file_path: str, line_num: int, snippet: str + ) -> Finding: + """Create a Finding object from a matched vulnerability pattern. + + Args: + pattern: The VulnerabilityPattern that was matched. + file_path: Path to the file containing the finding. + line_num: One-based line number of the finding. + snippet: Code snippet from the surrounding context. + + Returns: + Populated Finding dataclass instance. + """ + return Finding( + title=pattern.name, + description=pattern.description, + severity=pattern.severity, + category=pattern.category, + file_path=file_path, + line_number=line_num, + code_snippet=snippet, + recommendation=pattern.recommendation, + confidence=pattern.confidence, + cwe_id=pattern.cwe_id, + ) + + def _scan_file_for_pattern( + self, content: str, file_path: str, pattern: VulnerabilityPattern + ) -> List[Finding]: + """Scan a single file's content against one vulnerability pattern. + + Args: + content: Full text content of the source file. + file_path: Path to the source file being scanned. + pattern: VulnerabilityPattern to match against. + + Returns: + List of Finding objects for undefended trigger matches. + """ + findings = [] + lines = content.split("\n") + trigger_lines = self._find_trigger_lines(content, pattern.trigger_pattern) + for line_num in trigger_lines: + context = self._extract_context(lines, line_num, pattern.context_window) + if not self._has_defense(context, pattern.defense_pattern): + finding = self._create_finding( + pattern, file_path, line_num + 1, context + ) + findings.append(finding) + return findings + + def _scan_single_file(self, file_path: str, content: str) -> List[Finding]: + """Apply all vulnerability patterns against a single file. + + Args: + file_path: Path to the file being scanned. + content: Full text content of the file. + + Returns: + Aggregated list of findings from all pattern checks. + """ + findings = [] + for pattern in self.PATTERNS: + pattern_findings = self._scan_file_for_pattern(content, file_path, pattern) + findings.extend(pattern_findings) + return findings + + def _process_files(self, file_paths: List[str]) -> List[Finding]: + """Read and scan each file, collecting all findings. + + Args: + file_paths: List of absolute file paths to scan. + + Returns: + Combined list of findings from all files. + """ + findings = [] + for file_path in file_paths: + content = self._read_file_safe(file_path) + if content: + findings.extend(self._scan_single_file(file_path, content)) + return findings + + def scan(self, target_path: str) -> ScanResult: + """Execute the full scan workflow against a target directory. + + Args: + target_path: Root directory path to scan for vulnerabilities. + + Returns: + ScanResult containing all findings and scan metadata. + """ + start_time = time.time() + file_paths = self._collect_files(target_path) + findings = self._process_files(file_paths) + elapsed = time.time() - start_time + logger.info( + "%s scanned %d files in %.2fs, found %d issues", + self.scanner_name, + len(file_paths), + elapsed, + len(findings), + ) + return ScanResult( + scanner_name=self.scanner_name, + findings=findings, + files_scanned=len(file_paths), + scan_duration_seconds=round(elapsed, 3), + ) diff --git a/tmt/scanners/race_condition_scanner.py b/tmt/scanners/race_condition_scanner.py new file mode 100644 index 0000000..2344597 --- /dev/null +++ b/tmt/scanners/race_condition_scanner.py @@ -0,0 +1,165 @@ +"""Scanner for detecting race condition vulnerabilities in application code. + +Identifies non-atomic read-modify-write sequences, time-of-check to +time-of-use (TOCTOU) patterns, concurrent resource access without locking, +and unprotected shared state modifications that enable race conditions. +""" + +from typing import List + +from tmt.config import ScannerConfig +from tmt.models import FindingCategory, Severity +from tmt.scanners.base_scanner import BaseScanner, VulnerabilityPattern + +# ────────────────────────────────────────────────────────────────────── +# Read-modify-write without atomicity +# ────────────────────────────────────────────────────────────────────── + +NON_ATOMIC_READ_MODIFY_WRITE = VulnerabilityPattern( + name="Non-Atomic Read-Modify-Write Sequence", + trigger_pattern=r"(\.\s*save\s*\(|\.\s*update\s*\(|UPDATE\s+.*SET)", + defense_pattern=r"(select_for_update|FOR UPDATE|atomic|transaction|lock|mutex|semaphore|compare_and_swap|F\s*\()", + context_window=10, + description=( + "A database record is read and then updated without atomic protection. " + "Concurrent requests can read stale state and apply conflicting writes, " + "leading to lost updates (e.g., balance overdraws, inventory oversells)." + ), + severity=Severity.HIGH, + category=FindingCategory.RACE_CONDITION, + recommendation=( + "Use SELECT FOR UPDATE, database-level atomic operations (e.g., " + "Django F() expressions), or application-level distributed locks. " + "Wrap read-modify-write in a serializable transaction." + ), + cwe_id="CWE-362", + confidence=0.7, +) + +# ────────────────────────────────────────────────────────────────────── +# TOCTOU (Time-of-Check to Time-of-Use) +# ────────────────────────────────────────────────────────────────────── + +TOCTOU_CHECK_THEN_ACT = VulnerabilityPattern( + name="TOCTOU Check-Then-Act Pattern", + trigger_pattern=r"(if\s+.*\.(exists|count|filter|find|get)\s*\(.*\).*:[\s\S]*?\.(create|save|insert|delete|remove)\s*\()", + defense_pattern=r"(atomic|transaction|lock|unique_together|unique=True|get_or_create|upsert|ON CONFLICT)", + context_window=12, + description=( + "Code checks for existence then acts on the result without atomicity. " + "Between the check and the action, another request can change the state, " + "causing phantom reads or duplicate inserts." + ), + severity=Severity.HIGH, + category=FindingCategory.RACE_CONDITION, + recommendation=( + "Replace check-then-act with atomic operations like get_or_create, " + "upsert, or INSERT ON CONFLICT. If not possible, wrap both the check " + "and action in a serializable transaction with proper locking." + ), + cwe_id="CWE-367", + confidence=0.7, +) + +# ────────────────────────────────────────────────────────────────────── +# Concurrent token/coupon redemption +# ────────────────────────────────────────────────────────────────────── + +CONCURRENT_REDEMPTION = VulnerabilityPattern( + name="Unguarded Concurrent Redemption", + trigger_pattern=r"(redeem|claim|activate|consume|use_coupon|apply_code|accept_invite)\s*\(", + defense_pattern=r"(atomic|transaction|lock|select_for_update|FOR UPDATE|mutex|semaphore|compare_and_swap)", + context_window=15, + description=( + "A redemption or claim operation is not protected against concurrent " + "execution. Multiple simultaneous requests can redeem the same token, " + "coupon, or invite before any single request marks it as consumed." + ), + severity=Severity.CRITICAL, + category=FindingCategory.RACE_CONDITION, + recommendation=( + "Use SELECT FOR UPDATE or a distributed lock around the redemption " + "check and mark-as-used operation. Ensure both steps execute within " + "a single atomic transaction." + ), + cwe_id="CWE-362", + confidence=0.8, +) + +# ────────────────────────────────────────────────────────────────────── +# Shared mutable state without synchronization +# ────────────────────────────────────────────────────────────────────── + +UNPROTECTED_SHARED_STATE = VulnerabilityPattern( + name="Shared Mutable State Without Synchronization", + trigger_pattern=r"(global\s+\w|threading\.Thread|asyncio\.\w+|celery.*\.delay|\.apply_async)", + defense_pattern=r"(Lock|RLock|Semaphore|Event|Condition|Queue|atomic|mutex|synchronized)", + context_window=10, + description=( + "Code uses global mutable state or spawns concurrent execution without " + "visible synchronization primitives. Unsynchronized shared state leads " + "to data corruption and non-deterministic behavior." + ), + severity=Severity.MEDIUM, + category=FindingCategory.RACE_CONDITION, + recommendation=( + "Use threading.Lock, asyncio.Lock, or move shared state to a " + "thread-safe data structure like queue.Queue. Prefer stateless " + "request handlers with database-backed state." + ), + cwe_id="CWE-362", + confidence=0.6, +) + +# ────────────────────────────────────────────────────────────────────── +# JavaScript concurrent state patterns +# ────────────────────────────────────────────────────────────────────── + +JS_ASYNC_RACE = VulnerabilityPattern( + name="JS Async Race Condition", + trigger_pattern=r"(await\s+.*find|await\s+.*get)[\s\S]*?(await\s+.*save|await\s+.*update)", + defense_pattern=r"(transaction|findOneAndUpdate|atomicUpdate|\$inc|\$set.*upsert|lock|mutex|semaphore)", + context_window=12, + description=( + "An async find-then-update pattern without atomicity. In Node.js with " + "concurrent request handling, this window allows race conditions." + ), + severity=Severity.HIGH, + category=FindingCategory.RACE_CONDITION, + recommendation=( + "Use MongoDB findOneAndUpdate, Sequelize transactions, or Prisma " + "interactive transactions. Avoid separate find-then-save in " + "concurrent contexts." + ), + cwe_id="CWE-362", + confidence=0.7, +) + + +class RaceConditionScanner(BaseScanner): + """Scanner specialized in detecting race condition vulnerabilities. + + Detects non-atomic operations, TOCTOU patterns, unguarded concurrent + redemptions, unsynchronized shared state, and async race conditions + across Python and JavaScript codebases. + """ + + PATTERNS: List[VulnerabilityPattern] = [ + NON_ATOMIC_READ_MODIFY_WRITE, + TOCTOU_CHECK_THEN_ACT, + CONCURRENT_REDEMPTION, + UNPROTECTED_SHARED_STATE, + JS_ASYNC_RACE, + ] + + def __init__( + self, config: ScannerConfig, file_extensions: List[str], exclude_dirs: List[str] + ): + """Initialize the race condition scanner. + + Args: + config: Scanner configuration controlling behavior. + file_extensions: File extensions to include in scanning. + exclude_dirs: Directory names to skip during traversal. + """ + super().__init__(config, file_extensions, exclude_dirs) diff --git a/tmt/scanners/replay_scanner.py b/tmt/scanners/replay_scanner.py new file mode 100644 index 0000000..efb766d --- /dev/null +++ b/tmt/scanners/replay_scanner.py @@ -0,0 +1,130 @@ +"""Scanner for detecting replay attack vulnerabilities in API endpoints. + +Identifies mutating endpoints that lack idempotency keys, nonce validation, +timestamp checks, and request deduplication defenses that prevent replay +attacks where captured requests are maliciously re-submitted. +""" + +from typing import List + +from tmt.config import ScannerConfig +from tmt.models import FindingCategory, Severity +from tmt.scanners.base_scanner import BaseScanner, VulnerabilityPattern + +# ────────────────────────────────────────────────────────────────────── +# Python / Flask / FastAPI replay attack patterns +# ────────────────────────────────────────────────────────────────────── + +PYTHON_POST_WITHOUT_IDEMPOTENCY = VulnerabilityPattern( + name="POST Endpoint Missing Idempotency Key", + trigger_pattern=r"@(app|router|blueprint)\.(post|put|patch)\s*\(", + defense_pattern=r"idempoten|nonce|request_id|x-request-id|dedup|unique_token", + context_window=15, + description=( + "Mutating endpoint does not check for an idempotency key or nonce. " + "An attacker can replay a captured POST/PUT/PATCH request to cause " + "duplicate side effects such as double charges or duplicate records." + ), + severity=Severity.MEDIUM, + category=FindingCategory.REPLAY_ATTACK, + recommendation=( + "Accept an Idempotency-Key header or request_id field. Store processed " + "keys server-side (e.g., Redis with TTL) and reject duplicates before " + "executing business logic." + ), + cwe_id="CWE-294", + confidence=0.7, +) + +PYTHON_NO_TIMESTAMP_VALIDATION = VulnerabilityPattern( + name="Request Missing Timestamp Validation", + trigger_pattern=r"@(app|router|blueprint)\.(post|put|patch|delete)\s*\(", + defense_pattern=r"timestamp|expires?_at|valid_until|time_window|max_age|request_time", + context_window=20, + description=( + "Endpoint does not validate a request timestamp or expiration window. " + "Captured requests can be replayed hours or days later without detection." + ), + severity=Severity.LOW, + category=FindingCategory.REPLAY_ATTACK, + recommendation=( + "Include a timestamp in signed requests and reject any request older " + "than a configurable window (e.g., 5 minutes). Combine with nonce " + "tracking for best protection." + ), + cwe_id="CWE-294", + confidence=0.6, +) + +# ────────────────────────────────────────────────────────────────────── +# JavaScript / Express / Node replay attack patterns +# ────────────────────────────────────────────────────────────────────── + +JS_POST_WITHOUT_IDEMPOTENCY = VulnerabilityPattern( + name="JS POST Endpoint Missing Idempotency Key", + trigger_pattern=r"(app|router)\.(post|put|patch)\s*\(", + defense_pattern=r"idempoten|nonce|requestId|x-request-id|dedup|uniqueToken", + context_window=15, + description=( + "JavaScript mutating endpoint lacks idempotency key or nonce validation. " + "Replayed requests may cause duplicate side effects." + ), + severity=Severity.MEDIUM, + category=FindingCategory.REPLAY_ATTACK, + recommendation=( + "Require an Idempotency-Key header on mutating endpoints. Store " + "processed keys in Redis with a TTL and return cached responses " + "for duplicate keys." + ), + cwe_id="CWE-294", + confidence=0.7, +) + +PYTHON_TOKEN_REUSE_NO_INVALIDATION = VulnerabilityPattern( + name="Token Used Without Single-Use Invalidation", + trigger_pattern=r"(verify_token|validate_token|check_token|decode_token)\s*\(", + defense_pattern=r"(delete|invalidat|revoke|mark_used|consume|burn).*token", + context_window=20, + description=( + "A token is verified but not invalidated after use. One-time tokens " + "(e.g., password reset, email verification) that remain valid after " + "consumption are vulnerable to replay." + ), + severity=Severity.HIGH, + category=FindingCategory.REPLAY_ATTACK, + recommendation=( + "Immediately invalidate single-use tokens after successful verification. " + "Use a database flag or delete the token record within the same " + "transaction as the action it authorizes." + ), + cwe_id="CWE-294", + confidence=0.75, +) + + +class ReplayScanner(BaseScanner): + """Scanner specialized in detecting replay attack vulnerabilities. + + Detects missing idempotency keys, absent timestamp validation, + and token reuse without invalidation across Python and JavaScript + web application codebases. + """ + + PATTERNS: List[VulnerabilityPattern] = [ + PYTHON_POST_WITHOUT_IDEMPOTENCY, + PYTHON_NO_TIMESTAMP_VALIDATION, + JS_POST_WITHOUT_IDEMPOTENCY, + PYTHON_TOKEN_REUSE_NO_INVALIDATION, + ] + + def __init__( + self, config: ScannerConfig, file_extensions: List[str], exclude_dirs: List[str] + ): + """Initialize the replay attack scanner. + + Args: + config: Scanner configuration controlling behavior. + file_extensions: File extensions to include in scanning. + exclude_dirs: Directory names to skip during traversal. + """ + super().__init__(config, file_extensions, exclude_dirs) diff --git a/tmt/scanners/token_abuse_scanner.py b/tmt/scanners/token_abuse_scanner.py new file mode 100644 index 0000000..dd2828d --- /dev/null +++ b/tmt/scanners/token_abuse_scanner.py @@ -0,0 +1,166 @@ +"""Scanner for detecting token and invite abuse vulnerabilities. + +Identifies unbounded token generation, predictable token creation, +missing expiration on tokens, multi-use invite tokens, and absent +rate limiting on token issuance endpoints. +""" + +from typing import List + +from tmt.config import ScannerConfig +from tmt.models import FindingCategory, Severity +from tmt.scanners.base_scanner import BaseScanner, VulnerabilityPattern + +# ────────────────────────────────────────────────────────────────────── +# Token generation without rate limiting +# ────────────────────────────────────────────────────────────────────── + +UNBOUNDED_TOKEN_GENERATION = VulnerabilityPattern( + name="Token Generation Without Rate Limiting", + trigger_pattern=r"(generate_token|create_token|issue_token|create_invite|generate_invite|send_invite)\s*\(", + defense_pattern=r"(rate_limit|throttle|cooldown|max_attempts|limit_per|RateLimit|slowapi|ratelimit)", + context_window=20, + description=( + "Token or invite generation endpoint lacks rate limiting. An attacker " + "can flood the endpoint to generate excessive tokens, exhausting " + "resources or creating mass invite abuse." + ), + severity=Severity.HIGH, + category=FindingCategory.TOKEN_ABUSE, + recommendation=( + "Apply rate limiting per user/IP on token generation endpoints. " + "Use a sliding window counter (e.g., Redis-based) with reasonable " + "limits such as 5 invites per hour per user." + ), + cwe_id="CWE-799", + confidence=0.75, +) + +# ────────────────────────────────────────────────────────────────────── +# Predictable token generation +# ────────────────────────────────────────────────────────────────────── + +PREDICTABLE_TOKEN = VulnerabilityPattern( + name="Predictable Token Generation", + trigger_pattern=r"(uuid\.uuid1|random\.random|random\.randint|Math\.random|hashlib\.(md5|sha1)\(.*time|str\(.*id\))", + defense_pattern=r"(secrets\.|crypto\.random|uuid\.uuid4|os\.urandom|token_hex|token_urlsafe|randomBytes)", + context_window=8, + description=( + "Token generation uses predictable sources like UUID1 (MAC-based), " + "Python's random module (not CSPRNG), or timestamp-based hashing. " + "Predictable tokens can be guessed or brute-forced by attackers." + ), + severity=Severity.CRITICAL, + category=FindingCategory.TOKEN_ABUSE, + recommendation=( + "Use cryptographically secure random generators: secrets.token_urlsafe() " + "in Python, crypto.randomBytes() in Node.js, or uuid4 for identifiers. " + "Never derive tokens from timestamps, sequential IDs, or weak PRNGs." + ), + cwe_id="CWE-330", + confidence=0.85, +) + +# ────────────────────────────────────────────────────────────────────── +# Tokens without expiration +# ────────────────────────────────────────────────────────────────────── + +TOKEN_NO_EXPIRY = VulnerabilityPattern( + name="Token Created Without Expiration", + trigger_pattern=r"(Token\.create|Token\.objects\.create|create_token|generate_token|new\s+Token|InviteToken)\s*\(", + defense_pattern=r"(expir|ttl|valid_until|max_age|lifetime|duration|exp\s*=|expiresAt|expires_at)", + context_window=10, + description=( + "Tokens are created without an expiration time. Long-lived tokens " + "increase the window for token theft and abuse, and make revocation " + "more critical and harder to enforce." + ), + severity=Severity.HIGH, + category=FindingCategory.TOKEN_ABUSE, + recommendation=( + "Set a reasonable TTL on all tokens: 15 minutes for password reset, " + "24-72 hours for invites, 1 hour for session tokens. Store the " + "expiration and check it on every validation." + ), + cwe_id="CWE-613", + confidence=0.75, +) + +# ────────────────────────────────────────────────────────────────────── +# Multi-use invite tokens +# ────────────────────────────────────────────────────────────────────── + +MULTI_USE_INVITE = VulnerabilityPattern( + name="Invite Token Allows Multiple Redemptions", + trigger_pattern=r"(accept_invite|redeem_invite|use_invite|claim_invite|join_.*invite)\s*\(", + defense_pattern=r"(is_used|used_at|redeemed|consumed|single_use|max_uses|use_count|delete.*invite|mark.*used)", + context_window=15, + description=( + "Invite acceptance logic does not check or enforce single-use. " + "An invite link can be shared and used by multiple unauthorized " + "users to gain access to the system." + ), + severity=Severity.HIGH, + category=FindingCategory.TOKEN_ABUSE, + recommendation=( + "Track invite usage with a used_at timestamp or use_count field. " + "Atomically mark invites as consumed during acceptance. Consider " + "binding invites to specific email addresses." + ), + cwe_id="CWE-841", + confidence=0.8, +) + +# ────────────────────────────────────────────────────────────────────── +# Missing token revocation +# ────────────────────────────────────────────────────────────────────── + +NO_TOKEN_REVOCATION = VulnerabilityPattern( + name="No Token Revocation Mechanism", + trigger_pattern=r"(def\s+logout|def\s+revoke|def\s+invalidate|signOut|logOut)\s*\(", + defense_pattern=r"(delete.*token|revoke.*token|blacklist|blocklist|token.*delete|destroy.*session|clear.*token)", + context_window=15, + description=( + "Logout or revocation endpoint does not actually invalidate the " + "token server-side. The token remains valid and usable even after " + "the user believes they have logged out." + ), + severity=Severity.HIGH, + category=FindingCategory.TOKEN_ABUSE, + recommendation=( + "Maintain a server-side token blocklist or delete the token record " + "on logout. For JWTs, use short expiration combined with a refresh " + "token that can be revoked from the database." + ), + cwe_id="CWE-613", + confidence=0.7, +) + + +class TokenAbuseScanner(BaseScanner): + """Scanner specialized in detecting token and invite abuse vulnerabilities. + + Detects unbounded generation, predictable tokens, missing expiration, + multi-use invites, and absent revocation mechanisms across Python + and JavaScript codebases. + """ + + PATTERNS: List[VulnerabilityPattern] = [ + UNBOUNDED_TOKEN_GENERATION, + PREDICTABLE_TOKEN, + TOKEN_NO_EXPIRY, + MULTI_USE_INVITE, + NO_TOKEN_REVOCATION, + ] + + def __init__( + self, config: ScannerConfig, file_extensions: List[str], exclude_dirs: List[str] + ): + """Initialize the token abuse scanner. + + Args: + config: Scanner configuration controlling behavior. + file_extensions: File extensions to include in scanning. + exclude_dirs: Directory names to skip during traversal. + """ + super().__init__(config, file_extensions, exclude_dirs)