Initial commit: AASRT v1.0.0 - AI Agent Security Reconnaissance Tool

This commit is contained in:
swethab
2026-02-10 10:53:31 -05:00
commit a714a3399b
61 changed files with 14858 additions and 0 deletions
+97
View File
@@ -0,0 +1,97 @@
# =============================================================================
# AASRT Docker Ignore File
# Excludes files from Docker build context for security and efficiency
# =============================================================================
# Git
.git/
.gitignore
.gitattributes
# Documentation (not needed in container)
*.md
docs/
LICENSE
# Python artifacts
__pycache__/
*.py[cod]
*$py.class
*.so
.Python
*.egg-info/
*.egg
dist/
build/
eggs/
.eggs/
wheels/
# Virtual environments
venv/
.venv/
ENV/
env/
# IDE and editor files
.idea/
.vscode/
*.swp
*.swo
*~
.editorconfig
# Environment files (NEVER include secrets in image)
.env
.env.*
!.env.example
# Local data files (should be mounted as volumes)
data/
logs/
reports/
*.db
*.sqlite
*.sqlite3
# Test files
tests/
test_*.py
*_test.py
.pytest_cache/
.coverage
htmlcov/
.mypy_cache/
.tox/
# OS files
.DS_Store
Thumbs.db
# Docker files (prevent recursive copying)
Dockerfile*
docker-compose*.yml
.dockerignore
# Secrets and credentials
*.pem
*.key
*.crt
credentials.json
secrets.yaml
secrets/
private/
# Temporary and backup files
tmp/
temp/
*.tmp
*.bak
*.backup
*.old
# Security scan results
security-reports/
vulnerability-reports/
*.sarif
+82
View File
@@ -0,0 +1,82 @@
# =============================================================================
# AASRT - AI Agent Security Reconnaissance Tool
# Environment Configuration Template
# =============================================================================
# Copy this file to .env and fill in your values
# NEVER commit your .env file to version control!
# =============================================================================
# -----------------------------------------------------------------------------
# REQUIRED: Shodan API Configuration
# -----------------------------------------------------------------------------
# Get your API key from: https://account.shodan.io/
SHODAN_API_KEY=your_shodan_api_key_here
# -----------------------------------------------------------------------------
# OPTIONAL: Application Settings
# -----------------------------------------------------------------------------
# Log level: DEBUG, INFO, WARNING, ERROR, CRITICAL
AASRT_LOG_LEVEL=INFO
# Environment: development, staging, production
AASRT_ENVIRONMENT=production
# Enable debug mode (set to false in production!)
AASRT_DEBUG=false
# -----------------------------------------------------------------------------
# OPTIONAL: Database Configuration
# -----------------------------------------------------------------------------
# For SQLite (default): Leave these empty, uses ./data/scanner.db
# For PostgreSQL: Uncomment and fill in the values below
# DB_TYPE=postgresql
# DB_HOST=localhost
# DB_PORT=5432
# DB_NAME=aasrt
# DB_USER=aasrt_user
# DB_PASSWORD=your_secure_password_here
# DB_SSL_MODE=require
# -----------------------------------------------------------------------------
# OPTIONAL: ClawSec Threat Intelligence
# -----------------------------------------------------------------------------
# ClawSec feed URL (default: https://clawsec.prompt.security/advisories/feed.json)
# CLAWSEC_FEED_URL=https://clawsec.prompt.security/advisories/feed.json
# Cache settings
# CLAWSEC_CACHE_TTL=86400
# CLAWSEC_OFFLINE_MODE=false
# -----------------------------------------------------------------------------
# OPTIONAL: Security Settings
# -----------------------------------------------------------------------------
# Secret key for session management (generate a random 32+ char string)
# AASRT_SECRET_KEY=your_random_secret_key_here
# Allowed hosts (comma-separated, for production deployment)
# AASRT_ALLOWED_HOSTS=localhost,127.0.0.1
# Maximum results per scan (rate limiting)
# AASRT_MAX_RESULTS=500
# -----------------------------------------------------------------------------
# OPTIONAL: Streamlit Configuration
# -----------------------------------------------------------------------------
# Streamlit server settings
# STREAMLIT_SERVER_PORT=8501
# STREAMLIT_SERVER_ADDRESS=0.0.0.0
# STREAMLIT_SERVER_HEADLESS=true
# STREAMLIT_BROWSER_GATHER_USAGE_STATS=false
# -----------------------------------------------------------------------------
# OPTIONAL: Alerting & Notifications (Future)
# -----------------------------------------------------------------------------
# Slack webhook URL for critical findings
# SLACK_WEBHOOK_URL=
# Email notifications
# SMTP_HOST=
# SMTP_PORT=587
# SMTP_USER=
# SMTP_PASSWORD=
# ALERT_EMAIL_TO=
+172
View File
@@ -0,0 +1,172 @@
name: AASRT CI/CD Pipeline
on:
push:
branches: [main, develop]
pull_request:
branches: [main]
env:
PYTHON_VERSION: '3.11'
jobs:
# ============================================================================
# Code Quality Checks
# ============================================================================
lint:
name: Code Quality
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: ${{ env.PYTHON_VERSION }}
cache: 'pip'
- name: Install linting tools
run: |
pip install flake8 black isort mypy
pip install -r requirements.txt
- name: Run Black (formatting check)
run: black --check --diff src/ tests/
continue-on-error: true
- name: Run isort (import sorting)
run: isort --check-only --diff src/ tests/
continue-on-error: true
- name: Run Flake8 (linting)
run: flake8 src/ tests/ --max-line-length=120 --statistics
continue-on-error: true
- name: Run MyPy (type checking)
run: mypy src/ --ignore-missing-imports --no-error-summary
continue-on-error: true
# ============================================================================
# Unit Tests
# ============================================================================
test-unit:
name: Unit Tests
runs-on: ubuntu-latest
needs: lint
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: ${{ env.PYTHON_VERSION }}
cache: 'pip'
- name: Install dependencies
run: |
pip install -r requirements.txt
pip install pytest pytest-cov pytest-mock pytest-timeout
- name: Run unit tests
env:
SHODAN_API_KEY: test_key_for_ci
AASRT_ENVIRONMENT: testing
run: |
pytest tests/unit/ -v --cov=src --cov-report=xml --cov-report=term-missing -m "not slow"
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v4
with:
file: coverage.xml
fail_ci_if_error: false
# ============================================================================
# Integration Tests
# ============================================================================
test-integration:
name: Integration Tests
runs-on: ubuntu-latest
needs: test-unit
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: ${{ env.PYTHON_VERSION }}
cache: 'pip'
- name: Install dependencies
run: |
pip install -r requirements.txt
pip install pytest pytest-cov pytest-mock pytest-timeout
- name: Run integration tests
env:
SHODAN_API_KEY: test_key_for_ci
AASRT_ENVIRONMENT: testing
run: |
pytest tests/integration/ -v --timeout=120
# ============================================================================
# Security Scanning
# ============================================================================
security:
name: Security Scanning
runs-on: ubuntu-latest
needs: lint
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: ${{ env.PYTHON_VERSION }}
cache: 'pip'
- name: Install security tools
run: |
pip install bandit safety pip-audit
pip install -r requirements.txt
- name: Run Bandit (SAST)
run: bandit -r src/ -ll -ii --format json --output bandit-report.json
continue-on-error: true
- name: Run Safety (dependency vulnerabilities)
run: safety check --full-report
continue-on-error: true
- name: Run pip-audit
run: pip-audit --strict --desc
continue-on-error: true
- name: Upload Bandit report
uses: actions/upload-artifact@v4
with:
name: bandit-report
path: bandit-report.json
if: always()
# ============================================================================
# Docker Build
# ============================================================================
docker:
name: Docker Build
runs-on: ubuntu-latest
needs: [test-unit, security]
steps:
- uses: actions/checkout@v4
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Build Docker image
uses: docker/build-push-action@v5
with:
context: .
push: false
tags: aasrt:${{ github.sha }}
cache-from: type=gha
cache-to: type=gha,mode=max
+103
View File
@@ -0,0 +1,103 @@
# AASRT .gitignore
# Python
__pycache__/
*.py[cod]
*$py.class
*.so
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
# Virtual environments
venv/
ENV/
env/
.venv/
# IDE
.idea/
.vscode/
*.swp
*.swo
*~
# Environment variables - NEVER commit these
.env
.env.local
.env.*.local
# Logs
logs/
*.log
# Database files
data/
*.db
*.sqlite
*.sqlite3
# Reports (may contain sensitive data)
reports/
*.json
!queries/*.yaml
!config.yaml
!config.yaml.example
# OS files
.DS_Store
Thumbs.db
# Pytest
.pytest_cache/
.coverage
htmlcov/
# MyPy
.mypy_cache/
# Secrets and credentials
*.pem
*.key
*.crt
*.p12
*.pfx
credentials.json
secrets.yaml
secrets/
private/
# Docker
.docker/
docker-compose.override.yml
docker-compose.*.yml
!docker-compose.yml
!docker-compose.prod.yml
# Security scan results (may contain sensitive findings)
security-reports/
vulnerability-reports/
*.sarif
# Backup files
*.bak
*.backup
*.old
# Temporary files
tmp/
temp/
*.tmp
+553
View File
@@ -0,0 +1,553 @@
# 📋 AASRT Custom Query Templates Guide
Create your own Shodan query templates to extend AASRT's reconnaissance capabilities.
---
## Table of Contents
1. [Quick Start](#quick-start)
2. [File Structure](#file-structure)
3. [Required & Optional Fields](#required--optional-fields)
4. [Shodan Query Syntax](#shodan-query-syntax)
5. [UI Integration](#ui-integration)
6. [Best Practices](#best-practices)
7. [Step-by-Step Example](#step-by-step-example)
8. [Troubleshooting](#troubleshooting)
---
## Quick Start
1. Copy the example template:
```bash
cp queries/custom.yaml.example queries/my_template.yaml
```
2. Edit the file with your queries:
```yaml
name: My Custom Template
description: Search for my target systems
author: Your Name
version: 1.0
queries:
- 'http.title:"My Target"'
- 'http.html:"keyword" port:8080'
tags:
- custom
- my-category
```
3. Refresh the AASRT dashboard—your template appears automatically!
---
## File Structure
### Location
All custom templates go in the `queries/` directory at the project root:
```
AASRT/
├── queries/
│ ├── autogpt.yaml # Built-in template
│ ├── clawdbot.yaml # Built-in template
│ ├── langchain.yaml # Built-in template
│ ├── clawsec_advisories.yaml
│ ├── custom.yaml.example # Template reference (ignored)
│ └── my_custom.yaml # ← Your custom templates here!
├── app.py
├── src/
└── ...
```
### File Naming
| Aspect | Rule | Example |
|--------|------|---------|
| **Extension** | Must be `.yaml` or `.yml` | `ollama_instances.yaml` |
| **Template Name** | Derived from filename (without extension) | `ollama_instances` |
| **Characters** | Use lowercase, underscores, no spaces | `huggingface_models.yaml` ✅ |
| **Avoid** | Hyphens, uppercase, special chars | `Hugging-Face.yaml` ❌ |
### YAML Format
Templates use standard YAML syntax:
```yaml
# Comment line (starts with #)
name: Template Name # Scalar value
description: Some text # String (quotes optional for simple text)
queries: # List of items
- 'first query' # List item (note the dash)
- 'second query'
- 'third query'
tags: # Another list
- tag1
- tag2
```
> **⚠️ Important:** Use consistent indentation (2 spaces recommended). YAML is whitespace-sensitive!
---
## Required & Optional Fields
### Field Reference Table
| Field | Required | Type | Description |
|-------|----------|------|-------------|
| `name` | ✅ Yes | String | Display name shown in UI |
| `description` | ✅ Yes | String | What the template searches for |
| `queries` | ✅ Yes | List | Shodan query strings to execute |
| `author` | ⚪ Optional | String | Creator's name |
| `version` | ⚪ Optional | String | Template version (e.g., "1.0") |
| `tags` | ⚪ Optional | List | Categorization tags |
### Complete Example
```yaml
# Ollama Model Server Detection Template
# Created: 2026-02-10
name: Ollama Model Servers
description: Detect exposed Ollama LLM model servers with web interfaces
author: Security Research Team
version: 1.2
queries:
- 'http.title:"Ollama"'
- 'http.html:"ollama" port:11434'
- 'http.html:"llama" http.html:"model"'
- 'product:"Ollama"'
- 'http.title:"Ollama Web UI"'
tags:
- ai-agent
- llm
- ollama
- self-hosted
- critical
```
### Queries Field Formats
AASRT supports two query formats:
**Format 1: Simple List (Recommended)**
```yaml
queries:
- 'http.title:"Target"'
- 'http.html:"keyword"'
```
**Format 2: Nested Dict (Advanced)**
```yaml
queries:
shodan:
- 'http.title:"Target"'
- 'http.html:"keyword"'
```
---
## Shodan Query Syntax
### Common Search Operators
| Operator | Purpose | Example |
|----------|---------|---------|
| `http.title:` | Search page titles | `http.title:"Dashboard"` |
| `http.html:` | Search HTML body content | `http.html:"api_key"` |
| `product:` | Search product banners | `product:"nginx"` |
| `port:` | Filter by port number | `port:8080` |
| `hostname:` | Filter by hostname | `hostname:example.com` |
| `org:` | Filter by organization | `org:"Amazon"` |
| `country:` | Filter by country code | `country:US` |
| `ssl:` | Search SSL certificate fields | `ssl:"Let's Encrypt"` |
| `http.status:` | Filter by HTTP status code | `http.status:200` |
### Boolean Operators
| Operator | Usage | Example |
|----------|-------|---------|
| **AND** | Implicit (space-separated) | `http.title:"GPT" port:8000` |
| **OR** | Explicit OR keyword | `http.title:"AutoGPT" OR http.title:"Auto-GPT"` |
| **NOT** | Exclude with minus | `http.title:"Dashboard" -port:443` |
### Combining Filters (Examples)
```yaml
queries:
# Find LangChain agents on common ports
- 'http.html:"langchain" http.html:"agent" port:8000,8080,3000'
# Find exposed API keys in HTML
- 'http.html:"sk-" http.html:"openai"'
# Find debug mode enabled (multiple patterns)
- 'http.html:"DEBUG=True" OR http.html:"debug: true"'
# Exclude CDN-hosted results
- 'http.title:"Jupyter" -org:"Cloudflare" -org:"Amazon CloudFront"'
# Country-specific search
- 'http.title:"AI Dashboard" country:US,GB,DE'
# Certificate-based discovery
- 'ssl.cert.subject.CN:"*.openai.com"'
```
### Query Quoting Rules
| Scenario | Syntax | Example |
|----------|--------|---------|
| Exact phrase | Double quotes inside single | `'http.title:"Auto-GPT Dashboard"'` |
| Simple word | No inner quotes needed | `'product:nginx'` |
| Special characters | Always quote the value | `'http.html:"api_key="'` |
> **💡 Pro Tip:** Test your queries on [Shodan.io](https://www.shodan.io/) before adding them to templates!
---
## UI Integration
### Where Templates Appear
Custom templates automatically appear in the Streamlit dashboard:
```
┌─────────────────────────────────────────────────────┐
│ MISSION TYPE: ○ 🎯 TEMPLATE ○ ✍️ CUSTOM │
├─────────────────────────────────────────────────────┤
│ SELECT TARGET │
│ ┌───────────────────────────────────────────────┐ │
│ │ 📋 My Custom Template ▼ │ │
│ ├───────────────────────────────────────────────┤ │
│ │ 🤖 Autogpt Instances │ │
│ │ 🐾 Clawdbot Instances │ │
│ │ 🔗 Langchain Agents │ │
│ │ 📋 My Custom Template ← Your template! │ │
│ │ 🛡️ Clawsec Advisories │ │
│ └───────────────────────────────────────────────┘ │
└─────────────────────────────────────────────────────┘
```
### Template Display Formatting
Templates are displayed with:
1. **Icon** - Based on template name (defaults to 📋 for custom templates)
2. **Name** - Filename converted to title case (`my_template` → "My Template")
**Built-in Icon Mappings:**
| Template Name | Icon |
|--------------|------|
| `autogpt_instances` | 🤖 |
| `langchain_agents` | 🔗 |
| `jupyter_notebooks` | 📓 |
| `clawdbot_instances` | 🐾 |
| `exposed_env_files` | 📁 |
| `clawsec_advisories` | 🛡️ |
| Custom templates | 📋 |
### Refreshing Templates
**Web Dashboard (Streamlit):**
- Templates are cached for 5 minutes (`ttl=300`)
- To force refresh: **Press `R`** or **click the ⟳ button** in the browser
- Or restart the Streamlit server
**CLI:**
- Templates are loaded fresh on each command
- Run `python -m src.main templates` to see updated list
---
## Best Practices
### ✅ Do's
| Practice | Why | Example |
|----------|-----|---------|
| **Start specific, then broaden** | Avoid wasting API credits | Start with `http.title:"Exact Name"` before `http.html:"keyword"` |
| **Test on Shodan.io first** | Validate results before scanning | Check result count and relevance |
| **Use multiple query variations** | Cover different configurations | Include both "Auto-GPT" and "AutoGPT" |
| **Add meaningful tags** | Organize and filter templates | `tags: [ai-agent, critical, llm]` |
| **Document your queries** | Future maintainability | Add comments explaining each query |
| **Version your templates** | Track changes | `version: 1.2` |
### ❌ Don'ts
| Anti-Pattern | Problem | Better Alternative |
|--------------|---------|-------------------|
| `http.html:"a"` | Too broad, millions of results | Use specific keywords |
| No quotes on phrases | Query parsing errors | Always quote multi-word phrases |
| `port:*` | Invalid syntax | Omit port filter entirely |
| Mixing tabs and spaces | YAML parsing fails | Use spaces only (2-space indent) |
| 50+ queries per template | Slow scans, API waste | Split into multiple templates |
### Query Optimization Tips
```yaml
# ❌ Bad: Too broad, returns millions
queries:
- 'http.html:"api"'
# ✅ Good: Specific and targeted
queries:
- 'http.html:"openai" http.html:"api_key"'
- 'http.html:"sk-" http.html:"Bearer"'
```
```yaml
# ❌ Bad: Missing quotes around phrase
queries:
- http.title:Auto-GPT Dashboard
# ✅ Good: Properly quoted
queries:
- 'http.title:"Auto-GPT Dashboard"'
```
---
## Step-by-Step Example
Let's create a template to find exposed **Hugging Face Spaces** and **Gradio apps**.
### Step 1: Research on Shodan.io
Visit [Shodan.io](https://www.shodan.io/) and test queries:
```
http.title:"Hugging Face" → 1,234 results
http.html:"gradio" port:7860 → 567 results
http.title:"Gradio" → 890 results
```
### Step 2: Create the Template File
Create `queries/huggingface_spaces.yaml`:
```yaml
# Hugging Face Spaces and Gradio Detection Template
# Finds exposed ML demo applications
name: Hugging Face Spaces
description: Detect exposed Hugging Face Spaces and Gradio ML applications
author: Security Team
version: 1.0
queries:
# Direct Hugging Face Spaces
- 'http.title:"Hugging Face"'
- 'http.html:"huggingface" http.html:"spaces"'
# Gradio apps (common ML demo framework)
- 'http.title:"Gradio"'
- 'http.html:"gradio" port:7860'
- 'http.html:"gradio-app"'
# Streamlit ML apps
- 'http.html:"streamlit" http.html:"model"'
# Generic ML dashboard patterns
- 'http.title:"ML Dashboard" OR http.title:"Model Demo"'
tags:
- ai-agent
- huggingface
- gradio
- machine-learning
- demo
```
### Step 3: Validate YAML Syntax
Use an online YAML validator or Python:
```bash
python -c "import yaml; yaml.safe_load(open('queries/huggingface_spaces.yaml'))"
```
No output = valid YAML!
### Step 4: Verify Template Loading
```bash
python -m src.main templates
```
Output should include:
```
┌─────────────────────────────────┬──────────┐
│ Template Name │ Queries │
├─────────────────────────────────┼──────────┤
│ autogpt_instances │ 2 queries│
│ clawdbot_instances │ 3 queries│
│ huggingface_spaces │ 8 queries│ ← Your new template!
│ langchain_agents │ 2 queries│
└─────────────────────────────────┴──────────┘
```
### Step 5: Run a Scan
**CLI:**
```bash
python -m src.main scan --template huggingface_spaces --yes
```
**Web Dashboard:**
1. Open `http://localhost:8501`
2. Select "📋 Huggingface Spaces" from dropdown
3. Accept mission parameters
4. Click "🚀 INITIATE SCAN"
---
## Troubleshooting
### Common Issues
#### Template Not Appearing in UI
| Symptom | Cause | Solution |
|---------|-------|----------|
| Template not in dropdown | File not in `queries/` dir | Move file to correct location |
| Template not in dropdown | Wrong file extension | Rename to `.yaml` or `.yml` |
| Template not in dropdown | Cache not refreshed | Press `R` in browser or restart Streamlit |
| Template not in dropdown | YAML syntax error | Validate YAML (see below) |
#### YAML Syntax Errors
**Error:** `yaml.scanner.ScannerError: mapping values are not allowed here`
```yaml
# ❌ Wrong: Missing space after colon
name:Template Name
# ✅ Correct
name: Template Name
```
**Error:** `yaml.parser.ParserError: expected ',' or ']'`
```yaml
# ❌ Wrong: Mixing quote styles
queries:
- "http.title:'Dashboard'"
# ✅ Correct: Consistent quoting
queries:
- 'http.title:"Dashboard"'
```
**Error:** `yaml.scanner.ScannerError: found character '\t'`
```yaml
# ❌ Wrong: Using tabs
queries:
- 'query'
# ✅ Correct: Using spaces (2-space indent)
queries:
- 'query'
```
#### Queries Return No Results
| Symptom | Cause | Solution |
|---------|-------|----------|
| 0 results for all queries | Shodan API key invalid | Check `SHODAN_API_KEY` in `.env` |
| 0 results for specific query | Query too specific | Broaden search terms |
| 0 results for specific query | Typo in query | Test on Shodan.io first |
| Fewer results than expected | Rate limiting | Wait and retry |
#### Validate YAML Syntax
**Online validators:**
- [YAML Lint](https://www.yamllint.com/)
- [YAML Validator](https://jsonformatter.org/yaml-validator)
**Python validation:**
```python
import yaml
from pathlib import Path
template_path = Path("queries/my_template.yaml")
try:
data = yaml.safe_load(template_path.read_text())
print("✅ Valid YAML!")
print(f" Name: {data.get('name')}")
print(f" Queries: {len(data.get('queries', []))}")
except yaml.YAMLError as e:
print(f"❌ YAML Error: {e}")
```
---
## Advanced Topics
### Template Inheritance (Future)
Currently not supported, but you can combine queries manually:
```yaml
# combined_ai_agents.yaml
queries:
# From autogpt template
- 'http.title:"Auto-GPT"'
- 'http.title:"AutoGPT"'
# From langchain template
- 'http.html:"langchain" http.html:"agent"'
# Custom additions
- 'http.title:"CrewAI"'
```
### Programmatic Template Creation
```python
from src.core.query_manager import QueryManager
qm = QueryManager()
# Add a new template programmatically
qm.templates['my_new_template'] = [
'http.title:"My Target"',
'http.html:"keyword"'
]
# Save to file
qm.save_template('my_new_template')
```
---
## Summary Checklist
Before using your custom template:
- [ ] File is in `queries/` directory
- [ ] File extension is `.yaml` or `.yml`
- [ ] `name`, `description`, and `queries` fields are present
- [ ] YAML syntax is valid (validated with linter)
- [ ] Queries are properly quoted
- [ ] Tested queries on Shodan.io first
- [ ] Template appears in `python -m src.main templates` output
- [ ] Tags are meaningful for organization
---
**Happy Hunting! 🎯**
*For questions or contributions, see the main project documentation.*
+85
View File
@@ -0,0 +1,85 @@
# =============================================================================
# AASRT - AI Agent Security Reconnaissance Tool
# Production Dockerfile with Multi-Stage Build
# =============================================================================
# -----------------------------------------------------------------------------
# Stage 1: Builder - Install dependencies and prepare application
# -----------------------------------------------------------------------------
FROM python:3.13-slim AS builder
# Set build-time environment variables
ENV PYTHONDONTWRITEBYTECODE=1 \
PYTHONUNBUFFERED=1 \
PIP_NO_CACHE_DIR=1 \
PIP_DISABLE_PIP_VERSION_CHECK=1
WORKDIR /app
# Install build dependencies
RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential \
&& rm -rf /var/lib/apt/lists/*
# Copy requirements first for layer caching
COPY requirements.txt .
# Install Python dependencies to a virtual environment
RUN python -m venv /opt/venv
ENV PATH="/opt/venv/bin:$PATH"
RUN pip install --upgrade pip && \
pip install -r requirements.txt
# -----------------------------------------------------------------------------
# Stage 2: Runtime - Minimal production image
# -----------------------------------------------------------------------------
FROM python:3.13-slim AS runtime
# Labels for container identification
LABEL maintainer="AASRT Team" \
version="1.0.0" \
description="AI Agent Security Reconnaissance Tool - Production Image"
# Security: Run as non-root user
RUN groupadd --gid 1000 aasrt && \
useradd --uid 1000 --gid aasrt --shell /bin/bash --create-home aasrt
# Set runtime environment variables
ENV PYTHONDONTWRITEBYTECODE=1 \
PYTHONUNBUFFERED=1 \
PYTHONPATH=/app \
# Application configuration
AASRT_ENVIRONMENT=production \
AASRT_LOG_LEVEL=INFO \
# Streamlit configuration
STREAMLIT_SERVER_PORT=8501 \
STREAMLIT_SERVER_ADDRESS=0.0.0.0 \
STREAMLIT_SERVER_HEADLESS=true \
STREAMLIT_BROWSER_GATHER_USAGE_STATS=false
WORKDIR /app
# Copy virtual environment from builder
COPY --from=builder /opt/venv /opt/venv
ENV PATH="/opt/venv/bin:$PATH"
# Copy application code
COPY --chown=aasrt:aasrt . .
# Create necessary directories with correct permissions
RUN mkdir -p /app/data /app/logs /app/reports && \
chown -R aasrt:aasrt /app/data /app/logs /app/reports
# Switch to non-root user
USER aasrt
# Expose Streamlit port
EXPOSE 8501
# Health check
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
CMD curl -f http://localhost:8501/_stcore/health || exit 1
# Default command: Run Streamlit web interface
CMD ["streamlit", "run", "app.py", "--server.port=8501", "--server.address=0.0.0.0"]
+212
View File
@@ -0,0 +1,212 @@
# 🗺️ Quick Map Guide
## How to See the Enhanced Map
### Step 1: Start the Dashboard
```bash
streamlit run app.py
```
### Step 2: Run a Scan
1. Open browser to `http://localhost:8501`
2. In sidebar, select a template (e.g., "clawdbot_instances")
3. Check "I accept mission parameters"
4. Click "🚀 INITIATE SCAN"
### Step 3: View the Map
Scroll down to the **"🌍 GALACTIC THREAT MAP"** section
## 🎮 New Controls
### Map Style Selector
```
🗺️ MAP STYLE: [3D Globe ▼]
```
- **3D Globe** - Rotating sphere (most impressive!)
- **Flat Map** - Traditional 2D view
- **Dark Matter** - Dark equirectangular
- **Natural Earth** - Natural projection
### Interactive Options
```
☐ ⚡ Show Threat Connections
☑ 💫 Animated Markers
```
### Map Controls (Bottom Center)
```
[🔄 AUTO ROTATE] [⏸️ PAUSE]
```
## 🎨 What You'll See
### Main Map (Left Side)
- **Large colored markers** showing threats
- **Different shapes** for different severities:
- 💎 Red Diamonds = CRITICAL
- ⬛ Orange Squares = HIGH
- ⚪ Yellow Circles = MEDIUM
- ⚪ Green Circles = LOW
- **Hover over markers** for detailed info
- **Click and drag** to rotate globe
- **Scroll** to zoom in/out
### Stats Panel (Top)
```
🛰️ LOCATED 🌐 SYSTEMS 🏙️ SECTORS ⭐ HOTSPOT
32 12 24 Germany
```
### Country Rankings (Right Side)
```
🏴 TOP SYSTEMS
Germany ████████████ 12
United States ████████ 8
France ██████ 6
...
```
### Threat Density Chart (Right Side)
Horizontal bar chart showing average risk per country
### New Analysis Section (Bottom)
```
📡 THREAT SURFACE ANALYSIS
🎯 PORT DISTRIBUTION 🔧 SERVICE BREAKDOWN
[Bar Chart] [Donut Chart]
```
## 💡 Pro Tips
1. **Best View**: Start with "3D Globe" + "Animated Markers"
2. **For Presentations**: Enable "Show Threat Connections" for critical threats
3. **For Analysis**: Switch to "Flat Map" to see all threats at once
4. **Performance**: Disable animations if map is slow
5. **Screenshots**: Use "⏸️ PAUSE" before taking screenshots
## 🎯 Understanding the Markers
### Size
- Larger markers = Higher risk score
- Size range: 15px - 50px
- Formula: `size = max(15, risk_score * 5)`
### Color
- 🔴 Red (#FF2D2D) = Critical (9.0-10.0)
- 🟠 Orange (#FF6B35) = High (7.0-8.9)
- 🟡 Yellow (#FFE81F) = Medium (4.0-6.9)
- 🟢 Green (#39FF14) = Low (0.0-3.9)
### Hover Info
```
192.168.1.1:8080
⚡ Risk: 10.0/10
📍 Berlin, Germany
🔧 nginx
```
## 🚀 Quick Actions
### Rotate Globe Manually
- Click and drag on the globe
- Works in "3D Globe" mode only
### Zoom In/Out
- Scroll wheel up = Zoom in
- Scroll wheel down = Zoom out
### Focus on Region
- Double-click on a country
- Map will zoom to that region
### Toggle Threat Category
- Click legend item (e.g., "CRITICAL")
- Hides/shows that category
### Reset View
- Refresh the page
- Or change map style and back
## 📊 Reading the Charts
### Port Distribution
- Shows which ports are most exposed
- Higher bars = More targets on that port
- Common ports: 80, 443, 8080, 3000
### Service Breakdown
- Shows technology distribution
- Larger slices = More common services
- Common services: nginx, apache, node
### Threat Density
- Shows average risk by country
- Longer bars = Higher average risk
- Color gradient indicates severity
## 🎬 Demo Scenario
1. **Launch**: `streamlit run app.py`
2. **Scan**: Select "clawdbot_instances" template
3. **Wait**: ~5 seconds for scan to complete
4. **Scroll**: Down to map section
5. **Interact**:
- Try rotating the globe
- Hover over markers
- Click "AUTO ROTATE"
- Toggle "Show Threat Connections"
- Change to "Flat Map"
6. **Analyze**:
- Check top countries
- Review port distribution
- Examine service breakdown
## 🎨 Visual Features
### Animations
- ✨ Smooth marker transitions
- 🌍 Globe auto-rotation (3°/frame)
- 💫 Hover glow effects
- 🌊 Pulsing connections
### Styling
- 🌌 Transparent background (space theme)
- 🌊 Cyan coastlines and borders
- 🌑 Dark land and ocean
- ⭐ Glowing markers
- 🎯 Professional tooltips
## 🔧 Customization
Want to change colors or sizes? Edit `app.py` around line 1100:
```python
# Marker colors
('critical', '#FF2D2D', 'CRITICAL', 'diamond')
# Marker size
df_map['size'] = df_map['risk_score'].apply(lambda x: max(15, x * 5))
# Map height
height=650
```
## 📱 Mobile/Tablet
The map works on mobile devices:
- Touch to rotate
- Pinch to zoom
- Tap markers for info
- Responsive layout
## 🎉 Enjoy!
The enhanced map makes threat visualization impressive and informative. Perfect for:
- 🎤 Security presentations
- 📊 Executive dashboards
- 🔍 Threat hunting
- 📈 Trend analysis
- 🎓 Security training
**May the Force be with your reconnaissance!**
+229
View File
@@ -0,0 +1,229 @@
# AASRT Quick Start Guide
## Prerequisites
✅ Python 3.13 installed
✅ All dependencies installed (`pip install -r requirements.txt`)
✅ Shodan API key configured in `.env` file
## Basic Commands
### 1. Check System Status
```bash
python -m src.main status
```
This shows:
- Shodan API status and credits
- Available query templates (13 templates)
- Your current plan type
### 2. List Available Templates
```bash
python -m src.main templates
```
Available templates:
- `clawdbot_instances` - Find ClawdBot dashboards
- `autogpt_instances` - Find AutoGPT deployments
- `langchain_agents` - Find LangChain agents
- `openai_exposed` - Find exposed OpenAI integrations
- `exposed_env_files` - Find exposed .env files
- `debug_mode` - Find services with debug mode enabled
- `jupyter_notebooks` - Find exposed Jupyter notebooks
- `streamlit_apps` - Find Streamlit applications
- And 5 more...
### 3. Run a Scan
**Using a template (recommended):**
```bash
python -m src.main scan --template clawdbot_instances --yes
```
**Using a custom query:**
```bash
python -m src.main scan --query 'http.title:"AutoGPT"' --yes
```
**Without --yes flag (shows legal disclaimer):**
```bash
python -m src.main scan --template clawdbot_instances
```
### 4. View Scan History
```bash
python -m src.main history
```
Shows:
- Last 10 scans
- Scan IDs, timestamps, results count
- Database statistics
### 5. Generate Report from Previous Scan
```bash
python -m src.main report --scan-id <scan_id>
```
## Understanding Scan Results
### Console Output
```
+-------------------------------- Scan Summary --------------------------------+
| Scan ID: 211a5df0... |
| Duration: 3.3s |
| Total Results: 32 |
| Average Risk Score: 3.7/10 |
+------------------------------------------------------------------------------+
Risk Distribution
+------------------+
| Severity | Count |
|----------+-------|
| Critical | 4 |
| High | 0 |
| Medium | 0 |
| Low | 28 |
+------------------+
```
### Report Files
Reports are saved in `./reports/` directory:
- **JSON format:** `scan_<id>_<timestamp>.json`
- **CSV format:** `scan_<id>_<timestamp>.csv` (if enabled)
### Database
All scans are automatically saved to: `./data/scanner.db`
## Common Use Cases
### 1. Find Exposed AI Dashboards
```bash
python -m src.main scan --template ai_dashboards --yes
```
### 2. Find Debug Mode Enabled Services
```bash
python -m src.main scan --template debug_mode --yes
```
### 3. Find Exposed Environment Files
```bash
python -m src.main scan --template exposed_env_files --yes
```
### 4. Custom Search for Specific Service
```bash
python -m src.main scan --query 'product:"nginx" port:8080' --yes
```
## Understanding Risk Scores
- **10.0 (Critical):** No authentication on sensitive dashboards
- **7.0-9.9 (High):** Exposed API keys, shell access, database strings
- **5.0-6.9 (Medium):** SSL issues, exposed config files
- **3.0-4.9 (Low):** Self-signed certificates, missing security.txt
- **1.0-2.9 (Info):** Informational findings
## Vulnerability Types Detected
1. **Authentication Issues**
- No authentication on dashboards
- Missing security controls
2. **API Key Exposure**
- OpenAI keys (sk-...)
- Anthropic keys (sk-ant-...)
- AWS credentials (AKIA...)
- GitHub tokens (ghp_...)
- Google API keys (AIza...)
- Stripe keys (sk_live_...)
3. **Dangerous Functionality**
- Shell execution endpoints
- Debug mode enabled
- File upload functionality
- Admin panels exposed
- Database connection strings
4. **Information Disclosure**
- Exposed .env files
- Configuration files
- Git repositories
- Source code files
5. **SSL/TLS Issues**
- Expired certificates
- Self-signed certificates
- No SSL on HTTPS ports
## Configuration
Edit `config.yaml` to customize:
```yaml
shodan:
rate_limit: 1 # queries per second
max_results: 100
vulnerability_checks:
enabled: true
passive_only: true
reporting:
formats:
- json
- csv
output_dir: "./reports"
filtering:
min_confidence_score: 70
exclude_honeypots: true
logging:
level: "INFO"
file: "./logs/scanner.log"
```
## Tips & Best Practices
1. **Start with specific templates** rather than broad queries
2. **Use --yes flag** to skip legal disclaimer for automated scans
3. **Check your Shodan credits** before running large scans
4. **Review reports in JSON format** for detailed findings
5. **Use scan history** to track your reconnaissance over time
## Troubleshooting
### "Invalid API key" error
- Check your `.env` file has the correct `SHODAN_API_KEY`
- Verify the key is valid at https://account.shodan.io/
### "Rate limit exceeded"
- Reduce `rate_limit` in `config.yaml`
- Wait a few minutes before retrying
### No results found
- Try different templates or queries
- Check if the service/product exists on Shodan
- Use `python -m src.main status` to verify API connectivity
## Legal Notice
⚠️ **Important:** This tool is for authorized security research only.
- Only scan systems you have permission to test
- Comply with all applicable laws and terms of service
- Responsibly disclose any findings
- Do not exploit discovered vulnerabilities
## Support
- Documentation: See `README.md` and `Outline.md`
- Bug Fixes: See `FIXES_APPLIED.md`
- Query Templates: Check `queries/` directory
- Logs: Check `logs/scanner.log` for detailed information
## Current Status
✅ All systems operational
✅ 13 query templates available
✅ 81 Shodan query credits remaining
✅ Database with 17 scans and 2253 findings
✅ All bug fixes applied and tested
+309
View File
@@ -0,0 +1,309 @@
<div align="center">
# 🛡️ AASRT
### AI Agent Security Reconnaissance Tool
*Imperial Security Reconnaissance System for AI Agent Discovery*
[![Python 3.11+](https://img.shields.io/badge/Python-3.11%2B-blue?logo=python&logoColor=white)](https://www.python.org/)
[![License: MIT](https://img.shields.io/badge/License-MIT-green.svg)](LICENSE)
[![Status: Production Ready](https://img.shields.io/badge/Status-Production%20Ready-brightgreen)](PROJECT_STATUS.md)
[![Version](https://img.shields.io/badge/Version-1.0.0-blue)](https://github.com/yourusername/aasrt/releases)
[![Tests](https://img.shields.io/badge/Tests-63%20Passing-success)](tests/)
[![Coverage](https://img.shields.io/badge/Coverage-35%25-yellow)](tests/)
</div>
---
## 🎯 Overview
**AASRT** (AI Agent Security Reconnaissance Tool) automates the discovery of publicly exposed AI agent implementations—including ClawdBot, AutoGPT, LangChain agents, Jupyter notebooks, and more—using the Shodan search engine API.
As organizations rapidly deploy AI agents and LLM-powered systems, many are inadvertently exposed to the public internet without proper security controls. AASRT helps security teams identify these exposures through **passive reconnaissance** before attackers do.
**Key Value Propositions:**
- 🔍 **Automated Discovery** — Find exposed AI infrastructure across the internet
- ⚠️ **Vulnerability Assessment** — Automatic detection of API key leaks, auth issues, and dangerous functionality
- 📊 **Risk Scoring** — CVSS-based scoring with severity categorization (Critical/High/Medium/Low)
- 📋 **Comprehensive Reporting** — JSON, CSV exports with persistent scan history
**Target Audience:** Security researchers, penetration testers, DevSecOps teams, and compliance officers conducting authorized security assessments.
---
## ✨ Features
| Feature | Description |
|---------|-------------|
| 🔍 **Multi-Source Search** | Shodan integration (Censys/BinaryEdge planned) |
| 🛡️ **Vulnerability Assessment** | Detects API key exposure, auth issues, debug mode, SSL problems |
| 📊 **Risk Scoring** | CVSS-based 0-10 scoring with severity levels |
| 📋 **13+ Query Templates** | Pre-built searches for AutoGPT, LangChain, Jupyter, and more |
| 🌐 **Web Dashboard** | Interactive Streamlit UI with Star Wars Imperial theme |
| ⌨️ **Full CLI** | Complete command-line interface for automation |
| 💾 **Scan History** | SQLite database for persistent findings (2,253+ findings tracked) |
| 🗺️ **Threat Mapping** | Interactive 3D globe visualization of discovered targets |
| 🐳 **Docker Ready** | Multi-stage Dockerfile with docker-compose for easy deployment |
| ✅ **Production Ready** | 63 passing tests, comprehensive input validation, retry logic |
---
## 📦 Installation
### Prerequisites
- **Python 3.11+** (tested on 3.13)
- **pip** package manager
- **Shodan API Key** — [Get one here](https://account.shodan.io/)
### Method 1: From Source (Recommended)
```bash
# Clone the repository
git clone https://github.com/yourusername/aasrt.git
cd aasrt
# Install dependencies
pip install -r requirements.txt
# (Optional) Install development dependencies
pip install -r requirements-dev.txt
```
### Method 2: Using pip (When Published)
```bash
pip install aasrt
```
### Method 3: Docker
```bash
# Build and run with Docker Compose
docker-compose up -d
# Or build manually
docker build -t aasrt .
docker run -e SHODAN_API_KEY=your_key aasrt
```
### Configuration
1. **Create environment file:**
```bash
cp .env.example .env
```
2. **Add your Shodan API key:**
```bash
# .env
SHODAN_API_KEY=your_shodan_api_key_here
```
3. **(Optional) Customize settings in `config.yaml`:**
```yaml
shodan:
rate_limit: 1 # Queries per second
max_results: 100 # Results per query
timeout: 30 # Request timeout
```
---
## 🚀 Quick Start
### Example 1: Run a Template Scan (CLI)
```bash
python -m src.main scan --template clawdbot_instances --yes
```
### Example 2: Launch Web Dashboard
```bash
streamlit run app.py
# Open http://localhost:8501 in your browser
```
### Example 3: Custom Shodan Query
```bash
python -m src.main scan --query 'http.title:"AutoGPT"' --yes
```
### Example 4: View Scan History
```bash
python -m src.main history
```
### Example 5: List Available Templates
```bash
python -m src.main templates
```
**Output:**
```
┌─────────────────────────────┬──────────┐
│ Template Name │ Queries │
├─────────────────────────────┼──────────┤
│ autogpt_instances │ 2 queries│
│ clawdbot_instances │ 3 queries│
│ langchain_agents │ 2 queries│
│ jupyter_notebooks │ 3 queries│
│ exposed_env_files │ 2 queries│
│ ... │ ... │
└─────────────────────────────┴──────────┘
```
---
## 📋 Available Query Templates
| Template | Target | Queries |
|----------|--------|---------|
| `clawdbot_instances` | ClawdBot AI dashboards | 5 |
| `autogpt_instances` | AutoGPT deployments | 5 |
| `langchain_agents` | LangChain agent implementations | 5 |
| `openai_exposed` | Exposed OpenAI integrations | 2 |
| `exposed_env_files` | Leaked .env configuration files | 2 |
| `debug_mode` | Services with debug mode enabled | 3 |
| `jupyter_notebooks` | Exposed Jupyter notebooks | 3 |
| `streamlit_apps` | Streamlit applications | 2 |
| `ai_dashboards` | Generic AI/LLM dashboards | 3 |
| `clawsec_advisories` | ClawSec CVE-matched targets | 10 |
**Create custom templates:** See [Custom Query Templates Guide](CUSTOM_QUERIES_GUIDE.md)
---
## 📖 Documentation
| Document | Description |
|----------|-------------|
| 📖 [Quick Start Guide](QUICK_START.md) | Detailed usage instructions and examples |
| 📋 [Custom Query Templates](CUSTOM_QUERIES_GUIDE.md) | Create your own Shodan query templates |
| 🗺️ [Map Visualization Guide](QUICK_MAP_GUIDE.md) | Interactive threat map features |
**Developer Documentation** (in `dev/docs/`):
| Document | Description |
|----------|-------------|
| 📊 [Project Status](dev/docs/PROJECT_STATUS.md) | Current system health and statistics |
| 📝 [Technical Specification](dev/docs/Outline.md) | Full product requirements document |
| 🔧 [Bug Fixes Log](dev/docs/FIXES_APPLIED.md) | Technical details of resolved issues |
| 🗺️ [Map Enhancements](dev/docs/MAP_ENHANCEMENTS.md) | Map visualization implementation details |
---
## ⚠️ Legal Disclaimer
> **🚨 IMPORTANT: This tool is for AUTHORIZED SECURITY RESEARCH and DEFENSIVE PURPOSES ONLY.**
>
> **Unauthorized access to computer systems is ILLEGAL under:**
> - 🇺🇸 CFAA (Computer Fraud and Abuse Act) — United States
> - 🇬🇧 Computer Misuse Act — United Kingdom
> - 🇪🇺 EU Directive on Attacks Against Information Systems
> - Similar laws exist in virtually every jurisdiction worldwide
>
> **By using this tool, you acknowledge and agree that:**
> 1. ✅ You have **explicit authorization** to scan target systems
> 2. ✅ You will **comply with all applicable laws** and terms of service
> 3. ✅ You will **responsibly disclose** any vulnerabilities discovered
> 4. ✅ You will **NOT exploit** discovered vulnerabilities
> 5. ✅ You understand this tool performs **passive reconnaissance only**
>
> **The authors assume NO LIABILITY for misuse of this tool.**
---
## 📁 Project Structure
```
aasrt/
├── src/ # Core application code
│ ├── main.py # CLI entry point
│ ├── core/ # Query manager, risk scorer, vulnerability assessor
│ ├── engines/ # Search engine integrations (Shodan)
│ ├── enrichment/ # Threat intelligence (ClawSec feed)
│ ├── reporting/ # JSON/CSV report generators
│ ├── storage/ # SQLite database layer
│ └── utils/ # Config, logging, validators, exceptions
├── queries/ # Query template YAML files
├── reports/ # Generated scan reports
├── logs/ # Application logs
├── data/ # SQLite database
├── dev/ # Development files (not for production)
│ ├── tests/ # Unit and integration tests (63 tests)
│ ├── docs/ # Developer documentation
│ ├── pytest.ini # Pytest configuration
│ └── requirements-dev.txt # Development dependencies
├── app.py # Streamlit web dashboard
├── config.yaml # Application configuration
├── requirements.txt # Production dependencies
├── Dockerfile # Multi-stage Docker build
└── docker-compose.yml # Docker Compose with PostgreSQL
```
---
## 🧪 Testing
```bash
# Run all unit tests (from project root)
python -m pytest dev/tests/unit/ -v
# Run with coverage
python -m pytest dev/tests/unit/ --cov=src --cov-report=term-missing
# Run specific test module
python -m pytest dev/tests/unit/test_validators.py -v
# Use pytest.ini config
python -m pytest -c dev/pytest.ini dev/tests/unit/ -v
```
**Current Status:** 63 tests passing, 35% coverage
---
## 🤝 Contributing
Contributions are welcome! Please follow these steps:
1. Fork the repository
2. Create a feature branch (`git checkout -b feature/amazing-feature`)
3. Commit your changes (`git commit -m 'Add amazing feature'`)
4. Push to the branch (`git push origin feature/amazing-feature`)
5. Open a Pull Request
For bugs or feature requests, please [open an issue](https://github.com/yourusername/aasrt/issues).
---
## 📄 License
This project is licensed under the **MIT License** — see the [LICENSE](LICENSE) file for details.
---
## 🙏 Acknowledgments
- [Shodan](https://www.shodan.io/) — Search engine for internet-connected devices
- [Streamlit](https://streamlit.io/) — Web dashboard framework
- [SQLAlchemy](https://www.sqlalchemy.org/) — Database ORM
- [Click](https://click.palletsprojects.com/) — CLI framework
- [Rich](https://rich.readthedocs.io/) — Terminal formatting
- The security research community
---
<div align="center">
**⭐ Star this repo if you find it useful!**
*May the Force be with your reconnaissance.* 🌟
</div>
+2212
View File
File diff suppressed because it is too large Load Diff
+43
View File
@@ -0,0 +1,43 @@
# AI Agent Security Reconnaissance Tool (AASRT)
# Configuration File
# Shodan Configuration
shodan:
enabled: true
rate_limit: 1 # queries per second
max_results: 100
timeout: 30
# Vulnerability Assessment
vulnerability_checks:
enabled: true
passive_only: true
timeout_per_check: 10
# Reporting
reporting:
formats:
- json
- csv
output_dir: "./reports"
anonymize_by_default: false
# Filtering
filtering:
whitelist_ips: []
whitelist_domains: []
min_confidence_score: 70
exclude_honeypots: true
# Logging
logging:
level: "INFO"
file: "./logs/scanner.log"
max_size_mb: 100
backup_count: 5
# Database
database:
type: "sqlite"
sqlite:
path: "./data/scanner.db"
+157
View File
@@ -0,0 +1,157 @@
# Bug Fixes Applied - February 9, 2026
## Summary
Fixed critical `AttributeError: 'NoneType' object has no attribute 'lower'` that was causing the vulnerability assessment to crash during scans.
## Root Cause
The issue occurred when Shodan API returned results where the `http` field was `None` instead of an empty dictionary. The code was using `.get('http', {})` which returns `{}` when the key doesn't exist, but returns `None` when the key exists with a `None` value.
When the vulnerability assessor tried to call `.lower()` on `http_info.get('title', '')`, if `title` was `None`, it would crash because `None` doesn't have a `.lower()` method.
## Files Modified
### 1. `src/core/vulnerability_assessor.py`
**Changes:**
- Line 316: Changed `title = http_info.get('title', '').lower()` to use the `or` operator for None-safety
- Line 289: Fixed `http_info` extraction in `_check_dangerous_functionality()`
- Line 330: Fixed `ssl_info` extraction in `_check_ssl_issues()`
- Line 344: Fixed `cert` extraction
- Line 371: Fixed `http_info` extraction in `_check_authentication()`
**Pattern Applied:**
```python
# Before (unsafe)
http_info = result.metadata.get('http', {})
title = http_info.get('title', '').lower()
# After (safe)
http_info = result.metadata.get('http') or {}
title = http_info.get('title') or ''
title = title.lower()
```
### 2. `src/engines/shodan_engine.py`
**Changes:**
- Line 178-179: Fixed SSL certificate parsing to handle None values
- Line 182: Fixed HTTP data extraction
- Line 192-204: Fixed location data extraction
- Line 198: Fixed SSL data assignment
**Pattern Applied:**
```python
# Before (unsafe)
http_data = match.get('http', {})
ssl_info = match.get('ssl', {}).get('cert', {})
# After (safe)
http_data = match.get('http') or {}
ssl_data = match.get('ssl') or {}
ssl_cert = ssl_data.get('cert') or {}
```
### 3. `src/core/risk_scorer.py`
**Changes:**
- Line 209-211: Fixed HTTP headers extraction
- Line 239-244: Fixed HTTP title extraction in `_is_ai_agent()`
**Pattern Applied:**
```python
# Before (unsafe)
http_headers = result.metadata.get('http', {}).get('headers', {})
# After (safe)
http_info = result.metadata.get('http') or {}
http_headers = http_info.get('headers', {})
```
### 4. `src/enrichment/threat_enricher.py`
**Changes:**
- Line 106: Fixed HTTP info extraction
## Testing Results
### Before Fix
```
AttributeError: 'NoneType' object has no attribute 'lower'
File "C:\Users\sweth\Desktop\Gemini\ShodanS\src\core\vulnerability_assessor.py", line 316, in _check_authentication
title = http_info.get('title', '').lower()
```
### After Fix
```
Scan completed successfully!
- Duration: 3.3s
- Total Results: 32
- Average Risk Score: 3.7/10
- Critical Findings: 4
- Low Findings: 28
```
## Commands Tested Successfully
1. **Scan with template:**
```bash
python -m src.main scan --template clawdbot_instances --yes
```
✅ Completed without errors
2. **Check engine status:**
```bash
python -m src.main status
```
✅ Shows Shodan API status, credits, and available templates
3. **List templates:**
```bash
python -m src.main templates
```
✅ Shows 13 available query templates
4. **View scan history:**
```bash
python -m src.main history
```
✅ Shows 17 completed scans with 2253 findings
## Key Improvements
1. **Null Safety:** All dictionary access patterns now handle `None` values correctly
2. **Defensive Programming:** Using `or {}` pattern ensures we always have a dictionary to work with
3. **Consistent Pattern:** Applied the same fix pattern across all similar code locations
4. **No Breaking Changes:** The fixes are backward compatible and don't change the API
## Prevention Strategy
To prevent similar issues in the future:
1. **Always use the `or` operator when extracting nested dictionaries:**
```python
data = source.get('key') or {}
```
2. **Check for None before calling string methods:**
```python
value = data.get('field') or ''
result = value.lower()
```
3. **Add type hints to catch these issues during development:**
```python
def process(data: Optional[Dict[str, Any]]) -> str:
info = data.get('http') or {}
title = info.get('title') or ''
return title.lower()
```
## Next Steps
The project is now fully functional and ready for use. All core features are working:
- ✅ Shodan API integration
- ✅ Vulnerability assessment
- ✅ Risk scoring
- ✅ Report generation (JSON/CSV)
- ✅ Database storage
- ✅ Query templates
- ✅ Scan history
You can now safely run scans against any of the 13 available templates without encountering the AttributeError.
+245
View File
@@ -0,0 +1,245 @@
# Map Visualization Enhancements
## 🎨 New Features Added
### 1. **Multiple Map Styles**
Choose from 4 different visualization modes:
- **3D Globe** - Interactive rotating sphere (default)
- **Flat Map** - Traditional 2D projection
- **Dark Matter** - Equirectangular dark theme
- **Natural Earth** - Natural earth projection
### 2. **Threat Connections**
- Toggle to show connections between critical threats
- Dotted lines connecting high-risk targets
- Visual network of attack surface
### 3. **Animated Markers**
- Toggle for animated threat markers
- Smooth rotation for 3D globe
- Auto-rotate and pause controls
### 4. **Enhanced Markers**
Different shapes for different threat levels:
- 💎 **Diamond** - Critical threats (red)
-**Square** - High threats (orange)
-**Circle** - Medium threats (yellow)
-**Circle** - Low threats (green)
### 5. **Improved Hover Information**
Rich tooltips showing:
- IP address and port (highlighted)
- Risk score with visual indicator
- Location (city, country)
- Service type
- Color-coded by severity
### 6. **Enhanced Styling**
- Larger, more visible markers (15-50px)
- Thicker borders (3px white outline)
- Better contrast with dark background
- Glowing effects on hover
- Professional color palette
### 7. **Better Geography**
- Enhanced coastlines (2px cyan)
- Visible country borders (cyan, 40% opacity)
- Dark land masses (15, 25, 35 RGB)
- Deep ocean color (5, 10, 20 RGB)
- Lake visualization
- Grid lines for reference
### 8. **Interactive Controls**
- Auto-rotate button for 3D globe
- Pause button to stop animation
- Drawing tools enabled
- Zoom and pan controls
- Mode bar with tools
### 9. **Threat Density Heatmap** (Right Panel)
- Top 10 countries by threat count
- Horizontal bar chart showing average risk per country
- Color gradient from green → yellow → orange → red
- Shows both count and average risk score
### 10. **New Analysis Sections**
#### 📡 Threat Surface Analysis
Two new visualizations below the map:
**A. Port Distribution**
- Bar chart of top 10 most common ports
- Color-coded by frequency
- Shows attack surface entry points
- Helps identify common vulnerabilities
**B. Service Breakdown**
- Donut chart of service types
- Shows technology stack distribution
- Color-coded by service
- Center shows total service count
## 🎯 Visual Improvements
### Color Scheme
- **Critical**: `#FF2D2D` (Bright Red)
- **High**: `#FF6B35` (Orange)
- **Medium**: `#FFE81F` (Star Wars Yellow)
- **Low**: `#39FF14` (Neon Green)
- **Info**: `#4BD5EE` (Cyan)
- **Background**: `rgba(0,0,0,0)` (Transparent)
### Typography
- **Headers**: Orbitron (Bold, 12px)
- **Data**: Share Tech Mono (11px)
- **Values**: Orbitron (14px)
### Animations
- Smooth marker transitions
- Globe rotation (3° per frame)
- Hover scale effects
- Fade-in for tooltips
## 🚀 How to Use
### Basic Usage
1. Run a scan to get results
2. Scroll to "GALACTIC THREAT MAP" section
3. View threats on interactive map
### Advanced Features
1. **Change Map Style**: Use dropdown to switch between 3D Globe, Flat Map, etc.
2. **Enable Connections**: Check "Show Threat Connections" to see network links
3. **Toggle Animation**: Check/uncheck "Animated Markers" for rotation
4. **Interact with Globe**:
- Click and drag to rotate
- Scroll to zoom
- Click markers for details
5. **Auto-Rotate**: Click "🔄 AUTO ROTATE" button for continuous rotation
6. **Pause**: Click "⏸️ PAUSE" to stop animation
### Understanding the Data
#### Geo Stats (Top Row)
- **🛰️ LOCATED**: Number of threats with GPS coordinates
- **🌐 SYSTEMS**: Number of unique countries
- **🏙️ SECTORS**: Number of unique cities
- **⭐ HOTSPOT**: Country with most threats
#### Map Legend
- Hover over legend items to highlight threat category
- Click legend items to show/hide categories
- Size of markers indicates risk score
#### Right Panel
- **TOP SYSTEMS**: Countries ranked by threat count
- **THREAT DENSITY**: Average risk score by country
#### Bottom Charts
- **PORT DISTRIBUTION**: Most targeted ports
- **SERVICE BREAKDOWN**: Technology distribution
## 📊 Technical Details
### Map Projections
- **Orthographic**: 3D sphere projection (best for global view)
- **Natural Earth**: Compromise between equal-area and conformal
- **Equirectangular**: Simple cylindrical projection
### Performance
- Optimized for up to 500 markers
- Smooth 60fps animations
- Lazy loading for large datasets
- Efficient frame rendering
### Responsive Design
- Adapts to screen size
- Mobile-friendly controls
- Touch-enabled on tablets
- High DPI display support
## 🎨 Customization Options
You can further customize by editing `app.py`:
### Marker Sizes
```python
df_map['size'] = df_map['risk_score'].apply(lambda x: max(15, x * 5))
```
Change `15` (min size) and `5` (multiplier) to adjust marker sizes.
### Animation Speed
```python
frames = [...] for i in range(0, 360, 3)
```
Change `3` to adjust rotation speed (higher = faster).
### Color Schemes
Modify the color variables in the marker loop:
```python
('critical', '#FF2D2D', 'CRITICAL', 'diamond')
```
### Map Height
```python
height=650
```
Adjust the height value to make map taller/shorter.
## 🐛 Troubleshooting
### Map Not Showing
- Ensure scan has results with geolocation data
- Check browser console for errors
- Verify Plotly is installed: `pip install plotly`
### Slow Performance
- Reduce number of results with `max_results` parameter
- Disable animations
- Use "Flat Map" instead of "3D Globe"
### Markers Too Small/Large
- Adjust size multiplier in code
- Check risk scores are calculated correctly
## 🌟 Best Practices
1. **Start with 3D Globe** for impressive visualization
2. **Enable Connections** for critical threats only (cleaner view)
3. **Use Flat Map** for detailed regional analysis
4. **Check Port Distribution** to identify common attack vectors
5. **Review Service Breakdown** to understand technology stack
6. **Export data** for further analysis in other tools
## 📈 Future Enhancements (Ideas)
- [ ] Time-series animation showing threat evolution
- [ ] Clustering for dense areas
- [ ] Custom marker icons per service type
- [ ] Heat map overlay option
- [ ] 3D terrain elevation based on risk
- [ ] Attack path visualization
- [ ] Real-time threat feed integration
- [ ] Comparison mode (multiple scans)
- [ ] Export map as image/video
- [ ] VR/AR mode for immersive viewing
## 🎉 Summary
The enhanced map visualization provides:
- **4 map styles** for different use cases
- **Interactive controls** for exploration
- **Rich tooltips** with detailed information
- **Visual connections** between threats
- **Additional analytics** (ports, services, density)
- **Professional styling** with Star Wars theme
- **Smooth animations** and transitions
- **Responsive design** for all devices
Perfect for security presentations, threat intelligence reports, and real-time monitoring dashboards!
---
**Version**: 2.0
**Last Updated**: February 9, 2026
**Theme**: Star Wars Imperial
+2098
View File
File diff suppressed because it is too large Load Diff
+295
View File
@@ -0,0 +1,295 @@
# AASRT Project Status Report
**Date:** February 9, 2026
**Status:** ✅ Fully Operational
---
## Executive Summary
The AI Agent Security Reconnaissance Tool (AASRT) is now fully functional and ready for production use. All critical bugs have been fixed, and the system has been tested successfully across multiple scan operations.
---
## System Health
### ✅ Core Components
- **Shodan API Integration:** Working (81 credits available, Dev plan)
- **Vulnerability Assessment:** Fixed and operational
- **Risk Scoring:** Operational
- **Report Generation:** JSON and CSV formats working
- **Database Storage:** SQLite operational (17 scans, 2253 findings)
- **Query Templates:** 13 templates available and tested
### 📊 Current Statistics
- **Total Scans Completed:** 17
- **Total Findings:** 2,253
- **Unique IPs Discovered:** 1,577
- **Available Templates:** 13
- **Shodan Credits Remaining:** 81
---
## Recent Bug Fixes (Feb 9, 2026)
### Critical Issue Resolved
**Problem:** `AttributeError: 'NoneType' object has no attribute 'lower'`
**Impact:** Caused vulnerability assessment to crash during scans
**Root Cause:** Shodan API returning `None` values for HTTP metadata instead of empty dictionaries
**Solution:** Applied defensive programming pattern across 4 files:
- `src/core/vulnerability_assessor.py` (5 fixes)
- `src/engines/shodan_engine.py` (4 fixes)
- `src/core/risk_scorer.py` (2 fixes)
- `src/enrichment/threat_enricher.py` (1 fix)
**Testing:** Verified with successful scan of 32 ClawdBot instances
See `FIXES_APPLIED.md` for detailed technical information.
---
## Available Features
### 1. Search Engines
- ✅ Shodan (fully integrated)
- ⏳ Censys (planned)
- ⏳ BinaryEdge (planned)
### 2. Query Templates
| Template | Purpose | Queries |
|----------|---------|---------|
| `clawdbot_instances` | Find ClawdBot dashboards | 3 |
| `autogpt_instances` | Find AutoGPT deployments | 2 |
| `langchain_agents` | Find LangChain agents | 2 |
| `openai_exposed` | Find exposed OpenAI integrations | 2 |
| `exposed_env_files` | Find exposed .env files | 2 |
| `debug_mode` | Find debug mode enabled | 3 |
| `jupyter_notebooks` | Find exposed Jupyter notebooks | 3 |
| `streamlit_apps` | Find Streamlit apps | 2 |
| `ai_dashboards` | Find AI dashboards | 3 |
| `autogpt` | AutoGPT comprehensive | 5 |
| `clawdbot` | ClawdBot comprehensive | 5 |
| `langchain` | LangChain comprehensive | 5 |
| `clawsec_advisories` | ClawSec CVE matching | 10 |
### 3. Vulnerability Detection
- ✅ API Key Exposure (7 types)
- ✅ Authentication Issues
- ✅ Dangerous Functionality (5 types)
- ✅ Information Disclosure (4 types)
- ✅ SSL/TLS Issues
- ✅ ClawSec CVE Integration
### 4. Risk Assessment
- ✅ CVSS-based scoring
- ✅ Severity categorization (Critical/High/Medium/Low/Info)
- ✅ Context-aware scoring
- ✅ Exploitability assessment
### 5. Reporting
- ✅ JSON format (machine-readable)
- ✅ CSV format (spreadsheet-friendly)
- ✅ Console output (human-readable)
- ✅ Database storage (SQLite)
### 6. CLI Commands
```bash
# Core Commands
python -m src.main status # Check system status
python -m src.main templates # List available templates
python -m src.main history # View scan history
python -m src.main scan # Run a scan
python -m src.main report # Generate report from scan
python -m src.main configure # Configuration wizard
# Scan Options
--template, -t # Use predefined template
--query, -q # Custom Shodan query
--engine, -e # Search engine (shodan/censys/all)
--max-results # Maximum results per engine
--output, -o # Output file path
--format, -f # Output format (json/csv/both)
--no-assess # Skip vulnerability assessment
--yes, -y # Skip legal disclaimer
```
---
## File Structure
```
ShodanS/
├── src/
│ ├── main.py # CLI entry point
│ ├── core/ # Core components
│ │ ├── query_manager.py # Query execution
│ │ ├── result_aggregator.py # Result deduplication
│ │ ├── vulnerability_assessor.py # Vuln detection
│ │ └── risk_scorer.py # Risk calculation
│ ├── engines/
│ │ ├── base.py # Base engine class
│ │ └── shodan_engine.py # Shodan integration
│ ├── enrichment/
│ │ ├── threat_enricher.py # Threat intelligence
│ │ └── clawsec_feed.py # ClawSec CVE feed
│ ├── reporting/
│ │ ├── json_reporter.py # JSON reports
│ │ └── csv_reporter.py # CSV reports
│ ├── storage/
│ │ └── database.py # SQLite database
│ └── utils/
│ ├── config.py # Configuration
│ ├── logger.py # Logging
│ ├── validators.py # Input validation
│ └── exceptions.py # Custom exceptions
├── queries/ # Query templates (YAML)
├── reports/ # Generated reports
├── logs/ # Log files
├── data/ # Database files
├── config.yaml # Main configuration
├── .env # API keys
├── requirements.txt # Python dependencies
├── README.md # Project documentation
├── Outline.md # Product requirements
├── QUICK_START.md # Quick start guide
├── FIXES_APPLIED.md # Bug fix documentation
└── PROJECT_STATUS.md # This file
```
---
## Configuration Files
### `.env`
```
SHODAN_API_KEY=oEm3fCUFctAByLoQkxHCgK8lFFp3t53w
```
### `config.yaml`
```yaml
shodan:
enabled: true
rate_limit: 1
max_results: 100
timeout: 30
vulnerability_checks:
enabled: true
passive_only: true
reporting:
formats: [json, csv]
output_dir: "./reports"
filtering:
min_confidence_score: 70
exclude_honeypots: true
logging:
level: "INFO"
file: "./logs/scanner.log"
```
---
## Testing Results
### Latest Scan (Feb 9, 2026 23:43)
```
Template: clawdbot_instances
Duration: 3.3 seconds
Results: 32 unique findings
Risk Distribution:
- Critical: 4
- High: 0
- Medium: 0
- Low: 28
Average Risk Score: 3.7/10
Status: ✅ Completed successfully
```
### All Commands Tested
-`python -m src.main status` - Working
-`python -m src.main templates` - Working
-`python -m src.main history` - Working
-`python -m src.main scan --template clawdbot_instances --yes` - Working
---
## Known Limitations
1. **Search Engines:** Only Shodan is currently implemented
2. **Rate Limiting:** Limited by Shodan API plan (1 query/second)
3. **Passive Scanning:** No active vulnerability verification
4. **False Positives:** Some findings may be honeypots or false positives
---
## Recommendations
### Immediate Use
1. ✅ Run reconnaissance scans using available templates
2. ✅ Review generated JSON reports for detailed findings
3. ✅ Use scan history to track discoveries over time
4. ✅ Export findings to CSV for analysis
### Future Enhancements
1. Add Censys and BinaryEdge engine support
2. Implement active vulnerability verification (with authorization)
3. Add web dashboard for visualization
4. Create custom query builder UI
5. Add automated alert system
6. Implement result export to SIEM systems
### Best Practices
1. Always use `--yes` flag for automated scans
2. Start with specific templates rather than broad queries
3. Monitor Shodan credit usage
4. Review and validate findings before taking action
5. Responsibly disclose any critical vulnerabilities found
---
## Support Resources
- **Quick Start Guide:** `QUICK_START.md`
- **Bug Fix Details:** `FIXES_APPLIED.md`
- **Full Documentation:** `README.md`
- **Product Requirements:** `Outline.md`
- **Logs:** `logs/scanner.log`
---
## Legal & Ethical Use
⚠️ **IMPORTANT DISCLAIMER**
This tool is for **authorized security research and defensive purposes only**.
**You MUST:**
- Have authorization to scan target systems
- Comply with all applicable laws and terms of service
- Responsibly disclose findings
- NOT exploit discovered vulnerabilities
**Unauthorized access is illegal under:**
- CFAA (Computer Fraud and Abuse Act) - United States
- Computer Misuse Act - United Kingdom
- Similar laws worldwide
---
## Conclusion
The AASRT project is **production-ready** and fully operational. All critical bugs have been resolved, and the system has been thoroughly tested. You can now confidently use this tool for authorized security reconnaissance of AI agent implementations.
**Next Step:** Review `QUICK_START.md` and begin your first scan!
---
**Project Maintainer:** Sweth
**Last Updated:** February 9, 2026
**Version:** 1.0.0 (MVP)
**Status:** ✅ Production Ready
+49
View File
@@ -0,0 +1,49 @@
[pytest]
# AASRT Test Configuration
# Test discovery
testpaths = tests
python_files = test_*.py
python_classes = Test*
python_functions = test_*
# Markers for categorizing tests
markers =
unit: Unit tests (fast, isolated)
integration: Integration tests (slower, may use real resources)
slow: Tests that take a long time to run
security: Security-focused tests
api: Tests that interact with external APIs
# Default options
addopts =
-v
--tb=short
--strict-markers
-ra
# Coverage settings (when using pytest-cov)
# Run with: pytest --cov=src --cov-report=html
# Ignore patterns
norecursedirs =
.git
__pycache__
.venv
venv
node_modules
.pytest_cache
# Timeout for individual tests (requires pytest-timeout)
timeout = 300
# Logging during tests
log_cli = true
log_cli_level = WARNING
log_cli_format = %(asctime)s [%(levelname)s] %(name)s: %(message)s
# Filter warnings
filterwarnings =
ignore::DeprecationWarning
ignore::PendingDeprecationWarning
+52
View File
@@ -0,0 +1,52 @@
# =============================================================================
# AASRT Development & Testing Dependencies
# =============================================================================
# Install with: pip install -r requirements.txt -r requirements-dev.txt
# =============================================================================
# -----------------------------------------------------------------------------
# Testing Framework
# -----------------------------------------------------------------------------
pytest==8.0.2 # Testing framework
pytest-cov==4.1.0 # Coverage reporting
pytest-mock==3.12.0 # Mocking utilities
pytest-timeout==2.2.0 # Test timeouts
pytest-asyncio==0.23.5 # Async test support
# -----------------------------------------------------------------------------
# Mocking & Test Utilities
# -----------------------------------------------------------------------------
responses==0.25.0 # Mock HTTP responses
freezegun==1.4.0 # Mock datetime
factory-boy==3.3.0 # Test fixtures factory
# -----------------------------------------------------------------------------
# Code Quality
# -----------------------------------------------------------------------------
flake8==7.0.0 # Linting
black==24.2.0 # Code formatting
isort==5.13.2 # Import sorting
mypy==1.8.0 # Static type checking
pylint==3.0.3 # Additional linting
# -----------------------------------------------------------------------------
# Security Scanning
# -----------------------------------------------------------------------------
bandit==1.7.7 # Security linting (SAST)
safety==3.0.1 # Dependency vulnerability scanning
pip-audit==2.7.1 # Python package audit
# -----------------------------------------------------------------------------
# Documentation
# -----------------------------------------------------------------------------
sphinx==7.2.6 # Documentation generator
sphinx-rtd-theme==2.0.0 # ReadTheDocs theme
myst-parser==2.0.0 # Markdown support for Sphinx
# -----------------------------------------------------------------------------
# Development Tools
# -----------------------------------------------------------------------------
pre-commit==3.6.2 # Git pre-commit hooks
ipython==8.22.1 # Enhanced Python shell
ipdb==0.13.13 # IPython debugger
+18
View File
@@ -0,0 +1,18 @@
"""
AASRT Test Suite
This package contains all tests for the AI Agent Security Reconnaissance Tool.
Test Categories:
- Unit tests: Test individual components in isolation
- Integration tests: Test component interactions
- End-to-end tests: Test complete workflows
Running Tests:
pytest # Run all tests
pytest tests/unit/ # Run unit tests only
pytest tests/integration/ # Run integration tests only
pytest -v --cov=src # Run with coverage
pytest -m "not slow" # Skip slow tests
"""
+181
View File
@@ -0,0 +1,181 @@
"""
Pytest Configuration and Shared Fixtures
This module provides shared fixtures and configuration for all tests.
"""
import os
import sys
import tempfile
from pathlib import Path
from typing import Any, Dict, Generator, List
from unittest.mock import MagicMock, patch
import pytest
# Add src to path for imports
sys.path.insert(0, str(Path(__file__).parent.parent))
# =============================================================================
# Environment Fixtures
# =============================================================================
@pytest.fixture(scope="session")
def test_env():
"""Set up test environment variables."""
original_env = os.environ.copy()
os.environ.update({
'SHODAN_API_KEY': 'test_api_key_12345',
'AASRT_ENVIRONMENT': 'testing',
'AASRT_LOG_LEVEL': 'DEBUG',
'AASRT_DEBUG': 'true',
})
yield os.environ
# Restore original environment
os.environ.clear()
os.environ.update(original_env)
@pytest.fixture
def temp_dir():
"""Create a temporary directory for test files."""
with tempfile.TemporaryDirectory() as tmpdir:
yield Path(tmpdir)
@pytest.fixture
def temp_db(temp_dir):
"""Create a temporary database path."""
return temp_dir / "test_scanner.db"
# =============================================================================
# Mock Data Fixtures
# =============================================================================
@pytest.fixture
def sample_shodan_result() -> Dict[str, Any]:
"""Sample Shodan API response."""
return {
'ip_str': '192.0.2.1',
'port': 8080,
'transport': 'tcp',
'hostnames': ['test.example.com'],
'org': 'Test Organization',
'asn': 'AS12345',
'isp': 'Test ISP',
'data': 'HTTP/1.1 200 OK\r\nServer: nginx\r\n\r\nClawdBot Dashboard',
'location': {
'country_code': 'US',
'country_name': 'United States',
'city': 'Test City',
'latitude': 37.7749,
'longitude': -122.4194
},
'http': {
'status': 200,
'title': 'ClawdBot Dashboard',
'server': 'nginx/1.18.0',
'html': '<html><body>ClawdBot Dashboard</body></html>'
},
'vulns': ['CVE-2021-44228'],
'timestamp': '2024-01-15T10:30:00.000000'
}
@pytest.fixture
def sample_search_results(sample_shodan_result) -> List[Dict[str, Any]]:
"""Multiple sample Shodan results."""
results = [sample_shodan_result]
# Add more varied results
results.append({
**sample_shodan_result,
'ip_str': '192.0.2.2',
'port': 3000,
'http': {
'status': 200,
'title': 'AutoGPT Interface',
'server': 'Python/3.11'
}
})
results.append({
**sample_shodan_result,
'ip_str': '192.0.2.3',
'port': 443,
'http': {
'status': 401,
'title': 'Login Required'
}
})
return results
@pytest.fixture
def sample_vulnerability() -> Dict[str, Any]:
"""Sample vulnerability data."""
return {
'check_name': 'exposed_dashboard',
'severity': 'HIGH',
'cvss_score': 7.5,
'description': 'Dashboard accessible without authentication',
'evidence': {'http_title': 'ClawdBot Dashboard'},
'remediation': 'Implement authentication',
'cwe_id': 'CWE-306'
}
# =============================================================================
# Mock Service Fixtures
# =============================================================================
@pytest.fixture
def mock_shodan_client():
"""Mock Shodan API client."""
with patch('shodan.Shodan') as mock:
client = MagicMock()
client.info.return_value = {
'plan': 'dev',
'query_credits': 100,
'scan_credits': 50
}
mock.return_value = client
yield client
@pytest.fixture
def mock_config(temp_dir, temp_db):
"""Mock configuration object."""
config = MagicMock()
config.get_shodan_key.return_value = 'test_api_key'
config.get.side_effect = lambda *args, default=None: {
('database', 'type'): 'sqlite',
('database', 'sqlite', 'path'): str(temp_db),
('logging', 'level'): 'DEBUG',
('reporting', 'output_dir'): str(temp_dir / 'reports'),
('vulnerability_checks',): {'passive_only': True},
}.get(args, default)
return config
# =============================================================================
# Database Fixtures
# =============================================================================
@pytest.fixture
def test_database(mock_config, temp_db):
"""Create a test database instance."""
from src.storage.database import Database
# Patch config to use temp database
with patch('src.storage.database.Config', return_value=mock_config):
db = Database(mock_config)
yield db
db.close()
+7
View File
@@ -0,0 +1,7 @@
"""
Integration Tests for AASRT
Integration tests verify that components work together correctly.
These tests may use real (test) databases but mock external APIs.
"""
@@ -0,0 +1,207 @@
"""
Integration Tests for Database Operations
Tests database operations with real SQLite database.
"""
import pytest
import tempfile
from pathlib import Path
from unittest.mock import patch, MagicMock
from datetime import datetime, timedelta
class TestDatabaseIntegration:
"""Integration tests for database with real SQLite."""
@pytest.fixture
def real_db(self):
"""Create a real SQLite database for testing."""
from src.storage.database import Database
from src.utils.config import Config
with tempfile.TemporaryDirectory() as tmpdir:
db_path = Path(tmpdir) / "test_scanner.db"
mock_config = MagicMock()
mock_config.get.side_effect = lambda *args, **kwargs: {
('database', 'type'): 'sqlite',
('database', 'sqlite', 'path'): str(db_path),
}.get(args, kwargs.get('default'))
mock_config.get_shodan_key.return_value = 'test_key'
with patch('src.storage.database.Config', return_value=mock_config):
db = Database(mock_config)
yield db
db.close()
def test_scan_lifecycle(self, real_db):
"""Test complete scan lifecycle: create, update, retrieve."""
# Create scan
scan_id = 'integration-test-scan-001'
real_db.save_scan({
'scan_id': scan_id,
'query': 'http.title:"ClawdBot"',
'engine': 'shodan',
'started_at': datetime.utcnow(),
'status': 'running',
'total_results': 0
})
# Update scan status
real_db.update_scan(scan_id, {
'status': 'completed',
'total_results': 25,
'completed_at': datetime.utcnow()
})
# Retrieve scan
scan = real_db.get_scan(scan_id)
assert scan is not None
def test_findings_association(self, real_db):
"""Test findings are properly associated with scans."""
scan_id = 'integration-test-scan-002'
# Create scan
real_db.save_scan({
'scan_id': scan_id,
'query': 'test query',
'engine': 'shodan',
'started_at': datetime.utcnow(),
'status': 'completed'
})
# Save multiple findings
for i in range(5):
real_db.save_finding({
'scan_id': scan_id,
'ip': f'192.0.2.{i+1}',
'port': 8080 + i,
'risk_score': 50 + i * 10,
'vulnerabilities': ['test_vuln']
})
# Retrieve findings
findings = real_db.get_findings_by_scan(scan_id)
assert len(findings) == 5
def test_scan_statistics(self, real_db):
"""Test scan statistics calculation."""
# Create multiple scans with different statuses
for i in range(10):
real_db.save_scan({
'scan_id': f'stats-test-{i:03d}',
'query': f'test query {i}',
'engine': 'shodan',
'started_at': datetime.utcnow() - timedelta(days=i),
'status': 'completed' if i % 2 == 0 else 'failed',
'total_results': i * 10
})
# Get statistics
if hasattr(real_db, 'get_scan_statistics'):
stats = real_db.get_scan_statistics()
assert 'total_scans' in stats or stats is not None
def test_concurrent_operations(self, real_db):
"""Test concurrent database operations."""
import threading
errors = []
def save_scan(scan_num):
try:
real_db.save_scan({
'scan_id': f'concurrent-test-{scan_num:03d}',
'query': f'test {scan_num}',
'engine': 'shodan',
'started_at': datetime.utcnow(),
'status': 'completed'
})
except Exception as e:
errors.append(e)
threads = [threading.Thread(target=save_scan, args=(i,)) for i in range(5)]
for t in threads:
t.start()
for t in threads:
t.join()
# Should complete without deadlocks
assert len(errors) == 0
def test_data_persistence(self):
"""Test that data persists across database connections."""
from src.storage.database import Database
with tempfile.TemporaryDirectory() as tmpdir:
db_path = Path(tmpdir) / "persistence_test.db"
scan_id = 'persistence-test-001'
mock_config = MagicMock()
mock_config.get.side_effect = lambda *args, **kwargs: {
('database', 'type'): 'sqlite',
('database', 'sqlite', 'path'): str(db_path),
}.get(args, kwargs.get('default'))
# First connection - create data
with patch('src.storage.database.Config', return_value=mock_config):
db1 = Database(mock_config)
db1.save_scan({
'scan_id': scan_id,
'query': 'test',
'engine': 'shodan',
'started_at': datetime.utcnow(),
'status': 'completed'
})
db1.close()
# Second connection - verify data exists
with patch('src.storage.database.Config', return_value=mock_config):
db2 = Database(mock_config)
scan = db2.get_scan(scan_id)
db2.close()
assert scan is not None
class TestDatabaseCleanup:
"""Tests for database cleanup and maintenance."""
@pytest.fixture
def real_db(self):
"""Create a real SQLite database for testing."""
from src.storage.database import Database
with tempfile.TemporaryDirectory() as tmpdir:
db_path = Path(tmpdir) / "cleanup_test.db"
mock_config = MagicMock()
mock_config.get.side_effect = lambda *args, **kwargs: {
('database', 'type'): 'sqlite',
('database', 'sqlite', 'path'): str(db_path),
}.get(args, kwargs.get('default'))
with patch('src.storage.database.Config', return_value=mock_config):
db = Database(mock_config)
yield db
db.close()
def test_delete_old_scans(self, real_db):
"""Test deleting old scan records."""
# Create old scans
for i in range(5):
real_db.save_scan({
'scan_id': f'old-scan-{i:03d}',
'query': f'test {i}',
'engine': 'shodan',
'started_at': datetime.utcnow() - timedelta(days=365),
'status': 'completed'
})
# If cleanup method exists, test it
if hasattr(real_db, 'cleanup_old_scans'):
deleted = real_db.cleanup_old_scans(days=30)
assert deleted >= 0
+198
View File
@@ -0,0 +1,198 @@
"""
Integration Tests for Scan Workflow
Tests the complete scan workflow from query to report generation.
"""
import pytest
import tempfile
from pathlib import Path
from unittest.mock import MagicMock, patch
from datetime import datetime
class TestEndToEndScan:
"""Integration tests for complete scan workflow."""
@pytest.fixture
def mock_shodan_response(self, sample_search_results):
"""Mock Shodan API response."""
return {
'matches': sample_search_results,
'total': len(sample_search_results)
}
@pytest.fixture
def temp_workspace(self):
"""Create temporary workspace for test files."""
with tempfile.TemporaryDirectory() as tmpdir:
workspace = Path(tmpdir)
(workspace / 'reports').mkdir()
(workspace / 'data').mkdir()
(workspace / 'logs').mkdir()
yield workspace
def test_scan_template_workflow(self, mock_shodan_response, temp_workspace):
"""Test scanning using a template."""
from unittest.mock import patch, MagicMock
mock_client = MagicMock()
mock_client.info.return_value = {'plan': 'dev', 'query_credits': 100}
mock_client.search.return_value = mock_shodan_response
with patch('shodan.Shodan', return_value=mock_client):
with patch.dict('os.environ', {
'SHODAN_API_KEY': 'test_key_12345',
'AASRT_REPORTS_DIR': str(temp_workspace / 'reports'),
'AASRT_DATA_DIR': str(temp_workspace / 'data'),
}):
# Import after patching
from src.core.query_manager import QueryManager
from src.utils.config import Config
config = Config()
qm = QueryManager(config)
# Check templates are available
templates = qm.get_available_templates()
assert len(templates) > 0
def test_custom_query_workflow(self, mock_shodan_response, temp_workspace):
"""Test scanning with a custom query."""
mock_client = MagicMock()
mock_client.info.return_value = {'plan': 'dev', 'query_credits': 100}
mock_client.search.return_value = mock_shodan_response
with patch('shodan.Shodan', return_value=mock_client):
with patch.dict('os.environ', {
'SHODAN_API_KEY': 'test_key_12345',
}):
from src.engines.shodan_engine import ShodanEngine
from src.utils.config import Config
config = Config()
engine = ShodanEngine(config=config)
engine._client = mock_client
results = engine.search('http.title:"Test"')
assert len(results) > 0
class TestVulnerabilityAssessmentIntegration:
"""Integration tests for vulnerability assessment pipeline."""
def test_assess_search_results(self, sample_search_results):
"""Test vulnerability assessment on search results."""
from src.core.vulnerability_assessor import VulnerabilityAssessor
from src.engines.base import SearchResult
assessor = VulnerabilityAssessor()
# Convert sample data to SearchResult
result = SearchResult(
ip=sample_search_results[0]['ip_str'],
port=sample_search_results[0]['port'],
protocol='tcp',
banner=sample_search_results[0].get('data', ''),
metadata=sample_search_results[0]
)
vulns = assessor.assess(result)
# Should return a list (may be empty if no vulns detected)
assert isinstance(vulns, list)
def test_risk_scoring_integration(self, sample_search_results):
"""Test risk scoring on assessed results."""
from src.core.risk_scorer import RiskScorer
from src.core.vulnerability_assessor import VulnerabilityAssessor
from src.engines.base import SearchResult
assessor = VulnerabilityAssessor()
scorer = RiskScorer()
result = SearchResult(
ip=sample_search_results[0]['ip_str'],
port=sample_search_results[0]['port'],
protocol='tcp',
banner=sample_search_results[0].get('data', ''),
metadata=sample_search_results[0]
)
vulns = assessor.assess(result)
score = scorer.score(result)
assert 0 <= score <= 100
class TestReportGenerationIntegration:
"""Integration tests for report generation."""
@pytest.fixture
def temp_reports_dir(self):
"""Create temporary reports directory."""
with tempfile.TemporaryDirectory() as tmpdir:
yield Path(tmpdir)
def test_json_report_generation(self, temp_reports_dir, sample_search_results):
"""Test JSON report generation."""
from src.reporting import JSONReporter, ScanReport
from src.engines.base import SearchResult
# Create scan report data
results = [
SearchResult(
ip=r['ip_str'],
port=r['port'],
protocol='tcp',
banner=r.get('data', ''),
metadata=r
) for r in sample_search_results
]
report = ScanReport(
scan_id='test-scan-001',
query='test query',
engine='shodan',
started_at=datetime.utcnow(),
completed_at=datetime.utcnow(),
results=results,
total_results=len(results)
)
reporter = JSONReporter(output_dir=str(temp_reports_dir))
output_path = reporter.generate(report)
assert Path(output_path).exists()
assert output_path.endswith('.json')
def test_csv_report_generation(self, temp_reports_dir, sample_search_results):
"""Test CSV report generation."""
from src.reporting import CSVReporter, ScanReport
from src.engines.base import SearchResult
results = [
SearchResult(
ip=r['ip_str'],
port=r['port'],
protocol='tcp',
banner=r.get('data', ''),
metadata=r
) for r in sample_search_results
]
report = ScanReport(
scan_id='test-scan-002',
query='test query',
engine='shodan',
started_at=datetime.utcnow(),
completed_at=datetime.utcnow(),
results=results,
total_results=len(results)
)
reporter = CSVReporter(output_dir=str(temp_reports_dir))
output_path = reporter.generate(report)
assert Path(output_path).exists()
assert output_path.endswith('.csv')
+7
View File
@@ -0,0 +1,7 @@
"""
Unit Tests for AASRT
Unit tests test individual components in isolation using mocks
for external dependencies.
"""
+165
View File
@@ -0,0 +1,165 @@
"""
Unit Tests for Config Module
Tests for src/utils/config.py
"""
import os
import tempfile
from pathlib import Path
from unittest.mock import patch, MagicMock
import pytest
class TestConfigLoading:
"""Tests for configuration loading."""
def test_load_from_yaml(self, temp_dir):
"""Test loading configuration from YAML file."""
from src.utils.config import Config
config_content = """
shodan:
enabled: true
rate_limit: 1
max_results: 100
logging:
level: DEBUG
"""
config_path = temp_dir / "config.yaml"
config_path.write_text(config_content)
with patch.dict(os.environ, {'AASRT_CONFIG_PATH': str(config_path)}):
config = Config()
assert config.get('shodan', 'enabled') is True
assert config.get('shodan', 'rate_limit') == 1
def test_environment_variable_override(self, temp_dir, monkeypatch):
"""Test environment variables override config file."""
from src.utils.config import Config
config_content = """
shodan:
enabled: true
rate_limit: 1
"""
config_path = temp_dir / "config.yaml"
config_path.write_text(config_content)
# Reset the Config singleton so we get a fresh instance
Config._instance = None
Config._config = {}
# Use monkeypatch for proper isolation - clear and set
monkeypatch.setenv('AASRT_CONFIG_PATH', str(config_path))
monkeypatch.setenv('SHODAN_API_KEY', 'env_api_key_12345')
config = Config()
assert config.get_shodan_key() == 'env_api_key_12345'
# Clean up - reset singleton for other tests
Config._instance = None
Config._config = {}
def test_default_values(self):
"""Test default configuration values are used when not specified."""
from src.utils.config import Config
with patch.dict(os.environ, {'SHODAN_API_KEY': 'test_key'}):
config = Config()
# Check default logging level if not set
log_level = config.get('logging', 'level', default='INFO')
assert log_level in ['DEBUG', 'INFO', 'WARNING', 'ERROR']
class TestConfigValidation:
"""Tests for configuration validation."""
def test_validate_shodan_key_format(self):
"""Test Shodan API key format validation."""
from src.utils.config import Config
# Valid key format (typically alphanumeric)
with patch.dict(os.environ, {'SHODAN_API_KEY': 'AbCdEf123456789012345678'}):
config = Config()
key = config.get_shodan_key()
assert key is not None
def test_missing_required_config(self):
"""Test handling of missing required configuration."""
from src.utils.config import Config
# Clear all Shodan-related env vars
env_copy = {k: v for k, v in os.environ.items() if 'SHODAN' not in k}
with patch.dict(os.environ, env_copy, clear=True):
config = Config()
# Should return None or raise exception for missing key
key = config.get_shodan_key()
# Depending on implementation, key could be None or empty
class TestConfigHealthCheck:
"""Tests for configuration health check."""
def test_health_check_returns_dict(self):
"""Test health_check returns a dictionary."""
from src.utils.config import Config
with patch.dict(os.environ, {'SHODAN_API_KEY': 'test_key'}):
config = Config()
health = config.health_check()
assert isinstance(health, dict)
assert 'status' in health or 'healthy' in health
def test_health_check_includes_key_info(self):
"""Test health_check includes API key status."""
from src.utils.config import Config
with patch.dict(os.environ, {'SHODAN_API_KEY': 'test_key'}):
config = Config()
health = config.health_check()
# Should indicate whether key is configured
assert health is not None
class TestConfigGet:
"""Tests for the get() method."""
def test_nested_key_access(self, temp_dir):
"""Test accessing nested configuration values."""
from src.utils.config import Config
config_content = """
database:
sqlite:
path: ./data/scanner.db
pool_size: 5
"""
config_path = temp_dir / "config.yaml"
config_path.write_text(config_content)
with patch.dict(os.environ, {'AASRT_CONFIG_PATH': str(config_path)}):
config = Config()
path = config.get('database', 'sqlite', 'path')
assert path is not None
def test_default_for_missing_key(self):
"""Test default value is returned for missing keys."""
from src.utils.config import Config
with patch.dict(os.environ, {'SHODAN_API_KEY': 'test_key'}):
config = Config()
value = config.get('nonexistent', 'key', default='default_value')
assert value == 'default_value'
def test_none_for_missing_key_no_default(self):
"""Test None is returned for missing keys without default."""
from src.utils.config import Config
with patch.dict(os.environ, {'SHODAN_API_KEY': 'test_key'}):
config = Config()
value = config.get('nonexistent', 'key')
assert value is None
+204
View File
@@ -0,0 +1,204 @@
"""
Unit Tests for Database Module
Tests for src/storage/database.py
"""
import pytest
from unittest.mock import MagicMock, patch
from datetime import datetime
class TestDatabaseInit:
"""Tests for Database initialization."""
def test_init_creates_tables(self, temp_db, mock_config):
"""Test database initialization creates tables."""
from src.storage.database import Database
mock_config.get.side_effect = lambda *args, **kwargs: {
('database', 'type'): 'sqlite',
('database', 'sqlite', 'path'): str(temp_db),
}.get(args, kwargs.get('default'))
db = Database(mock_config)
assert db is not None
db.close()
def test_init_sqlite_with_temp_path(self, temp_db, mock_config):
"""Test SQLite database with temp path."""
from src.storage.database import Database
mock_config.get.side_effect = lambda *args, **kwargs: {
('database', 'type'): 'sqlite',
('database', 'sqlite', 'path'): str(temp_db),
}.get(args, kwargs.get('default'))
db = Database(mock_config)
assert db is not None
assert db._db_type == 'sqlite'
db.close()
class TestDatabaseOperations:
"""Tests for database CRUD operations."""
@pytest.fixture
def db(self, temp_db, mock_config):
"""Create a test database instance."""
from src.storage.database import Database
mock_config.get.side_effect = lambda *args, **kwargs: {
('database', 'type'): 'sqlite',
('database', 'sqlite', 'path'): str(temp_db),
}.get(args, kwargs.get('default'))
db = Database(mock_config)
yield db
db.close()
def test_create_scan(self, db):
"""Test creating a scan record."""
scan = db.create_scan(
engines=['shodan'],
query='http.title:"ClawdBot"'
)
assert scan is not None
assert scan.scan_id is not None
def test_get_scan_by_id(self, db):
"""Test retrieving a scan by ID."""
# Create a scan first
scan = db.create_scan(
engines=['shodan'],
query='test query'
)
retrieved = db.get_scan(scan.scan_id)
assert retrieved is not None
assert retrieved.scan_id == scan.scan_id
def test_get_recent_scans(self, db):
"""Test retrieving recent scans."""
# Create a few scans
for i in range(3):
db.create_scan(
engines=['shodan'],
query=f'test query {i}'
)
scans = db.get_recent_scans(limit=10)
assert len(scans) >= 3
def test_add_findings(self, db):
"""Test adding findings to a scan."""
from src.engines.base import SearchResult
# First create a scan
scan = db.create_scan(
engines=['shodan'],
query='test'
)
# Create some search results
results = [
SearchResult(
ip='192.0.2.1',
port=8080,
banner='ClawdBot Dashboard',
vulnerabilities=['exposed_dashboard']
)
]
count = db.add_findings(scan.scan_id, results)
assert count >= 1
def test_update_scan(self, db):
"""Test updating a scan."""
# Create a scan
scan = db.create_scan(
engines=['shodan'],
query='test'
)
# Update it
updated = db.update_scan(
scan.scan_id,
status='completed',
total_results=5
)
assert updated is not None
assert updated.status == 'completed'
class TestDatabaseHealthCheck:
"""Tests for database health check."""
@pytest.fixture
def db(self, temp_db, mock_config):
"""Create a test database instance."""
from src.storage.database import Database
mock_config.get.side_effect = lambda *args, **kwargs: {
('database', 'type'): 'sqlite',
('database', 'sqlite', 'path'): str(temp_db),
}.get(args, kwargs.get('default'))
with patch('src.storage.database.Config', return_value=mock_config):
db = Database(mock_config)
yield db
db.close()
def test_health_check_returns_dict(self, db):
"""Test health_check returns a dictionary."""
health = db.health_check()
assert isinstance(health, dict)
def test_health_check_includes_status(self, db):
"""Test health_check includes status."""
health = db.health_check()
assert 'status' in health or 'healthy' in health
def test_health_check_includes_latency(self, db):
"""Test health_check includes latency measurement."""
health = db.health_check()
# Should have some form of latency/response time
has_latency = 'latency' in health or 'latency_ms' in health or 'response_time' in health
assert has_latency or health.get('status') == 'healthy'
class TestDatabaseSessionScope:
"""Tests for session_scope context manager."""
@pytest.fixture
def db(self, temp_db, mock_config):
"""Create a test database instance."""
from src.storage.database import Database
mock_config.get.side_effect = lambda *args, **kwargs: {
('database', 'type'): 'sqlite',
('database', 'sqlite', 'path'): str(temp_db),
}.get(args, kwargs.get('default'))
with patch('src.storage.database.Config', return_value=mock_config):
db = Database(mock_config)
yield db
db.close()
def test_session_scope_commits(self, db):
"""Test session_scope commits on success."""
with db.session_scope() as session:
# Perform some operation
pass
# Should complete without error
def test_session_scope_rollback_on_error(self, db):
"""Test session_scope rolls back on error."""
try:
with db.session_scope() as session:
raise ValueError("Test error")
except ValueError:
pass # Expected
# Session should have been rolled back
+125
View File
@@ -0,0 +1,125 @@
"""
Unit Tests for Risk Scorer Module
Tests for src/core/risk_scorer.py
"""
import pytest
from unittest.mock import MagicMock
class TestRiskScorer:
"""Tests for RiskScorer class."""
@pytest.fixture
def risk_scorer(self):
"""Create a RiskScorer instance."""
from src.core.risk_scorer import RiskScorer
return RiskScorer()
@pytest.fixture
def sample_vulnerabilities(self):
"""Create sample Vulnerability objects."""
from src.core.vulnerability_assessor import Vulnerability
return [
Vulnerability(
check_name='exposed_dashboard',
severity='HIGH',
cvss_score=7.5,
description='Dashboard exposed without authentication'
)
]
@pytest.fixture
def sample_result(self, sample_shodan_result):
"""Create a SearchResult with vulnerabilities."""
from src.engines.base import SearchResult
result = SearchResult(
ip=sample_shodan_result['ip_str'],
port=sample_shodan_result['port'],
banner=sample_shodan_result['data'],
metadata=sample_shodan_result
)
return result
def test_calculate_score_returns_dict(self, risk_scorer, sample_vulnerabilities):
"""Test that calculate_score returns a dictionary."""
result = risk_scorer.calculate_score(sample_vulnerabilities)
assert isinstance(result, dict)
assert 'overall_score' in result
def test_calculate_score_range_valid(self, risk_scorer, sample_vulnerabilities):
"""Test that score is within valid range (0-10)."""
result = risk_scorer.calculate_score(sample_vulnerabilities)
assert 0 <= result['overall_score'] <= 10
def test_high_risk_vulnerabilities_increase_score(self, risk_scorer):
"""Test that high-risk vulnerabilities increase score."""
from src.core.vulnerability_assessor import Vulnerability
# High severity vulnerabilities
high_vulns = [
Vulnerability(check_name='api_key_exposure', severity='CRITICAL', cvss_score=9.0, description='API key exposed'),
Vulnerability(check_name='no_authentication', severity='CRITICAL', cvss_score=9.5, description='No auth')
]
# Low severity vulnerabilities
low_vulns = [
Vulnerability(check_name='version_exposed', severity='LOW', cvss_score=2.0, description='Version info')
]
high_score = risk_scorer.calculate_score(high_vulns)['overall_score']
low_score = risk_scorer.calculate_score(low_vulns)['overall_score']
assert high_score > low_score
def test_empty_vulnerabilities_zero_score(self, risk_scorer):
"""Test that no vulnerabilities result in zero score."""
result = risk_scorer.calculate_score([])
assert result['overall_score'] == 0
def test_score_result_updates_search_result(self, risk_scorer, sample_result, sample_vulnerabilities):
"""Test that score_result updates the SearchResult."""
scored_result = risk_scorer.score_result(sample_result, sample_vulnerabilities)
assert scored_result.risk_score >= 0
assert 'risk_assessment' in scored_result.metadata
def test_context_multipliers_applied(self, risk_scorer, sample_vulnerabilities):
"""Test that context multipliers affect the score."""
# Score with no context
base_score = risk_scorer.calculate_score(sample_vulnerabilities)['overall_score']
# Score with context multiplier
context = {'public_internet': True, 'no_waf': True, 'ai_agent': True}
context_score = risk_scorer.calculate_score(sample_vulnerabilities, context)['overall_score']
# Context should increase or maintain score
assert context_score >= base_score
def test_severity_breakdown_included(self, risk_scorer, sample_vulnerabilities):
"""Test that severity breakdown is included in results."""
result = risk_scorer.calculate_score(sample_vulnerabilities)
assert 'severity_breakdown' in result
assert isinstance(result['severity_breakdown'], dict)
class TestRiskCategories:
"""Tests for risk categorization."""
@pytest.fixture
def risk_scorer(self):
"""Create a RiskScorer instance."""
from src.core.risk_scorer import RiskScorer
return RiskScorer()
def test_get_risk_level(self, risk_scorer):
"""Test risk level categorization."""
# Test if there's a method to get risk level string
if hasattr(risk_scorer, 'get_risk_level'):
assert risk_scorer.get_risk_level(95) in ['CRITICAL', 'HIGH', 'critical', 'high']
assert risk_scorer.get_risk_level(75) in ['HIGH', 'MEDIUM', 'high', 'medium']
assert risk_scorer.get_risk_level(50) in ['MEDIUM', 'medium']
assert risk_scorer.get_risk_level(25) in ['LOW', 'low']
assert risk_scorer.get_risk_level(10) in ['INFO', 'LOW', 'info', 'low']
+179
View File
@@ -0,0 +1,179 @@
"""
Unit Tests for Shodan Engine Module
Tests for src/engines/shodan_engine.py
"""
import pytest
from unittest.mock import MagicMock, patch
class TestShodanEngineInit:
"""Tests for ShodanEngine initialization."""
def test_init_with_valid_key(self):
"""Test initialization with valid API key."""
from src.engines.shodan_engine import ShodanEngine
with patch('shodan.Shodan') as mock_shodan:
mock_client = MagicMock()
mock_shodan.return_value = mock_client
engine = ShodanEngine(api_key='test_api_key_12345')
assert engine is not None
assert engine.name == 'shodan'
def test_init_without_key_raises_error(self):
"""Test initialization without API key raises error."""
from src.engines.shodan_engine import ShodanEngine
with pytest.raises(ValueError):
ShodanEngine(api_key='')
with pytest.raises((ValueError, TypeError)):
ShodanEngine(api_key=None)
class TestShodanEngineSearch:
"""Tests for ShodanEngine search functionality."""
@pytest.fixture
def engine(self, mock_shodan_client):
"""Create a ShodanEngine instance with mocked client."""
from src.engines.shodan_engine import ShodanEngine
with patch('shodan.Shodan', return_value=mock_shodan_client):
engine = ShodanEngine(api_key='test_api_key_12345')
engine._client = mock_shodan_client
return engine
def test_search_returns_results(self, engine, mock_shodan_client, sample_shodan_result):
"""Test search returns results."""
mock_shodan_client.search.return_value = {
'matches': [sample_shodan_result],
'total': 1
}
results = engine.search('http.title:"ClawdBot"')
assert isinstance(results, list)
def test_search_empty_query_raises_error(self, engine):
"""Test empty query raises error."""
with pytest.raises((ValueError, Exception)):
engine.search('')
def test_search_handles_api_error(self, engine, mock_shodan_client):
"""Test search handles API errors gracefully."""
import shodan
mock_shodan_client.search.side_effect = shodan.APIError('API Error')
from src.utils.exceptions import APIException
with pytest.raises((APIException, Exception)):
engine.search('test query')
def test_search_with_max_results(self, engine, mock_shodan_client, sample_shodan_result):
"""Test search respects max_results limit."""
mock_shodan_client.search.return_value = {
'matches': [sample_shodan_result],
'total': 1
}
results = engine.search('test', max_results=1)
assert len(results) <= 1
class TestShodanEngineCredentials:
"""Tests for credential validation."""
@pytest.fixture
def engine(self, mock_shodan_client):
"""Create a ShodanEngine instance."""
from src.engines.shodan_engine import ShodanEngine
with patch('shodan.Shodan', return_value=mock_shodan_client):
engine = ShodanEngine(api_key='test_api_key_12345')
engine._client = mock_shodan_client
return engine
def test_validate_credentials_success(self, engine, mock_shodan_client):
"""Test successful credential validation."""
mock_shodan_client.info.return_value = {'plan': 'dev', 'query_credits': 100}
result = engine.validate_credentials()
assert result is True
def test_validate_credentials_invalid_key(self, engine, mock_shodan_client):
"""Test invalid API key handling."""
import shodan
from src.utils.exceptions import AuthenticationException
mock_shodan_client.info.side_effect = shodan.APIError('Invalid API key')
with pytest.raises((AuthenticationException, Exception)):
engine.validate_credentials()
class TestShodanEngineQuota:
"""Tests for quota information."""
@pytest.fixture
def engine(self, mock_shodan_client):
"""Create a ShodanEngine instance."""
from src.engines.shodan_engine import ShodanEngine
with patch('shodan.Shodan', return_value=mock_shodan_client):
engine = ShodanEngine(api_key='test_api_key_12345')
engine._client = mock_shodan_client
return engine
def test_get_quota_info(self, engine, mock_shodan_client):
"""Test getting quota information."""
mock_shodan_client.info.return_value = {
'plan': 'dev',
'query_credits': 100,
'scan_credits': 50
}
quota = engine.get_quota_info()
assert isinstance(quota, dict)
def test_quota_info_handles_error(self, engine, mock_shodan_client):
"""Test quota info handles API errors."""
import shodan
from src.utils.exceptions import APIException
mock_shodan_client.info.side_effect = shodan.APIError('API Error')
# May either raise or return error info depending on implementation
try:
quota = engine.get_quota_info()
assert quota is not None
except (APIException, Exception):
pass # Acceptable if it raises
class TestShodanEngineRetry:
"""Tests for retry logic."""
def test_retry_on_transient_error(self, mock_shodan_client):
"""Test retry logic on transient errors."""
from src.engines.shodan_engine import ShodanEngine
import shodan
with patch('shodan.Shodan', return_value=mock_shodan_client):
engine = ShodanEngine(api_key='test_api_key_12345')
engine._client = mock_shodan_client
# First call fails, second succeeds
mock_shodan_client.search.side_effect = [
ConnectionError("Network error"),
{'matches': [], 'total': 0}
]
# Depending on implementation, this may retry or raise
try:
results = engine.search('test')
assert isinstance(results, list)
except Exception:
pass # Expected if retries exhausted
+180
View File
@@ -0,0 +1,180 @@
"""
Unit Tests for Validators Module
Tests for src/utils/validators.py
"""
import pytest
from src.utils.validators import (
validate_ip,
validate_domain,
validate_query,
validate_file_path,
validate_template_name,
is_safe_string,
sanitize_output,
)
from src.utils.exceptions import ValidationException
class TestValidateIP:
"""Tests for IP address validation."""
def test_valid_ipv4(self):
"""Test valid IPv4 addresses."""
assert validate_ip("192.168.1.1") is True
assert validate_ip("10.0.0.1") is True
assert validate_ip("172.16.0.1") is True
assert validate_ip("8.8.8.8") is True
def test_invalid_ipv4_raises_exception(self):
"""Test invalid IPv4 addresses raise ValidationException."""
with pytest.raises(ValidationException):
validate_ip("256.1.1.1")
with pytest.raises(ValidationException):
validate_ip("192.168.1")
with pytest.raises(ValidationException):
validate_ip("not.an.ip.address")
def test_empty_and_none_raises_exception(self):
"""Test empty and None values raise ValidationException."""
with pytest.raises(ValidationException):
validate_ip("")
with pytest.raises(ValidationException):
validate_ip(None)
def test_ipv6_addresses(self):
"""Test IPv6 address handling."""
# IPv6 addresses should be valid
assert validate_ip("::1") is True
assert validate_ip("2001:db8::1") is True
class TestValidateDomain:
"""Tests for domain validation."""
def test_valid_domains(self):
"""Test valid domain names."""
assert validate_domain("example.com") is True
assert validate_domain("sub.example.com") is True
assert validate_domain("test-site.example.org") is True
def test_invalid_domains_raises_exception(self):
"""Test invalid domain names raise ValidationException."""
with pytest.raises(ValidationException):
validate_domain("-invalid.com")
with pytest.raises(ValidationException):
validate_domain("invalid-.com")
def test_localhost_raises_exception(self):
"""Test localhost raises ValidationException (not a valid domain format)."""
with pytest.raises(ValidationException):
validate_domain("localhost")
class TestValidateQuery:
"""Tests for Shodan query validation."""
def test_valid_queries(self):
"""Test valid Shodan queries."""
assert validate_query('http.title:"ClawdBot"', engine='shodan') is True
assert validate_query("port:8080", engine='shodan') is True
assert validate_query("product:nginx", engine='shodan') is True
def test_empty_query_raises_exception(self):
"""Test empty queries raise ValidationException."""
with pytest.raises(ValidationException):
validate_query("", engine='shodan')
with pytest.raises(ValidationException):
validate_query(" ", engine='shodan')
def test_sql_injection_patterns_allowed(self):
"""Test SQL-like patterns are allowed (Shodan doesn't execute SQL)."""
# Shodan queries can contain SQL-like syntax without causing issues
result = validate_query("'; DROP TABLE users; --", engine='shodan')
assert result is True # No script tags or null bytes
class TestValidateFilePath:
"""Tests for file path validation."""
def test_valid_paths(self):
"""Test valid file paths return sanitized path."""
result = validate_file_path("reports/scan.json")
assert result is not None
assert "scan.json" in result
def test_directory_traversal_raises_exception(self):
"""Test directory traversal raises ValidationException."""
with pytest.raises(ValidationException):
validate_file_path("../../../etc/passwd")
with pytest.raises(ValidationException):
validate_file_path("..\\..\\windows\\system32")
def test_null_bytes_raises_exception(self):
"""Test null byte injection raises ValidationException."""
with pytest.raises(ValidationException):
validate_file_path("file.txt\x00.exe")
class TestValidateTemplateName:
"""Tests for template name validation."""
def test_valid_templates(self):
"""Test valid template names."""
assert validate_template_name("clawdbot_instances") is True
assert validate_template_name("autogpt_instances") is True
def test_invalid_template_raises_exception(self):
"""Test invalid template names raise ValidationException."""
with pytest.raises(ValidationException):
validate_template_name("nonexistent_template")
def test_empty_template_raises_exception(self):
"""Test empty template names raise ValidationException."""
with pytest.raises(ValidationException):
validate_template_name("")
with pytest.raises(ValidationException):
validate_template_name(None)
class TestIsSafeString:
"""Tests for safe string detection."""
def test_safe_strings(self):
"""Test safe strings pass validation."""
assert is_safe_string("hello world") is True
assert is_safe_string("ClawdBot Dashboard") is True
def test_script_tags_detected(self):
"""Test script tags are detected as unsafe."""
assert is_safe_string("<script>alert('xss')</script>") is False
def test_sql_patterns_allowed(self):
"""Test SQL-like patterns are allowed (is_safe_string checks XSS, not SQL)."""
# Note: is_safe_string focuses on XSS patterns, not SQL injection
result = is_safe_string("'; DROP TABLE users; --")
# This may or may not be detected depending on implementation
assert isinstance(result, bool)
class TestSanitizeOutput:
"""Tests for output sanitization."""
def test_password_redaction(self):
"""Test passwords are redacted."""
output = sanitize_output("password=mysecretpassword")
assert "mysecretpassword" not in output
def test_normal_text_unchanged(self):
"""Test normal text is not modified."""
text = "This is normal text without secrets"
assert sanitize_output(text) == text
def test_api_key_pattern_redaction(self):
"""Test API key patterns are redacted."""
# Test with patterns that match the redaction rules
output = sanitize_output("api_key=12345678901234567890")
# Depending on implementation, may or may not be redacted
assert output is not None
+110
View File
@@ -0,0 +1,110 @@
# =============================================================================
# AASRT - AI Agent Security Reconnaissance Tool
# Docker Compose Configuration for Production Deployment
# =============================================================================
#
# Usage:
# docker-compose up -d # Start all services
# docker-compose up -d aasrt # Start only AASRT (SQLite mode)
# docker-compose logs -f aasrt # View logs
# docker-compose down # Stop all services
#
# Environment:
# Copy .env.example to .env and configure your settings before starting.
#
# =============================================================================
services:
# ---------------------------------------------------------------------------
# AASRT Web Application (Streamlit)
# ---------------------------------------------------------------------------
aasrt:
build:
context: .
dockerfile: Dockerfile
container_name: aasrt-web
restart: unless-stopped
ports:
- "${STREAMLIT_SERVER_PORT:-8501}:8501"
environment:
# Shodan API (Required)
- SHODAN_API_KEY=${SHODAN_API_KEY}
# Application settings
- AASRT_ENVIRONMENT=${AASRT_ENVIRONMENT:-production}
- AASRT_LOG_LEVEL=${AASRT_LOG_LEVEL:-INFO}
- AASRT_DEBUG=${AASRT_DEBUG:-false}
# Rate limiting
- AASRT_MAX_SCANS_PER_HOUR=${AASRT_MAX_SCANS_PER_HOUR:-10}
- AASRT_SCAN_COOLDOWN=${AASRT_SCAN_COOLDOWN:-30}
# Database (use PostgreSQL in production)
- DB_TYPE=${DB_TYPE:-sqlite}
- DB_HOST=postgres
- DB_PORT=5432
- DB_NAME=${DB_NAME:-aasrt}
- DB_USER=${DB_USER:-aasrt}
- DB_PASSWORD=${DB_PASSWORD}
# ClawSec integration
- CLAWSEC_ENABLED=${CLAWSEC_ENABLED:-false}
- CLAWSEC_API_KEY=${CLAWSEC_API_KEY:-}
volumes:
# Persist data
- aasrt-data:/app/data
- aasrt-logs:/app/logs
- aasrt-reports:/app/reports
depends_on:
postgres:
condition: service_healthy
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:8501/_stcore/health"]
interval: 30s
timeout: 10s
retries: 3
start_period: 10s
networks:
- aasrt-network
# ---------------------------------------------------------------------------
# PostgreSQL Database (Production)
# ---------------------------------------------------------------------------
postgres:
image: postgres:16-alpine
container_name: aasrt-postgres
restart: unless-stopped
environment:
- POSTGRES_USER=${DB_USER:-aasrt}
- POSTGRES_PASSWORD=${DB_PASSWORD:?Database password required}
- POSTGRES_DB=${DB_NAME:-aasrt}
volumes:
- postgres-data:/var/lib/postgresql/data
healthcheck:
test: ["CMD-SHELL", "pg_isready -U ${DB_USER:-aasrt} -d ${DB_NAME:-aasrt}"]
interval: 10s
timeout: 5s
retries: 5
start_period: 10s
networks:
- aasrt-network
# Security: Only accessible from internal network
expose:
- "5432"
# =============================================================================
# Networks
# =============================================================================
networks:
aasrt-network:
driver: bridge
# =============================================================================
# Volumes
# =============================================================================
volumes:
aasrt-data:
driver: local
aasrt-logs:
driver: local
aasrt-reports:
driver: local
postgres-data:
driver: local
+18
View File
@@ -0,0 +1,18 @@
# AutoGPT Vulnerability Scan Template
name: AutoGPT Instances
description: Detect exposed AutoGPT AI agent dashboards
author: AASRT
version: 1.0
queries:
- 'http.title:"Auto-GPT"'
- 'http.title:"AutoGPT"'
- 'http.html:"autogpt" port:8000'
- 'http.html:"Auto-GPT" http.html:"OpenAI"'
- 'http.html:"autogpt" http.html:"execute"'
tags:
- ai-agent
- openai
- llm
- autonomous
+17
View File
@@ -0,0 +1,17 @@
# ClawdBot Vulnerability Scan Template
name: ClawdBot Instances
description: Detect exposed ClawdBot AI agent dashboards
author: AASRT
version: 1.0
queries:
- 'http.title:"ClawdBot Dashboard"'
- 'http.html:"ClawdBot" port:3000'
- 'product:"ClawdBot"'
- 'http.html:"anthropic" http.html:"api_key"'
- 'http.html:"ClawdBot" http.html:"execute"'
tags:
- ai-agent
- anthropic
- llm
+32
View File
@@ -0,0 +1,32 @@
# ClawSec Advisory Targets
# Scan templates for systems affected by ClawSec-published advisories
name: ClawSec Advisory Targets
description: Scan for systems affected by ClawSec-published security advisories
author: AASRT
version: 1.0
queries:
# Clawdbot vulnerabilities (per ClawSec advisories)
- 'http.title:"ClawdBot Dashboard"'
- 'http.html:"clawdbot" http.html:"execute"'
- 'product:"ClawdBot" port:3000'
# Moltbot exposure patterns
- 'http.title:"Moltbot Dashboard"'
- 'http.title:"Moltbot" http.html:"api"'
- 'product:"Moltbot" port:8080'
# Generic OpenClaw patterns
- 'http.html:"OpenClaw" http.html:"agent"'
- 'http.title:"OpenClaw" port:3000'
# AI Agent API exposure (common vulnerability patterns)
- 'http.html:"sk-ant-" http.html:"api"'
- 'http.html:"anthropic" http.html:"execute"'
tags:
- clawsec
- threat-intel
- cve
- ai-agents
- critical
+18
View File
@@ -0,0 +1,18 @@
# Custom Query Template Example
# Copy this file and modify for your own Shodan queries
name: Custom Template
description: Example custom query template
author: Your Name
version: 1.0
# List of Shodan queries to execute
queries:
- 'http.title:"Your Target"'
- 'http.html:"keyword" port:8080'
- 'product:"ProductName"'
# Tags for categorization
tags:
- custom
- example
+18
View File
@@ -0,0 +1,18 @@
# LangChain Agents Vulnerability Scan Template
name: LangChain Agents
description: Detect exposed LangChain AI agent implementations
author: AASRT
version: 1.0
queries:
- 'http.html:"langchain" http.html:"agent"'
- 'product:"LangChain"'
- 'http.html:"LangChain" port:8000'
- 'http.html:"langchain" http.html:"tool"'
- 'http.title:"LangChain" OR http.title:"Langchain"'
tags:
- ai-agent
- langchain
- llm
- framework
+71
View File
@@ -0,0 +1,71 @@
# =============================================================================
# AI Agent Security Reconnaissance Tool (AASRT)
# Production Dependencies - All versions pinned for reproducible builds
# =============================================================================
# To update dependencies: pip install pip-tools && pip-compile requirements.in
# =============================================================================
# -----------------------------------------------------------------------------
# API Clients
# -----------------------------------------------------------------------------
shodan==1.31.0 # Shodan API client for security reconnaissance
requests==2.31.0 # HTTP library for API calls
urllib3==2.2.0 # HTTP client (requests dependency, pinned for security)
# -----------------------------------------------------------------------------
# Data Processing
# -----------------------------------------------------------------------------
pandas==2.2.0 # Data manipulation and analysis
numpy==1.26.4 # Numerical computing (pandas dependency)
# -----------------------------------------------------------------------------
# Database
# -----------------------------------------------------------------------------
sqlalchemy==2.0.35 # ORM for database operations
alembic==1.13.1 # Database migration support
greenlet==3.0.3 # SQLAlchemy async support
# -----------------------------------------------------------------------------
# Reporting
# -----------------------------------------------------------------------------
jinja2==3.1.3 # Template engine for report generation
markupsafe==2.1.5 # HTML escaping (jinja2 dependency)
# -----------------------------------------------------------------------------
# CLI & UI
# -----------------------------------------------------------------------------
click==8.1.7 # Command-line interface framework
rich==13.7.0 # Terminal formatting and progress bars
# -----------------------------------------------------------------------------
# Configuration & Utilities
# -----------------------------------------------------------------------------
pyyaml==6.0.2 # YAML configuration parsing
python-dotenv==1.0.1 # Environment variable management
pydantic==2.6.3 # Data validation and settings management
pydantic-settings==2.2.1 # Pydantic settings management
validators==0.22.0 # Input validation helpers
tenacity==8.2.3 # Retry logic with exponential backoff
# -----------------------------------------------------------------------------
# Security
# -----------------------------------------------------------------------------
cryptography==42.0.5 # Cryptographic operations
# -----------------------------------------------------------------------------
# Web Dashboard
# -----------------------------------------------------------------------------
streamlit==1.30.0 # Web dashboard framework
plotly==5.18.0 # Interactive visualizations
watchdog==4.0.0 # File system monitoring (streamlit dependency)
# -----------------------------------------------------------------------------
# Logging & Monitoring
# -----------------------------------------------------------------------------
python-json-logger==2.0.7 # Structured JSON logging
structlog==24.1.0 # Structured logging framework
# -----------------------------------------------------------------------------
# Testing (install with: pip install -r requirements.txt -r requirements-dev.txt)
# -----------------------------------------------------------------------------
# See requirements-dev.txt for testing and development dependencies
+48
View File
@@ -0,0 +1,48 @@
"""Setup script for AASRT."""
from setuptools import setup, find_packages
with open("README.md", "r", encoding="utf-8") as fh:
long_description = fh.read()
with open("requirements.txt", "r", encoding="utf-8") as fh:
requirements = [line.strip() for line in fh if line.strip() and not line.startswith("#")]
setup(
name="aasrt",
version="1.0.0",
author="AGK",
author_email="security@example.com",
description="AI Agent Security Reconnaissance Tool",
long_description=long_description,
long_description_content_type="text/markdown",
url="https://github.com/yourusername/aasrt",
packages=find_packages(),
classifiers=[
"Development Status :: 4 - Beta",
"Environment :: Console",
"Intended Audience :: Developers",
"Intended Audience :: Information Technology",
"Intended Audience :: System Administrators",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Topic :: Security",
"Topic :: Internet :: WWW/HTTP",
],
python_requires=">=3.9",
install_requires=requirements,
entry_points={
"console_scripts": [
"aasrt=src.main:main",
],
},
include_package_data=True,
package_data={
"": ["queries/*.yaml", "config.yaml"],
},
)
+4
View File
@@ -0,0 +1,4 @@
"""AI Agent Security Reconnaissance Tool (AASRT)"""
__version__ = "1.0.0"
__author__ = "AGK"
+12
View File
@@ -0,0 +1,12 @@
"""Alert modules for AASRT.
This module will contain alerting capabilities:
- Email notifications
- Slack webhooks
- Discord webhooks
- Telegram bot integration
These are planned for Phase 3 implementation.
"""
__all__ = []
+14
View File
@@ -0,0 +1,14 @@
"""Core engine components for AASRT."""
from .query_manager import QueryManager
from .result_aggregator import ResultAggregator
from .vulnerability_assessor import VulnerabilityAssessor, Vulnerability
from .risk_scorer import RiskScorer
__all__ = [
'QueryManager',
'ResultAggregator',
'VulnerabilityAssessor',
'Vulnerability',
'RiskScorer'
]
+252
View File
@@ -0,0 +1,252 @@
"""Query management and execution for AASRT."""
import os
from pathlib import Path
from typing import Any, Dict, List, Optional
import yaml
from src.engines import SearchResult, ShodanEngine
from src.utils.config import Config
from src.utils.logger import get_logger
from src.utils.exceptions import APIException, ConfigurationException
logger = get_logger(__name__)
class QueryManager:
"""Manages search queries using Shodan."""
# Built-in query templates
DEFAULT_TEMPLATES = {
"clawdbot_instances": [
'http.title:"ClawdBot Dashboard"',
'http.html:"ClawdBot" port:3000',
'product:"ClawdBot"'
],
"autogpt_instances": [
'http.title:"Auto-GPT"',
'http.html:"autogpt" port:8000'
],
"langchain_agents": [
'http.html:"langchain" http.html:"agent"',
'product:"LangChain"'
],
"openai_exposed": [
'http.title:"OpenAI Playground"',
'http.html:"sk-" http.html:"openai"'
],
"exposed_env_files": [
'http.html:".env" http.html:"API_KEY"',
'http.title:"Index of" http.html:".env"'
],
"debug_mode": [
'http.html:"DEBUG=True"',
'http.html:"development mode"',
'http.html:"stack trace"'
],
"ai_dashboards": [
'http.title:"AI Dashboard"',
'http.title:"LLM" http.html:"chat"',
'http.html:"anthropic" http.html:"claude"'
],
"jupyter_notebooks": [
'http.title:"Jupyter Notebook"',
'http.title:"JupyterLab"',
'http.html:"jupyter" port:8888'
],
"streamlit_apps": [
'http.html:"streamlit"',
'http.title:"Streamlit"'
]
}
def __init__(self, config: Optional[Config] = None):
"""
Initialize QueryManager.
Args:
config: Configuration instance
"""
self.config = config or Config()
self.engine: Optional[ShodanEngine] = None
self.templates: Dict[str, List[str]] = self.DEFAULT_TEMPLATES.copy()
self._initialize_engine()
self._load_custom_templates()
def _initialize_engine(self) -> None:
"""Initialize Shodan engine."""
api_key = self.config.get_shodan_key()
if api_key:
shodan_config = self.config.get_shodan_config()
self.engine = ShodanEngine(
api_key=api_key,
rate_limit=shodan_config.get('rate_limit', 1.0),
timeout=shodan_config.get('timeout', 30),
max_results=shodan_config.get('max_results', 100)
)
logger.info("Shodan engine initialized")
else:
logger.warning("Shodan API key not provided")
def _load_custom_templates(self) -> None:
"""Load custom query templates from YAML files."""
queries_dir = Path("queries")
if not queries_dir.exists():
return
for yaml_file in queries_dir.glob("*.yaml"):
try:
with open(yaml_file, 'r') as f:
data = yaml.safe_load(f)
if data and 'queries' in data:
template_name = yaml_file.stem
# Support both list format and dict format
queries = data['queries']
if isinstance(queries, dict) and 'shodan' in queries:
self.templates[template_name] = queries['shodan']
elif isinstance(queries, list):
self.templates[template_name] = queries
logger.debug(f"Loaded query template: {template_name}")
except yaml.YAMLError as e:
logger.error(f"Failed to parse {yaml_file}: {e}")
def is_available(self) -> bool:
"""Check if Shodan engine is available."""
return self.engine is not None
def get_available_templates(self) -> List[str]:
"""Get list of available query templates."""
return list(self.templates.keys())
def validate_engine(self) -> bool:
"""
Validate Shodan credentials.
Returns:
True if credentials are valid
"""
if not self.engine:
return False
try:
return self.engine.validate_credentials()
except Exception as e:
logger.error(f"Failed to validate Shodan: {e}")
return False
def get_quota_info(self) -> Dict[str, Any]:
"""Get Shodan API quota information."""
if not self.engine:
return {'error': 'Engine not initialized'}
return self.engine.get_quota_info()
def execute_query(
self,
query: str,
max_results: Optional[int] = None
) -> List[SearchResult]:
"""
Execute a search query.
Args:
query: Shodan search query
max_results: Maximum results to return
Returns:
List of SearchResult objects
"""
if not self.engine:
raise ConfigurationException("Shodan engine not initialized. Check your API key.")
try:
results = self.engine.search(query, max_results)
logger.info(f"Query returned {len(results)} results")
return results
except APIException as e:
logger.error(f"Query failed: {e}")
raise
def execute_template(
self,
template_name: str,
max_results: Optional[int] = None
) -> List[SearchResult]:
"""
Execute all queries from a template.
Args:
template_name: Name of the query template
max_results: Maximum results per query
Returns:
Combined list of results from all queries
"""
if template_name not in self.templates:
raise ConfigurationException(f"Template not found: {template_name}")
if not self.engine:
raise ConfigurationException("Shodan engine not initialized. Check your API key.")
queries = self.templates[template_name]
all_results = []
for query in queries:
try:
results = self.engine.search(query, max_results)
all_results.extend(results)
except APIException as e:
logger.error(f"Query failed: {query} - {e}")
logger.info(f"Template '{template_name}' returned {len(all_results)} total results")
return all_results
def count_results(self, query: str) -> int:
"""
Get count of results for a query without consuming credits.
Args:
query: Search query
Returns:
Number of results
"""
if not self.engine:
return 0
return self.engine.count(query)
def add_custom_template(self, name: str, queries: List[str]) -> None:
"""
Add a custom query template.
Args:
name: Template name
queries: List of Shodan queries
"""
self.templates[name] = queries
logger.info(f"Added custom template: {name}")
def save_template(self, name: str, path: Optional[str] = None) -> None:
"""
Save a template to a YAML file.
Args:
name: Template name
path: Output file path (default: queries/{name}.yaml)
"""
if name not in self.templates:
raise ConfigurationException(f"Template not found: {name}")
output_path = path or f"queries/{name}.yaml"
os.makedirs(os.path.dirname(output_path), exist_ok=True)
template_data = {
'name': name,
'description': f"Query template for {name}",
'queries': self.templates[name]
}
with open(output_path, 'w') as f:
yaml.dump(template_data, f, default_flow_style=False)
logger.info(f"Saved template to {output_path}")
+304
View File
@@ -0,0 +1,304 @@
"""Result aggregation and deduplication for AASRT."""
from typing import Any, Dict, List, Optional, Set
from datetime import datetime
from collections import defaultdict
from src.engines import SearchResult
from src.utils.logger import get_logger
logger = get_logger(__name__)
class ResultAggregator:
"""Aggregates and deduplicates search results from multiple engines."""
def __init__(
self,
dedupe_by: str = "ip_port",
merge_metadata: bool = True,
prefer_engine: Optional[str] = None
):
"""
Initialize ResultAggregator.
Args:
dedupe_by: Deduplication key ("ip_port", "ip", or "hostname")
merge_metadata: Whether to merge metadata from duplicate results
prefer_engine: Preferred engine when resolving conflicts
"""
self.dedupe_by = dedupe_by
self.merge_metadata = merge_metadata
self.prefer_engine = prefer_engine
def aggregate(
self,
results: Dict[str, List[SearchResult]]
) -> List[SearchResult]:
"""
Aggregate results from multiple engines.
Args:
results: Dictionary mapping engine names to result lists
Returns:
Deduplicated and merged list of results
"""
all_results = []
# Flatten results
for engine_name, engine_results in results.items():
for result in engine_results:
result.source_engine = engine_name
all_results.append(result)
logger.info(f"Aggregating {len(all_results)} total results")
# Deduplicate
deduplicated = self._deduplicate(all_results)
logger.info(f"After deduplication: {len(deduplicated)} unique results")
return deduplicated
def _get_dedupe_key(self, result: SearchResult) -> str:
"""Get deduplication key for a result."""
if self.dedupe_by == "ip_port":
return f"{result.ip}:{result.port}"
elif self.dedupe_by == "ip":
return result.ip
elif self.dedupe_by == "hostname":
return result.hostname or result.ip
else:
return f"{result.ip}:{result.port}"
def _deduplicate(self, results: List[SearchResult]) -> List[SearchResult]:
"""Deduplicate results based on configured key."""
seen: Dict[str, SearchResult] = {}
for result in results:
key = self._get_dedupe_key(result)
if key not in seen:
seen[key] = result
else:
# Merge with existing result
existing = seen[key]
seen[key] = self._merge_results(existing, result)
return list(seen.values())
def _merge_results(
self,
existing: SearchResult,
new: SearchResult
) -> SearchResult:
"""
Merge two results for the same target.
Args:
existing: Existing result
new: New result to merge
Returns:
Merged result
"""
# Prefer result from preferred engine
if self.prefer_engine:
if new.source_engine == self.prefer_engine:
base = new
other = existing
else:
base = existing
other = new
else:
# Default: prefer result with more information
if len(new.metadata) > len(existing.metadata):
base = new
other = existing
else:
base = existing
other = new
# Merge vulnerabilities (union)
merged_vulns = list(set(base.vulnerabilities + other.vulnerabilities))
# Merge metadata if enabled
if self.merge_metadata:
merged_metadata = {**other.metadata, **base.metadata}
# Track source engines
engines = set()
if base.metadata.get('source_engines'):
engines.update(base.metadata['source_engines'])
if other.metadata.get('source_engines'):
engines.update(other.metadata['source_engines'])
engines.add(base.source_engine)
engines.add(other.source_engine)
merged_metadata['source_engines'] = list(engines)
else:
merged_metadata = base.metadata
# Take highest risk score
risk_score = max(base.risk_score, other.risk_score)
# Take highest confidence
confidence = max(base.confidence, other.confidence)
return SearchResult(
ip=base.ip,
port=base.port,
hostname=base.hostname or other.hostname,
service=base.service or other.service,
banner=base.banner or other.banner,
vulnerabilities=merged_vulns,
metadata=merged_metadata,
source_engine=base.source_engine,
timestamp=base.timestamp,
risk_score=risk_score,
confidence=confidence
)
def filter_by_confidence(
self,
results: List[SearchResult],
min_confidence: int = 70
) -> List[SearchResult]:
"""Filter results by minimum confidence score."""
filtered = [r for r in results if r.confidence >= min_confidence]
logger.info(f"Filtered by confidence >= {min_confidence}: {len(filtered)} results")
return filtered
def filter_by_risk_score(
self,
results: List[SearchResult],
min_score: float = 0.0
) -> List[SearchResult]:
"""Filter results by minimum risk score."""
filtered = [r for r in results if r.risk_score >= min_score]
logger.info(f"Filtered by risk >= {min_score}: {len(filtered)} results")
return filtered
def filter_whitelist(
self,
results: List[SearchResult],
whitelist_ips: Optional[List[str]] = None,
whitelist_domains: Optional[List[str]] = None
) -> List[SearchResult]:
"""Filter out whitelisted IPs and domains."""
if not whitelist_ips and not whitelist_domains:
return results
whitelist_ips = set(whitelist_ips or [])
whitelist_domains = set(whitelist_domains or [])
filtered = []
for result in results:
if result.ip in whitelist_ips:
continue
if result.hostname and result.hostname in whitelist_domains:
continue
# Check if hostname ends with any whitelisted domain
if result.hostname:
skip = False
for domain in whitelist_domains:
if result.hostname.endswith(f".{domain}") or result.hostname == domain:
skip = True
break
if skip:
continue
filtered.append(result)
excluded = len(results) - len(filtered)
if excluded > 0:
logger.info(f"Excluded {excluded} whitelisted results")
return filtered
def group_by_ip(
self,
results: List[SearchResult]
) -> Dict[str, List[SearchResult]]:
"""Group results by IP address."""
grouped = defaultdict(list)
for result in results:
grouped[result.ip].append(result)
return dict(grouped)
def group_by_service(
self,
results: List[SearchResult]
) -> Dict[str, List[SearchResult]]:
"""Group results by service type."""
grouped = defaultdict(list)
for result in results:
service = result.service or "unknown"
grouped[service].append(result)
return dict(grouped)
def get_statistics(self, results: List[SearchResult]) -> Dict[str, Any]:
"""
Get aggregate statistics for results.
Args:
results: List of search results
Returns:
Statistics dictionary
"""
if not results:
return {
'total_results': 0,
'unique_ips': 0,
'unique_hostnames': 0,
'engines_used': [],
'vulnerability_counts': {},
'risk_distribution': {},
'top_services': []
}
# Count unique IPs and hostnames
unique_ips = set(r.ip for r in results)
unique_hostnames = set(r.hostname for r in results if r.hostname)
# Count engines
engines = set()
for r in results:
if r.metadata.get('source_engines'):
engines.update(r.metadata['source_engines'])
else:
engines.add(r.source_engine)
# Count vulnerabilities
vuln_counts = defaultdict(int)
for r in results:
for vuln in r.vulnerabilities:
vuln_counts[vuln] += 1
# Risk distribution
risk_dist = {
'critical': len([r for r in results if r.risk_score >= 9.0]),
'high': len([r for r in results if 7.0 <= r.risk_score < 9.0]),
'medium': len([r for r in results if 4.0 <= r.risk_score < 7.0]),
'low': len([r for r in results if r.risk_score < 4.0])
}
# Top services
service_counts = defaultdict(int)
for r in results:
service_counts[r.service or "unknown"] += 1
top_services = sorted(
service_counts.items(),
key=lambda x: x[1],
reverse=True
)[:10]
return {
'total_results': len(results),
'unique_ips': len(unique_ips),
'unique_hostnames': len(unique_hostnames),
'engines_used': list(engines),
'vulnerability_counts': dict(vuln_counts),
'risk_distribution': risk_dist,
'top_services': top_services,
'average_risk_score': sum(r.risk_score for r in results) / len(results)
}
+313
View File
@@ -0,0 +1,313 @@
"""Risk scoring engine for AASRT."""
from typing import Any, Dict, List
from .vulnerability_assessor import Vulnerability
from src.engines import SearchResult
from src.utils.logger import get_logger
logger = get_logger(__name__)
class RiskScorer:
"""Calculates risk scores for targets based on vulnerabilities."""
# Severity weights for scoring
SEVERITY_WEIGHTS = {
'CRITICAL': 1.5,
'HIGH': 1.2,
'MEDIUM': 1.0,
'LOW': 0.5,
'INFO': 0.1
}
# Context multipliers
CONTEXT_MULTIPLIERS = {
'public_internet': 1.2,
'no_waf': 1.1,
'known_vulnerable_version': 1.3,
'ai_agent': 1.2, # AI agents may have additional risk
'clawsec_cve': 1.4, # Known ClawSec CVE vulnerability
'clawsec_critical': 1.5, # Critical ClawSec CVE
}
def __init__(self, config: Dict[str, Any] = None):
"""
Initialize RiskScorer.
Args:
config: Configuration options
"""
self.config = config or {}
def calculate_score(
self,
vulnerabilities: List[Vulnerability],
context: Dict[str, Any] = None
) -> Dict[str, Any]:
"""
Calculate risk score based on vulnerabilities.
Formula:
- Base score: Highest CVSS score found
- Adjusted: base * (1 + 0.1 * critical_count)
- Context multipliers applied
- Capped at 10.0
Args:
vulnerabilities: List of discovered vulnerabilities
context: Additional context (public_internet, etc.)
Returns:
Risk assessment dictionary
"""
if not vulnerabilities:
return {
'overall_score': 0.0,
'severity_breakdown': {
'critical': 0, 'high': 0, 'medium': 0, 'low': 0, 'info': 0
},
'exploitability': 'NONE',
'impact': 'NONE',
'confidence': 100
}
context = context or {}
# Get base score (highest CVSS)
base_score = max(v.cvss_score for v in vulnerabilities)
# Count by severity
severity_counts = self._count_severities(vulnerabilities)
# Apply vulnerability count multiplier
critical_count = severity_counts['critical']
high_count = severity_counts['high']
# Increase score based on multiple vulnerabilities
adjusted_score = base_score * (1.0 + (0.1 * critical_count) + (0.05 * high_count))
# Apply context multipliers
for ctx_key, multiplier in self.CONTEXT_MULTIPLIERS.items():
if context.get(ctx_key, False):
adjusted_score *= multiplier
# Cap at 10.0
final_score = min(adjusted_score, 10.0)
# Determine exploitability
exploitability = self._calculate_exploitability(vulnerabilities, critical_count)
# Determine impact
impact = self._calculate_impact(vulnerabilities)
return {
'overall_score': round(final_score, 1),
'severity_breakdown': severity_counts,
'exploitability': exploitability,
'impact': impact,
'confidence': self._calculate_confidence(vulnerabilities),
'contributing_factors': self._get_contributing_factors(vulnerabilities)
}
def _count_severities(self, vulnerabilities: List[Vulnerability]) -> Dict[str, int]:
"""Count vulnerabilities by severity level."""
counts = {'critical': 0, 'high': 0, 'medium': 0, 'low': 0, 'info': 0}
for v in vulnerabilities:
severity_key = v.severity.lower()
if severity_key in counts:
counts[severity_key] += 1
return counts
def _calculate_exploitability(
self,
vulnerabilities: List[Vulnerability],
critical_count: int
) -> str:
"""Determine overall exploitability level."""
if critical_count >= 2:
return 'CRITICAL'
elif critical_count >= 1:
return 'HIGH'
# Check for easily exploitable vulnerabilities
easy_exploit = ['api_key_exposure', 'no_authentication', 'shell_access']
for v in vulnerabilities:
if any(indicator in v.check_name for indicator in easy_exploit):
return 'HIGH'
high_count = len([v for v in vulnerabilities if v.severity == 'HIGH'])
if high_count >= 2:
return 'MEDIUM'
return 'LOW'
def _calculate_impact(self, vulnerabilities: List[Vulnerability]) -> str:
"""Determine potential impact level."""
# Check for high-impact vulnerabilities
high_impact_indicators = [
'api_key_exposure',
'shell_access',
'database_exposed',
'admin_panel'
]
for v in vulnerabilities:
if any(indicator in v.check_name for indicator in high_impact_indicators):
return 'HIGH'
if any(v.cvss_score >= 7.0 for v in vulnerabilities):
return 'MEDIUM'
return 'LOW'
def _calculate_confidence(self, vulnerabilities: List[Vulnerability]) -> int:
"""Calculate confidence in the assessment."""
if not vulnerabilities:
return 100
# Start with high confidence
confidence = 100
# Reduce confidence for potential false positives
for v in vulnerabilities:
if 'potential' in v.check_name or 'possible' in v.description.lower():
confidence -= 10
return max(confidence, 0)
def _get_contributing_factors(self, vulnerabilities: List[Vulnerability]) -> List[str]:
"""Get list of main contributing factors to the risk score."""
factors = []
for v in vulnerabilities:
if v.severity in ['CRITICAL', 'HIGH']:
factors.append(f"{v.severity}: {v.description}")
return factors[:5] # Top 5 factors
def score_result(self, result: SearchResult, vulnerabilities: List[Vulnerability]) -> SearchResult:
"""
Apply risk score to a SearchResult.
Args:
result: SearchResult to score
vulnerabilities: Assessed vulnerabilities
Returns:
Updated SearchResult with risk score
"""
# Build context from result metadata
context = {
'public_internet': True, # Assume public if found via search
'ai_agent': self._is_ai_agent(result),
'clawsec_cve': self._has_clawsec_cve(result),
'clawsec_critical': self._has_critical_clawsec_cve(result)
}
# Check for WAF
http_info = result.metadata.get('http') or {}
http_headers = http_info.get('headers', {})
if not any(waf in str(http_headers).lower() for waf in ['cloudflare', 'akamai', 'fastly']):
context['no_waf'] = True
# Calculate score
risk_data = self.calculate_score(vulnerabilities, context)
# Update result
result.risk_score = risk_data['overall_score']
result.metadata['risk_assessment'] = risk_data
result.vulnerabilities = [v.check_name for v in vulnerabilities]
return result
def _has_clawsec_cve(self, result: SearchResult) -> bool:
"""Check if result has any ClawSec CVE associations."""
return bool(result.metadata.get('clawsec_advisories'))
def _has_critical_clawsec_cve(self, result: SearchResult) -> bool:
"""Check if result has a critical ClawSec CVE."""
advisories = result.metadata.get('clawsec_advisories', [])
return any(a.get('severity') == 'CRITICAL' for a in advisories)
def _is_ai_agent(self, result: SearchResult) -> bool:
"""Check if result appears to be an AI agent."""
ai_indicators = [
'clawdbot', 'autogpt', 'langchain', 'openai',
'anthropic', 'claude', 'gpt', 'agent'
]
http_info = result.metadata.get('http') or {}
http_title = http_info.get('title') or ''
text = (
(result.banner or '') +
(result.service or '') +
str(http_title)
).lower()
return any(indicator in text for indicator in ai_indicators)
def categorize_results(
self,
results: List[SearchResult]
) -> Dict[str, List[SearchResult]]:
"""
Categorize results by risk level.
Args:
results: List of scored results
Returns:
Dictionary with risk level categories
"""
categories = {
'critical': [],
'high': [],
'medium': [],
'low': []
}
for result in results:
if result.risk_score >= 9.0:
categories['critical'].append(result)
elif result.risk_score >= 7.0:
categories['high'].append(result)
elif result.risk_score >= 4.0:
categories['medium'].append(result)
else:
categories['low'].append(result)
return categories
def get_summary(self, results: List[SearchResult]) -> Dict[str, Any]:
"""
Get risk summary for a set of results.
Args:
results: List of scored results
Returns:
Summary statistics
"""
if not results:
return {
'total': 0,
'average_score': 0.0,
'max_score': 0.0,
'distribution': {'critical': 0, 'high': 0, 'medium': 0, 'low': 0}
}
categories = self.categorize_results(results)
scores = [r.risk_score for r in results]
return {
'total': len(results),
'average_score': round(sum(scores) / len(scores), 1),
'max_score': max(scores),
'distribution': {
'critical': len(categories['critical']),
'high': len(categories['high']),
'medium': len(categories['medium']),
'low': len(categories['low'])
}
}
+441
View File
@@ -0,0 +1,441 @@
"""Vulnerability assessment engine for AASRT."""
import re
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, TYPE_CHECKING
from src.engines import SearchResult
from src.utils.logger import get_logger
if TYPE_CHECKING:
from src.enrichment import ThreatEnricher
logger = get_logger(__name__)
@dataclass
class Vulnerability:
"""Represents a discovered vulnerability."""
check_name: str
severity: str # CRITICAL, HIGH, MEDIUM, LOW, INFO
cvss_score: float
description: str
evidence: Dict[str, Any] = field(default_factory=dict)
remediation: Optional[str] = None
cwe_id: Optional[str] = None
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary."""
return {
'check_name': self.check_name,
'severity': self.severity,
'cvss_score': self.cvss_score,
'description': self.description,
'evidence': self.evidence,
'remediation': self.remediation,
'cwe_id': self.cwe_id
}
class VulnerabilityAssessor:
"""Performs passive vulnerability assessment on search results."""
# API key patterns for detection
API_KEY_PATTERNS = {
'anthropic': {
'pattern': r'sk-ant-[a-zA-Z0-9-_]{20,}',
'description': 'Anthropic API key exposed',
'cvss': 10.0
},
'openai': {
'pattern': r'sk-[a-zA-Z0-9]{32,}',
'description': 'OpenAI API key exposed',
'cvss': 10.0
},
'aws_access_key': {
'pattern': r'AKIA[0-9A-Z]{16}',
'description': 'AWS Access Key ID exposed',
'cvss': 9.8
},
'aws_secret': {
'pattern': r'(?<![A-Za-z0-9/+=])[A-Za-z0-9/+=]{40}(?![A-Za-z0-9/+=])',
'description': 'Potential AWS Secret Key exposed',
'cvss': 9.8
},
'github_token': {
'pattern': r'ghp_[a-zA-Z0-9]{36}',
'description': 'GitHub Personal Access Token exposed',
'cvss': 9.5
},
'google_api': {
'pattern': r'AIza[0-9A-Za-z\-_]{35}',
'description': 'Google API key exposed',
'cvss': 7.5
},
'stripe': {
'pattern': r'sk_live_[0-9a-zA-Z]{24}',
'description': 'Stripe Secret Key exposed',
'cvss': 9.8
}
}
# Dangerous functionality patterns
DANGEROUS_PATTERNS = {
'shell_access': {
'patterns': [r'/shell', r'/exec', r'/execute', r'/api/execute', r'/cmd'],
'description': 'Shell command execution endpoint detected',
'cvss': 9.9,
'severity': 'CRITICAL'
},
'debug_mode': {
'patterns': [r'DEBUG\s*[=:]\s*[Tt]rue', r'debug\s*mode', r'stack\s*trace'],
'description': 'Debug mode appears to be enabled',
'cvss': 7.5,
'severity': 'HIGH'
},
'file_upload': {
'patterns': [r'/upload', r'/api/files', r'multipart/form-data'],
'description': 'File upload functionality detected',
'cvss': 7.8,
'severity': 'HIGH'
},
'admin_panel': {
'patterns': [r'/admin', r'admin\s*panel', r'administrator'],
'description': 'Admin panel potentially exposed',
'cvss': 8.5,
'severity': 'HIGH'
},
'database_exposed': {
'patterns': [r'mongodb://', r'mysql://', r'postgresql://', r'redis://'],
'description': 'Database connection string exposed',
'cvss': 9.5,
'severity': 'CRITICAL'
}
}
# Information disclosure patterns
INFO_DISCLOSURE_PATTERNS = {
'env_file': {
'patterns': [r'\.env', r'environment\s*variables?'],
'description': 'Environment file or variables exposed',
'cvss': 8.0,
'severity': 'HIGH'
},
'config_file': {
'patterns': [r'config\.json', r'settings\.py', r'application\.yml'],
'description': 'Configuration file exposed',
'cvss': 7.5,
'severity': 'HIGH'
},
'git_exposed': {
'patterns': [r'\.git/', r'\.git/config'],
'description': 'Git repository exposed',
'cvss': 7.0,
'severity': 'MEDIUM'
},
'source_code': {
'patterns': [r'\.py$', r'\.js$', r'\.php$'],
'description': 'Source code files potentially exposed',
'cvss': 6.5,
'severity': 'MEDIUM'
}
}
def __init__(self, config: Optional[Dict[str, Any]] = None, threat_enricher: Optional['ThreatEnricher'] = None):
"""
Initialize VulnerabilityAssessor.
Args:
config: Configuration dictionary
threat_enricher: Optional ThreatEnricher for ClawSec integration
"""
self.config = config or {}
self.passive_only = self.config.get('passive_only', True)
self.threat_enricher = threat_enricher
def assess(self, result: SearchResult) -> List[Vulnerability]:
"""
Perform vulnerability assessment on a search result.
Args:
result: SearchResult to assess
Returns:
List of discovered vulnerabilities
"""
vulnerabilities = []
# Check for API key exposure in banner
if result.banner:
vulnerabilities.extend(self._check_api_keys(result.banner))
# Check for dangerous functionality
vulnerabilities.extend(self._check_dangerous_functionality(result))
# Check for information disclosure
vulnerabilities.extend(self._check_information_disclosure(result))
# Check SSL/TLS issues
vulnerabilities.extend(self._check_ssl_issues(result))
# Check for authentication issues (based on metadata)
vulnerabilities.extend(self._check_authentication(result))
# Add pre-existing vulnerability indicators
for vuln_name in result.vulnerabilities:
if not any(v.check_name == vuln_name for v in vulnerabilities):
vulnerabilities.append(self._create_from_indicator(vuln_name))
logger.debug(f"Assessed {result.ip}:{result.port} - {len(vulnerabilities)} vulnerabilities")
return vulnerabilities
def assess_batch(self, results: List[SearchResult]) -> Dict[str, List[Vulnerability]]:
"""
Assess multiple results.
Args:
results: List of SearchResults
Returns:
Dictionary mapping result keys to vulnerability lists
"""
assessments = {}
for result in results:
key = f"{result.ip}:{result.port}"
assessments[key] = self.assess(result)
return assessments
def assess_with_intel(self, result: SearchResult) -> List[Vulnerability]:
"""
Perform vulnerability assessment enhanced with threat intelligence.
1. Enrich result with ClawSec CVE data
2. Run standard passive checks
3. Create Vulnerability objects for matched CVEs
Args:
result: SearchResult to assess
Returns:
List of discovered vulnerabilities including ClawSec CVEs
"""
# First, enrich the result if threat enricher is available
if self.threat_enricher:
result = self.threat_enricher.enrich(result)
# Run standard assessment
vulnerabilities = self.assess(result)
# Add ClawSec CVE vulnerabilities
if self.threat_enricher:
clawsec_vulns = self._create_clawsec_vulnerabilities(result)
vulnerabilities.extend(clawsec_vulns)
return vulnerabilities
def _create_clawsec_vulnerabilities(self, result: SearchResult) -> List[Vulnerability]:
"""
Convert ClawSec advisory data to Vulnerability objects.
Args:
result: SearchResult with clawsec_advisories in metadata
Returns:
List of Vulnerability objects from ClawSec data
"""
vulns = []
clawsec_data = result.metadata.get('clawsec_advisories', [])
for advisory in clawsec_data:
vulns.append(Vulnerability(
check_name=f"clawsec_{advisory['cve_id']}",
severity=advisory.get('severity', 'MEDIUM'),
cvss_score=advisory.get('cvss_score', 7.0),
description=f"[ClawSec] {advisory.get('title', 'Known vulnerability')}",
evidence={
'cve_id': advisory['cve_id'],
'source': 'ClawSec',
'vuln_type': advisory.get('vuln_type', 'unknown'),
'nvd_url': advisory.get('nvd_url')
},
remediation=advisory.get('action', 'See ClawSec advisory for remediation steps'),
cwe_id=advisory.get('cwe_id')
))
return vulns
def _check_api_keys(self, text: str) -> List[Vulnerability]:
"""Check for exposed API keys in text."""
vulnerabilities = []
for key_type, config in self.API_KEY_PATTERNS.items():
if re.search(config['pattern'], text, re.IGNORECASE):
vulnerabilities.append(Vulnerability(
check_name=f"api_key_exposure_{key_type}",
severity="CRITICAL",
cvss_score=config['cvss'],
description=config['description'],
evidence={'pattern_matched': key_type},
remediation="Immediately rotate the exposed API key and remove from public-facing content",
cwe_id="CWE-798"
))
return vulnerabilities
def _check_dangerous_functionality(self, result: SearchResult) -> List[Vulnerability]:
"""Check for dangerous functionality indicators."""
vulnerabilities = []
http_info = result.metadata.get('http') or {}
text = (result.banner or '') + str(http_info)
for check_name, config in self.DANGEROUS_PATTERNS.items():
for pattern in config['patterns']:
if re.search(pattern, text, re.IGNORECASE):
vulnerabilities.append(Vulnerability(
check_name=check_name,
severity=config['severity'],
cvss_score=config['cvss'],
description=config['description'],
evidence={'pattern': pattern},
remediation=self._get_remediation(check_name)
))
break # Only add once per check
return vulnerabilities
def _check_information_disclosure(self, result: SearchResult) -> List[Vulnerability]:
"""Check for information disclosure."""
vulnerabilities = []
text = (result.banner or '') + str(result.metadata)
for check_name, config in self.INFO_DISCLOSURE_PATTERNS.items():
for pattern in config['patterns']:
if re.search(pattern, text, re.IGNORECASE):
vulnerabilities.append(Vulnerability(
check_name=f"info_disclosure_{check_name}",
severity=config['severity'],
cvss_score=config['cvss'],
description=config['description'],
evidence={'pattern': pattern},
remediation="Remove or restrict access to sensitive files"
))
break
return vulnerabilities
def _check_ssl_issues(self, result: SearchResult) -> List[Vulnerability]:
"""Check for SSL/TLS issues."""
vulnerabilities = []
ssl_info = result.metadata.get('ssl') or {}
if not ssl_info:
# No SSL on HTTPS port might be an issue
if result.port in [443, 8443]:
vulnerabilities.append(Vulnerability(
check_name="no_ssl_on_https_port",
severity="MEDIUM",
cvss_score=5.3,
description="HTTPS port without SSL/TLS",
remediation="Configure proper SSL/TLS certificate"
))
return vulnerabilities
cert = ssl_info.get('cert') or {}
# Check for expired certificate
if cert.get('expired', False):
vulnerabilities.append(Vulnerability(
check_name="expired_ssl_certificate",
severity="MEDIUM",
cvss_score=5.0,
description="SSL certificate has expired",
remediation="Renew SSL certificate",
cwe_id="CWE-295"
))
# Check for self-signed certificate
if cert.get('self_signed', False):
vulnerabilities.append(Vulnerability(
check_name="self_signed_certificate",
severity="LOW",
cvss_score=3.0,
description="Self-signed SSL certificate detected",
remediation="Use a certificate from a trusted CA"
))
return vulnerabilities
def _check_authentication(self, result: SearchResult) -> List[Vulnerability]:
"""Check for authentication issues."""
vulnerabilities = []
http_info = result.metadata.get('http') or {}
if not http_info:
return vulnerabilities
# Check for missing authentication on sensitive endpoints
status = http_info.get('status')
if status == 200:
# 200 OK on root might indicate no auth
title = http_info.get('title') or ''
title = title.lower()
if any(term in title for term in ['dashboard', 'admin', 'control panel']):
vulnerabilities.append(Vulnerability(
check_name="no_authentication",
severity="CRITICAL",
cvss_score=9.1,
description="Dashboard accessible without authentication",
evidence={'http_title': http_info.get('title')},
remediation="Implement authentication mechanism",
cwe_id="CWE-306"
))
return vulnerabilities
def _create_from_indicator(self, indicator: str) -> Vulnerability:
"""Create a Vulnerability from a string indicator."""
# Map common indicators to vulnerabilities
indicator_map = {
'debug_mode_enabled': ('DEBUG', 'HIGH', 7.5, "Debug mode is enabled"),
'potential_api_key_exposure': ('API Keys', 'CRITICAL', 9.0, "Potential API key exposure detected"),
'expired_ssl_certificate': ('SSL', 'MEDIUM', 5.0, "Expired SSL certificate"),
'no_security_txt': ('Config', 'LOW', 2.0, "No security.txt file found"),
'self_signed_certificate': ('SSL', 'LOW', 3.0, "Self-signed certificate"),
}
if indicator in indicator_map:
category, severity, cvss, desc = indicator_map[indicator]
return Vulnerability(
check_name=indicator,
severity=severity,
cvss_score=cvss,
description=desc
)
# Default for unknown indicators
return Vulnerability(
check_name=indicator,
severity="INFO",
cvss_score=1.0,
description=f"Indicator detected: {indicator}"
)
def _get_remediation(self, check_name: str) -> str:
"""Get remediation advice for a vulnerability."""
remediations = {
'shell_access': "Disable or restrict shell execution endpoints. Implement authentication and authorization.",
'debug_mode': "Disable debug mode in production environments.",
'file_upload': "Implement file type validation, size limits, and malware scanning.",
'admin_panel': "Restrict admin panel access to authorized networks. Implement strong authentication.",
'database_exposed': "Remove database connection strings from public-facing content. Use environment variables.",
}
return remediations.get(check_name, "Review and remediate the identified issue.")
def get_severity_counts(self, vulnerabilities: List[Vulnerability]) -> Dict[str, int]:
"""Get count of vulnerabilities by severity."""
counts = {'CRITICAL': 0, 'HIGH': 0, 'MEDIUM': 0, 'LOW': 0, 'INFO': 0}
for vuln in vulnerabilities:
if vuln.severity in counts:
counts[vuln.severity] += 1
return counts
+10
View File
@@ -0,0 +1,10 @@
"""Search engine modules for AASRT."""
from .base import BaseSearchEngine, SearchResult
from .shodan_engine import ShodanEngine
__all__ = [
'BaseSearchEngine',
'SearchResult',
'ShodanEngine'
]
+183
View File
@@ -0,0 +1,183 @@
"""Abstract base class for search engine integrations."""
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
import time
from src.utils.logger import get_logger
from src.utils.exceptions import RateLimitException
logger = get_logger(__name__)
@dataclass
class SearchResult:
"""Represents a single search result from any engine."""
ip: str
port: int
hostname: Optional[str] = None
service: Optional[str] = None
banner: Optional[str] = None
vulnerabilities: List[str] = field(default_factory=list)
metadata: Dict[str, Any] = field(default_factory=dict)
source_engine: Optional[str] = None
timestamp: Optional[str] = None
risk_score: float = 0.0
confidence: int = 100
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary."""
return {
'ip': self.ip,
'port': self.port,
'hostname': self.hostname,
'service': self.service,
'banner': self.banner,
'vulnerabilities': self.vulnerabilities,
'metadata': self.metadata,
'source_engine': self.source_engine,
'timestamp': self.timestamp,
'risk_score': self.risk_score,
'confidence': self.confidence
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'SearchResult':
"""Create from dictionary."""
return cls(
ip=data.get('ip', ''),
port=data.get('port', 0),
hostname=data.get('hostname'),
service=data.get('service'),
banner=data.get('banner'),
vulnerabilities=data.get('vulnerabilities', []),
metadata=data.get('metadata', {}),
source_engine=data.get('source_engine'),
timestamp=data.get('timestamp'),
risk_score=data.get('risk_score', 0.0),
confidence=data.get('confidence', 100)
)
class BaseSearchEngine(ABC):
"""Abstract base class for all search engine integrations."""
def __init__(
self,
api_key: str,
rate_limit: float = 1.0,
timeout: int = 30,
max_results: int = 100
):
"""
Initialize the search engine.
Args:
api_key: API key for authentication
rate_limit: Maximum queries per second
timeout: Request timeout in seconds
max_results: Maximum results to return per query
"""
self.api_key = api_key
self.rate_limit = rate_limit
self.timeout = timeout
self.max_results = max_results
self._last_request_time = 0.0
self._request_count = 0
@property
@abstractmethod
def name(self) -> str:
"""Return the engine name."""
pass
@abstractmethod
def search(self, query: str, max_results: Optional[int] = None) -> List[SearchResult]:
"""
Execute a search query and return results.
Args:
query: Search query string
max_results: Maximum number of results to return (overrides default)
Returns:
List of SearchResult objects
Raises:
APIException: If API call fails
RateLimitException: If rate limit exceeded
"""
pass
@abstractmethod
def validate_credentials(self) -> bool:
"""
Validate API credentials.
Returns:
True if credentials are valid
Raises:
AuthenticationException: If credentials are invalid
"""
pass
@abstractmethod
def get_quota_info(self) -> Dict[str, Any]:
"""
Get API quota/usage information.
Returns:
Dictionary with quota information
"""
pass
def _rate_limit_wait(self) -> None:
"""Enforce rate limiting between requests."""
if self.rate_limit <= 0:
return
min_interval = 1.0 / self.rate_limit
elapsed = time.time() - self._last_request_time
if elapsed < min_interval:
wait_time = min_interval - elapsed
logger.debug(f"Rate limiting: waiting {wait_time:.2f}s")
time.sleep(wait_time)
self._last_request_time = time.time()
self._request_count += 1
def _check_rate_limit(self) -> None:
"""Check if rate limit is being approached."""
# This can be overridden by specific engines with their own rate limit logic
pass
def _parse_result(self, raw_result: Dict[str, Any]) -> SearchResult:
"""
Parse a raw API result into a SearchResult.
Args:
raw_result: Raw result from API
Returns:
SearchResult object
"""
# Default implementation - should be overridden by specific engines
return SearchResult(
ip=raw_result.get('ip', ''),
port=raw_result.get('port', 0),
source_engine=self.name
)
def get_stats(self) -> Dict[str, Any]:
"""Get engine statistics."""
return {
'engine': self.name,
'request_count': self._request_count,
'rate_limit': self.rate_limit,
'timeout': self.timeout,
'max_results': self.max_results
}
+589
View File
@@ -0,0 +1,589 @@
"""
Shodan search engine integration for AASRT.
This module provides a production-ready integration with the Shodan API
for security reconnaissance. Features include:
- Automatic retry with exponential backoff for transient failures
- Rate limiting to prevent API quota exhaustion
- Comprehensive error handling with specific exception types
- Detailed logging for debugging and monitoring
- Graceful degradation when API is unavailable
Example:
>>> from src.engines.shodan_engine import ShodanEngine
>>> engine = ShodanEngine(api_key="your_key")
>>> engine.validate_credentials()
True
>>> results = engine.search("http.html:clawdbot", max_results=10)
"""
from typing import Any, Callable, Dict, List, Optional, TypeVar
from datetime import datetime
from functools import wraps
import time
import socket
import shodan
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
before_sleep_log,
RetryError
)
from .base import BaseSearchEngine, SearchResult
from src.utils.logger import get_logger
from src.utils.validators import validate_ip, sanitize_output
from src.utils.exceptions import (
APIException,
RateLimitException,
AuthenticationException,
TimeoutException
)
logger = get_logger(__name__)
# Type variable for generic retry decorator
T = TypeVar('T')
# =============================================================================
# Retry Configuration
# =============================================================================
# Exceptions that should trigger a retry (transient failures)
RETRYABLE_EXCEPTIONS = (
socket.timeout,
ConnectionError,
ConnectionResetError,
TimeoutError,
)
# Maximum number of retry attempts
MAX_RETRY_ATTEMPTS = 3
# Base delay for exponential backoff (seconds)
RETRY_BASE_DELAY = 2
# Maximum delay between retries (seconds)
RETRY_MAX_DELAY = 30
def with_retry(func: Callable[..., T]) -> Callable[..., T]:
"""
Decorator that adds retry logic with exponential backoff.
Retries on transient network errors but not on authentication
or validation errors.
Args:
func: Function to wrap with retry logic.
Returns:
Wrapped function with retry capability.
"""
@wraps(func)
def wrapper(*args, **kwargs) -> T:
last_exception = None
for attempt in range(1, MAX_RETRY_ATTEMPTS + 1):
try:
return func(*args, **kwargs)
except RETRYABLE_EXCEPTIONS as e:
last_exception = e
if attempt < MAX_RETRY_ATTEMPTS:
delay = min(RETRY_BASE_DELAY ** attempt, RETRY_MAX_DELAY)
logger.warning(
f"Retry {attempt}/{MAX_RETRY_ATTEMPTS} for {func.__name__} "
f"after {delay}s delay. Error: {e}"
)
time.sleep(delay)
else:
logger.error(
f"All {MAX_RETRY_ATTEMPTS} retries exhausted for {func.__name__}. "
f"Last error: {e}"
)
except (AuthenticationException, RateLimitException):
# Don't retry auth or rate limit errors
raise
except shodan.APIError as e:
error_msg = str(e).lower()
# Don't retry permanent errors
if "invalid api key" in error_msg:
raise AuthenticationException(
"Invalid Shodan API key",
engine="shodan"
)
if "rate limit" in error_msg:
raise RateLimitException(
f"Shodan rate limit exceeded: {e}",
engine="shodan"
)
# Retry other API errors
last_exception = e
if attempt < MAX_RETRY_ATTEMPTS:
delay = min(RETRY_BASE_DELAY ** attempt, RETRY_MAX_DELAY)
logger.warning(
f"Retry {attempt}/{MAX_RETRY_ATTEMPTS} for {func.__name__} "
f"after {delay}s delay. API Error: {e}"
)
time.sleep(delay)
# All retries exhausted
if last_exception:
raise APIException(
f"Operation failed after {MAX_RETRY_ATTEMPTS} retries: {last_exception}",
engine="shodan"
)
raise APIException("Unexpected retry failure", engine="shodan")
return wrapper
class ShodanEngine(BaseSearchEngine):
"""
Shodan search engine integration for security reconnaissance.
This class provides a production-ready interface to the Shodan API with:
- Automatic rate limiting to respect API quotas
- Retry logic with exponential backoff for transient failures
- Comprehensive error handling and logging
- Result parsing with vulnerability detection
Attributes:
name: Engine identifier ("shodan").
api_key: Shodan API key (masked in logs).
rate_limit: Maximum requests per second.
timeout: Request timeout in seconds.
max_results: Default maximum results per search.
Example:
>>> engine = ShodanEngine(api_key="your_key")
>>> if engine.validate_credentials():
... results = engine.search("http.html:agent", max_results=50)
... for result in results:
... print(f"{result.ip}:{result.port}")
"""
def __init__(
self,
api_key: str,
rate_limit: float = 1.0,
timeout: int = 30,
max_results: int = 100
) -> None:
"""
Initialize Shodan engine with API credentials.
Args:
api_key: Shodan API key from https://account.shodan.io/.
Never log or expose this value.
rate_limit: Maximum queries per second. Default 1.0 to respect
Shodan's free tier limits.
timeout: Request timeout in seconds. Increase for slow connections.
max_results: Maximum results per query. Higher values consume
more query credits.
Raises:
ValueError: If api_key is empty or None.
"""
if not api_key or not api_key.strip():
raise ValueError("Shodan API key is required")
super().__init__(api_key, rate_limit, timeout, max_results)
self._client = shodan.Shodan(api_key)
self._api_key_preview = f"{api_key[:4]}...{api_key[-4:]}" if len(api_key) > 8 else "***"
logger.debug(f"ShodanEngine initialized with key: {self._api_key_preview}")
@property
def name(self) -> str:
"""Return engine identifier."""
return "shodan"
@with_retry
def validate_credentials(self) -> bool:
"""
Validate Shodan API credentials by making a test API call.
This method performs a lightweight API call to verify the API key
is valid and has not been revoked.
Returns:
True if credentials are valid and API is accessible.
Raises:
AuthenticationException: If API key is invalid or revoked.
APIException: If API call fails for other reasons.
Example:
>>> engine = ShodanEngine(api_key="your_key")
>>> try:
... engine.validate_credentials()
... print("API key is valid")
... except AuthenticationException:
... print("Invalid API key")
"""
try:
info = self._client.info()
plan = info.get('plan', 'unknown')
credits = info.get('query_credits', 0)
logger.info(
f"Shodan API validated. Plan: {plan}, "
f"Query credits: {credits}"
)
return True
except shodan.APIError as e:
error_msg = str(e)
if "Invalid API key" in error_msg:
logger.error("Shodan authentication failed: Invalid API key")
raise AuthenticationException(
"Invalid Shodan API key",
engine=self.name
)
logger.error(f"Shodan API validation error: {sanitize_output(error_msg)}")
raise APIException(f"Shodan API error: {e}", engine=self.name)
@with_retry
def get_quota_info(self) -> Dict[str, Any]:
"""
Get Shodan API quota and usage information.
Returns:
Dictionary containing:
- engine: Engine name ("shodan")
- plan: API plan type (e.g., "dev", "edu", "corp")
- query_credits: Remaining query credits
- scan_credits: Remaining scan credits
- monitored_ips: Number of monitored IPs
- unlocked: Whether account has unlocked features
- error: Error message if call failed (optional)
Note:
This call does not consume query credits.
"""
try:
info = self._client.info()
quota = {
'engine': self.name,
'plan': info.get('plan', 'unknown'),
'query_credits': info.get('query_credits', 0),
'scan_credits': info.get('scan_credits', 0),
'monitored_ips': info.get('monitored_ips', 0),
'unlocked': info.get('unlocked', False),
'timestamp': datetime.utcnow().isoformat()
}
logger.debug(f"Shodan quota retrieved: {quota['query_credits']} credits remaining")
return quota
except shodan.APIError as e:
logger.error(f"Failed to get Shodan quota: {sanitize_output(str(e))}")
return {
'engine': self.name,
'error': str(e),
'timestamp': datetime.utcnow().isoformat()
}
def search(self, query: str, max_results: Optional[int] = None) -> List[SearchResult]:
"""
Execute a Shodan search query with automatic pagination.
This method handles pagination automatically, respecting rate limits
and the specified maximum results. Each page consumes one query credit.
Args:
query: Shodan search query string. Supports Shodan's query syntax
including filters like http.html:, port:, country:, etc.
max_results: Maximum number of results to return. Defaults to
the engine's max_results setting. Set to None for default.
Returns:
List of SearchResult objects containing parsed Shodan data.
May return fewer results than max_results if not enough matches.
Raises:
APIException: If API call fails after all retries.
RateLimitException: If rate limit is exceeded.
AuthenticationException: If API key is invalid.
ValidationException: If query is invalid.
Example:
>>> results = engine.search("http.html:clawdbot", max_results=50)
>>> for r in results:
... print(f"{r.ip}:{r.port} - {r.service}")
Note:
- Shodan returns max 100 results per page
- Multiple pages consume multiple query credits
- Consider using count() first to check total results
"""
# Validate and sanitize query
if not query or not query.strip():
raise APIException("Search query cannot be empty", engine=self.name)
query = query.strip()
limit = max_results or self.max_results
# Log sanitized query (remove potential sensitive data)
safe_query = sanitize_output(query)
logger.info(f"Executing Shodan search: {safe_query} (limit: {limit})")
results: List[SearchResult] = []
page = 1
total_pages = 0
start_time = time.time()
try:
while len(results) < limit:
# Apply rate limiting before each request
self._rate_limit_wait()
# Execute search with retry logic
response = self._execute_search_page(query, page)
if response is None or not response.get('matches'):
logger.debug(f"No more matches at page {page}")
break
# Parse matches
matches = response.get('matches', [])
for match in matches:
if len(results) >= limit:
break
try:
result = self._parse_result(match)
results.append(result)
except Exception as e:
# Log but continue on parse errors
logger.warning(f"Failed to parse result: {e}")
continue
# Check pagination limits
total = response.get('total', 0)
total_pages = (total + 99) // 100 # Ceiling division
if len(results) >= total or len(results) >= limit:
break
page += 1
# Safety limit to prevent infinite loops
if page > 100:
logger.warning("Reached maximum page limit (100)")
break
# Log completion stats
elapsed = time.time() - start_time
logger.info(
f"Shodan search complete: {len(results)} results "
f"from {page} pages in {elapsed:.2f}s"
)
return results
except (AuthenticationException, RateLimitException):
# Re-raise known exceptions without wrapping
raise
except shodan.APIError as e:
error_msg = str(e).lower()
if "rate limit" in error_msg:
logger.error("Shodan rate limit exceeded during search")
raise RateLimitException(
f"Shodan rate limit exceeded: {e}",
engine=self.name
)
elif "invalid api key" in error_msg:
logger.error("Shodan authentication failed during search")
raise AuthenticationException(
"Invalid Shodan API key",
engine=self.name
)
else:
logger.error(f"Shodan API error: {sanitize_output(str(e))}")
raise APIException(
f"Shodan search failed: {e}",
engine=self.name
)
except Exception as e:
logger.exception(f"Unexpected error in Shodan search: {e}")
raise APIException(
f"Shodan search error: {type(e).__name__}: {e}",
engine=self.name
)
@with_retry
def _execute_search_page(self, query: str, page: int) -> Optional[Dict[str, Any]]:
"""
Execute a single page of Shodan search with retry logic.
Args:
query: Search query string.
page: Page number (1-indexed).
Returns:
Shodan API response dictionary or None on failure.
"""
logger.debug(f"Fetching Shodan results page {page}")
return self._client.search(query, page=page)
def _parse_result(self, match: Dict[str, Any]) -> SearchResult:
"""
Parse a Shodan match into a SearchResult.
Args:
match: Raw Shodan match data
Returns:
SearchResult object
"""
# Extract vulnerability indicators from data
vulnerabilities = []
data = match.get('data', '')
# Check for common vulnerability indicators
if 'debug' in data.lower() or 'DEBUG=True' in data:
vulnerabilities.append('debug_mode_enabled')
if 'api_key' in data.lower() or 'apikey' in data.lower():
vulnerabilities.append('potential_api_key_exposure')
ssl_data = match.get('ssl') or {}
ssl_cert = ssl_data.get('cert') or {}
if ssl_cert.get('expired', False):
vulnerabilities.append('expired_ssl_certificate')
# Check HTTP response for issues
http_data = match.get('http') or {}
if http_data:
if not http_data.get('securitytxt'):
vulnerabilities.append('no_security_txt')
# Build metadata
location_data = match.get('location') or {}
metadata = {
'asn': match.get('asn'),
'isp': match.get('isp'),
'org': match.get('org'),
'os': match.get('os'),
'transport': match.get('transport'),
'product': match.get('product'),
'version': match.get('version'),
'cpe': match.get('cpe', []),
'http': http_data,
'ssl': ssl_data,
'location': {
'country': location_data.get('country_name'),
'city': location_data.get('city'),
'latitude': location_data.get('latitude'),
'longitude': location_data.get('longitude')
}
}
# Extract hostnames
hostnames = match.get('hostnames', [])
hostname = hostnames[0] if hostnames else None
return SearchResult(
ip=match.get('ip_str', ''),
port=match.get('port', 0),
hostname=hostname,
service=match.get('product') or match.get('_shodan', {}).get('module'),
banner=data[:1000] if data else None, # Truncate long banners
vulnerabilities=vulnerabilities,
metadata=metadata,
source_engine=self.name,
timestamp=match.get('timestamp', datetime.utcnow().isoformat())
)
@with_retry
def host_info(self, ip: str) -> Dict[str, Any]:
"""
Get detailed information about a specific host.
This method retrieves comprehensive information about a host
including all open ports, services, banners, and historical data.
Args:
ip: IP address to lookup. Must be a valid IPv4 address.
Returns:
Dictionary containing:
- ip_str: IP address as string
- ports: List of open ports
- data: List of service banners per port
- hostnames: List of hostnames
- vulns: List of vulnerabilities (if any)
- location: Geographic information
Raises:
APIException: If lookup fails.
ValidationException: If IP address is invalid.
Example:
>>> info = engine.host_info("8.8.8.8")
>>> print(f"Ports: {info.get('ports', [])}")
"""
# Validate IP address
try:
validate_ip(ip)
except Exception as e:
raise APIException(f"Invalid IP address: {ip}", engine=self.name)
self._rate_limit_wait()
logger.debug(f"Looking up host info for: {ip}")
try:
host_data = self._client.host(ip)
logger.info(f"Retrieved host info for {ip}: {len(host_data.get('ports', []))} ports")
return host_data
except shodan.APIError as e:
error_msg = str(e).lower()
if "no information available" in error_msg:
logger.info(f"No Shodan data available for {ip}")
return {'ip_str': ip, 'ports': [], 'data': []}
logger.error(f"Failed to get host info for {ip}: {sanitize_output(str(e))}")
raise APIException(f"Shodan host lookup failed: {e}", engine=self.name)
@with_retry
def count(self, query: str) -> int:
"""
Get the count of results for a query without consuming query credits.
Use this method to estimate result count before running a full search
to avoid consuming query credits unnecessarily.
Args:
query: Shodan search query string.
Returns:
Estimated number of matching results. Returns 0 on error.
Note:
- Does not consume query credits
- Count may be approximate for large result sets
- Useful for validating queries before running searches
Example:
>>> count = engine.count("http.html:clawdbot")
>>> if count > 0:
... results = engine.search("http.html:clawdbot")
"""
if not query or not query.strip():
logger.warning("Empty query provided to count()")
return 0
self._rate_limit_wait()
logger.debug(f"Counting results for query: {sanitize_output(query)}")
try:
result = self._client.count(query)
total = result.get('total', 0)
logger.info(f"Query '{sanitize_output(query)}' has {total} results")
return total
except shodan.APIError as e:
logger.error(f"Failed to count results: {sanitize_output(str(e))}")
return 0
except Exception as e:
logger.exception(f"Unexpected error in count: {e}")
return 0
+19
View File
@@ -0,0 +1,19 @@
"""Enrichment modules for AASRT.
This module contains data enrichment capabilities:
- ClawSec threat intelligence integration
- (Future) WHOIS lookups
- (Future) Geolocation
- (Future) SSL/TLS certificate analysis
- (Future) DNS records
"""
from .clawsec_feed import ClawSecFeedManager, ClawSecFeed, ClawSecAdvisory
from .threat_enricher import ThreatEnricher
__all__ = [
'ClawSecFeedManager',
'ClawSecFeed',
'ClawSecAdvisory',
'ThreatEnricher'
]
+380
View File
@@ -0,0 +1,380 @@
"""ClawSec Threat Intelligence Feed Manager for AASRT."""
import json
import os
import threading
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from pathlib import Path
from typing import Any, Dict, List, Optional
import requests
from src.utils.logger import get_logger
logger = get_logger(__name__)
@dataclass
class ClawSecAdvisory:
"""Represents a single ClawSec CVE advisory."""
cve_id: str
severity: str # CRITICAL, HIGH, MEDIUM, LOW
vuln_type: str # e.g., "prompt_injection", "missing_authentication"
cvss_score: float
title: str
description: str
affected: List[str] = field(default_factory=list)
action: str = ""
nvd_url: Optional[str] = None
cwe_id: Optional[str] = None
published_date: Optional[datetime] = None
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary."""
return {
'cve_id': self.cve_id,
'severity': self.severity,
'vuln_type': self.vuln_type,
'cvss_score': self.cvss_score,
'title': self.title,
'description': self.description,
'affected': self.affected,
'action': self.action,
'nvd_url': self.nvd_url,
'cwe_id': self.cwe_id,
'published_date': self.published_date.isoformat() if self.published_date else None
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'ClawSecAdvisory':
"""Create from dictionary."""
published = data.get('published')
if published and isinstance(published, str):
try:
published = datetime.fromisoformat(published.replace('Z', '+00:00'))
except:
published = None
return cls(
cve_id=data.get('id', ''),
severity=data.get('severity', 'MEDIUM').upper(),
vuln_type=data.get('type', 'unknown'),
cvss_score=float(data.get('cvss_score', 0.0)),
title=data.get('title', ''),
description=data.get('description', ''),
affected=data.get('affected', []),
action=data.get('action', ''),
nvd_url=data.get('nvd_url'),
cwe_id=data.get('nvd_category_id'),
published_date=published
)
@dataclass
class ClawSecFeed:
"""Container for the full ClawSec advisory feed."""
advisories: List[ClawSecAdvisory]
last_updated: datetime
feed_version: str
total_count: int
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for caching."""
return {
'advisories': [a.to_dict() for a in self.advisories],
'last_updated': self.last_updated.isoformat(),
'feed_version': self.feed_version,
'total_count': self.total_count
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'ClawSecFeed':
"""Create from dictionary."""
return cls(
advisories=[ClawSecAdvisory.from_dict(a) for a in data.get('advisories', [])],
last_updated=datetime.fromisoformat(data.get('last_updated', datetime.utcnow().isoformat())),
feed_version=data.get('feed_version', '0.0.0'),
total_count=data.get('total_count', 0)
)
class ClawSecFeedManager:
"""
Manages ClawSec threat intelligence feed with caching and offline support.
Features:
- HTTP fetch with configurable timeout
- Local file caching for offline mode
- Advisory matching by product/version/banner
- Non-blocking background updates
"""
DEFAULT_FEED_URL = "https://clawsec.prompt.security/advisories/feed.json"
DEFAULT_CACHE_FILE = "./data/clawsec_cache.json"
DEFAULT_TTL = 86400 # 24 hours
def __init__(self, config=None):
"""
Initialize ClawSecFeedManager.
Args:
config: Configuration object with clawsec settings
"""
self.config = config
# Get configuration values
if config:
clawsec_config = config.get('clawsec', default={})
self.feed_url = clawsec_config.get('feed_url', self.DEFAULT_FEED_URL)
self.cache_file = clawsec_config.get('cache_file', self.DEFAULT_CACHE_FILE)
self.cache_ttl = clawsec_config.get('cache_ttl_seconds', self.DEFAULT_TTL)
self.offline_mode = clawsec_config.get('offline_mode', False)
self.timeout = clawsec_config.get('timeout', 30)
else:
self.feed_url = self.DEFAULT_FEED_URL
self.cache_file = self.DEFAULT_CACHE_FILE
self.cache_ttl = self.DEFAULT_TTL
self.offline_mode = False
self.timeout = 30
self._cache: Optional[ClawSecFeed] = None
self._cache_timestamp: Optional[datetime] = None
self._lock = threading.Lock()
def fetch_feed(self, force_refresh: bool = False) -> Optional[ClawSecFeed]:
"""
Fetch the ClawSec advisory feed.
Args:
force_refresh: Force fetch from URL even if cache is valid
Returns:
ClawSecFeed object or None if fetch fails
"""
# Check cache first
if not force_refresh and self.is_cache_valid():
logger.debug("Using cached ClawSec feed")
return self._cache
# In offline mode, only use cache
if self.offline_mode:
logger.info("ClawSec offline mode - using cached data only")
return self.get_cached_feed()
try:
logger.info(f"Fetching ClawSec feed from {self.feed_url}")
response = requests.get(self.feed_url, timeout=self.timeout)
response.raise_for_status()
data = response.json()
feed = self._parse_feed(data)
with self._lock:
self._cache = feed
self._cache_timestamp = datetime.utcnow()
# Persist to disk
self.save_cache()
logger.info(f"ClawSec feed loaded: {feed.total_count} advisories")
return feed
except requests.RequestException as e:
logger.warning(f"Failed to fetch ClawSec feed: {e}")
# Fall back to cache
return self.get_cached_feed()
except (json.JSONDecodeError, KeyError) as e:
logger.error(f"Failed to parse ClawSec feed: {e}")
return self.get_cached_feed()
def _parse_feed(self, data: Dict[str, Any]) -> ClawSecFeed:
"""Parse raw feed JSON into ClawSecFeed object."""
advisories = []
for advisory_data in data.get('advisories', []):
try:
advisory = ClawSecAdvisory.from_dict(advisory_data)
advisories.append(advisory)
except Exception as e:
logger.warning(f"Failed to parse advisory: {e}")
continue
return ClawSecFeed(
advisories=advisories,
last_updated=datetime.utcnow(),
feed_version=data.get('version', '0.0.0'),
total_count=len(advisories)
)
def get_cached_feed(self) -> Optional[ClawSecFeed]:
"""Return cached feed without network call."""
if self._cache:
return self._cache
# Try loading from disk
self.load_cache()
return self._cache
def is_cache_valid(self) -> bool:
"""Check if cache is within TTL."""
if not self._cache or not self._cache_timestamp:
return False
age = datetime.utcnow() - self._cache_timestamp
return age.total_seconds() < self.cache_ttl
def save_cache(self) -> None:
"""Persist cache to local file for offline mode."""
if not self._cache:
return
try:
cache_path = Path(self.cache_file)
cache_path.parent.mkdir(parents=True, exist_ok=True)
cache_data = {
'feed': self._cache.to_dict(),
'cached_at': datetime.utcnow().isoformat()
}
with open(cache_path, 'w') as f:
json.dump(cache_data, f, indent=2)
logger.debug(f"ClawSec cache saved to {self.cache_file}")
except Exception as e:
logger.warning(f"Failed to save ClawSec cache: {e}")
def load_cache(self) -> bool:
"""Load cache from local file."""
try:
cache_path = Path(self.cache_file)
if not cache_path.exists():
return False
with open(cache_path, 'r') as f:
cache_data = json.load(f)
self._cache = ClawSecFeed.from_dict(cache_data.get('feed', {}))
cached_at = cache_data.get('cached_at')
if cached_at:
self._cache_timestamp = datetime.fromisoformat(cached_at)
logger.info(f"ClawSec cache loaded: {self._cache.total_count} advisories")
return True
except Exception as e:
logger.warning(f"Failed to load ClawSec cache: {e}")
return False
def match_advisories(
self,
product: Optional[str] = None,
version: Optional[str] = None,
banner: Optional[str] = None
) -> List[ClawSecAdvisory]:
"""
Find matching advisories for a product/version/banner.
Matching strategies (in order):
1. Exact product name match in affected list
2. Fuzzy product match (clawdbot, clawbot, claw-bot)
3. Banner text contains product from affected
Args:
product: Product name to match
version: Version string to check
banner: Banner text to search
Returns:
List of matching ClawSecAdvisory objects
"""
feed = self.get_cached_feed()
if not feed:
return []
matches = []
product_lower = (product or '').lower()
banner_lower = (banner or '').lower()
# AI agent keywords to look for
ai_keywords = ['clawdbot', 'clawbot', 'moltbot', 'openclaw', 'autogpt', 'langchain']
for advisory in feed.advisories:
matched = False
# Check each affected product
for affected in advisory.affected:
affected_lower = affected.lower()
# Strategy 1: Direct product match
if product_lower and product_lower in affected_lower:
matched = True
break
# Strategy 2: Check AI keywords in affected and product/banner
for keyword in ai_keywords:
if keyword in affected_lower:
if keyword in product_lower or keyword in banner_lower:
matched = True
break
if matched:
break
# Strategy 3: Banner contains affected product
if banner_lower:
# Extract product name from affected (e.g., "ClawdBot < 2.0" -> "clawdbot")
affected_product = affected_lower.split('<')[0].split('>')[0].strip()
if affected_product and affected_product in banner_lower:
matched = True
break
if matched and advisory not in matches:
matches.append(advisory)
logger.debug(f"ClawSec matched {len(matches)} advisories for product={product}")
return matches
def background_refresh(self) -> None:
"""Start background thread to refresh feed."""
def _refresh():
try:
self.fetch_feed(force_refresh=True)
except Exception as e:
logger.warning(f"Background ClawSec refresh failed: {e}")
thread = threading.Thread(target=_refresh, daemon=True)
thread.start()
logger.debug("ClawSec background refresh started")
def get_statistics(self) -> Dict[str, Any]:
"""Get feed statistics for UI display."""
feed = self.get_cached_feed()
if not feed:
return {
'total_advisories': 0,
'critical_count': 0,
'high_count': 0,
'last_updated': None,
'is_stale': True
}
severity_counts = {'CRITICAL': 0, 'HIGH': 0, 'MEDIUM': 0, 'LOW': 0}
for advisory in feed.advisories:
if advisory.severity in severity_counts:
severity_counts[advisory.severity] += 1
return {
'total_advisories': feed.total_count,
'critical_count': severity_counts['CRITICAL'],
'high_count': severity_counts['HIGH'],
'medium_count': severity_counts['MEDIUM'],
'low_count': severity_counts['LOW'],
'last_updated': feed.last_updated.isoformat() if feed.last_updated else None,
'feed_version': feed.feed_version,
'is_stale': not self.is_cache_valid()
}
+228
View File
@@ -0,0 +1,228 @@
"""Threat Intelligence Enrichment for AASRT."""
from typing import Any, Dict, List, Optional, Tuple
from src.engines import SearchResult
from src.utils.logger import get_logger
from .clawsec_feed import ClawSecAdvisory, ClawSecFeedManager
logger = get_logger(__name__)
class ThreatEnricher:
"""
Enriches SearchResult objects with ClawSec threat intelligence.
Responsibilities:
- Match results against ClawSec advisories
- Add CVE metadata to result.metadata
- Inject ClawSec vulnerabilities into result.vulnerabilities
"""
def __init__(self, feed_manager: ClawSecFeedManager, config=None):
"""
Initialize ThreatEnricher.
Args:
feed_manager: ClawSecFeedManager instance
config: Optional configuration object
"""
self.feed_manager = feed_manager
self.config = config
def enrich(self, result: SearchResult) -> SearchResult:
"""
Enrich a single result with threat intelligence.
Args:
result: SearchResult to enrich
Returns:
Enriched SearchResult with ClawSec metadata
"""
# Extract product info from result
product, version = self._extract_product_info(result)
banner = result.banner or ''
# Get HTTP title if available
http_info = result.metadata.get('http', {}) or {}
title = http_info.get('title') or ''
if title:
banner = f"{banner} {title}"
# Match against ClawSec advisories
advisories = self.feed_manager.match_advisories(
product=product,
version=version,
banner=banner
)
if advisories:
result = self._add_cve_context(result, advisories)
logger.debug(f"Enriched {result.ip}:{result.port} with {len(advisories)} ClawSec advisories")
return result
def enrich_batch(self, results: List[SearchResult]) -> List[SearchResult]:
"""
Enrich multiple results efficiently.
Args:
results: List of SearchResults to enrich
Returns:
List of enriched SearchResults
"""
enriched = []
for result in results:
enriched.append(self.enrich(result))
return enriched
def _extract_product_info(self, result: SearchResult) -> Tuple[Optional[str], Optional[str]]:
"""
Extract product name and version from result metadata.
Args:
result: SearchResult to analyze
Returns:
Tuple of (product_name, version) or (None, None)
"""
product = None
version = None
# Check metadata for product info
metadata = result.metadata if isinstance(result.metadata, dict) else {}
# Try product field directly
if 'product' in metadata:
product = metadata['product']
# Try version field
if 'version' in metadata:
version = metadata['version']
# Check HTTP info
http_info = metadata.get('http') or {}
if http_info:
title = http_info.get('title') or ''
# Look for AI agent keywords in title
ai_products = {
'clawdbot': 'ClawdBot',
'moltbot': 'MoltBot',
'autogpt': 'AutoGPT',
'langchain': 'LangChain',
'openclaw': 'OpenClaw'
}
for keyword, name in ai_products.items():
if title and keyword in title.lower():
product = name
break
# Check service name
if not product and result.service:
service_lower = result.service.lower()
for keyword in ['clawdbot', 'moltbot', 'autogpt', 'langchain']:
if keyword in service_lower:
product = result.service
break
# Check banner for version patterns
if result.banner and not version:
import re
version_patterns = [
r'v?(\d+\.\d+(?:\.\d+)?)', # v1.2.3 or 1.2.3
r'version[:\s]+(\d+\.\d+(?:\.\d+)?)', # version: 1.2.3
]
for pattern in version_patterns:
match = re.search(pattern, result.banner, re.IGNORECASE)
if match:
version = match.group(1)
break
return product, version
def _add_cve_context(
self,
result: SearchResult,
advisories: List[ClawSecAdvisory]
) -> SearchResult:
"""
Add CVE information to result metadata and vulnerabilities.
Args:
result: SearchResult to update
advisories: List of matched ClawSecAdvisory objects
Returns:
Updated SearchResult
"""
# Add ClawSec advisories to metadata
clawsec_data = []
for advisory in advisories:
clawsec_data.append({
'cve_id': advisory.cve_id,
'severity': advisory.severity,
'cvss_score': advisory.cvss_score,
'title': advisory.title,
'vuln_type': advisory.vuln_type,
'action': advisory.action,
'nvd_url': advisory.nvd_url,
'cwe_id': advisory.cwe_id
})
result.metadata['clawsec_advisories'] = clawsec_data
# Track highest severity for quick access
severity_order = {'CRITICAL': 4, 'HIGH': 3, 'MEDIUM': 2, 'LOW': 1}
highest_severity = max(
(a.severity for a in advisories),
key=lambda s: severity_order.get(s, 0),
default='LOW'
)
result.metadata['clawsec_severity'] = highest_severity
# Add CVE IDs to vulnerabilities list
for advisory in advisories:
vuln_id = f"clawsec_{advisory.cve_id}"
if vuln_id not in result.vulnerabilities:
result.vulnerabilities.append(vuln_id)
return result
def get_enrichment_stats(self, results: List[SearchResult]) -> Dict[str, Any]:
"""
Get statistics about enrichment for a set of results.
Args:
results: List of enriched SearchResults
Returns:
Dictionary with enrichment statistics
"""
enriched_count = 0
total_cves = 0
severity_counts = {'CRITICAL': 0, 'HIGH': 0, 'MEDIUM': 0, 'LOW': 0}
cve_list = set()
for result in results:
advisories = result.metadata.get('clawsec_advisories', [])
if advisories:
enriched_count += 1
total_cves += len(advisories)
for advisory in advisories:
cve_list.add(advisory['cve_id'])
severity = advisory.get('severity', 'LOW')
if severity in severity_counts:
severity_counts[severity] += 1
return {
'enriched_results': enriched_count,
'total_results': len(results),
'enrichment_rate': (enriched_count / len(results) * 100) if results else 0,
'unique_cves': len(cve_list),
'total_cve_matches': total_cves,
'severity_breakdown': severity_counts,
'cve_ids': list(cve_list)
}
+689
View File
@@ -0,0 +1,689 @@
"""
CLI entry point for AASRT - AI Agent Security Reconnaissance Tool.
This module provides the command-line interface for AASRT with:
- Shodan-based security reconnaissance scanning
- Vulnerability assessment and risk scoring
- Report generation (JSON/CSV)
- Database storage and history tracking
- Signal handling for graceful shutdown
Usage:
python -m src.main status # Check API status
python -m src.main scan --template clawdbot_instances
python -m src.main history # View scan history
python -m src.main templates # List available templates
Environment Variables:
SHODAN_API_KEY: Required for scanning operations
AASRT_LOG_LEVEL: Logging level (DEBUG, INFO, WARNING, ERROR)
AASRT_DEBUG: Enable debug mode (true/false)
Exit Codes:
0: Success
1: Error (invalid arguments, API errors, etc.)
130: Interrupted by user (SIGINT/Ctrl+C)
"""
import atexit
import signal
import sys
import time
import uuid
from typing import Any, Dict, Optional
import click
from rich.console import Console
from rich.table import Table
from rich.panel import Panel
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn
from src import __version__
from src.utils.config import Config
from src.utils.logger import setup_logger, get_logger
from src.core.query_manager import QueryManager
from src.core.result_aggregator import ResultAggregator
from src.core.vulnerability_assessor import VulnerabilityAssessor
from src.core.risk_scorer import RiskScorer
from src.storage.database import Database
from src.reporting import JSONReporter, CSVReporter, ScanReport
# =============================================================================
# Global State
# =============================================================================
console = Console()
_shutdown_requested = False
_active_database: Optional[Database] = None
# =============================================================================
# Signal Handlers
# =============================================================================
def _signal_handler(signum: int, frame: Any) -> None:
"""
Handle interrupt signals for graceful shutdown.
Args:
signum: Signal number received.
frame: Current stack frame (unused).
"""
global _shutdown_requested
signal_name = signal.Signals(signum).name
if _shutdown_requested:
# Second interrupt - force exit
console.print("\n[red]Force shutdown requested. Exiting immediately.[/red]")
sys.exit(130)
_shutdown_requested = True
console.print(f"\n[yellow]Received {signal_name}. Shutting down gracefully...[/yellow]")
console.print("[dim]Press Ctrl+C again to force quit.[/dim]")
def _cleanup() -> None:
"""
Cleanup function called on exit.
Closes database connections and performs cleanup.
"""
global _active_database
if _active_database:
try:
_active_database.close()
except Exception:
pass # Ignore errors during cleanup
def is_shutdown_requested() -> bool:
"""
Check if shutdown has been requested.
Returns:
True if a shutdown signal was received.
"""
return _shutdown_requested
# Register signal handlers
signal.signal(signal.SIGINT, _signal_handler)
signal.signal(signal.SIGTERM, _signal_handler)
atexit.register(_cleanup)
# =============================================================================
# Legal Disclaimer
# =============================================================================
LEGAL_DISCLAIMER = """
[bold red]WARNING: LEGAL DISCLAIMER[/bold red]
This tool is for [bold]authorized security research and defensive purposes only[/bold].
Unauthorized access to computer systems is illegal under:
- CFAA (Computer Fraud and Abuse Act) - United States
- Computer Misuse Act - United Kingdom
- Similar laws worldwide
By proceeding, you acknowledge that:
1. You have authorization to scan target systems
2. You will comply with all applicable laws and terms of service
3. You will responsibly disclose findings
4. You will not exploit discovered vulnerabilities
[bold yellow]The authors are not responsible for misuse of this tool.[/bold yellow]
"""
# =============================================================================
# CLI Command Group
# =============================================================================
@click.group()
@click.version_option(version=__version__, prog_name="AASRT")
@click.option('--config', '-c', type=click.Path(exists=True), help='Path to config file')
@click.option('--verbose', '-v', is_flag=True, help='Enable verbose output')
@click.pass_context
def cli(ctx: click.Context, config: Optional[str], verbose: bool) -> None:
"""
AI Agent Security Reconnaissance Tool (AASRT).
Discover and assess exposed AI agent implementations using Shodan.
Use 'aasrt --help' for command list or 'aasrt COMMAND --help' for command details.
"""
ctx.ensure_object(dict)
# Initialize configuration
try:
ctx.obj['config'] = Config(config)
except Exception as e:
console.print(f"[red]Failed to load configuration: {e}[/red]")
sys.exit(1)
# Setup logging
log_level = 'DEBUG' if verbose else ctx.obj['config'].get('logging', 'level', default='INFO')
log_file = ctx.obj['config'].get('logging', 'file')
try:
setup_logger('aasrt', level=log_level, log_file=log_file)
except Exception as e:
console.print(f"[yellow]Warning: Could not setup logging: {e}[/yellow]")
ctx.obj['verbose'] = verbose
# Log startup in debug mode
logger = get_logger('aasrt')
logger.debug(f"AASRT v{__version__} starting (verbose={verbose})")
# =============================================================================
# Scan Command
# =============================================================================
@cli.command()
@click.option('--query', '-q', help='Custom Shodan search query')
@click.option('--template', '-t', help='Use predefined query template')
@click.option('--max-results', '-m', default=100, type=int, help='Max results to retrieve (1-10000)')
@click.option('--output', '-o', type=click.Path(), help='Output file path')
@click.option('--format', '-f', 'output_format',
type=click.Choice(['json', 'csv', 'both']),
default='json', help='Output format')
@click.option('--no-assess', is_flag=True, help='Skip vulnerability assessment')
@click.option('--save-db/--no-save-db', default=True, help='Save results to database')
@click.option('--yes', '-y', is_flag=True, help='Skip legal disclaimer confirmation')
@click.pass_context
def scan(
ctx: click.Context,
query: Optional[str],
template: Optional[str],
max_results: int,
output: Optional[str],
output_format: str,
no_assess: bool,
save_db: bool,
yes: bool
) -> None:
"""
Perform a security reconnaissance scan using Shodan.
Searches for exposed AI agent implementations and assesses their
security posture using passive analysis techniques.
Examples:
aasrt scan --template clawdbot_instances
aasrt scan --query 'http.title:"AutoGPT"'
aasrt scan -t exposed_env_files -m 50 -f csv
"""
global _active_database
config = ctx.obj['config']
logger = get_logger('aasrt')
logger.info(f"Starting scan command (template={template}, query={query[:50] if query else None})")
# Display legal disclaimer
if not yes:
console.print(Panel(LEGAL_DISCLAIMER, title="Legal Notice", border_style="red"))
if not click.confirm('\nDo you agree to the terms above?', default=False):
console.print('[red]Scan aborted. You must agree to terms of use.[/red]')
logger.info("Scan aborted: User declined legal disclaimer")
sys.exit(1)
# Validate max_results
if max_results < 1:
console.print('[red]Error: max-results must be at least 1[/red]')
sys.exit(1)
if max_results > 10000:
console.print('[yellow]Warning: Limiting max-results to 10000[/yellow]')
max_results = 10000
# Validate inputs
if not query and not template:
console.print('[yellow]No query or template specified. Using default template: clawdbot_instances[/yellow]')
template = 'clawdbot_instances'
# Check for shutdown before heavy operations
if is_shutdown_requested():
console.print('[yellow]Scan cancelled due to shutdown request.[/yellow]')
sys.exit(130)
# Initialize query manager
try:
query_manager = QueryManager(config)
except Exception as e:
console.print(f'[red]Failed to initialize query manager: {e}[/red]')
logger.error(f"Query manager initialization failed: {e}")
sys.exit(1)
# Check if Shodan is available
if not query_manager.is_available():
console.print('[red]Shodan is not available. Please check your API key in .env file.[/red]')
console.print('[dim]Set SHODAN_API_KEY environment variable or add to .env file.[/dim]')
sys.exit(1)
console.print('\n[green]Starting Shodan scan...[/green]')
logger.info(f"Scan started: template={template}, max_results={max_results}")
# Generate scan ID
scan_id = str(uuid.uuid4())
start_time = time.time()
# Execute scan with interrupt checking
all_results = []
scan_error = None
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
console=console
) as progress:
if template:
task = progress.add_task(f"[cyan]Scanning with template: {template}...", total=100)
try:
if not is_shutdown_requested():
all_results = query_manager.execute_template(template, max_results=max_results)
progress.update(task, completed=100)
except KeyboardInterrupt:
console.print('\n[yellow]Scan interrupted by user.[/yellow]')
scan_error = "Interrupted"
except Exception as e:
console.print(f'[red]Template execution failed: {e}[/red]')
logger.error(f"Template execution error: {e}", exc_info=True)
scan_error = str(e)
else:
task = progress.add_task("[cyan]Executing query...", total=100)
try:
if not is_shutdown_requested():
all_results = query_manager.execute_query(query, max_results=max_results)
progress.update(task, completed=100)
except KeyboardInterrupt:
console.print('\n[yellow]Scan interrupted by user.[/yellow]')
scan_error = "Interrupted"
except Exception as e:
console.print(f'[red]Query execution failed: {e}[/red]')
logger.error(f"Query execution error: {e}", exc_info=True)
scan_error = str(e)
# Check if scan was interrupted or had errors
if is_shutdown_requested():
console.print('[yellow]Scan was interrupted. Saving partial results...[/yellow]')
# Aggregate and deduplicate results
console.print('\n[cyan]Aggregating results...[/cyan]')
aggregator = ResultAggregator()
unique_results = aggregator.aggregate({'shodan': all_results})
console.print(f'Found [green]{len(unique_results)}[/green] unique results')
logger.info(f"Aggregated {len(unique_results)} unique results from {len(all_results)} total")
# Vulnerability assessment (skip if shutdown requested)
if not no_assess and unique_results and not is_shutdown_requested():
console.print('\n[cyan]Assessing vulnerabilities...[/cyan]')
assessor = VulnerabilityAssessor(config.get('vulnerability_checks', default={}))
scorer = RiskScorer()
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
console=console
) as progress:
task = progress.add_task("[cyan]Analyzing...", total=len(unique_results))
for result in unique_results:
if is_shutdown_requested():
console.print('[yellow]Assessment interrupted.[/yellow]')
break
try:
vulns = assessor.assess(result)
scorer.score_result(result, vulns)
except Exception as e:
logger.warning(f"Failed to assess result {result.ip}: {e}")
progress.advance(task)
# Calculate duration
duration = time.time() - start_time
# Determine final status
final_status = 'completed'
if scan_error:
final_status = 'failed' if not unique_results else 'partial'
elif is_shutdown_requested():
final_status = 'partial'
# Create report
report = ScanReport.from_results(
scan_id=scan_id,
results=unique_results,
engines=['shodan'],
query=query,
template_name=template,
duration=duration
)
# Display summary
_display_summary(report)
# Save to database
if save_db:
try:
db = Database(config)
_active_database = db # Track for cleanup
scan_record = db.create_scan(
engines=['shodan'],
query=query,
template_name=template
)
if unique_results:
db.add_findings(scan_record.scan_id, unique_results)
db.update_scan(
scan_record.scan_id,
status=final_status,
total_results=len(unique_results),
duration_seconds=duration
)
console.print(f'\n[green]Results saved to database. Scan ID: {scan_record.scan_id}[/green]')
logger.info(f"Saved scan {scan_record.scan_id} with {len(unique_results)} findings")
except Exception as e:
console.print(f'[yellow]Warning: Failed to save to database: {e}[/yellow]')
logger.error(f"Database save error: {e}", exc_info=True)
# Generate reports
output_dir = config.get('reporting', 'output_dir', default='./reports')
try:
if output_format in ['json', 'both']:
json_reporter = JSONReporter(output_dir)
json_path = json_reporter.generate(report, output)
console.print(f'[green]JSON report: {json_path}[/green]')
if output_format in ['csv', 'both']:
csv_reporter = CSVReporter(output_dir)
csv_path = csv_reporter.generate(report, output)
console.print(f'[green]CSV report: {csv_path}[/green]')
except Exception as e:
console.print(f'[yellow]Warning: Failed to generate report: {e}[/yellow]')
logger.error(f"Report generation error: {e}", exc_info=True)
# Final status message
if final_status == 'completed':
console.print(f'\n[bold green]Scan completed in {duration:.1f} seconds[/bold green]')
elif final_status == 'partial':
console.print(f'\n[bold yellow]Scan partially completed in {duration:.1f} seconds[/bold yellow]')
else:
console.print(f'\n[bold red]Scan failed after {duration:.1f} seconds[/bold red]')
sys.exit(1)
# =============================================================================
# Helper Functions
# =============================================================================
def _display_summary(report: ScanReport) -> None:
"""
Display scan summary in a formatted table.
Renders a Rich-formatted summary including:
- Scan ID and duration
- Total results and average risk score
- Risk distribution table
- Top 5 highest risk findings
Args:
report: ScanReport object with scan results.
"""
console.print('\n')
# Summary panel
summary_text = f"""
[bold]Scan ID:[/bold] {report.scan_id[:8]}...
[bold]Duration:[/bold] {report.duration_seconds:.1f}s
[bold]Total Results:[/bold] {report.total_results}
[bold]Average Risk Score:[/bold] {report.average_risk_score}/10
"""
console.print(Panel(summary_text, title="Scan Summary", border_style="green"))
# Risk distribution table
table = Table(title="Risk Distribution")
table.add_column("Severity", style="bold")
table.add_column("Count", justify="right")
table.add_row("[red]Critical[/red]", str(report.critical_findings))
table.add_row("[orange1]High[/orange1]", str(report.high_findings))
table.add_row("[yellow]Medium[/yellow]", str(report.medium_findings))
table.add_row("[green]Low[/green]", str(report.low_findings))
console.print(table)
# Top findings
if report.findings:
console.print('\n[bold]Top 5 Highest Risk Findings:[/bold]')
top_findings = sorted(
report.findings,
key=lambda x: x.get('risk_score', 0),
reverse=True
)[:5]
for i, finding in enumerate(top_findings, 1):
risk = finding.get('risk_score', 0)
ip = finding.get('target_ip', 'N/A')
port = finding.get('target_port', 'N/A')
hostname = finding.get('target_hostname', '')
# Color-code by CVSS-like severity
if risk >= 9.0:
color = 'red' # Critical
elif risk >= 7.0:
color = 'orange1' # High
elif risk >= 4.0:
color = 'yellow' # Medium
else:
color = 'green' # Low
target = f"{ip}:{port}"
if hostname:
target += f" ({hostname})"
console.print(f" {i}. [{color}]{target}[/{color}] - Risk: [{color}]{risk}[/{color}]")
@cli.command()
@click.option('--scan-id', '-s', help='Generate report for specific scan')
@click.option('--format', '-f', 'output_format',
type=click.Choice(['json', 'csv', 'both']),
default='json', help='Output format')
@click.option('--output', '-o', type=click.Path(), help='Output file path')
@click.pass_context
def report(ctx, scan_id, output_format, output):
"""Generate a report from a previous scan."""
config = ctx.obj['config']
try:
db = Database(config)
except Exception as e:
console.print(f'[red]Failed to connect to database: {e}[/red]')
sys.exit(1)
if scan_id:
scan = db.get_scan(scan_id)
if not scan:
console.print(f'[red]Scan not found: {scan_id}[/red]')
sys.exit(1)
scans = [scan]
else:
scans = db.get_recent_scans(limit=1)
if not scans:
console.print('[yellow]No scans found in database.[/yellow]')
sys.exit(0)
scan = scans[0]
findings = db.get_findings(scan_id=scan.scan_id)
report_data = ScanReport.from_scan(scan, findings)
output_dir = config.get('reporting', 'output_dir', default='./reports')
if output_format in ['json', 'both']:
json_reporter = JSONReporter(output_dir)
json_path = json_reporter.generate(report_data, output)
console.print(f'[green]JSON report: {json_path}[/green]')
if output_format in ['csv', 'both']:
csv_reporter = CSVReporter(output_dir)
csv_path = csv_reporter.generate(report_data, output)
console.print(f'[green]CSV report: {csv_path}[/green]')
@cli.command()
@click.pass_context
def status(ctx):
"""Show status of Shodan API configuration."""
config = ctx.obj['config']
console.print('\n[bold]Shodan API Status[/bold]\n')
try:
query_manager = QueryManager(config)
except Exception as e:
console.print(f'[red]Failed to initialize: {e}[/red]')
return
# Validate Shodan
table = Table(title="Engine Status")
table.add_column("Engine", style="bold")
table.add_column("Status")
table.add_column("Details")
if query_manager.is_available():
is_valid = query_manager.validate_engine()
if is_valid:
quota = query_manager.get_quota_info()
status_str = "[green]OK[/green]"
details = f"Credits: {quota.get('query_credits', 'N/A')}, Plan: {quota.get('plan', 'N/A')}"
else:
status_str = "[red]Invalid[/red]"
details = "API key validation failed"
else:
status_str = "[red]Not Configured[/red]"
details = "Add SHODAN_API_KEY to .env file"
table.add_row("Shodan", status_str, details)
console.print(table)
# Available templates
templates = query_manager.get_available_templates()
console.print(f'\n[bold]Available Query Templates:[/bold] {len(templates)}')
for template in sorted(templates):
console.print(f' - {template}')
@cli.command()
@click.option('--limit', '-l', default=10, help='Number of recent scans to show')
@click.pass_context
def history(ctx, limit):
"""Show scan history from database."""
config = ctx.obj['config']
try:
db = Database(config)
scans = db.get_recent_scans(limit=limit)
except Exception as e:
console.print(f'[red]Failed to access database: {e}[/red]')
return
if not scans:
console.print('[yellow]No scans found in database.[/yellow]')
return
table = Table(title=f"Recent Scans (Last {limit})")
table.add_column("Scan ID", style="cyan")
table.add_column("Timestamp")
table.add_column("Template/Query")
table.add_column("Results", justify="right")
table.add_column("Status")
for scan in scans:
scan_id = scan.scan_id[:8] + "..."
timestamp = scan.timestamp.strftime("%Y-%m-%d %H:%M") if scan.timestamp else "N/A"
query_info = scan.template_name or (scan.query[:30] + "..." if scan.query and len(scan.query) > 30 else scan.query) or "N/A"
status_color = "green" if scan.status == "completed" else "yellow" if scan.status == "running" else "red"
status_str = f"[{status_color}]{scan.status}[/{status_color}]"
table.add_row(scan_id, timestamp, query_info, str(scan.total_results), status_str)
console.print(table)
# Show database stats
stats = db.get_statistics()
console.print(f'\n[bold]Database Statistics:[/bold]')
console.print(f' Total Scans: {stats["total_scans"]}')
console.print(f' Total Findings: {stats["total_findings"]}')
console.print(f' Unique IPs: {stats["unique_ips"]}')
@cli.command()
@click.pass_context
def templates(ctx):
"""List available query templates."""
config = ctx.obj['config']
try:
query_manager = QueryManager(config)
except Exception as e:
console.print(f'[red]Failed to initialize: {e}[/red]')
return
templates = query_manager.get_available_templates()
console.print('\n[bold]Available Shodan Query Templates[/bold]\n')
table = Table()
table.add_column("Template Name", style="cyan")
table.add_column("Queries")
for template_name in sorted(templates):
queries = query_manager.templates.get(template_name, [])
query_count = len(queries)
table.add_row(template_name, f"{query_count} queries")
console.print(table)
console.print('\n[dim]Use with: aasrt scan --template <template_name>[/dim]')
# =============================================================================
# Entry Point
# =============================================================================
def main() -> None:
"""
Main entry point for AASRT CLI.
Initializes the Click command group and handles top-level exceptions.
Called when running `python -m src.main` or `aasrt` command.
Exit Codes:
0: Success
1: Error
130: Interrupted by user
"""
try:
cli(obj={})
except KeyboardInterrupt:
console.print("\n[yellow]Operation cancelled by user.[/yellow]")
sys.exit(130)
except Exception as e:
logger = get_logger('aasrt')
logger.exception(f"Unexpected error: {e}")
console.print(f"\n[red]Unexpected error: {e}[/red]")
console.print("[dim]Check logs for details.[/dim]")
sys.exit(1)
if __name__ == '__main__':
main()
+7
View File
@@ -0,0 +1,7 @@
"""Reporting modules for AASRT."""
from .base import BaseReporter, ScanReport
from .json_reporter import JSONReporter
from .csv_reporter import CSVReporter
__all__ = ['BaseReporter', 'ScanReport', 'JSONReporter', 'CSVReporter']
+199
View File
@@ -0,0 +1,199 @@
"""Base reporter class for AASRT."""
import os
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any, Dict, List, Optional
from src.engines import SearchResult
from src.storage.database import Scan, Finding
@dataclass
class ScanReport:
"""Container for scan report data."""
scan_id: str
timestamp: datetime
engines_used: List[str]
query: Optional[str] = None
template_name: Optional[str] = None
total_results: int = 0
duration_seconds: float = 0.0
status: str = "completed"
# Summary statistics
critical_findings: int = 0
high_findings: int = 0
medium_findings: int = 0
low_findings: int = 0
average_risk_score: float = 0.0
# Detailed findings
findings: List[Dict[str, Any]] = field(default_factory=list)
# Additional metadata
metadata: Dict[str, Any] = field(default_factory=dict)
@classmethod
def from_scan(
cls,
scan: Scan,
findings: List[Finding]
) -> 'ScanReport':
"""Create ScanReport from database objects."""
import json
# Calculate severity counts
critical = sum(1 for f in findings if f.risk_score >= 9.0)
high = sum(1 for f in findings if 7.0 <= f.risk_score < 9.0)
medium = sum(1 for f in findings if 4.0 <= f.risk_score < 7.0)
low = sum(1 for f in findings if f.risk_score < 4.0)
# Calculate average risk
avg_risk = sum(f.risk_score for f in findings) / len(findings) if findings else 0.0
return cls(
scan_id=scan.scan_id,
timestamp=scan.timestamp,
engines_used=json.loads(scan.engines_used) if scan.engines_used else [],
query=scan.query,
template_name=scan.template_name,
total_results=len(findings),
duration_seconds=scan.duration_seconds or 0.0,
status=scan.status,
critical_findings=critical,
high_findings=high,
medium_findings=medium,
low_findings=low,
average_risk_score=round(avg_risk, 1),
findings=[f.to_dict() for f in findings],
metadata=json.loads(scan.metadata) if scan.metadata else {}
)
@classmethod
def from_results(
cls,
scan_id: str,
results: List[SearchResult],
engines: List[str],
query: Optional[str] = None,
template_name: Optional[str] = None,
duration: float = 0.0
) -> 'ScanReport':
"""Create ScanReport from search results."""
# Calculate severity counts
critical = sum(1 for r in results if r.risk_score >= 9.0)
high = sum(1 for r in results if 7.0 <= r.risk_score < 9.0)
medium = sum(1 for r in results if 4.0 <= r.risk_score < 7.0)
low = sum(1 for r in results if r.risk_score < 4.0)
# Calculate average risk
avg_risk = sum(r.risk_score for r in results) / len(results) if results else 0.0
return cls(
scan_id=scan_id,
timestamp=datetime.utcnow(),
engines_used=engines,
query=query,
template_name=template_name,
total_results=len(results),
duration_seconds=duration,
status="completed",
critical_findings=critical,
high_findings=high,
medium_findings=medium,
low_findings=low,
average_risk_score=round(avg_risk, 1),
findings=[r.to_dict() for r in results]
)
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary."""
return {
'scan_metadata': {
'scan_id': self.scan_id,
'timestamp': self.timestamp.isoformat() if self.timestamp else None,
'engines_used': self.engines_used,
'query': self.query,
'template_name': self.template_name,
'total_results': self.total_results,
'duration_seconds': self.duration_seconds,
'status': self.status
},
'summary': {
'critical_findings': self.critical_findings,
'high_findings': self.high_findings,
'medium_findings': self.medium_findings,
'low_findings': self.low_findings,
'average_risk_score': self.average_risk_score
},
'findings': self.findings,
'metadata': self.metadata
}
class BaseReporter(ABC):
"""Abstract base class for reporters."""
def __init__(self, output_dir: str = "./reports"):
"""
Initialize reporter.
Args:
output_dir: Directory for report output
"""
self.output_dir = output_dir
os.makedirs(output_dir, exist_ok=True)
@property
@abstractmethod
def format_name(self) -> str:
"""Return the format name (e.g., 'json', 'csv')."""
pass
@property
@abstractmethod
def file_extension(self) -> str:
"""Return the file extension."""
pass
@abstractmethod
def generate(self, report: ScanReport, filename: Optional[str] = None) -> str:
"""
Generate a report file.
Args:
report: ScanReport data
filename: Optional custom filename (without extension)
Returns:
Path to generated report file
"""
pass
@abstractmethod
def generate_string(self, report: ScanReport) -> str:
"""
Generate report as a string.
Args:
report: ScanReport data
Returns:
Report content as string
"""
pass
def get_filename(self, scan_id: str, custom_name: Optional[str] = None) -> str:
"""Generate a filename for the report."""
if custom_name:
return f"{custom_name}.{self.file_extension}"
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
return f"scan_{scan_id[:8]}_{timestamp}.{self.file_extension}"
def get_filepath(self, filename: str) -> str:
"""Get full file path for a report."""
return os.path.join(self.output_dir, filename)
+221
View File
@@ -0,0 +1,221 @@
"""CSV report generator for AASRT."""
import csv
import io
from typing import List, Optional
from .base import BaseReporter, ScanReport
from src.utils.logger import get_logger
logger = get_logger(__name__)
class CSVReporter(BaseReporter):
"""Generates CSV format reports."""
# Default columns for findings export
DEFAULT_COLUMNS = [
'target_ip',
'target_port',
'target_hostname',
'service',
'risk_score',
'vulnerabilities',
'source_engine',
'first_seen',
'status',
'confidence'
]
def __init__(
self,
output_dir: str = "./reports",
columns: Optional[List[str]] = None,
include_metadata: bool = False
):
"""
Initialize CSV reporter.
Args:
output_dir: Output directory for reports
columns: Custom columns to include
include_metadata: Whether to include metadata columns
"""
super().__init__(output_dir)
self.columns = columns or self.DEFAULT_COLUMNS.copy()
self.include_metadata = include_metadata
if include_metadata:
self.columns.extend(['location', 'isp', 'asn'])
@property
def format_name(self) -> str:
return "csv"
@property
def file_extension(self) -> str:
return "csv"
def generate(self, report: ScanReport, filename: Optional[str] = None) -> str:
"""
Generate CSV report file.
Args:
report: ScanReport data
filename: Optional custom filename
Returns:
Path to generated report file
"""
output_filename = self.get_filename(report.scan_id, filename)
filepath = self.get_filepath(output_filename)
content = self.generate_string(report)
with open(filepath, 'w', encoding='utf-8', newline='') as f:
f.write(content)
logger.info(f"Generated CSV report: {filepath}")
return filepath
def generate_string(self, report: ScanReport) -> str:
"""
Generate CSV report as string.
Args:
report: ScanReport data
Returns:
CSV content as string
"""
output = io.StringIO()
writer = csv.DictWriter(output, fieldnames=self.columns, extrasaction='ignore')
# Write header
writer.writeheader()
# Write findings
for finding in report.findings:
row = self._format_finding(finding)
writer.writerow(row)
return output.getvalue()
def _format_finding(self, finding: dict) -> dict:
"""Format a finding for CSV output."""
row = {}
for col in self.columns:
if col in finding:
value = finding[col]
# Convert lists to comma-separated strings
if isinstance(value, list):
value = '; '.join(str(v) for v in value)
# Convert dicts to string representation
elif isinstance(value, dict):
value = str(value)
row[col] = value
elif col == 'location':
# Extract from metadata
metadata = finding.get('metadata', {})
location = metadata.get('location', {})
if isinstance(location, dict):
row[col] = f"{location.get('country', '')}, {location.get('city', '')}"
else:
row[col] = ''
elif col == 'isp':
metadata = finding.get('metadata', {})
row[col] = metadata.get('isp', '')
elif col == 'asn':
metadata = finding.get('metadata', {})
row[col] = metadata.get('asn', '')
else:
row[col] = ''
return row
def generate_summary(self, report: ScanReport, filename: Optional[str] = None) -> str:
"""
Generate a summary CSV file.
Args:
report: ScanReport data
filename: Optional custom filename
Returns:
Path to generated file
"""
summary_filename = filename or f"summary_{report.scan_id[:8]}"
if not summary_filename.endswith('.csv'):
summary_filename = f"{summary_filename}.csv"
filepath = self.get_filepath(summary_filename)
output = io.StringIO()
writer = csv.writer(output)
# Write summary as key-value pairs
writer.writerow(['Metric', 'Value'])
writer.writerow(['Scan ID', report.scan_id])
writer.writerow(['Timestamp', report.timestamp.isoformat() if report.timestamp else ''])
writer.writerow(['Engines Used', ', '.join(report.engines_used)])
writer.writerow(['Query', report.query or ''])
writer.writerow(['Template', report.template_name or ''])
writer.writerow(['Total Results', report.total_results])
writer.writerow(['Duration (seconds)', report.duration_seconds])
writer.writerow(['Critical Findings', report.critical_findings])
writer.writerow(['High Findings', report.high_findings])
writer.writerow(['Medium Findings', report.medium_findings])
writer.writerow(['Low Findings', report.low_findings])
writer.writerow(['Average Risk Score', report.average_risk_score])
with open(filepath, 'w', encoding='utf-8', newline='') as f:
f.write(output.getvalue())
logger.info(f"Generated CSV summary: {filepath}")
return filepath
def generate_vulnerability_report(self, report: ScanReport, filename: Optional[str] = None) -> str:
"""
Generate a vulnerability-focused CSV report.
Args:
report: ScanReport data
filename: Optional custom filename
Returns:
Path to generated file
"""
vuln_filename = filename or f"vulnerabilities_{report.scan_id[:8]}"
if not vuln_filename.endswith('.csv'):
vuln_filename = f"{vuln_filename}.csv"
filepath = self.get_filepath(vuln_filename)
output = io.StringIO()
writer = csv.writer(output)
# Header
writer.writerow(['Target IP', 'Port', 'Hostname', 'Vulnerability', 'Risk Score'])
# Write vulnerability rows
for finding in report.findings:
ip = finding.get('target_ip', '')
port = finding.get('target_port', '')
hostname = finding.get('target_hostname', '')
risk_score = finding.get('risk_score', 0)
vulns = finding.get('vulnerabilities', [])
if vulns:
for vuln in vulns:
writer.writerow([ip, port, hostname, vuln, risk_score])
else:
writer.writerow([ip, port, hostname, 'None detected', risk_score])
with open(filepath, 'w', encoding='utf-8', newline='') as f:
f.write(output.getvalue())
logger.info(f"Generated vulnerability CSV: {filepath}")
return filepath
+122
View File
@@ -0,0 +1,122 @@
"""JSON report generator for AASRT."""
import json
from typing import Optional
from .base import BaseReporter, ScanReport
from src.utils.logger import get_logger
logger = get_logger(__name__)
class JSONReporter(BaseReporter):
"""Generates JSON format reports."""
def __init__(self, output_dir: str = "./reports", pretty: bool = True):
"""
Initialize JSON reporter.
Args:
output_dir: Output directory for reports
pretty: Whether to format JSON with indentation
"""
super().__init__(output_dir)
self.pretty = pretty
@property
def format_name(self) -> str:
return "json"
@property
def file_extension(self) -> str:
return "json"
def generate(self, report: ScanReport, filename: Optional[str] = None) -> str:
"""
Generate JSON report file.
Args:
report: ScanReport data
filename: Optional custom filename
Returns:
Path to generated report file
"""
output_filename = self.get_filename(report.scan_id, filename)
filepath = self.get_filepath(output_filename)
content = self.generate_string(report)
with open(filepath, 'w', encoding='utf-8') as f:
f.write(content)
logger.info(f"Generated JSON report: {filepath}")
return filepath
def generate_string(self, report: ScanReport) -> str:
"""
Generate JSON report as string.
Args:
report: ScanReport data
Returns:
JSON string
"""
data = report.to_dict()
# Add report metadata
data['report_metadata'] = {
'format': 'json',
'version': '1.0',
'generated_by': 'AASRT (AI Agent Security Reconnaissance Tool)'
}
if self.pretty:
return json.dumps(data, indent=2, default=str, ensure_ascii=False)
else:
return json.dumps(data, default=str, ensure_ascii=False)
def generate_summary(self, report: ScanReport) -> str:
"""
Generate a summary-only JSON report.
Args:
report: ScanReport data
Returns:
JSON string with summary only
"""
summary = {
'scan_id': report.scan_id,
'timestamp': report.timestamp.isoformat() if report.timestamp else None,
'engines_used': report.engines_used,
'total_results': report.total_results,
'summary': {
'critical_findings': report.critical_findings,
'high_findings': report.high_findings,
'medium_findings': report.medium_findings,
'low_findings': report.low_findings,
'average_risk_score': report.average_risk_score
}
}
if self.pretty:
return json.dumps(summary, indent=2, default=str)
else:
return json.dumps(summary, default=str)
def generate_findings_only(self, report: ScanReport) -> str:
"""
Generate JSON with findings only (no metadata).
Args:
report: ScanReport data
Returns:
JSON string with findings array
"""
if self.pretty:
return json.dumps(report.findings, indent=2, default=str, ensure_ascii=False)
else:
return json.dumps(report.findings, default=str, ensure_ascii=False)
+5
View File
@@ -0,0 +1,5 @@
"""Storage modules for AASRT."""
from .database import Database, Scan, Finding
__all__ = ['Database', 'Scan', 'Finding']
+806
View File
@@ -0,0 +1,806 @@
"""
Database storage layer for AASRT.
This module provides a production-ready database layer with:
- Connection pooling for efficient resource usage
- Automatic retry logic for transient failures
- Context managers for proper session cleanup
- Support for SQLite (default) and PostgreSQL
- Comprehensive logging and error handling
Example:
>>> from src.storage.database import Database
>>> db = Database()
>>> scan = db.create_scan(engines=["shodan"], query="http.html:agent")
>>> db.add_findings(scan.scan_id, results)
>>> db.update_scan(scan.scan_id, status="completed")
"""
import json
import os
import uuid
import time
from contextlib import contextmanager
from datetime import datetime, timedelta
from functools import wraps
from typing import Any, Callable, Dict, Generator, List, Optional, TypeVar
from sqlalchemy import (
create_engine, Column, String, Integer, Float, DateTime,
Text, Boolean, ForeignKey, Index, event
)
from sqlalchemy.orm import declarative_base, sessionmaker, relationship, Session, scoped_session
from sqlalchemy.pool import QueuePool, StaticPool
from sqlalchemy.exc import SQLAlchemyError, OperationalError, IntegrityError
from src.engines import SearchResult
from src.utils.config import Config
from src.utils.logger import get_logger
logger = get_logger(__name__)
Base = declarative_base()
# =============================================================================
# Retry Configuration
# =============================================================================
T = TypeVar('T')
# Maximum retry attempts for transient database errors
MAX_DB_RETRIES = 3
# Base delay for exponential backoff (seconds)
DB_RETRY_BASE_DELAY = 0.5
# Exceptions that should trigger a retry
RETRYABLE_DB_EXCEPTIONS = (OperationalError,)
def with_db_retry(func: Callable[..., T]) -> Callable[..., T]:
"""
Decorator that adds retry logic for transient database errors.
Retries on connection errors and deadlocks but not on
constraint violations or other permanent errors.
Args:
func: Database function to wrap with retry logic.
Returns:
Wrapped function with retry capability.
"""
@wraps(func)
def wrapper(*args, **kwargs) -> T:
last_exception = None
for attempt in range(1, MAX_DB_RETRIES + 1):
try:
return func(*args, **kwargs)
except RETRYABLE_DB_EXCEPTIONS as e:
last_exception = e
if attempt < MAX_DB_RETRIES:
delay = DB_RETRY_BASE_DELAY * (2 ** (attempt - 1))
logger.warning(
f"Database retry {attempt}/{MAX_DB_RETRIES} for {func.__name__} "
f"after {delay:.2f}s. Error: {e}"
)
time.sleep(delay)
else:
logger.error(
f"All {MAX_DB_RETRIES} database retries exhausted for {func.__name__}"
)
except IntegrityError as e:
# Don't retry constraint violations
logger.error(f"Database integrity error in {func.__name__}: {e}")
raise
except SQLAlchemyError as e:
# Log and re-raise other SQLAlchemy errors
logger.error(f"Database error in {func.__name__}: {e}")
raise
# All retries exhausted
if last_exception:
raise last_exception
raise SQLAlchemyError(f"Unexpected database error in {func.__name__}")
return wrapper
class Scan(Base):
"""Scan record model."""
__tablename__ = 'scans'
scan_id = Column(String(36), primary_key=True)
timestamp = Column(DateTime, nullable=False, default=datetime.utcnow)
engines_used = Column(Text) # JSON array
query = Column(Text)
template_name = Column(String(255))
total_results = Column(Integer, default=0)
duration_seconds = Column(Float)
status = Column(String(50), default='running') # running, completed, failed, partial
extra_data = Column(Text) # JSON
# Relationships
findings = relationship("Finding", back_populates="scan", cascade="all, delete-orphan")
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary."""
return {
'scan_id': self.scan_id,
'timestamp': self.timestamp.isoformat() if self.timestamp else None,
'engines_used': json.loads(self.engines_used) if self.engines_used else [],
'query': self.query,
'template_name': self.template_name,
'total_results': self.total_results,
'duration_seconds': self.duration_seconds,
'status': self.status,
'metadata': json.loads(self.extra_data) if self.extra_data else {}
}
class Finding(Base):
"""Finding record model."""
__tablename__ = 'findings'
finding_id = Column(String(36), primary_key=True)
scan_id = Column(String(36), ForeignKey('scans.scan_id'), nullable=False)
source_engine = Column(String(50))
target_ip = Column(String(45), nullable=False) # Support IPv6
target_port = Column(Integer, nullable=False)
target_hostname = Column(String(255))
service = Column(String(255))
banner = Column(Text)
risk_score = Column(Float, default=0.0)
vulnerabilities = Column(Text) # JSON array
first_seen = Column(DateTime, default=datetime.utcnow)
last_seen = Column(DateTime, default=datetime.utcnow)
status = Column(String(50), default='new') # new, confirmed, false_positive, remediated
confidence = Column(Integer, default=100)
extra_data = Column(Text) # JSON
# Relationships
scan = relationship("Scan", back_populates="findings")
# Indexes
__table_args__ = (
Index('idx_findings_risk', risk_score.desc()),
Index('idx_findings_timestamp', first_seen.desc()),
Index('idx_findings_ip', target_ip),
Index('idx_findings_status', status),
)
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary."""
return {
'finding_id': self.finding_id,
'scan_id': self.scan_id,
'source_engine': self.source_engine,
'target_ip': self.target_ip,
'target_port': self.target_port,
'target_hostname': self.target_hostname,
'service': self.service,
'banner': self.banner,
'risk_score': self.risk_score,
'vulnerabilities': json.loads(self.vulnerabilities) if self.vulnerabilities else [],
'first_seen': self.first_seen.isoformat() if self.first_seen else None,
'last_seen': self.last_seen.isoformat() if self.last_seen else None,
'status': self.status,
'confidence': self.confidence,
'metadata': json.loads(self.extra_data) if self.extra_data else {}
}
@classmethod
def from_search_result(cls, result: SearchResult, scan_id: str) -> 'Finding':
"""Create Finding from SearchResult."""
return cls(
finding_id=str(uuid.uuid4()),
scan_id=scan_id,
source_engine=result.source_engine,
target_ip=result.ip,
target_port=result.port,
target_hostname=result.hostname,
service=result.service,
banner=result.banner,
risk_score=result.risk_score,
vulnerabilities=json.dumps(result.vulnerabilities),
confidence=result.confidence,
extra_data=json.dumps(result.metadata)
)
class Database:
"""
Database manager for AASRT with connection pooling and retry logic.
This class provides a thread-safe database layer with:
- Connection pooling for efficient resource usage
- Automatic retry on transient failures
- Context managers for proper session cleanup
- Support for SQLite and PostgreSQL
Attributes:
config: Configuration instance.
engine: SQLAlchemy engine with connection pool.
Session: Scoped session factory.
Example:
>>> db = Database()
>>> with db.session_scope() as session:
... scan = Scan(scan_id="123", ...)
... session.add(scan)
>>> # Session is automatically committed and closed
"""
# Connection pool settings
POOL_SIZE = 5
MAX_OVERFLOW = 10
POOL_TIMEOUT = 30
POOL_RECYCLE = 3600 # Recycle connections after 1 hour
def __init__(self, config: Optional[Config] = None) -> None:
"""
Initialize database connection with connection pooling.
Args:
config: Configuration instance. If None, uses default Config.
Raises:
SQLAlchemyError: If database connection fails.
"""
self.config = config or Config()
self.engine = None
self.Session = None
self._db_type: str = "unknown"
self._initialize()
def _initialize(self) -> None:
"""
Initialize database connection, pooling, and create tables.
Sets up connection pooling appropriate for the database type:
- SQLite: Uses StaticPool for thread safety
- PostgreSQL: Uses QueuePool with configurable size
"""
self._db_type = self.config.get('database', 'type', default='sqlite')
if self._db_type == 'sqlite':
db_path = self.config.get('database', 'sqlite', 'path', default='./data/scanner.db')
# Ensure directory exists
os.makedirs(os.path.dirname(db_path), exist_ok=True)
connection_string = f"sqlite:///{db_path}"
# SQLite configuration - use StaticPool for thread safety
# Also enable WAL mode for better concurrent access
self.engine = create_engine(
connection_string,
echo=False,
poolclass=StaticPool,
connect_args={
"check_same_thread": False,
"timeout": 30
}
)
# Enable WAL mode for better concurrent access
@event.listens_for(self.engine, "connect")
def set_sqlite_pragma(dbapi_connection, connection_record):
cursor = dbapi_connection.cursor()
cursor.execute("PRAGMA journal_mode=WAL")
cursor.execute("PRAGMA synchronous=NORMAL")
cursor.execute("PRAGMA foreign_keys=ON")
cursor.close()
else:
# PostgreSQL with connection pooling
host = self.config.get('database', 'postgresql', 'host', default='localhost')
port = self.config.get('database', 'postgresql', 'port', default=5432)
database = self.config.get('database', 'postgresql', 'database', default='aasrt')
user = self.config.get('database', 'postgresql', 'user')
password = self.config.get('database', 'postgresql', 'password')
ssl_mode = self.config.get('database', 'postgresql', 'ssl_mode', default='prefer')
# Mask password in logs
safe_conn_str = f"postgresql://{user}:***@{host}:{port}/{database}"
connection_string = f"postgresql://{user}:{password}@{host}:{port}/{database}?sslmode={ssl_mode}"
self.engine = create_engine(
connection_string,
echo=False,
poolclass=QueuePool,
pool_size=self.POOL_SIZE,
max_overflow=self.MAX_OVERFLOW,
pool_timeout=self.POOL_TIMEOUT,
pool_recycle=self.POOL_RECYCLE,
pool_pre_ping=True # Verify connections before use
)
logger.debug(f"PostgreSQL connection: {safe_conn_str}")
# Use scoped_session for thread safety
self.Session = scoped_session(sessionmaker(bind=self.engine))
# Create tables
Base.metadata.create_all(self.engine)
logger.info(f"Database initialized: {self._db_type}")
@contextmanager
def session_scope(self) -> Generator[Session, None, None]:
"""
Provide a transactional scope around a series of operations.
This context manager handles session lifecycle:
- Creates a new session
- Commits on success
- Rolls back on exception
- Always closes the session
Yields:
SQLAlchemy Session object.
Raises:
SQLAlchemyError: On database errors (after rollback).
Example:
>>> with db.session_scope() as session:
... session.add(Scan(...))
... # Automatically committed if no exception
"""
session = self.Session()
try:
yield session
session.commit()
except Exception as e:
session.rollback()
logger.error(f"Database session error, rolling back: {e}")
raise
finally:
session.close()
def get_session(self) -> Session:
"""
Get a database session (legacy method).
Note:
Prefer using session_scope() context manager for new code.
This method is kept for backward compatibility.
Returns:
SQLAlchemy Session object.
"""
return self.Session()
def close(self) -> None:
"""
Close all database connections and cleanup resources.
Call this method during application shutdown to properly
release database connections.
"""
if self.Session:
self.Session.remove()
if self.engine:
self.engine.dispose()
logger.info("Database connections closed")
def health_check(self) -> Dict[str, Any]:
"""
Perform a health check on the database connection.
Returns:
Dictionary with health status:
- healthy: bool indicating if database is accessible
- db_type: Database type (sqlite/postgresql)
- latency_ms: Response time in milliseconds
- error: Error message if unhealthy (optional)
"""
start_time = time.time()
try:
with self.session_scope() as session:
# Simple query to verify connection
session.execute("SELECT 1")
latency = (time.time() - start_time) * 1000
return {
"healthy": True,
"db_type": self._db_type,
"latency_ms": round(latency, 2),
"pool_size": getattr(self.engine.pool, 'size', lambda: 'N/A')() if hasattr(self.engine, 'pool') else 'N/A'
}
except Exception as e:
latency = (time.time() - start_time) * 1000
logger.error(f"Database health check failed: {e}")
return {
"healthy": False,
"db_type": self._db_type,
"latency_ms": round(latency, 2),
"error": str(e)
}
# =========================================================================
# Scan Operations
# =========================================================================
@with_db_retry
def create_scan(
self,
engines: List[str],
query: Optional[str] = None,
template_name: Optional[str] = None
) -> Scan:
"""
Create a new scan record in the database.
Args:
engines: List of engine names used for the scan (e.g., ["shodan"]).
query: Search query string (if using custom query).
template_name: Template name (if using predefined template).
Returns:
Created Scan object with generated scan_id.
Raises:
SQLAlchemyError: If database operation fails.
Example:
>>> scan = db.create_scan(engines=["shodan"], template_name="clawdbot")
>>> print(scan.scan_id)
"""
scan = Scan(
scan_id=str(uuid.uuid4()),
timestamp=datetime.utcnow(),
engines_used=json.dumps(engines),
query=query,
template_name=template_name,
status='running'
)
with self.session_scope() as session:
session.add(scan)
# Flush to ensure data is written before expunge
session.flush()
logger.info(f"Created scan: {scan.scan_id}")
# Need to expunge to use outside session
session.expunge(scan)
return scan
@with_db_retry
def update_scan(
self,
scan_id: str,
status: Optional[str] = None,
total_results: Optional[int] = None,
duration_seconds: Optional[float] = None,
metadata: Optional[Dict] = None
) -> Optional[Scan]:
"""
Update a scan record with new values.
Args:
scan_id: UUID of the scan to update.
status: New status (running, completed, failed, partial).
total_results: Number of results found.
duration_seconds: Total scan duration.
metadata: Additional metadata to merge.
Returns:
Updated Scan object, or None if scan not found.
"""
with self.session_scope() as session:
scan = session.query(Scan).filter(Scan.scan_id == scan_id).first()
if not scan:
logger.warning(f"Scan not found for update: {scan_id}")
return None
if status:
scan.status = status
if total_results is not None:
scan.total_results = total_results
if duration_seconds is not None:
scan.duration_seconds = duration_seconds
if metadata:
existing = json.loads(scan.extra_data) if scan.extra_data else {}
existing.update(metadata)
scan.extra_data = json.dumps(existing)
# Flush to ensure changes are written before expunge
session.flush()
logger.debug(f"Updated scan {scan_id}: status={status}, results={total_results}")
session.expunge(scan)
return scan
@with_db_retry
def get_scan(self, scan_id: str) -> Optional[Scan]:
"""
Get a scan by its UUID.
Args:
scan_id: UUID of the scan.
Returns:
Scan object or None if not found.
"""
with self.session_scope() as session:
scan = session.query(Scan).filter(Scan.scan_id == scan_id).first()
if scan:
session.expunge(scan)
return scan
@with_db_retry
def get_recent_scans(self, limit: int = 10) -> List[Scan]:
"""
Get the most recent scans.
Args:
limit: Maximum number of scans to return.
Returns:
List of Scan objects ordered by timestamp descending.
"""
with self.session_scope() as session:
scans = session.query(Scan).order_by(Scan.timestamp.desc()).limit(limit).all()
for scan in scans:
session.expunge(scan)
return scans
@with_db_retry
def delete_scan(self, scan_id: str) -> bool:
"""
Delete a scan and all its associated findings.
Args:
scan_id: UUID of the scan to delete.
Returns:
True if scan was deleted, False if not found.
"""
with self.session_scope() as session:
scan = session.query(Scan).filter(Scan.scan_id == scan_id).first()
if scan:
session.delete(scan)
session.commit()
logger.info(f"Deleted scan: {scan_id}")
return True
return False
# =========================================================================
# Finding Operations
# =========================================================================
@with_db_retry
def add_findings(self, scan_id: str, results: List[SearchResult]) -> int:
"""
Add findings from search results to the database.
Args:
scan_id: Parent scan UUID.
results: List of SearchResult objects to store.
Returns:
Number of findings successfully added.
Raises:
SQLAlchemyError: If database operation fails.
Note:
Findings are added in batches for efficiency.
"""
if not results:
logger.debug(f"No findings to add for scan {scan_id}")
return 0
with self.session_scope() as session:
count = 0
for result in results:
try:
finding = Finding.from_search_result(result, scan_id)
session.add(finding)
count += 1
except Exception as e:
logger.warning(f"Failed to create finding from result: {e}")
continue
logger.info(f"Added {count} findings to scan {scan_id}")
return count
@with_db_retry
def get_findings(
self,
scan_id: Optional[str] = None,
min_risk_score: Optional[float] = None,
status: Optional[str] = None,
limit: int = 100,
offset: int = 0
) -> List[Finding]:
"""
Get findings with optional filters.
Args:
scan_id: Filter by scan UUID.
min_risk_score: Minimum risk score (0.0-10.0).
status: Finding status filter (new, confirmed, false_positive, remediated).
limit: Maximum results to return.
offset: Number of results to skip (for pagination).
Returns:
List of Finding objects matching filters, ordered by risk score descending.
"""
with self.session_scope() as session:
query = session.query(Finding)
if scan_id:
query = query.filter(Finding.scan_id == scan_id)
if min_risk_score is not None:
query = query.filter(Finding.risk_score >= min_risk_score)
if status:
query = query.filter(Finding.status == status)
query = query.order_by(Finding.risk_score.desc())
findings = query.offset(offset).limit(limit).all()
for finding in findings:
session.expunge(finding)
return findings
@with_db_retry
def get_finding(self, finding_id: str) -> Optional[Finding]:
"""
Get a single finding by its UUID.
Args:
finding_id: UUID of the finding.
Returns:
Finding object or None if not found.
"""
with self.session_scope() as session:
finding = session.query(Finding).filter(Finding.finding_id == finding_id).first()
if finding:
session.expunge(finding)
return finding
@with_db_retry
def update_finding_status(self, finding_id: str, status: str) -> bool:
"""
Update the status of a finding.
Args:
finding_id: UUID of the finding.
status: New status (new, confirmed, false_positive, remediated).
Returns:
True if finding was updated, False if not found.
"""
with self.session_scope() as session:
finding = session.query(Finding).filter(Finding.finding_id == finding_id).first()
if finding:
finding.status = status
finding.last_seen = datetime.utcnow()
logger.debug(f"Updated finding {finding_id} status to {status}")
return True
logger.warning(f"Finding not found for status update: {finding_id}")
return False
@with_db_retry
def get_finding_by_target(self, ip: str, port: int) -> Optional[Finding]:
"""
Get the most recent finding for a specific target.
Args:
ip: Target IP address.
port: Target port number.
Returns:
Most recent Finding for the target, or None if not found.
"""
with self.session_scope() as session:
finding = session.query(Finding).filter(
Finding.target_ip == ip,
Finding.target_port == port
).order_by(Finding.last_seen.desc()).first()
if finding:
session.expunge(finding)
return finding
# =========================================================================
# Statistics and Maintenance
# =========================================================================
@with_db_retry
def get_statistics(self) -> Dict[str, Any]:
"""
Get overall database statistics.
Returns:
Dictionary containing:
- total_scans: Total number of scans
- total_findings: Total number of findings
- unique_ips: Count of unique IP addresses
- risk_distribution: Dict with critical/high/medium/low counts
- last_scan_time: Timestamp of most recent scan (or None)
Example:
>>> stats = db.get_statistics()
>>> print(f"Critical findings: {stats['risk_distribution']['critical']}")
"""
with self.session_scope() as session:
total_scans = session.query(Scan).count()
total_findings = session.query(Finding).count()
# Risk distribution using CVSS-like thresholds
critical = session.query(Finding).filter(Finding.risk_score >= 9.0).count()
high = session.query(Finding).filter(
Finding.risk_score >= 7.0,
Finding.risk_score < 9.0
).count()
medium = session.query(Finding).filter(
Finding.risk_score >= 4.0,
Finding.risk_score < 7.0
).count()
low = session.query(Finding).filter(Finding.risk_score < 4.0).count()
# Unique IPs discovered
unique_ips = session.query(Finding.target_ip).distinct().count()
# Last scan timestamp
last_scan = session.query(Scan).order_by(Scan.timestamp.desc()).first()
last_scan_time = last_scan.timestamp.isoformat() if last_scan else None
return {
'total_scans': total_scans,
'total_findings': total_findings,
'unique_ips': unique_ips,
'risk_distribution': {
'critical': critical,
'high': high,
'medium': medium,
'low': low
},
'last_scan_time': last_scan_time
}
@with_db_retry
def cleanup_old_data(self, days: int = 90) -> int:
"""
Remove scan data older than specified days.
This is a maintenance operation that removes old scans and their
associated findings to manage database size.
Args:
days: Age threshold in days. Scans older than this will be deleted.
Default is 90 days.
Returns:
Number of scans deleted (findings are cascade deleted).
Raises:
ValueError: If days is less than 1.
SQLAlchemyError: If database operation fails.
Example:
>>> # Remove data older than 30 days
>>> deleted = db.cleanup_old_data(days=30)
>>> print(f"Removed {deleted} old scans")
"""
if days < 1:
raise ValueError("Days must be at least 1")
cutoff = datetime.utcnow() - timedelta(days=days)
logger.info(f"Cleaning up data older than {cutoff.isoformat()}")
with self.session_scope() as session:
# Count first for logging
old_scans = session.query(Scan).filter(Scan.timestamp < cutoff).all()
count = len(old_scans)
if count == 0:
logger.info("No old data to clean up")
return 0
for scan in old_scans:
session.delete(scan)
logger.info(f"Cleaned up {count} scans older than {days} days")
return count
+26
View File
@@ -0,0 +1,26 @@
"""Utility modules for AASRT."""
from .config import Config
from .logger import setup_logger, get_logger
from .exceptions import (
AASRTException,
APIException,
RateLimitException,
ConfigurationException,
ValidationException
)
from .validators import validate_ip, validate_domain, validate_query
__all__ = [
'Config',
'setup_logger',
'get_logger',
'AASRTException',
'APIException',
'RateLimitException',
'ConfigurationException',
'ValidationException',
'validate_ip',
'validate_domain',
'validate_query'
]
+513
View File
@@ -0,0 +1,513 @@
"""
Configuration management for AASRT.
This module provides a production-ready configuration management system with:
- Singleton pattern for global configuration access
- YAML file loading with deep merging
- Environment variable overrides
- Validation of required settings
- Support for structured logging configuration
- Health check capabilities
Configuration priority (highest to lowest):
1. Environment variables
2. YAML configuration file
3. Default values
Example:
>>> from src.utils.config import Config
>>> config = Config()
>>> shodan_key = config.get_shodan_key()
>>> log_level = config.get('logging', 'level', default='INFO')
Environment Variables:
SHODAN_API_KEY: Required Shodan API key
AASRT_LOG_LEVEL: Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
AASRT_ENVIRONMENT: Deployment environment (development, staging, production)
AASRT_DEBUG: Enable debug mode (true/false)
DB_TYPE: Database type (sqlite, postgresql)
DB_HOST, DB_PORT, DB_NAME, DB_USER, DB_PASSWORD: PostgreSQL settings
"""
import os
import secrets
from pathlib import Path
from typing import Any, Dict, List, Optional, Set
import yaml
from dotenv import load_dotenv
from .exceptions import ConfigurationException
from .logger import get_logger
logger = get_logger(__name__)
# =============================================================================
# Validation Constants
# =============================================================================
VALID_LOG_LEVELS: Set[str] = {"DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"}
VALID_ENVIRONMENTS: Set[str] = {"development", "staging", "production"}
VALID_DB_TYPES: Set[str] = {"sqlite", "postgresql", "mysql"}
REQUIRED_SETTINGS: List[str] = [] # API key is optional until scan is run
class Config:
"""
Configuration manager for AASRT with singleton pattern.
This class provides centralized configuration management with:
- Thread-safe singleton access
- YAML file configuration
- Environment variable overrides
- Validation of critical settings
- Health check for configuration state
Attributes:
_instance: Singleton instance.
_config: Configuration dictionary.
_initialized: Flag indicating initialization status.
_config_path: Path to loaded configuration file.
_environment: Current deployment environment.
Example:
>>> config = Config()
>>> api_key = config.get_shodan_key()
>>> if not api_key:
... print("Warning: Shodan API key not configured")
"""
_instance: Optional['Config'] = None
_config: Dict[str, Any] = {}
def __new__(cls, config_path: Optional[str] = None):
"""
Singleton pattern implementation.
Args:
config_path: Optional path to YAML configuration file.
Returns:
Singleton Config instance.
"""
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self, config_path: Optional[str] = None) -> None:
"""
Initialize configuration from multiple sources.
Configuration is loaded in order of priority:
1. Default values
2. YAML configuration file
3. Environment variables (highest priority)
Args:
config_path: Path to YAML configuration file.
If not provided, searches common locations.
Raises:
ConfigurationException: If YAML file is malformed.
"""
if self._initialized:
return
# Load environment variables from .env file
load_dotenv()
# Store metadata
self._config_path: Optional[str] = None
self._environment: str = os.getenv('AASRT_ENVIRONMENT', 'development')
self._validation_errors: List[str] = []
# Default configuration
self._config = self._get_defaults()
# Load from file if provided
if config_path:
self._load_from_file(config_path)
else:
# Try to find config file in common locations
for path in ['config.yaml', 'config.yml', './config/config.yaml']:
if os.path.exists(path):
self._load_from_file(path)
break
# Override with environment variables
self._load_from_env()
# Validate configuration
self._validate_config()
self._initialized = True
logger.info(f"Configuration initialized (environment: {self._environment})")
def _get_defaults(self) -> Dict[str, Any]:
"""Get default configuration values."""
return {
'shodan': {
'enabled': True,
'rate_limit': 1,
'max_results': 100,
'timeout': 30
},
'vulnerability_checks': {
'enabled': True,
'passive_only': True,
'timeout_per_check': 10
},
'reporting': {
'formats': ['json', 'csv'],
'output_dir': './reports',
'anonymize_by_default': False
},
'filtering': {
'whitelist_ips': [],
'whitelist_domains': [],
'min_confidence_score': 70,
'exclude_honeypots': True
},
'logging': {
'level': 'INFO',
'file': './logs/scanner.log',
'max_size_mb': 100,
'backup_count': 5
},
'database': {
'type': 'sqlite',
'sqlite': {
'path': './data/scanner.db'
}
},
'api_keys': {},
'clawsec': {
'enabled': True,
'feed_url': 'https://clawsec.prompt.security/advisories/feed.json',
'cache_ttl_seconds': 86400, # 24 hours
'cache_file': './data/clawsec_cache.json',
'offline_mode': False,
'timeout': 30,
'auto_refresh': True
}
}
def _load_from_file(self, path: str) -> None:
"""
Load configuration from YAML file.
Args:
path: Path to YAML configuration file.
Raises:
ConfigurationException: If YAML is malformed.
"""
try:
with open(path, 'r') as f:
file_config = yaml.safe_load(f)
if file_config:
self._deep_merge(self._config, file_config)
self._config_path = path
logger.info(f"Loaded configuration from {path}")
except FileNotFoundError:
logger.warning(f"Configuration file not found: {path}")
except yaml.YAMLError as e:
raise ConfigurationException(f"Invalid YAML in configuration file: {e}")
def _load_from_env(self) -> None:
"""
Load settings from environment variables.
Environment variables override file-based configuration.
This method handles all supported environment variables.
"""
# Load Shodan API key
shodan_key = os.getenv('SHODAN_API_KEY')
if shodan_key:
self._set_nested(('api_keys', 'shodan'), shodan_key)
# Load log level if set
log_level = os.getenv('AASRT_LOG_LEVEL', '').upper()
if log_level and log_level in VALID_LOG_LEVELS:
self._set_nested(('logging', 'level'), log_level)
elif log_level:
logger.warning(f"Invalid log level '{log_level}', using default")
# Load environment setting
env = os.getenv('AASRT_ENVIRONMENT', '').lower()
if env and env in VALID_ENVIRONMENTS:
self._environment = env
# Load debug flag
debug = os.getenv('AASRT_DEBUG', '').lower()
if debug in ('true', '1', 'yes'):
self._set_nested(('logging', 'level'), 'DEBUG')
# Load database settings from environment
db_type = os.getenv('DB_TYPE', '').lower()
if db_type and db_type in VALID_DB_TYPES:
self._set_nested(('database', 'type'), db_type)
# PostgreSQL settings from environment
if os.getenv('DB_HOST'):
self._set_nested(('database', 'postgresql', 'host'), os.getenv('DB_HOST'))
if os.getenv('DB_PORT'):
try:
port = int(os.getenv('DB_PORT'))
self._set_nested(('database', 'postgresql', 'port'), port)
except ValueError:
logger.warning("Invalid DB_PORT, using default")
if os.getenv('DB_NAME'):
self._set_nested(('database', 'postgresql', 'database'), os.getenv('DB_NAME'))
if os.getenv('DB_USER'):
self._set_nested(('database', 'postgresql', 'user'), os.getenv('DB_USER'))
if os.getenv('DB_PASSWORD'):
self._set_nested(('database', 'postgresql', 'password'), os.getenv('DB_PASSWORD'))
if os.getenv('DB_SSL_MODE'):
self._set_nested(('database', 'postgresql', 'ssl_mode'), os.getenv('DB_SSL_MODE'))
# Max results limit
max_results = os.getenv('AASRT_MAX_RESULTS')
if max_results:
try:
self._set_nested(('shodan', 'max_results'), int(max_results))
except ValueError:
logger.warning("Invalid AASRT_MAX_RESULTS, using default")
def _validate_config(self) -> None:
"""
Validate configuration settings.
Checks for valid values and logs warnings for potential issues.
Does not raise exceptions to allow graceful degradation.
"""
self._validation_errors = []
# Validate log level
log_level = self.get('logging', 'level', default='INFO')
if log_level.upper() not in VALID_LOG_LEVELS:
self._validation_errors.append(f"Invalid log level: {log_level}")
# Validate database type
db_type = self.get('database', 'type', default='sqlite')
if db_type.lower() not in VALID_DB_TYPES:
self._validation_errors.append(f"Invalid database type: {db_type}")
# Validate max results is positive
max_results = self.get('shodan', 'max_results', default=100)
if not isinstance(max_results, int) or max_results < 1:
self._validation_errors.append(f"Invalid max_results: {max_results}")
# Check for Shodan API key (warning, not error)
if not self.get_shodan_key():
logger.debug("Shodan API key not configured - scans will require it")
# Log validation errors
for error in self._validation_errors:
logger.warning(f"Configuration validation: {error}")
def _deep_merge(self, base: Dict, overlay: Dict) -> None:
"""
Deep merge overlay dictionary into base dictionary.
Args:
base: Base dictionary to merge into (modified in place).
overlay: Overlay dictionary to merge from.
"""
for key, value in overlay.items():
if key in base and isinstance(base[key], dict) and isinstance(value, dict):
self._deep_merge(base[key], value)
else:
base[key] = value
def _set_nested(self, path: tuple, value: Any) -> None:
"""
Set a nested configuration value by key path.
Args:
path: Tuple of keys representing the path.
value: Value to set at the path.
"""
current = self._config
for key in path[:-1]:
if key not in current:
current[key] = {}
current = current[key]
current[path[-1]] = value
def get(self, *keys: str, default: Any = None) -> Any:
"""
Get a configuration value by nested keys.
Args:
*keys: Nested keys to traverse (e.g., 'database', 'type').
default: Default value if path not found.
Returns:
Configuration value or default.
Example:
>>> config.get('shodan', 'max_results', default=100)
100
"""
current = self._config
for key in keys:
if isinstance(current, dict) and key in current:
current = current[key]
else:
return default
return current
def get_shodan_key(self) -> Optional[str]:
"""
Get Shodan API key.
Returns:
Shodan API key string, or None if not configured.
"""
return self.get('api_keys', 'shodan')
def get_shodan_config(self) -> Dict[str, Any]:
"""
Get Shodan configuration dictionary.
Returns:
Dictionary with Shodan settings (enabled, rate_limit, max_results, timeout).
"""
return self.get('shodan', default={})
def get_clawsec_config(self) -> Dict[str, Any]:
"""
Get ClawSec configuration dictionary.
Returns:
Dictionary with ClawSec settings.
"""
return self.get('clawsec', default={})
def is_clawsec_enabled(self) -> bool:
"""
Check if ClawSec integration is enabled.
Returns:
True if ClawSec vulnerability lookup is enabled.
"""
return self.get('clawsec', 'enabled', default=True)
def get_database_config(self) -> Dict[str, Any]:
"""
Get database configuration.
Returns:
Dictionary with database settings.
"""
return self.get('database', default={})
def get_logging_config(self) -> Dict[str, Any]:
"""
Get logging configuration.
Returns:
Dictionary with logging settings (level, file, max_size_mb, backup_count).
"""
return self.get('logging', default={})
@property
def environment(self) -> str:
"""
Get current deployment environment.
Returns:
Environment string (development, staging, production).
"""
return self._environment
@property
def is_production(self) -> bool:
"""
Check if running in production environment.
Returns:
True if environment is 'production'.
"""
return self._environment == 'production'
@property
def is_debug(self) -> bool:
"""
Check if debug mode is enabled.
Returns:
True if log level is DEBUG.
"""
return self.get('logging', 'level', default='INFO').upper() == 'DEBUG'
@property
def all(self) -> Dict[str, Any]:
"""
Get all configuration as dictionary.
Returns:
Copy of complete configuration dictionary.
"""
return self._config.copy()
def reload(self, config_path: Optional[str] = None) -> None:
"""
Reload configuration from file and environment.
Use this to refresh configuration without restarting the application.
Args:
config_path: Optional path to configuration file.
If None, uses previously loaded file path.
"""
logger.info("Reloading configuration...")
self._initialized = False
self._config = self._get_defaults()
# Use new path or fall back to previously loaded path
path_to_load = config_path or self._config_path
if path_to_load:
self._load_from_file(path_to_load)
self._load_from_env()
self._validate_config()
self._initialized = True
logger.info("Configuration reloaded successfully")
def health_check(self) -> Dict[str, Any]:
"""
Perform a health check on configuration.
Returns:
Dictionary with health status:
- healthy: bool indicating if configuration is valid
- environment: Current deployment environment
- config_file: Path to loaded config file (if any)
- validation_errors: List of validation errors
- shodan_configured: Whether Shodan API key is set
- clawsec_enabled: Whether ClawSec is enabled
"""
return {
"healthy": len(self._validation_errors) == 0,
"environment": self._environment,
"config_file": self._config_path,
"validation_errors": self._validation_errors.copy(),
"shodan_configured": bool(self.get_shodan_key()),
"clawsec_enabled": self.is_clawsec_enabled(),
"log_level": self.get('logging', 'level', default='INFO'),
"database_type": self.get('database', 'type', default='sqlite')
}
@staticmethod
def reset_instance() -> None:
"""
Reset the singleton instance (for testing).
Warning:
This should only be used in tests. It will cause any
existing references to the old instance to be stale.
"""
Config._instance = None
+51
View File
@@ -0,0 +1,51 @@
"""Custom exceptions for AASRT."""
class AASRTException(Exception):
"""Base exception for AASRT."""
pass
class APIException(AASRTException):
"""Raised when API call fails."""
def __init__(self, message: str, engine: str = None, status_code: int = None):
self.engine = engine
self.status_code = status_code
super().__init__(message)
class RateLimitException(AASRTException):
"""Raised when rate limit is exceeded."""
def __init__(self, message: str, engine: str = None, retry_after: int = None):
self.engine = engine
self.retry_after = retry_after
super().__init__(message)
class ConfigurationException(AASRTException):
"""Raised when configuration is invalid."""
pass
class ValidationException(AASRTException):
"""Raised when input validation fails."""
pass
class AuthenticationException(AASRTException):
"""Raised when authentication fails."""
def __init__(self, message: str, engine: str = None):
self.engine = engine
super().__init__(message)
class TimeoutException(AASRTException):
"""Raised when a request times out."""
def __init__(self, message: str, engine: str = None, timeout: int = None):
self.engine = engine
self.timeout = timeout
super().__init__(message)
+91
View File
@@ -0,0 +1,91 @@
"""Logging setup for AASRT."""
import logging
import os
from logging.handlers import RotatingFileHandler
from typing import Optional
_loggers = {}
def setup_logger(
name: str = "aasrt",
level: str = "INFO",
log_file: Optional[str] = None,
max_size_mb: int = 100,
backup_count: int = 5
) -> logging.Logger:
"""
Setup and configure a logger.
Args:
name: Logger name
level: Log level (DEBUG, INFO, WARNING, ERROR)
log_file: Path to log file (optional)
max_size_mb: Max log file size in MB
backup_count: Number of backup files to keep
Returns:
Configured logger instance
"""
if name in _loggers:
return _loggers[name]
logger = logging.getLogger(name)
logger.setLevel(getattr(logging, level.upper(), logging.INFO))
# Prevent duplicate handlers
if logger.handlers:
return logger
# Console handler with colors
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.DEBUG)
# Format with colors for console
console_format = logging.Formatter(
'%(asctime)s | %(levelname)-8s | %(name)s | %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
console_handler.setFormatter(console_format)
logger.addHandler(console_handler)
# File handler if log_file specified
if log_file:
# Ensure directory exists
log_dir = os.path.dirname(log_file)
if log_dir and not os.path.exists(log_dir):
os.makedirs(log_dir, exist_ok=True)
file_handler = RotatingFileHandler(
log_file,
maxBytes=max_size_mb * 1024 * 1024,
backupCount=backup_count
)
file_handler.setLevel(logging.DEBUG)
file_format = logging.Formatter(
'%(asctime)s | %(levelname)-8s | %(name)s | %(filename)s:%(lineno)d | %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
file_handler.setFormatter(file_format)
logger.addHandler(file_handler)
_loggers[name] = logger
return logger
def get_logger(name: str = "aasrt") -> logging.Logger:
"""
Get an existing logger or create a new one.
Args:
name: Logger name
Returns:
Logger instance
"""
if name in _loggers:
return _loggers[name]
return setup_logger(name)
+583
View File
@@ -0,0 +1,583 @@
"""
Input validation utilities for AASRT.
This module provides comprehensive input validation and sanitization functions
for security-sensitive operations including:
- IP address and domain validation
- Port number and query string validation
- File path sanitization (directory traversal prevention)
- API key format validation
- Template name whitelist validation
- Configuration value validation
All validators raise ValidationException on invalid input with descriptive
error messages for debugging.
Example:
>>> from src.utils.validators import validate_ip, validate_file_path
>>> validate_ip("192.168.1.1") # Returns True
>>> validate_file_path("../../../etc/passwd") # Raises ValidationException
"""
import re
import os
import ipaddress
from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Union
import validators
from .exceptions import ValidationException
# =============================================================================
# Constants
# =============================================================================
# Valid log levels for configuration
VALID_LOG_LEVELS: Set[str] = {"DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"}
# Valid environment names
VALID_ENVIRONMENTS: Set[str] = {"development", "staging", "production"}
# Valid database types
VALID_DB_TYPES: Set[str] = {"sqlite", "postgresql", "mysql"}
# Valid report formats
VALID_REPORT_FORMATS: Set[str] = {"json", "csv", "html", "pdf"}
# Valid query template names (whitelist)
VALID_TEMPLATES: Set[str] = {
"clawdbot_instances",
"autogpt_instances",
"langchain_agents",
"openai_agents",
"anthropic_agents",
"ai_agent_general",
"agent_gpt",
"babyagi_instances",
"crewai_instances",
"autogen_instances",
"superagi_instances",
"flowise_instances",
"dify_instances",
}
# Maximum limits for various inputs
MAX_QUERY_LENGTH: int = 2000
MAX_RESULTS_LIMIT: int = 10000
MIN_RESULTS_LIMIT: int = 1
MAX_PORT: int = 65535
MIN_PORT: int = 1
MAX_FILE_PATH_LENGTH: int = 4096
MAX_API_KEY_LENGTH: int = 256
# =============================================================================
# IP and Network Validators
# =============================================================================
def validate_ip(ip: str) -> bool:
"""
Validate an IP address (IPv4 or IPv6).
Args:
ip: IP address string to validate.
Returns:
True if the IP address is valid.
Raises:
ValidationException: If IP is None, empty, or invalid format.
Example:
>>> validate_ip("192.168.1.1")
True
>>> validate_ip("2001:db8::1")
True
>>> validate_ip("invalid")
ValidationException: Invalid IP address: invalid
"""
if ip is None:
raise ValidationException("IP address cannot be None")
if not isinstance(ip, str):
raise ValidationException(f"IP address must be a string, got {type(ip).__name__}")
ip = ip.strip()
if not ip:
raise ValidationException("IP address cannot be empty")
try:
ipaddress.ip_address(ip)
return True
except ValueError:
raise ValidationException(f"Invalid IP address: {ip}")
def validate_domain(domain: str) -> bool:
"""
Validate a domain name.
Args:
domain: Domain name string
Returns:
True if valid
Raises:
ValidationException: If domain is invalid
"""
if validators.domain(domain):
return True
raise ValidationException(f"Invalid domain: {domain}")
def validate_query(query: str, engine: str) -> bool:
"""
Validate a search query for a specific engine.
Args:
query: Search query string
engine: Search engine name
Returns:
True if valid
Raises:
ValidationException: If query is invalid
"""
if not query or not query.strip():
raise ValidationException("Query cannot be empty")
# Check for potentially dangerous characters
dangerous_patterns = [
r'[<>]', # Script injection attempts
r'\x00', # Null bytes
]
for pattern in dangerous_patterns:
if re.search(pattern, query):
raise ValidationException(f"Query contains invalid characters: {pattern}")
# Engine-specific validation
if engine == "shodan":
# Shodan queries should be reasonable length
if len(query) > 1000:
raise ValidationException("Shodan query too long (max 1000 chars)")
elif engine == "censys":
# Censys queries should be reasonable length
if len(query) > 2000:
raise ValidationException("Censys query too long (max 2000 chars)")
return True
def validate_port(port: int) -> bool:
"""
Validate a port number.
Args:
port: Port number
Returns:
True if valid
Raises:
ValidationException: If port is invalid
"""
if not isinstance(port, int) or port < 1 or port > 65535:
raise ValidationException(f"Invalid port number: {port}")
return True
def validate_api_key(api_key: str, engine: str) -> bool:
"""
Validate API key format for a specific engine.
Args:
api_key: API key string
engine: Search engine name
Returns:
True if valid
Raises:
ValidationException: If API key format is invalid
"""
if not api_key or not api_key.strip():
raise ValidationException(f"API key for {engine} cannot be empty")
# Basic format validation (not checking actual validity)
if engine == "shodan":
# Shodan API keys are typically 32 characters
if len(api_key) < 20:
raise ValidationException("Shodan API key appears too short")
return True
def sanitize_output(text: str) -> str:
"""
Sanitize text for safe output (remove potential secrets).
This function redacts sensitive patterns like API keys, passwords, and
authentication tokens to prevent accidental exposure in logs or output.
Args:
text: Text to sanitize.
Returns:
Sanitized text with sensitive data replaced by REDACTED markers.
Example:
>>> sanitize_output("key: sk-ant-abc123...")
'key: sk-ant-***REDACTED***'
"""
if text is None:
return ""
if not isinstance(text, str):
text = str(text)
# Patterns for sensitive data (order matters - more specific first)
patterns = [
# Anthropic API keys
(r'sk-ant-[a-zA-Z0-9-_]{20,}', 'sk-ant-***REDACTED***'),
# OpenAI API keys
(r'sk-[a-zA-Z0-9]{40,}', 'sk-***REDACTED***'),
# AWS Access Key
(r'AKIA[0-9A-Z]{16}', 'AKIA***REDACTED***'),
# AWS Secret Key
(r'(?i)aws_secret_access_key["\s:=]+["\']?[A-Za-z0-9/+=]{40}', 'aws_secret_access_key=***REDACTED***'),
# GitHub tokens
(r'ghp_[a-zA-Z0-9]{36}', 'ghp_***REDACTED***'),
(r'gho_[a-zA-Z0-9]{36}', 'gho_***REDACTED***'),
# Google API keys
(r'AIza[0-9A-Za-z-_]{35}', 'AIza***REDACTED***'),
# Stripe keys
(r'sk_live_[a-zA-Z0-9]{24,}', 'sk_live_***REDACTED***'),
(r'sk_test_[a-zA-Z0-9]{24,}', 'sk_test_***REDACTED***'),
# Shodan API key (32 hex chars)
(r'[a-fA-F0-9]{32}', '***REDACTED_KEY***'),
# Generic password patterns
(r'password["\s:=]+["\']?[\w@#$%^&*!?]+', 'password=***REDACTED***'),
(r'passwd["\s:=]+["\']?[\w@#$%^&*!?]+', 'passwd=***REDACTED***'),
(r'secret["\s:=]+["\']?[\w@#$%^&*!?]+', 'secret=***REDACTED***'),
# Bearer tokens
(r'Bearer\s+[a-zA-Z0-9._-]+', 'Bearer ***REDACTED***'),
# Basic auth
(r'Basic\s+[a-zA-Z0-9+/=]+', 'Basic ***REDACTED***'),
]
result = text
for pattern, replacement in patterns:
result = re.sub(pattern, replacement, result, flags=re.IGNORECASE)
return result
# =============================================================================
# File Path Validators
# =============================================================================
def validate_file_path(
path: str,
must_exist: bool = False,
allow_absolute: bool = True,
base_dir: Optional[str] = None
) -> str:
"""
Validate and sanitize a file path to prevent directory traversal attacks.
Args:
path: File path to validate.
must_exist: If True, the file must exist.
allow_absolute: If True, allow absolute paths.
base_dir: If provided, ensure path is within this directory.
Returns:
Sanitized, normalized file path.
Raises:
ValidationException: If path is invalid or potentially dangerous.
Example:
>>> validate_file_path("reports/scan.json")
'reports/scan.json'
>>> validate_file_path("../../../etc/passwd")
ValidationException: Path traversal detected
"""
if path is None:
raise ValidationException("File path cannot be None")
if not isinstance(path, str):
raise ValidationException(f"File path must be a string, got {type(path).__name__}")
path = path.strip()
if not path:
raise ValidationException("File path cannot be empty")
if len(path) > MAX_FILE_PATH_LENGTH:
raise ValidationException(f"File path too long (max {MAX_FILE_PATH_LENGTH} chars)")
# Check for null bytes (security risk)
if '\x00' in path:
raise ValidationException("File path contains null bytes")
# Normalize the path
try:
normalized = os.path.normpath(path)
except Exception as e:
raise ValidationException(f"Invalid file path: {e}")
# Check for directory traversal
if '..' in normalized.split(os.sep):
raise ValidationException("Path traversal detected: '..' not allowed")
# Check absolute path restriction
if not allow_absolute and os.path.isabs(normalized):
raise ValidationException("Absolute paths not allowed")
# Check if within base directory
if base_dir:
base_dir = os.path.abspath(base_dir)
full_path = os.path.abspath(os.path.join(base_dir, normalized))
if not full_path.startswith(base_dir):
raise ValidationException("Path escapes base directory")
# Check existence if required
if must_exist and not os.path.exists(path):
raise ValidationException(f"File does not exist: {path}")
return normalized
# =============================================================================
# Template and Configuration Validators
# =============================================================================
def validate_template_name(template: str) -> bool:
"""
Validate a query template name against the whitelist.
Args:
template: Template name to validate.
Returns:
True if template is valid.
Raises:
ValidationException: If template is not in the allowed list.
Example:
>>> validate_template_name("clawdbot_instances")
True
>>> validate_template_name("malicious_query")
ValidationException: Invalid template name
"""
if template is None:
raise ValidationException("Template name cannot be None")
template = template.strip().lower()
if not template:
raise ValidationException("Template name cannot be empty")
if template not in VALID_TEMPLATES:
valid_list = ", ".join(sorted(VALID_TEMPLATES))
raise ValidationException(
f"Invalid template name: '{template}'. Valid templates: {valid_list}"
)
return True
def validate_max_results(max_results: Union[int, str]) -> int:
"""
Validate and normalize max_results parameter.
Args:
max_results: Maximum number of results (int or string).
Returns:
Validated integer value.
Raises:
ValidationException: If value is invalid or out of range.
Example:
>>> validate_max_results(100)
100
>>> validate_max_results("50")
50
>>> validate_max_results(-1)
ValidationException: max_results must be positive
"""
if max_results is None:
raise ValidationException("max_results cannot be None")
# Convert string to int if needed
if isinstance(max_results, str):
try:
max_results = int(max_results.strip())
except ValueError:
raise ValidationException(f"max_results must be a number, got: '{max_results}'")
if not isinstance(max_results, int):
raise ValidationException(f"max_results must be an integer, got {type(max_results).__name__}")
if max_results < MIN_RESULTS_LIMIT:
raise ValidationException(f"max_results must be at least {MIN_RESULTS_LIMIT}")
if max_results > MAX_RESULTS_LIMIT:
raise ValidationException(f"max_results cannot exceed {MAX_RESULTS_LIMIT}")
return max_results
def validate_log_level(level: str) -> str:
"""
Validate a log level string.
Args:
level: Log level string.
Returns:
Normalized uppercase log level.
Raises:
ValidationException: If log level is invalid.
"""
if level is None:
raise ValidationException("Log level cannot be None")
level = str(level).strip().upper()
if level not in VALID_LOG_LEVELS:
valid_list = ", ".join(sorted(VALID_LOG_LEVELS))
raise ValidationException(f"Invalid log level: '{level}'. Valid levels: {valid_list}")
return level
def validate_environment(env: str) -> str:
"""
Validate an environment name.
Args:
env: Environment name string.
Returns:
Normalized lowercase environment name.
Raises:
ValidationException: If environment is invalid.
"""
if env is None:
raise ValidationException("Environment cannot be None")
env = str(env).strip().lower()
if env not in VALID_ENVIRONMENTS:
valid_list = ", ".join(sorted(VALID_ENVIRONMENTS))
raise ValidationException(f"Invalid environment: '{env}'. Valid environments: {valid_list}")
return env
def validate_db_type(db_type: str) -> str:
"""
Validate a database type.
Args:
db_type: Database type string.
Returns:
Normalized lowercase database type.
Raises:
ValidationException: If database type is invalid.
"""
if db_type is None:
raise ValidationException("Database type cannot be None")
db_type = str(db_type).strip().lower()
if db_type not in VALID_DB_TYPES:
valid_list = ", ".join(sorted(VALID_DB_TYPES))
raise ValidationException(f"Invalid database type: '{db_type}'. Valid types: {valid_list}")
return db_type
# =============================================================================
# Batch Validation Helpers
# =============================================================================
def validate_config_dict(config: Dict[str, Any]) -> Dict[str, Any]:
"""
Validate a configuration dictionary.
Args:
config: Configuration dictionary to validate.
Returns:
Validated configuration dictionary.
Raises:
ValidationException: If any configuration value is invalid.
"""
validated = {}
# Validate log level if present
if 'logging' in config and 'level' in config['logging']:
config['logging']['level'] = validate_log_level(config['logging']['level'])
# Validate database type if present
if 'database' in config and 'type' in config['database']:
config['database']['type'] = validate_db_type(config['database']['type'])
# Validate max_results if present
if 'shodan' in config and 'max_results' in config['shodan']:
config['shodan']['max_results'] = validate_max_results(config['shodan']['max_results'])
return config
def is_safe_string(text: str, max_length: int = 1000) -> bool:
"""
Check if a string is safe (no injection attempts).
Args:
text: Text to check.
max_length: Maximum allowed length.
Returns:
True if string appears safe, False otherwise.
"""
if text is None:
return False
if len(text) > max_length:
return False
# Check for null bytes
if '\x00' in text:
return False
# Check for common injection patterns
dangerous_patterns = [
r'<script',
r'javascript:',
r'on\w+\s*=',
r'\x00',
r'<!--',
r'--\s*>',
]
for pattern in dangerous_patterns:
if re.search(pattern, text, re.IGNORECASE):
return False
return True