mirror of
https://github.com/0xsrb/AASRT.git
synced 2026-04-23 07:36:00 +02:00
Initial commit: AASRT v1.0.0 - AI Agent Security Reconnaissance Tool
This commit is contained in:
@@ -0,0 +1,4 @@
|
||||
"""AI Agent Security Reconnaissance Tool (AASRT)"""
|
||||
|
||||
__version__ = "1.0.0"
|
||||
__author__ = "AGK"
|
||||
@@ -0,0 +1,12 @@
|
||||
"""Alert modules for AASRT.
|
||||
|
||||
This module will contain alerting capabilities:
|
||||
- Email notifications
|
||||
- Slack webhooks
|
||||
- Discord webhooks
|
||||
- Telegram bot integration
|
||||
|
||||
These are planned for Phase 3 implementation.
|
||||
"""
|
||||
|
||||
__all__ = []
|
||||
@@ -0,0 +1,14 @@
|
||||
"""Core engine components for AASRT."""
|
||||
|
||||
from .query_manager import QueryManager
|
||||
from .result_aggregator import ResultAggregator
|
||||
from .vulnerability_assessor import VulnerabilityAssessor, Vulnerability
|
||||
from .risk_scorer import RiskScorer
|
||||
|
||||
__all__ = [
|
||||
'QueryManager',
|
||||
'ResultAggregator',
|
||||
'VulnerabilityAssessor',
|
||||
'Vulnerability',
|
||||
'RiskScorer'
|
||||
]
|
||||
@@ -0,0 +1,252 @@
|
||||
"""Query management and execution for AASRT."""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import yaml
|
||||
|
||||
from src.engines import SearchResult, ShodanEngine
|
||||
from src.utils.config import Config
|
||||
from src.utils.logger import get_logger
|
||||
from src.utils.exceptions import APIException, ConfigurationException
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class QueryManager:
|
||||
"""Manages search queries using Shodan."""
|
||||
|
||||
# Built-in query templates
|
||||
DEFAULT_TEMPLATES = {
|
||||
"clawdbot_instances": [
|
||||
'http.title:"ClawdBot Dashboard"',
|
||||
'http.html:"ClawdBot" port:3000',
|
||||
'product:"ClawdBot"'
|
||||
],
|
||||
"autogpt_instances": [
|
||||
'http.title:"Auto-GPT"',
|
||||
'http.html:"autogpt" port:8000'
|
||||
],
|
||||
"langchain_agents": [
|
||||
'http.html:"langchain" http.html:"agent"',
|
||||
'product:"LangChain"'
|
||||
],
|
||||
"openai_exposed": [
|
||||
'http.title:"OpenAI Playground"',
|
||||
'http.html:"sk-" http.html:"openai"'
|
||||
],
|
||||
"exposed_env_files": [
|
||||
'http.html:".env" http.html:"API_KEY"',
|
||||
'http.title:"Index of" http.html:".env"'
|
||||
],
|
||||
"debug_mode": [
|
||||
'http.html:"DEBUG=True"',
|
||||
'http.html:"development mode"',
|
||||
'http.html:"stack trace"'
|
||||
],
|
||||
"ai_dashboards": [
|
||||
'http.title:"AI Dashboard"',
|
||||
'http.title:"LLM" http.html:"chat"',
|
||||
'http.html:"anthropic" http.html:"claude"'
|
||||
],
|
||||
"jupyter_notebooks": [
|
||||
'http.title:"Jupyter Notebook"',
|
||||
'http.title:"JupyterLab"',
|
||||
'http.html:"jupyter" port:8888'
|
||||
],
|
||||
"streamlit_apps": [
|
||||
'http.html:"streamlit"',
|
||||
'http.title:"Streamlit"'
|
||||
]
|
||||
}
|
||||
|
||||
def __init__(self, config: Optional[Config] = None):
|
||||
"""
|
||||
Initialize QueryManager.
|
||||
|
||||
Args:
|
||||
config: Configuration instance
|
||||
"""
|
||||
self.config = config or Config()
|
||||
self.engine: Optional[ShodanEngine] = None
|
||||
self.templates: Dict[str, List[str]] = self.DEFAULT_TEMPLATES.copy()
|
||||
|
||||
self._initialize_engine()
|
||||
self._load_custom_templates()
|
||||
|
||||
def _initialize_engine(self) -> None:
|
||||
"""Initialize Shodan engine."""
|
||||
api_key = self.config.get_shodan_key()
|
||||
if api_key:
|
||||
shodan_config = self.config.get_shodan_config()
|
||||
self.engine = ShodanEngine(
|
||||
api_key=api_key,
|
||||
rate_limit=shodan_config.get('rate_limit', 1.0),
|
||||
timeout=shodan_config.get('timeout', 30),
|
||||
max_results=shodan_config.get('max_results', 100)
|
||||
)
|
||||
logger.info("Shodan engine initialized")
|
||||
else:
|
||||
logger.warning("Shodan API key not provided")
|
||||
|
||||
def _load_custom_templates(self) -> None:
|
||||
"""Load custom query templates from YAML files."""
|
||||
queries_dir = Path("queries")
|
||||
if not queries_dir.exists():
|
||||
return
|
||||
|
||||
for yaml_file in queries_dir.glob("*.yaml"):
|
||||
try:
|
||||
with open(yaml_file, 'r') as f:
|
||||
data = yaml.safe_load(f)
|
||||
if data and 'queries' in data:
|
||||
template_name = yaml_file.stem
|
||||
# Support both list format and dict format
|
||||
queries = data['queries']
|
||||
if isinstance(queries, dict) and 'shodan' in queries:
|
||||
self.templates[template_name] = queries['shodan']
|
||||
elif isinstance(queries, list):
|
||||
self.templates[template_name] = queries
|
||||
logger.debug(f"Loaded query template: {template_name}")
|
||||
except yaml.YAMLError as e:
|
||||
logger.error(f"Failed to parse {yaml_file}: {e}")
|
||||
|
||||
def is_available(self) -> bool:
|
||||
"""Check if Shodan engine is available."""
|
||||
return self.engine is not None
|
||||
|
||||
def get_available_templates(self) -> List[str]:
|
||||
"""Get list of available query templates."""
|
||||
return list(self.templates.keys())
|
||||
|
||||
def validate_engine(self) -> bool:
|
||||
"""
|
||||
Validate Shodan credentials.
|
||||
|
||||
Returns:
|
||||
True if credentials are valid
|
||||
"""
|
||||
if not self.engine:
|
||||
return False
|
||||
try:
|
||||
return self.engine.validate_credentials()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to validate Shodan: {e}")
|
||||
return False
|
||||
|
||||
def get_quota_info(self) -> Dict[str, Any]:
|
||||
"""Get Shodan API quota information."""
|
||||
if not self.engine:
|
||||
return {'error': 'Engine not initialized'}
|
||||
return self.engine.get_quota_info()
|
||||
|
||||
def execute_query(
|
||||
self,
|
||||
query: str,
|
||||
max_results: Optional[int] = None
|
||||
) -> List[SearchResult]:
|
||||
"""
|
||||
Execute a search query.
|
||||
|
||||
Args:
|
||||
query: Shodan search query
|
||||
max_results: Maximum results to return
|
||||
|
||||
Returns:
|
||||
List of SearchResult objects
|
||||
"""
|
||||
if not self.engine:
|
||||
raise ConfigurationException("Shodan engine not initialized. Check your API key.")
|
||||
|
||||
try:
|
||||
results = self.engine.search(query, max_results)
|
||||
logger.info(f"Query returned {len(results)} results")
|
||||
return results
|
||||
except APIException as e:
|
||||
logger.error(f"Query failed: {e}")
|
||||
raise
|
||||
|
||||
def execute_template(
|
||||
self,
|
||||
template_name: str,
|
||||
max_results: Optional[int] = None
|
||||
) -> List[SearchResult]:
|
||||
"""
|
||||
Execute all queries from a template.
|
||||
|
||||
Args:
|
||||
template_name: Name of the query template
|
||||
max_results: Maximum results per query
|
||||
|
||||
Returns:
|
||||
Combined list of results from all queries
|
||||
"""
|
||||
if template_name not in self.templates:
|
||||
raise ConfigurationException(f"Template not found: {template_name}")
|
||||
|
||||
if not self.engine:
|
||||
raise ConfigurationException("Shodan engine not initialized. Check your API key.")
|
||||
|
||||
queries = self.templates[template_name]
|
||||
all_results = []
|
||||
|
||||
for query in queries:
|
||||
try:
|
||||
results = self.engine.search(query, max_results)
|
||||
all_results.extend(results)
|
||||
except APIException as e:
|
||||
logger.error(f"Query failed: {query} - {e}")
|
||||
|
||||
logger.info(f"Template '{template_name}' returned {len(all_results)} total results")
|
||||
return all_results
|
||||
|
||||
def count_results(self, query: str) -> int:
|
||||
"""
|
||||
Get count of results for a query without consuming credits.
|
||||
|
||||
Args:
|
||||
query: Search query
|
||||
|
||||
Returns:
|
||||
Number of results
|
||||
"""
|
||||
if not self.engine:
|
||||
return 0
|
||||
return self.engine.count(query)
|
||||
|
||||
def add_custom_template(self, name: str, queries: List[str]) -> None:
|
||||
"""
|
||||
Add a custom query template.
|
||||
|
||||
Args:
|
||||
name: Template name
|
||||
queries: List of Shodan queries
|
||||
"""
|
||||
self.templates[name] = queries
|
||||
logger.info(f"Added custom template: {name}")
|
||||
|
||||
def save_template(self, name: str, path: Optional[str] = None) -> None:
|
||||
"""
|
||||
Save a template to a YAML file.
|
||||
|
||||
Args:
|
||||
name: Template name
|
||||
path: Output file path (default: queries/{name}.yaml)
|
||||
"""
|
||||
if name not in self.templates:
|
||||
raise ConfigurationException(f"Template not found: {name}")
|
||||
|
||||
output_path = path or f"queries/{name}.yaml"
|
||||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||
|
||||
template_data = {
|
||||
'name': name,
|
||||
'description': f"Query template for {name}",
|
||||
'queries': self.templates[name]
|
||||
}
|
||||
|
||||
with open(output_path, 'w') as f:
|
||||
yaml.dump(template_data, f, default_flow_style=False)
|
||||
|
||||
logger.info(f"Saved template to {output_path}")
|
||||
@@ -0,0 +1,304 @@
|
||||
"""Result aggregation and deduplication for AASRT."""
|
||||
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
from datetime import datetime
|
||||
from collections import defaultdict
|
||||
|
||||
from src.engines import SearchResult
|
||||
from src.utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class ResultAggregator:
|
||||
"""Aggregates and deduplicates search results from multiple engines."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dedupe_by: str = "ip_port",
|
||||
merge_metadata: bool = True,
|
||||
prefer_engine: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
Initialize ResultAggregator.
|
||||
|
||||
Args:
|
||||
dedupe_by: Deduplication key ("ip_port", "ip", or "hostname")
|
||||
merge_metadata: Whether to merge metadata from duplicate results
|
||||
prefer_engine: Preferred engine when resolving conflicts
|
||||
"""
|
||||
self.dedupe_by = dedupe_by
|
||||
self.merge_metadata = merge_metadata
|
||||
self.prefer_engine = prefer_engine
|
||||
|
||||
def aggregate(
|
||||
self,
|
||||
results: Dict[str, List[SearchResult]]
|
||||
) -> List[SearchResult]:
|
||||
"""
|
||||
Aggregate results from multiple engines.
|
||||
|
||||
Args:
|
||||
results: Dictionary mapping engine names to result lists
|
||||
|
||||
Returns:
|
||||
Deduplicated and merged list of results
|
||||
"""
|
||||
all_results = []
|
||||
|
||||
# Flatten results
|
||||
for engine_name, engine_results in results.items():
|
||||
for result in engine_results:
|
||||
result.source_engine = engine_name
|
||||
all_results.append(result)
|
||||
|
||||
logger.info(f"Aggregating {len(all_results)} total results")
|
||||
|
||||
# Deduplicate
|
||||
deduplicated = self._deduplicate(all_results)
|
||||
|
||||
logger.info(f"After deduplication: {len(deduplicated)} unique results")
|
||||
|
||||
return deduplicated
|
||||
|
||||
def _get_dedupe_key(self, result: SearchResult) -> str:
|
||||
"""Get deduplication key for a result."""
|
||||
if self.dedupe_by == "ip_port":
|
||||
return f"{result.ip}:{result.port}"
|
||||
elif self.dedupe_by == "ip":
|
||||
return result.ip
|
||||
elif self.dedupe_by == "hostname":
|
||||
return result.hostname or result.ip
|
||||
else:
|
||||
return f"{result.ip}:{result.port}"
|
||||
|
||||
def _deduplicate(self, results: List[SearchResult]) -> List[SearchResult]:
|
||||
"""Deduplicate results based on configured key."""
|
||||
seen: Dict[str, SearchResult] = {}
|
||||
|
||||
for result in results:
|
||||
key = self._get_dedupe_key(result)
|
||||
|
||||
if key not in seen:
|
||||
seen[key] = result
|
||||
else:
|
||||
# Merge with existing result
|
||||
existing = seen[key]
|
||||
seen[key] = self._merge_results(existing, result)
|
||||
|
||||
return list(seen.values())
|
||||
|
||||
def _merge_results(
|
||||
self,
|
||||
existing: SearchResult,
|
||||
new: SearchResult
|
||||
) -> SearchResult:
|
||||
"""
|
||||
Merge two results for the same target.
|
||||
|
||||
Args:
|
||||
existing: Existing result
|
||||
new: New result to merge
|
||||
|
||||
Returns:
|
||||
Merged result
|
||||
"""
|
||||
# Prefer result from preferred engine
|
||||
if self.prefer_engine:
|
||||
if new.source_engine == self.prefer_engine:
|
||||
base = new
|
||||
other = existing
|
||||
else:
|
||||
base = existing
|
||||
other = new
|
||||
else:
|
||||
# Default: prefer result with more information
|
||||
if len(new.metadata) > len(existing.metadata):
|
||||
base = new
|
||||
other = existing
|
||||
else:
|
||||
base = existing
|
||||
other = new
|
||||
|
||||
# Merge vulnerabilities (union)
|
||||
merged_vulns = list(set(base.vulnerabilities + other.vulnerabilities))
|
||||
|
||||
# Merge metadata if enabled
|
||||
if self.merge_metadata:
|
||||
merged_metadata = {**other.metadata, **base.metadata}
|
||||
# Track source engines
|
||||
engines = set()
|
||||
if base.metadata.get('source_engines'):
|
||||
engines.update(base.metadata['source_engines'])
|
||||
if other.metadata.get('source_engines'):
|
||||
engines.update(other.metadata['source_engines'])
|
||||
engines.add(base.source_engine)
|
||||
engines.add(other.source_engine)
|
||||
merged_metadata['source_engines'] = list(engines)
|
||||
else:
|
||||
merged_metadata = base.metadata
|
||||
|
||||
# Take highest risk score
|
||||
risk_score = max(base.risk_score, other.risk_score)
|
||||
|
||||
# Take highest confidence
|
||||
confidence = max(base.confidence, other.confidence)
|
||||
|
||||
return SearchResult(
|
||||
ip=base.ip,
|
||||
port=base.port,
|
||||
hostname=base.hostname or other.hostname,
|
||||
service=base.service or other.service,
|
||||
banner=base.banner or other.banner,
|
||||
vulnerabilities=merged_vulns,
|
||||
metadata=merged_metadata,
|
||||
source_engine=base.source_engine,
|
||||
timestamp=base.timestamp,
|
||||
risk_score=risk_score,
|
||||
confidence=confidence
|
||||
)
|
||||
|
||||
def filter_by_confidence(
|
||||
self,
|
||||
results: List[SearchResult],
|
||||
min_confidence: int = 70
|
||||
) -> List[SearchResult]:
|
||||
"""Filter results by minimum confidence score."""
|
||||
filtered = [r for r in results if r.confidence >= min_confidence]
|
||||
logger.info(f"Filtered by confidence >= {min_confidence}: {len(filtered)} results")
|
||||
return filtered
|
||||
|
||||
def filter_by_risk_score(
|
||||
self,
|
||||
results: List[SearchResult],
|
||||
min_score: float = 0.0
|
||||
) -> List[SearchResult]:
|
||||
"""Filter results by minimum risk score."""
|
||||
filtered = [r for r in results if r.risk_score >= min_score]
|
||||
logger.info(f"Filtered by risk >= {min_score}: {len(filtered)} results")
|
||||
return filtered
|
||||
|
||||
def filter_whitelist(
|
||||
self,
|
||||
results: List[SearchResult],
|
||||
whitelist_ips: Optional[List[str]] = None,
|
||||
whitelist_domains: Optional[List[str]] = None
|
||||
) -> List[SearchResult]:
|
||||
"""Filter out whitelisted IPs and domains."""
|
||||
if not whitelist_ips and not whitelist_domains:
|
||||
return results
|
||||
|
||||
whitelist_ips = set(whitelist_ips or [])
|
||||
whitelist_domains = set(whitelist_domains or [])
|
||||
|
||||
filtered = []
|
||||
for result in results:
|
||||
if result.ip in whitelist_ips:
|
||||
continue
|
||||
if result.hostname and result.hostname in whitelist_domains:
|
||||
continue
|
||||
# Check if hostname ends with any whitelisted domain
|
||||
if result.hostname:
|
||||
skip = False
|
||||
for domain in whitelist_domains:
|
||||
if result.hostname.endswith(f".{domain}") or result.hostname == domain:
|
||||
skip = True
|
||||
break
|
||||
if skip:
|
||||
continue
|
||||
filtered.append(result)
|
||||
|
||||
excluded = len(results) - len(filtered)
|
||||
if excluded > 0:
|
||||
logger.info(f"Excluded {excluded} whitelisted results")
|
||||
|
||||
return filtered
|
||||
|
||||
def group_by_ip(
|
||||
self,
|
||||
results: List[SearchResult]
|
||||
) -> Dict[str, List[SearchResult]]:
|
||||
"""Group results by IP address."""
|
||||
grouped = defaultdict(list)
|
||||
for result in results:
|
||||
grouped[result.ip].append(result)
|
||||
return dict(grouped)
|
||||
|
||||
def group_by_service(
|
||||
self,
|
||||
results: List[SearchResult]
|
||||
) -> Dict[str, List[SearchResult]]:
|
||||
"""Group results by service type."""
|
||||
grouped = defaultdict(list)
|
||||
for result in results:
|
||||
service = result.service or "unknown"
|
||||
grouped[service].append(result)
|
||||
return dict(grouped)
|
||||
|
||||
def get_statistics(self, results: List[SearchResult]) -> Dict[str, Any]:
|
||||
"""
|
||||
Get aggregate statistics for results.
|
||||
|
||||
Args:
|
||||
results: List of search results
|
||||
|
||||
Returns:
|
||||
Statistics dictionary
|
||||
"""
|
||||
if not results:
|
||||
return {
|
||||
'total_results': 0,
|
||||
'unique_ips': 0,
|
||||
'unique_hostnames': 0,
|
||||
'engines_used': [],
|
||||
'vulnerability_counts': {},
|
||||
'risk_distribution': {},
|
||||
'top_services': []
|
||||
}
|
||||
|
||||
# Count unique IPs and hostnames
|
||||
unique_ips = set(r.ip for r in results)
|
||||
unique_hostnames = set(r.hostname for r in results if r.hostname)
|
||||
|
||||
# Count engines
|
||||
engines = set()
|
||||
for r in results:
|
||||
if r.metadata.get('source_engines'):
|
||||
engines.update(r.metadata['source_engines'])
|
||||
else:
|
||||
engines.add(r.source_engine)
|
||||
|
||||
# Count vulnerabilities
|
||||
vuln_counts = defaultdict(int)
|
||||
for r in results:
|
||||
for vuln in r.vulnerabilities:
|
||||
vuln_counts[vuln] += 1
|
||||
|
||||
# Risk distribution
|
||||
risk_dist = {
|
||||
'critical': len([r for r in results if r.risk_score >= 9.0]),
|
||||
'high': len([r for r in results if 7.0 <= r.risk_score < 9.0]),
|
||||
'medium': len([r for r in results if 4.0 <= r.risk_score < 7.0]),
|
||||
'low': len([r for r in results if r.risk_score < 4.0])
|
||||
}
|
||||
|
||||
# Top services
|
||||
service_counts = defaultdict(int)
|
||||
for r in results:
|
||||
service_counts[r.service or "unknown"] += 1
|
||||
top_services = sorted(
|
||||
service_counts.items(),
|
||||
key=lambda x: x[1],
|
||||
reverse=True
|
||||
)[:10]
|
||||
|
||||
return {
|
||||
'total_results': len(results),
|
||||
'unique_ips': len(unique_ips),
|
||||
'unique_hostnames': len(unique_hostnames),
|
||||
'engines_used': list(engines),
|
||||
'vulnerability_counts': dict(vuln_counts),
|
||||
'risk_distribution': risk_dist,
|
||||
'top_services': top_services,
|
||||
'average_risk_score': sum(r.risk_score for r in results) / len(results)
|
||||
}
|
||||
@@ -0,0 +1,313 @@
|
||||
"""Risk scoring engine for AASRT."""
|
||||
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from .vulnerability_assessor import Vulnerability
|
||||
from src.engines import SearchResult
|
||||
from src.utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class RiskScorer:
|
||||
"""Calculates risk scores for targets based on vulnerabilities."""
|
||||
|
||||
# Severity weights for scoring
|
||||
SEVERITY_WEIGHTS = {
|
||||
'CRITICAL': 1.5,
|
||||
'HIGH': 1.2,
|
||||
'MEDIUM': 1.0,
|
||||
'LOW': 0.5,
|
||||
'INFO': 0.1
|
||||
}
|
||||
|
||||
# Context multipliers
|
||||
CONTEXT_MULTIPLIERS = {
|
||||
'public_internet': 1.2,
|
||||
'no_waf': 1.1,
|
||||
'known_vulnerable_version': 1.3,
|
||||
'ai_agent': 1.2, # AI agents may have additional risk
|
||||
'clawsec_cve': 1.4, # Known ClawSec CVE vulnerability
|
||||
'clawsec_critical': 1.5, # Critical ClawSec CVE
|
||||
}
|
||||
|
||||
def __init__(self, config: Dict[str, Any] = None):
|
||||
"""
|
||||
Initialize RiskScorer.
|
||||
|
||||
Args:
|
||||
config: Configuration options
|
||||
"""
|
||||
self.config = config or {}
|
||||
|
||||
def calculate_score(
|
||||
self,
|
||||
vulnerabilities: List[Vulnerability],
|
||||
context: Dict[str, Any] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Calculate risk score based on vulnerabilities.
|
||||
|
||||
Formula:
|
||||
- Base score: Highest CVSS score found
|
||||
- Adjusted: base * (1 + 0.1 * critical_count)
|
||||
- Context multipliers applied
|
||||
- Capped at 10.0
|
||||
|
||||
Args:
|
||||
vulnerabilities: List of discovered vulnerabilities
|
||||
context: Additional context (public_internet, etc.)
|
||||
|
||||
Returns:
|
||||
Risk assessment dictionary
|
||||
"""
|
||||
if not vulnerabilities:
|
||||
return {
|
||||
'overall_score': 0.0,
|
||||
'severity_breakdown': {
|
||||
'critical': 0, 'high': 0, 'medium': 0, 'low': 0, 'info': 0
|
||||
},
|
||||
'exploitability': 'NONE',
|
||||
'impact': 'NONE',
|
||||
'confidence': 100
|
||||
}
|
||||
|
||||
context = context or {}
|
||||
|
||||
# Get base score (highest CVSS)
|
||||
base_score = max(v.cvss_score for v in vulnerabilities)
|
||||
|
||||
# Count by severity
|
||||
severity_counts = self._count_severities(vulnerabilities)
|
||||
|
||||
# Apply vulnerability count multiplier
|
||||
critical_count = severity_counts['critical']
|
||||
high_count = severity_counts['high']
|
||||
|
||||
# Increase score based on multiple vulnerabilities
|
||||
adjusted_score = base_score * (1.0 + (0.1 * critical_count) + (0.05 * high_count))
|
||||
|
||||
# Apply context multipliers
|
||||
for ctx_key, multiplier in self.CONTEXT_MULTIPLIERS.items():
|
||||
if context.get(ctx_key, False):
|
||||
adjusted_score *= multiplier
|
||||
|
||||
# Cap at 10.0
|
||||
final_score = min(adjusted_score, 10.0)
|
||||
|
||||
# Determine exploitability
|
||||
exploitability = self._calculate_exploitability(vulnerabilities, critical_count)
|
||||
|
||||
# Determine impact
|
||||
impact = self._calculate_impact(vulnerabilities)
|
||||
|
||||
return {
|
||||
'overall_score': round(final_score, 1),
|
||||
'severity_breakdown': severity_counts,
|
||||
'exploitability': exploitability,
|
||||
'impact': impact,
|
||||
'confidence': self._calculate_confidence(vulnerabilities),
|
||||
'contributing_factors': self._get_contributing_factors(vulnerabilities)
|
||||
}
|
||||
|
||||
def _count_severities(self, vulnerabilities: List[Vulnerability]) -> Dict[str, int]:
|
||||
"""Count vulnerabilities by severity level."""
|
||||
counts = {'critical': 0, 'high': 0, 'medium': 0, 'low': 0, 'info': 0}
|
||||
for v in vulnerabilities:
|
||||
severity_key = v.severity.lower()
|
||||
if severity_key in counts:
|
||||
counts[severity_key] += 1
|
||||
return counts
|
||||
|
||||
def _calculate_exploitability(
|
||||
self,
|
||||
vulnerabilities: List[Vulnerability],
|
||||
critical_count: int
|
||||
) -> str:
|
||||
"""Determine overall exploitability level."""
|
||||
if critical_count >= 2:
|
||||
return 'CRITICAL'
|
||||
elif critical_count >= 1:
|
||||
return 'HIGH'
|
||||
|
||||
# Check for easily exploitable vulnerabilities
|
||||
easy_exploit = ['api_key_exposure', 'no_authentication', 'shell_access']
|
||||
for v in vulnerabilities:
|
||||
if any(indicator in v.check_name for indicator in easy_exploit):
|
||||
return 'HIGH'
|
||||
|
||||
high_count = len([v for v in vulnerabilities if v.severity == 'HIGH'])
|
||||
if high_count >= 2:
|
||||
return 'MEDIUM'
|
||||
|
||||
return 'LOW'
|
||||
|
||||
def _calculate_impact(self, vulnerabilities: List[Vulnerability]) -> str:
|
||||
"""Determine potential impact level."""
|
||||
# Check for high-impact vulnerabilities
|
||||
high_impact_indicators = [
|
||||
'api_key_exposure',
|
||||
'shell_access',
|
||||
'database_exposed',
|
||||
'admin_panel'
|
||||
]
|
||||
|
||||
for v in vulnerabilities:
|
||||
if any(indicator in v.check_name for indicator in high_impact_indicators):
|
||||
return 'HIGH'
|
||||
|
||||
if any(v.cvss_score >= 7.0 for v in vulnerabilities):
|
||||
return 'MEDIUM'
|
||||
|
||||
return 'LOW'
|
||||
|
||||
def _calculate_confidence(self, vulnerabilities: List[Vulnerability]) -> int:
|
||||
"""Calculate confidence in the assessment."""
|
||||
if not vulnerabilities:
|
||||
return 100
|
||||
|
||||
# Start with high confidence
|
||||
confidence = 100
|
||||
|
||||
# Reduce confidence for potential false positives
|
||||
for v in vulnerabilities:
|
||||
if 'potential' in v.check_name or 'possible' in v.description.lower():
|
||||
confidence -= 10
|
||||
|
||||
return max(confidence, 0)
|
||||
|
||||
def _get_contributing_factors(self, vulnerabilities: List[Vulnerability]) -> List[str]:
|
||||
"""Get list of main contributing factors to the risk score."""
|
||||
factors = []
|
||||
|
||||
for v in vulnerabilities:
|
||||
if v.severity in ['CRITICAL', 'HIGH']:
|
||||
factors.append(f"{v.severity}: {v.description}")
|
||||
|
||||
return factors[:5] # Top 5 factors
|
||||
|
||||
def score_result(self, result: SearchResult, vulnerabilities: List[Vulnerability]) -> SearchResult:
|
||||
"""
|
||||
Apply risk score to a SearchResult.
|
||||
|
||||
Args:
|
||||
result: SearchResult to score
|
||||
vulnerabilities: Assessed vulnerabilities
|
||||
|
||||
Returns:
|
||||
Updated SearchResult with risk score
|
||||
"""
|
||||
# Build context from result metadata
|
||||
context = {
|
||||
'public_internet': True, # Assume public if found via search
|
||||
'ai_agent': self._is_ai_agent(result),
|
||||
'clawsec_cve': self._has_clawsec_cve(result),
|
||||
'clawsec_critical': self._has_critical_clawsec_cve(result)
|
||||
}
|
||||
|
||||
# Check for WAF
|
||||
http_info = result.metadata.get('http') or {}
|
||||
http_headers = http_info.get('headers', {})
|
||||
if not any(waf in str(http_headers).lower() for waf in ['cloudflare', 'akamai', 'fastly']):
|
||||
context['no_waf'] = True
|
||||
|
||||
# Calculate score
|
||||
risk_data = self.calculate_score(vulnerabilities, context)
|
||||
|
||||
# Update result
|
||||
result.risk_score = risk_data['overall_score']
|
||||
result.metadata['risk_assessment'] = risk_data
|
||||
result.vulnerabilities = [v.check_name for v in vulnerabilities]
|
||||
|
||||
return result
|
||||
|
||||
def _has_clawsec_cve(self, result: SearchResult) -> bool:
|
||||
"""Check if result has any ClawSec CVE associations."""
|
||||
return bool(result.metadata.get('clawsec_advisories'))
|
||||
|
||||
def _has_critical_clawsec_cve(self, result: SearchResult) -> bool:
|
||||
"""Check if result has a critical ClawSec CVE."""
|
||||
advisories = result.metadata.get('clawsec_advisories', [])
|
||||
return any(a.get('severity') == 'CRITICAL' for a in advisories)
|
||||
|
||||
def _is_ai_agent(self, result: SearchResult) -> bool:
|
||||
"""Check if result appears to be an AI agent."""
|
||||
ai_indicators = [
|
||||
'clawdbot', 'autogpt', 'langchain', 'openai',
|
||||
'anthropic', 'claude', 'gpt', 'agent'
|
||||
]
|
||||
|
||||
http_info = result.metadata.get('http') or {}
|
||||
http_title = http_info.get('title') or ''
|
||||
text = (
|
||||
(result.banner or '') +
|
||||
(result.service or '') +
|
||||
str(http_title)
|
||||
).lower()
|
||||
|
||||
return any(indicator in text for indicator in ai_indicators)
|
||||
|
||||
def categorize_results(
|
||||
self,
|
||||
results: List[SearchResult]
|
||||
) -> Dict[str, List[SearchResult]]:
|
||||
"""
|
||||
Categorize results by risk level.
|
||||
|
||||
Args:
|
||||
results: List of scored results
|
||||
|
||||
Returns:
|
||||
Dictionary with risk level categories
|
||||
"""
|
||||
categories = {
|
||||
'critical': [],
|
||||
'high': [],
|
||||
'medium': [],
|
||||
'low': []
|
||||
}
|
||||
|
||||
for result in results:
|
||||
if result.risk_score >= 9.0:
|
||||
categories['critical'].append(result)
|
||||
elif result.risk_score >= 7.0:
|
||||
categories['high'].append(result)
|
||||
elif result.risk_score >= 4.0:
|
||||
categories['medium'].append(result)
|
||||
else:
|
||||
categories['low'].append(result)
|
||||
|
||||
return categories
|
||||
|
||||
def get_summary(self, results: List[SearchResult]) -> Dict[str, Any]:
|
||||
"""
|
||||
Get risk summary for a set of results.
|
||||
|
||||
Args:
|
||||
results: List of scored results
|
||||
|
||||
Returns:
|
||||
Summary statistics
|
||||
"""
|
||||
if not results:
|
||||
return {
|
||||
'total': 0,
|
||||
'average_score': 0.0,
|
||||
'max_score': 0.0,
|
||||
'distribution': {'critical': 0, 'high': 0, 'medium': 0, 'low': 0}
|
||||
}
|
||||
|
||||
categories = self.categorize_results(results)
|
||||
scores = [r.risk_score for r in results]
|
||||
|
||||
return {
|
||||
'total': len(results),
|
||||
'average_score': round(sum(scores) / len(scores), 1),
|
||||
'max_score': max(scores),
|
||||
'distribution': {
|
||||
'critical': len(categories['critical']),
|
||||
'high': len(categories['high']),
|
||||
'medium': len(categories['medium']),
|
||||
'low': len(categories['low'])
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,441 @@
|
||||
"""Vulnerability assessment engine for AASRT."""
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, TYPE_CHECKING
|
||||
|
||||
from src.engines import SearchResult
|
||||
from src.utils.logger import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.enrichment import ThreatEnricher
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Vulnerability:
|
||||
"""Represents a discovered vulnerability."""
|
||||
|
||||
check_name: str
|
||||
severity: str # CRITICAL, HIGH, MEDIUM, LOW, INFO
|
||||
cvss_score: float
|
||||
description: str
|
||||
evidence: Dict[str, Any] = field(default_factory=dict)
|
||||
remediation: Optional[str] = None
|
||||
cwe_id: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
'check_name': self.check_name,
|
||||
'severity': self.severity,
|
||||
'cvss_score': self.cvss_score,
|
||||
'description': self.description,
|
||||
'evidence': self.evidence,
|
||||
'remediation': self.remediation,
|
||||
'cwe_id': self.cwe_id
|
||||
}
|
||||
|
||||
|
||||
class VulnerabilityAssessor:
|
||||
"""Performs passive vulnerability assessment on search results."""
|
||||
|
||||
# API key patterns for detection
|
||||
API_KEY_PATTERNS = {
|
||||
'anthropic': {
|
||||
'pattern': r'sk-ant-[a-zA-Z0-9-_]{20,}',
|
||||
'description': 'Anthropic API key exposed',
|
||||
'cvss': 10.0
|
||||
},
|
||||
'openai': {
|
||||
'pattern': r'sk-[a-zA-Z0-9]{32,}',
|
||||
'description': 'OpenAI API key exposed',
|
||||
'cvss': 10.0
|
||||
},
|
||||
'aws_access_key': {
|
||||
'pattern': r'AKIA[0-9A-Z]{16}',
|
||||
'description': 'AWS Access Key ID exposed',
|
||||
'cvss': 9.8
|
||||
},
|
||||
'aws_secret': {
|
||||
'pattern': r'(?<![A-Za-z0-9/+=])[A-Za-z0-9/+=]{40}(?![A-Za-z0-9/+=])',
|
||||
'description': 'Potential AWS Secret Key exposed',
|
||||
'cvss': 9.8
|
||||
},
|
||||
'github_token': {
|
||||
'pattern': r'ghp_[a-zA-Z0-9]{36}',
|
||||
'description': 'GitHub Personal Access Token exposed',
|
||||
'cvss': 9.5
|
||||
},
|
||||
'google_api': {
|
||||
'pattern': r'AIza[0-9A-Za-z\-_]{35}',
|
||||
'description': 'Google API key exposed',
|
||||
'cvss': 7.5
|
||||
},
|
||||
'stripe': {
|
||||
'pattern': r'sk_live_[0-9a-zA-Z]{24}',
|
||||
'description': 'Stripe Secret Key exposed',
|
||||
'cvss': 9.8
|
||||
}
|
||||
}
|
||||
|
||||
# Dangerous functionality patterns
|
||||
DANGEROUS_PATTERNS = {
|
||||
'shell_access': {
|
||||
'patterns': [r'/shell', r'/exec', r'/execute', r'/api/execute', r'/cmd'],
|
||||
'description': 'Shell command execution endpoint detected',
|
||||
'cvss': 9.9,
|
||||
'severity': 'CRITICAL'
|
||||
},
|
||||
'debug_mode': {
|
||||
'patterns': [r'DEBUG\s*[=:]\s*[Tt]rue', r'debug\s*mode', r'stack\s*trace'],
|
||||
'description': 'Debug mode appears to be enabled',
|
||||
'cvss': 7.5,
|
||||
'severity': 'HIGH'
|
||||
},
|
||||
'file_upload': {
|
||||
'patterns': [r'/upload', r'/api/files', r'multipart/form-data'],
|
||||
'description': 'File upload functionality detected',
|
||||
'cvss': 7.8,
|
||||
'severity': 'HIGH'
|
||||
},
|
||||
'admin_panel': {
|
||||
'patterns': [r'/admin', r'admin\s*panel', r'administrator'],
|
||||
'description': 'Admin panel potentially exposed',
|
||||
'cvss': 8.5,
|
||||
'severity': 'HIGH'
|
||||
},
|
||||
'database_exposed': {
|
||||
'patterns': [r'mongodb://', r'mysql://', r'postgresql://', r'redis://'],
|
||||
'description': 'Database connection string exposed',
|
||||
'cvss': 9.5,
|
||||
'severity': 'CRITICAL'
|
||||
}
|
||||
}
|
||||
|
||||
# Information disclosure patterns
|
||||
INFO_DISCLOSURE_PATTERNS = {
|
||||
'env_file': {
|
||||
'patterns': [r'\.env', r'environment\s*variables?'],
|
||||
'description': 'Environment file or variables exposed',
|
||||
'cvss': 8.0,
|
||||
'severity': 'HIGH'
|
||||
},
|
||||
'config_file': {
|
||||
'patterns': [r'config\.json', r'settings\.py', r'application\.yml'],
|
||||
'description': 'Configuration file exposed',
|
||||
'cvss': 7.5,
|
||||
'severity': 'HIGH'
|
||||
},
|
||||
'git_exposed': {
|
||||
'patterns': [r'\.git/', r'\.git/config'],
|
||||
'description': 'Git repository exposed',
|
||||
'cvss': 7.0,
|
||||
'severity': 'MEDIUM'
|
||||
},
|
||||
'source_code': {
|
||||
'patterns': [r'\.py$', r'\.js$', r'\.php$'],
|
||||
'description': 'Source code files potentially exposed',
|
||||
'cvss': 6.5,
|
||||
'severity': 'MEDIUM'
|
||||
}
|
||||
}
|
||||
|
||||
def __init__(self, config: Optional[Dict[str, Any]] = None, threat_enricher: Optional['ThreatEnricher'] = None):
|
||||
"""
|
||||
Initialize VulnerabilityAssessor.
|
||||
|
||||
Args:
|
||||
config: Configuration dictionary
|
||||
threat_enricher: Optional ThreatEnricher for ClawSec integration
|
||||
"""
|
||||
self.config = config or {}
|
||||
self.passive_only = self.config.get('passive_only', True)
|
||||
self.threat_enricher = threat_enricher
|
||||
|
||||
def assess(self, result: SearchResult) -> List[Vulnerability]:
|
||||
"""
|
||||
Perform vulnerability assessment on a search result.
|
||||
|
||||
Args:
|
||||
result: SearchResult to assess
|
||||
|
||||
Returns:
|
||||
List of discovered vulnerabilities
|
||||
"""
|
||||
vulnerabilities = []
|
||||
|
||||
# Check for API key exposure in banner
|
||||
if result.banner:
|
||||
vulnerabilities.extend(self._check_api_keys(result.banner))
|
||||
|
||||
# Check for dangerous functionality
|
||||
vulnerabilities.extend(self._check_dangerous_functionality(result))
|
||||
|
||||
# Check for information disclosure
|
||||
vulnerabilities.extend(self._check_information_disclosure(result))
|
||||
|
||||
# Check SSL/TLS issues
|
||||
vulnerabilities.extend(self._check_ssl_issues(result))
|
||||
|
||||
# Check for authentication issues (based on metadata)
|
||||
vulnerabilities.extend(self._check_authentication(result))
|
||||
|
||||
# Add pre-existing vulnerability indicators
|
||||
for vuln_name in result.vulnerabilities:
|
||||
if not any(v.check_name == vuln_name for v in vulnerabilities):
|
||||
vulnerabilities.append(self._create_from_indicator(vuln_name))
|
||||
|
||||
logger.debug(f"Assessed {result.ip}:{result.port} - {len(vulnerabilities)} vulnerabilities")
|
||||
return vulnerabilities
|
||||
|
||||
def assess_batch(self, results: List[SearchResult]) -> Dict[str, List[Vulnerability]]:
|
||||
"""
|
||||
Assess multiple results.
|
||||
|
||||
Args:
|
||||
results: List of SearchResults
|
||||
|
||||
Returns:
|
||||
Dictionary mapping result keys to vulnerability lists
|
||||
"""
|
||||
assessments = {}
|
||||
for result in results:
|
||||
key = f"{result.ip}:{result.port}"
|
||||
assessments[key] = self.assess(result)
|
||||
return assessments
|
||||
|
||||
def assess_with_intel(self, result: SearchResult) -> List[Vulnerability]:
|
||||
"""
|
||||
Perform vulnerability assessment enhanced with threat intelligence.
|
||||
|
||||
1. Enrich result with ClawSec CVE data
|
||||
2. Run standard passive checks
|
||||
3. Create Vulnerability objects for matched CVEs
|
||||
|
||||
Args:
|
||||
result: SearchResult to assess
|
||||
|
||||
Returns:
|
||||
List of discovered vulnerabilities including ClawSec CVEs
|
||||
"""
|
||||
# First, enrich the result if threat enricher is available
|
||||
if self.threat_enricher:
|
||||
result = self.threat_enricher.enrich(result)
|
||||
|
||||
# Run standard assessment
|
||||
vulnerabilities = self.assess(result)
|
||||
|
||||
# Add ClawSec CVE vulnerabilities
|
||||
if self.threat_enricher:
|
||||
clawsec_vulns = self._create_clawsec_vulnerabilities(result)
|
||||
vulnerabilities.extend(clawsec_vulns)
|
||||
|
||||
return vulnerabilities
|
||||
|
||||
def _create_clawsec_vulnerabilities(self, result: SearchResult) -> List[Vulnerability]:
|
||||
"""
|
||||
Convert ClawSec advisory data to Vulnerability objects.
|
||||
|
||||
Args:
|
||||
result: SearchResult with clawsec_advisories in metadata
|
||||
|
||||
Returns:
|
||||
List of Vulnerability objects from ClawSec data
|
||||
"""
|
||||
vulns = []
|
||||
clawsec_data = result.metadata.get('clawsec_advisories', [])
|
||||
|
||||
for advisory in clawsec_data:
|
||||
vulns.append(Vulnerability(
|
||||
check_name=f"clawsec_{advisory['cve_id']}",
|
||||
severity=advisory.get('severity', 'MEDIUM'),
|
||||
cvss_score=advisory.get('cvss_score', 7.0),
|
||||
description=f"[ClawSec] {advisory.get('title', 'Known vulnerability')}",
|
||||
evidence={
|
||||
'cve_id': advisory['cve_id'],
|
||||
'source': 'ClawSec',
|
||||
'vuln_type': advisory.get('vuln_type', 'unknown'),
|
||||
'nvd_url': advisory.get('nvd_url')
|
||||
},
|
||||
remediation=advisory.get('action', 'See ClawSec advisory for remediation steps'),
|
||||
cwe_id=advisory.get('cwe_id')
|
||||
))
|
||||
|
||||
return vulns
|
||||
|
||||
def _check_api_keys(self, text: str) -> List[Vulnerability]:
|
||||
"""Check for exposed API keys in text."""
|
||||
vulnerabilities = []
|
||||
|
||||
for key_type, config in self.API_KEY_PATTERNS.items():
|
||||
if re.search(config['pattern'], text, re.IGNORECASE):
|
||||
vulnerabilities.append(Vulnerability(
|
||||
check_name=f"api_key_exposure_{key_type}",
|
||||
severity="CRITICAL",
|
||||
cvss_score=config['cvss'],
|
||||
description=config['description'],
|
||||
evidence={'pattern_matched': key_type},
|
||||
remediation="Immediately rotate the exposed API key and remove from public-facing content",
|
||||
cwe_id="CWE-798"
|
||||
))
|
||||
|
||||
return vulnerabilities
|
||||
|
||||
def _check_dangerous_functionality(self, result: SearchResult) -> List[Vulnerability]:
|
||||
"""Check for dangerous functionality indicators."""
|
||||
vulnerabilities = []
|
||||
http_info = result.metadata.get('http') or {}
|
||||
text = (result.banner or '') + str(http_info)
|
||||
|
||||
for check_name, config in self.DANGEROUS_PATTERNS.items():
|
||||
for pattern in config['patterns']:
|
||||
if re.search(pattern, text, re.IGNORECASE):
|
||||
vulnerabilities.append(Vulnerability(
|
||||
check_name=check_name,
|
||||
severity=config['severity'],
|
||||
cvss_score=config['cvss'],
|
||||
description=config['description'],
|
||||
evidence={'pattern': pattern},
|
||||
remediation=self._get_remediation(check_name)
|
||||
))
|
||||
break # Only add once per check
|
||||
|
||||
return vulnerabilities
|
||||
|
||||
def _check_information_disclosure(self, result: SearchResult) -> List[Vulnerability]:
|
||||
"""Check for information disclosure."""
|
||||
vulnerabilities = []
|
||||
text = (result.banner or '') + str(result.metadata)
|
||||
|
||||
for check_name, config in self.INFO_DISCLOSURE_PATTERNS.items():
|
||||
for pattern in config['patterns']:
|
||||
if re.search(pattern, text, re.IGNORECASE):
|
||||
vulnerabilities.append(Vulnerability(
|
||||
check_name=f"info_disclosure_{check_name}",
|
||||
severity=config['severity'],
|
||||
cvss_score=config['cvss'],
|
||||
description=config['description'],
|
||||
evidence={'pattern': pattern},
|
||||
remediation="Remove or restrict access to sensitive files"
|
||||
))
|
||||
break
|
||||
|
||||
return vulnerabilities
|
||||
|
||||
def _check_ssl_issues(self, result: SearchResult) -> List[Vulnerability]:
|
||||
"""Check for SSL/TLS issues."""
|
||||
vulnerabilities = []
|
||||
ssl_info = result.metadata.get('ssl') or {}
|
||||
|
||||
if not ssl_info:
|
||||
# No SSL on HTTPS port might be an issue
|
||||
if result.port in [443, 8443]:
|
||||
vulnerabilities.append(Vulnerability(
|
||||
check_name="no_ssl_on_https_port",
|
||||
severity="MEDIUM",
|
||||
cvss_score=5.3,
|
||||
description="HTTPS port without SSL/TLS",
|
||||
remediation="Configure proper SSL/TLS certificate"
|
||||
))
|
||||
return vulnerabilities
|
||||
|
||||
cert = ssl_info.get('cert') or {}
|
||||
|
||||
# Check for expired certificate
|
||||
if cert.get('expired', False):
|
||||
vulnerabilities.append(Vulnerability(
|
||||
check_name="expired_ssl_certificate",
|
||||
severity="MEDIUM",
|
||||
cvss_score=5.0,
|
||||
description="SSL certificate has expired",
|
||||
remediation="Renew SSL certificate",
|
||||
cwe_id="CWE-295"
|
||||
))
|
||||
|
||||
# Check for self-signed certificate
|
||||
if cert.get('self_signed', False):
|
||||
vulnerabilities.append(Vulnerability(
|
||||
check_name="self_signed_certificate",
|
||||
severity="LOW",
|
||||
cvss_score=3.0,
|
||||
description="Self-signed SSL certificate detected",
|
||||
remediation="Use a certificate from a trusted CA"
|
||||
))
|
||||
|
||||
return vulnerabilities
|
||||
|
||||
def _check_authentication(self, result: SearchResult) -> List[Vulnerability]:
|
||||
"""Check for authentication issues."""
|
||||
vulnerabilities = []
|
||||
http_info = result.metadata.get('http') or {}
|
||||
|
||||
if not http_info:
|
||||
return vulnerabilities
|
||||
|
||||
# Check for missing authentication on sensitive endpoints
|
||||
status = http_info.get('status')
|
||||
if status == 200:
|
||||
# 200 OK on root might indicate no auth
|
||||
title = http_info.get('title') or ''
|
||||
title = title.lower()
|
||||
if any(term in title for term in ['dashboard', 'admin', 'control panel']):
|
||||
vulnerabilities.append(Vulnerability(
|
||||
check_name="no_authentication",
|
||||
severity="CRITICAL",
|
||||
cvss_score=9.1,
|
||||
description="Dashboard accessible without authentication",
|
||||
evidence={'http_title': http_info.get('title')},
|
||||
remediation="Implement authentication mechanism",
|
||||
cwe_id="CWE-306"
|
||||
))
|
||||
|
||||
return vulnerabilities
|
||||
|
||||
def _create_from_indicator(self, indicator: str) -> Vulnerability:
|
||||
"""Create a Vulnerability from a string indicator."""
|
||||
# Map common indicators to vulnerabilities
|
||||
indicator_map = {
|
||||
'debug_mode_enabled': ('DEBUG', 'HIGH', 7.5, "Debug mode is enabled"),
|
||||
'potential_api_key_exposure': ('API Keys', 'CRITICAL', 9.0, "Potential API key exposure detected"),
|
||||
'expired_ssl_certificate': ('SSL', 'MEDIUM', 5.0, "Expired SSL certificate"),
|
||||
'no_security_txt': ('Config', 'LOW', 2.0, "No security.txt file found"),
|
||||
'self_signed_certificate': ('SSL', 'LOW', 3.0, "Self-signed certificate"),
|
||||
}
|
||||
|
||||
if indicator in indicator_map:
|
||||
category, severity, cvss, desc = indicator_map[indicator]
|
||||
return Vulnerability(
|
||||
check_name=indicator,
|
||||
severity=severity,
|
||||
cvss_score=cvss,
|
||||
description=desc
|
||||
)
|
||||
|
||||
# Default for unknown indicators
|
||||
return Vulnerability(
|
||||
check_name=indicator,
|
||||
severity="INFO",
|
||||
cvss_score=1.0,
|
||||
description=f"Indicator detected: {indicator}"
|
||||
)
|
||||
|
||||
def _get_remediation(self, check_name: str) -> str:
|
||||
"""Get remediation advice for a vulnerability."""
|
||||
remediations = {
|
||||
'shell_access': "Disable or restrict shell execution endpoints. Implement authentication and authorization.",
|
||||
'debug_mode': "Disable debug mode in production environments.",
|
||||
'file_upload': "Implement file type validation, size limits, and malware scanning.",
|
||||
'admin_panel': "Restrict admin panel access to authorized networks. Implement strong authentication.",
|
||||
'database_exposed': "Remove database connection strings from public-facing content. Use environment variables.",
|
||||
}
|
||||
return remediations.get(check_name, "Review and remediate the identified issue.")
|
||||
|
||||
def get_severity_counts(self, vulnerabilities: List[Vulnerability]) -> Dict[str, int]:
|
||||
"""Get count of vulnerabilities by severity."""
|
||||
counts = {'CRITICAL': 0, 'HIGH': 0, 'MEDIUM': 0, 'LOW': 0, 'INFO': 0}
|
||||
for vuln in vulnerabilities:
|
||||
if vuln.severity in counts:
|
||||
counts[vuln.severity] += 1
|
||||
return counts
|
||||
@@ -0,0 +1,10 @@
|
||||
"""Search engine modules for AASRT."""
|
||||
|
||||
from .base import BaseSearchEngine, SearchResult
|
||||
from .shodan_engine import ShodanEngine
|
||||
|
||||
__all__ = [
|
||||
'BaseSearchEngine',
|
||||
'SearchResult',
|
||||
'ShodanEngine'
|
||||
]
|
||||
@@ -0,0 +1,183 @@
|
||||
"""Abstract base class for search engine integrations."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional
|
||||
import time
|
||||
|
||||
from src.utils.logger import get_logger
|
||||
from src.utils.exceptions import RateLimitException
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SearchResult:
|
||||
"""Represents a single search result from any engine."""
|
||||
|
||||
ip: str
|
||||
port: int
|
||||
hostname: Optional[str] = None
|
||||
service: Optional[str] = None
|
||||
banner: Optional[str] = None
|
||||
vulnerabilities: List[str] = field(default_factory=list)
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
source_engine: Optional[str] = None
|
||||
timestamp: Optional[str] = None
|
||||
risk_score: float = 0.0
|
||||
confidence: int = 100
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
'ip': self.ip,
|
||||
'port': self.port,
|
||||
'hostname': self.hostname,
|
||||
'service': self.service,
|
||||
'banner': self.banner,
|
||||
'vulnerabilities': self.vulnerabilities,
|
||||
'metadata': self.metadata,
|
||||
'source_engine': self.source_engine,
|
||||
'timestamp': self.timestamp,
|
||||
'risk_score': self.risk_score,
|
||||
'confidence': self.confidence
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'SearchResult':
|
||||
"""Create from dictionary."""
|
||||
return cls(
|
||||
ip=data.get('ip', ''),
|
||||
port=data.get('port', 0),
|
||||
hostname=data.get('hostname'),
|
||||
service=data.get('service'),
|
||||
banner=data.get('banner'),
|
||||
vulnerabilities=data.get('vulnerabilities', []),
|
||||
metadata=data.get('metadata', {}),
|
||||
source_engine=data.get('source_engine'),
|
||||
timestamp=data.get('timestamp'),
|
||||
risk_score=data.get('risk_score', 0.0),
|
||||
confidence=data.get('confidence', 100)
|
||||
)
|
||||
|
||||
|
||||
class BaseSearchEngine(ABC):
|
||||
"""Abstract base class for all search engine integrations."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
rate_limit: float = 1.0,
|
||||
timeout: int = 30,
|
||||
max_results: int = 100
|
||||
):
|
||||
"""
|
||||
Initialize the search engine.
|
||||
|
||||
Args:
|
||||
api_key: API key for authentication
|
||||
rate_limit: Maximum queries per second
|
||||
timeout: Request timeout in seconds
|
||||
max_results: Maximum results to return per query
|
||||
"""
|
||||
self.api_key = api_key
|
||||
self.rate_limit = rate_limit
|
||||
self.timeout = timeout
|
||||
self.max_results = max_results
|
||||
self._last_request_time = 0.0
|
||||
self._request_count = 0
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def name(self) -> str:
|
||||
"""Return the engine name."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def search(self, query: str, max_results: Optional[int] = None) -> List[SearchResult]:
|
||||
"""
|
||||
Execute a search query and return results.
|
||||
|
||||
Args:
|
||||
query: Search query string
|
||||
max_results: Maximum number of results to return (overrides default)
|
||||
|
||||
Returns:
|
||||
List of SearchResult objects
|
||||
|
||||
Raises:
|
||||
APIException: If API call fails
|
||||
RateLimitException: If rate limit exceeded
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def validate_credentials(self) -> bool:
|
||||
"""
|
||||
Validate API credentials.
|
||||
|
||||
Returns:
|
||||
True if credentials are valid
|
||||
|
||||
Raises:
|
||||
AuthenticationException: If credentials are invalid
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_quota_info(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get API quota/usage information.
|
||||
|
||||
Returns:
|
||||
Dictionary with quota information
|
||||
"""
|
||||
pass
|
||||
|
||||
def _rate_limit_wait(self) -> None:
|
||||
"""Enforce rate limiting between requests."""
|
||||
if self.rate_limit <= 0:
|
||||
return
|
||||
|
||||
min_interval = 1.0 / self.rate_limit
|
||||
elapsed = time.time() - self._last_request_time
|
||||
|
||||
if elapsed < min_interval:
|
||||
wait_time = min_interval - elapsed
|
||||
logger.debug(f"Rate limiting: waiting {wait_time:.2f}s")
|
||||
time.sleep(wait_time)
|
||||
|
||||
self._last_request_time = time.time()
|
||||
self._request_count += 1
|
||||
|
||||
def _check_rate_limit(self) -> None:
|
||||
"""Check if rate limit is being approached."""
|
||||
# This can be overridden by specific engines with their own rate limit logic
|
||||
pass
|
||||
|
||||
def _parse_result(self, raw_result: Dict[str, Any]) -> SearchResult:
|
||||
"""
|
||||
Parse a raw API result into a SearchResult.
|
||||
|
||||
Args:
|
||||
raw_result: Raw result from API
|
||||
|
||||
Returns:
|
||||
SearchResult object
|
||||
"""
|
||||
# Default implementation - should be overridden by specific engines
|
||||
return SearchResult(
|
||||
ip=raw_result.get('ip', ''),
|
||||
port=raw_result.get('port', 0),
|
||||
source_engine=self.name
|
||||
)
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get engine statistics."""
|
||||
return {
|
||||
'engine': self.name,
|
||||
'request_count': self._request_count,
|
||||
'rate_limit': self.rate_limit,
|
||||
'timeout': self.timeout,
|
||||
'max_results': self.max_results
|
||||
}
|
||||
@@ -0,0 +1,589 @@
|
||||
"""
|
||||
Shodan search engine integration for AASRT.
|
||||
|
||||
This module provides a production-ready integration with the Shodan API
|
||||
for security reconnaissance. Features include:
|
||||
|
||||
- Automatic retry with exponential backoff for transient failures
|
||||
- Rate limiting to prevent API quota exhaustion
|
||||
- Comprehensive error handling with specific exception types
|
||||
- Detailed logging for debugging and monitoring
|
||||
- Graceful degradation when API is unavailable
|
||||
|
||||
Example:
|
||||
>>> from src.engines.shodan_engine import ShodanEngine
|
||||
>>> engine = ShodanEngine(api_key="your_key")
|
||||
>>> engine.validate_credentials()
|
||||
True
|
||||
>>> results = engine.search("http.html:clawdbot", max_results=10)
|
||||
"""
|
||||
|
||||
from typing import Any, Callable, Dict, List, Optional, TypeVar
|
||||
from datetime import datetime
|
||||
from functools import wraps
|
||||
import time
|
||||
import socket
|
||||
|
||||
import shodan
|
||||
from tenacity import (
|
||||
retry,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
retry_if_exception_type,
|
||||
before_sleep_log,
|
||||
RetryError
|
||||
)
|
||||
|
||||
from .base import BaseSearchEngine, SearchResult
|
||||
from src.utils.logger import get_logger
|
||||
from src.utils.validators import validate_ip, sanitize_output
|
||||
from src.utils.exceptions import (
|
||||
APIException,
|
||||
RateLimitException,
|
||||
AuthenticationException,
|
||||
TimeoutException
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# Type variable for generic retry decorator
|
||||
T = TypeVar('T')
|
||||
|
||||
# =============================================================================
|
||||
# Retry Configuration
|
||||
# =============================================================================
|
||||
|
||||
# Exceptions that should trigger a retry (transient failures)
|
||||
RETRYABLE_EXCEPTIONS = (
|
||||
socket.timeout,
|
||||
ConnectionError,
|
||||
ConnectionResetError,
|
||||
TimeoutError,
|
||||
)
|
||||
|
||||
# Maximum number of retry attempts
|
||||
MAX_RETRY_ATTEMPTS = 3
|
||||
|
||||
# Base delay for exponential backoff (seconds)
|
||||
RETRY_BASE_DELAY = 2
|
||||
|
||||
# Maximum delay between retries (seconds)
|
||||
RETRY_MAX_DELAY = 30
|
||||
|
||||
|
||||
def with_retry(func: Callable[..., T]) -> Callable[..., T]:
|
||||
"""
|
||||
Decorator that adds retry logic with exponential backoff.
|
||||
|
||||
Retries on transient network errors but not on authentication
|
||||
or validation errors.
|
||||
|
||||
Args:
|
||||
func: Function to wrap with retry logic.
|
||||
|
||||
Returns:
|
||||
Wrapped function with retry capability.
|
||||
"""
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs) -> T:
|
||||
last_exception = None
|
||||
|
||||
for attempt in range(1, MAX_RETRY_ATTEMPTS + 1):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except RETRYABLE_EXCEPTIONS as e:
|
||||
last_exception = e
|
||||
if attempt < MAX_RETRY_ATTEMPTS:
|
||||
delay = min(RETRY_BASE_DELAY ** attempt, RETRY_MAX_DELAY)
|
||||
logger.warning(
|
||||
f"Retry {attempt}/{MAX_RETRY_ATTEMPTS} for {func.__name__} "
|
||||
f"after {delay}s delay. Error: {e}"
|
||||
)
|
||||
time.sleep(delay)
|
||||
else:
|
||||
logger.error(
|
||||
f"All {MAX_RETRY_ATTEMPTS} retries exhausted for {func.__name__}. "
|
||||
f"Last error: {e}"
|
||||
)
|
||||
except (AuthenticationException, RateLimitException):
|
||||
# Don't retry auth or rate limit errors
|
||||
raise
|
||||
except shodan.APIError as e:
|
||||
error_msg = str(e).lower()
|
||||
# Don't retry permanent errors
|
||||
if "invalid api key" in error_msg:
|
||||
raise AuthenticationException(
|
||||
"Invalid Shodan API key",
|
||||
engine="shodan"
|
||||
)
|
||||
if "rate limit" in error_msg:
|
||||
raise RateLimitException(
|
||||
f"Shodan rate limit exceeded: {e}",
|
||||
engine="shodan"
|
||||
)
|
||||
# Retry other API errors
|
||||
last_exception = e
|
||||
if attempt < MAX_RETRY_ATTEMPTS:
|
||||
delay = min(RETRY_BASE_DELAY ** attempt, RETRY_MAX_DELAY)
|
||||
logger.warning(
|
||||
f"Retry {attempt}/{MAX_RETRY_ATTEMPTS} for {func.__name__} "
|
||||
f"after {delay}s delay. API Error: {e}"
|
||||
)
|
||||
time.sleep(delay)
|
||||
|
||||
# All retries exhausted
|
||||
if last_exception:
|
||||
raise APIException(
|
||||
f"Operation failed after {MAX_RETRY_ATTEMPTS} retries: {last_exception}",
|
||||
engine="shodan"
|
||||
)
|
||||
raise APIException("Unexpected retry failure", engine="shodan")
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class ShodanEngine(BaseSearchEngine):
|
||||
"""
|
||||
Shodan search engine integration for security reconnaissance.
|
||||
|
||||
This class provides a production-ready interface to the Shodan API with:
|
||||
- Automatic rate limiting to respect API quotas
|
||||
- Retry logic with exponential backoff for transient failures
|
||||
- Comprehensive error handling and logging
|
||||
- Result parsing with vulnerability detection
|
||||
|
||||
Attributes:
|
||||
name: Engine identifier ("shodan").
|
||||
api_key: Shodan API key (masked in logs).
|
||||
rate_limit: Maximum requests per second.
|
||||
timeout: Request timeout in seconds.
|
||||
max_results: Default maximum results per search.
|
||||
|
||||
Example:
|
||||
>>> engine = ShodanEngine(api_key="your_key")
|
||||
>>> if engine.validate_credentials():
|
||||
... results = engine.search("http.html:agent", max_results=50)
|
||||
... for result in results:
|
||||
... print(f"{result.ip}:{result.port}")
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
rate_limit: float = 1.0,
|
||||
timeout: int = 30,
|
||||
max_results: int = 100
|
||||
) -> None:
|
||||
"""
|
||||
Initialize Shodan engine with API credentials.
|
||||
|
||||
Args:
|
||||
api_key: Shodan API key from https://account.shodan.io/.
|
||||
Never log or expose this value.
|
||||
rate_limit: Maximum queries per second. Default 1.0 to respect
|
||||
Shodan's free tier limits.
|
||||
timeout: Request timeout in seconds. Increase for slow connections.
|
||||
max_results: Maximum results per query. Higher values consume
|
||||
more query credits.
|
||||
|
||||
Raises:
|
||||
ValueError: If api_key is empty or None.
|
||||
"""
|
||||
if not api_key or not api_key.strip():
|
||||
raise ValueError("Shodan API key is required")
|
||||
|
||||
super().__init__(api_key, rate_limit, timeout, max_results)
|
||||
self._client = shodan.Shodan(api_key)
|
||||
self._api_key_preview = f"{api_key[:4]}...{api_key[-4:]}" if len(api_key) > 8 else "***"
|
||||
logger.debug(f"ShodanEngine initialized with key: {self._api_key_preview}")
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Return engine identifier."""
|
||||
return "shodan"
|
||||
|
||||
@with_retry
|
||||
def validate_credentials(self) -> bool:
|
||||
"""
|
||||
Validate Shodan API credentials by making a test API call.
|
||||
|
||||
This method performs a lightweight API call to verify the API key
|
||||
is valid and has not been revoked.
|
||||
|
||||
Returns:
|
||||
True if credentials are valid and API is accessible.
|
||||
|
||||
Raises:
|
||||
AuthenticationException: If API key is invalid or revoked.
|
||||
APIException: If API call fails for other reasons.
|
||||
|
||||
Example:
|
||||
>>> engine = ShodanEngine(api_key="your_key")
|
||||
>>> try:
|
||||
... engine.validate_credentials()
|
||||
... print("API key is valid")
|
||||
... except AuthenticationException:
|
||||
... print("Invalid API key")
|
||||
"""
|
||||
try:
|
||||
info = self._client.info()
|
||||
plan = info.get('plan', 'unknown')
|
||||
credits = info.get('query_credits', 0)
|
||||
logger.info(
|
||||
f"Shodan API validated. Plan: {plan}, "
|
||||
f"Query credits: {credits}"
|
||||
)
|
||||
return True
|
||||
except shodan.APIError as e:
|
||||
error_msg = str(e)
|
||||
if "Invalid API key" in error_msg:
|
||||
logger.error("Shodan authentication failed: Invalid API key")
|
||||
raise AuthenticationException(
|
||||
"Invalid Shodan API key",
|
||||
engine=self.name
|
||||
)
|
||||
logger.error(f"Shodan API validation error: {sanitize_output(error_msg)}")
|
||||
raise APIException(f"Shodan API error: {e}", engine=self.name)
|
||||
|
||||
@with_retry
|
||||
def get_quota_info(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get Shodan API quota and usage information.
|
||||
|
||||
Returns:
|
||||
Dictionary containing:
|
||||
- engine: Engine name ("shodan")
|
||||
- plan: API plan type (e.g., "dev", "edu", "corp")
|
||||
- query_credits: Remaining query credits
|
||||
- scan_credits: Remaining scan credits
|
||||
- monitored_ips: Number of monitored IPs
|
||||
- unlocked: Whether account has unlocked features
|
||||
- error: Error message if call failed (optional)
|
||||
|
||||
Note:
|
||||
This call does not consume query credits.
|
||||
"""
|
||||
try:
|
||||
info = self._client.info()
|
||||
quota = {
|
||||
'engine': self.name,
|
||||
'plan': info.get('plan', 'unknown'),
|
||||
'query_credits': info.get('query_credits', 0),
|
||||
'scan_credits': info.get('scan_credits', 0),
|
||||
'monitored_ips': info.get('monitored_ips', 0),
|
||||
'unlocked': info.get('unlocked', False),
|
||||
'timestamp': datetime.utcnow().isoformat()
|
||||
}
|
||||
logger.debug(f"Shodan quota retrieved: {quota['query_credits']} credits remaining")
|
||||
return quota
|
||||
except shodan.APIError as e:
|
||||
logger.error(f"Failed to get Shodan quota: {sanitize_output(str(e))}")
|
||||
return {
|
||||
'engine': self.name,
|
||||
'error': str(e),
|
||||
'timestamp': datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
def search(self, query: str, max_results: Optional[int] = None) -> List[SearchResult]:
|
||||
"""
|
||||
Execute a Shodan search query with automatic pagination.
|
||||
|
||||
This method handles pagination automatically, respecting rate limits
|
||||
and the specified maximum results. Each page consumes one query credit.
|
||||
|
||||
Args:
|
||||
query: Shodan search query string. Supports Shodan's query syntax
|
||||
including filters like http.html:, port:, country:, etc.
|
||||
max_results: Maximum number of results to return. Defaults to
|
||||
the engine's max_results setting. Set to None for default.
|
||||
|
||||
Returns:
|
||||
List of SearchResult objects containing parsed Shodan data.
|
||||
May return fewer results than max_results if not enough matches.
|
||||
|
||||
Raises:
|
||||
APIException: If API call fails after all retries.
|
||||
RateLimitException: If rate limit is exceeded.
|
||||
AuthenticationException: If API key is invalid.
|
||||
ValidationException: If query is invalid.
|
||||
|
||||
Example:
|
||||
>>> results = engine.search("http.html:clawdbot", max_results=50)
|
||||
>>> for r in results:
|
||||
... print(f"{r.ip}:{r.port} - {r.service}")
|
||||
|
||||
Note:
|
||||
- Shodan returns max 100 results per page
|
||||
- Multiple pages consume multiple query credits
|
||||
- Consider using count() first to check total results
|
||||
"""
|
||||
# Validate and sanitize query
|
||||
if not query or not query.strip():
|
||||
raise APIException("Search query cannot be empty", engine=self.name)
|
||||
|
||||
query = query.strip()
|
||||
limit = max_results or self.max_results
|
||||
|
||||
# Log sanitized query (remove potential sensitive data)
|
||||
safe_query = sanitize_output(query)
|
||||
logger.info(f"Executing Shodan search: {safe_query} (limit: {limit})")
|
||||
|
||||
results: List[SearchResult] = []
|
||||
page = 1
|
||||
total_pages = 0
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
while len(results) < limit:
|
||||
# Apply rate limiting before each request
|
||||
self._rate_limit_wait()
|
||||
|
||||
# Execute search with retry logic
|
||||
response = self._execute_search_page(query, page)
|
||||
|
||||
if response is None or not response.get('matches'):
|
||||
logger.debug(f"No more matches at page {page}")
|
||||
break
|
||||
|
||||
# Parse matches
|
||||
matches = response.get('matches', [])
|
||||
for match in matches:
|
||||
if len(results) >= limit:
|
||||
break
|
||||
|
||||
try:
|
||||
result = self._parse_result(match)
|
||||
results.append(result)
|
||||
except Exception as e:
|
||||
# Log but continue on parse errors
|
||||
logger.warning(f"Failed to parse result: {e}")
|
||||
continue
|
||||
|
||||
# Check pagination limits
|
||||
total = response.get('total', 0)
|
||||
total_pages = (total + 99) // 100 # Ceiling division
|
||||
|
||||
if len(results) >= total or len(results) >= limit:
|
||||
break
|
||||
|
||||
page += 1
|
||||
|
||||
# Safety limit to prevent infinite loops
|
||||
if page > 100:
|
||||
logger.warning("Reached maximum page limit (100)")
|
||||
break
|
||||
|
||||
# Log completion stats
|
||||
elapsed = time.time() - start_time
|
||||
logger.info(
|
||||
f"Shodan search complete: {len(results)} results "
|
||||
f"from {page} pages in {elapsed:.2f}s"
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
except (AuthenticationException, RateLimitException):
|
||||
# Re-raise known exceptions without wrapping
|
||||
raise
|
||||
except shodan.APIError as e:
|
||||
error_msg = str(e).lower()
|
||||
if "rate limit" in error_msg:
|
||||
logger.error("Shodan rate limit exceeded during search")
|
||||
raise RateLimitException(
|
||||
f"Shodan rate limit exceeded: {e}",
|
||||
engine=self.name
|
||||
)
|
||||
elif "invalid api key" in error_msg:
|
||||
logger.error("Shodan authentication failed during search")
|
||||
raise AuthenticationException(
|
||||
"Invalid Shodan API key",
|
||||
engine=self.name
|
||||
)
|
||||
else:
|
||||
logger.error(f"Shodan API error: {sanitize_output(str(e))}")
|
||||
raise APIException(
|
||||
f"Shodan search failed: {e}",
|
||||
engine=self.name
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Unexpected error in Shodan search: {e}")
|
||||
raise APIException(
|
||||
f"Shodan search error: {type(e).__name__}: {e}",
|
||||
engine=self.name
|
||||
)
|
||||
|
||||
@with_retry
|
||||
def _execute_search_page(self, query: str, page: int) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Execute a single page of Shodan search with retry logic.
|
||||
|
||||
Args:
|
||||
query: Search query string.
|
||||
page: Page number (1-indexed).
|
||||
|
||||
Returns:
|
||||
Shodan API response dictionary or None on failure.
|
||||
"""
|
||||
logger.debug(f"Fetching Shodan results page {page}")
|
||||
return self._client.search(query, page=page)
|
||||
|
||||
def _parse_result(self, match: Dict[str, Any]) -> SearchResult:
|
||||
"""
|
||||
Parse a Shodan match into a SearchResult.
|
||||
|
||||
Args:
|
||||
match: Raw Shodan match data
|
||||
|
||||
Returns:
|
||||
SearchResult object
|
||||
"""
|
||||
# Extract vulnerability indicators from data
|
||||
vulnerabilities = []
|
||||
data = match.get('data', '')
|
||||
|
||||
# Check for common vulnerability indicators
|
||||
if 'debug' in data.lower() or 'DEBUG=True' in data:
|
||||
vulnerabilities.append('debug_mode_enabled')
|
||||
|
||||
if 'api_key' in data.lower() or 'apikey' in data.lower():
|
||||
vulnerabilities.append('potential_api_key_exposure')
|
||||
|
||||
ssl_data = match.get('ssl') or {}
|
||||
ssl_cert = ssl_data.get('cert') or {}
|
||||
if ssl_cert.get('expired', False):
|
||||
vulnerabilities.append('expired_ssl_certificate')
|
||||
|
||||
# Check HTTP response for issues
|
||||
http_data = match.get('http') or {}
|
||||
if http_data:
|
||||
if not http_data.get('securitytxt'):
|
||||
vulnerabilities.append('no_security_txt')
|
||||
|
||||
# Build metadata
|
||||
location_data = match.get('location') or {}
|
||||
metadata = {
|
||||
'asn': match.get('asn'),
|
||||
'isp': match.get('isp'),
|
||||
'org': match.get('org'),
|
||||
'os': match.get('os'),
|
||||
'transport': match.get('transport'),
|
||||
'product': match.get('product'),
|
||||
'version': match.get('version'),
|
||||
'cpe': match.get('cpe', []),
|
||||
'http': http_data,
|
||||
'ssl': ssl_data,
|
||||
'location': {
|
||||
'country': location_data.get('country_name'),
|
||||
'city': location_data.get('city'),
|
||||
'latitude': location_data.get('latitude'),
|
||||
'longitude': location_data.get('longitude')
|
||||
}
|
||||
}
|
||||
|
||||
# Extract hostnames
|
||||
hostnames = match.get('hostnames', [])
|
||||
hostname = hostnames[0] if hostnames else None
|
||||
|
||||
return SearchResult(
|
||||
ip=match.get('ip_str', ''),
|
||||
port=match.get('port', 0),
|
||||
hostname=hostname,
|
||||
service=match.get('product') or match.get('_shodan', {}).get('module'),
|
||||
banner=data[:1000] if data else None, # Truncate long banners
|
||||
vulnerabilities=vulnerabilities,
|
||||
metadata=metadata,
|
||||
source_engine=self.name,
|
||||
timestamp=match.get('timestamp', datetime.utcnow().isoformat())
|
||||
)
|
||||
|
||||
@with_retry
|
||||
def host_info(self, ip: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Get detailed information about a specific host.
|
||||
|
||||
This method retrieves comprehensive information about a host
|
||||
including all open ports, services, banners, and historical data.
|
||||
|
||||
Args:
|
||||
ip: IP address to lookup. Must be a valid IPv4 address.
|
||||
|
||||
Returns:
|
||||
Dictionary containing:
|
||||
- ip_str: IP address as string
|
||||
- ports: List of open ports
|
||||
- data: List of service banners per port
|
||||
- hostnames: List of hostnames
|
||||
- vulns: List of vulnerabilities (if any)
|
||||
- location: Geographic information
|
||||
|
||||
Raises:
|
||||
APIException: If lookup fails.
|
||||
ValidationException: If IP address is invalid.
|
||||
|
||||
Example:
|
||||
>>> info = engine.host_info("8.8.8.8")
|
||||
>>> print(f"Ports: {info.get('ports', [])}")
|
||||
"""
|
||||
# Validate IP address
|
||||
try:
|
||||
validate_ip(ip)
|
||||
except Exception as e:
|
||||
raise APIException(f"Invalid IP address: {ip}", engine=self.name)
|
||||
|
||||
self._rate_limit_wait()
|
||||
logger.debug(f"Looking up host info for: {ip}")
|
||||
|
||||
try:
|
||||
host_data = self._client.host(ip)
|
||||
logger.info(f"Retrieved host info for {ip}: {len(host_data.get('ports', []))} ports")
|
||||
return host_data
|
||||
except shodan.APIError as e:
|
||||
error_msg = str(e).lower()
|
||||
if "no information available" in error_msg:
|
||||
logger.info(f"No Shodan data available for {ip}")
|
||||
return {'ip_str': ip, 'ports': [], 'data': []}
|
||||
logger.error(f"Failed to get host info for {ip}: {sanitize_output(str(e))}")
|
||||
raise APIException(f"Shodan host lookup failed: {e}", engine=self.name)
|
||||
|
||||
@with_retry
|
||||
def count(self, query: str) -> int:
|
||||
"""
|
||||
Get the count of results for a query without consuming query credits.
|
||||
|
||||
Use this method to estimate result count before running a full search
|
||||
to avoid consuming query credits unnecessarily.
|
||||
|
||||
Args:
|
||||
query: Shodan search query string.
|
||||
|
||||
Returns:
|
||||
Estimated number of matching results. Returns 0 on error.
|
||||
|
||||
Note:
|
||||
- Does not consume query credits
|
||||
- Count may be approximate for large result sets
|
||||
- Useful for validating queries before running searches
|
||||
|
||||
Example:
|
||||
>>> count = engine.count("http.html:clawdbot")
|
||||
>>> if count > 0:
|
||||
... results = engine.search("http.html:clawdbot")
|
||||
"""
|
||||
if not query or not query.strip():
|
||||
logger.warning("Empty query provided to count()")
|
||||
return 0
|
||||
|
||||
self._rate_limit_wait()
|
||||
logger.debug(f"Counting results for query: {sanitize_output(query)}")
|
||||
|
||||
try:
|
||||
result = self._client.count(query)
|
||||
total = result.get('total', 0)
|
||||
logger.info(f"Query '{sanitize_output(query)}' has {total} results")
|
||||
return total
|
||||
except shodan.APIError as e:
|
||||
logger.error(f"Failed to count results: {sanitize_output(str(e))}")
|
||||
return 0
|
||||
except Exception as e:
|
||||
logger.exception(f"Unexpected error in count: {e}")
|
||||
return 0
|
||||
@@ -0,0 +1,19 @@
|
||||
"""Enrichment modules for AASRT.
|
||||
|
||||
This module contains data enrichment capabilities:
|
||||
- ClawSec threat intelligence integration
|
||||
- (Future) WHOIS lookups
|
||||
- (Future) Geolocation
|
||||
- (Future) SSL/TLS certificate analysis
|
||||
- (Future) DNS records
|
||||
"""
|
||||
|
||||
from .clawsec_feed import ClawSecFeedManager, ClawSecFeed, ClawSecAdvisory
|
||||
from .threat_enricher import ThreatEnricher
|
||||
|
||||
__all__ = [
|
||||
'ClawSecFeedManager',
|
||||
'ClawSecFeed',
|
||||
'ClawSecAdvisory',
|
||||
'ThreatEnricher'
|
||||
]
|
||||
@@ -0,0 +1,380 @@
|
||||
"""ClawSec Threat Intelligence Feed Manager for AASRT."""
|
||||
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import requests
|
||||
|
||||
from src.utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClawSecAdvisory:
|
||||
"""Represents a single ClawSec CVE advisory."""
|
||||
|
||||
cve_id: str
|
||||
severity: str # CRITICAL, HIGH, MEDIUM, LOW
|
||||
vuln_type: str # e.g., "prompt_injection", "missing_authentication"
|
||||
cvss_score: float
|
||||
title: str
|
||||
description: str
|
||||
affected: List[str] = field(default_factory=list)
|
||||
action: str = ""
|
||||
nvd_url: Optional[str] = None
|
||||
cwe_id: Optional[str] = None
|
||||
published_date: Optional[datetime] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
'cve_id': self.cve_id,
|
||||
'severity': self.severity,
|
||||
'vuln_type': self.vuln_type,
|
||||
'cvss_score': self.cvss_score,
|
||||
'title': self.title,
|
||||
'description': self.description,
|
||||
'affected': self.affected,
|
||||
'action': self.action,
|
||||
'nvd_url': self.nvd_url,
|
||||
'cwe_id': self.cwe_id,
|
||||
'published_date': self.published_date.isoformat() if self.published_date else None
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'ClawSecAdvisory':
|
||||
"""Create from dictionary."""
|
||||
published = data.get('published')
|
||||
if published and isinstance(published, str):
|
||||
try:
|
||||
published = datetime.fromisoformat(published.replace('Z', '+00:00'))
|
||||
except:
|
||||
published = None
|
||||
|
||||
return cls(
|
||||
cve_id=data.get('id', ''),
|
||||
severity=data.get('severity', 'MEDIUM').upper(),
|
||||
vuln_type=data.get('type', 'unknown'),
|
||||
cvss_score=float(data.get('cvss_score', 0.0)),
|
||||
title=data.get('title', ''),
|
||||
description=data.get('description', ''),
|
||||
affected=data.get('affected', []),
|
||||
action=data.get('action', ''),
|
||||
nvd_url=data.get('nvd_url'),
|
||||
cwe_id=data.get('nvd_category_id'),
|
||||
published_date=published
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClawSecFeed:
|
||||
"""Container for the full ClawSec advisory feed."""
|
||||
|
||||
advisories: List[ClawSecAdvisory]
|
||||
last_updated: datetime
|
||||
feed_version: str
|
||||
total_count: int
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for caching."""
|
||||
return {
|
||||
'advisories': [a.to_dict() for a in self.advisories],
|
||||
'last_updated': self.last_updated.isoformat(),
|
||||
'feed_version': self.feed_version,
|
||||
'total_count': self.total_count
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'ClawSecFeed':
|
||||
"""Create from dictionary."""
|
||||
return cls(
|
||||
advisories=[ClawSecAdvisory.from_dict(a) for a in data.get('advisories', [])],
|
||||
last_updated=datetime.fromisoformat(data.get('last_updated', datetime.utcnow().isoformat())),
|
||||
feed_version=data.get('feed_version', '0.0.0'),
|
||||
total_count=data.get('total_count', 0)
|
||||
)
|
||||
|
||||
|
||||
class ClawSecFeedManager:
|
||||
"""
|
||||
Manages ClawSec threat intelligence feed with caching and offline support.
|
||||
|
||||
Features:
|
||||
- HTTP fetch with configurable timeout
|
||||
- Local file caching for offline mode
|
||||
- Advisory matching by product/version/banner
|
||||
- Non-blocking background updates
|
||||
"""
|
||||
|
||||
DEFAULT_FEED_URL = "https://clawsec.prompt.security/advisories/feed.json"
|
||||
DEFAULT_CACHE_FILE = "./data/clawsec_cache.json"
|
||||
DEFAULT_TTL = 86400 # 24 hours
|
||||
|
||||
def __init__(self, config=None):
|
||||
"""
|
||||
Initialize ClawSecFeedManager.
|
||||
|
||||
Args:
|
||||
config: Configuration object with clawsec settings
|
||||
"""
|
||||
self.config = config
|
||||
|
||||
# Get configuration values
|
||||
if config:
|
||||
clawsec_config = config.get('clawsec', default={})
|
||||
self.feed_url = clawsec_config.get('feed_url', self.DEFAULT_FEED_URL)
|
||||
self.cache_file = clawsec_config.get('cache_file', self.DEFAULT_CACHE_FILE)
|
||||
self.cache_ttl = clawsec_config.get('cache_ttl_seconds', self.DEFAULT_TTL)
|
||||
self.offline_mode = clawsec_config.get('offline_mode', False)
|
||||
self.timeout = clawsec_config.get('timeout', 30)
|
||||
else:
|
||||
self.feed_url = self.DEFAULT_FEED_URL
|
||||
self.cache_file = self.DEFAULT_CACHE_FILE
|
||||
self.cache_ttl = self.DEFAULT_TTL
|
||||
self.offline_mode = False
|
||||
self.timeout = 30
|
||||
|
||||
self._cache: Optional[ClawSecFeed] = None
|
||||
self._cache_timestamp: Optional[datetime] = None
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def fetch_feed(self, force_refresh: bool = False) -> Optional[ClawSecFeed]:
|
||||
"""
|
||||
Fetch the ClawSec advisory feed.
|
||||
|
||||
Args:
|
||||
force_refresh: Force fetch from URL even if cache is valid
|
||||
|
||||
Returns:
|
||||
ClawSecFeed object or None if fetch fails
|
||||
"""
|
||||
# Check cache first
|
||||
if not force_refresh and self.is_cache_valid():
|
||||
logger.debug("Using cached ClawSec feed")
|
||||
return self._cache
|
||||
|
||||
# In offline mode, only use cache
|
||||
if self.offline_mode:
|
||||
logger.info("ClawSec offline mode - using cached data only")
|
||||
return self.get_cached_feed()
|
||||
|
||||
try:
|
||||
logger.info(f"Fetching ClawSec feed from {self.feed_url}")
|
||||
response = requests.get(self.feed_url, timeout=self.timeout)
|
||||
response.raise_for_status()
|
||||
|
||||
data = response.json()
|
||||
feed = self._parse_feed(data)
|
||||
|
||||
with self._lock:
|
||||
self._cache = feed
|
||||
self._cache_timestamp = datetime.utcnow()
|
||||
|
||||
# Persist to disk
|
||||
self.save_cache()
|
||||
|
||||
logger.info(f"ClawSec feed loaded: {feed.total_count} advisories")
|
||||
return feed
|
||||
|
||||
except requests.RequestException as e:
|
||||
logger.warning(f"Failed to fetch ClawSec feed: {e}")
|
||||
# Fall back to cache
|
||||
return self.get_cached_feed()
|
||||
except (json.JSONDecodeError, KeyError) as e:
|
||||
logger.error(f"Failed to parse ClawSec feed: {e}")
|
||||
return self.get_cached_feed()
|
||||
|
||||
def _parse_feed(self, data: Dict[str, Any]) -> ClawSecFeed:
|
||||
"""Parse raw feed JSON into ClawSecFeed object."""
|
||||
advisories = []
|
||||
|
||||
for advisory_data in data.get('advisories', []):
|
||||
try:
|
||||
advisory = ClawSecAdvisory.from_dict(advisory_data)
|
||||
advisories.append(advisory)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to parse advisory: {e}")
|
||||
continue
|
||||
|
||||
return ClawSecFeed(
|
||||
advisories=advisories,
|
||||
last_updated=datetime.utcnow(),
|
||||
feed_version=data.get('version', '0.0.0'),
|
||||
total_count=len(advisories)
|
||||
)
|
||||
|
||||
def get_cached_feed(self) -> Optional[ClawSecFeed]:
|
||||
"""Return cached feed without network call."""
|
||||
if self._cache:
|
||||
return self._cache
|
||||
|
||||
# Try loading from disk
|
||||
self.load_cache()
|
||||
return self._cache
|
||||
|
||||
def is_cache_valid(self) -> bool:
|
||||
"""Check if cache is within TTL."""
|
||||
if not self._cache or not self._cache_timestamp:
|
||||
return False
|
||||
|
||||
age = datetime.utcnow() - self._cache_timestamp
|
||||
return age.total_seconds() < self.cache_ttl
|
||||
|
||||
def save_cache(self) -> None:
|
||||
"""Persist cache to local file for offline mode."""
|
||||
if not self._cache:
|
||||
return
|
||||
|
||||
try:
|
||||
cache_path = Path(self.cache_file)
|
||||
cache_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
cache_data = {
|
||||
'feed': self._cache.to_dict(),
|
||||
'cached_at': datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
with open(cache_path, 'w') as f:
|
||||
json.dump(cache_data, f, indent=2)
|
||||
|
||||
logger.debug(f"ClawSec cache saved to {self.cache_file}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to save ClawSec cache: {e}")
|
||||
|
||||
def load_cache(self) -> bool:
|
||||
"""Load cache from local file."""
|
||||
try:
|
||||
cache_path = Path(self.cache_file)
|
||||
if not cache_path.exists():
|
||||
return False
|
||||
|
||||
with open(cache_path, 'r') as f:
|
||||
cache_data = json.load(f)
|
||||
|
||||
self._cache = ClawSecFeed.from_dict(cache_data.get('feed', {}))
|
||||
cached_at = cache_data.get('cached_at')
|
||||
if cached_at:
|
||||
self._cache_timestamp = datetime.fromisoformat(cached_at)
|
||||
|
||||
logger.info(f"ClawSec cache loaded: {self._cache.total_count} advisories")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load ClawSec cache: {e}")
|
||||
return False
|
||||
|
||||
def match_advisories(
|
||||
self,
|
||||
product: Optional[str] = None,
|
||||
version: Optional[str] = None,
|
||||
banner: Optional[str] = None
|
||||
) -> List[ClawSecAdvisory]:
|
||||
"""
|
||||
Find matching advisories for a product/version/banner.
|
||||
|
||||
Matching strategies (in order):
|
||||
1. Exact product name match in affected list
|
||||
2. Fuzzy product match (clawdbot, clawbot, claw-bot)
|
||||
3. Banner text contains product from affected
|
||||
|
||||
Args:
|
||||
product: Product name to match
|
||||
version: Version string to check
|
||||
banner: Banner text to search
|
||||
|
||||
Returns:
|
||||
List of matching ClawSecAdvisory objects
|
||||
"""
|
||||
feed = self.get_cached_feed()
|
||||
if not feed:
|
||||
return []
|
||||
|
||||
matches = []
|
||||
product_lower = (product or '').lower()
|
||||
banner_lower = (banner or '').lower()
|
||||
|
||||
# AI agent keywords to look for
|
||||
ai_keywords = ['clawdbot', 'clawbot', 'moltbot', 'openclaw', 'autogpt', 'langchain']
|
||||
|
||||
for advisory in feed.advisories:
|
||||
matched = False
|
||||
|
||||
# Check each affected product
|
||||
for affected in advisory.affected:
|
||||
affected_lower = affected.lower()
|
||||
|
||||
# Strategy 1: Direct product match
|
||||
if product_lower and product_lower in affected_lower:
|
||||
matched = True
|
||||
break
|
||||
|
||||
# Strategy 2: Check AI keywords in affected and product/banner
|
||||
for keyword in ai_keywords:
|
||||
if keyword in affected_lower:
|
||||
if keyword in product_lower or keyword in banner_lower:
|
||||
matched = True
|
||||
break
|
||||
|
||||
if matched:
|
||||
break
|
||||
|
||||
# Strategy 3: Banner contains affected product
|
||||
if banner_lower:
|
||||
# Extract product name from affected (e.g., "ClawdBot < 2.0" -> "clawdbot")
|
||||
affected_product = affected_lower.split('<')[0].split('>')[0].strip()
|
||||
if affected_product and affected_product in banner_lower:
|
||||
matched = True
|
||||
break
|
||||
|
||||
if matched and advisory not in matches:
|
||||
matches.append(advisory)
|
||||
|
||||
logger.debug(f"ClawSec matched {len(matches)} advisories for product={product}")
|
||||
return matches
|
||||
|
||||
def background_refresh(self) -> None:
|
||||
"""Start background thread to refresh feed."""
|
||||
def _refresh():
|
||||
try:
|
||||
self.fetch_feed(force_refresh=True)
|
||||
except Exception as e:
|
||||
logger.warning(f"Background ClawSec refresh failed: {e}")
|
||||
|
||||
thread = threading.Thread(target=_refresh, daemon=True)
|
||||
thread.start()
|
||||
logger.debug("ClawSec background refresh started")
|
||||
|
||||
def get_statistics(self) -> Dict[str, Any]:
|
||||
"""Get feed statistics for UI display."""
|
||||
feed = self.get_cached_feed()
|
||||
if not feed:
|
||||
return {
|
||||
'total_advisories': 0,
|
||||
'critical_count': 0,
|
||||
'high_count': 0,
|
||||
'last_updated': None,
|
||||
'is_stale': True
|
||||
}
|
||||
|
||||
severity_counts = {'CRITICAL': 0, 'HIGH': 0, 'MEDIUM': 0, 'LOW': 0}
|
||||
for advisory in feed.advisories:
|
||||
if advisory.severity in severity_counts:
|
||||
severity_counts[advisory.severity] += 1
|
||||
|
||||
return {
|
||||
'total_advisories': feed.total_count,
|
||||
'critical_count': severity_counts['CRITICAL'],
|
||||
'high_count': severity_counts['HIGH'],
|
||||
'medium_count': severity_counts['MEDIUM'],
|
||||
'low_count': severity_counts['LOW'],
|
||||
'last_updated': feed.last_updated.isoformat() if feed.last_updated else None,
|
||||
'feed_version': feed.feed_version,
|
||||
'is_stale': not self.is_cache_valid()
|
||||
}
|
||||
@@ -0,0 +1,228 @@
|
||||
"""Threat Intelligence Enrichment for AASRT."""
|
||||
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from src.engines import SearchResult
|
||||
from src.utils.logger import get_logger
|
||||
from .clawsec_feed import ClawSecAdvisory, ClawSecFeedManager
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class ThreatEnricher:
|
||||
"""
|
||||
Enriches SearchResult objects with ClawSec threat intelligence.
|
||||
|
||||
Responsibilities:
|
||||
- Match results against ClawSec advisories
|
||||
- Add CVE metadata to result.metadata
|
||||
- Inject ClawSec vulnerabilities into result.vulnerabilities
|
||||
"""
|
||||
|
||||
def __init__(self, feed_manager: ClawSecFeedManager, config=None):
|
||||
"""
|
||||
Initialize ThreatEnricher.
|
||||
|
||||
Args:
|
||||
feed_manager: ClawSecFeedManager instance
|
||||
config: Optional configuration object
|
||||
"""
|
||||
self.feed_manager = feed_manager
|
||||
self.config = config
|
||||
|
||||
def enrich(self, result: SearchResult) -> SearchResult:
|
||||
"""
|
||||
Enrich a single result with threat intelligence.
|
||||
|
||||
Args:
|
||||
result: SearchResult to enrich
|
||||
|
||||
Returns:
|
||||
Enriched SearchResult with ClawSec metadata
|
||||
"""
|
||||
# Extract product info from result
|
||||
product, version = self._extract_product_info(result)
|
||||
banner = result.banner or ''
|
||||
|
||||
# Get HTTP title if available
|
||||
http_info = result.metadata.get('http', {}) or {}
|
||||
title = http_info.get('title') or ''
|
||||
if title:
|
||||
banner = f"{banner} {title}"
|
||||
|
||||
# Match against ClawSec advisories
|
||||
advisories = self.feed_manager.match_advisories(
|
||||
product=product,
|
||||
version=version,
|
||||
banner=banner
|
||||
)
|
||||
|
||||
if advisories:
|
||||
result = self._add_cve_context(result, advisories)
|
||||
logger.debug(f"Enriched {result.ip}:{result.port} with {len(advisories)} ClawSec advisories")
|
||||
|
||||
return result
|
||||
|
||||
def enrich_batch(self, results: List[SearchResult]) -> List[SearchResult]:
|
||||
"""
|
||||
Enrich multiple results efficiently.
|
||||
|
||||
Args:
|
||||
results: List of SearchResults to enrich
|
||||
|
||||
Returns:
|
||||
List of enriched SearchResults
|
||||
"""
|
||||
enriched = []
|
||||
for result in results:
|
||||
enriched.append(self.enrich(result))
|
||||
return enriched
|
||||
|
||||
def _extract_product_info(self, result: SearchResult) -> Tuple[Optional[str], Optional[str]]:
|
||||
"""
|
||||
Extract product name and version from result metadata.
|
||||
|
||||
Args:
|
||||
result: SearchResult to analyze
|
||||
|
||||
Returns:
|
||||
Tuple of (product_name, version) or (None, None)
|
||||
"""
|
||||
product = None
|
||||
version = None
|
||||
|
||||
# Check metadata for product info
|
||||
metadata = result.metadata if isinstance(result.metadata, dict) else {}
|
||||
|
||||
# Try product field directly
|
||||
if 'product' in metadata:
|
||||
product = metadata['product']
|
||||
|
||||
# Try version field
|
||||
if 'version' in metadata:
|
||||
version = metadata['version']
|
||||
|
||||
# Check HTTP info
|
||||
http_info = metadata.get('http') or {}
|
||||
if http_info:
|
||||
title = http_info.get('title') or ''
|
||||
# Look for AI agent keywords in title
|
||||
ai_products = {
|
||||
'clawdbot': 'ClawdBot',
|
||||
'moltbot': 'MoltBot',
|
||||
'autogpt': 'AutoGPT',
|
||||
'langchain': 'LangChain',
|
||||
'openclaw': 'OpenClaw'
|
||||
}
|
||||
for keyword, name in ai_products.items():
|
||||
if title and keyword in title.lower():
|
||||
product = name
|
||||
break
|
||||
|
||||
# Check service name
|
||||
if not product and result.service:
|
||||
service_lower = result.service.lower()
|
||||
for keyword in ['clawdbot', 'moltbot', 'autogpt', 'langchain']:
|
||||
if keyword in service_lower:
|
||||
product = result.service
|
||||
break
|
||||
|
||||
# Check banner for version patterns
|
||||
if result.banner and not version:
|
||||
import re
|
||||
version_patterns = [
|
||||
r'v?(\d+\.\d+(?:\.\d+)?)', # v1.2.3 or 1.2.3
|
||||
r'version[:\s]+(\d+\.\d+(?:\.\d+)?)', # version: 1.2.3
|
||||
]
|
||||
for pattern in version_patterns:
|
||||
match = re.search(pattern, result.banner, re.IGNORECASE)
|
||||
if match:
|
||||
version = match.group(1)
|
||||
break
|
||||
|
||||
return product, version
|
||||
|
||||
def _add_cve_context(
|
||||
self,
|
||||
result: SearchResult,
|
||||
advisories: List[ClawSecAdvisory]
|
||||
) -> SearchResult:
|
||||
"""
|
||||
Add CVE information to result metadata and vulnerabilities.
|
||||
|
||||
Args:
|
||||
result: SearchResult to update
|
||||
advisories: List of matched ClawSecAdvisory objects
|
||||
|
||||
Returns:
|
||||
Updated SearchResult
|
||||
"""
|
||||
# Add ClawSec advisories to metadata
|
||||
clawsec_data = []
|
||||
for advisory in advisories:
|
||||
clawsec_data.append({
|
||||
'cve_id': advisory.cve_id,
|
||||
'severity': advisory.severity,
|
||||
'cvss_score': advisory.cvss_score,
|
||||
'title': advisory.title,
|
||||
'vuln_type': advisory.vuln_type,
|
||||
'action': advisory.action,
|
||||
'nvd_url': advisory.nvd_url,
|
||||
'cwe_id': advisory.cwe_id
|
||||
})
|
||||
|
||||
result.metadata['clawsec_advisories'] = clawsec_data
|
||||
|
||||
# Track highest severity for quick access
|
||||
severity_order = {'CRITICAL': 4, 'HIGH': 3, 'MEDIUM': 2, 'LOW': 1}
|
||||
highest_severity = max(
|
||||
(a.severity for a in advisories),
|
||||
key=lambda s: severity_order.get(s, 0),
|
||||
default='LOW'
|
||||
)
|
||||
result.metadata['clawsec_severity'] = highest_severity
|
||||
|
||||
# Add CVE IDs to vulnerabilities list
|
||||
for advisory in advisories:
|
||||
vuln_id = f"clawsec_{advisory.cve_id}"
|
||||
if vuln_id not in result.vulnerabilities:
|
||||
result.vulnerabilities.append(vuln_id)
|
||||
|
||||
return result
|
||||
|
||||
def get_enrichment_stats(self, results: List[SearchResult]) -> Dict[str, Any]:
|
||||
"""
|
||||
Get statistics about enrichment for a set of results.
|
||||
|
||||
Args:
|
||||
results: List of enriched SearchResults
|
||||
|
||||
Returns:
|
||||
Dictionary with enrichment statistics
|
||||
"""
|
||||
enriched_count = 0
|
||||
total_cves = 0
|
||||
severity_counts = {'CRITICAL': 0, 'HIGH': 0, 'MEDIUM': 0, 'LOW': 0}
|
||||
cve_list = set()
|
||||
|
||||
for result in results:
|
||||
advisories = result.metadata.get('clawsec_advisories', [])
|
||||
if advisories:
|
||||
enriched_count += 1
|
||||
total_cves += len(advisories)
|
||||
|
||||
for advisory in advisories:
|
||||
cve_list.add(advisory['cve_id'])
|
||||
severity = advisory.get('severity', 'LOW')
|
||||
if severity in severity_counts:
|
||||
severity_counts[severity] += 1
|
||||
|
||||
return {
|
||||
'enriched_results': enriched_count,
|
||||
'total_results': len(results),
|
||||
'enrichment_rate': (enriched_count / len(results) * 100) if results else 0,
|
||||
'unique_cves': len(cve_list),
|
||||
'total_cve_matches': total_cves,
|
||||
'severity_breakdown': severity_counts,
|
||||
'cve_ids': list(cve_list)
|
||||
}
|
||||
+689
@@ -0,0 +1,689 @@
|
||||
"""
|
||||
CLI entry point for AASRT - AI Agent Security Reconnaissance Tool.
|
||||
|
||||
This module provides the command-line interface for AASRT with:
|
||||
- Shodan-based security reconnaissance scanning
|
||||
- Vulnerability assessment and risk scoring
|
||||
- Report generation (JSON/CSV)
|
||||
- Database storage and history tracking
|
||||
- Signal handling for graceful shutdown
|
||||
|
||||
Usage:
|
||||
python -m src.main status # Check API status
|
||||
python -m src.main scan --template clawdbot_instances
|
||||
python -m src.main history # View scan history
|
||||
python -m src.main templates # List available templates
|
||||
|
||||
Environment Variables:
|
||||
SHODAN_API_KEY: Required for scanning operations
|
||||
AASRT_LOG_LEVEL: Logging level (DEBUG, INFO, WARNING, ERROR)
|
||||
AASRT_DEBUG: Enable debug mode (true/false)
|
||||
|
||||
Exit Codes:
|
||||
0: Success
|
||||
1: Error (invalid arguments, API errors, etc.)
|
||||
130: Interrupted by user (SIGINT/Ctrl+C)
|
||||
"""
|
||||
|
||||
import atexit
|
||||
import signal
|
||||
import sys
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import click
|
||||
from rich.console import Console
|
||||
from rich.table import Table
|
||||
from rich.panel import Panel
|
||||
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn
|
||||
|
||||
from src import __version__
|
||||
from src.utils.config import Config
|
||||
from src.utils.logger import setup_logger, get_logger
|
||||
from src.core.query_manager import QueryManager
|
||||
from src.core.result_aggregator import ResultAggregator
|
||||
from src.core.vulnerability_assessor import VulnerabilityAssessor
|
||||
from src.core.risk_scorer import RiskScorer
|
||||
from src.storage.database import Database
|
||||
from src.reporting import JSONReporter, CSVReporter, ScanReport
|
||||
|
||||
# =============================================================================
|
||||
# Global State
|
||||
# =============================================================================
|
||||
|
||||
console = Console()
|
||||
_shutdown_requested = False
|
||||
_active_database: Optional[Database] = None
|
||||
|
||||
# =============================================================================
|
||||
# Signal Handlers
|
||||
# =============================================================================
|
||||
|
||||
def _signal_handler(signum: int, frame: Any) -> None:
|
||||
"""
|
||||
Handle interrupt signals for graceful shutdown.
|
||||
|
||||
Args:
|
||||
signum: Signal number received.
|
||||
frame: Current stack frame (unused).
|
||||
"""
|
||||
global _shutdown_requested
|
||||
|
||||
signal_name = signal.Signals(signum).name
|
||||
|
||||
if _shutdown_requested:
|
||||
# Second interrupt - force exit
|
||||
console.print("\n[red]Force shutdown requested. Exiting immediately.[/red]")
|
||||
sys.exit(130)
|
||||
|
||||
_shutdown_requested = True
|
||||
console.print(f"\n[yellow]Received {signal_name}. Shutting down gracefully...[/yellow]")
|
||||
console.print("[dim]Press Ctrl+C again to force quit.[/dim]")
|
||||
|
||||
|
||||
def _cleanup() -> None:
|
||||
"""
|
||||
Cleanup function called on exit.
|
||||
|
||||
Closes database connections and performs cleanup.
|
||||
"""
|
||||
global _active_database
|
||||
|
||||
if _active_database:
|
||||
try:
|
||||
_active_database.close()
|
||||
except Exception:
|
||||
pass # Ignore errors during cleanup
|
||||
|
||||
|
||||
def is_shutdown_requested() -> bool:
|
||||
"""
|
||||
Check if shutdown has been requested.
|
||||
|
||||
Returns:
|
||||
True if a shutdown signal was received.
|
||||
"""
|
||||
return _shutdown_requested
|
||||
|
||||
|
||||
# Register signal handlers
|
||||
signal.signal(signal.SIGINT, _signal_handler)
|
||||
signal.signal(signal.SIGTERM, _signal_handler)
|
||||
atexit.register(_cleanup)
|
||||
|
||||
# =============================================================================
|
||||
# Legal Disclaimer
|
||||
# =============================================================================
|
||||
|
||||
LEGAL_DISCLAIMER = """
|
||||
[bold red]WARNING: LEGAL DISCLAIMER[/bold red]
|
||||
|
||||
This tool is for [bold]authorized security research and defensive purposes only[/bold].
|
||||
Unauthorized access to computer systems is illegal under:
|
||||
- CFAA (Computer Fraud and Abuse Act) - United States
|
||||
- Computer Misuse Act - United Kingdom
|
||||
- Similar laws worldwide
|
||||
|
||||
By proceeding, you acknowledge that:
|
||||
1. You have authorization to scan target systems
|
||||
2. You will comply with all applicable laws and terms of service
|
||||
3. You will responsibly disclose findings
|
||||
4. You will not exploit discovered vulnerabilities
|
||||
|
||||
[bold yellow]The authors are not responsible for misuse of this tool.[/bold yellow]
|
||||
"""
|
||||
|
||||
# =============================================================================
|
||||
# CLI Command Group
|
||||
# =============================================================================
|
||||
|
||||
@click.group()
|
||||
@click.version_option(version=__version__, prog_name="AASRT")
|
||||
@click.option('--config', '-c', type=click.Path(exists=True), help='Path to config file')
|
||||
@click.option('--verbose', '-v', is_flag=True, help='Enable verbose output')
|
||||
@click.pass_context
|
||||
def cli(ctx: click.Context, config: Optional[str], verbose: bool) -> None:
|
||||
"""
|
||||
AI Agent Security Reconnaissance Tool (AASRT).
|
||||
|
||||
Discover and assess exposed AI agent implementations using Shodan.
|
||||
|
||||
Use 'aasrt --help' for command list or 'aasrt COMMAND --help' for command details.
|
||||
"""
|
||||
ctx.ensure_object(dict)
|
||||
|
||||
# Initialize configuration
|
||||
try:
|
||||
ctx.obj['config'] = Config(config)
|
||||
except Exception as e:
|
||||
console.print(f"[red]Failed to load configuration: {e}[/red]")
|
||||
sys.exit(1)
|
||||
|
||||
# Setup logging
|
||||
log_level = 'DEBUG' if verbose else ctx.obj['config'].get('logging', 'level', default='INFO')
|
||||
log_file = ctx.obj['config'].get('logging', 'file')
|
||||
|
||||
try:
|
||||
setup_logger('aasrt', level=log_level, log_file=log_file)
|
||||
except Exception as e:
|
||||
console.print(f"[yellow]Warning: Could not setup logging: {e}[/yellow]")
|
||||
|
||||
ctx.obj['verbose'] = verbose
|
||||
|
||||
# Log startup in debug mode
|
||||
logger = get_logger('aasrt')
|
||||
logger.debug(f"AASRT v{__version__} starting (verbose={verbose})")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Scan Command
|
||||
# =============================================================================
|
||||
|
||||
@cli.command()
|
||||
@click.option('--query', '-q', help='Custom Shodan search query')
|
||||
@click.option('--template', '-t', help='Use predefined query template')
|
||||
@click.option('--max-results', '-m', default=100, type=int, help='Max results to retrieve (1-10000)')
|
||||
@click.option('--output', '-o', type=click.Path(), help='Output file path')
|
||||
@click.option('--format', '-f', 'output_format',
|
||||
type=click.Choice(['json', 'csv', 'both']),
|
||||
default='json', help='Output format')
|
||||
@click.option('--no-assess', is_flag=True, help='Skip vulnerability assessment')
|
||||
@click.option('--save-db/--no-save-db', default=True, help='Save results to database')
|
||||
@click.option('--yes', '-y', is_flag=True, help='Skip legal disclaimer confirmation')
|
||||
@click.pass_context
|
||||
def scan(
|
||||
ctx: click.Context,
|
||||
query: Optional[str],
|
||||
template: Optional[str],
|
||||
max_results: int,
|
||||
output: Optional[str],
|
||||
output_format: str,
|
||||
no_assess: bool,
|
||||
save_db: bool,
|
||||
yes: bool
|
||||
) -> None:
|
||||
"""
|
||||
Perform a security reconnaissance scan using Shodan.
|
||||
|
||||
Searches for exposed AI agent implementations and assesses their
|
||||
security posture using passive analysis techniques.
|
||||
|
||||
Examples:
|
||||
aasrt scan --template clawdbot_instances
|
||||
aasrt scan --query 'http.title:"AutoGPT"'
|
||||
aasrt scan -t exposed_env_files -m 50 -f csv
|
||||
"""
|
||||
global _active_database
|
||||
|
||||
config = ctx.obj['config']
|
||||
logger = get_logger('aasrt')
|
||||
|
||||
logger.info(f"Starting scan command (template={template}, query={query[:50] if query else None})")
|
||||
|
||||
# Display legal disclaimer
|
||||
if not yes:
|
||||
console.print(Panel(LEGAL_DISCLAIMER, title="Legal Notice", border_style="red"))
|
||||
if not click.confirm('\nDo you agree to the terms above?', default=False):
|
||||
console.print('[red]Scan aborted. You must agree to terms of use.[/red]')
|
||||
logger.info("Scan aborted: User declined legal disclaimer")
|
||||
sys.exit(1)
|
||||
|
||||
# Validate max_results
|
||||
if max_results < 1:
|
||||
console.print('[red]Error: max-results must be at least 1[/red]')
|
||||
sys.exit(1)
|
||||
if max_results > 10000:
|
||||
console.print('[yellow]Warning: Limiting max-results to 10000[/yellow]')
|
||||
max_results = 10000
|
||||
|
||||
# Validate inputs
|
||||
if not query and not template:
|
||||
console.print('[yellow]No query or template specified. Using default template: clawdbot_instances[/yellow]')
|
||||
template = 'clawdbot_instances'
|
||||
|
||||
# Check for shutdown before heavy operations
|
||||
if is_shutdown_requested():
|
||||
console.print('[yellow]Scan cancelled due to shutdown request.[/yellow]')
|
||||
sys.exit(130)
|
||||
|
||||
# Initialize query manager
|
||||
try:
|
||||
query_manager = QueryManager(config)
|
||||
except Exception as e:
|
||||
console.print(f'[red]Failed to initialize query manager: {e}[/red]')
|
||||
logger.error(f"Query manager initialization failed: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
# Check if Shodan is available
|
||||
if not query_manager.is_available():
|
||||
console.print('[red]Shodan is not available. Please check your API key in .env file.[/red]')
|
||||
console.print('[dim]Set SHODAN_API_KEY environment variable or add to .env file.[/dim]')
|
||||
sys.exit(1)
|
||||
|
||||
console.print('\n[green]Starting Shodan scan...[/green]')
|
||||
logger.info(f"Scan started: template={template}, max_results={max_results}")
|
||||
|
||||
# Generate scan ID
|
||||
scan_id = str(uuid.uuid4())
|
||||
start_time = time.time()
|
||||
|
||||
# Execute scan with interrupt checking
|
||||
all_results = []
|
||||
scan_error = None
|
||||
|
||||
with Progress(
|
||||
SpinnerColumn(),
|
||||
TextColumn("[progress.description]{task.description}"),
|
||||
BarColumn(),
|
||||
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
|
||||
console=console
|
||||
) as progress:
|
||||
if template:
|
||||
task = progress.add_task(f"[cyan]Scanning with template: {template}...", total=100)
|
||||
try:
|
||||
if not is_shutdown_requested():
|
||||
all_results = query_manager.execute_template(template, max_results=max_results)
|
||||
progress.update(task, completed=100)
|
||||
except KeyboardInterrupt:
|
||||
console.print('\n[yellow]Scan interrupted by user.[/yellow]')
|
||||
scan_error = "Interrupted"
|
||||
except Exception as e:
|
||||
console.print(f'[red]Template execution failed: {e}[/red]')
|
||||
logger.error(f"Template execution error: {e}", exc_info=True)
|
||||
scan_error = str(e)
|
||||
else:
|
||||
task = progress.add_task("[cyan]Executing query...", total=100)
|
||||
try:
|
||||
if not is_shutdown_requested():
|
||||
all_results = query_manager.execute_query(query, max_results=max_results)
|
||||
progress.update(task, completed=100)
|
||||
except KeyboardInterrupt:
|
||||
console.print('\n[yellow]Scan interrupted by user.[/yellow]')
|
||||
scan_error = "Interrupted"
|
||||
except Exception as e:
|
||||
console.print(f'[red]Query execution failed: {e}[/red]')
|
||||
logger.error(f"Query execution error: {e}", exc_info=True)
|
||||
scan_error = str(e)
|
||||
|
||||
# Check if scan was interrupted or had errors
|
||||
if is_shutdown_requested():
|
||||
console.print('[yellow]Scan was interrupted. Saving partial results...[/yellow]')
|
||||
|
||||
# Aggregate and deduplicate results
|
||||
console.print('\n[cyan]Aggregating results...[/cyan]')
|
||||
aggregator = ResultAggregator()
|
||||
unique_results = aggregator.aggregate({'shodan': all_results})
|
||||
|
||||
console.print(f'Found [green]{len(unique_results)}[/green] unique results')
|
||||
logger.info(f"Aggregated {len(unique_results)} unique results from {len(all_results)} total")
|
||||
|
||||
# Vulnerability assessment (skip if shutdown requested)
|
||||
if not no_assess and unique_results and not is_shutdown_requested():
|
||||
console.print('\n[cyan]Assessing vulnerabilities...[/cyan]')
|
||||
assessor = VulnerabilityAssessor(config.get('vulnerability_checks', default={}))
|
||||
scorer = RiskScorer()
|
||||
|
||||
with Progress(
|
||||
SpinnerColumn(),
|
||||
TextColumn("[progress.description]{task.description}"),
|
||||
BarColumn(),
|
||||
console=console
|
||||
) as progress:
|
||||
task = progress.add_task("[cyan]Analyzing...", total=len(unique_results))
|
||||
|
||||
for result in unique_results:
|
||||
if is_shutdown_requested():
|
||||
console.print('[yellow]Assessment interrupted.[/yellow]')
|
||||
break
|
||||
try:
|
||||
vulns = assessor.assess(result)
|
||||
scorer.score_result(result, vulns)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to assess result {result.ip}: {e}")
|
||||
progress.advance(task)
|
||||
|
||||
# Calculate duration
|
||||
duration = time.time() - start_time
|
||||
|
||||
# Determine final status
|
||||
final_status = 'completed'
|
||||
if scan_error:
|
||||
final_status = 'failed' if not unique_results else 'partial'
|
||||
elif is_shutdown_requested():
|
||||
final_status = 'partial'
|
||||
|
||||
# Create report
|
||||
report = ScanReport.from_results(
|
||||
scan_id=scan_id,
|
||||
results=unique_results,
|
||||
engines=['shodan'],
|
||||
query=query,
|
||||
template_name=template,
|
||||
duration=duration
|
||||
)
|
||||
|
||||
# Display summary
|
||||
_display_summary(report)
|
||||
|
||||
# Save to database
|
||||
if save_db:
|
||||
try:
|
||||
db = Database(config)
|
||||
_active_database = db # Track for cleanup
|
||||
|
||||
scan_record = db.create_scan(
|
||||
engines=['shodan'],
|
||||
query=query,
|
||||
template_name=template
|
||||
)
|
||||
|
||||
if unique_results:
|
||||
db.add_findings(scan_record.scan_id, unique_results)
|
||||
|
||||
db.update_scan(
|
||||
scan_record.scan_id,
|
||||
status=final_status,
|
||||
total_results=len(unique_results),
|
||||
duration_seconds=duration
|
||||
)
|
||||
console.print(f'\n[green]Results saved to database. Scan ID: {scan_record.scan_id}[/green]')
|
||||
logger.info(f"Saved scan {scan_record.scan_id} with {len(unique_results)} findings")
|
||||
except Exception as e:
|
||||
console.print(f'[yellow]Warning: Failed to save to database: {e}[/yellow]')
|
||||
logger.error(f"Database save error: {e}", exc_info=True)
|
||||
|
||||
# Generate reports
|
||||
output_dir = config.get('reporting', 'output_dir', default='./reports')
|
||||
|
||||
try:
|
||||
if output_format in ['json', 'both']:
|
||||
json_reporter = JSONReporter(output_dir)
|
||||
json_path = json_reporter.generate(report, output)
|
||||
console.print(f'[green]JSON report: {json_path}[/green]')
|
||||
|
||||
if output_format in ['csv', 'both']:
|
||||
csv_reporter = CSVReporter(output_dir)
|
||||
csv_path = csv_reporter.generate(report, output)
|
||||
console.print(f'[green]CSV report: {csv_path}[/green]')
|
||||
except Exception as e:
|
||||
console.print(f'[yellow]Warning: Failed to generate report: {e}[/yellow]')
|
||||
logger.error(f"Report generation error: {e}", exc_info=True)
|
||||
|
||||
# Final status message
|
||||
if final_status == 'completed':
|
||||
console.print(f'\n[bold green]Scan completed in {duration:.1f} seconds[/bold green]')
|
||||
elif final_status == 'partial':
|
||||
console.print(f'\n[bold yellow]Scan partially completed in {duration:.1f} seconds[/bold yellow]')
|
||||
else:
|
||||
console.print(f'\n[bold red]Scan failed after {duration:.1f} seconds[/bold red]')
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Helper Functions
|
||||
# =============================================================================
|
||||
|
||||
def _display_summary(report: ScanReport) -> None:
|
||||
"""
|
||||
Display scan summary in a formatted table.
|
||||
|
||||
Renders a Rich-formatted summary including:
|
||||
- Scan ID and duration
|
||||
- Total results and average risk score
|
||||
- Risk distribution table
|
||||
- Top 5 highest risk findings
|
||||
|
||||
Args:
|
||||
report: ScanReport object with scan results.
|
||||
"""
|
||||
console.print('\n')
|
||||
|
||||
# Summary panel
|
||||
summary_text = f"""
|
||||
[bold]Scan ID:[/bold] {report.scan_id[:8]}...
|
||||
[bold]Duration:[/bold] {report.duration_seconds:.1f}s
|
||||
[bold]Total Results:[/bold] {report.total_results}
|
||||
[bold]Average Risk Score:[/bold] {report.average_risk_score}/10
|
||||
"""
|
||||
console.print(Panel(summary_text, title="Scan Summary", border_style="green"))
|
||||
|
||||
# Risk distribution table
|
||||
table = Table(title="Risk Distribution")
|
||||
table.add_column("Severity", style="bold")
|
||||
table.add_column("Count", justify="right")
|
||||
|
||||
table.add_row("[red]Critical[/red]", str(report.critical_findings))
|
||||
table.add_row("[orange1]High[/orange1]", str(report.high_findings))
|
||||
table.add_row("[yellow]Medium[/yellow]", str(report.medium_findings))
|
||||
table.add_row("[green]Low[/green]", str(report.low_findings))
|
||||
|
||||
console.print(table)
|
||||
|
||||
# Top findings
|
||||
if report.findings:
|
||||
console.print('\n[bold]Top 5 Highest Risk Findings:[/bold]')
|
||||
top_findings = sorted(
|
||||
report.findings,
|
||||
key=lambda x: x.get('risk_score', 0),
|
||||
reverse=True
|
||||
)[:5]
|
||||
|
||||
for i, finding in enumerate(top_findings, 1):
|
||||
risk = finding.get('risk_score', 0)
|
||||
ip = finding.get('target_ip', 'N/A')
|
||||
port = finding.get('target_port', 'N/A')
|
||||
hostname = finding.get('target_hostname', '')
|
||||
|
||||
# Color-code by CVSS-like severity
|
||||
if risk >= 9.0:
|
||||
color = 'red' # Critical
|
||||
elif risk >= 7.0:
|
||||
color = 'orange1' # High
|
||||
elif risk >= 4.0:
|
||||
color = 'yellow' # Medium
|
||||
else:
|
||||
color = 'green' # Low
|
||||
|
||||
target = f"{ip}:{port}"
|
||||
if hostname:
|
||||
target += f" ({hostname})"
|
||||
|
||||
console.print(f" {i}. [{color}]{target}[/{color}] - Risk: [{color}]{risk}[/{color}]")
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.option('--scan-id', '-s', help='Generate report for specific scan')
|
||||
@click.option('--format', '-f', 'output_format',
|
||||
type=click.Choice(['json', 'csv', 'both']),
|
||||
default='json', help='Output format')
|
||||
@click.option('--output', '-o', type=click.Path(), help='Output file path')
|
||||
@click.pass_context
|
||||
def report(ctx, scan_id, output_format, output):
|
||||
"""Generate a report from a previous scan."""
|
||||
config = ctx.obj['config']
|
||||
|
||||
try:
|
||||
db = Database(config)
|
||||
except Exception as e:
|
||||
console.print(f'[red]Failed to connect to database: {e}[/red]')
|
||||
sys.exit(1)
|
||||
|
||||
if scan_id:
|
||||
scan = db.get_scan(scan_id)
|
||||
if not scan:
|
||||
console.print(f'[red]Scan not found: {scan_id}[/red]')
|
||||
sys.exit(1)
|
||||
scans = [scan]
|
||||
else:
|
||||
scans = db.get_recent_scans(limit=1)
|
||||
if not scans:
|
||||
console.print('[yellow]No scans found in database.[/yellow]')
|
||||
sys.exit(0)
|
||||
|
||||
scan = scans[0]
|
||||
findings = db.get_findings(scan_id=scan.scan_id)
|
||||
|
||||
report_data = ScanReport.from_scan(scan, findings)
|
||||
output_dir = config.get('reporting', 'output_dir', default='./reports')
|
||||
|
||||
if output_format in ['json', 'both']:
|
||||
json_reporter = JSONReporter(output_dir)
|
||||
json_path = json_reporter.generate(report_data, output)
|
||||
console.print(f'[green]JSON report: {json_path}[/green]')
|
||||
|
||||
if output_format in ['csv', 'both']:
|
||||
csv_reporter = CSVReporter(output_dir)
|
||||
csv_path = csv_reporter.generate(report_data, output)
|
||||
console.print(f'[green]CSV report: {csv_path}[/green]')
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.pass_context
|
||||
def status(ctx):
|
||||
"""Show status of Shodan API configuration."""
|
||||
config = ctx.obj['config']
|
||||
|
||||
console.print('\n[bold]Shodan API Status[/bold]\n')
|
||||
|
||||
try:
|
||||
query_manager = QueryManager(config)
|
||||
except Exception as e:
|
||||
console.print(f'[red]Failed to initialize: {e}[/red]')
|
||||
return
|
||||
|
||||
# Validate Shodan
|
||||
table = Table(title="Engine Status")
|
||||
table.add_column("Engine", style="bold")
|
||||
table.add_column("Status")
|
||||
table.add_column("Details")
|
||||
|
||||
if query_manager.is_available():
|
||||
is_valid = query_manager.validate_engine()
|
||||
if is_valid:
|
||||
quota = query_manager.get_quota_info()
|
||||
status_str = "[green]OK[/green]"
|
||||
details = f"Credits: {quota.get('query_credits', 'N/A')}, Plan: {quota.get('plan', 'N/A')}"
|
||||
else:
|
||||
status_str = "[red]Invalid[/red]"
|
||||
details = "API key validation failed"
|
||||
else:
|
||||
status_str = "[red]Not Configured[/red]"
|
||||
details = "Add SHODAN_API_KEY to .env file"
|
||||
|
||||
table.add_row("Shodan", status_str, details)
|
||||
console.print(table)
|
||||
|
||||
# Available templates
|
||||
templates = query_manager.get_available_templates()
|
||||
console.print(f'\n[bold]Available Query Templates:[/bold] {len(templates)}')
|
||||
for template in sorted(templates):
|
||||
console.print(f' - {template}')
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.option('--limit', '-l', default=10, help='Number of recent scans to show')
|
||||
@click.pass_context
|
||||
def history(ctx, limit):
|
||||
"""Show scan history from database."""
|
||||
config = ctx.obj['config']
|
||||
|
||||
try:
|
||||
db = Database(config)
|
||||
scans = db.get_recent_scans(limit=limit)
|
||||
except Exception as e:
|
||||
console.print(f'[red]Failed to access database: {e}[/red]')
|
||||
return
|
||||
|
||||
if not scans:
|
||||
console.print('[yellow]No scans found in database.[/yellow]')
|
||||
return
|
||||
|
||||
table = Table(title=f"Recent Scans (Last {limit})")
|
||||
table.add_column("Scan ID", style="cyan")
|
||||
table.add_column("Timestamp")
|
||||
table.add_column("Template/Query")
|
||||
table.add_column("Results", justify="right")
|
||||
table.add_column("Status")
|
||||
|
||||
for scan in scans:
|
||||
scan_id = scan.scan_id[:8] + "..."
|
||||
timestamp = scan.timestamp.strftime("%Y-%m-%d %H:%M") if scan.timestamp else "N/A"
|
||||
query_info = scan.template_name or (scan.query[:30] + "..." if scan.query and len(scan.query) > 30 else scan.query) or "N/A"
|
||||
|
||||
status_color = "green" if scan.status == "completed" else "yellow" if scan.status == "running" else "red"
|
||||
status_str = f"[{status_color}]{scan.status}[/{status_color}]"
|
||||
|
||||
table.add_row(scan_id, timestamp, query_info, str(scan.total_results), status_str)
|
||||
|
||||
console.print(table)
|
||||
|
||||
# Show database stats
|
||||
stats = db.get_statistics()
|
||||
console.print(f'\n[bold]Database Statistics:[/bold]')
|
||||
console.print(f' Total Scans: {stats["total_scans"]}')
|
||||
console.print(f' Total Findings: {stats["total_findings"]}')
|
||||
console.print(f' Unique IPs: {stats["unique_ips"]}')
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.pass_context
|
||||
def templates(ctx):
|
||||
"""List available query templates."""
|
||||
config = ctx.obj['config']
|
||||
|
||||
try:
|
||||
query_manager = QueryManager(config)
|
||||
except Exception as e:
|
||||
console.print(f'[red]Failed to initialize: {e}[/red]')
|
||||
return
|
||||
|
||||
templates = query_manager.get_available_templates()
|
||||
|
||||
console.print('\n[bold]Available Shodan Query Templates[/bold]\n')
|
||||
|
||||
table = Table()
|
||||
table.add_column("Template Name", style="cyan")
|
||||
table.add_column("Queries")
|
||||
|
||||
for template_name in sorted(templates):
|
||||
queries = query_manager.templates.get(template_name, [])
|
||||
query_count = len(queries)
|
||||
table.add_row(template_name, f"{query_count} queries")
|
||||
|
||||
console.print(table)
|
||||
|
||||
console.print('\n[dim]Use with: aasrt scan --template <template_name>[/dim]')
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Entry Point
|
||||
# =============================================================================
|
||||
|
||||
def main() -> None:
|
||||
"""
|
||||
Main entry point for AASRT CLI.
|
||||
|
||||
Initializes the Click command group and handles top-level exceptions.
|
||||
Called when running `python -m src.main` or `aasrt` command.
|
||||
|
||||
Exit Codes:
|
||||
0: Success
|
||||
1: Error
|
||||
130: Interrupted by user
|
||||
"""
|
||||
try:
|
||||
cli(obj={})
|
||||
except KeyboardInterrupt:
|
||||
console.print("\n[yellow]Operation cancelled by user.[/yellow]")
|
||||
sys.exit(130)
|
||||
except Exception as e:
|
||||
logger = get_logger('aasrt')
|
||||
logger.exception(f"Unexpected error: {e}")
|
||||
console.print(f"\n[red]Unexpected error: {e}[/red]")
|
||||
console.print("[dim]Check logs for details.[/dim]")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -0,0 +1,7 @@
|
||||
"""Reporting modules for AASRT."""
|
||||
|
||||
from .base import BaseReporter, ScanReport
|
||||
from .json_reporter import JSONReporter
|
||||
from .csv_reporter import CSVReporter
|
||||
|
||||
__all__ = ['BaseReporter', 'ScanReport', 'JSONReporter', 'CSVReporter']
|
||||
@@ -0,0 +1,199 @@
|
||||
"""Base reporter class for AASRT."""
|
||||
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from src.engines import SearchResult
|
||||
from src.storage.database import Scan, Finding
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScanReport:
|
||||
"""Container for scan report data."""
|
||||
|
||||
scan_id: str
|
||||
timestamp: datetime
|
||||
engines_used: List[str]
|
||||
query: Optional[str] = None
|
||||
template_name: Optional[str] = None
|
||||
total_results: int = 0
|
||||
duration_seconds: float = 0.0
|
||||
status: str = "completed"
|
||||
|
||||
# Summary statistics
|
||||
critical_findings: int = 0
|
||||
high_findings: int = 0
|
||||
medium_findings: int = 0
|
||||
low_findings: int = 0
|
||||
average_risk_score: float = 0.0
|
||||
|
||||
# Detailed findings
|
||||
findings: List[Dict[str, Any]] = field(default_factory=list)
|
||||
|
||||
# Additional metadata
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@classmethod
|
||||
def from_scan(
|
||||
cls,
|
||||
scan: Scan,
|
||||
findings: List[Finding]
|
||||
) -> 'ScanReport':
|
||||
"""Create ScanReport from database objects."""
|
||||
import json
|
||||
|
||||
# Calculate severity counts
|
||||
critical = sum(1 for f in findings if f.risk_score >= 9.0)
|
||||
high = sum(1 for f in findings if 7.0 <= f.risk_score < 9.0)
|
||||
medium = sum(1 for f in findings if 4.0 <= f.risk_score < 7.0)
|
||||
low = sum(1 for f in findings if f.risk_score < 4.0)
|
||||
|
||||
# Calculate average risk
|
||||
avg_risk = sum(f.risk_score for f in findings) / len(findings) if findings else 0.0
|
||||
|
||||
return cls(
|
||||
scan_id=scan.scan_id,
|
||||
timestamp=scan.timestamp,
|
||||
engines_used=json.loads(scan.engines_used) if scan.engines_used else [],
|
||||
query=scan.query,
|
||||
template_name=scan.template_name,
|
||||
total_results=len(findings),
|
||||
duration_seconds=scan.duration_seconds or 0.0,
|
||||
status=scan.status,
|
||||
critical_findings=critical,
|
||||
high_findings=high,
|
||||
medium_findings=medium,
|
||||
low_findings=low,
|
||||
average_risk_score=round(avg_risk, 1),
|
||||
findings=[f.to_dict() for f in findings],
|
||||
metadata=json.loads(scan.metadata) if scan.metadata else {}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_results(
|
||||
cls,
|
||||
scan_id: str,
|
||||
results: List[SearchResult],
|
||||
engines: List[str],
|
||||
query: Optional[str] = None,
|
||||
template_name: Optional[str] = None,
|
||||
duration: float = 0.0
|
||||
) -> 'ScanReport':
|
||||
"""Create ScanReport from search results."""
|
||||
# Calculate severity counts
|
||||
critical = sum(1 for r in results if r.risk_score >= 9.0)
|
||||
high = sum(1 for r in results if 7.0 <= r.risk_score < 9.0)
|
||||
medium = sum(1 for r in results if 4.0 <= r.risk_score < 7.0)
|
||||
low = sum(1 for r in results if r.risk_score < 4.0)
|
||||
|
||||
# Calculate average risk
|
||||
avg_risk = sum(r.risk_score for r in results) / len(results) if results else 0.0
|
||||
|
||||
return cls(
|
||||
scan_id=scan_id,
|
||||
timestamp=datetime.utcnow(),
|
||||
engines_used=engines,
|
||||
query=query,
|
||||
template_name=template_name,
|
||||
total_results=len(results),
|
||||
duration_seconds=duration,
|
||||
status="completed",
|
||||
critical_findings=critical,
|
||||
high_findings=high,
|
||||
medium_findings=medium,
|
||||
low_findings=low,
|
||||
average_risk_score=round(avg_risk, 1),
|
||||
findings=[r.to_dict() for r in results]
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
'scan_metadata': {
|
||||
'scan_id': self.scan_id,
|
||||
'timestamp': self.timestamp.isoformat() if self.timestamp else None,
|
||||
'engines_used': self.engines_used,
|
||||
'query': self.query,
|
||||
'template_name': self.template_name,
|
||||
'total_results': self.total_results,
|
||||
'duration_seconds': self.duration_seconds,
|
||||
'status': self.status
|
||||
},
|
||||
'summary': {
|
||||
'critical_findings': self.critical_findings,
|
||||
'high_findings': self.high_findings,
|
||||
'medium_findings': self.medium_findings,
|
||||
'low_findings': self.low_findings,
|
||||
'average_risk_score': self.average_risk_score
|
||||
},
|
||||
'findings': self.findings,
|
||||
'metadata': self.metadata
|
||||
}
|
||||
|
||||
|
||||
class BaseReporter(ABC):
|
||||
"""Abstract base class for reporters."""
|
||||
|
||||
def __init__(self, output_dir: str = "./reports"):
|
||||
"""
|
||||
Initialize reporter.
|
||||
|
||||
Args:
|
||||
output_dir: Directory for report output
|
||||
"""
|
||||
self.output_dir = output_dir
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def format_name(self) -> str:
|
||||
"""Return the format name (e.g., 'json', 'csv')."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def file_extension(self) -> str:
|
||||
"""Return the file extension."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def generate(self, report: ScanReport, filename: Optional[str] = None) -> str:
|
||||
"""
|
||||
Generate a report file.
|
||||
|
||||
Args:
|
||||
report: ScanReport data
|
||||
filename: Optional custom filename (without extension)
|
||||
|
||||
Returns:
|
||||
Path to generated report file
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def generate_string(self, report: ScanReport) -> str:
|
||||
"""
|
||||
Generate report as a string.
|
||||
|
||||
Args:
|
||||
report: ScanReport data
|
||||
|
||||
Returns:
|
||||
Report content as string
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_filename(self, scan_id: str, custom_name: Optional[str] = None) -> str:
|
||||
"""Generate a filename for the report."""
|
||||
if custom_name:
|
||||
return f"{custom_name}.{self.file_extension}"
|
||||
|
||||
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
|
||||
return f"scan_{scan_id[:8]}_{timestamp}.{self.file_extension}"
|
||||
|
||||
def get_filepath(self, filename: str) -> str:
|
||||
"""Get full file path for a report."""
|
||||
return os.path.join(self.output_dir, filename)
|
||||
@@ -0,0 +1,221 @@
|
||||
"""CSV report generator for AASRT."""
|
||||
|
||||
import csv
|
||||
import io
|
||||
from typing import List, Optional
|
||||
|
||||
from .base import BaseReporter, ScanReport
|
||||
from src.utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class CSVReporter(BaseReporter):
|
||||
"""Generates CSV format reports."""
|
||||
|
||||
# Default columns for findings export
|
||||
DEFAULT_COLUMNS = [
|
||||
'target_ip',
|
||||
'target_port',
|
||||
'target_hostname',
|
||||
'service',
|
||||
'risk_score',
|
||||
'vulnerabilities',
|
||||
'source_engine',
|
||||
'first_seen',
|
||||
'status',
|
||||
'confidence'
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
output_dir: str = "./reports",
|
||||
columns: Optional[List[str]] = None,
|
||||
include_metadata: bool = False
|
||||
):
|
||||
"""
|
||||
Initialize CSV reporter.
|
||||
|
||||
Args:
|
||||
output_dir: Output directory for reports
|
||||
columns: Custom columns to include
|
||||
include_metadata: Whether to include metadata columns
|
||||
"""
|
||||
super().__init__(output_dir)
|
||||
self.columns = columns or self.DEFAULT_COLUMNS.copy()
|
||||
self.include_metadata = include_metadata
|
||||
|
||||
if include_metadata:
|
||||
self.columns.extend(['location', 'isp', 'asn'])
|
||||
|
||||
@property
|
||||
def format_name(self) -> str:
|
||||
return "csv"
|
||||
|
||||
@property
|
||||
def file_extension(self) -> str:
|
||||
return "csv"
|
||||
|
||||
def generate(self, report: ScanReport, filename: Optional[str] = None) -> str:
|
||||
"""
|
||||
Generate CSV report file.
|
||||
|
||||
Args:
|
||||
report: ScanReport data
|
||||
filename: Optional custom filename
|
||||
|
||||
Returns:
|
||||
Path to generated report file
|
||||
"""
|
||||
output_filename = self.get_filename(report.scan_id, filename)
|
||||
filepath = self.get_filepath(output_filename)
|
||||
|
||||
content = self.generate_string(report)
|
||||
|
||||
with open(filepath, 'w', encoding='utf-8', newline='') as f:
|
||||
f.write(content)
|
||||
|
||||
logger.info(f"Generated CSV report: {filepath}")
|
||||
return filepath
|
||||
|
||||
def generate_string(self, report: ScanReport) -> str:
|
||||
"""
|
||||
Generate CSV report as string.
|
||||
|
||||
Args:
|
||||
report: ScanReport data
|
||||
|
||||
Returns:
|
||||
CSV content as string
|
||||
"""
|
||||
output = io.StringIO()
|
||||
writer = csv.DictWriter(output, fieldnames=self.columns, extrasaction='ignore')
|
||||
|
||||
# Write header
|
||||
writer.writeheader()
|
||||
|
||||
# Write findings
|
||||
for finding in report.findings:
|
||||
row = self._format_finding(finding)
|
||||
writer.writerow(row)
|
||||
|
||||
return output.getvalue()
|
||||
|
||||
def _format_finding(self, finding: dict) -> dict:
|
||||
"""Format a finding for CSV output."""
|
||||
row = {}
|
||||
|
||||
for col in self.columns:
|
||||
if col in finding:
|
||||
value = finding[col]
|
||||
|
||||
# Convert lists to comma-separated strings
|
||||
if isinstance(value, list):
|
||||
value = '; '.join(str(v) for v in value)
|
||||
# Convert dicts to string representation
|
||||
elif isinstance(value, dict):
|
||||
value = str(value)
|
||||
|
||||
row[col] = value
|
||||
elif col == 'location':
|
||||
# Extract from metadata
|
||||
metadata = finding.get('metadata', {})
|
||||
location = metadata.get('location', {})
|
||||
if isinstance(location, dict):
|
||||
row[col] = f"{location.get('country', '')}, {location.get('city', '')}"
|
||||
else:
|
||||
row[col] = ''
|
||||
elif col == 'isp':
|
||||
metadata = finding.get('metadata', {})
|
||||
row[col] = metadata.get('isp', '')
|
||||
elif col == 'asn':
|
||||
metadata = finding.get('metadata', {})
|
||||
row[col] = metadata.get('asn', '')
|
||||
else:
|
||||
row[col] = ''
|
||||
|
||||
return row
|
||||
|
||||
def generate_summary(self, report: ScanReport, filename: Optional[str] = None) -> str:
|
||||
"""
|
||||
Generate a summary CSV file.
|
||||
|
||||
Args:
|
||||
report: ScanReport data
|
||||
filename: Optional custom filename
|
||||
|
||||
Returns:
|
||||
Path to generated file
|
||||
"""
|
||||
summary_filename = filename or f"summary_{report.scan_id[:8]}"
|
||||
if not summary_filename.endswith('.csv'):
|
||||
summary_filename = f"{summary_filename}.csv"
|
||||
|
||||
filepath = self.get_filepath(summary_filename)
|
||||
|
||||
output = io.StringIO()
|
||||
writer = csv.writer(output)
|
||||
|
||||
# Write summary as key-value pairs
|
||||
writer.writerow(['Metric', 'Value'])
|
||||
writer.writerow(['Scan ID', report.scan_id])
|
||||
writer.writerow(['Timestamp', report.timestamp.isoformat() if report.timestamp else ''])
|
||||
writer.writerow(['Engines Used', ', '.join(report.engines_used)])
|
||||
writer.writerow(['Query', report.query or ''])
|
||||
writer.writerow(['Template', report.template_name or ''])
|
||||
writer.writerow(['Total Results', report.total_results])
|
||||
writer.writerow(['Duration (seconds)', report.duration_seconds])
|
||||
writer.writerow(['Critical Findings', report.critical_findings])
|
||||
writer.writerow(['High Findings', report.high_findings])
|
||||
writer.writerow(['Medium Findings', report.medium_findings])
|
||||
writer.writerow(['Low Findings', report.low_findings])
|
||||
writer.writerow(['Average Risk Score', report.average_risk_score])
|
||||
|
||||
with open(filepath, 'w', encoding='utf-8', newline='') as f:
|
||||
f.write(output.getvalue())
|
||||
|
||||
logger.info(f"Generated CSV summary: {filepath}")
|
||||
return filepath
|
||||
|
||||
def generate_vulnerability_report(self, report: ScanReport, filename: Optional[str] = None) -> str:
|
||||
"""
|
||||
Generate a vulnerability-focused CSV report.
|
||||
|
||||
Args:
|
||||
report: ScanReport data
|
||||
filename: Optional custom filename
|
||||
|
||||
Returns:
|
||||
Path to generated file
|
||||
"""
|
||||
vuln_filename = filename or f"vulnerabilities_{report.scan_id[:8]}"
|
||||
if not vuln_filename.endswith('.csv'):
|
||||
vuln_filename = f"{vuln_filename}.csv"
|
||||
|
||||
filepath = self.get_filepath(vuln_filename)
|
||||
|
||||
output = io.StringIO()
|
||||
writer = csv.writer(output)
|
||||
|
||||
# Header
|
||||
writer.writerow(['Target IP', 'Port', 'Hostname', 'Vulnerability', 'Risk Score'])
|
||||
|
||||
# Write vulnerability rows
|
||||
for finding in report.findings:
|
||||
ip = finding.get('target_ip', '')
|
||||
port = finding.get('target_port', '')
|
||||
hostname = finding.get('target_hostname', '')
|
||||
risk_score = finding.get('risk_score', 0)
|
||||
|
||||
vulns = finding.get('vulnerabilities', [])
|
||||
if vulns:
|
||||
for vuln in vulns:
|
||||
writer.writerow([ip, port, hostname, vuln, risk_score])
|
||||
else:
|
||||
writer.writerow([ip, port, hostname, 'None detected', risk_score])
|
||||
|
||||
with open(filepath, 'w', encoding='utf-8', newline='') as f:
|
||||
f.write(output.getvalue())
|
||||
|
||||
logger.info(f"Generated vulnerability CSV: {filepath}")
|
||||
return filepath
|
||||
@@ -0,0 +1,122 @@
|
||||
"""JSON report generator for AASRT."""
|
||||
|
||||
import json
|
||||
from typing import Optional
|
||||
|
||||
from .base import BaseReporter, ScanReport
|
||||
from src.utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class JSONReporter(BaseReporter):
|
||||
"""Generates JSON format reports."""
|
||||
|
||||
def __init__(self, output_dir: str = "./reports", pretty: bool = True):
|
||||
"""
|
||||
Initialize JSON reporter.
|
||||
|
||||
Args:
|
||||
output_dir: Output directory for reports
|
||||
pretty: Whether to format JSON with indentation
|
||||
"""
|
||||
super().__init__(output_dir)
|
||||
self.pretty = pretty
|
||||
|
||||
@property
|
||||
def format_name(self) -> str:
|
||||
return "json"
|
||||
|
||||
@property
|
||||
def file_extension(self) -> str:
|
||||
return "json"
|
||||
|
||||
def generate(self, report: ScanReport, filename: Optional[str] = None) -> str:
|
||||
"""
|
||||
Generate JSON report file.
|
||||
|
||||
Args:
|
||||
report: ScanReport data
|
||||
filename: Optional custom filename
|
||||
|
||||
Returns:
|
||||
Path to generated report file
|
||||
"""
|
||||
output_filename = self.get_filename(report.scan_id, filename)
|
||||
filepath = self.get_filepath(output_filename)
|
||||
|
||||
content = self.generate_string(report)
|
||||
|
||||
with open(filepath, 'w', encoding='utf-8') as f:
|
||||
f.write(content)
|
||||
|
||||
logger.info(f"Generated JSON report: {filepath}")
|
||||
return filepath
|
||||
|
||||
def generate_string(self, report: ScanReport) -> str:
|
||||
"""
|
||||
Generate JSON report as string.
|
||||
|
||||
Args:
|
||||
report: ScanReport data
|
||||
|
||||
Returns:
|
||||
JSON string
|
||||
"""
|
||||
data = report.to_dict()
|
||||
|
||||
# Add report metadata
|
||||
data['report_metadata'] = {
|
||||
'format': 'json',
|
||||
'version': '1.0',
|
||||
'generated_by': 'AASRT (AI Agent Security Reconnaissance Tool)'
|
||||
}
|
||||
|
||||
if self.pretty:
|
||||
return json.dumps(data, indent=2, default=str, ensure_ascii=False)
|
||||
else:
|
||||
return json.dumps(data, default=str, ensure_ascii=False)
|
||||
|
||||
def generate_summary(self, report: ScanReport) -> str:
|
||||
"""
|
||||
Generate a summary-only JSON report.
|
||||
|
||||
Args:
|
||||
report: ScanReport data
|
||||
|
||||
Returns:
|
||||
JSON string with summary only
|
||||
"""
|
||||
summary = {
|
||||
'scan_id': report.scan_id,
|
||||
'timestamp': report.timestamp.isoformat() if report.timestamp else None,
|
||||
'engines_used': report.engines_used,
|
||||
'total_results': report.total_results,
|
||||
'summary': {
|
||||
'critical_findings': report.critical_findings,
|
||||
'high_findings': report.high_findings,
|
||||
'medium_findings': report.medium_findings,
|
||||
'low_findings': report.low_findings,
|
||||
'average_risk_score': report.average_risk_score
|
||||
}
|
||||
}
|
||||
|
||||
if self.pretty:
|
||||
return json.dumps(summary, indent=2, default=str)
|
||||
else:
|
||||
return json.dumps(summary, default=str)
|
||||
|
||||
def generate_findings_only(self, report: ScanReport) -> str:
|
||||
"""
|
||||
Generate JSON with findings only (no metadata).
|
||||
|
||||
Args:
|
||||
report: ScanReport data
|
||||
|
||||
Returns:
|
||||
JSON string with findings array
|
||||
"""
|
||||
if self.pretty:
|
||||
return json.dumps(report.findings, indent=2, default=str, ensure_ascii=False)
|
||||
else:
|
||||
return json.dumps(report.findings, default=str, ensure_ascii=False)
|
||||
@@ -0,0 +1,5 @@
|
||||
"""Storage modules for AASRT."""
|
||||
|
||||
from .database import Database, Scan, Finding
|
||||
|
||||
__all__ = ['Database', 'Scan', 'Finding']
|
||||
@@ -0,0 +1,806 @@
|
||||
"""
|
||||
Database storage layer for AASRT.
|
||||
|
||||
This module provides a production-ready database layer with:
|
||||
- Connection pooling for efficient resource usage
|
||||
- Automatic retry logic for transient failures
|
||||
- Context managers for proper session cleanup
|
||||
- Support for SQLite (default) and PostgreSQL
|
||||
- Comprehensive logging and error handling
|
||||
|
||||
Example:
|
||||
>>> from src.storage.database import Database
|
||||
>>> db = Database()
|
||||
>>> scan = db.create_scan(engines=["shodan"], query="http.html:agent")
|
||||
>>> db.add_findings(scan.scan_id, results)
|
||||
>>> db.update_scan(scan.scan_id, status="completed")
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime, timedelta
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, Dict, Generator, List, Optional, TypeVar
|
||||
|
||||
from sqlalchemy import (
|
||||
create_engine, Column, String, Integer, Float, DateTime,
|
||||
Text, Boolean, ForeignKey, Index, event
|
||||
)
|
||||
from sqlalchemy.orm import declarative_base, sessionmaker, relationship, Session, scoped_session
|
||||
from sqlalchemy.pool import QueuePool, StaticPool
|
||||
from sqlalchemy.exc import SQLAlchemyError, OperationalError, IntegrityError
|
||||
|
||||
from src.engines import SearchResult
|
||||
from src.utils.config import Config
|
||||
from src.utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
# =============================================================================
|
||||
# Retry Configuration
|
||||
# =============================================================================
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
# Maximum retry attempts for transient database errors
|
||||
MAX_DB_RETRIES = 3
|
||||
|
||||
# Base delay for exponential backoff (seconds)
|
||||
DB_RETRY_BASE_DELAY = 0.5
|
||||
|
||||
# Exceptions that should trigger a retry
|
||||
RETRYABLE_DB_EXCEPTIONS = (OperationalError,)
|
||||
|
||||
|
||||
def with_db_retry(func: Callable[..., T]) -> Callable[..., T]:
|
||||
"""
|
||||
Decorator that adds retry logic for transient database errors.
|
||||
|
||||
Retries on connection errors and deadlocks but not on
|
||||
constraint violations or other permanent errors.
|
||||
|
||||
Args:
|
||||
func: Database function to wrap with retry logic.
|
||||
|
||||
Returns:
|
||||
Wrapped function with retry capability.
|
||||
"""
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs) -> T:
|
||||
last_exception = None
|
||||
|
||||
for attempt in range(1, MAX_DB_RETRIES + 1):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except RETRYABLE_DB_EXCEPTIONS as e:
|
||||
last_exception = e
|
||||
if attempt < MAX_DB_RETRIES:
|
||||
delay = DB_RETRY_BASE_DELAY * (2 ** (attempt - 1))
|
||||
logger.warning(
|
||||
f"Database retry {attempt}/{MAX_DB_RETRIES} for {func.__name__} "
|
||||
f"after {delay:.2f}s. Error: {e}"
|
||||
)
|
||||
time.sleep(delay)
|
||||
else:
|
||||
logger.error(
|
||||
f"All {MAX_DB_RETRIES} database retries exhausted for {func.__name__}"
|
||||
)
|
||||
except IntegrityError as e:
|
||||
# Don't retry constraint violations
|
||||
logger.error(f"Database integrity error in {func.__name__}: {e}")
|
||||
raise
|
||||
except SQLAlchemyError as e:
|
||||
# Log and re-raise other SQLAlchemy errors
|
||||
logger.error(f"Database error in {func.__name__}: {e}")
|
||||
raise
|
||||
|
||||
# All retries exhausted
|
||||
if last_exception:
|
||||
raise last_exception
|
||||
raise SQLAlchemyError(f"Unexpected database error in {func.__name__}")
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class Scan(Base):
|
||||
"""Scan record model."""
|
||||
|
||||
__tablename__ = 'scans'
|
||||
|
||||
scan_id = Column(String(36), primary_key=True)
|
||||
timestamp = Column(DateTime, nullable=False, default=datetime.utcnow)
|
||||
engines_used = Column(Text) # JSON array
|
||||
query = Column(Text)
|
||||
template_name = Column(String(255))
|
||||
total_results = Column(Integer, default=0)
|
||||
duration_seconds = Column(Float)
|
||||
status = Column(String(50), default='running') # running, completed, failed, partial
|
||||
extra_data = Column(Text) # JSON
|
||||
|
||||
# Relationships
|
||||
findings = relationship("Finding", back_populates="scan", cascade="all, delete-orphan")
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
'scan_id': self.scan_id,
|
||||
'timestamp': self.timestamp.isoformat() if self.timestamp else None,
|
||||
'engines_used': json.loads(self.engines_used) if self.engines_used else [],
|
||||
'query': self.query,
|
||||
'template_name': self.template_name,
|
||||
'total_results': self.total_results,
|
||||
'duration_seconds': self.duration_seconds,
|
||||
'status': self.status,
|
||||
'metadata': json.loads(self.extra_data) if self.extra_data else {}
|
||||
}
|
||||
|
||||
|
||||
class Finding(Base):
|
||||
"""Finding record model."""
|
||||
|
||||
__tablename__ = 'findings'
|
||||
|
||||
finding_id = Column(String(36), primary_key=True)
|
||||
scan_id = Column(String(36), ForeignKey('scans.scan_id'), nullable=False)
|
||||
source_engine = Column(String(50))
|
||||
target_ip = Column(String(45), nullable=False) # Support IPv6
|
||||
target_port = Column(Integer, nullable=False)
|
||||
target_hostname = Column(String(255))
|
||||
service = Column(String(255))
|
||||
banner = Column(Text)
|
||||
risk_score = Column(Float, default=0.0)
|
||||
vulnerabilities = Column(Text) # JSON array
|
||||
first_seen = Column(DateTime, default=datetime.utcnow)
|
||||
last_seen = Column(DateTime, default=datetime.utcnow)
|
||||
status = Column(String(50), default='new') # new, confirmed, false_positive, remediated
|
||||
confidence = Column(Integer, default=100)
|
||||
extra_data = Column(Text) # JSON
|
||||
|
||||
# Relationships
|
||||
scan = relationship("Scan", back_populates="findings")
|
||||
|
||||
# Indexes
|
||||
__table_args__ = (
|
||||
Index('idx_findings_risk', risk_score.desc()),
|
||||
Index('idx_findings_timestamp', first_seen.desc()),
|
||||
Index('idx_findings_ip', target_ip),
|
||||
Index('idx_findings_status', status),
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
'finding_id': self.finding_id,
|
||||
'scan_id': self.scan_id,
|
||||
'source_engine': self.source_engine,
|
||||
'target_ip': self.target_ip,
|
||||
'target_port': self.target_port,
|
||||
'target_hostname': self.target_hostname,
|
||||
'service': self.service,
|
||||
'banner': self.banner,
|
||||
'risk_score': self.risk_score,
|
||||
'vulnerabilities': json.loads(self.vulnerabilities) if self.vulnerabilities else [],
|
||||
'first_seen': self.first_seen.isoformat() if self.first_seen else None,
|
||||
'last_seen': self.last_seen.isoformat() if self.last_seen else None,
|
||||
'status': self.status,
|
||||
'confidence': self.confidence,
|
||||
'metadata': json.loads(self.extra_data) if self.extra_data else {}
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_search_result(cls, result: SearchResult, scan_id: str) -> 'Finding':
|
||||
"""Create Finding from SearchResult."""
|
||||
return cls(
|
||||
finding_id=str(uuid.uuid4()),
|
||||
scan_id=scan_id,
|
||||
source_engine=result.source_engine,
|
||||
target_ip=result.ip,
|
||||
target_port=result.port,
|
||||
target_hostname=result.hostname,
|
||||
service=result.service,
|
||||
banner=result.banner,
|
||||
risk_score=result.risk_score,
|
||||
vulnerabilities=json.dumps(result.vulnerabilities),
|
||||
confidence=result.confidence,
|
||||
extra_data=json.dumps(result.metadata)
|
||||
)
|
||||
|
||||
|
||||
class Database:
|
||||
"""
|
||||
Database manager for AASRT with connection pooling and retry logic.
|
||||
|
||||
This class provides a thread-safe database layer with:
|
||||
- Connection pooling for efficient resource usage
|
||||
- Automatic retry on transient failures
|
||||
- Context managers for proper session cleanup
|
||||
- Support for SQLite and PostgreSQL
|
||||
|
||||
Attributes:
|
||||
config: Configuration instance.
|
||||
engine: SQLAlchemy engine with connection pool.
|
||||
Session: Scoped session factory.
|
||||
|
||||
Example:
|
||||
>>> db = Database()
|
||||
>>> with db.session_scope() as session:
|
||||
... scan = Scan(scan_id="123", ...)
|
||||
... session.add(scan)
|
||||
>>> # Session is automatically committed and closed
|
||||
"""
|
||||
|
||||
# Connection pool settings
|
||||
POOL_SIZE = 5
|
||||
MAX_OVERFLOW = 10
|
||||
POOL_TIMEOUT = 30
|
||||
POOL_RECYCLE = 3600 # Recycle connections after 1 hour
|
||||
|
||||
def __init__(self, config: Optional[Config] = None) -> None:
|
||||
"""
|
||||
Initialize database connection with connection pooling.
|
||||
|
||||
Args:
|
||||
config: Configuration instance. If None, uses default Config.
|
||||
|
||||
Raises:
|
||||
SQLAlchemyError: If database connection fails.
|
||||
"""
|
||||
self.config = config or Config()
|
||||
self.engine = None
|
||||
self.Session = None
|
||||
self._db_type: str = "unknown"
|
||||
self._initialize()
|
||||
|
||||
def _initialize(self) -> None:
|
||||
"""
|
||||
Initialize database connection, pooling, and create tables.
|
||||
|
||||
Sets up connection pooling appropriate for the database type:
|
||||
- SQLite: Uses StaticPool for thread safety
|
||||
- PostgreSQL: Uses QueuePool with configurable size
|
||||
"""
|
||||
self._db_type = self.config.get('database', 'type', default='sqlite')
|
||||
|
||||
if self._db_type == 'sqlite':
|
||||
db_path = self.config.get('database', 'sqlite', 'path', default='./data/scanner.db')
|
||||
# Ensure directory exists
|
||||
os.makedirs(os.path.dirname(db_path), exist_ok=True)
|
||||
connection_string = f"sqlite:///{db_path}"
|
||||
|
||||
# SQLite configuration - use StaticPool for thread safety
|
||||
# Also enable WAL mode for better concurrent access
|
||||
self.engine = create_engine(
|
||||
connection_string,
|
||||
echo=False,
|
||||
poolclass=StaticPool,
|
||||
connect_args={
|
||||
"check_same_thread": False,
|
||||
"timeout": 30
|
||||
}
|
||||
)
|
||||
|
||||
# Enable WAL mode for better concurrent access
|
||||
@event.listens_for(self.engine, "connect")
|
||||
def set_sqlite_pragma(dbapi_connection, connection_record):
|
||||
cursor = dbapi_connection.cursor()
|
||||
cursor.execute("PRAGMA journal_mode=WAL")
|
||||
cursor.execute("PRAGMA synchronous=NORMAL")
|
||||
cursor.execute("PRAGMA foreign_keys=ON")
|
||||
cursor.close()
|
||||
|
||||
else:
|
||||
# PostgreSQL with connection pooling
|
||||
host = self.config.get('database', 'postgresql', 'host', default='localhost')
|
||||
port = self.config.get('database', 'postgresql', 'port', default=5432)
|
||||
database = self.config.get('database', 'postgresql', 'database', default='aasrt')
|
||||
user = self.config.get('database', 'postgresql', 'user')
|
||||
password = self.config.get('database', 'postgresql', 'password')
|
||||
ssl_mode = self.config.get('database', 'postgresql', 'ssl_mode', default='prefer')
|
||||
|
||||
# Mask password in logs
|
||||
safe_conn_str = f"postgresql://{user}:***@{host}:{port}/{database}"
|
||||
connection_string = f"postgresql://{user}:{password}@{host}:{port}/{database}?sslmode={ssl_mode}"
|
||||
|
||||
self.engine = create_engine(
|
||||
connection_string,
|
||||
echo=False,
|
||||
poolclass=QueuePool,
|
||||
pool_size=self.POOL_SIZE,
|
||||
max_overflow=self.MAX_OVERFLOW,
|
||||
pool_timeout=self.POOL_TIMEOUT,
|
||||
pool_recycle=self.POOL_RECYCLE,
|
||||
pool_pre_ping=True # Verify connections before use
|
||||
)
|
||||
logger.debug(f"PostgreSQL connection: {safe_conn_str}")
|
||||
|
||||
# Use scoped_session for thread safety
|
||||
self.Session = scoped_session(sessionmaker(bind=self.engine))
|
||||
|
||||
# Create tables
|
||||
Base.metadata.create_all(self.engine)
|
||||
logger.info(f"Database initialized: {self._db_type}")
|
||||
|
||||
@contextmanager
|
||||
def session_scope(self) -> Generator[Session, None, None]:
|
||||
"""
|
||||
Provide a transactional scope around a series of operations.
|
||||
|
||||
This context manager handles session lifecycle:
|
||||
- Creates a new session
|
||||
- Commits on success
|
||||
- Rolls back on exception
|
||||
- Always closes the session
|
||||
|
||||
Yields:
|
||||
SQLAlchemy Session object.
|
||||
|
||||
Raises:
|
||||
SQLAlchemyError: On database errors (after rollback).
|
||||
|
||||
Example:
|
||||
>>> with db.session_scope() as session:
|
||||
... session.add(Scan(...))
|
||||
... # Automatically committed if no exception
|
||||
"""
|
||||
session = self.Session()
|
||||
try:
|
||||
yield session
|
||||
session.commit()
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error(f"Database session error, rolling back: {e}")
|
||||
raise
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
def get_session(self) -> Session:
|
||||
"""
|
||||
Get a database session (legacy method).
|
||||
|
||||
Note:
|
||||
Prefer using session_scope() context manager for new code.
|
||||
This method is kept for backward compatibility.
|
||||
|
||||
Returns:
|
||||
SQLAlchemy Session object.
|
||||
"""
|
||||
return self.Session()
|
||||
|
||||
def close(self) -> None:
|
||||
"""
|
||||
Close all database connections and cleanup resources.
|
||||
|
||||
Call this method during application shutdown to properly
|
||||
release database connections.
|
||||
"""
|
||||
if self.Session:
|
||||
self.Session.remove()
|
||||
if self.engine:
|
||||
self.engine.dispose()
|
||||
logger.info("Database connections closed")
|
||||
|
||||
def health_check(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Perform a health check on the database connection.
|
||||
|
||||
Returns:
|
||||
Dictionary with health status:
|
||||
- healthy: bool indicating if database is accessible
|
||||
- db_type: Database type (sqlite/postgresql)
|
||||
- latency_ms: Response time in milliseconds
|
||||
- error: Error message if unhealthy (optional)
|
||||
"""
|
||||
start_time = time.time()
|
||||
try:
|
||||
with self.session_scope() as session:
|
||||
# Simple query to verify connection
|
||||
session.execute("SELECT 1")
|
||||
|
||||
latency = (time.time() - start_time) * 1000
|
||||
return {
|
||||
"healthy": True,
|
||||
"db_type": self._db_type,
|
||||
"latency_ms": round(latency, 2),
|
||||
"pool_size": getattr(self.engine.pool, 'size', lambda: 'N/A')() if hasattr(self.engine, 'pool') else 'N/A'
|
||||
}
|
||||
except Exception as e:
|
||||
latency = (time.time() - start_time) * 1000
|
||||
logger.error(f"Database health check failed: {e}")
|
||||
return {
|
||||
"healthy": False,
|
||||
"db_type": self._db_type,
|
||||
"latency_ms": round(latency, 2),
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
# =========================================================================
|
||||
# Scan Operations
|
||||
# =========================================================================
|
||||
|
||||
@with_db_retry
|
||||
def create_scan(
|
||||
self,
|
||||
engines: List[str],
|
||||
query: Optional[str] = None,
|
||||
template_name: Optional[str] = None
|
||||
) -> Scan:
|
||||
"""
|
||||
Create a new scan record in the database.
|
||||
|
||||
Args:
|
||||
engines: List of engine names used for the scan (e.g., ["shodan"]).
|
||||
query: Search query string (if using custom query).
|
||||
template_name: Template name (if using predefined template).
|
||||
|
||||
Returns:
|
||||
Created Scan object with generated scan_id.
|
||||
|
||||
Raises:
|
||||
SQLAlchemyError: If database operation fails.
|
||||
|
||||
Example:
|
||||
>>> scan = db.create_scan(engines=["shodan"], template_name="clawdbot")
|
||||
>>> print(scan.scan_id)
|
||||
"""
|
||||
scan = Scan(
|
||||
scan_id=str(uuid.uuid4()),
|
||||
timestamp=datetime.utcnow(),
|
||||
engines_used=json.dumps(engines),
|
||||
query=query,
|
||||
template_name=template_name,
|
||||
status='running'
|
||||
)
|
||||
|
||||
with self.session_scope() as session:
|
||||
session.add(scan)
|
||||
# Flush to ensure data is written before expunge
|
||||
session.flush()
|
||||
logger.info(f"Created scan: {scan.scan_id}")
|
||||
# Need to expunge to use outside session
|
||||
session.expunge(scan)
|
||||
return scan
|
||||
|
||||
@with_db_retry
|
||||
def update_scan(
|
||||
self,
|
||||
scan_id: str,
|
||||
status: Optional[str] = None,
|
||||
total_results: Optional[int] = None,
|
||||
duration_seconds: Optional[float] = None,
|
||||
metadata: Optional[Dict] = None
|
||||
) -> Optional[Scan]:
|
||||
"""
|
||||
Update a scan record with new values.
|
||||
|
||||
Args:
|
||||
scan_id: UUID of the scan to update.
|
||||
status: New status (running, completed, failed, partial).
|
||||
total_results: Number of results found.
|
||||
duration_seconds: Total scan duration.
|
||||
metadata: Additional metadata to merge.
|
||||
|
||||
Returns:
|
||||
Updated Scan object, or None if scan not found.
|
||||
"""
|
||||
with self.session_scope() as session:
|
||||
scan = session.query(Scan).filter(Scan.scan_id == scan_id).first()
|
||||
if not scan:
|
||||
logger.warning(f"Scan not found for update: {scan_id}")
|
||||
return None
|
||||
|
||||
if status:
|
||||
scan.status = status
|
||||
if total_results is not None:
|
||||
scan.total_results = total_results
|
||||
if duration_seconds is not None:
|
||||
scan.duration_seconds = duration_seconds
|
||||
if metadata:
|
||||
existing = json.loads(scan.extra_data) if scan.extra_data else {}
|
||||
existing.update(metadata)
|
||||
scan.extra_data = json.dumps(existing)
|
||||
|
||||
# Flush to ensure changes are written before expunge
|
||||
session.flush()
|
||||
logger.debug(f"Updated scan {scan_id}: status={status}, results={total_results}")
|
||||
session.expunge(scan)
|
||||
return scan
|
||||
|
||||
@with_db_retry
|
||||
def get_scan(self, scan_id: str) -> Optional[Scan]:
|
||||
"""
|
||||
Get a scan by its UUID.
|
||||
|
||||
Args:
|
||||
scan_id: UUID of the scan.
|
||||
|
||||
Returns:
|
||||
Scan object or None if not found.
|
||||
"""
|
||||
with self.session_scope() as session:
|
||||
scan = session.query(Scan).filter(Scan.scan_id == scan_id).first()
|
||||
if scan:
|
||||
session.expunge(scan)
|
||||
return scan
|
||||
|
||||
@with_db_retry
|
||||
def get_recent_scans(self, limit: int = 10) -> List[Scan]:
|
||||
"""
|
||||
Get the most recent scans.
|
||||
|
||||
Args:
|
||||
limit: Maximum number of scans to return.
|
||||
|
||||
Returns:
|
||||
List of Scan objects ordered by timestamp descending.
|
||||
"""
|
||||
with self.session_scope() as session:
|
||||
scans = session.query(Scan).order_by(Scan.timestamp.desc()).limit(limit).all()
|
||||
for scan in scans:
|
||||
session.expunge(scan)
|
||||
return scans
|
||||
|
||||
@with_db_retry
|
||||
def delete_scan(self, scan_id: str) -> bool:
|
||||
"""
|
||||
Delete a scan and all its associated findings.
|
||||
|
||||
Args:
|
||||
scan_id: UUID of the scan to delete.
|
||||
|
||||
Returns:
|
||||
True if scan was deleted, False if not found.
|
||||
"""
|
||||
with self.session_scope() as session:
|
||||
scan = session.query(Scan).filter(Scan.scan_id == scan_id).first()
|
||||
if scan:
|
||||
session.delete(scan)
|
||||
session.commit()
|
||||
logger.info(f"Deleted scan: {scan_id}")
|
||||
return True
|
||||
return False
|
||||
|
||||
# =========================================================================
|
||||
# Finding Operations
|
||||
# =========================================================================
|
||||
|
||||
@with_db_retry
|
||||
def add_findings(self, scan_id: str, results: List[SearchResult]) -> int:
|
||||
"""
|
||||
Add findings from search results to the database.
|
||||
|
||||
Args:
|
||||
scan_id: Parent scan UUID.
|
||||
results: List of SearchResult objects to store.
|
||||
|
||||
Returns:
|
||||
Number of findings successfully added.
|
||||
|
||||
Raises:
|
||||
SQLAlchemyError: If database operation fails.
|
||||
|
||||
Note:
|
||||
Findings are added in batches for efficiency.
|
||||
"""
|
||||
if not results:
|
||||
logger.debug(f"No findings to add for scan {scan_id}")
|
||||
return 0
|
||||
|
||||
with self.session_scope() as session:
|
||||
count = 0
|
||||
for result in results:
|
||||
try:
|
||||
finding = Finding.from_search_result(result, scan_id)
|
||||
session.add(finding)
|
||||
count += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to create finding from result: {e}")
|
||||
continue
|
||||
|
||||
logger.info(f"Added {count} findings to scan {scan_id}")
|
||||
return count
|
||||
|
||||
@with_db_retry
|
||||
def get_findings(
|
||||
self,
|
||||
scan_id: Optional[str] = None,
|
||||
min_risk_score: Optional[float] = None,
|
||||
status: Optional[str] = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0
|
||||
) -> List[Finding]:
|
||||
"""
|
||||
Get findings with optional filters.
|
||||
|
||||
Args:
|
||||
scan_id: Filter by scan UUID.
|
||||
min_risk_score: Minimum risk score (0.0-10.0).
|
||||
status: Finding status filter (new, confirmed, false_positive, remediated).
|
||||
limit: Maximum results to return.
|
||||
offset: Number of results to skip (for pagination).
|
||||
|
||||
Returns:
|
||||
List of Finding objects matching filters, ordered by risk score descending.
|
||||
"""
|
||||
with self.session_scope() as session:
|
||||
query = session.query(Finding)
|
||||
|
||||
if scan_id:
|
||||
query = query.filter(Finding.scan_id == scan_id)
|
||||
if min_risk_score is not None:
|
||||
query = query.filter(Finding.risk_score >= min_risk_score)
|
||||
if status:
|
||||
query = query.filter(Finding.status == status)
|
||||
|
||||
query = query.order_by(Finding.risk_score.desc())
|
||||
findings = query.offset(offset).limit(limit).all()
|
||||
|
||||
for finding in findings:
|
||||
session.expunge(finding)
|
||||
return findings
|
||||
|
||||
@with_db_retry
|
||||
def get_finding(self, finding_id: str) -> Optional[Finding]:
|
||||
"""
|
||||
Get a single finding by its UUID.
|
||||
|
||||
Args:
|
||||
finding_id: UUID of the finding.
|
||||
|
||||
Returns:
|
||||
Finding object or None if not found.
|
||||
"""
|
||||
with self.session_scope() as session:
|
||||
finding = session.query(Finding).filter(Finding.finding_id == finding_id).first()
|
||||
if finding:
|
||||
session.expunge(finding)
|
||||
return finding
|
||||
|
||||
@with_db_retry
|
||||
def update_finding_status(self, finding_id: str, status: str) -> bool:
|
||||
"""
|
||||
Update the status of a finding.
|
||||
|
||||
Args:
|
||||
finding_id: UUID of the finding.
|
||||
status: New status (new, confirmed, false_positive, remediated).
|
||||
|
||||
Returns:
|
||||
True if finding was updated, False if not found.
|
||||
"""
|
||||
with self.session_scope() as session:
|
||||
finding = session.query(Finding).filter(Finding.finding_id == finding_id).first()
|
||||
if finding:
|
||||
finding.status = status
|
||||
finding.last_seen = datetime.utcnow()
|
||||
logger.debug(f"Updated finding {finding_id} status to {status}")
|
||||
return True
|
||||
logger.warning(f"Finding not found for status update: {finding_id}")
|
||||
return False
|
||||
|
||||
@with_db_retry
|
||||
def get_finding_by_target(self, ip: str, port: int) -> Optional[Finding]:
|
||||
"""
|
||||
Get the most recent finding for a specific target.
|
||||
|
||||
Args:
|
||||
ip: Target IP address.
|
||||
port: Target port number.
|
||||
|
||||
Returns:
|
||||
Most recent Finding for the target, or None if not found.
|
||||
"""
|
||||
with self.session_scope() as session:
|
||||
finding = session.query(Finding).filter(
|
||||
Finding.target_ip == ip,
|
||||
Finding.target_port == port
|
||||
).order_by(Finding.last_seen.desc()).first()
|
||||
if finding:
|
||||
session.expunge(finding)
|
||||
return finding
|
||||
|
||||
# =========================================================================
|
||||
# Statistics and Maintenance
|
||||
# =========================================================================
|
||||
|
||||
@with_db_retry
|
||||
def get_statistics(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get overall database statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary containing:
|
||||
- total_scans: Total number of scans
|
||||
- total_findings: Total number of findings
|
||||
- unique_ips: Count of unique IP addresses
|
||||
- risk_distribution: Dict with critical/high/medium/low counts
|
||||
- last_scan_time: Timestamp of most recent scan (or None)
|
||||
|
||||
Example:
|
||||
>>> stats = db.get_statistics()
|
||||
>>> print(f"Critical findings: {stats['risk_distribution']['critical']}")
|
||||
"""
|
||||
with self.session_scope() as session:
|
||||
total_scans = session.query(Scan).count()
|
||||
total_findings = session.query(Finding).count()
|
||||
|
||||
# Risk distribution using CVSS-like thresholds
|
||||
critical = session.query(Finding).filter(Finding.risk_score >= 9.0).count()
|
||||
high = session.query(Finding).filter(
|
||||
Finding.risk_score >= 7.0,
|
||||
Finding.risk_score < 9.0
|
||||
).count()
|
||||
medium = session.query(Finding).filter(
|
||||
Finding.risk_score >= 4.0,
|
||||
Finding.risk_score < 7.0
|
||||
).count()
|
||||
low = session.query(Finding).filter(Finding.risk_score < 4.0).count()
|
||||
|
||||
# Unique IPs discovered
|
||||
unique_ips = session.query(Finding.target_ip).distinct().count()
|
||||
|
||||
# Last scan timestamp
|
||||
last_scan = session.query(Scan).order_by(Scan.timestamp.desc()).first()
|
||||
last_scan_time = last_scan.timestamp.isoformat() if last_scan else None
|
||||
|
||||
return {
|
||||
'total_scans': total_scans,
|
||||
'total_findings': total_findings,
|
||||
'unique_ips': unique_ips,
|
||||
'risk_distribution': {
|
||||
'critical': critical,
|
||||
'high': high,
|
||||
'medium': medium,
|
||||
'low': low
|
||||
},
|
||||
'last_scan_time': last_scan_time
|
||||
}
|
||||
|
||||
@with_db_retry
|
||||
def cleanup_old_data(self, days: int = 90) -> int:
|
||||
"""
|
||||
Remove scan data older than specified days.
|
||||
|
||||
This is a maintenance operation that removes old scans and their
|
||||
associated findings to manage database size.
|
||||
|
||||
Args:
|
||||
days: Age threshold in days. Scans older than this will be deleted.
|
||||
Default is 90 days.
|
||||
|
||||
Returns:
|
||||
Number of scans deleted (findings are cascade deleted).
|
||||
|
||||
Raises:
|
||||
ValueError: If days is less than 1.
|
||||
SQLAlchemyError: If database operation fails.
|
||||
|
||||
Example:
|
||||
>>> # Remove data older than 30 days
|
||||
>>> deleted = db.cleanup_old_data(days=30)
|
||||
>>> print(f"Removed {deleted} old scans")
|
||||
"""
|
||||
if days < 1:
|
||||
raise ValueError("Days must be at least 1")
|
||||
|
||||
cutoff = datetime.utcnow() - timedelta(days=days)
|
||||
logger.info(f"Cleaning up data older than {cutoff.isoformat()}")
|
||||
|
||||
with self.session_scope() as session:
|
||||
# Count first for logging
|
||||
old_scans = session.query(Scan).filter(Scan.timestamp < cutoff).all()
|
||||
count = len(old_scans)
|
||||
|
||||
if count == 0:
|
||||
logger.info("No old data to clean up")
|
||||
return 0
|
||||
|
||||
for scan in old_scans:
|
||||
session.delete(scan)
|
||||
|
||||
logger.info(f"Cleaned up {count} scans older than {days} days")
|
||||
return count
|
||||
@@ -0,0 +1,26 @@
|
||||
"""Utility modules for AASRT."""
|
||||
|
||||
from .config import Config
|
||||
from .logger import setup_logger, get_logger
|
||||
from .exceptions import (
|
||||
AASRTException,
|
||||
APIException,
|
||||
RateLimitException,
|
||||
ConfigurationException,
|
||||
ValidationException
|
||||
)
|
||||
from .validators import validate_ip, validate_domain, validate_query
|
||||
|
||||
__all__ = [
|
||||
'Config',
|
||||
'setup_logger',
|
||||
'get_logger',
|
||||
'AASRTException',
|
||||
'APIException',
|
||||
'RateLimitException',
|
||||
'ConfigurationException',
|
||||
'ValidationException',
|
||||
'validate_ip',
|
||||
'validate_domain',
|
||||
'validate_query'
|
||||
]
|
||||
@@ -0,0 +1,513 @@
|
||||
"""
|
||||
Configuration management for AASRT.
|
||||
|
||||
This module provides a production-ready configuration management system with:
|
||||
- Singleton pattern for global configuration access
|
||||
- YAML file loading with deep merging
|
||||
- Environment variable overrides
|
||||
- Validation of required settings
|
||||
- Support for structured logging configuration
|
||||
- Health check capabilities
|
||||
|
||||
Configuration priority (highest to lowest):
|
||||
1. Environment variables
|
||||
2. YAML configuration file
|
||||
3. Default values
|
||||
|
||||
Example:
|
||||
>>> from src.utils.config import Config
|
||||
>>> config = Config()
|
||||
>>> shodan_key = config.get_shodan_key()
|
||||
>>> log_level = config.get('logging', 'level', default='INFO')
|
||||
|
||||
Environment Variables:
|
||||
SHODAN_API_KEY: Required Shodan API key
|
||||
AASRT_LOG_LEVEL: Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
|
||||
AASRT_ENVIRONMENT: Deployment environment (development, staging, production)
|
||||
AASRT_DEBUG: Enable debug mode (true/false)
|
||||
DB_TYPE: Database type (sqlite, postgresql)
|
||||
DB_HOST, DB_PORT, DB_NAME, DB_USER, DB_PASSWORD: PostgreSQL settings
|
||||
"""
|
||||
|
||||
import os
|
||||
import secrets
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
import yaml
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from .exceptions import ConfigurationException
|
||||
from .logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# =============================================================================
|
||||
# Validation Constants
|
||||
# =============================================================================
|
||||
|
||||
VALID_LOG_LEVELS: Set[str] = {"DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"}
|
||||
VALID_ENVIRONMENTS: Set[str] = {"development", "staging", "production"}
|
||||
VALID_DB_TYPES: Set[str] = {"sqlite", "postgresql", "mysql"}
|
||||
REQUIRED_SETTINGS: List[str] = [] # API key is optional until scan is run
|
||||
|
||||
|
||||
class Config:
|
||||
"""
|
||||
Configuration manager for AASRT with singleton pattern.
|
||||
|
||||
This class provides centralized configuration management with:
|
||||
- Thread-safe singleton access
|
||||
- YAML file configuration
|
||||
- Environment variable overrides
|
||||
- Validation of critical settings
|
||||
- Health check for configuration state
|
||||
|
||||
Attributes:
|
||||
_instance: Singleton instance.
|
||||
_config: Configuration dictionary.
|
||||
_initialized: Flag indicating initialization status.
|
||||
_config_path: Path to loaded configuration file.
|
||||
_environment: Current deployment environment.
|
||||
|
||||
Example:
|
||||
>>> config = Config()
|
||||
>>> api_key = config.get_shodan_key()
|
||||
>>> if not api_key:
|
||||
... print("Warning: Shodan API key not configured")
|
||||
"""
|
||||
|
||||
_instance: Optional['Config'] = None
|
||||
_config: Dict[str, Any] = {}
|
||||
|
||||
def __new__(cls, config_path: Optional[str] = None):
|
||||
"""
|
||||
Singleton pattern implementation.
|
||||
|
||||
Args:
|
||||
config_path: Optional path to YAML configuration file.
|
||||
|
||||
Returns:
|
||||
Singleton Config instance.
|
||||
"""
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance._initialized = False
|
||||
return cls._instance
|
||||
|
||||
def __init__(self, config_path: Optional[str] = None) -> None:
|
||||
"""
|
||||
Initialize configuration from multiple sources.
|
||||
|
||||
Configuration is loaded in order of priority:
|
||||
1. Default values
|
||||
2. YAML configuration file
|
||||
3. Environment variables (highest priority)
|
||||
|
||||
Args:
|
||||
config_path: Path to YAML configuration file.
|
||||
If not provided, searches common locations.
|
||||
|
||||
Raises:
|
||||
ConfigurationException: If YAML file is malformed.
|
||||
"""
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
# Load environment variables from .env file
|
||||
load_dotenv()
|
||||
|
||||
# Store metadata
|
||||
self._config_path: Optional[str] = None
|
||||
self._environment: str = os.getenv('AASRT_ENVIRONMENT', 'development')
|
||||
self._validation_errors: List[str] = []
|
||||
|
||||
# Default configuration
|
||||
self._config = self._get_defaults()
|
||||
|
||||
# Load from file if provided
|
||||
if config_path:
|
||||
self._load_from_file(config_path)
|
||||
else:
|
||||
# Try to find config file in common locations
|
||||
for path in ['config.yaml', 'config.yml', './config/config.yaml']:
|
||||
if os.path.exists(path):
|
||||
self._load_from_file(path)
|
||||
break
|
||||
|
||||
# Override with environment variables
|
||||
self._load_from_env()
|
||||
|
||||
# Validate configuration
|
||||
self._validate_config()
|
||||
|
||||
self._initialized = True
|
||||
logger.info(f"Configuration initialized (environment: {self._environment})")
|
||||
|
||||
def _get_defaults(self) -> Dict[str, Any]:
|
||||
"""Get default configuration values."""
|
||||
return {
|
||||
'shodan': {
|
||||
'enabled': True,
|
||||
'rate_limit': 1,
|
||||
'max_results': 100,
|
||||
'timeout': 30
|
||||
},
|
||||
'vulnerability_checks': {
|
||||
'enabled': True,
|
||||
'passive_only': True,
|
||||
'timeout_per_check': 10
|
||||
},
|
||||
'reporting': {
|
||||
'formats': ['json', 'csv'],
|
||||
'output_dir': './reports',
|
||||
'anonymize_by_default': False
|
||||
},
|
||||
'filtering': {
|
||||
'whitelist_ips': [],
|
||||
'whitelist_domains': [],
|
||||
'min_confidence_score': 70,
|
||||
'exclude_honeypots': True
|
||||
},
|
||||
'logging': {
|
||||
'level': 'INFO',
|
||||
'file': './logs/scanner.log',
|
||||
'max_size_mb': 100,
|
||||
'backup_count': 5
|
||||
},
|
||||
'database': {
|
||||
'type': 'sqlite',
|
||||
'sqlite': {
|
||||
'path': './data/scanner.db'
|
||||
}
|
||||
},
|
||||
'api_keys': {},
|
||||
'clawsec': {
|
||||
'enabled': True,
|
||||
'feed_url': 'https://clawsec.prompt.security/advisories/feed.json',
|
||||
'cache_ttl_seconds': 86400, # 24 hours
|
||||
'cache_file': './data/clawsec_cache.json',
|
||||
'offline_mode': False,
|
||||
'timeout': 30,
|
||||
'auto_refresh': True
|
||||
}
|
||||
}
|
||||
|
||||
def _load_from_file(self, path: str) -> None:
|
||||
"""
|
||||
Load configuration from YAML file.
|
||||
|
||||
Args:
|
||||
path: Path to YAML configuration file.
|
||||
|
||||
Raises:
|
||||
ConfigurationException: If YAML is malformed.
|
||||
"""
|
||||
try:
|
||||
with open(path, 'r') as f:
|
||||
file_config = yaml.safe_load(f)
|
||||
if file_config:
|
||||
self._deep_merge(self._config, file_config)
|
||||
self._config_path = path
|
||||
logger.info(f"Loaded configuration from {path}")
|
||||
except FileNotFoundError:
|
||||
logger.warning(f"Configuration file not found: {path}")
|
||||
except yaml.YAMLError as e:
|
||||
raise ConfigurationException(f"Invalid YAML in configuration file: {e}")
|
||||
|
||||
def _load_from_env(self) -> None:
|
||||
"""
|
||||
Load settings from environment variables.
|
||||
|
||||
Environment variables override file-based configuration.
|
||||
This method handles all supported environment variables.
|
||||
"""
|
||||
# Load Shodan API key
|
||||
shodan_key = os.getenv('SHODAN_API_KEY')
|
||||
if shodan_key:
|
||||
self._set_nested(('api_keys', 'shodan'), shodan_key)
|
||||
|
||||
# Load log level if set
|
||||
log_level = os.getenv('AASRT_LOG_LEVEL', '').upper()
|
||||
if log_level and log_level in VALID_LOG_LEVELS:
|
||||
self._set_nested(('logging', 'level'), log_level)
|
||||
elif log_level:
|
||||
logger.warning(f"Invalid log level '{log_level}', using default")
|
||||
|
||||
# Load environment setting
|
||||
env = os.getenv('AASRT_ENVIRONMENT', '').lower()
|
||||
if env and env in VALID_ENVIRONMENTS:
|
||||
self._environment = env
|
||||
|
||||
# Load debug flag
|
||||
debug = os.getenv('AASRT_DEBUG', '').lower()
|
||||
if debug in ('true', '1', 'yes'):
|
||||
self._set_nested(('logging', 'level'), 'DEBUG')
|
||||
|
||||
# Load database settings from environment
|
||||
db_type = os.getenv('DB_TYPE', '').lower()
|
||||
if db_type and db_type in VALID_DB_TYPES:
|
||||
self._set_nested(('database', 'type'), db_type)
|
||||
|
||||
# PostgreSQL settings from environment
|
||||
if os.getenv('DB_HOST'):
|
||||
self._set_nested(('database', 'postgresql', 'host'), os.getenv('DB_HOST'))
|
||||
if os.getenv('DB_PORT'):
|
||||
try:
|
||||
port = int(os.getenv('DB_PORT'))
|
||||
self._set_nested(('database', 'postgresql', 'port'), port)
|
||||
except ValueError:
|
||||
logger.warning("Invalid DB_PORT, using default")
|
||||
if os.getenv('DB_NAME'):
|
||||
self._set_nested(('database', 'postgresql', 'database'), os.getenv('DB_NAME'))
|
||||
if os.getenv('DB_USER'):
|
||||
self._set_nested(('database', 'postgresql', 'user'), os.getenv('DB_USER'))
|
||||
if os.getenv('DB_PASSWORD'):
|
||||
self._set_nested(('database', 'postgresql', 'password'), os.getenv('DB_PASSWORD'))
|
||||
if os.getenv('DB_SSL_MODE'):
|
||||
self._set_nested(('database', 'postgresql', 'ssl_mode'), os.getenv('DB_SSL_MODE'))
|
||||
|
||||
# Max results limit
|
||||
max_results = os.getenv('AASRT_MAX_RESULTS')
|
||||
if max_results:
|
||||
try:
|
||||
self._set_nested(('shodan', 'max_results'), int(max_results))
|
||||
except ValueError:
|
||||
logger.warning("Invalid AASRT_MAX_RESULTS, using default")
|
||||
|
||||
def _validate_config(self) -> None:
|
||||
"""
|
||||
Validate configuration settings.
|
||||
|
||||
Checks for valid values and logs warnings for potential issues.
|
||||
Does not raise exceptions to allow graceful degradation.
|
||||
"""
|
||||
self._validation_errors = []
|
||||
|
||||
# Validate log level
|
||||
log_level = self.get('logging', 'level', default='INFO')
|
||||
if log_level.upper() not in VALID_LOG_LEVELS:
|
||||
self._validation_errors.append(f"Invalid log level: {log_level}")
|
||||
|
||||
# Validate database type
|
||||
db_type = self.get('database', 'type', default='sqlite')
|
||||
if db_type.lower() not in VALID_DB_TYPES:
|
||||
self._validation_errors.append(f"Invalid database type: {db_type}")
|
||||
|
||||
# Validate max results is positive
|
||||
max_results = self.get('shodan', 'max_results', default=100)
|
||||
if not isinstance(max_results, int) or max_results < 1:
|
||||
self._validation_errors.append(f"Invalid max_results: {max_results}")
|
||||
|
||||
# Check for Shodan API key (warning, not error)
|
||||
if not self.get_shodan_key():
|
||||
logger.debug("Shodan API key not configured - scans will require it")
|
||||
|
||||
# Log validation errors
|
||||
for error in self._validation_errors:
|
||||
logger.warning(f"Configuration validation: {error}")
|
||||
|
||||
def _deep_merge(self, base: Dict, overlay: Dict) -> None:
|
||||
"""
|
||||
Deep merge overlay dictionary into base dictionary.
|
||||
|
||||
Args:
|
||||
base: Base dictionary to merge into (modified in place).
|
||||
overlay: Overlay dictionary to merge from.
|
||||
"""
|
||||
for key, value in overlay.items():
|
||||
if key in base and isinstance(base[key], dict) and isinstance(value, dict):
|
||||
self._deep_merge(base[key], value)
|
||||
else:
|
||||
base[key] = value
|
||||
|
||||
def _set_nested(self, path: tuple, value: Any) -> None:
|
||||
"""
|
||||
Set a nested configuration value by key path.
|
||||
|
||||
Args:
|
||||
path: Tuple of keys representing the path.
|
||||
value: Value to set at the path.
|
||||
"""
|
||||
current = self._config
|
||||
for key in path[:-1]:
|
||||
if key not in current:
|
||||
current[key] = {}
|
||||
current = current[key]
|
||||
current[path[-1]] = value
|
||||
|
||||
def get(self, *keys: str, default: Any = None) -> Any:
|
||||
"""
|
||||
Get a configuration value by nested keys.
|
||||
|
||||
Args:
|
||||
*keys: Nested keys to traverse (e.g., 'database', 'type').
|
||||
default: Default value if path not found.
|
||||
|
||||
Returns:
|
||||
Configuration value or default.
|
||||
|
||||
Example:
|
||||
>>> config.get('shodan', 'max_results', default=100)
|
||||
100
|
||||
"""
|
||||
current = self._config
|
||||
for key in keys:
|
||||
if isinstance(current, dict) and key in current:
|
||||
current = current[key]
|
||||
else:
|
||||
return default
|
||||
return current
|
||||
|
||||
def get_shodan_key(self) -> Optional[str]:
|
||||
"""
|
||||
Get Shodan API key.
|
||||
|
||||
Returns:
|
||||
Shodan API key string, or None if not configured.
|
||||
"""
|
||||
return self.get('api_keys', 'shodan')
|
||||
|
||||
def get_shodan_config(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get Shodan configuration dictionary.
|
||||
|
||||
Returns:
|
||||
Dictionary with Shodan settings (enabled, rate_limit, max_results, timeout).
|
||||
"""
|
||||
return self.get('shodan', default={})
|
||||
|
||||
def get_clawsec_config(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get ClawSec configuration dictionary.
|
||||
|
||||
Returns:
|
||||
Dictionary with ClawSec settings.
|
||||
"""
|
||||
return self.get('clawsec', default={})
|
||||
|
||||
def is_clawsec_enabled(self) -> bool:
|
||||
"""
|
||||
Check if ClawSec integration is enabled.
|
||||
|
||||
Returns:
|
||||
True if ClawSec vulnerability lookup is enabled.
|
||||
"""
|
||||
return self.get('clawsec', 'enabled', default=True)
|
||||
|
||||
def get_database_config(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get database configuration.
|
||||
|
||||
Returns:
|
||||
Dictionary with database settings.
|
||||
"""
|
||||
return self.get('database', default={})
|
||||
|
||||
def get_logging_config(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get logging configuration.
|
||||
|
||||
Returns:
|
||||
Dictionary with logging settings (level, file, max_size_mb, backup_count).
|
||||
"""
|
||||
return self.get('logging', default={})
|
||||
|
||||
@property
|
||||
def environment(self) -> str:
|
||||
"""
|
||||
Get current deployment environment.
|
||||
|
||||
Returns:
|
||||
Environment string (development, staging, production).
|
||||
"""
|
||||
return self._environment
|
||||
|
||||
@property
|
||||
def is_production(self) -> bool:
|
||||
"""
|
||||
Check if running in production environment.
|
||||
|
||||
Returns:
|
||||
True if environment is 'production'.
|
||||
"""
|
||||
return self._environment == 'production'
|
||||
|
||||
@property
|
||||
def is_debug(self) -> bool:
|
||||
"""
|
||||
Check if debug mode is enabled.
|
||||
|
||||
Returns:
|
||||
True if log level is DEBUG.
|
||||
"""
|
||||
return self.get('logging', 'level', default='INFO').upper() == 'DEBUG'
|
||||
|
||||
@property
|
||||
def all(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get all configuration as dictionary.
|
||||
|
||||
Returns:
|
||||
Copy of complete configuration dictionary.
|
||||
"""
|
||||
return self._config.copy()
|
||||
|
||||
def reload(self, config_path: Optional[str] = None) -> None:
|
||||
"""
|
||||
Reload configuration from file and environment.
|
||||
|
||||
Use this to refresh configuration without restarting the application.
|
||||
|
||||
Args:
|
||||
config_path: Optional path to configuration file.
|
||||
If None, uses previously loaded file path.
|
||||
"""
|
||||
logger.info("Reloading configuration...")
|
||||
self._initialized = False
|
||||
self._config = self._get_defaults()
|
||||
|
||||
# Use new path or fall back to previously loaded path
|
||||
path_to_load = config_path or self._config_path
|
||||
if path_to_load:
|
||||
self._load_from_file(path_to_load)
|
||||
|
||||
self._load_from_env()
|
||||
self._validate_config()
|
||||
self._initialized = True
|
||||
logger.info("Configuration reloaded successfully")
|
||||
|
||||
def health_check(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Perform a health check on configuration.
|
||||
|
||||
Returns:
|
||||
Dictionary with health status:
|
||||
- healthy: bool indicating if configuration is valid
|
||||
- environment: Current deployment environment
|
||||
- config_file: Path to loaded config file (if any)
|
||||
- validation_errors: List of validation errors
|
||||
- shodan_configured: Whether Shodan API key is set
|
||||
- clawsec_enabled: Whether ClawSec is enabled
|
||||
"""
|
||||
return {
|
||||
"healthy": len(self._validation_errors) == 0,
|
||||
"environment": self._environment,
|
||||
"config_file": self._config_path,
|
||||
"validation_errors": self._validation_errors.copy(),
|
||||
"shodan_configured": bool(self.get_shodan_key()),
|
||||
"clawsec_enabled": self.is_clawsec_enabled(),
|
||||
"log_level": self.get('logging', 'level', default='INFO'),
|
||||
"database_type": self.get('database', 'type', default='sqlite')
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def reset_instance() -> None:
|
||||
"""
|
||||
Reset the singleton instance (for testing).
|
||||
|
||||
Warning:
|
||||
This should only be used in tests. It will cause any
|
||||
existing references to the old instance to be stale.
|
||||
"""
|
||||
Config._instance = None
|
||||
@@ -0,0 +1,51 @@
|
||||
"""Custom exceptions for AASRT."""
|
||||
|
||||
|
||||
class AASRTException(Exception):
|
||||
"""Base exception for AASRT."""
|
||||
pass
|
||||
|
||||
|
||||
class APIException(AASRTException):
|
||||
"""Raised when API call fails."""
|
||||
|
||||
def __init__(self, message: str, engine: str = None, status_code: int = None):
|
||||
self.engine = engine
|
||||
self.status_code = status_code
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class RateLimitException(AASRTException):
|
||||
"""Raised when rate limit is exceeded."""
|
||||
|
||||
def __init__(self, message: str, engine: str = None, retry_after: int = None):
|
||||
self.engine = engine
|
||||
self.retry_after = retry_after
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ConfigurationException(AASRTException):
|
||||
"""Raised when configuration is invalid."""
|
||||
pass
|
||||
|
||||
|
||||
class ValidationException(AASRTException):
|
||||
"""Raised when input validation fails."""
|
||||
pass
|
||||
|
||||
|
||||
class AuthenticationException(AASRTException):
|
||||
"""Raised when authentication fails."""
|
||||
|
||||
def __init__(self, message: str, engine: str = None):
|
||||
self.engine = engine
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class TimeoutException(AASRTException):
|
||||
"""Raised when a request times out."""
|
||||
|
||||
def __init__(self, message: str, engine: str = None, timeout: int = None):
|
||||
self.engine = engine
|
||||
self.timeout = timeout
|
||||
super().__init__(message)
|
||||
@@ -0,0 +1,91 @@
|
||||
"""Logging setup for AASRT."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from logging.handlers import RotatingFileHandler
|
||||
from typing import Optional
|
||||
|
||||
|
||||
_loggers = {}
|
||||
|
||||
|
||||
def setup_logger(
|
||||
name: str = "aasrt",
|
||||
level: str = "INFO",
|
||||
log_file: Optional[str] = None,
|
||||
max_size_mb: int = 100,
|
||||
backup_count: int = 5
|
||||
) -> logging.Logger:
|
||||
"""
|
||||
Setup and configure a logger.
|
||||
|
||||
Args:
|
||||
name: Logger name
|
||||
level: Log level (DEBUG, INFO, WARNING, ERROR)
|
||||
log_file: Path to log file (optional)
|
||||
max_size_mb: Max log file size in MB
|
||||
backup_count: Number of backup files to keep
|
||||
|
||||
Returns:
|
||||
Configured logger instance
|
||||
"""
|
||||
if name in _loggers:
|
||||
return _loggers[name]
|
||||
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(getattr(logging, level.upper(), logging.INFO))
|
||||
|
||||
# Prevent duplicate handlers
|
||||
if logger.handlers:
|
||||
return logger
|
||||
|
||||
# Console handler with colors
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setLevel(logging.DEBUG)
|
||||
|
||||
# Format with colors for console
|
||||
console_format = logging.Formatter(
|
||||
'%(asctime)s | %(levelname)-8s | %(name)s | %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'
|
||||
)
|
||||
console_handler.setFormatter(console_format)
|
||||
logger.addHandler(console_handler)
|
||||
|
||||
# File handler if log_file specified
|
||||
if log_file:
|
||||
# Ensure directory exists
|
||||
log_dir = os.path.dirname(log_file)
|
||||
if log_dir and not os.path.exists(log_dir):
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
|
||||
file_handler = RotatingFileHandler(
|
||||
log_file,
|
||||
maxBytes=max_size_mb * 1024 * 1024,
|
||||
backupCount=backup_count
|
||||
)
|
||||
file_handler.setLevel(logging.DEBUG)
|
||||
|
||||
file_format = logging.Formatter(
|
||||
'%(asctime)s | %(levelname)-8s | %(name)s | %(filename)s:%(lineno)d | %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'
|
||||
)
|
||||
file_handler.setFormatter(file_format)
|
||||
logger.addHandler(file_handler)
|
||||
|
||||
_loggers[name] = logger
|
||||
return logger
|
||||
|
||||
|
||||
def get_logger(name: str = "aasrt") -> logging.Logger:
|
||||
"""
|
||||
Get an existing logger or create a new one.
|
||||
|
||||
Args:
|
||||
name: Logger name
|
||||
|
||||
Returns:
|
||||
Logger instance
|
||||
"""
|
||||
if name in _loggers:
|
||||
return _loggers[name]
|
||||
return setup_logger(name)
|
||||
@@ -0,0 +1,583 @@
|
||||
"""
|
||||
Input validation utilities for AASRT.
|
||||
|
||||
This module provides comprehensive input validation and sanitization functions
|
||||
for security-sensitive operations including:
|
||||
- IP address and domain validation
|
||||
- Port number and query string validation
|
||||
- File path sanitization (directory traversal prevention)
|
||||
- API key format validation
|
||||
- Template name whitelist validation
|
||||
- Configuration value validation
|
||||
|
||||
All validators raise ValidationException on invalid input with descriptive
|
||||
error messages for debugging.
|
||||
|
||||
Example:
|
||||
>>> from src.utils.validators import validate_ip, validate_file_path
|
||||
>>> validate_ip("192.168.1.1") # Returns True
|
||||
>>> validate_file_path("../../../etc/passwd") # Raises ValidationException
|
||||
"""
|
||||
|
||||
import re
|
||||
import os
|
||||
import ipaddress
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Set, Union
|
||||
|
||||
import validators
|
||||
|
||||
from .exceptions import ValidationException
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Constants
|
||||
# =============================================================================
|
||||
|
||||
# Valid log levels for configuration
|
||||
VALID_LOG_LEVELS: Set[str] = {"DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"}
|
||||
|
||||
# Valid environment names
|
||||
VALID_ENVIRONMENTS: Set[str] = {"development", "staging", "production"}
|
||||
|
||||
# Valid database types
|
||||
VALID_DB_TYPES: Set[str] = {"sqlite", "postgresql", "mysql"}
|
||||
|
||||
# Valid report formats
|
||||
VALID_REPORT_FORMATS: Set[str] = {"json", "csv", "html", "pdf"}
|
||||
|
||||
# Valid query template names (whitelist)
|
||||
VALID_TEMPLATES: Set[str] = {
|
||||
"clawdbot_instances",
|
||||
"autogpt_instances",
|
||||
"langchain_agents",
|
||||
"openai_agents",
|
||||
"anthropic_agents",
|
||||
"ai_agent_general",
|
||||
"agent_gpt",
|
||||
"babyagi_instances",
|
||||
"crewai_instances",
|
||||
"autogen_instances",
|
||||
"superagi_instances",
|
||||
"flowise_instances",
|
||||
"dify_instances",
|
||||
}
|
||||
|
||||
# Maximum limits for various inputs
|
||||
MAX_QUERY_LENGTH: int = 2000
|
||||
MAX_RESULTS_LIMIT: int = 10000
|
||||
MIN_RESULTS_LIMIT: int = 1
|
||||
MAX_PORT: int = 65535
|
||||
MIN_PORT: int = 1
|
||||
MAX_FILE_PATH_LENGTH: int = 4096
|
||||
MAX_API_KEY_LENGTH: int = 256
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# IP and Network Validators
|
||||
# =============================================================================
|
||||
|
||||
def validate_ip(ip: str) -> bool:
|
||||
"""
|
||||
Validate an IP address (IPv4 or IPv6).
|
||||
|
||||
Args:
|
||||
ip: IP address string to validate.
|
||||
|
||||
Returns:
|
||||
True if the IP address is valid.
|
||||
|
||||
Raises:
|
||||
ValidationException: If IP is None, empty, or invalid format.
|
||||
|
||||
Example:
|
||||
>>> validate_ip("192.168.1.1")
|
||||
True
|
||||
>>> validate_ip("2001:db8::1")
|
||||
True
|
||||
>>> validate_ip("invalid")
|
||||
ValidationException: Invalid IP address: invalid
|
||||
"""
|
||||
if ip is None:
|
||||
raise ValidationException("IP address cannot be None")
|
||||
|
||||
if not isinstance(ip, str):
|
||||
raise ValidationException(f"IP address must be a string, got {type(ip).__name__}")
|
||||
|
||||
ip = ip.strip()
|
||||
if not ip:
|
||||
raise ValidationException("IP address cannot be empty")
|
||||
|
||||
try:
|
||||
ipaddress.ip_address(ip)
|
||||
return True
|
||||
except ValueError:
|
||||
raise ValidationException(f"Invalid IP address: {ip}")
|
||||
|
||||
|
||||
def validate_domain(domain: str) -> bool:
|
||||
"""
|
||||
Validate a domain name.
|
||||
|
||||
Args:
|
||||
domain: Domain name string
|
||||
|
||||
Returns:
|
||||
True if valid
|
||||
|
||||
Raises:
|
||||
ValidationException: If domain is invalid
|
||||
"""
|
||||
if validators.domain(domain):
|
||||
return True
|
||||
raise ValidationException(f"Invalid domain: {domain}")
|
||||
|
||||
|
||||
def validate_query(query: str, engine: str) -> bool:
|
||||
"""
|
||||
Validate a search query for a specific engine.
|
||||
|
||||
Args:
|
||||
query: Search query string
|
||||
engine: Search engine name
|
||||
|
||||
Returns:
|
||||
True if valid
|
||||
|
||||
Raises:
|
||||
ValidationException: If query is invalid
|
||||
"""
|
||||
if not query or not query.strip():
|
||||
raise ValidationException("Query cannot be empty")
|
||||
|
||||
# Check for potentially dangerous characters
|
||||
dangerous_patterns = [
|
||||
r'[<>]', # Script injection attempts
|
||||
r'\x00', # Null bytes
|
||||
]
|
||||
|
||||
for pattern in dangerous_patterns:
|
||||
if re.search(pattern, query):
|
||||
raise ValidationException(f"Query contains invalid characters: {pattern}")
|
||||
|
||||
# Engine-specific validation
|
||||
if engine == "shodan":
|
||||
# Shodan queries should be reasonable length
|
||||
if len(query) > 1000:
|
||||
raise ValidationException("Shodan query too long (max 1000 chars)")
|
||||
|
||||
elif engine == "censys":
|
||||
# Censys queries should be reasonable length
|
||||
if len(query) > 2000:
|
||||
raise ValidationException("Censys query too long (max 2000 chars)")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def validate_port(port: int) -> bool:
|
||||
"""
|
||||
Validate a port number.
|
||||
|
||||
Args:
|
||||
port: Port number
|
||||
|
||||
Returns:
|
||||
True if valid
|
||||
|
||||
Raises:
|
||||
ValidationException: If port is invalid
|
||||
"""
|
||||
if not isinstance(port, int) or port < 1 or port > 65535:
|
||||
raise ValidationException(f"Invalid port number: {port}")
|
||||
return True
|
||||
|
||||
|
||||
def validate_api_key(api_key: str, engine: str) -> bool:
|
||||
"""
|
||||
Validate API key format for a specific engine.
|
||||
|
||||
Args:
|
||||
api_key: API key string
|
||||
engine: Search engine name
|
||||
|
||||
Returns:
|
||||
True if valid
|
||||
|
||||
Raises:
|
||||
ValidationException: If API key format is invalid
|
||||
"""
|
||||
if not api_key or not api_key.strip():
|
||||
raise ValidationException(f"API key for {engine} cannot be empty")
|
||||
|
||||
# Basic format validation (not checking actual validity)
|
||||
if engine == "shodan":
|
||||
# Shodan API keys are typically 32 characters
|
||||
if len(api_key) < 20:
|
||||
raise ValidationException("Shodan API key appears too short")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def sanitize_output(text: str) -> str:
|
||||
"""
|
||||
Sanitize text for safe output (remove potential secrets).
|
||||
|
||||
This function redacts sensitive patterns like API keys, passwords, and
|
||||
authentication tokens to prevent accidental exposure in logs or output.
|
||||
|
||||
Args:
|
||||
text: Text to sanitize.
|
||||
|
||||
Returns:
|
||||
Sanitized text with sensitive data replaced by REDACTED markers.
|
||||
|
||||
Example:
|
||||
>>> sanitize_output("key: sk-ant-abc123...")
|
||||
'key: sk-ant-***REDACTED***'
|
||||
"""
|
||||
if text is None:
|
||||
return ""
|
||||
|
||||
if not isinstance(text, str):
|
||||
text = str(text)
|
||||
|
||||
# Patterns for sensitive data (order matters - more specific first)
|
||||
patterns = [
|
||||
# Anthropic API keys
|
||||
(r'sk-ant-[a-zA-Z0-9-_]{20,}', 'sk-ant-***REDACTED***'),
|
||||
# OpenAI API keys
|
||||
(r'sk-[a-zA-Z0-9]{40,}', 'sk-***REDACTED***'),
|
||||
# AWS Access Key
|
||||
(r'AKIA[0-9A-Z]{16}', 'AKIA***REDACTED***'),
|
||||
# AWS Secret Key
|
||||
(r'(?i)aws_secret_access_key["\s:=]+["\']?[A-Za-z0-9/+=]{40}', 'aws_secret_access_key=***REDACTED***'),
|
||||
# GitHub tokens
|
||||
(r'ghp_[a-zA-Z0-9]{36}', 'ghp_***REDACTED***'),
|
||||
(r'gho_[a-zA-Z0-9]{36}', 'gho_***REDACTED***'),
|
||||
# Google API keys
|
||||
(r'AIza[0-9A-Za-z-_]{35}', 'AIza***REDACTED***'),
|
||||
# Stripe keys
|
||||
(r'sk_live_[a-zA-Z0-9]{24,}', 'sk_live_***REDACTED***'),
|
||||
(r'sk_test_[a-zA-Z0-9]{24,}', 'sk_test_***REDACTED***'),
|
||||
# Shodan API key (32 hex chars)
|
||||
(r'[a-fA-F0-9]{32}', '***REDACTED_KEY***'),
|
||||
# Generic password patterns
|
||||
(r'password["\s:=]+["\']?[\w@#$%^&*!?]+', 'password=***REDACTED***'),
|
||||
(r'passwd["\s:=]+["\']?[\w@#$%^&*!?]+', 'passwd=***REDACTED***'),
|
||||
(r'secret["\s:=]+["\']?[\w@#$%^&*!?]+', 'secret=***REDACTED***'),
|
||||
# Bearer tokens
|
||||
(r'Bearer\s+[a-zA-Z0-9._-]+', 'Bearer ***REDACTED***'),
|
||||
# Basic auth
|
||||
(r'Basic\s+[a-zA-Z0-9+/=]+', 'Basic ***REDACTED***'),
|
||||
]
|
||||
|
||||
result = text
|
||||
for pattern, replacement in patterns:
|
||||
result = re.sub(pattern, replacement, result, flags=re.IGNORECASE)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# File Path Validators
|
||||
# =============================================================================
|
||||
|
||||
def validate_file_path(
|
||||
path: str,
|
||||
must_exist: bool = False,
|
||||
allow_absolute: bool = True,
|
||||
base_dir: Optional[str] = None
|
||||
) -> str:
|
||||
"""
|
||||
Validate and sanitize a file path to prevent directory traversal attacks.
|
||||
|
||||
Args:
|
||||
path: File path to validate.
|
||||
must_exist: If True, the file must exist.
|
||||
allow_absolute: If True, allow absolute paths.
|
||||
base_dir: If provided, ensure path is within this directory.
|
||||
|
||||
Returns:
|
||||
Sanitized, normalized file path.
|
||||
|
||||
Raises:
|
||||
ValidationException: If path is invalid or potentially dangerous.
|
||||
|
||||
Example:
|
||||
>>> validate_file_path("reports/scan.json")
|
||||
'reports/scan.json'
|
||||
>>> validate_file_path("../../../etc/passwd")
|
||||
ValidationException: Path traversal detected
|
||||
"""
|
||||
if path is None:
|
||||
raise ValidationException("File path cannot be None")
|
||||
|
||||
if not isinstance(path, str):
|
||||
raise ValidationException(f"File path must be a string, got {type(path).__name__}")
|
||||
|
||||
path = path.strip()
|
||||
if not path:
|
||||
raise ValidationException("File path cannot be empty")
|
||||
|
||||
if len(path) > MAX_FILE_PATH_LENGTH:
|
||||
raise ValidationException(f"File path too long (max {MAX_FILE_PATH_LENGTH} chars)")
|
||||
|
||||
# Check for null bytes (security risk)
|
||||
if '\x00' in path:
|
||||
raise ValidationException("File path contains null bytes")
|
||||
|
||||
# Normalize the path
|
||||
try:
|
||||
normalized = os.path.normpath(path)
|
||||
except Exception as e:
|
||||
raise ValidationException(f"Invalid file path: {e}")
|
||||
|
||||
# Check for directory traversal
|
||||
if '..' in normalized.split(os.sep):
|
||||
raise ValidationException("Path traversal detected: '..' not allowed")
|
||||
|
||||
# Check absolute path restriction
|
||||
if not allow_absolute and os.path.isabs(normalized):
|
||||
raise ValidationException("Absolute paths not allowed")
|
||||
|
||||
# Check if within base directory
|
||||
if base_dir:
|
||||
base_dir = os.path.abspath(base_dir)
|
||||
full_path = os.path.abspath(os.path.join(base_dir, normalized))
|
||||
if not full_path.startswith(base_dir):
|
||||
raise ValidationException("Path escapes base directory")
|
||||
|
||||
# Check existence if required
|
||||
if must_exist and not os.path.exists(path):
|
||||
raise ValidationException(f"File does not exist: {path}")
|
||||
|
||||
return normalized
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Template and Configuration Validators
|
||||
# =============================================================================
|
||||
|
||||
def validate_template_name(template: str) -> bool:
|
||||
"""
|
||||
Validate a query template name against the whitelist.
|
||||
|
||||
Args:
|
||||
template: Template name to validate.
|
||||
|
||||
Returns:
|
||||
True if template is valid.
|
||||
|
||||
Raises:
|
||||
ValidationException: If template is not in the allowed list.
|
||||
|
||||
Example:
|
||||
>>> validate_template_name("clawdbot_instances")
|
||||
True
|
||||
>>> validate_template_name("malicious_query")
|
||||
ValidationException: Invalid template name
|
||||
"""
|
||||
if template is None:
|
||||
raise ValidationException("Template name cannot be None")
|
||||
|
||||
template = template.strip().lower()
|
||||
if not template:
|
||||
raise ValidationException("Template name cannot be empty")
|
||||
|
||||
if template not in VALID_TEMPLATES:
|
||||
valid_list = ", ".join(sorted(VALID_TEMPLATES))
|
||||
raise ValidationException(
|
||||
f"Invalid template name: '{template}'. Valid templates: {valid_list}"
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def validate_max_results(max_results: Union[int, str]) -> int:
|
||||
"""
|
||||
Validate and normalize max_results parameter.
|
||||
|
||||
Args:
|
||||
max_results: Maximum number of results (int or string).
|
||||
|
||||
Returns:
|
||||
Validated integer value.
|
||||
|
||||
Raises:
|
||||
ValidationException: If value is invalid or out of range.
|
||||
|
||||
Example:
|
||||
>>> validate_max_results(100)
|
||||
100
|
||||
>>> validate_max_results("50")
|
||||
50
|
||||
>>> validate_max_results(-1)
|
||||
ValidationException: max_results must be positive
|
||||
"""
|
||||
if max_results is None:
|
||||
raise ValidationException("max_results cannot be None")
|
||||
|
||||
# Convert string to int if needed
|
||||
if isinstance(max_results, str):
|
||||
try:
|
||||
max_results = int(max_results.strip())
|
||||
except ValueError:
|
||||
raise ValidationException(f"max_results must be a number, got: '{max_results}'")
|
||||
|
||||
if not isinstance(max_results, int):
|
||||
raise ValidationException(f"max_results must be an integer, got {type(max_results).__name__}")
|
||||
|
||||
if max_results < MIN_RESULTS_LIMIT:
|
||||
raise ValidationException(f"max_results must be at least {MIN_RESULTS_LIMIT}")
|
||||
|
||||
if max_results > MAX_RESULTS_LIMIT:
|
||||
raise ValidationException(f"max_results cannot exceed {MAX_RESULTS_LIMIT}")
|
||||
|
||||
return max_results
|
||||
|
||||
|
||||
def validate_log_level(level: str) -> str:
|
||||
"""
|
||||
Validate a log level string.
|
||||
|
||||
Args:
|
||||
level: Log level string.
|
||||
|
||||
Returns:
|
||||
Normalized uppercase log level.
|
||||
|
||||
Raises:
|
||||
ValidationException: If log level is invalid.
|
||||
"""
|
||||
if level is None:
|
||||
raise ValidationException("Log level cannot be None")
|
||||
|
||||
level = str(level).strip().upper()
|
||||
|
||||
if level not in VALID_LOG_LEVELS:
|
||||
valid_list = ", ".join(sorted(VALID_LOG_LEVELS))
|
||||
raise ValidationException(f"Invalid log level: '{level}'. Valid levels: {valid_list}")
|
||||
|
||||
return level
|
||||
|
||||
|
||||
def validate_environment(env: str) -> str:
|
||||
"""
|
||||
Validate an environment name.
|
||||
|
||||
Args:
|
||||
env: Environment name string.
|
||||
|
||||
Returns:
|
||||
Normalized lowercase environment name.
|
||||
|
||||
Raises:
|
||||
ValidationException: If environment is invalid.
|
||||
"""
|
||||
if env is None:
|
||||
raise ValidationException("Environment cannot be None")
|
||||
|
||||
env = str(env).strip().lower()
|
||||
|
||||
if env not in VALID_ENVIRONMENTS:
|
||||
valid_list = ", ".join(sorted(VALID_ENVIRONMENTS))
|
||||
raise ValidationException(f"Invalid environment: '{env}'. Valid environments: {valid_list}")
|
||||
|
||||
return env
|
||||
|
||||
|
||||
def validate_db_type(db_type: str) -> str:
|
||||
"""
|
||||
Validate a database type.
|
||||
|
||||
Args:
|
||||
db_type: Database type string.
|
||||
|
||||
Returns:
|
||||
Normalized lowercase database type.
|
||||
|
||||
Raises:
|
||||
ValidationException: If database type is invalid.
|
||||
"""
|
||||
if db_type is None:
|
||||
raise ValidationException("Database type cannot be None")
|
||||
|
||||
db_type = str(db_type).strip().lower()
|
||||
|
||||
if db_type not in VALID_DB_TYPES:
|
||||
valid_list = ", ".join(sorted(VALID_DB_TYPES))
|
||||
raise ValidationException(f"Invalid database type: '{db_type}'. Valid types: {valid_list}")
|
||||
|
||||
return db_type
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Batch Validation Helpers
|
||||
# =============================================================================
|
||||
|
||||
def validate_config_dict(config: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Validate a configuration dictionary.
|
||||
|
||||
Args:
|
||||
config: Configuration dictionary to validate.
|
||||
|
||||
Returns:
|
||||
Validated configuration dictionary.
|
||||
|
||||
Raises:
|
||||
ValidationException: If any configuration value is invalid.
|
||||
"""
|
||||
validated = {}
|
||||
|
||||
# Validate log level if present
|
||||
if 'logging' in config and 'level' in config['logging']:
|
||||
config['logging']['level'] = validate_log_level(config['logging']['level'])
|
||||
|
||||
# Validate database type if present
|
||||
if 'database' in config and 'type' in config['database']:
|
||||
config['database']['type'] = validate_db_type(config['database']['type'])
|
||||
|
||||
# Validate max_results if present
|
||||
if 'shodan' in config and 'max_results' in config['shodan']:
|
||||
config['shodan']['max_results'] = validate_max_results(config['shodan']['max_results'])
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def is_safe_string(text: str, max_length: int = 1000) -> bool:
|
||||
"""
|
||||
Check if a string is safe (no injection attempts).
|
||||
|
||||
Args:
|
||||
text: Text to check.
|
||||
max_length: Maximum allowed length.
|
||||
|
||||
Returns:
|
||||
True if string appears safe, False otherwise.
|
||||
"""
|
||||
if text is None:
|
||||
return False
|
||||
|
||||
if len(text) > max_length:
|
||||
return False
|
||||
|
||||
# Check for null bytes
|
||||
if '\x00' in text:
|
||||
return False
|
||||
|
||||
# Check for common injection patterns
|
||||
dangerous_patterns = [
|
||||
r'<script',
|
||||
r'javascript:',
|
||||
r'on\w+\s*=',
|
||||
r'\x00',
|
||||
r'<!--',
|
||||
r'--\s*>',
|
||||
]
|
||||
|
||||
for pattern in dangerous_patterns:
|
||||
if re.search(pattern, text, re.IGNORECASE):
|
||||
return False
|
||||
|
||||
return True
|
||||
Reference in New Issue
Block a user