import json
import logging
import re
import subprocess
import shlex
import shutil
import urllib.parse
import os
from typing import Dict, Any, List, Optional, Tuple
from datetime import datetime
from core.llm_manager import LLMManager
logger = logging.getLogger(__name__)
class BaseAgent:
"""
Autonomous AI-Powered Security Agent.
This agent operates like a real pentester:
1. Discovers attack surface dynamically
2. Analyzes responses intelligently
3. Adapts testing based on findings
4. Intensifies when it finds something interesting
5. Documents real PoCs
"""
def __init__(self, agent_name: str, config: Dict, llm_manager: LLMManager, context_prompts: Dict):
self.agent_name = agent_name
self.config = config
self.llm_manager = llm_manager
self.context_prompts = context_prompts
self.agent_role_config = self.config.get('agent_roles', {}).get(agent_name, {})
self.tools_allowed = self.agent_role_config.get('tools_allowed', [])
self.description = self.agent_role_config.get('description', 'Autonomous Security Tester')
# Attack surface discovered
self.discovered_endpoints = []
self.discovered_params = []
self.discovered_forms = []
self.tech_stack = {}
# Findings
self.vulnerabilities = []
self.interesting_findings = []
self.tool_history = []
# Knowledge augmentation (opt-in via env)
self.augmentor = None
if os.getenv('ENABLE_KNOWLEDGE_AUGMENTATION', 'false').lower() == 'true':
try:
from core.knowledge_augmentor import KnowledgeAugmentor
ka_config = config.get('knowledge_augmentation', {})
self.augmentor = KnowledgeAugmentor(
dataset_path=ka_config.get('dataset_path', 'models/bug-bounty/bugbounty_finetuning_dataset.json'),
max_patterns=ka_config.get('max_patterns_per_query', 3)
)
logger.info("Knowledge augmentation enabled")
except Exception as e:
logger.warning(f"Knowledge augmentation init failed: {e}")
# MCP tool client (opt-in via config)
self.mcp_client = None
if config.get('mcp_servers', {}).get('enabled', False):
try:
from core.mcp_client import MCPToolClient
self.mcp_client = MCPToolClient(config)
logger.info("MCP tool client enabled")
except Exception as e:
logger.warning(f"MCP client init failed: {e}")
# Browser validation (opt-in via env)
self.browser_validation_enabled = (
os.getenv('ENABLE_BROWSER_VALIDATION', 'false').lower() == 'true'
)
logger.info(f"Initialized {self.agent_name} - Autonomous Agent")
def _extract_targets(self, user_input: str) -> List[str]:
"""Extract target URLs from input."""
targets = []
if os.path.isfile(user_input.strip()):
with open(user_input.strip(), 'r') as f:
for line in f:
line = line.strip()
if line and not line.startswith('#'):
targets.append(self._normalize_url(line))
return targets
url_pattern = r'https?://[^\s<>"{}|\\^`\[\]]+'
urls = re.findall(url_pattern, user_input)
if urls:
return [self._normalize_url(u) for u in urls]
domain_pattern = r'\b(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,}\b'
domains = re.findall(domain_pattern, user_input)
if domains:
return [f"http://{d}" for d in domains]
return []
def _normalize_url(self, url: str) -> str:
url = url.strip()
if not url.startswith(('http://', 'https://')):
url = f"http://{url}"
return url
def _get_domain(self, url: str) -> str:
parsed = urllib.parse.urlparse(url)
return parsed.netloc or parsed.path.split('/')[0]
def run_command(self, tool: str, args: str, timeout: int = 60) -> Dict:
"""Execute command and capture output."""
result = {
"tool": tool,
"args": args,
"command": "",
"success": False,
"output": "",
"timestamp": datetime.now().isoformat()
}
tool_path = self.config.get('tools', {}).get(tool) or shutil.which(tool)
if not tool_path:
result["output"] = f"[!] Tool '{tool}' not found - using alternative"
logger.warning(f"Tool not found: {tool}")
self.tool_history.append(result)
return result
try:
if tool == "curl":
cmd = f"{tool_path} {args}"
else:
cmd = f"{tool_path} {args}"
result["command"] = cmd
print(f" [>] {tool}: {args[:80]}{'...' if len(args) > 80 else ''}")
proc = subprocess.run(
cmd,
shell=True,
capture_output=True,
text=True,
timeout=timeout
)
output = proc.stdout or proc.stderr
result["output"] = output[:8000] if output else "[No output]"
result["success"] = proc.returncode == 0
except subprocess.TimeoutExpired:
result["output"] = f"[!] Timeout after {timeout}s"
except Exception as e:
result["output"] = f"[!] Error: {str(e)}"
self.tool_history.append(result)
return result
def run_mcp_tool(self, tool_name: str, arguments: Optional[Dict] = None) -> Optional[str]:
"""Execute a tool via MCP if available, returns None for subprocess fallback."""
if not self.mcp_client or not self.mcp_client.enabled:
return None
import asyncio
try:
result = asyncio.run(self.mcp_client.try_tool(tool_name, arguments))
if result is not None:
logger.info(f"MCP tool executed: {tool_name}")
return result
except Exception as e:
logger.debug(f"MCP tool '{tool_name}' not available: {e}")
return None
def run_browser_validation(self, finding_id: str, url: str,
payload: str = None) -> Dict:
"""Validate a finding using Playwright browser.
Only executes if ENABLE_BROWSER_VALIDATION is set.
Returns validation result with screenshots.
"""
if not self.browser_validation_enabled:
return {"skipped": True, "reason": "Browser validation disabled"}
try:
from core.browser_validator import validate_finding_sync
screenshots_dir = self.config.get('browser_validation', {}).get(
'screenshots_dir', 'reports/screenshots'
)
return validate_finding_sync(
finding_id=finding_id,
url=url,
payload=payload,
screenshots_dir=f"{screenshots_dir}/{self.agent_name}",
headless=self.config.get('browser_validation', {}).get('headless', True)
)
except Exception as e:
logger.error(f"Browser validation failed for {finding_id}: {e}")
return {"finding_id": finding_id, "error": str(e)}
def get_augmented_context(self, vulnerability_types: List[str]) -> str:
"""Get knowledge augmentation context for detected vulnerability types.
Returns formatted pattern context string to inject into prompts.
"""
if not self.augmentor:
return ""
augmentation = ""
technologies = list(self.tech_stack.get('detected', []))
for vtype in vulnerability_types[:3]: # Limit to avoid context bloat
patterns = self.augmentor.get_relevant_patterns(
vulnerability_type=vtype,
technologies=technologies
)
if patterns:
augmentation += patterns
return augmentation
def execute(self, user_input: str, campaign_data: Dict = None, recon_context: Dict = None) -> Dict:
"""
Execute security assessment.
If recon_context is provided, skip discovery and use the context.
Otherwise extract targets and run discovery.
"""
# Check if we have recon context (pre-collected data)
if recon_context:
return self._execute_with_context(user_input, recon_context)
# Legacy mode: extract targets and do discovery
targets = self._extract_targets(user_input)
if not targets:
return {
"error": "No targets found",
"llm_response": "Please provide a URL, domain, IP, or file with targets."
}
print(f"\n{'='*70}")
print(f" NEUROSPLOIT AUTONOMOUS AGENT - {self.agent_name.upper()}")
print(f"{'='*70}")
print(f" Mode: Adaptive AI-Driven Testing")
print(f" Targets: {len(targets)}")
print(f"{'='*70}\n")
all_findings = []
for idx, target in enumerate(targets, 1):
if len(targets) > 1:
print(f"\n[TARGET {idx}/{len(targets)}] {target}")
print("=" * 60)
self.tool_history = []
self.vulnerabilities = []
self.discovered_endpoints = []
findings = self._autonomous_assessment(target)
all_findings.extend(findings)
final_report = self._generate_final_report(targets, all_findings)
return {
"agent_name": self.agent_name,
"input": user_input,
"targets": targets,
"targets_count": len(targets),
"tools_executed": len(self.tool_history),
"vulnerabilities_found": len(self.vulnerabilities),
"findings": all_findings,
"llm_response": final_report,
"scan_data": {
"targets": targets,
"tools_executed": len(self.tool_history),
"endpoints_discovered": len(self.discovered_endpoints)
}
}
def _execute_with_context(self, user_input: str, recon_context: Dict) -> Dict:
"""
ADAPTIVE AI Mode - Analyzes context sufficiency, runs tools if needed.
Flow:
1. Analyze what user is asking for
2. Check if context has sufficient data
3. If insufficient → Run necessary tools to collect data
4. Perform final analysis with complete data
"""
target = recon_context.get('target', {}).get('primary_target', 'Unknown')
print(f"\n{'='*70}")
print(f" NEUROSPLOIT ADAPTIVE AI - {self.agent_name.upper()}")
print(f"{'='*70}")
print(f" Mode: Adaptive (LLM + Tools when needed)")
print(f" Target: {target}")
print(f" Context loaded with:")
attack_surface = recon_context.get('attack_surface', {})
print(f" - Subdomains: {attack_surface.get('total_subdomains', 0)}")
print(f" - Live hosts: {attack_surface.get('live_hosts', 0)}")
print(f" - URLs: {attack_surface.get('total_urls', 0)}")
print(f" - URLs with params: {attack_surface.get('urls_with_params', 0)}")
print(f" - Open ports: {attack_surface.get('open_ports', 0)}")
print(f" - Vulnerabilities: {attack_surface.get('vulnerabilities_found', 0)}")
print(f"{'='*70}\n")
# Extract context data
data = recon_context.get('data', {})
urls_with_params = data.get('urls', {}).get('with_params', [])
technologies = data.get('technologies', [])
api_endpoints = data.get('api_endpoints', [])
interesting_paths = data.get('interesting_paths', [])
existing_vulns = recon_context.get('vulnerabilities', {}).get('all', [])
unique_params = data.get('unique_params', {})
subdomains = data.get('subdomains', [])
live_hosts = data.get('live_hosts', [])
open_ports = data.get('open_ports', [])
js_files = data.get('js_files', [])
secrets = data.get('secrets', [])
all_urls = data.get('urls', {}).get('all', [])
# Phase 1: AI Analyzes Context Sufficiency
print(f"[PHASE 1] Analyzing Context Sufficiency")
print("-" * 50)
context_summary = {
"urls_with_params": len(urls_with_params),
"total_urls": len(all_urls),
"technologies": technologies,
"api_endpoints": len(api_endpoints),
"open_ports": len(open_ports),
"js_files": len(js_files),
"existing_vulns": len(existing_vulns),
"subdomains": len(subdomains),
"live_hosts": len(live_hosts),
"params_found": list(unique_params.keys())[:20]
}
gaps = self._analyze_context_gaps(user_input, context_summary, target)
self.tool_history = []
self.vulnerabilities = list(existing_vulns)
# Phase 2: Run tools to fill gaps if needed
if gaps.get('needs_tools', False):
print(f"\n[PHASE 2] Collecting Missing Data")
print("-" * 50)
print(f" [!] Context insufficient for: {', '.join(gaps.get('missing', []))}")
print(f" [*] Running tools to collect data...")
self._fill_context_gaps(target, gaps, urls_with_params, all_urls)
else:
print(f"\n[PHASE 2] Context Sufficient")
print("-" * 50)
print(f" [+] All required data available in context")
# Phase 3: Final AI Analysis
print(f"\n[PHASE 3] AI Analysis")
print("-" * 50)
context_text = self._build_context_text(target, recon_context)
llm_response = self._final_analysis(user_input, context_text, target)
return {
"agent_name": self.agent_name,
"input": user_input,
"targets": [target],
"targets_count": 1,
"tools_executed": len(self.tool_history),
"vulnerabilities_found": len(self.vulnerabilities),
"findings": self.tool_history,
"llm_response": llm_response,
"context_used": True,
"mode": "adaptive",
"scan_data": {
"targets": [target],
"tools_executed": len(self.tool_history),
"context_based": True
}
}
def _analyze_context_gaps(self, user_input: str, context_summary: Dict, target: str) -> Dict:
"""AI analyzes what user wants and what's missing in context."""
analysis_prompt = f"""Analyze this user request and context to determine what data is missing.
USER REQUEST:
{user_input}
AVAILABLE CONTEXT DATA:
- URLs with parameters: {context_summary['urls_with_params']}
- Total URLs discovered: {context_summary['total_urls']}
- Technologies detected: {', '.join(context_summary['technologies']) if context_summary['technologies'] else 'None'}
- API endpoints: {context_summary['api_endpoints']}
- Open ports scanned: {context_summary['open_ports']}
- JavaScript files: {context_summary['js_files']}
- Existing vulnerabilities: {context_summary['existing_vulns']}
- Subdomains: {context_summary['subdomains']}
- Live hosts: {context_summary['live_hosts']}
- Parameters found: {', '.join(context_summary['params_found'][:15]) if context_summary['params_found'] else 'None'}
TARGET: {target}
DETERMINE what the user wants to test/analyze and if we have sufficient data.
Respond in this EXACT format:
NEEDS_TOOLS: YES or NO
MISSING: [comma-separated list of what's missing]
TESTS_NEEDED: [comma-separated list of test types needed: sqli, xss, lfi, ssrf, rce, port_scan, subdomain, crawl, etc.]
URLS_TO_TEST: [list specific URLs from context to test, or DISCOVER if need to find URLs]
REASON: [brief explanation]"""
system = "You are a security assessment planner. Analyze context and determine data gaps. Be concise."
response = self.llm_manager.generate(analysis_prompt, system)
# Parse response
gaps = {
"needs_tools": False,
"missing": [],
"tests_needed": [],
"urls_to_test": [],
"reason": ""
}
for line in response.split('\n'):
line = line.strip()
if line.startswith('NEEDS_TOOLS:'):
gaps['needs_tools'] = 'YES' in line.upper()
elif line.startswith('MISSING:'):
items = line.replace('MISSING:', '').strip().strip('[]')
gaps['missing'] = [x.strip() for x in items.split(',') if x.strip()]
elif line.startswith('TESTS_NEEDED:'):
items = line.replace('TESTS_NEEDED:', '').strip().strip('[]')
gaps['tests_needed'] = [x.strip().lower() for x in items.split(',') if x.strip()]
elif line.startswith('URLS_TO_TEST:'):
items = line.replace('URLS_TO_TEST:', '').strip().strip('[]')
gaps['urls_to_test'] = [x.strip() for x in items.split(',') if x.strip() and x.startswith('http')]
elif line.startswith('REASON:'):
gaps['reason'] = line.replace('REASON:', '').strip()
print(f" [*] User wants: {', '.join(gaps['tests_needed']) if gaps['tests_needed'] else 'general analysis'}")
print(f" [*] Data sufficient: {'No' if gaps['needs_tools'] else 'Yes'}")
if gaps['missing']:
print(f" [*] Missing: {', '.join(gaps['missing'])}")
return gaps
def _fill_context_gaps(self, target: str, gaps: Dict, urls_with_params: List, all_urls: List):
"""Run tools to collect missing data based on identified gaps."""
tests_needed = gaps.get('tests_needed', [])
urls_to_test = gaps.get('urls_to_test', [])
# If no specific URLs, use from context
if not urls_to_test or 'DISCOVER' in str(urls_to_test).upper():
urls_to_test = urls_with_params[:20] if urls_with_params else all_urls[:20]
# Normalize target
if not target.startswith(('http://', 'https://')):
target = f"http://{target}"
tools_run = 0
max_tools = 30
# XSS Testing
if any(t in tests_needed for t in ['xss', 'cross-site', 'reflected', 'stored']):
print(f"\n [XSS] Running XSS tests...")
xss_payloads = [
'',
'">',
"'-alert(1)-'",
'
',
'