Initial commit: AASRT v1.0.0 - AI Agent Security Reconnaissance Tool

This commit is contained in:
swethab
2026-02-10 10:53:31 -05:00
commit a714a3399b
61 changed files with 14858 additions and 0 deletions
+4
View File
@@ -0,0 +1,4 @@
"""AI Agent Security Reconnaissance Tool (AASRT)"""
__version__ = "1.0.0"
__author__ = "AGK"
+12
View File
@@ -0,0 +1,12 @@
"""Alert modules for AASRT.
This module will contain alerting capabilities:
- Email notifications
- Slack webhooks
- Discord webhooks
- Telegram bot integration
These are planned for Phase 3 implementation.
"""
__all__ = []
+14
View File
@@ -0,0 +1,14 @@
"""Core engine components for AASRT."""
from .query_manager import QueryManager
from .result_aggregator import ResultAggregator
from .vulnerability_assessor import VulnerabilityAssessor, Vulnerability
from .risk_scorer import RiskScorer
__all__ = [
'QueryManager',
'ResultAggregator',
'VulnerabilityAssessor',
'Vulnerability',
'RiskScorer'
]
+252
View File
@@ -0,0 +1,252 @@
"""Query management and execution for AASRT."""
import os
from pathlib import Path
from typing import Any, Dict, List, Optional
import yaml
from src.engines import SearchResult, ShodanEngine
from src.utils.config import Config
from src.utils.logger import get_logger
from src.utils.exceptions import APIException, ConfigurationException
logger = get_logger(__name__)
class QueryManager:
"""Manages search queries using Shodan."""
# Built-in query templates
DEFAULT_TEMPLATES = {
"clawdbot_instances": [
'http.title:"ClawdBot Dashboard"',
'http.html:"ClawdBot" port:3000',
'product:"ClawdBot"'
],
"autogpt_instances": [
'http.title:"Auto-GPT"',
'http.html:"autogpt" port:8000'
],
"langchain_agents": [
'http.html:"langchain" http.html:"agent"',
'product:"LangChain"'
],
"openai_exposed": [
'http.title:"OpenAI Playground"',
'http.html:"sk-" http.html:"openai"'
],
"exposed_env_files": [
'http.html:".env" http.html:"API_KEY"',
'http.title:"Index of" http.html:".env"'
],
"debug_mode": [
'http.html:"DEBUG=True"',
'http.html:"development mode"',
'http.html:"stack trace"'
],
"ai_dashboards": [
'http.title:"AI Dashboard"',
'http.title:"LLM" http.html:"chat"',
'http.html:"anthropic" http.html:"claude"'
],
"jupyter_notebooks": [
'http.title:"Jupyter Notebook"',
'http.title:"JupyterLab"',
'http.html:"jupyter" port:8888'
],
"streamlit_apps": [
'http.html:"streamlit"',
'http.title:"Streamlit"'
]
}
def __init__(self, config: Optional[Config] = None):
"""
Initialize QueryManager.
Args:
config: Configuration instance
"""
self.config = config or Config()
self.engine: Optional[ShodanEngine] = None
self.templates: Dict[str, List[str]] = self.DEFAULT_TEMPLATES.copy()
self._initialize_engine()
self._load_custom_templates()
def _initialize_engine(self) -> None:
"""Initialize Shodan engine."""
api_key = self.config.get_shodan_key()
if api_key:
shodan_config = self.config.get_shodan_config()
self.engine = ShodanEngine(
api_key=api_key,
rate_limit=shodan_config.get('rate_limit', 1.0),
timeout=shodan_config.get('timeout', 30),
max_results=shodan_config.get('max_results', 100)
)
logger.info("Shodan engine initialized")
else:
logger.warning("Shodan API key not provided")
def _load_custom_templates(self) -> None:
"""Load custom query templates from YAML files."""
queries_dir = Path("queries")
if not queries_dir.exists():
return
for yaml_file in queries_dir.glob("*.yaml"):
try:
with open(yaml_file, 'r') as f:
data = yaml.safe_load(f)
if data and 'queries' in data:
template_name = yaml_file.stem
# Support both list format and dict format
queries = data['queries']
if isinstance(queries, dict) and 'shodan' in queries:
self.templates[template_name] = queries['shodan']
elif isinstance(queries, list):
self.templates[template_name] = queries
logger.debug(f"Loaded query template: {template_name}")
except yaml.YAMLError as e:
logger.error(f"Failed to parse {yaml_file}: {e}")
def is_available(self) -> bool:
"""Check if Shodan engine is available."""
return self.engine is not None
def get_available_templates(self) -> List[str]:
"""Get list of available query templates."""
return list(self.templates.keys())
def validate_engine(self) -> bool:
"""
Validate Shodan credentials.
Returns:
True if credentials are valid
"""
if not self.engine:
return False
try:
return self.engine.validate_credentials()
except Exception as e:
logger.error(f"Failed to validate Shodan: {e}")
return False
def get_quota_info(self) -> Dict[str, Any]:
"""Get Shodan API quota information."""
if not self.engine:
return {'error': 'Engine not initialized'}
return self.engine.get_quota_info()
def execute_query(
self,
query: str,
max_results: Optional[int] = None
) -> List[SearchResult]:
"""
Execute a search query.
Args:
query: Shodan search query
max_results: Maximum results to return
Returns:
List of SearchResult objects
"""
if not self.engine:
raise ConfigurationException("Shodan engine not initialized. Check your API key.")
try:
results = self.engine.search(query, max_results)
logger.info(f"Query returned {len(results)} results")
return results
except APIException as e:
logger.error(f"Query failed: {e}")
raise
def execute_template(
self,
template_name: str,
max_results: Optional[int] = None
) -> List[SearchResult]:
"""
Execute all queries from a template.
Args:
template_name: Name of the query template
max_results: Maximum results per query
Returns:
Combined list of results from all queries
"""
if template_name not in self.templates:
raise ConfigurationException(f"Template not found: {template_name}")
if not self.engine:
raise ConfigurationException("Shodan engine not initialized. Check your API key.")
queries = self.templates[template_name]
all_results = []
for query in queries:
try:
results = self.engine.search(query, max_results)
all_results.extend(results)
except APIException as e:
logger.error(f"Query failed: {query} - {e}")
logger.info(f"Template '{template_name}' returned {len(all_results)} total results")
return all_results
def count_results(self, query: str) -> int:
"""
Get count of results for a query without consuming credits.
Args:
query: Search query
Returns:
Number of results
"""
if not self.engine:
return 0
return self.engine.count(query)
def add_custom_template(self, name: str, queries: List[str]) -> None:
"""
Add a custom query template.
Args:
name: Template name
queries: List of Shodan queries
"""
self.templates[name] = queries
logger.info(f"Added custom template: {name}")
def save_template(self, name: str, path: Optional[str] = None) -> None:
"""
Save a template to a YAML file.
Args:
name: Template name
path: Output file path (default: queries/{name}.yaml)
"""
if name not in self.templates:
raise ConfigurationException(f"Template not found: {name}")
output_path = path or f"queries/{name}.yaml"
os.makedirs(os.path.dirname(output_path), exist_ok=True)
template_data = {
'name': name,
'description': f"Query template for {name}",
'queries': self.templates[name]
}
with open(output_path, 'w') as f:
yaml.dump(template_data, f, default_flow_style=False)
logger.info(f"Saved template to {output_path}")
+304
View File
@@ -0,0 +1,304 @@
"""Result aggregation and deduplication for AASRT."""
from typing import Any, Dict, List, Optional, Set
from datetime import datetime
from collections import defaultdict
from src.engines import SearchResult
from src.utils.logger import get_logger
logger = get_logger(__name__)
class ResultAggregator:
"""Aggregates and deduplicates search results from multiple engines."""
def __init__(
self,
dedupe_by: str = "ip_port",
merge_metadata: bool = True,
prefer_engine: Optional[str] = None
):
"""
Initialize ResultAggregator.
Args:
dedupe_by: Deduplication key ("ip_port", "ip", or "hostname")
merge_metadata: Whether to merge metadata from duplicate results
prefer_engine: Preferred engine when resolving conflicts
"""
self.dedupe_by = dedupe_by
self.merge_metadata = merge_metadata
self.prefer_engine = prefer_engine
def aggregate(
self,
results: Dict[str, List[SearchResult]]
) -> List[SearchResult]:
"""
Aggregate results from multiple engines.
Args:
results: Dictionary mapping engine names to result lists
Returns:
Deduplicated and merged list of results
"""
all_results = []
# Flatten results
for engine_name, engine_results in results.items():
for result in engine_results:
result.source_engine = engine_name
all_results.append(result)
logger.info(f"Aggregating {len(all_results)} total results")
# Deduplicate
deduplicated = self._deduplicate(all_results)
logger.info(f"After deduplication: {len(deduplicated)} unique results")
return deduplicated
def _get_dedupe_key(self, result: SearchResult) -> str:
"""Get deduplication key for a result."""
if self.dedupe_by == "ip_port":
return f"{result.ip}:{result.port}"
elif self.dedupe_by == "ip":
return result.ip
elif self.dedupe_by == "hostname":
return result.hostname or result.ip
else:
return f"{result.ip}:{result.port}"
def _deduplicate(self, results: List[SearchResult]) -> List[SearchResult]:
"""Deduplicate results based on configured key."""
seen: Dict[str, SearchResult] = {}
for result in results:
key = self._get_dedupe_key(result)
if key not in seen:
seen[key] = result
else:
# Merge with existing result
existing = seen[key]
seen[key] = self._merge_results(existing, result)
return list(seen.values())
def _merge_results(
self,
existing: SearchResult,
new: SearchResult
) -> SearchResult:
"""
Merge two results for the same target.
Args:
existing: Existing result
new: New result to merge
Returns:
Merged result
"""
# Prefer result from preferred engine
if self.prefer_engine:
if new.source_engine == self.prefer_engine:
base = new
other = existing
else:
base = existing
other = new
else:
# Default: prefer result with more information
if len(new.metadata) > len(existing.metadata):
base = new
other = existing
else:
base = existing
other = new
# Merge vulnerabilities (union)
merged_vulns = list(set(base.vulnerabilities + other.vulnerabilities))
# Merge metadata if enabled
if self.merge_metadata:
merged_metadata = {**other.metadata, **base.metadata}
# Track source engines
engines = set()
if base.metadata.get('source_engines'):
engines.update(base.metadata['source_engines'])
if other.metadata.get('source_engines'):
engines.update(other.metadata['source_engines'])
engines.add(base.source_engine)
engines.add(other.source_engine)
merged_metadata['source_engines'] = list(engines)
else:
merged_metadata = base.metadata
# Take highest risk score
risk_score = max(base.risk_score, other.risk_score)
# Take highest confidence
confidence = max(base.confidence, other.confidence)
return SearchResult(
ip=base.ip,
port=base.port,
hostname=base.hostname or other.hostname,
service=base.service or other.service,
banner=base.banner or other.banner,
vulnerabilities=merged_vulns,
metadata=merged_metadata,
source_engine=base.source_engine,
timestamp=base.timestamp,
risk_score=risk_score,
confidence=confidence
)
def filter_by_confidence(
self,
results: List[SearchResult],
min_confidence: int = 70
) -> List[SearchResult]:
"""Filter results by minimum confidence score."""
filtered = [r for r in results if r.confidence >= min_confidence]
logger.info(f"Filtered by confidence >= {min_confidence}: {len(filtered)} results")
return filtered
def filter_by_risk_score(
self,
results: List[SearchResult],
min_score: float = 0.0
) -> List[SearchResult]:
"""Filter results by minimum risk score."""
filtered = [r for r in results if r.risk_score >= min_score]
logger.info(f"Filtered by risk >= {min_score}: {len(filtered)} results")
return filtered
def filter_whitelist(
self,
results: List[SearchResult],
whitelist_ips: Optional[List[str]] = None,
whitelist_domains: Optional[List[str]] = None
) -> List[SearchResult]:
"""Filter out whitelisted IPs and domains."""
if not whitelist_ips and not whitelist_domains:
return results
whitelist_ips = set(whitelist_ips or [])
whitelist_domains = set(whitelist_domains or [])
filtered = []
for result in results:
if result.ip in whitelist_ips:
continue
if result.hostname and result.hostname in whitelist_domains:
continue
# Check if hostname ends with any whitelisted domain
if result.hostname:
skip = False
for domain in whitelist_domains:
if result.hostname.endswith(f".{domain}") or result.hostname == domain:
skip = True
break
if skip:
continue
filtered.append(result)
excluded = len(results) - len(filtered)
if excluded > 0:
logger.info(f"Excluded {excluded} whitelisted results")
return filtered
def group_by_ip(
self,
results: List[SearchResult]
) -> Dict[str, List[SearchResult]]:
"""Group results by IP address."""
grouped = defaultdict(list)
for result in results:
grouped[result.ip].append(result)
return dict(grouped)
def group_by_service(
self,
results: List[SearchResult]
) -> Dict[str, List[SearchResult]]:
"""Group results by service type."""
grouped = defaultdict(list)
for result in results:
service = result.service or "unknown"
grouped[service].append(result)
return dict(grouped)
def get_statistics(self, results: List[SearchResult]) -> Dict[str, Any]:
"""
Get aggregate statistics for results.
Args:
results: List of search results
Returns:
Statistics dictionary
"""
if not results:
return {
'total_results': 0,
'unique_ips': 0,
'unique_hostnames': 0,
'engines_used': [],
'vulnerability_counts': {},
'risk_distribution': {},
'top_services': []
}
# Count unique IPs and hostnames
unique_ips = set(r.ip for r in results)
unique_hostnames = set(r.hostname for r in results if r.hostname)
# Count engines
engines = set()
for r in results:
if r.metadata.get('source_engines'):
engines.update(r.metadata['source_engines'])
else:
engines.add(r.source_engine)
# Count vulnerabilities
vuln_counts = defaultdict(int)
for r in results:
for vuln in r.vulnerabilities:
vuln_counts[vuln] += 1
# Risk distribution
risk_dist = {
'critical': len([r for r in results if r.risk_score >= 9.0]),
'high': len([r for r in results if 7.0 <= r.risk_score < 9.0]),
'medium': len([r for r in results if 4.0 <= r.risk_score < 7.0]),
'low': len([r for r in results if r.risk_score < 4.0])
}
# Top services
service_counts = defaultdict(int)
for r in results:
service_counts[r.service or "unknown"] += 1
top_services = sorted(
service_counts.items(),
key=lambda x: x[1],
reverse=True
)[:10]
return {
'total_results': len(results),
'unique_ips': len(unique_ips),
'unique_hostnames': len(unique_hostnames),
'engines_used': list(engines),
'vulnerability_counts': dict(vuln_counts),
'risk_distribution': risk_dist,
'top_services': top_services,
'average_risk_score': sum(r.risk_score for r in results) / len(results)
}
+313
View File
@@ -0,0 +1,313 @@
"""Risk scoring engine for AASRT."""
from typing import Any, Dict, List
from .vulnerability_assessor import Vulnerability
from src.engines import SearchResult
from src.utils.logger import get_logger
logger = get_logger(__name__)
class RiskScorer:
"""Calculates risk scores for targets based on vulnerabilities."""
# Severity weights for scoring
SEVERITY_WEIGHTS = {
'CRITICAL': 1.5,
'HIGH': 1.2,
'MEDIUM': 1.0,
'LOW': 0.5,
'INFO': 0.1
}
# Context multipliers
CONTEXT_MULTIPLIERS = {
'public_internet': 1.2,
'no_waf': 1.1,
'known_vulnerable_version': 1.3,
'ai_agent': 1.2, # AI agents may have additional risk
'clawsec_cve': 1.4, # Known ClawSec CVE vulnerability
'clawsec_critical': 1.5, # Critical ClawSec CVE
}
def __init__(self, config: Dict[str, Any] = None):
"""
Initialize RiskScorer.
Args:
config: Configuration options
"""
self.config = config or {}
def calculate_score(
self,
vulnerabilities: List[Vulnerability],
context: Dict[str, Any] = None
) -> Dict[str, Any]:
"""
Calculate risk score based on vulnerabilities.
Formula:
- Base score: Highest CVSS score found
- Adjusted: base * (1 + 0.1 * critical_count)
- Context multipliers applied
- Capped at 10.0
Args:
vulnerabilities: List of discovered vulnerabilities
context: Additional context (public_internet, etc.)
Returns:
Risk assessment dictionary
"""
if not vulnerabilities:
return {
'overall_score': 0.0,
'severity_breakdown': {
'critical': 0, 'high': 0, 'medium': 0, 'low': 0, 'info': 0
},
'exploitability': 'NONE',
'impact': 'NONE',
'confidence': 100
}
context = context or {}
# Get base score (highest CVSS)
base_score = max(v.cvss_score for v in vulnerabilities)
# Count by severity
severity_counts = self._count_severities(vulnerabilities)
# Apply vulnerability count multiplier
critical_count = severity_counts['critical']
high_count = severity_counts['high']
# Increase score based on multiple vulnerabilities
adjusted_score = base_score * (1.0 + (0.1 * critical_count) + (0.05 * high_count))
# Apply context multipliers
for ctx_key, multiplier in self.CONTEXT_MULTIPLIERS.items():
if context.get(ctx_key, False):
adjusted_score *= multiplier
# Cap at 10.0
final_score = min(adjusted_score, 10.0)
# Determine exploitability
exploitability = self._calculate_exploitability(vulnerabilities, critical_count)
# Determine impact
impact = self._calculate_impact(vulnerabilities)
return {
'overall_score': round(final_score, 1),
'severity_breakdown': severity_counts,
'exploitability': exploitability,
'impact': impact,
'confidence': self._calculate_confidence(vulnerabilities),
'contributing_factors': self._get_contributing_factors(vulnerabilities)
}
def _count_severities(self, vulnerabilities: List[Vulnerability]) -> Dict[str, int]:
"""Count vulnerabilities by severity level."""
counts = {'critical': 0, 'high': 0, 'medium': 0, 'low': 0, 'info': 0}
for v in vulnerabilities:
severity_key = v.severity.lower()
if severity_key in counts:
counts[severity_key] += 1
return counts
def _calculate_exploitability(
self,
vulnerabilities: List[Vulnerability],
critical_count: int
) -> str:
"""Determine overall exploitability level."""
if critical_count >= 2:
return 'CRITICAL'
elif critical_count >= 1:
return 'HIGH'
# Check for easily exploitable vulnerabilities
easy_exploit = ['api_key_exposure', 'no_authentication', 'shell_access']
for v in vulnerabilities:
if any(indicator in v.check_name for indicator in easy_exploit):
return 'HIGH'
high_count = len([v for v in vulnerabilities if v.severity == 'HIGH'])
if high_count >= 2:
return 'MEDIUM'
return 'LOW'
def _calculate_impact(self, vulnerabilities: List[Vulnerability]) -> str:
"""Determine potential impact level."""
# Check for high-impact vulnerabilities
high_impact_indicators = [
'api_key_exposure',
'shell_access',
'database_exposed',
'admin_panel'
]
for v in vulnerabilities:
if any(indicator in v.check_name for indicator in high_impact_indicators):
return 'HIGH'
if any(v.cvss_score >= 7.0 for v in vulnerabilities):
return 'MEDIUM'
return 'LOW'
def _calculate_confidence(self, vulnerabilities: List[Vulnerability]) -> int:
"""Calculate confidence in the assessment."""
if not vulnerabilities:
return 100
# Start with high confidence
confidence = 100
# Reduce confidence for potential false positives
for v in vulnerabilities:
if 'potential' in v.check_name or 'possible' in v.description.lower():
confidence -= 10
return max(confidence, 0)
def _get_contributing_factors(self, vulnerabilities: List[Vulnerability]) -> List[str]:
"""Get list of main contributing factors to the risk score."""
factors = []
for v in vulnerabilities:
if v.severity in ['CRITICAL', 'HIGH']:
factors.append(f"{v.severity}: {v.description}")
return factors[:5] # Top 5 factors
def score_result(self, result: SearchResult, vulnerabilities: List[Vulnerability]) -> SearchResult:
"""
Apply risk score to a SearchResult.
Args:
result: SearchResult to score
vulnerabilities: Assessed vulnerabilities
Returns:
Updated SearchResult with risk score
"""
# Build context from result metadata
context = {
'public_internet': True, # Assume public if found via search
'ai_agent': self._is_ai_agent(result),
'clawsec_cve': self._has_clawsec_cve(result),
'clawsec_critical': self._has_critical_clawsec_cve(result)
}
# Check for WAF
http_info = result.metadata.get('http') or {}
http_headers = http_info.get('headers', {})
if not any(waf in str(http_headers).lower() for waf in ['cloudflare', 'akamai', 'fastly']):
context['no_waf'] = True
# Calculate score
risk_data = self.calculate_score(vulnerabilities, context)
# Update result
result.risk_score = risk_data['overall_score']
result.metadata['risk_assessment'] = risk_data
result.vulnerabilities = [v.check_name for v in vulnerabilities]
return result
def _has_clawsec_cve(self, result: SearchResult) -> bool:
"""Check if result has any ClawSec CVE associations."""
return bool(result.metadata.get('clawsec_advisories'))
def _has_critical_clawsec_cve(self, result: SearchResult) -> bool:
"""Check if result has a critical ClawSec CVE."""
advisories = result.metadata.get('clawsec_advisories', [])
return any(a.get('severity') == 'CRITICAL' for a in advisories)
def _is_ai_agent(self, result: SearchResult) -> bool:
"""Check if result appears to be an AI agent."""
ai_indicators = [
'clawdbot', 'autogpt', 'langchain', 'openai',
'anthropic', 'claude', 'gpt', 'agent'
]
http_info = result.metadata.get('http') or {}
http_title = http_info.get('title') or ''
text = (
(result.banner or '') +
(result.service or '') +
str(http_title)
).lower()
return any(indicator in text for indicator in ai_indicators)
def categorize_results(
self,
results: List[SearchResult]
) -> Dict[str, List[SearchResult]]:
"""
Categorize results by risk level.
Args:
results: List of scored results
Returns:
Dictionary with risk level categories
"""
categories = {
'critical': [],
'high': [],
'medium': [],
'low': []
}
for result in results:
if result.risk_score >= 9.0:
categories['critical'].append(result)
elif result.risk_score >= 7.0:
categories['high'].append(result)
elif result.risk_score >= 4.0:
categories['medium'].append(result)
else:
categories['low'].append(result)
return categories
def get_summary(self, results: List[SearchResult]) -> Dict[str, Any]:
"""
Get risk summary for a set of results.
Args:
results: List of scored results
Returns:
Summary statistics
"""
if not results:
return {
'total': 0,
'average_score': 0.0,
'max_score': 0.0,
'distribution': {'critical': 0, 'high': 0, 'medium': 0, 'low': 0}
}
categories = self.categorize_results(results)
scores = [r.risk_score for r in results]
return {
'total': len(results),
'average_score': round(sum(scores) / len(scores), 1),
'max_score': max(scores),
'distribution': {
'critical': len(categories['critical']),
'high': len(categories['high']),
'medium': len(categories['medium']),
'low': len(categories['low'])
}
}
+441
View File
@@ -0,0 +1,441 @@
"""Vulnerability assessment engine for AASRT."""
import re
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, TYPE_CHECKING
from src.engines import SearchResult
from src.utils.logger import get_logger
if TYPE_CHECKING:
from src.enrichment import ThreatEnricher
logger = get_logger(__name__)
@dataclass
class Vulnerability:
"""Represents a discovered vulnerability."""
check_name: str
severity: str # CRITICAL, HIGH, MEDIUM, LOW, INFO
cvss_score: float
description: str
evidence: Dict[str, Any] = field(default_factory=dict)
remediation: Optional[str] = None
cwe_id: Optional[str] = None
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary."""
return {
'check_name': self.check_name,
'severity': self.severity,
'cvss_score': self.cvss_score,
'description': self.description,
'evidence': self.evidence,
'remediation': self.remediation,
'cwe_id': self.cwe_id
}
class VulnerabilityAssessor:
"""Performs passive vulnerability assessment on search results."""
# API key patterns for detection
API_KEY_PATTERNS = {
'anthropic': {
'pattern': r'sk-ant-[a-zA-Z0-9-_]{20,}',
'description': 'Anthropic API key exposed',
'cvss': 10.0
},
'openai': {
'pattern': r'sk-[a-zA-Z0-9]{32,}',
'description': 'OpenAI API key exposed',
'cvss': 10.0
},
'aws_access_key': {
'pattern': r'AKIA[0-9A-Z]{16}',
'description': 'AWS Access Key ID exposed',
'cvss': 9.8
},
'aws_secret': {
'pattern': r'(?<![A-Za-z0-9/+=])[A-Za-z0-9/+=]{40}(?![A-Za-z0-9/+=])',
'description': 'Potential AWS Secret Key exposed',
'cvss': 9.8
},
'github_token': {
'pattern': r'ghp_[a-zA-Z0-9]{36}',
'description': 'GitHub Personal Access Token exposed',
'cvss': 9.5
},
'google_api': {
'pattern': r'AIza[0-9A-Za-z\-_]{35}',
'description': 'Google API key exposed',
'cvss': 7.5
},
'stripe': {
'pattern': r'sk_live_[0-9a-zA-Z]{24}',
'description': 'Stripe Secret Key exposed',
'cvss': 9.8
}
}
# Dangerous functionality patterns
DANGEROUS_PATTERNS = {
'shell_access': {
'patterns': [r'/shell', r'/exec', r'/execute', r'/api/execute', r'/cmd'],
'description': 'Shell command execution endpoint detected',
'cvss': 9.9,
'severity': 'CRITICAL'
},
'debug_mode': {
'patterns': [r'DEBUG\s*[=:]\s*[Tt]rue', r'debug\s*mode', r'stack\s*trace'],
'description': 'Debug mode appears to be enabled',
'cvss': 7.5,
'severity': 'HIGH'
},
'file_upload': {
'patterns': [r'/upload', r'/api/files', r'multipart/form-data'],
'description': 'File upload functionality detected',
'cvss': 7.8,
'severity': 'HIGH'
},
'admin_panel': {
'patterns': [r'/admin', r'admin\s*panel', r'administrator'],
'description': 'Admin panel potentially exposed',
'cvss': 8.5,
'severity': 'HIGH'
},
'database_exposed': {
'patterns': [r'mongodb://', r'mysql://', r'postgresql://', r'redis://'],
'description': 'Database connection string exposed',
'cvss': 9.5,
'severity': 'CRITICAL'
}
}
# Information disclosure patterns
INFO_DISCLOSURE_PATTERNS = {
'env_file': {
'patterns': [r'\.env', r'environment\s*variables?'],
'description': 'Environment file or variables exposed',
'cvss': 8.0,
'severity': 'HIGH'
},
'config_file': {
'patterns': [r'config\.json', r'settings\.py', r'application\.yml'],
'description': 'Configuration file exposed',
'cvss': 7.5,
'severity': 'HIGH'
},
'git_exposed': {
'patterns': [r'\.git/', r'\.git/config'],
'description': 'Git repository exposed',
'cvss': 7.0,
'severity': 'MEDIUM'
},
'source_code': {
'patterns': [r'\.py$', r'\.js$', r'\.php$'],
'description': 'Source code files potentially exposed',
'cvss': 6.5,
'severity': 'MEDIUM'
}
}
def __init__(self, config: Optional[Dict[str, Any]] = None, threat_enricher: Optional['ThreatEnricher'] = None):
"""
Initialize VulnerabilityAssessor.
Args:
config: Configuration dictionary
threat_enricher: Optional ThreatEnricher for ClawSec integration
"""
self.config = config or {}
self.passive_only = self.config.get('passive_only', True)
self.threat_enricher = threat_enricher
def assess(self, result: SearchResult) -> List[Vulnerability]:
"""
Perform vulnerability assessment on a search result.
Args:
result: SearchResult to assess
Returns:
List of discovered vulnerabilities
"""
vulnerabilities = []
# Check for API key exposure in banner
if result.banner:
vulnerabilities.extend(self._check_api_keys(result.banner))
# Check for dangerous functionality
vulnerabilities.extend(self._check_dangerous_functionality(result))
# Check for information disclosure
vulnerabilities.extend(self._check_information_disclosure(result))
# Check SSL/TLS issues
vulnerabilities.extend(self._check_ssl_issues(result))
# Check for authentication issues (based on metadata)
vulnerabilities.extend(self._check_authentication(result))
# Add pre-existing vulnerability indicators
for vuln_name in result.vulnerabilities:
if not any(v.check_name == vuln_name for v in vulnerabilities):
vulnerabilities.append(self._create_from_indicator(vuln_name))
logger.debug(f"Assessed {result.ip}:{result.port} - {len(vulnerabilities)} vulnerabilities")
return vulnerabilities
def assess_batch(self, results: List[SearchResult]) -> Dict[str, List[Vulnerability]]:
"""
Assess multiple results.
Args:
results: List of SearchResults
Returns:
Dictionary mapping result keys to vulnerability lists
"""
assessments = {}
for result in results:
key = f"{result.ip}:{result.port}"
assessments[key] = self.assess(result)
return assessments
def assess_with_intel(self, result: SearchResult) -> List[Vulnerability]:
"""
Perform vulnerability assessment enhanced with threat intelligence.
1. Enrich result with ClawSec CVE data
2. Run standard passive checks
3. Create Vulnerability objects for matched CVEs
Args:
result: SearchResult to assess
Returns:
List of discovered vulnerabilities including ClawSec CVEs
"""
# First, enrich the result if threat enricher is available
if self.threat_enricher:
result = self.threat_enricher.enrich(result)
# Run standard assessment
vulnerabilities = self.assess(result)
# Add ClawSec CVE vulnerabilities
if self.threat_enricher:
clawsec_vulns = self._create_clawsec_vulnerabilities(result)
vulnerabilities.extend(clawsec_vulns)
return vulnerabilities
def _create_clawsec_vulnerabilities(self, result: SearchResult) -> List[Vulnerability]:
"""
Convert ClawSec advisory data to Vulnerability objects.
Args:
result: SearchResult with clawsec_advisories in metadata
Returns:
List of Vulnerability objects from ClawSec data
"""
vulns = []
clawsec_data = result.metadata.get('clawsec_advisories', [])
for advisory in clawsec_data:
vulns.append(Vulnerability(
check_name=f"clawsec_{advisory['cve_id']}",
severity=advisory.get('severity', 'MEDIUM'),
cvss_score=advisory.get('cvss_score', 7.0),
description=f"[ClawSec] {advisory.get('title', 'Known vulnerability')}",
evidence={
'cve_id': advisory['cve_id'],
'source': 'ClawSec',
'vuln_type': advisory.get('vuln_type', 'unknown'),
'nvd_url': advisory.get('nvd_url')
},
remediation=advisory.get('action', 'See ClawSec advisory for remediation steps'),
cwe_id=advisory.get('cwe_id')
))
return vulns
def _check_api_keys(self, text: str) -> List[Vulnerability]:
"""Check for exposed API keys in text."""
vulnerabilities = []
for key_type, config in self.API_KEY_PATTERNS.items():
if re.search(config['pattern'], text, re.IGNORECASE):
vulnerabilities.append(Vulnerability(
check_name=f"api_key_exposure_{key_type}",
severity="CRITICAL",
cvss_score=config['cvss'],
description=config['description'],
evidence={'pattern_matched': key_type},
remediation="Immediately rotate the exposed API key and remove from public-facing content",
cwe_id="CWE-798"
))
return vulnerabilities
def _check_dangerous_functionality(self, result: SearchResult) -> List[Vulnerability]:
"""Check for dangerous functionality indicators."""
vulnerabilities = []
http_info = result.metadata.get('http') or {}
text = (result.banner or '') + str(http_info)
for check_name, config in self.DANGEROUS_PATTERNS.items():
for pattern in config['patterns']:
if re.search(pattern, text, re.IGNORECASE):
vulnerabilities.append(Vulnerability(
check_name=check_name,
severity=config['severity'],
cvss_score=config['cvss'],
description=config['description'],
evidence={'pattern': pattern},
remediation=self._get_remediation(check_name)
))
break # Only add once per check
return vulnerabilities
def _check_information_disclosure(self, result: SearchResult) -> List[Vulnerability]:
"""Check for information disclosure."""
vulnerabilities = []
text = (result.banner or '') + str(result.metadata)
for check_name, config in self.INFO_DISCLOSURE_PATTERNS.items():
for pattern in config['patterns']:
if re.search(pattern, text, re.IGNORECASE):
vulnerabilities.append(Vulnerability(
check_name=f"info_disclosure_{check_name}",
severity=config['severity'],
cvss_score=config['cvss'],
description=config['description'],
evidence={'pattern': pattern},
remediation="Remove or restrict access to sensitive files"
))
break
return vulnerabilities
def _check_ssl_issues(self, result: SearchResult) -> List[Vulnerability]:
"""Check for SSL/TLS issues."""
vulnerabilities = []
ssl_info = result.metadata.get('ssl') or {}
if not ssl_info:
# No SSL on HTTPS port might be an issue
if result.port in [443, 8443]:
vulnerabilities.append(Vulnerability(
check_name="no_ssl_on_https_port",
severity="MEDIUM",
cvss_score=5.3,
description="HTTPS port without SSL/TLS",
remediation="Configure proper SSL/TLS certificate"
))
return vulnerabilities
cert = ssl_info.get('cert') or {}
# Check for expired certificate
if cert.get('expired', False):
vulnerabilities.append(Vulnerability(
check_name="expired_ssl_certificate",
severity="MEDIUM",
cvss_score=5.0,
description="SSL certificate has expired",
remediation="Renew SSL certificate",
cwe_id="CWE-295"
))
# Check for self-signed certificate
if cert.get('self_signed', False):
vulnerabilities.append(Vulnerability(
check_name="self_signed_certificate",
severity="LOW",
cvss_score=3.0,
description="Self-signed SSL certificate detected",
remediation="Use a certificate from a trusted CA"
))
return vulnerabilities
def _check_authentication(self, result: SearchResult) -> List[Vulnerability]:
"""Check for authentication issues."""
vulnerabilities = []
http_info = result.metadata.get('http') or {}
if not http_info:
return vulnerabilities
# Check for missing authentication on sensitive endpoints
status = http_info.get('status')
if status == 200:
# 200 OK on root might indicate no auth
title = http_info.get('title') or ''
title = title.lower()
if any(term in title for term in ['dashboard', 'admin', 'control panel']):
vulnerabilities.append(Vulnerability(
check_name="no_authentication",
severity="CRITICAL",
cvss_score=9.1,
description="Dashboard accessible without authentication",
evidence={'http_title': http_info.get('title')},
remediation="Implement authentication mechanism",
cwe_id="CWE-306"
))
return vulnerabilities
def _create_from_indicator(self, indicator: str) -> Vulnerability:
"""Create a Vulnerability from a string indicator."""
# Map common indicators to vulnerabilities
indicator_map = {
'debug_mode_enabled': ('DEBUG', 'HIGH', 7.5, "Debug mode is enabled"),
'potential_api_key_exposure': ('API Keys', 'CRITICAL', 9.0, "Potential API key exposure detected"),
'expired_ssl_certificate': ('SSL', 'MEDIUM', 5.0, "Expired SSL certificate"),
'no_security_txt': ('Config', 'LOW', 2.0, "No security.txt file found"),
'self_signed_certificate': ('SSL', 'LOW', 3.0, "Self-signed certificate"),
}
if indicator in indicator_map:
category, severity, cvss, desc = indicator_map[indicator]
return Vulnerability(
check_name=indicator,
severity=severity,
cvss_score=cvss,
description=desc
)
# Default for unknown indicators
return Vulnerability(
check_name=indicator,
severity="INFO",
cvss_score=1.0,
description=f"Indicator detected: {indicator}"
)
def _get_remediation(self, check_name: str) -> str:
"""Get remediation advice for a vulnerability."""
remediations = {
'shell_access': "Disable or restrict shell execution endpoints. Implement authentication and authorization.",
'debug_mode': "Disable debug mode in production environments.",
'file_upload': "Implement file type validation, size limits, and malware scanning.",
'admin_panel': "Restrict admin panel access to authorized networks. Implement strong authentication.",
'database_exposed': "Remove database connection strings from public-facing content. Use environment variables.",
}
return remediations.get(check_name, "Review and remediate the identified issue.")
def get_severity_counts(self, vulnerabilities: List[Vulnerability]) -> Dict[str, int]:
"""Get count of vulnerabilities by severity."""
counts = {'CRITICAL': 0, 'HIGH': 0, 'MEDIUM': 0, 'LOW': 0, 'INFO': 0}
for vuln in vulnerabilities:
if vuln.severity in counts:
counts[vuln.severity] += 1
return counts
+10
View File
@@ -0,0 +1,10 @@
"""Search engine modules for AASRT."""
from .base import BaseSearchEngine, SearchResult
from .shodan_engine import ShodanEngine
__all__ = [
'BaseSearchEngine',
'SearchResult',
'ShodanEngine'
]
+183
View File
@@ -0,0 +1,183 @@
"""Abstract base class for search engine integrations."""
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
import time
from src.utils.logger import get_logger
from src.utils.exceptions import RateLimitException
logger = get_logger(__name__)
@dataclass
class SearchResult:
"""Represents a single search result from any engine."""
ip: str
port: int
hostname: Optional[str] = None
service: Optional[str] = None
banner: Optional[str] = None
vulnerabilities: List[str] = field(default_factory=list)
metadata: Dict[str, Any] = field(default_factory=dict)
source_engine: Optional[str] = None
timestamp: Optional[str] = None
risk_score: float = 0.0
confidence: int = 100
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary."""
return {
'ip': self.ip,
'port': self.port,
'hostname': self.hostname,
'service': self.service,
'banner': self.banner,
'vulnerabilities': self.vulnerabilities,
'metadata': self.metadata,
'source_engine': self.source_engine,
'timestamp': self.timestamp,
'risk_score': self.risk_score,
'confidence': self.confidence
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'SearchResult':
"""Create from dictionary."""
return cls(
ip=data.get('ip', ''),
port=data.get('port', 0),
hostname=data.get('hostname'),
service=data.get('service'),
banner=data.get('banner'),
vulnerabilities=data.get('vulnerabilities', []),
metadata=data.get('metadata', {}),
source_engine=data.get('source_engine'),
timestamp=data.get('timestamp'),
risk_score=data.get('risk_score', 0.0),
confidence=data.get('confidence', 100)
)
class BaseSearchEngine(ABC):
"""Abstract base class for all search engine integrations."""
def __init__(
self,
api_key: str,
rate_limit: float = 1.0,
timeout: int = 30,
max_results: int = 100
):
"""
Initialize the search engine.
Args:
api_key: API key for authentication
rate_limit: Maximum queries per second
timeout: Request timeout in seconds
max_results: Maximum results to return per query
"""
self.api_key = api_key
self.rate_limit = rate_limit
self.timeout = timeout
self.max_results = max_results
self._last_request_time = 0.0
self._request_count = 0
@property
@abstractmethod
def name(self) -> str:
"""Return the engine name."""
pass
@abstractmethod
def search(self, query: str, max_results: Optional[int] = None) -> List[SearchResult]:
"""
Execute a search query and return results.
Args:
query: Search query string
max_results: Maximum number of results to return (overrides default)
Returns:
List of SearchResult objects
Raises:
APIException: If API call fails
RateLimitException: If rate limit exceeded
"""
pass
@abstractmethod
def validate_credentials(self) -> bool:
"""
Validate API credentials.
Returns:
True if credentials are valid
Raises:
AuthenticationException: If credentials are invalid
"""
pass
@abstractmethod
def get_quota_info(self) -> Dict[str, Any]:
"""
Get API quota/usage information.
Returns:
Dictionary with quota information
"""
pass
def _rate_limit_wait(self) -> None:
"""Enforce rate limiting between requests."""
if self.rate_limit <= 0:
return
min_interval = 1.0 / self.rate_limit
elapsed = time.time() - self._last_request_time
if elapsed < min_interval:
wait_time = min_interval - elapsed
logger.debug(f"Rate limiting: waiting {wait_time:.2f}s")
time.sleep(wait_time)
self._last_request_time = time.time()
self._request_count += 1
def _check_rate_limit(self) -> None:
"""Check if rate limit is being approached."""
# This can be overridden by specific engines with their own rate limit logic
pass
def _parse_result(self, raw_result: Dict[str, Any]) -> SearchResult:
"""
Parse a raw API result into a SearchResult.
Args:
raw_result: Raw result from API
Returns:
SearchResult object
"""
# Default implementation - should be overridden by specific engines
return SearchResult(
ip=raw_result.get('ip', ''),
port=raw_result.get('port', 0),
source_engine=self.name
)
def get_stats(self) -> Dict[str, Any]:
"""Get engine statistics."""
return {
'engine': self.name,
'request_count': self._request_count,
'rate_limit': self.rate_limit,
'timeout': self.timeout,
'max_results': self.max_results
}
+589
View File
@@ -0,0 +1,589 @@
"""
Shodan search engine integration for AASRT.
This module provides a production-ready integration with the Shodan API
for security reconnaissance. Features include:
- Automatic retry with exponential backoff for transient failures
- Rate limiting to prevent API quota exhaustion
- Comprehensive error handling with specific exception types
- Detailed logging for debugging and monitoring
- Graceful degradation when API is unavailable
Example:
>>> from src.engines.shodan_engine import ShodanEngine
>>> engine = ShodanEngine(api_key="your_key")
>>> engine.validate_credentials()
True
>>> results = engine.search("http.html:clawdbot", max_results=10)
"""
from typing import Any, Callable, Dict, List, Optional, TypeVar
from datetime import datetime
from functools import wraps
import time
import socket
import shodan
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
before_sleep_log,
RetryError
)
from .base import BaseSearchEngine, SearchResult
from src.utils.logger import get_logger
from src.utils.validators import validate_ip, sanitize_output
from src.utils.exceptions import (
APIException,
RateLimitException,
AuthenticationException,
TimeoutException
)
logger = get_logger(__name__)
# Type variable for generic retry decorator
T = TypeVar('T')
# =============================================================================
# Retry Configuration
# =============================================================================
# Exceptions that should trigger a retry (transient failures)
RETRYABLE_EXCEPTIONS = (
socket.timeout,
ConnectionError,
ConnectionResetError,
TimeoutError,
)
# Maximum number of retry attempts
MAX_RETRY_ATTEMPTS = 3
# Base delay for exponential backoff (seconds)
RETRY_BASE_DELAY = 2
# Maximum delay between retries (seconds)
RETRY_MAX_DELAY = 30
def with_retry(func: Callable[..., T]) -> Callable[..., T]:
"""
Decorator that adds retry logic with exponential backoff.
Retries on transient network errors but not on authentication
or validation errors.
Args:
func: Function to wrap with retry logic.
Returns:
Wrapped function with retry capability.
"""
@wraps(func)
def wrapper(*args, **kwargs) -> T:
last_exception = None
for attempt in range(1, MAX_RETRY_ATTEMPTS + 1):
try:
return func(*args, **kwargs)
except RETRYABLE_EXCEPTIONS as e:
last_exception = e
if attempt < MAX_RETRY_ATTEMPTS:
delay = min(RETRY_BASE_DELAY ** attempt, RETRY_MAX_DELAY)
logger.warning(
f"Retry {attempt}/{MAX_RETRY_ATTEMPTS} for {func.__name__} "
f"after {delay}s delay. Error: {e}"
)
time.sleep(delay)
else:
logger.error(
f"All {MAX_RETRY_ATTEMPTS} retries exhausted for {func.__name__}. "
f"Last error: {e}"
)
except (AuthenticationException, RateLimitException):
# Don't retry auth or rate limit errors
raise
except shodan.APIError as e:
error_msg = str(e).lower()
# Don't retry permanent errors
if "invalid api key" in error_msg:
raise AuthenticationException(
"Invalid Shodan API key",
engine="shodan"
)
if "rate limit" in error_msg:
raise RateLimitException(
f"Shodan rate limit exceeded: {e}",
engine="shodan"
)
# Retry other API errors
last_exception = e
if attempt < MAX_RETRY_ATTEMPTS:
delay = min(RETRY_BASE_DELAY ** attempt, RETRY_MAX_DELAY)
logger.warning(
f"Retry {attempt}/{MAX_RETRY_ATTEMPTS} for {func.__name__} "
f"after {delay}s delay. API Error: {e}"
)
time.sleep(delay)
# All retries exhausted
if last_exception:
raise APIException(
f"Operation failed after {MAX_RETRY_ATTEMPTS} retries: {last_exception}",
engine="shodan"
)
raise APIException("Unexpected retry failure", engine="shodan")
return wrapper
class ShodanEngine(BaseSearchEngine):
"""
Shodan search engine integration for security reconnaissance.
This class provides a production-ready interface to the Shodan API with:
- Automatic rate limiting to respect API quotas
- Retry logic with exponential backoff for transient failures
- Comprehensive error handling and logging
- Result parsing with vulnerability detection
Attributes:
name: Engine identifier ("shodan").
api_key: Shodan API key (masked in logs).
rate_limit: Maximum requests per second.
timeout: Request timeout in seconds.
max_results: Default maximum results per search.
Example:
>>> engine = ShodanEngine(api_key="your_key")
>>> if engine.validate_credentials():
... results = engine.search("http.html:agent", max_results=50)
... for result in results:
... print(f"{result.ip}:{result.port}")
"""
def __init__(
self,
api_key: str,
rate_limit: float = 1.0,
timeout: int = 30,
max_results: int = 100
) -> None:
"""
Initialize Shodan engine with API credentials.
Args:
api_key: Shodan API key from https://account.shodan.io/.
Never log or expose this value.
rate_limit: Maximum queries per second. Default 1.0 to respect
Shodan's free tier limits.
timeout: Request timeout in seconds. Increase for slow connections.
max_results: Maximum results per query. Higher values consume
more query credits.
Raises:
ValueError: If api_key is empty or None.
"""
if not api_key or not api_key.strip():
raise ValueError("Shodan API key is required")
super().__init__(api_key, rate_limit, timeout, max_results)
self._client = shodan.Shodan(api_key)
self._api_key_preview = f"{api_key[:4]}...{api_key[-4:]}" if len(api_key) > 8 else "***"
logger.debug(f"ShodanEngine initialized with key: {self._api_key_preview}")
@property
def name(self) -> str:
"""Return engine identifier."""
return "shodan"
@with_retry
def validate_credentials(self) -> bool:
"""
Validate Shodan API credentials by making a test API call.
This method performs a lightweight API call to verify the API key
is valid and has not been revoked.
Returns:
True if credentials are valid and API is accessible.
Raises:
AuthenticationException: If API key is invalid or revoked.
APIException: If API call fails for other reasons.
Example:
>>> engine = ShodanEngine(api_key="your_key")
>>> try:
... engine.validate_credentials()
... print("API key is valid")
... except AuthenticationException:
... print("Invalid API key")
"""
try:
info = self._client.info()
plan = info.get('plan', 'unknown')
credits = info.get('query_credits', 0)
logger.info(
f"Shodan API validated. Plan: {plan}, "
f"Query credits: {credits}"
)
return True
except shodan.APIError as e:
error_msg = str(e)
if "Invalid API key" in error_msg:
logger.error("Shodan authentication failed: Invalid API key")
raise AuthenticationException(
"Invalid Shodan API key",
engine=self.name
)
logger.error(f"Shodan API validation error: {sanitize_output(error_msg)}")
raise APIException(f"Shodan API error: {e}", engine=self.name)
@with_retry
def get_quota_info(self) -> Dict[str, Any]:
"""
Get Shodan API quota and usage information.
Returns:
Dictionary containing:
- engine: Engine name ("shodan")
- plan: API plan type (e.g., "dev", "edu", "corp")
- query_credits: Remaining query credits
- scan_credits: Remaining scan credits
- monitored_ips: Number of monitored IPs
- unlocked: Whether account has unlocked features
- error: Error message if call failed (optional)
Note:
This call does not consume query credits.
"""
try:
info = self._client.info()
quota = {
'engine': self.name,
'plan': info.get('plan', 'unknown'),
'query_credits': info.get('query_credits', 0),
'scan_credits': info.get('scan_credits', 0),
'monitored_ips': info.get('monitored_ips', 0),
'unlocked': info.get('unlocked', False),
'timestamp': datetime.utcnow().isoformat()
}
logger.debug(f"Shodan quota retrieved: {quota['query_credits']} credits remaining")
return quota
except shodan.APIError as e:
logger.error(f"Failed to get Shodan quota: {sanitize_output(str(e))}")
return {
'engine': self.name,
'error': str(e),
'timestamp': datetime.utcnow().isoformat()
}
def search(self, query: str, max_results: Optional[int] = None) -> List[SearchResult]:
"""
Execute a Shodan search query with automatic pagination.
This method handles pagination automatically, respecting rate limits
and the specified maximum results. Each page consumes one query credit.
Args:
query: Shodan search query string. Supports Shodan's query syntax
including filters like http.html:, port:, country:, etc.
max_results: Maximum number of results to return. Defaults to
the engine's max_results setting. Set to None for default.
Returns:
List of SearchResult objects containing parsed Shodan data.
May return fewer results than max_results if not enough matches.
Raises:
APIException: If API call fails after all retries.
RateLimitException: If rate limit is exceeded.
AuthenticationException: If API key is invalid.
ValidationException: If query is invalid.
Example:
>>> results = engine.search("http.html:clawdbot", max_results=50)
>>> for r in results:
... print(f"{r.ip}:{r.port} - {r.service}")
Note:
- Shodan returns max 100 results per page
- Multiple pages consume multiple query credits
- Consider using count() first to check total results
"""
# Validate and sanitize query
if not query or not query.strip():
raise APIException("Search query cannot be empty", engine=self.name)
query = query.strip()
limit = max_results or self.max_results
# Log sanitized query (remove potential sensitive data)
safe_query = sanitize_output(query)
logger.info(f"Executing Shodan search: {safe_query} (limit: {limit})")
results: List[SearchResult] = []
page = 1
total_pages = 0
start_time = time.time()
try:
while len(results) < limit:
# Apply rate limiting before each request
self._rate_limit_wait()
# Execute search with retry logic
response = self._execute_search_page(query, page)
if response is None or not response.get('matches'):
logger.debug(f"No more matches at page {page}")
break
# Parse matches
matches = response.get('matches', [])
for match in matches:
if len(results) >= limit:
break
try:
result = self._parse_result(match)
results.append(result)
except Exception as e:
# Log but continue on parse errors
logger.warning(f"Failed to parse result: {e}")
continue
# Check pagination limits
total = response.get('total', 0)
total_pages = (total + 99) // 100 # Ceiling division
if len(results) >= total or len(results) >= limit:
break
page += 1
# Safety limit to prevent infinite loops
if page > 100:
logger.warning("Reached maximum page limit (100)")
break
# Log completion stats
elapsed = time.time() - start_time
logger.info(
f"Shodan search complete: {len(results)} results "
f"from {page} pages in {elapsed:.2f}s"
)
return results
except (AuthenticationException, RateLimitException):
# Re-raise known exceptions without wrapping
raise
except shodan.APIError as e:
error_msg = str(e).lower()
if "rate limit" in error_msg:
logger.error("Shodan rate limit exceeded during search")
raise RateLimitException(
f"Shodan rate limit exceeded: {e}",
engine=self.name
)
elif "invalid api key" in error_msg:
logger.error("Shodan authentication failed during search")
raise AuthenticationException(
"Invalid Shodan API key",
engine=self.name
)
else:
logger.error(f"Shodan API error: {sanitize_output(str(e))}")
raise APIException(
f"Shodan search failed: {e}",
engine=self.name
)
except Exception as e:
logger.exception(f"Unexpected error in Shodan search: {e}")
raise APIException(
f"Shodan search error: {type(e).__name__}: {e}",
engine=self.name
)
@with_retry
def _execute_search_page(self, query: str, page: int) -> Optional[Dict[str, Any]]:
"""
Execute a single page of Shodan search with retry logic.
Args:
query: Search query string.
page: Page number (1-indexed).
Returns:
Shodan API response dictionary or None on failure.
"""
logger.debug(f"Fetching Shodan results page {page}")
return self._client.search(query, page=page)
def _parse_result(self, match: Dict[str, Any]) -> SearchResult:
"""
Parse a Shodan match into a SearchResult.
Args:
match: Raw Shodan match data
Returns:
SearchResult object
"""
# Extract vulnerability indicators from data
vulnerabilities = []
data = match.get('data', '')
# Check for common vulnerability indicators
if 'debug' in data.lower() or 'DEBUG=True' in data:
vulnerabilities.append('debug_mode_enabled')
if 'api_key' in data.lower() or 'apikey' in data.lower():
vulnerabilities.append('potential_api_key_exposure')
ssl_data = match.get('ssl') or {}
ssl_cert = ssl_data.get('cert') or {}
if ssl_cert.get('expired', False):
vulnerabilities.append('expired_ssl_certificate')
# Check HTTP response for issues
http_data = match.get('http') or {}
if http_data:
if not http_data.get('securitytxt'):
vulnerabilities.append('no_security_txt')
# Build metadata
location_data = match.get('location') or {}
metadata = {
'asn': match.get('asn'),
'isp': match.get('isp'),
'org': match.get('org'),
'os': match.get('os'),
'transport': match.get('transport'),
'product': match.get('product'),
'version': match.get('version'),
'cpe': match.get('cpe', []),
'http': http_data,
'ssl': ssl_data,
'location': {
'country': location_data.get('country_name'),
'city': location_data.get('city'),
'latitude': location_data.get('latitude'),
'longitude': location_data.get('longitude')
}
}
# Extract hostnames
hostnames = match.get('hostnames', [])
hostname = hostnames[0] if hostnames else None
return SearchResult(
ip=match.get('ip_str', ''),
port=match.get('port', 0),
hostname=hostname,
service=match.get('product') or match.get('_shodan', {}).get('module'),
banner=data[:1000] if data else None, # Truncate long banners
vulnerabilities=vulnerabilities,
metadata=metadata,
source_engine=self.name,
timestamp=match.get('timestamp', datetime.utcnow().isoformat())
)
@with_retry
def host_info(self, ip: str) -> Dict[str, Any]:
"""
Get detailed information about a specific host.
This method retrieves comprehensive information about a host
including all open ports, services, banners, and historical data.
Args:
ip: IP address to lookup. Must be a valid IPv4 address.
Returns:
Dictionary containing:
- ip_str: IP address as string
- ports: List of open ports
- data: List of service banners per port
- hostnames: List of hostnames
- vulns: List of vulnerabilities (if any)
- location: Geographic information
Raises:
APIException: If lookup fails.
ValidationException: If IP address is invalid.
Example:
>>> info = engine.host_info("8.8.8.8")
>>> print(f"Ports: {info.get('ports', [])}")
"""
# Validate IP address
try:
validate_ip(ip)
except Exception as e:
raise APIException(f"Invalid IP address: {ip}", engine=self.name)
self._rate_limit_wait()
logger.debug(f"Looking up host info for: {ip}")
try:
host_data = self._client.host(ip)
logger.info(f"Retrieved host info for {ip}: {len(host_data.get('ports', []))} ports")
return host_data
except shodan.APIError as e:
error_msg = str(e).lower()
if "no information available" in error_msg:
logger.info(f"No Shodan data available for {ip}")
return {'ip_str': ip, 'ports': [], 'data': []}
logger.error(f"Failed to get host info for {ip}: {sanitize_output(str(e))}")
raise APIException(f"Shodan host lookup failed: {e}", engine=self.name)
@with_retry
def count(self, query: str) -> int:
"""
Get the count of results for a query without consuming query credits.
Use this method to estimate result count before running a full search
to avoid consuming query credits unnecessarily.
Args:
query: Shodan search query string.
Returns:
Estimated number of matching results. Returns 0 on error.
Note:
- Does not consume query credits
- Count may be approximate for large result sets
- Useful for validating queries before running searches
Example:
>>> count = engine.count("http.html:clawdbot")
>>> if count > 0:
... results = engine.search("http.html:clawdbot")
"""
if not query or not query.strip():
logger.warning("Empty query provided to count()")
return 0
self._rate_limit_wait()
logger.debug(f"Counting results for query: {sanitize_output(query)}")
try:
result = self._client.count(query)
total = result.get('total', 0)
logger.info(f"Query '{sanitize_output(query)}' has {total} results")
return total
except shodan.APIError as e:
logger.error(f"Failed to count results: {sanitize_output(str(e))}")
return 0
except Exception as e:
logger.exception(f"Unexpected error in count: {e}")
return 0
+19
View File
@@ -0,0 +1,19 @@
"""Enrichment modules for AASRT.
This module contains data enrichment capabilities:
- ClawSec threat intelligence integration
- (Future) WHOIS lookups
- (Future) Geolocation
- (Future) SSL/TLS certificate analysis
- (Future) DNS records
"""
from .clawsec_feed import ClawSecFeedManager, ClawSecFeed, ClawSecAdvisory
from .threat_enricher import ThreatEnricher
__all__ = [
'ClawSecFeedManager',
'ClawSecFeed',
'ClawSecAdvisory',
'ThreatEnricher'
]
+380
View File
@@ -0,0 +1,380 @@
"""ClawSec Threat Intelligence Feed Manager for AASRT."""
import json
import os
import threading
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from pathlib import Path
from typing import Any, Dict, List, Optional
import requests
from src.utils.logger import get_logger
logger = get_logger(__name__)
@dataclass
class ClawSecAdvisory:
"""Represents a single ClawSec CVE advisory."""
cve_id: str
severity: str # CRITICAL, HIGH, MEDIUM, LOW
vuln_type: str # e.g., "prompt_injection", "missing_authentication"
cvss_score: float
title: str
description: str
affected: List[str] = field(default_factory=list)
action: str = ""
nvd_url: Optional[str] = None
cwe_id: Optional[str] = None
published_date: Optional[datetime] = None
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary."""
return {
'cve_id': self.cve_id,
'severity': self.severity,
'vuln_type': self.vuln_type,
'cvss_score': self.cvss_score,
'title': self.title,
'description': self.description,
'affected': self.affected,
'action': self.action,
'nvd_url': self.nvd_url,
'cwe_id': self.cwe_id,
'published_date': self.published_date.isoformat() if self.published_date else None
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'ClawSecAdvisory':
"""Create from dictionary."""
published = data.get('published')
if published and isinstance(published, str):
try:
published = datetime.fromisoformat(published.replace('Z', '+00:00'))
except:
published = None
return cls(
cve_id=data.get('id', ''),
severity=data.get('severity', 'MEDIUM').upper(),
vuln_type=data.get('type', 'unknown'),
cvss_score=float(data.get('cvss_score', 0.0)),
title=data.get('title', ''),
description=data.get('description', ''),
affected=data.get('affected', []),
action=data.get('action', ''),
nvd_url=data.get('nvd_url'),
cwe_id=data.get('nvd_category_id'),
published_date=published
)
@dataclass
class ClawSecFeed:
"""Container for the full ClawSec advisory feed."""
advisories: List[ClawSecAdvisory]
last_updated: datetime
feed_version: str
total_count: int
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for caching."""
return {
'advisories': [a.to_dict() for a in self.advisories],
'last_updated': self.last_updated.isoformat(),
'feed_version': self.feed_version,
'total_count': self.total_count
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'ClawSecFeed':
"""Create from dictionary."""
return cls(
advisories=[ClawSecAdvisory.from_dict(a) for a in data.get('advisories', [])],
last_updated=datetime.fromisoformat(data.get('last_updated', datetime.utcnow().isoformat())),
feed_version=data.get('feed_version', '0.0.0'),
total_count=data.get('total_count', 0)
)
class ClawSecFeedManager:
"""
Manages ClawSec threat intelligence feed with caching and offline support.
Features:
- HTTP fetch with configurable timeout
- Local file caching for offline mode
- Advisory matching by product/version/banner
- Non-blocking background updates
"""
DEFAULT_FEED_URL = "https://clawsec.prompt.security/advisories/feed.json"
DEFAULT_CACHE_FILE = "./data/clawsec_cache.json"
DEFAULT_TTL = 86400 # 24 hours
def __init__(self, config=None):
"""
Initialize ClawSecFeedManager.
Args:
config: Configuration object with clawsec settings
"""
self.config = config
# Get configuration values
if config:
clawsec_config = config.get('clawsec', default={})
self.feed_url = clawsec_config.get('feed_url', self.DEFAULT_FEED_URL)
self.cache_file = clawsec_config.get('cache_file', self.DEFAULT_CACHE_FILE)
self.cache_ttl = clawsec_config.get('cache_ttl_seconds', self.DEFAULT_TTL)
self.offline_mode = clawsec_config.get('offline_mode', False)
self.timeout = clawsec_config.get('timeout', 30)
else:
self.feed_url = self.DEFAULT_FEED_URL
self.cache_file = self.DEFAULT_CACHE_FILE
self.cache_ttl = self.DEFAULT_TTL
self.offline_mode = False
self.timeout = 30
self._cache: Optional[ClawSecFeed] = None
self._cache_timestamp: Optional[datetime] = None
self._lock = threading.Lock()
def fetch_feed(self, force_refresh: bool = False) -> Optional[ClawSecFeed]:
"""
Fetch the ClawSec advisory feed.
Args:
force_refresh: Force fetch from URL even if cache is valid
Returns:
ClawSecFeed object or None if fetch fails
"""
# Check cache first
if not force_refresh and self.is_cache_valid():
logger.debug("Using cached ClawSec feed")
return self._cache
# In offline mode, only use cache
if self.offline_mode:
logger.info("ClawSec offline mode - using cached data only")
return self.get_cached_feed()
try:
logger.info(f"Fetching ClawSec feed from {self.feed_url}")
response = requests.get(self.feed_url, timeout=self.timeout)
response.raise_for_status()
data = response.json()
feed = self._parse_feed(data)
with self._lock:
self._cache = feed
self._cache_timestamp = datetime.utcnow()
# Persist to disk
self.save_cache()
logger.info(f"ClawSec feed loaded: {feed.total_count} advisories")
return feed
except requests.RequestException as e:
logger.warning(f"Failed to fetch ClawSec feed: {e}")
# Fall back to cache
return self.get_cached_feed()
except (json.JSONDecodeError, KeyError) as e:
logger.error(f"Failed to parse ClawSec feed: {e}")
return self.get_cached_feed()
def _parse_feed(self, data: Dict[str, Any]) -> ClawSecFeed:
"""Parse raw feed JSON into ClawSecFeed object."""
advisories = []
for advisory_data in data.get('advisories', []):
try:
advisory = ClawSecAdvisory.from_dict(advisory_data)
advisories.append(advisory)
except Exception as e:
logger.warning(f"Failed to parse advisory: {e}")
continue
return ClawSecFeed(
advisories=advisories,
last_updated=datetime.utcnow(),
feed_version=data.get('version', '0.0.0'),
total_count=len(advisories)
)
def get_cached_feed(self) -> Optional[ClawSecFeed]:
"""Return cached feed without network call."""
if self._cache:
return self._cache
# Try loading from disk
self.load_cache()
return self._cache
def is_cache_valid(self) -> bool:
"""Check if cache is within TTL."""
if not self._cache or not self._cache_timestamp:
return False
age = datetime.utcnow() - self._cache_timestamp
return age.total_seconds() < self.cache_ttl
def save_cache(self) -> None:
"""Persist cache to local file for offline mode."""
if not self._cache:
return
try:
cache_path = Path(self.cache_file)
cache_path.parent.mkdir(parents=True, exist_ok=True)
cache_data = {
'feed': self._cache.to_dict(),
'cached_at': datetime.utcnow().isoformat()
}
with open(cache_path, 'w') as f:
json.dump(cache_data, f, indent=2)
logger.debug(f"ClawSec cache saved to {self.cache_file}")
except Exception as e:
logger.warning(f"Failed to save ClawSec cache: {e}")
def load_cache(self) -> bool:
"""Load cache from local file."""
try:
cache_path = Path(self.cache_file)
if not cache_path.exists():
return False
with open(cache_path, 'r') as f:
cache_data = json.load(f)
self._cache = ClawSecFeed.from_dict(cache_data.get('feed', {}))
cached_at = cache_data.get('cached_at')
if cached_at:
self._cache_timestamp = datetime.fromisoformat(cached_at)
logger.info(f"ClawSec cache loaded: {self._cache.total_count} advisories")
return True
except Exception as e:
logger.warning(f"Failed to load ClawSec cache: {e}")
return False
def match_advisories(
self,
product: Optional[str] = None,
version: Optional[str] = None,
banner: Optional[str] = None
) -> List[ClawSecAdvisory]:
"""
Find matching advisories for a product/version/banner.
Matching strategies (in order):
1. Exact product name match in affected list
2. Fuzzy product match (clawdbot, clawbot, claw-bot)
3. Banner text contains product from affected
Args:
product: Product name to match
version: Version string to check
banner: Banner text to search
Returns:
List of matching ClawSecAdvisory objects
"""
feed = self.get_cached_feed()
if not feed:
return []
matches = []
product_lower = (product or '').lower()
banner_lower = (banner or '').lower()
# AI agent keywords to look for
ai_keywords = ['clawdbot', 'clawbot', 'moltbot', 'openclaw', 'autogpt', 'langchain']
for advisory in feed.advisories:
matched = False
# Check each affected product
for affected in advisory.affected:
affected_lower = affected.lower()
# Strategy 1: Direct product match
if product_lower and product_lower in affected_lower:
matched = True
break
# Strategy 2: Check AI keywords in affected and product/banner
for keyword in ai_keywords:
if keyword in affected_lower:
if keyword in product_lower or keyword in banner_lower:
matched = True
break
if matched:
break
# Strategy 3: Banner contains affected product
if banner_lower:
# Extract product name from affected (e.g., "ClawdBot < 2.0" -> "clawdbot")
affected_product = affected_lower.split('<')[0].split('>')[0].strip()
if affected_product and affected_product in banner_lower:
matched = True
break
if matched and advisory not in matches:
matches.append(advisory)
logger.debug(f"ClawSec matched {len(matches)} advisories for product={product}")
return matches
def background_refresh(self) -> None:
"""Start background thread to refresh feed."""
def _refresh():
try:
self.fetch_feed(force_refresh=True)
except Exception as e:
logger.warning(f"Background ClawSec refresh failed: {e}")
thread = threading.Thread(target=_refresh, daemon=True)
thread.start()
logger.debug("ClawSec background refresh started")
def get_statistics(self) -> Dict[str, Any]:
"""Get feed statistics for UI display."""
feed = self.get_cached_feed()
if not feed:
return {
'total_advisories': 0,
'critical_count': 0,
'high_count': 0,
'last_updated': None,
'is_stale': True
}
severity_counts = {'CRITICAL': 0, 'HIGH': 0, 'MEDIUM': 0, 'LOW': 0}
for advisory in feed.advisories:
if advisory.severity in severity_counts:
severity_counts[advisory.severity] += 1
return {
'total_advisories': feed.total_count,
'critical_count': severity_counts['CRITICAL'],
'high_count': severity_counts['HIGH'],
'medium_count': severity_counts['MEDIUM'],
'low_count': severity_counts['LOW'],
'last_updated': feed.last_updated.isoformat() if feed.last_updated else None,
'feed_version': feed.feed_version,
'is_stale': not self.is_cache_valid()
}
+228
View File
@@ -0,0 +1,228 @@
"""Threat Intelligence Enrichment for AASRT."""
from typing import Any, Dict, List, Optional, Tuple
from src.engines import SearchResult
from src.utils.logger import get_logger
from .clawsec_feed import ClawSecAdvisory, ClawSecFeedManager
logger = get_logger(__name__)
class ThreatEnricher:
"""
Enriches SearchResult objects with ClawSec threat intelligence.
Responsibilities:
- Match results against ClawSec advisories
- Add CVE metadata to result.metadata
- Inject ClawSec vulnerabilities into result.vulnerabilities
"""
def __init__(self, feed_manager: ClawSecFeedManager, config=None):
"""
Initialize ThreatEnricher.
Args:
feed_manager: ClawSecFeedManager instance
config: Optional configuration object
"""
self.feed_manager = feed_manager
self.config = config
def enrich(self, result: SearchResult) -> SearchResult:
"""
Enrich a single result with threat intelligence.
Args:
result: SearchResult to enrich
Returns:
Enriched SearchResult with ClawSec metadata
"""
# Extract product info from result
product, version = self._extract_product_info(result)
banner = result.banner or ''
# Get HTTP title if available
http_info = result.metadata.get('http', {}) or {}
title = http_info.get('title') or ''
if title:
banner = f"{banner} {title}"
# Match against ClawSec advisories
advisories = self.feed_manager.match_advisories(
product=product,
version=version,
banner=banner
)
if advisories:
result = self._add_cve_context(result, advisories)
logger.debug(f"Enriched {result.ip}:{result.port} with {len(advisories)} ClawSec advisories")
return result
def enrich_batch(self, results: List[SearchResult]) -> List[SearchResult]:
"""
Enrich multiple results efficiently.
Args:
results: List of SearchResults to enrich
Returns:
List of enriched SearchResults
"""
enriched = []
for result in results:
enriched.append(self.enrich(result))
return enriched
def _extract_product_info(self, result: SearchResult) -> Tuple[Optional[str], Optional[str]]:
"""
Extract product name and version from result metadata.
Args:
result: SearchResult to analyze
Returns:
Tuple of (product_name, version) or (None, None)
"""
product = None
version = None
# Check metadata for product info
metadata = result.metadata if isinstance(result.metadata, dict) else {}
# Try product field directly
if 'product' in metadata:
product = metadata['product']
# Try version field
if 'version' in metadata:
version = metadata['version']
# Check HTTP info
http_info = metadata.get('http') or {}
if http_info:
title = http_info.get('title') or ''
# Look for AI agent keywords in title
ai_products = {
'clawdbot': 'ClawdBot',
'moltbot': 'MoltBot',
'autogpt': 'AutoGPT',
'langchain': 'LangChain',
'openclaw': 'OpenClaw'
}
for keyword, name in ai_products.items():
if title and keyword in title.lower():
product = name
break
# Check service name
if not product and result.service:
service_lower = result.service.lower()
for keyword in ['clawdbot', 'moltbot', 'autogpt', 'langchain']:
if keyword in service_lower:
product = result.service
break
# Check banner for version patterns
if result.banner and not version:
import re
version_patterns = [
r'v?(\d+\.\d+(?:\.\d+)?)', # v1.2.3 or 1.2.3
r'version[:\s]+(\d+\.\d+(?:\.\d+)?)', # version: 1.2.3
]
for pattern in version_patterns:
match = re.search(pattern, result.banner, re.IGNORECASE)
if match:
version = match.group(1)
break
return product, version
def _add_cve_context(
self,
result: SearchResult,
advisories: List[ClawSecAdvisory]
) -> SearchResult:
"""
Add CVE information to result metadata and vulnerabilities.
Args:
result: SearchResult to update
advisories: List of matched ClawSecAdvisory objects
Returns:
Updated SearchResult
"""
# Add ClawSec advisories to metadata
clawsec_data = []
for advisory in advisories:
clawsec_data.append({
'cve_id': advisory.cve_id,
'severity': advisory.severity,
'cvss_score': advisory.cvss_score,
'title': advisory.title,
'vuln_type': advisory.vuln_type,
'action': advisory.action,
'nvd_url': advisory.nvd_url,
'cwe_id': advisory.cwe_id
})
result.metadata['clawsec_advisories'] = clawsec_data
# Track highest severity for quick access
severity_order = {'CRITICAL': 4, 'HIGH': 3, 'MEDIUM': 2, 'LOW': 1}
highest_severity = max(
(a.severity for a in advisories),
key=lambda s: severity_order.get(s, 0),
default='LOW'
)
result.metadata['clawsec_severity'] = highest_severity
# Add CVE IDs to vulnerabilities list
for advisory in advisories:
vuln_id = f"clawsec_{advisory.cve_id}"
if vuln_id not in result.vulnerabilities:
result.vulnerabilities.append(vuln_id)
return result
def get_enrichment_stats(self, results: List[SearchResult]) -> Dict[str, Any]:
"""
Get statistics about enrichment for a set of results.
Args:
results: List of enriched SearchResults
Returns:
Dictionary with enrichment statistics
"""
enriched_count = 0
total_cves = 0
severity_counts = {'CRITICAL': 0, 'HIGH': 0, 'MEDIUM': 0, 'LOW': 0}
cve_list = set()
for result in results:
advisories = result.metadata.get('clawsec_advisories', [])
if advisories:
enriched_count += 1
total_cves += len(advisories)
for advisory in advisories:
cve_list.add(advisory['cve_id'])
severity = advisory.get('severity', 'LOW')
if severity in severity_counts:
severity_counts[severity] += 1
return {
'enriched_results': enriched_count,
'total_results': len(results),
'enrichment_rate': (enriched_count / len(results) * 100) if results else 0,
'unique_cves': len(cve_list),
'total_cve_matches': total_cves,
'severity_breakdown': severity_counts,
'cve_ids': list(cve_list)
}
+689
View File
@@ -0,0 +1,689 @@
"""
CLI entry point for AASRT - AI Agent Security Reconnaissance Tool.
This module provides the command-line interface for AASRT with:
- Shodan-based security reconnaissance scanning
- Vulnerability assessment and risk scoring
- Report generation (JSON/CSV)
- Database storage and history tracking
- Signal handling for graceful shutdown
Usage:
python -m src.main status # Check API status
python -m src.main scan --template clawdbot_instances
python -m src.main history # View scan history
python -m src.main templates # List available templates
Environment Variables:
SHODAN_API_KEY: Required for scanning operations
AASRT_LOG_LEVEL: Logging level (DEBUG, INFO, WARNING, ERROR)
AASRT_DEBUG: Enable debug mode (true/false)
Exit Codes:
0: Success
1: Error (invalid arguments, API errors, etc.)
130: Interrupted by user (SIGINT/Ctrl+C)
"""
import atexit
import signal
import sys
import time
import uuid
from typing import Any, Dict, Optional
import click
from rich.console import Console
from rich.table import Table
from rich.panel import Panel
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn
from src import __version__
from src.utils.config import Config
from src.utils.logger import setup_logger, get_logger
from src.core.query_manager import QueryManager
from src.core.result_aggregator import ResultAggregator
from src.core.vulnerability_assessor import VulnerabilityAssessor
from src.core.risk_scorer import RiskScorer
from src.storage.database import Database
from src.reporting import JSONReporter, CSVReporter, ScanReport
# =============================================================================
# Global State
# =============================================================================
console = Console()
_shutdown_requested = False
_active_database: Optional[Database] = None
# =============================================================================
# Signal Handlers
# =============================================================================
def _signal_handler(signum: int, frame: Any) -> None:
"""
Handle interrupt signals for graceful shutdown.
Args:
signum: Signal number received.
frame: Current stack frame (unused).
"""
global _shutdown_requested
signal_name = signal.Signals(signum).name
if _shutdown_requested:
# Second interrupt - force exit
console.print("\n[red]Force shutdown requested. Exiting immediately.[/red]")
sys.exit(130)
_shutdown_requested = True
console.print(f"\n[yellow]Received {signal_name}. Shutting down gracefully...[/yellow]")
console.print("[dim]Press Ctrl+C again to force quit.[/dim]")
def _cleanup() -> None:
"""
Cleanup function called on exit.
Closes database connections and performs cleanup.
"""
global _active_database
if _active_database:
try:
_active_database.close()
except Exception:
pass # Ignore errors during cleanup
def is_shutdown_requested() -> bool:
"""
Check if shutdown has been requested.
Returns:
True if a shutdown signal was received.
"""
return _shutdown_requested
# Register signal handlers
signal.signal(signal.SIGINT, _signal_handler)
signal.signal(signal.SIGTERM, _signal_handler)
atexit.register(_cleanup)
# =============================================================================
# Legal Disclaimer
# =============================================================================
LEGAL_DISCLAIMER = """
[bold red]WARNING: LEGAL DISCLAIMER[/bold red]
This tool is for [bold]authorized security research and defensive purposes only[/bold].
Unauthorized access to computer systems is illegal under:
- CFAA (Computer Fraud and Abuse Act) - United States
- Computer Misuse Act - United Kingdom
- Similar laws worldwide
By proceeding, you acknowledge that:
1. You have authorization to scan target systems
2. You will comply with all applicable laws and terms of service
3. You will responsibly disclose findings
4. You will not exploit discovered vulnerabilities
[bold yellow]The authors are not responsible for misuse of this tool.[/bold yellow]
"""
# =============================================================================
# CLI Command Group
# =============================================================================
@click.group()
@click.version_option(version=__version__, prog_name="AASRT")
@click.option('--config', '-c', type=click.Path(exists=True), help='Path to config file')
@click.option('--verbose', '-v', is_flag=True, help='Enable verbose output')
@click.pass_context
def cli(ctx: click.Context, config: Optional[str], verbose: bool) -> None:
"""
AI Agent Security Reconnaissance Tool (AASRT).
Discover and assess exposed AI agent implementations using Shodan.
Use 'aasrt --help' for command list or 'aasrt COMMAND --help' for command details.
"""
ctx.ensure_object(dict)
# Initialize configuration
try:
ctx.obj['config'] = Config(config)
except Exception as e:
console.print(f"[red]Failed to load configuration: {e}[/red]")
sys.exit(1)
# Setup logging
log_level = 'DEBUG' if verbose else ctx.obj['config'].get('logging', 'level', default='INFO')
log_file = ctx.obj['config'].get('logging', 'file')
try:
setup_logger('aasrt', level=log_level, log_file=log_file)
except Exception as e:
console.print(f"[yellow]Warning: Could not setup logging: {e}[/yellow]")
ctx.obj['verbose'] = verbose
# Log startup in debug mode
logger = get_logger('aasrt')
logger.debug(f"AASRT v{__version__} starting (verbose={verbose})")
# =============================================================================
# Scan Command
# =============================================================================
@cli.command()
@click.option('--query', '-q', help='Custom Shodan search query')
@click.option('--template', '-t', help='Use predefined query template')
@click.option('--max-results', '-m', default=100, type=int, help='Max results to retrieve (1-10000)')
@click.option('--output', '-o', type=click.Path(), help='Output file path')
@click.option('--format', '-f', 'output_format',
type=click.Choice(['json', 'csv', 'both']),
default='json', help='Output format')
@click.option('--no-assess', is_flag=True, help='Skip vulnerability assessment')
@click.option('--save-db/--no-save-db', default=True, help='Save results to database')
@click.option('--yes', '-y', is_flag=True, help='Skip legal disclaimer confirmation')
@click.pass_context
def scan(
ctx: click.Context,
query: Optional[str],
template: Optional[str],
max_results: int,
output: Optional[str],
output_format: str,
no_assess: bool,
save_db: bool,
yes: bool
) -> None:
"""
Perform a security reconnaissance scan using Shodan.
Searches for exposed AI agent implementations and assesses their
security posture using passive analysis techniques.
Examples:
aasrt scan --template clawdbot_instances
aasrt scan --query 'http.title:"AutoGPT"'
aasrt scan -t exposed_env_files -m 50 -f csv
"""
global _active_database
config = ctx.obj['config']
logger = get_logger('aasrt')
logger.info(f"Starting scan command (template={template}, query={query[:50] if query else None})")
# Display legal disclaimer
if not yes:
console.print(Panel(LEGAL_DISCLAIMER, title="Legal Notice", border_style="red"))
if not click.confirm('\nDo you agree to the terms above?', default=False):
console.print('[red]Scan aborted. You must agree to terms of use.[/red]')
logger.info("Scan aborted: User declined legal disclaimer")
sys.exit(1)
# Validate max_results
if max_results < 1:
console.print('[red]Error: max-results must be at least 1[/red]')
sys.exit(1)
if max_results > 10000:
console.print('[yellow]Warning: Limiting max-results to 10000[/yellow]')
max_results = 10000
# Validate inputs
if not query and not template:
console.print('[yellow]No query or template specified. Using default template: clawdbot_instances[/yellow]')
template = 'clawdbot_instances'
# Check for shutdown before heavy operations
if is_shutdown_requested():
console.print('[yellow]Scan cancelled due to shutdown request.[/yellow]')
sys.exit(130)
# Initialize query manager
try:
query_manager = QueryManager(config)
except Exception as e:
console.print(f'[red]Failed to initialize query manager: {e}[/red]')
logger.error(f"Query manager initialization failed: {e}")
sys.exit(1)
# Check if Shodan is available
if not query_manager.is_available():
console.print('[red]Shodan is not available. Please check your API key in .env file.[/red]')
console.print('[dim]Set SHODAN_API_KEY environment variable or add to .env file.[/dim]')
sys.exit(1)
console.print('\n[green]Starting Shodan scan...[/green]')
logger.info(f"Scan started: template={template}, max_results={max_results}")
# Generate scan ID
scan_id = str(uuid.uuid4())
start_time = time.time()
# Execute scan with interrupt checking
all_results = []
scan_error = None
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
console=console
) as progress:
if template:
task = progress.add_task(f"[cyan]Scanning with template: {template}...", total=100)
try:
if not is_shutdown_requested():
all_results = query_manager.execute_template(template, max_results=max_results)
progress.update(task, completed=100)
except KeyboardInterrupt:
console.print('\n[yellow]Scan interrupted by user.[/yellow]')
scan_error = "Interrupted"
except Exception as e:
console.print(f'[red]Template execution failed: {e}[/red]')
logger.error(f"Template execution error: {e}", exc_info=True)
scan_error = str(e)
else:
task = progress.add_task("[cyan]Executing query...", total=100)
try:
if not is_shutdown_requested():
all_results = query_manager.execute_query(query, max_results=max_results)
progress.update(task, completed=100)
except KeyboardInterrupt:
console.print('\n[yellow]Scan interrupted by user.[/yellow]')
scan_error = "Interrupted"
except Exception as e:
console.print(f'[red]Query execution failed: {e}[/red]')
logger.error(f"Query execution error: {e}", exc_info=True)
scan_error = str(e)
# Check if scan was interrupted or had errors
if is_shutdown_requested():
console.print('[yellow]Scan was interrupted. Saving partial results...[/yellow]')
# Aggregate and deduplicate results
console.print('\n[cyan]Aggregating results...[/cyan]')
aggregator = ResultAggregator()
unique_results = aggregator.aggregate({'shodan': all_results})
console.print(f'Found [green]{len(unique_results)}[/green] unique results')
logger.info(f"Aggregated {len(unique_results)} unique results from {len(all_results)} total")
# Vulnerability assessment (skip if shutdown requested)
if not no_assess and unique_results and not is_shutdown_requested():
console.print('\n[cyan]Assessing vulnerabilities...[/cyan]')
assessor = VulnerabilityAssessor(config.get('vulnerability_checks', default={}))
scorer = RiskScorer()
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
console=console
) as progress:
task = progress.add_task("[cyan]Analyzing...", total=len(unique_results))
for result in unique_results:
if is_shutdown_requested():
console.print('[yellow]Assessment interrupted.[/yellow]')
break
try:
vulns = assessor.assess(result)
scorer.score_result(result, vulns)
except Exception as e:
logger.warning(f"Failed to assess result {result.ip}: {e}")
progress.advance(task)
# Calculate duration
duration = time.time() - start_time
# Determine final status
final_status = 'completed'
if scan_error:
final_status = 'failed' if not unique_results else 'partial'
elif is_shutdown_requested():
final_status = 'partial'
# Create report
report = ScanReport.from_results(
scan_id=scan_id,
results=unique_results,
engines=['shodan'],
query=query,
template_name=template,
duration=duration
)
# Display summary
_display_summary(report)
# Save to database
if save_db:
try:
db = Database(config)
_active_database = db # Track for cleanup
scan_record = db.create_scan(
engines=['shodan'],
query=query,
template_name=template
)
if unique_results:
db.add_findings(scan_record.scan_id, unique_results)
db.update_scan(
scan_record.scan_id,
status=final_status,
total_results=len(unique_results),
duration_seconds=duration
)
console.print(f'\n[green]Results saved to database. Scan ID: {scan_record.scan_id}[/green]')
logger.info(f"Saved scan {scan_record.scan_id} with {len(unique_results)} findings")
except Exception as e:
console.print(f'[yellow]Warning: Failed to save to database: {e}[/yellow]')
logger.error(f"Database save error: {e}", exc_info=True)
# Generate reports
output_dir = config.get('reporting', 'output_dir', default='./reports')
try:
if output_format in ['json', 'both']:
json_reporter = JSONReporter(output_dir)
json_path = json_reporter.generate(report, output)
console.print(f'[green]JSON report: {json_path}[/green]')
if output_format in ['csv', 'both']:
csv_reporter = CSVReporter(output_dir)
csv_path = csv_reporter.generate(report, output)
console.print(f'[green]CSV report: {csv_path}[/green]')
except Exception as e:
console.print(f'[yellow]Warning: Failed to generate report: {e}[/yellow]')
logger.error(f"Report generation error: {e}", exc_info=True)
# Final status message
if final_status == 'completed':
console.print(f'\n[bold green]Scan completed in {duration:.1f} seconds[/bold green]')
elif final_status == 'partial':
console.print(f'\n[bold yellow]Scan partially completed in {duration:.1f} seconds[/bold yellow]')
else:
console.print(f'\n[bold red]Scan failed after {duration:.1f} seconds[/bold red]')
sys.exit(1)
# =============================================================================
# Helper Functions
# =============================================================================
def _display_summary(report: ScanReport) -> None:
"""
Display scan summary in a formatted table.
Renders a Rich-formatted summary including:
- Scan ID and duration
- Total results and average risk score
- Risk distribution table
- Top 5 highest risk findings
Args:
report: ScanReport object with scan results.
"""
console.print('\n')
# Summary panel
summary_text = f"""
[bold]Scan ID:[/bold] {report.scan_id[:8]}...
[bold]Duration:[/bold] {report.duration_seconds:.1f}s
[bold]Total Results:[/bold] {report.total_results}
[bold]Average Risk Score:[/bold] {report.average_risk_score}/10
"""
console.print(Panel(summary_text, title="Scan Summary", border_style="green"))
# Risk distribution table
table = Table(title="Risk Distribution")
table.add_column("Severity", style="bold")
table.add_column("Count", justify="right")
table.add_row("[red]Critical[/red]", str(report.critical_findings))
table.add_row("[orange1]High[/orange1]", str(report.high_findings))
table.add_row("[yellow]Medium[/yellow]", str(report.medium_findings))
table.add_row("[green]Low[/green]", str(report.low_findings))
console.print(table)
# Top findings
if report.findings:
console.print('\n[bold]Top 5 Highest Risk Findings:[/bold]')
top_findings = sorted(
report.findings,
key=lambda x: x.get('risk_score', 0),
reverse=True
)[:5]
for i, finding in enumerate(top_findings, 1):
risk = finding.get('risk_score', 0)
ip = finding.get('target_ip', 'N/A')
port = finding.get('target_port', 'N/A')
hostname = finding.get('target_hostname', '')
# Color-code by CVSS-like severity
if risk >= 9.0:
color = 'red' # Critical
elif risk >= 7.0:
color = 'orange1' # High
elif risk >= 4.0:
color = 'yellow' # Medium
else:
color = 'green' # Low
target = f"{ip}:{port}"
if hostname:
target += f" ({hostname})"
console.print(f" {i}. [{color}]{target}[/{color}] - Risk: [{color}]{risk}[/{color}]")
@cli.command()
@click.option('--scan-id', '-s', help='Generate report for specific scan')
@click.option('--format', '-f', 'output_format',
type=click.Choice(['json', 'csv', 'both']),
default='json', help='Output format')
@click.option('--output', '-o', type=click.Path(), help='Output file path')
@click.pass_context
def report(ctx, scan_id, output_format, output):
"""Generate a report from a previous scan."""
config = ctx.obj['config']
try:
db = Database(config)
except Exception as e:
console.print(f'[red]Failed to connect to database: {e}[/red]')
sys.exit(1)
if scan_id:
scan = db.get_scan(scan_id)
if not scan:
console.print(f'[red]Scan not found: {scan_id}[/red]')
sys.exit(1)
scans = [scan]
else:
scans = db.get_recent_scans(limit=1)
if not scans:
console.print('[yellow]No scans found in database.[/yellow]')
sys.exit(0)
scan = scans[0]
findings = db.get_findings(scan_id=scan.scan_id)
report_data = ScanReport.from_scan(scan, findings)
output_dir = config.get('reporting', 'output_dir', default='./reports')
if output_format in ['json', 'both']:
json_reporter = JSONReporter(output_dir)
json_path = json_reporter.generate(report_data, output)
console.print(f'[green]JSON report: {json_path}[/green]')
if output_format in ['csv', 'both']:
csv_reporter = CSVReporter(output_dir)
csv_path = csv_reporter.generate(report_data, output)
console.print(f'[green]CSV report: {csv_path}[/green]')
@cli.command()
@click.pass_context
def status(ctx):
"""Show status of Shodan API configuration."""
config = ctx.obj['config']
console.print('\n[bold]Shodan API Status[/bold]\n')
try:
query_manager = QueryManager(config)
except Exception as e:
console.print(f'[red]Failed to initialize: {e}[/red]')
return
# Validate Shodan
table = Table(title="Engine Status")
table.add_column("Engine", style="bold")
table.add_column("Status")
table.add_column("Details")
if query_manager.is_available():
is_valid = query_manager.validate_engine()
if is_valid:
quota = query_manager.get_quota_info()
status_str = "[green]OK[/green]"
details = f"Credits: {quota.get('query_credits', 'N/A')}, Plan: {quota.get('plan', 'N/A')}"
else:
status_str = "[red]Invalid[/red]"
details = "API key validation failed"
else:
status_str = "[red]Not Configured[/red]"
details = "Add SHODAN_API_KEY to .env file"
table.add_row("Shodan", status_str, details)
console.print(table)
# Available templates
templates = query_manager.get_available_templates()
console.print(f'\n[bold]Available Query Templates:[/bold] {len(templates)}')
for template in sorted(templates):
console.print(f' - {template}')
@cli.command()
@click.option('--limit', '-l', default=10, help='Number of recent scans to show')
@click.pass_context
def history(ctx, limit):
"""Show scan history from database."""
config = ctx.obj['config']
try:
db = Database(config)
scans = db.get_recent_scans(limit=limit)
except Exception as e:
console.print(f'[red]Failed to access database: {e}[/red]')
return
if not scans:
console.print('[yellow]No scans found in database.[/yellow]')
return
table = Table(title=f"Recent Scans (Last {limit})")
table.add_column("Scan ID", style="cyan")
table.add_column("Timestamp")
table.add_column("Template/Query")
table.add_column("Results", justify="right")
table.add_column("Status")
for scan in scans:
scan_id = scan.scan_id[:8] + "..."
timestamp = scan.timestamp.strftime("%Y-%m-%d %H:%M") if scan.timestamp else "N/A"
query_info = scan.template_name or (scan.query[:30] + "..." if scan.query and len(scan.query) > 30 else scan.query) or "N/A"
status_color = "green" if scan.status == "completed" else "yellow" if scan.status == "running" else "red"
status_str = f"[{status_color}]{scan.status}[/{status_color}]"
table.add_row(scan_id, timestamp, query_info, str(scan.total_results), status_str)
console.print(table)
# Show database stats
stats = db.get_statistics()
console.print(f'\n[bold]Database Statistics:[/bold]')
console.print(f' Total Scans: {stats["total_scans"]}')
console.print(f' Total Findings: {stats["total_findings"]}')
console.print(f' Unique IPs: {stats["unique_ips"]}')
@cli.command()
@click.pass_context
def templates(ctx):
"""List available query templates."""
config = ctx.obj['config']
try:
query_manager = QueryManager(config)
except Exception as e:
console.print(f'[red]Failed to initialize: {e}[/red]')
return
templates = query_manager.get_available_templates()
console.print('\n[bold]Available Shodan Query Templates[/bold]\n')
table = Table()
table.add_column("Template Name", style="cyan")
table.add_column("Queries")
for template_name in sorted(templates):
queries = query_manager.templates.get(template_name, [])
query_count = len(queries)
table.add_row(template_name, f"{query_count} queries")
console.print(table)
console.print('\n[dim]Use with: aasrt scan --template <template_name>[/dim]')
# =============================================================================
# Entry Point
# =============================================================================
def main() -> None:
"""
Main entry point for AASRT CLI.
Initializes the Click command group and handles top-level exceptions.
Called when running `python -m src.main` or `aasrt` command.
Exit Codes:
0: Success
1: Error
130: Interrupted by user
"""
try:
cli(obj={})
except KeyboardInterrupt:
console.print("\n[yellow]Operation cancelled by user.[/yellow]")
sys.exit(130)
except Exception as e:
logger = get_logger('aasrt')
logger.exception(f"Unexpected error: {e}")
console.print(f"\n[red]Unexpected error: {e}[/red]")
console.print("[dim]Check logs for details.[/dim]")
sys.exit(1)
if __name__ == '__main__':
main()
+7
View File
@@ -0,0 +1,7 @@
"""Reporting modules for AASRT."""
from .base import BaseReporter, ScanReport
from .json_reporter import JSONReporter
from .csv_reporter import CSVReporter
__all__ = ['BaseReporter', 'ScanReport', 'JSONReporter', 'CSVReporter']
+199
View File
@@ -0,0 +1,199 @@
"""Base reporter class for AASRT."""
import os
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any, Dict, List, Optional
from src.engines import SearchResult
from src.storage.database import Scan, Finding
@dataclass
class ScanReport:
"""Container for scan report data."""
scan_id: str
timestamp: datetime
engines_used: List[str]
query: Optional[str] = None
template_name: Optional[str] = None
total_results: int = 0
duration_seconds: float = 0.0
status: str = "completed"
# Summary statistics
critical_findings: int = 0
high_findings: int = 0
medium_findings: int = 0
low_findings: int = 0
average_risk_score: float = 0.0
# Detailed findings
findings: List[Dict[str, Any]] = field(default_factory=list)
# Additional metadata
metadata: Dict[str, Any] = field(default_factory=dict)
@classmethod
def from_scan(
cls,
scan: Scan,
findings: List[Finding]
) -> 'ScanReport':
"""Create ScanReport from database objects."""
import json
# Calculate severity counts
critical = sum(1 for f in findings if f.risk_score >= 9.0)
high = sum(1 for f in findings if 7.0 <= f.risk_score < 9.0)
medium = sum(1 for f in findings if 4.0 <= f.risk_score < 7.0)
low = sum(1 for f in findings if f.risk_score < 4.0)
# Calculate average risk
avg_risk = sum(f.risk_score for f in findings) / len(findings) if findings else 0.0
return cls(
scan_id=scan.scan_id,
timestamp=scan.timestamp,
engines_used=json.loads(scan.engines_used) if scan.engines_used else [],
query=scan.query,
template_name=scan.template_name,
total_results=len(findings),
duration_seconds=scan.duration_seconds or 0.0,
status=scan.status,
critical_findings=critical,
high_findings=high,
medium_findings=medium,
low_findings=low,
average_risk_score=round(avg_risk, 1),
findings=[f.to_dict() for f in findings],
metadata=json.loads(scan.metadata) if scan.metadata else {}
)
@classmethod
def from_results(
cls,
scan_id: str,
results: List[SearchResult],
engines: List[str],
query: Optional[str] = None,
template_name: Optional[str] = None,
duration: float = 0.0
) -> 'ScanReport':
"""Create ScanReport from search results."""
# Calculate severity counts
critical = sum(1 for r in results if r.risk_score >= 9.0)
high = sum(1 for r in results if 7.0 <= r.risk_score < 9.0)
medium = sum(1 for r in results if 4.0 <= r.risk_score < 7.0)
low = sum(1 for r in results if r.risk_score < 4.0)
# Calculate average risk
avg_risk = sum(r.risk_score for r in results) / len(results) if results else 0.0
return cls(
scan_id=scan_id,
timestamp=datetime.utcnow(),
engines_used=engines,
query=query,
template_name=template_name,
total_results=len(results),
duration_seconds=duration,
status="completed",
critical_findings=critical,
high_findings=high,
medium_findings=medium,
low_findings=low,
average_risk_score=round(avg_risk, 1),
findings=[r.to_dict() for r in results]
)
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary."""
return {
'scan_metadata': {
'scan_id': self.scan_id,
'timestamp': self.timestamp.isoformat() if self.timestamp else None,
'engines_used': self.engines_used,
'query': self.query,
'template_name': self.template_name,
'total_results': self.total_results,
'duration_seconds': self.duration_seconds,
'status': self.status
},
'summary': {
'critical_findings': self.critical_findings,
'high_findings': self.high_findings,
'medium_findings': self.medium_findings,
'low_findings': self.low_findings,
'average_risk_score': self.average_risk_score
},
'findings': self.findings,
'metadata': self.metadata
}
class BaseReporter(ABC):
"""Abstract base class for reporters."""
def __init__(self, output_dir: str = "./reports"):
"""
Initialize reporter.
Args:
output_dir: Directory for report output
"""
self.output_dir = output_dir
os.makedirs(output_dir, exist_ok=True)
@property
@abstractmethod
def format_name(self) -> str:
"""Return the format name (e.g., 'json', 'csv')."""
pass
@property
@abstractmethod
def file_extension(self) -> str:
"""Return the file extension."""
pass
@abstractmethod
def generate(self, report: ScanReport, filename: Optional[str] = None) -> str:
"""
Generate a report file.
Args:
report: ScanReport data
filename: Optional custom filename (without extension)
Returns:
Path to generated report file
"""
pass
@abstractmethod
def generate_string(self, report: ScanReport) -> str:
"""
Generate report as a string.
Args:
report: ScanReport data
Returns:
Report content as string
"""
pass
def get_filename(self, scan_id: str, custom_name: Optional[str] = None) -> str:
"""Generate a filename for the report."""
if custom_name:
return f"{custom_name}.{self.file_extension}"
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
return f"scan_{scan_id[:8]}_{timestamp}.{self.file_extension}"
def get_filepath(self, filename: str) -> str:
"""Get full file path for a report."""
return os.path.join(self.output_dir, filename)
+221
View File
@@ -0,0 +1,221 @@
"""CSV report generator for AASRT."""
import csv
import io
from typing import List, Optional
from .base import BaseReporter, ScanReport
from src.utils.logger import get_logger
logger = get_logger(__name__)
class CSVReporter(BaseReporter):
"""Generates CSV format reports."""
# Default columns for findings export
DEFAULT_COLUMNS = [
'target_ip',
'target_port',
'target_hostname',
'service',
'risk_score',
'vulnerabilities',
'source_engine',
'first_seen',
'status',
'confidence'
]
def __init__(
self,
output_dir: str = "./reports",
columns: Optional[List[str]] = None,
include_metadata: bool = False
):
"""
Initialize CSV reporter.
Args:
output_dir: Output directory for reports
columns: Custom columns to include
include_metadata: Whether to include metadata columns
"""
super().__init__(output_dir)
self.columns = columns or self.DEFAULT_COLUMNS.copy()
self.include_metadata = include_metadata
if include_metadata:
self.columns.extend(['location', 'isp', 'asn'])
@property
def format_name(self) -> str:
return "csv"
@property
def file_extension(self) -> str:
return "csv"
def generate(self, report: ScanReport, filename: Optional[str] = None) -> str:
"""
Generate CSV report file.
Args:
report: ScanReport data
filename: Optional custom filename
Returns:
Path to generated report file
"""
output_filename = self.get_filename(report.scan_id, filename)
filepath = self.get_filepath(output_filename)
content = self.generate_string(report)
with open(filepath, 'w', encoding='utf-8', newline='') as f:
f.write(content)
logger.info(f"Generated CSV report: {filepath}")
return filepath
def generate_string(self, report: ScanReport) -> str:
"""
Generate CSV report as string.
Args:
report: ScanReport data
Returns:
CSV content as string
"""
output = io.StringIO()
writer = csv.DictWriter(output, fieldnames=self.columns, extrasaction='ignore')
# Write header
writer.writeheader()
# Write findings
for finding in report.findings:
row = self._format_finding(finding)
writer.writerow(row)
return output.getvalue()
def _format_finding(self, finding: dict) -> dict:
"""Format a finding for CSV output."""
row = {}
for col in self.columns:
if col in finding:
value = finding[col]
# Convert lists to comma-separated strings
if isinstance(value, list):
value = '; '.join(str(v) for v in value)
# Convert dicts to string representation
elif isinstance(value, dict):
value = str(value)
row[col] = value
elif col == 'location':
# Extract from metadata
metadata = finding.get('metadata', {})
location = metadata.get('location', {})
if isinstance(location, dict):
row[col] = f"{location.get('country', '')}, {location.get('city', '')}"
else:
row[col] = ''
elif col == 'isp':
metadata = finding.get('metadata', {})
row[col] = metadata.get('isp', '')
elif col == 'asn':
metadata = finding.get('metadata', {})
row[col] = metadata.get('asn', '')
else:
row[col] = ''
return row
def generate_summary(self, report: ScanReport, filename: Optional[str] = None) -> str:
"""
Generate a summary CSV file.
Args:
report: ScanReport data
filename: Optional custom filename
Returns:
Path to generated file
"""
summary_filename = filename or f"summary_{report.scan_id[:8]}"
if not summary_filename.endswith('.csv'):
summary_filename = f"{summary_filename}.csv"
filepath = self.get_filepath(summary_filename)
output = io.StringIO()
writer = csv.writer(output)
# Write summary as key-value pairs
writer.writerow(['Metric', 'Value'])
writer.writerow(['Scan ID', report.scan_id])
writer.writerow(['Timestamp', report.timestamp.isoformat() if report.timestamp else ''])
writer.writerow(['Engines Used', ', '.join(report.engines_used)])
writer.writerow(['Query', report.query or ''])
writer.writerow(['Template', report.template_name or ''])
writer.writerow(['Total Results', report.total_results])
writer.writerow(['Duration (seconds)', report.duration_seconds])
writer.writerow(['Critical Findings', report.critical_findings])
writer.writerow(['High Findings', report.high_findings])
writer.writerow(['Medium Findings', report.medium_findings])
writer.writerow(['Low Findings', report.low_findings])
writer.writerow(['Average Risk Score', report.average_risk_score])
with open(filepath, 'w', encoding='utf-8', newline='') as f:
f.write(output.getvalue())
logger.info(f"Generated CSV summary: {filepath}")
return filepath
def generate_vulnerability_report(self, report: ScanReport, filename: Optional[str] = None) -> str:
"""
Generate a vulnerability-focused CSV report.
Args:
report: ScanReport data
filename: Optional custom filename
Returns:
Path to generated file
"""
vuln_filename = filename or f"vulnerabilities_{report.scan_id[:8]}"
if not vuln_filename.endswith('.csv'):
vuln_filename = f"{vuln_filename}.csv"
filepath = self.get_filepath(vuln_filename)
output = io.StringIO()
writer = csv.writer(output)
# Header
writer.writerow(['Target IP', 'Port', 'Hostname', 'Vulnerability', 'Risk Score'])
# Write vulnerability rows
for finding in report.findings:
ip = finding.get('target_ip', '')
port = finding.get('target_port', '')
hostname = finding.get('target_hostname', '')
risk_score = finding.get('risk_score', 0)
vulns = finding.get('vulnerabilities', [])
if vulns:
for vuln in vulns:
writer.writerow([ip, port, hostname, vuln, risk_score])
else:
writer.writerow([ip, port, hostname, 'None detected', risk_score])
with open(filepath, 'w', encoding='utf-8', newline='') as f:
f.write(output.getvalue())
logger.info(f"Generated vulnerability CSV: {filepath}")
return filepath
+122
View File
@@ -0,0 +1,122 @@
"""JSON report generator for AASRT."""
import json
from typing import Optional
from .base import BaseReporter, ScanReport
from src.utils.logger import get_logger
logger = get_logger(__name__)
class JSONReporter(BaseReporter):
"""Generates JSON format reports."""
def __init__(self, output_dir: str = "./reports", pretty: bool = True):
"""
Initialize JSON reporter.
Args:
output_dir: Output directory for reports
pretty: Whether to format JSON with indentation
"""
super().__init__(output_dir)
self.pretty = pretty
@property
def format_name(self) -> str:
return "json"
@property
def file_extension(self) -> str:
return "json"
def generate(self, report: ScanReport, filename: Optional[str] = None) -> str:
"""
Generate JSON report file.
Args:
report: ScanReport data
filename: Optional custom filename
Returns:
Path to generated report file
"""
output_filename = self.get_filename(report.scan_id, filename)
filepath = self.get_filepath(output_filename)
content = self.generate_string(report)
with open(filepath, 'w', encoding='utf-8') as f:
f.write(content)
logger.info(f"Generated JSON report: {filepath}")
return filepath
def generate_string(self, report: ScanReport) -> str:
"""
Generate JSON report as string.
Args:
report: ScanReport data
Returns:
JSON string
"""
data = report.to_dict()
# Add report metadata
data['report_metadata'] = {
'format': 'json',
'version': '1.0',
'generated_by': 'AASRT (AI Agent Security Reconnaissance Tool)'
}
if self.pretty:
return json.dumps(data, indent=2, default=str, ensure_ascii=False)
else:
return json.dumps(data, default=str, ensure_ascii=False)
def generate_summary(self, report: ScanReport) -> str:
"""
Generate a summary-only JSON report.
Args:
report: ScanReport data
Returns:
JSON string with summary only
"""
summary = {
'scan_id': report.scan_id,
'timestamp': report.timestamp.isoformat() if report.timestamp else None,
'engines_used': report.engines_used,
'total_results': report.total_results,
'summary': {
'critical_findings': report.critical_findings,
'high_findings': report.high_findings,
'medium_findings': report.medium_findings,
'low_findings': report.low_findings,
'average_risk_score': report.average_risk_score
}
}
if self.pretty:
return json.dumps(summary, indent=2, default=str)
else:
return json.dumps(summary, default=str)
def generate_findings_only(self, report: ScanReport) -> str:
"""
Generate JSON with findings only (no metadata).
Args:
report: ScanReport data
Returns:
JSON string with findings array
"""
if self.pretty:
return json.dumps(report.findings, indent=2, default=str, ensure_ascii=False)
else:
return json.dumps(report.findings, default=str, ensure_ascii=False)
+5
View File
@@ -0,0 +1,5 @@
"""Storage modules for AASRT."""
from .database import Database, Scan, Finding
__all__ = ['Database', 'Scan', 'Finding']
+806
View File
@@ -0,0 +1,806 @@
"""
Database storage layer for AASRT.
This module provides a production-ready database layer with:
- Connection pooling for efficient resource usage
- Automatic retry logic for transient failures
- Context managers for proper session cleanup
- Support for SQLite (default) and PostgreSQL
- Comprehensive logging and error handling
Example:
>>> from src.storage.database import Database
>>> db = Database()
>>> scan = db.create_scan(engines=["shodan"], query="http.html:agent")
>>> db.add_findings(scan.scan_id, results)
>>> db.update_scan(scan.scan_id, status="completed")
"""
import json
import os
import uuid
import time
from contextlib import contextmanager
from datetime import datetime, timedelta
from functools import wraps
from typing import Any, Callable, Dict, Generator, List, Optional, TypeVar
from sqlalchemy import (
create_engine, Column, String, Integer, Float, DateTime,
Text, Boolean, ForeignKey, Index, event
)
from sqlalchemy.orm import declarative_base, sessionmaker, relationship, Session, scoped_session
from sqlalchemy.pool import QueuePool, StaticPool
from sqlalchemy.exc import SQLAlchemyError, OperationalError, IntegrityError
from src.engines import SearchResult
from src.utils.config import Config
from src.utils.logger import get_logger
logger = get_logger(__name__)
Base = declarative_base()
# =============================================================================
# Retry Configuration
# =============================================================================
T = TypeVar('T')
# Maximum retry attempts for transient database errors
MAX_DB_RETRIES = 3
# Base delay for exponential backoff (seconds)
DB_RETRY_BASE_DELAY = 0.5
# Exceptions that should trigger a retry
RETRYABLE_DB_EXCEPTIONS = (OperationalError,)
def with_db_retry(func: Callable[..., T]) -> Callable[..., T]:
"""
Decorator that adds retry logic for transient database errors.
Retries on connection errors and deadlocks but not on
constraint violations or other permanent errors.
Args:
func: Database function to wrap with retry logic.
Returns:
Wrapped function with retry capability.
"""
@wraps(func)
def wrapper(*args, **kwargs) -> T:
last_exception = None
for attempt in range(1, MAX_DB_RETRIES + 1):
try:
return func(*args, **kwargs)
except RETRYABLE_DB_EXCEPTIONS as e:
last_exception = e
if attempt < MAX_DB_RETRIES:
delay = DB_RETRY_BASE_DELAY * (2 ** (attempt - 1))
logger.warning(
f"Database retry {attempt}/{MAX_DB_RETRIES} for {func.__name__} "
f"after {delay:.2f}s. Error: {e}"
)
time.sleep(delay)
else:
logger.error(
f"All {MAX_DB_RETRIES} database retries exhausted for {func.__name__}"
)
except IntegrityError as e:
# Don't retry constraint violations
logger.error(f"Database integrity error in {func.__name__}: {e}")
raise
except SQLAlchemyError as e:
# Log and re-raise other SQLAlchemy errors
logger.error(f"Database error in {func.__name__}: {e}")
raise
# All retries exhausted
if last_exception:
raise last_exception
raise SQLAlchemyError(f"Unexpected database error in {func.__name__}")
return wrapper
class Scan(Base):
"""Scan record model."""
__tablename__ = 'scans'
scan_id = Column(String(36), primary_key=True)
timestamp = Column(DateTime, nullable=False, default=datetime.utcnow)
engines_used = Column(Text) # JSON array
query = Column(Text)
template_name = Column(String(255))
total_results = Column(Integer, default=0)
duration_seconds = Column(Float)
status = Column(String(50), default='running') # running, completed, failed, partial
extra_data = Column(Text) # JSON
# Relationships
findings = relationship("Finding", back_populates="scan", cascade="all, delete-orphan")
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary."""
return {
'scan_id': self.scan_id,
'timestamp': self.timestamp.isoformat() if self.timestamp else None,
'engines_used': json.loads(self.engines_used) if self.engines_used else [],
'query': self.query,
'template_name': self.template_name,
'total_results': self.total_results,
'duration_seconds': self.duration_seconds,
'status': self.status,
'metadata': json.loads(self.extra_data) if self.extra_data else {}
}
class Finding(Base):
"""Finding record model."""
__tablename__ = 'findings'
finding_id = Column(String(36), primary_key=True)
scan_id = Column(String(36), ForeignKey('scans.scan_id'), nullable=False)
source_engine = Column(String(50))
target_ip = Column(String(45), nullable=False) # Support IPv6
target_port = Column(Integer, nullable=False)
target_hostname = Column(String(255))
service = Column(String(255))
banner = Column(Text)
risk_score = Column(Float, default=0.0)
vulnerabilities = Column(Text) # JSON array
first_seen = Column(DateTime, default=datetime.utcnow)
last_seen = Column(DateTime, default=datetime.utcnow)
status = Column(String(50), default='new') # new, confirmed, false_positive, remediated
confidence = Column(Integer, default=100)
extra_data = Column(Text) # JSON
# Relationships
scan = relationship("Scan", back_populates="findings")
# Indexes
__table_args__ = (
Index('idx_findings_risk', risk_score.desc()),
Index('idx_findings_timestamp', first_seen.desc()),
Index('idx_findings_ip', target_ip),
Index('idx_findings_status', status),
)
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary."""
return {
'finding_id': self.finding_id,
'scan_id': self.scan_id,
'source_engine': self.source_engine,
'target_ip': self.target_ip,
'target_port': self.target_port,
'target_hostname': self.target_hostname,
'service': self.service,
'banner': self.banner,
'risk_score': self.risk_score,
'vulnerabilities': json.loads(self.vulnerabilities) if self.vulnerabilities else [],
'first_seen': self.first_seen.isoformat() if self.first_seen else None,
'last_seen': self.last_seen.isoformat() if self.last_seen else None,
'status': self.status,
'confidence': self.confidence,
'metadata': json.loads(self.extra_data) if self.extra_data else {}
}
@classmethod
def from_search_result(cls, result: SearchResult, scan_id: str) -> 'Finding':
"""Create Finding from SearchResult."""
return cls(
finding_id=str(uuid.uuid4()),
scan_id=scan_id,
source_engine=result.source_engine,
target_ip=result.ip,
target_port=result.port,
target_hostname=result.hostname,
service=result.service,
banner=result.banner,
risk_score=result.risk_score,
vulnerabilities=json.dumps(result.vulnerabilities),
confidence=result.confidence,
extra_data=json.dumps(result.metadata)
)
class Database:
"""
Database manager for AASRT with connection pooling and retry logic.
This class provides a thread-safe database layer with:
- Connection pooling for efficient resource usage
- Automatic retry on transient failures
- Context managers for proper session cleanup
- Support for SQLite and PostgreSQL
Attributes:
config: Configuration instance.
engine: SQLAlchemy engine with connection pool.
Session: Scoped session factory.
Example:
>>> db = Database()
>>> with db.session_scope() as session:
... scan = Scan(scan_id="123", ...)
... session.add(scan)
>>> # Session is automatically committed and closed
"""
# Connection pool settings
POOL_SIZE = 5
MAX_OVERFLOW = 10
POOL_TIMEOUT = 30
POOL_RECYCLE = 3600 # Recycle connections after 1 hour
def __init__(self, config: Optional[Config] = None) -> None:
"""
Initialize database connection with connection pooling.
Args:
config: Configuration instance. If None, uses default Config.
Raises:
SQLAlchemyError: If database connection fails.
"""
self.config = config or Config()
self.engine = None
self.Session = None
self._db_type: str = "unknown"
self._initialize()
def _initialize(self) -> None:
"""
Initialize database connection, pooling, and create tables.
Sets up connection pooling appropriate for the database type:
- SQLite: Uses StaticPool for thread safety
- PostgreSQL: Uses QueuePool with configurable size
"""
self._db_type = self.config.get('database', 'type', default='sqlite')
if self._db_type == 'sqlite':
db_path = self.config.get('database', 'sqlite', 'path', default='./data/scanner.db')
# Ensure directory exists
os.makedirs(os.path.dirname(db_path), exist_ok=True)
connection_string = f"sqlite:///{db_path}"
# SQLite configuration - use StaticPool for thread safety
# Also enable WAL mode for better concurrent access
self.engine = create_engine(
connection_string,
echo=False,
poolclass=StaticPool,
connect_args={
"check_same_thread": False,
"timeout": 30
}
)
# Enable WAL mode for better concurrent access
@event.listens_for(self.engine, "connect")
def set_sqlite_pragma(dbapi_connection, connection_record):
cursor = dbapi_connection.cursor()
cursor.execute("PRAGMA journal_mode=WAL")
cursor.execute("PRAGMA synchronous=NORMAL")
cursor.execute("PRAGMA foreign_keys=ON")
cursor.close()
else:
# PostgreSQL with connection pooling
host = self.config.get('database', 'postgresql', 'host', default='localhost')
port = self.config.get('database', 'postgresql', 'port', default=5432)
database = self.config.get('database', 'postgresql', 'database', default='aasrt')
user = self.config.get('database', 'postgresql', 'user')
password = self.config.get('database', 'postgresql', 'password')
ssl_mode = self.config.get('database', 'postgresql', 'ssl_mode', default='prefer')
# Mask password in logs
safe_conn_str = f"postgresql://{user}:***@{host}:{port}/{database}"
connection_string = f"postgresql://{user}:{password}@{host}:{port}/{database}?sslmode={ssl_mode}"
self.engine = create_engine(
connection_string,
echo=False,
poolclass=QueuePool,
pool_size=self.POOL_SIZE,
max_overflow=self.MAX_OVERFLOW,
pool_timeout=self.POOL_TIMEOUT,
pool_recycle=self.POOL_RECYCLE,
pool_pre_ping=True # Verify connections before use
)
logger.debug(f"PostgreSQL connection: {safe_conn_str}")
# Use scoped_session for thread safety
self.Session = scoped_session(sessionmaker(bind=self.engine))
# Create tables
Base.metadata.create_all(self.engine)
logger.info(f"Database initialized: {self._db_type}")
@contextmanager
def session_scope(self) -> Generator[Session, None, None]:
"""
Provide a transactional scope around a series of operations.
This context manager handles session lifecycle:
- Creates a new session
- Commits on success
- Rolls back on exception
- Always closes the session
Yields:
SQLAlchemy Session object.
Raises:
SQLAlchemyError: On database errors (after rollback).
Example:
>>> with db.session_scope() as session:
... session.add(Scan(...))
... # Automatically committed if no exception
"""
session = self.Session()
try:
yield session
session.commit()
except Exception as e:
session.rollback()
logger.error(f"Database session error, rolling back: {e}")
raise
finally:
session.close()
def get_session(self) -> Session:
"""
Get a database session (legacy method).
Note:
Prefer using session_scope() context manager for new code.
This method is kept for backward compatibility.
Returns:
SQLAlchemy Session object.
"""
return self.Session()
def close(self) -> None:
"""
Close all database connections and cleanup resources.
Call this method during application shutdown to properly
release database connections.
"""
if self.Session:
self.Session.remove()
if self.engine:
self.engine.dispose()
logger.info("Database connections closed")
def health_check(self) -> Dict[str, Any]:
"""
Perform a health check on the database connection.
Returns:
Dictionary with health status:
- healthy: bool indicating if database is accessible
- db_type: Database type (sqlite/postgresql)
- latency_ms: Response time in milliseconds
- error: Error message if unhealthy (optional)
"""
start_time = time.time()
try:
with self.session_scope() as session:
# Simple query to verify connection
session.execute("SELECT 1")
latency = (time.time() - start_time) * 1000
return {
"healthy": True,
"db_type": self._db_type,
"latency_ms": round(latency, 2),
"pool_size": getattr(self.engine.pool, 'size', lambda: 'N/A')() if hasattr(self.engine, 'pool') else 'N/A'
}
except Exception as e:
latency = (time.time() - start_time) * 1000
logger.error(f"Database health check failed: {e}")
return {
"healthy": False,
"db_type": self._db_type,
"latency_ms": round(latency, 2),
"error": str(e)
}
# =========================================================================
# Scan Operations
# =========================================================================
@with_db_retry
def create_scan(
self,
engines: List[str],
query: Optional[str] = None,
template_name: Optional[str] = None
) -> Scan:
"""
Create a new scan record in the database.
Args:
engines: List of engine names used for the scan (e.g., ["shodan"]).
query: Search query string (if using custom query).
template_name: Template name (if using predefined template).
Returns:
Created Scan object with generated scan_id.
Raises:
SQLAlchemyError: If database operation fails.
Example:
>>> scan = db.create_scan(engines=["shodan"], template_name="clawdbot")
>>> print(scan.scan_id)
"""
scan = Scan(
scan_id=str(uuid.uuid4()),
timestamp=datetime.utcnow(),
engines_used=json.dumps(engines),
query=query,
template_name=template_name,
status='running'
)
with self.session_scope() as session:
session.add(scan)
# Flush to ensure data is written before expunge
session.flush()
logger.info(f"Created scan: {scan.scan_id}")
# Need to expunge to use outside session
session.expunge(scan)
return scan
@with_db_retry
def update_scan(
self,
scan_id: str,
status: Optional[str] = None,
total_results: Optional[int] = None,
duration_seconds: Optional[float] = None,
metadata: Optional[Dict] = None
) -> Optional[Scan]:
"""
Update a scan record with new values.
Args:
scan_id: UUID of the scan to update.
status: New status (running, completed, failed, partial).
total_results: Number of results found.
duration_seconds: Total scan duration.
metadata: Additional metadata to merge.
Returns:
Updated Scan object, or None if scan not found.
"""
with self.session_scope() as session:
scan = session.query(Scan).filter(Scan.scan_id == scan_id).first()
if not scan:
logger.warning(f"Scan not found for update: {scan_id}")
return None
if status:
scan.status = status
if total_results is not None:
scan.total_results = total_results
if duration_seconds is not None:
scan.duration_seconds = duration_seconds
if metadata:
existing = json.loads(scan.extra_data) if scan.extra_data else {}
existing.update(metadata)
scan.extra_data = json.dumps(existing)
# Flush to ensure changes are written before expunge
session.flush()
logger.debug(f"Updated scan {scan_id}: status={status}, results={total_results}")
session.expunge(scan)
return scan
@with_db_retry
def get_scan(self, scan_id: str) -> Optional[Scan]:
"""
Get a scan by its UUID.
Args:
scan_id: UUID of the scan.
Returns:
Scan object or None if not found.
"""
with self.session_scope() as session:
scan = session.query(Scan).filter(Scan.scan_id == scan_id).first()
if scan:
session.expunge(scan)
return scan
@with_db_retry
def get_recent_scans(self, limit: int = 10) -> List[Scan]:
"""
Get the most recent scans.
Args:
limit: Maximum number of scans to return.
Returns:
List of Scan objects ordered by timestamp descending.
"""
with self.session_scope() as session:
scans = session.query(Scan).order_by(Scan.timestamp.desc()).limit(limit).all()
for scan in scans:
session.expunge(scan)
return scans
@with_db_retry
def delete_scan(self, scan_id: str) -> bool:
"""
Delete a scan and all its associated findings.
Args:
scan_id: UUID of the scan to delete.
Returns:
True if scan was deleted, False if not found.
"""
with self.session_scope() as session:
scan = session.query(Scan).filter(Scan.scan_id == scan_id).first()
if scan:
session.delete(scan)
session.commit()
logger.info(f"Deleted scan: {scan_id}")
return True
return False
# =========================================================================
# Finding Operations
# =========================================================================
@with_db_retry
def add_findings(self, scan_id: str, results: List[SearchResult]) -> int:
"""
Add findings from search results to the database.
Args:
scan_id: Parent scan UUID.
results: List of SearchResult objects to store.
Returns:
Number of findings successfully added.
Raises:
SQLAlchemyError: If database operation fails.
Note:
Findings are added in batches for efficiency.
"""
if not results:
logger.debug(f"No findings to add for scan {scan_id}")
return 0
with self.session_scope() as session:
count = 0
for result in results:
try:
finding = Finding.from_search_result(result, scan_id)
session.add(finding)
count += 1
except Exception as e:
logger.warning(f"Failed to create finding from result: {e}")
continue
logger.info(f"Added {count} findings to scan {scan_id}")
return count
@with_db_retry
def get_findings(
self,
scan_id: Optional[str] = None,
min_risk_score: Optional[float] = None,
status: Optional[str] = None,
limit: int = 100,
offset: int = 0
) -> List[Finding]:
"""
Get findings with optional filters.
Args:
scan_id: Filter by scan UUID.
min_risk_score: Minimum risk score (0.0-10.0).
status: Finding status filter (new, confirmed, false_positive, remediated).
limit: Maximum results to return.
offset: Number of results to skip (for pagination).
Returns:
List of Finding objects matching filters, ordered by risk score descending.
"""
with self.session_scope() as session:
query = session.query(Finding)
if scan_id:
query = query.filter(Finding.scan_id == scan_id)
if min_risk_score is not None:
query = query.filter(Finding.risk_score >= min_risk_score)
if status:
query = query.filter(Finding.status == status)
query = query.order_by(Finding.risk_score.desc())
findings = query.offset(offset).limit(limit).all()
for finding in findings:
session.expunge(finding)
return findings
@with_db_retry
def get_finding(self, finding_id: str) -> Optional[Finding]:
"""
Get a single finding by its UUID.
Args:
finding_id: UUID of the finding.
Returns:
Finding object or None if not found.
"""
with self.session_scope() as session:
finding = session.query(Finding).filter(Finding.finding_id == finding_id).first()
if finding:
session.expunge(finding)
return finding
@with_db_retry
def update_finding_status(self, finding_id: str, status: str) -> bool:
"""
Update the status of a finding.
Args:
finding_id: UUID of the finding.
status: New status (new, confirmed, false_positive, remediated).
Returns:
True if finding was updated, False if not found.
"""
with self.session_scope() as session:
finding = session.query(Finding).filter(Finding.finding_id == finding_id).first()
if finding:
finding.status = status
finding.last_seen = datetime.utcnow()
logger.debug(f"Updated finding {finding_id} status to {status}")
return True
logger.warning(f"Finding not found for status update: {finding_id}")
return False
@with_db_retry
def get_finding_by_target(self, ip: str, port: int) -> Optional[Finding]:
"""
Get the most recent finding for a specific target.
Args:
ip: Target IP address.
port: Target port number.
Returns:
Most recent Finding for the target, or None if not found.
"""
with self.session_scope() as session:
finding = session.query(Finding).filter(
Finding.target_ip == ip,
Finding.target_port == port
).order_by(Finding.last_seen.desc()).first()
if finding:
session.expunge(finding)
return finding
# =========================================================================
# Statistics and Maintenance
# =========================================================================
@with_db_retry
def get_statistics(self) -> Dict[str, Any]:
"""
Get overall database statistics.
Returns:
Dictionary containing:
- total_scans: Total number of scans
- total_findings: Total number of findings
- unique_ips: Count of unique IP addresses
- risk_distribution: Dict with critical/high/medium/low counts
- last_scan_time: Timestamp of most recent scan (or None)
Example:
>>> stats = db.get_statistics()
>>> print(f"Critical findings: {stats['risk_distribution']['critical']}")
"""
with self.session_scope() as session:
total_scans = session.query(Scan).count()
total_findings = session.query(Finding).count()
# Risk distribution using CVSS-like thresholds
critical = session.query(Finding).filter(Finding.risk_score >= 9.0).count()
high = session.query(Finding).filter(
Finding.risk_score >= 7.0,
Finding.risk_score < 9.0
).count()
medium = session.query(Finding).filter(
Finding.risk_score >= 4.0,
Finding.risk_score < 7.0
).count()
low = session.query(Finding).filter(Finding.risk_score < 4.0).count()
# Unique IPs discovered
unique_ips = session.query(Finding.target_ip).distinct().count()
# Last scan timestamp
last_scan = session.query(Scan).order_by(Scan.timestamp.desc()).first()
last_scan_time = last_scan.timestamp.isoformat() if last_scan else None
return {
'total_scans': total_scans,
'total_findings': total_findings,
'unique_ips': unique_ips,
'risk_distribution': {
'critical': critical,
'high': high,
'medium': medium,
'low': low
},
'last_scan_time': last_scan_time
}
@with_db_retry
def cleanup_old_data(self, days: int = 90) -> int:
"""
Remove scan data older than specified days.
This is a maintenance operation that removes old scans and their
associated findings to manage database size.
Args:
days: Age threshold in days. Scans older than this will be deleted.
Default is 90 days.
Returns:
Number of scans deleted (findings are cascade deleted).
Raises:
ValueError: If days is less than 1.
SQLAlchemyError: If database operation fails.
Example:
>>> # Remove data older than 30 days
>>> deleted = db.cleanup_old_data(days=30)
>>> print(f"Removed {deleted} old scans")
"""
if days < 1:
raise ValueError("Days must be at least 1")
cutoff = datetime.utcnow() - timedelta(days=days)
logger.info(f"Cleaning up data older than {cutoff.isoformat()}")
with self.session_scope() as session:
# Count first for logging
old_scans = session.query(Scan).filter(Scan.timestamp < cutoff).all()
count = len(old_scans)
if count == 0:
logger.info("No old data to clean up")
return 0
for scan in old_scans:
session.delete(scan)
logger.info(f"Cleaned up {count} scans older than {days} days")
return count
+26
View File
@@ -0,0 +1,26 @@
"""Utility modules for AASRT."""
from .config import Config
from .logger import setup_logger, get_logger
from .exceptions import (
AASRTException,
APIException,
RateLimitException,
ConfigurationException,
ValidationException
)
from .validators import validate_ip, validate_domain, validate_query
__all__ = [
'Config',
'setup_logger',
'get_logger',
'AASRTException',
'APIException',
'RateLimitException',
'ConfigurationException',
'ValidationException',
'validate_ip',
'validate_domain',
'validate_query'
]
+513
View File
@@ -0,0 +1,513 @@
"""
Configuration management for AASRT.
This module provides a production-ready configuration management system with:
- Singleton pattern for global configuration access
- YAML file loading with deep merging
- Environment variable overrides
- Validation of required settings
- Support for structured logging configuration
- Health check capabilities
Configuration priority (highest to lowest):
1. Environment variables
2. YAML configuration file
3. Default values
Example:
>>> from src.utils.config import Config
>>> config = Config()
>>> shodan_key = config.get_shodan_key()
>>> log_level = config.get('logging', 'level', default='INFO')
Environment Variables:
SHODAN_API_KEY: Required Shodan API key
AASRT_LOG_LEVEL: Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
AASRT_ENVIRONMENT: Deployment environment (development, staging, production)
AASRT_DEBUG: Enable debug mode (true/false)
DB_TYPE: Database type (sqlite, postgresql)
DB_HOST, DB_PORT, DB_NAME, DB_USER, DB_PASSWORD: PostgreSQL settings
"""
import os
import secrets
from pathlib import Path
from typing import Any, Dict, List, Optional, Set
import yaml
from dotenv import load_dotenv
from .exceptions import ConfigurationException
from .logger import get_logger
logger = get_logger(__name__)
# =============================================================================
# Validation Constants
# =============================================================================
VALID_LOG_LEVELS: Set[str] = {"DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"}
VALID_ENVIRONMENTS: Set[str] = {"development", "staging", "production"}
VALID_DB_TYPES: Set[str] = {"sqlite", "postgresql", "mysql"}
REQUIRED_SETTINGS: List[str] = [] # API key is optional until scan is run
class Config:
"""
Configuration manager for AASRT with singleton pattern.
This class provides centralized configuration management with:
- Thread-safe singleton access
- YAML file configuration
- Environment variable overrides
- Validation of critical settings
- Health check for configuration state
Attributes:
_instance: Singleton instance.
_config: Configuration dictionary.
_initialized: Flag indicating initialization status.
_config_path: Path to loaded configuration file.
_environment: Current deployment environment.
Example:
>>> config = Config()
>>> api_key = config.get_shodan_key()
>>> if not api_key:
... print("Warning: Shodan API key not configured")
"""
_instance: Optional['Config'] = None
_config: Dict[str, Any] = {}
def __new__(cls, config_path: Optional[str] = None):
"""
Singleton pattern implementation.
Args:
config_path: Optional path to YAML configuration file.
Returns:
Singleton Config instance.
"""
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self, config_path: Optional[str] = None) -> None:
"""
Initialize configuration from multiple sources.
Configuration is loaded in order of priority:
1. Default values
2. YAML configuration file
3. Environment variables (highest priority)
Args:
config_path: Path to YAML configuration file.
If not provided, searches common locations.
Raises:
ConfigurationException: If YAML file is malformed.
"""
if self._initialized:
return
# Load environment variables from .env file
load_dotenv()
# Store metadata
self._config_path: Optional[str] = None
self._environment: str = os.getenv('AASRT_ENVIRONMENT', 'development')
self._validation_errors: List[str] = []
# Default configuration
self._config = self._get_defaults()
# Load from file if provided
if config_path:
self._load_from_file(config_path)
else:
# Try to find config file in common locations
for path in ['config.yaml', 'config.yml', './config/config.yaml']:
if os.path.exists(path):
self._load_from_file(path)
break
# Override with environment variables
self._load_from_env()
# Validate configuration
self._validate_config()
self._initialized = True
logger.info(f"Configuration initialized (environment: {self._environment})")
def _get_defaults(self) -> Dict[str, Any]:
"""Get default configuration values."""
return {
'shodan': {
'enabled': True,
'rate_limit': 1,
'max_results': 100,
'timeout': 30
},
'vulnerability_checks': {
'enabled': True,
'passive_only': True,
'timeout_per_check': 10
},
'reporting': {
'formats': ['json', 'csv'],
'output_dir': './reports',
'anonymize_by_default': False
},
'filtering': {
'whitelist_ips': [],
'whitelist_domains': [],
'min_confidence_score': 70,
'exclude_honeypots': True
},
'logging': {
'level': 'INFO',
'file': './logs/scanner.log',
'max_size_mb': 100,
'backup_count': 5
},
'database': {
'type': 'sqlite',
'sqlite': {
'path': './data/scanner.db'
}
},
'api_keys': {},
'clawsec': {
'enabled': True,
'feed_url': 'https://clawsec.prompt.security/advisories/feed.json',
'cache_ttl_seconds': 86400, # 24 hours
'cache_file': './data/clawsec_cache.json',
'offline_mode': False,
'timeout': 30,
'auto_refresh': True
}
}
def _load_from_file(self, path: str) -> None:
"""
Load configuration from YAML file.
Args:
path: Path to YAML configuration file.
Raises:
ConfigurationException: If YAML is malformed.
"""
try:
with open(path, 'r') as f:
file_config = yaml.safe_load(f)
if file_config:
self._deep_merge(self._config, file_config)
self._config_path = path
logger.info(f"Loaded configuration from {path}")
except FileNotFoundError:
logger.warning(f"Configuration file not found: {path}")
except yaml.YAMLError as e:
raise ConfigurationException(f"Invalid YAML in configuration file: {e}")
def _load_from_env(self) -> None:
"""
Load settings from environment variables.
Environment variables override file-based configuration.
This method handles all supported environment variables.
"""
# Load Shodan API key
shodan_key = os.getenv('SHODAN_API_KEY')
if shodan_key:
self._set_nested(('api_keys', 'shodan'), shodan_key)
# Load log level if set
log_level = os.getenv('AASRT_LOG_LEVEL', '').upper()
if log_level and log_level in VALID_LOG_LEVELS:
self._set_nested(('logging', 'level'), log_level)
elif log_level:
logger.warning(f"Invalid log level '{log_level}', using default")
# Load environment setting
env = os.getenv('AASRT_ENVIRONMENT', '').lower()
if env and env in VALID_ENVIRONMENTS:
self._environment = env
# Load debug flag
debug = os.getenv('AASRT_DEBUG', '').lower()
if debug in ('true', '1', 'yes'):
self._set_nested(('logging', 'level'), 'DEBUG')
# Load database settings from environment
db_type = os.getenv('DB_TYPE', '').lower()
if db_type and db_type in VALID_DB_TYPES:
self._set_nested(('database', 'type'), db_type)
# PostgreSQL settings from environment
if os.getenv('DB_HOST'):
self._set_nested(('database', 'postgresql', 'host'), os.getenv('DB_HOST'))
if os.getenv('DB_PORT'):
try:
port = int(os.getenv('DB_PORT'))
self._set_nested(('database', 'postgresql', 'port'), port)
except ValueError:
logger.warning("Invalid DB_PORT, using default")
if os.getenv('DB_NAME'):
self._set_nested(('database', 'postgresql', 'database'), os.getenv('DB_NAME'))
if os.getenv('DB_USER'):
self._set_nested(('database', 'postgresql', 'user'), os.getenv('DB_USER'))
if os.getenv('DB_PASSWORD'):
self._set_nested(('database', 'postgresql', 'password'), os.getenv('DB_PASSWORD'))
if os.getenv('DB_SSL_MODE'):
self._set_nested(('database', 'postgresql', 'ssl_mode'), os.getenv('DB_SSL_MODE'))
# Max results limit
max_results = os.getenv('AASRT_MAX_RESULTS')
if max_results:
try:
self._set_nested(('shodan', 'max_results'), int(max_results))
except ValueError:
logger.warning("Invalid AASRT_MAX_RESULTS, using default")
def _validate_config(self) -> None:
"""
Validate configuration settings.
Checks for valid values and logs warnings for potential issues.
Does not raise exceptions to allow graceful degradation.
"""
self._validation_errors = []
# Validate log level
log_level = self.get('logging', 'level', default='INFO')
if log_level.upper() not in VALID_LOG_LEVELS:
self._validation_errors.append(f"Invalid log level: {log_level}")
# Validate database type
db_type = self.get('database', 'type', default='sqlite')
if db_type.lower() not in VALID_DB_TYPES:
self._validation_errors.append(f"Invalid database type: {db_type}")
# Validate max results is positive
max_results = self.get('shodan', 'max_results', default=100)
if not isinstance(max_results, int) or max_results < 1:
self._validation_errors.append(f"Invalid max_results: {max_results}")
# Check for Shodan API key (warning, not error)
if not self.get_shodan_key():
logger.debug("Shodan API key not configured - scans will require it")
# Log validation errors
for error in self._validation_errors:
logger.warning(f"Configuration validation: {error}")
def _deep_merge(self, base: Dict, overlay: Dict) -> None:
"""
Deep merge overlay dictionary into base dictionary.
Args:
base: Base dictionary to merge into (modified in place).
overlay: Overlay dictionary to merge from.
"""
for key, value in overlay.items():
if key in base and isinstance(base[key], dict) and isinstance(value, dict):
self._deep_merge(base[key], value)
else:
base[key] = value
def _set_nested(self, path: tuple, value: Any) -> None:
"""
Set a nested configuration value by key path.
Args:
path: Tuple of keys representing the path.
value: Value to set at the path.
"""
current = self._config
for key in path[:-1]:
if key not in current:
current[key] = {}
current = current[key]
current[path[-1]] = value
def get(self, *keys: str, default: Any = None) -> Any:
"""
Get a configuration value by nested keys.
Args:
*keys: Nested keys to traverse (e.g., 'database', 'type').
default: Default value if path not found.
Returns:
Configuration value or default.
Example:
>>> config.get('shodan', 'max_results', default=100)
100
"""
current = self._config
for key in keys:
if isinstance(current, dict) and key in current:
current = current[key]
else:
return default
return current
def get_shodan_key(self) -> Optional[str]:
"""
Get Shodan API key.
Returns:
Shodan API key string, or None if not configured.
"""
return self.get('api_keys', 'shodan')
def get_shodan_config(self) -> Dict[str, Any]:
"""
Get Shodan configuration dictionary.
Returns:
Dictionary with Shodan settings (enabled, rate_limit, max_results, timeout).
"""
return self.get('shodan', default={})
def get_clawsec_config(self) -> Dict[str, Any]:
"""
Get ClawSec configuration dictionary.
Returns:
Dictionary with ClawSec settings.
"""
return self.get('clawsec', default={})
def is_clawsec_enabled(self) -> bool:
"""
Check if ClawSec integration is enabled.
Returns:
True if ClawSec vulnerability lookup is enabled.
"""
return self.get('clawsec', 'enabled', default=True)
def get_database_config(self) -> Dict[str, Any]:
"""
Get database configuration.
Returns:
Dictionary with database settings.
"""
return self.get('database', default={})
def get_logging_config(self) -> Dict[str, Any]:
"""
Get logging configuration.
Returns:
Dictionary with logging settings (level, file, max_size_mb, backup_count).
"""
return self.get('logging', default={})
@property
def environment(self) -> str:
"""
Get current deployment environment.
Returns:
Environment string (development, staging, production).
"""
return self._environment
@property
def is_production(self) -> bool:
"""
Check if running in production environment.
Returns:
True if environment is 'production'.
"""
return self._environment == 'production'
@property
def is_debug(self) -> bool:
"""
Check if debug mode is enabled.
Returns:
True if log level is DEBUG.
"""
return self.get('logging', 'level', default='INFO').upper() == 'DEBUG'
@property
def all(self) -> Dict[str, Any]:
"""
Get all configuration as dictionary.
Returns:
Copy of complete configuration dictionary.
"""
return self._config.copy()
def reload(self, config_path: Optional[str] = None) -> None:
"""
Reload configuration from file and environment.
Use this to refresh configuration without restarting the application.
Args:
config_path: Optional path to configuration file.
If None, uses previously loaded file path.
"""
logger.info("Reloading configuration...")
self._initialized = False
self._config = self._get_defaults()
# Use new path or fall back to previously loaded path
path_to_load = config_path or self._config_path
if path_to_load:
self._load_from_file(path_to_load)
self._load_from_env()
self._validate_config()
self._initialized = True
logger.info("Configuration reloaded successfully")
def health_check(self) -> Dict[str, Any]:
"""
Perform a health check on configuration.
Returns:
Dictionary with health status:
- healthy: bool indicating if configuration is valid
- environment: Current deployment environment
- config_file: Path to loaded config file (if any)
- validation_errors: List of validation errors
- shodan_configured: Whether Shodan API key is set
- clawsec_enabled: Whether ClawSec is enabled
"""
return {
"healthy": len(self._validation_errors) == 0,
"environment": self._environment,
"config_file": self._config_path,
"validation_errors": self._validation_errors.copy(),
"shodan_configured": bool(self.get_shodan_key()),
"clawsec_enabled": self.is_clawsec_enabled(),
"log_level": self.get('logging', 'level', default='INFO'),
"database_type": self.get('database', 'type', default='sqlite')
}
@staticmethod
def reset_instance() -> None:
"""
Reset the singleton instance (for testing).
Warning:
This should only be used in tests. It will cause any
existing references to the old instance to be stale.
"""
Config._instance = None
+51
View File
@@ -0,0 +1,51 @@
"""Custom exceptions for AASRT."""
class AASRTException(Exception):
"""Base exception for AASRT."""
pass
class APIException(AASRTException):
"""Raised when API call fails."""
def __init__(self, message: str, engine: str = None, status_code: int = None):
self.engine = engine
self.status_code = status_code
super().__init__(message)
class RateLimitException(AASRTException):
"""Raised when rate limit is exceeded."""
def __init__(self, message: str, engine: str = None, retry_after: int = None):
self.engine = engine
self.retry_after = retry_after
super().__init__(message)
class ConfigurationException(AASRTException):
"""Raised when configuration is invalid."""
pass
class ValidationException(AASRTException):
"""Raised when input validation fails."""
pass
class AuthenticationException(AASRTException):
"""Raised when authentication fails."""
def __init__(self, message: str, engine: str = None):
self.engine = engine
super().__init__(message)
class TimeoutException(AASRTException):
"""Raised when a request times out."""
def __init__(self, message: str, engine: str = None, timeout: int = None):
self.engine = engine
self.timeout = timeout
super().__init__(message)
+91
View File
@@ -0,0 +1,91 @@
"""Logging setup for AASRT."""
import logging
import os
from logging.handlers import RotatingFileHandler
from typing import Optional
_loggers = {}
def setup_logger(
name: str = "aasrt",
level: str = "INFO",
log_file: Optional[str] = None,
max_size_mb: int = 100,
backup_count: int = 5
) -> logging.Logger:
"""
Setup and configure a logger.
Args:
name: Logger name
level: Log level (DEBUG, INFO, WARNING, ERROR)
log_file: Path to log file (optional)
max_size_mb: Max log file size in MB
backup_count: Number of backup files to keep
Returns:
Configured logger instance
"""
if name in _loggers:
return _loggers[name]
logger = logging.getLogger(name)
logger.setLevel(getattr(logging, level.upper(), logging.INFO))
# Prevent duplicate handlers
if logger.handlers:
return logger
# Console handler with colors
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.DEBUG)
# Format with colors for console
console_format = logging.Formatter(
'%(asctime)s | %(levelname)-8s | %(name)s | %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
console_handler.setFormatter(console_format)
logger.addHandler(console_handler)
# File handler if log_file specified
if log_file:
# Ensure directory exists
log_dir = os.path.dirname(log_file)
if log_dir and not os.path.exists(log_dir):
os.makedirs(log_dir, exist_ok=True)
file_handler = RotatingFileHandler(
log_file,
maxBytes=max_size_mb * 1024 * 1024,
backupCount=backup_count
)
file_handler.setLevel(logging.DEBUG)
file_format = logging.Formatter(
'%(asctime)s | %(levelname)-8s | %(name)s | %(filename)s:%(lineno)d | %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
file_handler.setFormatter(file_format)
logger.addHandler(file_handler)
_loggers[name] = logger
return logger
def get_logger(name: str = "aasrt") -> logging.Logger:
"""
Get an existing logger or create a new one.
Args:
name: Logger name
Returns:
Logger instance
"""
if name in _loggers:
return _loggers[name]
return setup_logger(name)
+583
View File
@@ -0,0 +1,583 @@
"""
Input validation utilities for AASRT.
This module provides comprehensive input validation and sanitization functions
for security-sensitive operations including:
- IP address and domain validation
- Port number and query string validation
- File path sanitization (directory traversal prevention)
- API key format validation
- Template name whitelist validation
- Configuration value validation
All validators raise ValidationException on invalid input with descriptive
error messages for debugging.
Example:
>>> from src.utils.validators import validate_ip, validate_file_path
>>> validate_ip("192.168.1.1") # Returns True
>>> validate_file_path("../../../etc/passwd") # Raises ValidationException
"""
import re
import os
import ipaddress
from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Union
import validators
from .exceptions import ValidationException
# =============================================================================
# Constants
# =============================================================================
# Valid log levels for configuration
VALID_LOG_LEVELS: Set[str] = {"DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"}
# Valid environment names
VALID_ENVIRONMENTS: Set[str] = {"development", "staging", "production"}
# Valid database types
VALID_DB_TYPES: Set[str] = {"sqlite", "postgresql", "mysql"}
# Valid report formats
VALID_REPORT_FORMATS: Set[str] = {"json", "csv", "html", "pdf"}
# Valid query template names (whitelist)
VALID_TEMPLATES: Set[str] = {
"clawdbot_instances",
"autogpt_instances",
"langchain_agents",
"openai_agents",
"anthropic_agents",
"ai_agent_general",
"agent_gpt",
"babyagi_instances",
"crewai_instances",
"autogen_instances",
"superagi_instances",
"flowise_instances",
"dify_instances",
}
# Maximum limits for various inputs
MAX_QUERY_LENGTH: int = 2000
MAX_RESULTS_LIMIT: int = 10000
MIN_RESULTS_LIMIT: int = 1
MAX_PORT: int = 65535
MIN_PORT: int = 1
MAX_FILE_PATH_LENGTH: int = 4096
MAX_API_KEY_LENGTH: int = 256
# =============================================================================
# IP and Network Validators
# =============================================================================
def validate_ip(ip: str) -> bool:
"""
Validate an IP address (IPv4 or IPv6).
Args:
ip: IP address string to validate.
Returns:
True if the IP address is valid.
Raises:
ValidationException: If IP is None, empty, or invalid format.
Example:
>>> validate_ip("192.168.1.1")
True
>>> validate_ip("2001:db8::1")
True
>>> validate_ip("invalid")
ValidationException: Invalid IP address: invalid
"""
if ip is None:
raise ValidationException("IP address cannot be None")
if not isinstance(ip, str):
raise ValidationException(f"IP address must be a string, got {type(ip).__name__}")
ip = ip.strip()
if not ip:
raise ValidationException("IP address cannot be empty")
try:
ipaddress.ip_address(ip)
return True
except ValueError:
raise ValidationException(f"Invalid IP address: {ip}")
def validate_domain(domain: str) -> bool:
"""
Validate a domain name.
Args:
domain: Domain name string
Returns:
True if valid
Raises:
ValidationException: If domain is invalid
"""
if validators.domain(domain):
return True
raise ValidationException(f"Invalid domain: {domain}")
def validate_query(query: str, engine: str) -> bool:
"""
Validate a search query for a specific engine.
Args:
query: Search query string
engine: Search engine name
Returns:
True if valid
Raises:
ValidationException: If query is invalid
"""
if not query or not query.strip():
raise ValidationException("Query cannot be empty")
# Check for potentially dangerous characters
dangerous_patterns = [
r'[<>]', # Script injection attempts
r'\x00', # Null bytes
]
for pattern in dangerous_patterns:
if re.search(pattern, query):
raise ValidationException(f"Query contains invalid characters: {pattern}")
# Engine-specific validation
if engine == "shodan":
# Shodan queries should be reasonable length
if len(query) > 1000:
raise ValidationException("Shodan query too long (max 1000 chars)")
elif engine == "censys":
# Censys queries should be reasonable length
if len(query) > 2000:
raise ValidationException("Censys query too long (max 2000 chars)")
return True
def validate_port(port: int) -> bool:
"""
Validate a port number.
Args:
port: Port number
Returns:
True if valid
Raises:
ValidationException: If port is invalid
"""
if not isinstance(port, int) or port < 1 or port > 65535:
raise ValidationException(f"Invalid port number: {port}")
return True
def validate_api_key(api_key: str, engine: str) -> bool:
"""
Validate API key format for a specific engine.
Args:
api_key: API key string
engine: Search engine name
Returns:
True if valid
Raises:
ValidationException: If API key format is invalid
"""
if not api_key or not api_key.strip():
raise ValidationException(f"API key for {engine} cannot be empty")
# Basic format validation (not checking actual validity)
if engine == "shodan":
# Shodan API keys are typically 32 characters
if len(api_key) < 20:
raise ValidationException("Shodan API key appears too short")
return True
def sanitize_output(text: str) -> str:
"""
Sanitize text for safe output (remove potential secrets).
This function redacts sensitive patterns like API keys, passwords, and
authentication tokens to prevent accidental exposure in logs or output.
Args:
text: Text to sanitize.
Returns:
Sanitized text with sensitive data replaced by REDACTED markers.
Example:
>>> sanitize_output("key: sk-ant-abc123...")
'key: sk-ant-***REDACTED***'
"""
if text is None:
return ""
if not isinstance(text, str):
text = str(text)
# Patterns for sensitive data (order matters - more specific first)
patterns = [
# Anthropic API keys
(r'sk-ant-[a-zA-Z0-9-_]{20,}', 'sk-ant-***REDACTED***'),
# OpenAI API keys
(r'sk-[a-zA-Z0-9]{40,}', 'sk-***REDACTED***'),
# AWS Access Key
(r'AKIA[0-9A-Z]{16}', 'AKIA***REDACTED***'),
# AWS Secret Key
(r'(?i)aws_secret_access_key["\s:=]+["\']?[A-Za-z0-9/+=]{40}', 'aws_secret_access_key=***REDACTED***'),
# GitHub tokens
(r'ghp_[a-zA-Z0-9]{36}', 'ghp_***REDACTED***'),
(r'gho_[a-zA-Z0-9]{36}', 'gho_***REDACTED***'),
# Google API keys
(r'AIza[0-9A-Za-z-_]{35}', 'AIza***REDACTED***'),
# Stripe keys
(r'sk_live_[a-zA-Z0-9]{24,}', 'sk_live_***REDACTED***'),
(r'sk_test_[a-zA-Z0-9]{24,}', 'sk_test_***REDACTED***'),
# Shodan API key (32 hex chars)
(r'[a-fA-F0-9]{32}', '***REDACTED_KEY***'),
# Generic password patterns
(r'password["\s:=]+["\']?[\w@#$%^&*!?]+', 'password=***REDACTED***'),
(r'passwd["\s:=]+["\']?[\w@#$%^&*!?]+', 'passwd=***REDACTED***'),
(r'secret["\s:=]+["\']?[\w@#$%^&*!?]+', 'secret=***REDACTED***'),
# Bearer tokens
(r'Bearer\s+[a-zA-Z0-9._-]+', 'Bearer ***REDACTED***'),
# Basic auth
(r'Basic\s+[a-zA-Z0-9+/=]+', 'Basic ***REDACTED***'),
]
result = text
for pattern, replacement in patterns:
result = re.sub(pattern, replacement, result, flags=re.IGNORECASE)
return result
# =============================================================================
# File Path Validators
# =============================================================================
def validate_file_path(
path: str,
must_exist: bool = False,
allow_absolute: bool = True,
base_dir: Optional[str] = None
) -> str:
"""
Validate and sanitize a file path to prevent directory traversal attacks.
Args:
path: File path to validate.
must_exist: If True, the file must exist.
allow_absolute: If True, allow absolute paths.
base_dir: If provided, ensure path is within this directory.
Returns:
Sanitized, normalized file path.
Raises:
ValidationException: If path is invalid or potentially dangerous.
Example:
>>> validate_file_path("reports/scan.json")
'reports/scan.json'
>>> validate_file_path("../../../etc/passwd")
ValidationException: Path traversal detected
"""
if path is None:
raise ValidationException("File path cannot be None")
if not isinstance(path, str):
raise ValidationException(f"File path must be a string, got {type(path).__name__}")
path = path.strip()
if not path:
raise ValidationException("File path cannot be empty")
if len(path) > MAX_FILE_PATH_LENGTH:
raise ValidationException(f"File path too long (max {MAX_FILE_PATH_LENGTH} chars)")
# Check for null bytes (security risk)
if '\x00' in path:
raise ValidationException("File path contains null bytes")
# Normalize the path
try:
normalized = os.path.normpath(path)
except Exception as e:
raise ValidationException(f"Invalid file path: {e}")
# Check for directory traversal
if '..' in normalized.split(os.sep):
raise ValidationException("Path traversal detected: '..' not allowed")
# Check absolute path restriction
if not allow_absolute and os.path.isabs(normalized):
raise ValidationException("Absolute paths not allowed")
# Check if within base directory
if base_dir:
base_dir = os.path.abspath(base_dir)
full_path = os.path.abspath(os.path.join(base_dir, normalized))
if not full_path.startswith(base_dir):
raise ValidationException("Path escapes base directory")
# Check existence if required
if must_exist and not os.path.exists(path):
raise ValidationException(f"File does not exist: {path}")
return normalized
# =============================================================================
# Template and Configuration Validators
# =============================================================================
def validate_template_name(template: str) -> bool:
"""
Validate a query template name against the whitelist.
Args:
template: Template name to validate.
Returns:
True if template is valid.
Raises:
ValidationException: If template is not in the allowed list.
Example:
>>> validate_template_name("clawdbot_instances")
True
>>> validate_template_name("malicious_query")
ValidationException: Invalid template name
"""
if template is None:
raise ValidationException("Template name cannot be None")
template = template.strip().lower()
if not template:
raise ValidationException("Template name cannot be empty")
if template not in VALID_TEMPLATES:
valid_list = ", ".join(sorted(VALID_TEMPLATES))
raise ValidationException(
f"Invalid template name: '{template}'. Valid templates: {valid_list}"
)
return True
def validate_max_results(max_results: Union[int, str]) -> int:
"""
Validate and normalize max_results parameter.
Args:
max_results: Maximum number of results (int or string).
Returns:
Validated integer value.
Raises:
ValidationException: If value is invalid or out of range.
Example:
>>> validate_max_results(100)
100
>>> validate_max_results("50")
50
>>> validate_max_results(-1)
ValidationException: max_results must be positive
"""
if max_results is None:
raise ValidationException("max_results cannot be None")
# Convert string to int if needed
if isinstance(max_results, str):
try:
max_results = int(max_results.strip())
except ValueError:
raise ValidationException(f"max_results must be a number, got: '{max_results}'")
if not isinstance(max_results, int):
raise ValidationException(f"max_results must be an integer, got {type(max_results).__name__}")
if max_results < MIN_RESULTS_LIMIT:
raise ValidationException(f"max_results must be at least {MIN_RESULTS_LIMIT}")
if max_results > MAX_RESULTS_LIMIT:
raise ValidationException(f"max_results cannot exceed {MAX_RESULTS_LIMIT}")
return max_results
def validate_log_level(level: str) -> str:
"""
Validate a log level string.
Args:
level: Log level string.
Returns:
Normalized uppercase log level.
Raises:
ValidationException: If log level is invalid.
"""
if level is None:
raise ValidationException("Log level cannot be None")
level = str(level).strip().upper()
if level not in VALID_LOG_LEVELS:
valid_list = ", ".join(sorted(VALID_LOG_LEVELS))
raise ValidationException(f"Invalid log level: '{level}'. Valid levels: {valid_list}")
return level
def validate_environment(env: str) -> str:
"""
Validate an environment name.
Args:
env: Environment name string.
Returns:
Normalized lowercase environment name.
Raises:
ValidationException: If environment is invalid.
"""
if env is None:
raise ValidationException("Environment cannot be None")
env = str(env).strip().lower()
if env not in VALID_ENVIRONMENTS:
valid_list = ", ".join(sorted(VALID_ENVIRONMENTS))
raise ValidationException(f"Invalid environment: '{env}'. Valid environments: {valid_list}")
return env
def validate_db_type(db_type: str) -> str:
"""
Validate a database type.
Args:
db_type: Database type string.
Returns:
Normalized lowercase database type.
Raises:
ValidationException: If database type is invalid.
"""
if db_type is None:
raise ValidationException("Database type cannot be None")
db_type = str(db_type).strip().lower()
if db_type not in VALID_DB_TYPES:
valid_list = ", ".join(sorted(VALID_DB_TYPES))
raise ValidationException(f"Invalid database type: '{db_type}'. Valid types: {valid_list}")
return db_type
# =============================================================================
# Batch Validation Helpers
# =============================================================================
def validate_config_dict(config: Dict[str, Any]) -> Dict[str, Any]:
"""
Validate a configuration dictionary.
Args:
config: Configuration dictionary to validate.
Returns:
Validated configuration dictionary.
Raises:
ValidationException: If any configuration value is invalid.
"""
validated = {}
# Validate log level if present
if 'logging' in config and 'level' in config['logging']:
config['logging']['level'] = validate_log_level(config['logging']['level'])
# Validate database type if present
if 'database' in config and 'type' in config['database']:
config['database']['type'] = validate_db_type(config['database']['type'])
# Validate max_results if present
if 'shodan' in config and 'max_results' in config['shodan']:
config['shodan']['max_results'] = validate_max_results(config['shodan']['max_results'])
return config
def is_safe_string(text: str, max_length: int = 1000) -> bool:
"""
Check if a string is safe (no injection attempts).
Args:
text: Text to check.
max_length: Maximum allowed length.
Returns:
True if string appears safe, False otherwise.
"""
if text is None:
return False
if len(text) > max_length:
return False
# Check for null bytes
if '\x00' in text:
return False
# Check for common injection patterns
dangerous_patterns = [
r'<script',
r'javascript:',
r'on\w+\s*=',
r'\x00',
r'<!--',
r'--\s*>',
]
for pattern in dangerous_patterns:
if re.search(pattern, text, re.IGNORECASE):
return False
return True