diff --git a/agents/base_agent.py b/agents/base_agent.py index f2a9216..4535dbe 100644 --- a/agents/base_agent.py +++ b/agents/base_agent.py @@ -47,6 +47,35 @@ class BaseAgent: self.interesting_findings = [] self.tool_history = [] + # Knowledge augmentation (opt-in via env) + self.augmentor = None + if os.getenv('ENABLE_KNOWLEDGE_AUGMENTATION', 'false').lower() == 'true': + try: + from core.knowledge_augmentor import KnowledgeAugmentor + ka_config = config.get('knowledge_augmentation', {}) + self.augmentor = KnowledgeAugmentor( + dataset_path=ka_config.get('dataset_path', 'models/bug-bounty/bugbounty_finetuning_dataset.json'), + max_patterns=ka_config.get('max_patterns_per_query', 3) + ) + logger.info("Knowledge augmentation enabled") + except Exception as e: + logger.warning(f"Knowledge augmentation init failed: {e}") + + # MCP tool client (opt-in via config) + self.mcp_client = None + if config.get('mcp_servers', {}).get('enabled', False): + try: + from core.mcp_client import MCPToolClient + self.mcp_client = MCPToolClient(config) + logger.info("MCP tool client enabled") + except Exception as e: + logger.warning(f"MCP client init failed: {e}") + + # Browser validation (opt-in via env) + self.browser_validation_enabled = ( + os.getenv('ENABLE_BROWSER_VALIDATION', 'false').lower() == 'true' + ) + logger.info(f"Initialized {self.agent_name} - Autonomous Agent") def _extract_targets(self, user_input: str) -> List[str]: @@ -131,6 +160,68 @@ class BaseAgent: self.tool_history.append(result) return result + def run_mcp_tool(self, tool_name: str, arguments: Optional[Dict] = None) -> Optional[str]: + """Execute a tool via MCP if available, returns None for subprocess fallback.""" + if not self.mcp_client or not self.mcp_client.enabled: + return None + + import asyncio + try: + result = asyncio.run(self.mcp_client.try_tool(tool_name, arguments)) + if result is not None: + logger.info(f"MCP tool executed: {tool_name}") + return result + except Exception as e: + logger.debug(f"MCP tool '{tool_name}' not available: {e}") + return None + + def run_browser_validation(self, finding_id: str, url: str, + payload: str = None) -> Dict: + """Validate a finding using Playwright browser. + + Only executes if ENABLE_BROWSER_VALIDATION is set. + Returns validation result with screenshots. + """ + if not self.browser_validation_enabled: + return {"skipped": True, "reason": "Browser validation disabled"} + + try: + from core.browser_validator import validate_finding_sync + screenshots_dir = self.config.get('browser_validation', {}).get( + 'screenshots_dir', 'reports/screenshots' + ) + return validate_finding_sync( + finding_id=finding_id, + url=url, + payload=payload, + screenshots_dir=f"{screenshots_dir}/{self.agent_name}", + headless=self.config.get('browser_validation', {}).get('headless', True) + ) + except Exception as e: + logger.error(f"Browser validation failed for {finding_id}: {e}") + return {"finding_id": finding_id, "error": str(e)} + + def get_augmented_context(self, vulnerability_types: List[str]) -> str: + """Get knowledge augmentation context for detected vulnerability types. + + Returns formatted pattern context string to inject into prompts. + """ + if not self.augmentor: + return "" + + augmentation = "" + technologies = list(self.tech_stack.get('detected', [])) + + for vtype in vulnerability_types[:3]: # Limit to avoid context bloat + patterns = self.augmentor.get_relevant_patterns( + vulnerability_type=vtype, + technologies=technologies + ) + if patterns: + augmentation += patterns + + return augmentation + def execute(self, user_input: str, campaign_data: Dict = None, recon_context: Dict = None) -> Dict: """ Execute security assessment. diff --git a/backend/api/v1/agent.py b/backend/api/v1/agent.py index c3401ec..7115e33 100644 --- a/backend/api/v1/agent.py +++ b/backend/api/v1/agent.py @@ -32,6 +32,8 @@ agent_instances: Dict[str, AutonomousAgent] = {} # Map agent_id to scan_id for database persistence agent_to_scan: Dict[str, str] = {} +# Reverse map: scan_id to agent_id for ScanDetailsPage lookups +scan_to_agent: Dict[str, str] = {} @router.get("/status") @@ -101,6 +103,7 @@ class AgentMode(str, Enum): RECON_ONLY = "recon_only" # Just reconnaissance PROMPT_ONLY = "prompt_only" # AI decides (high tokens) ANALYZE_ONLY = "analyze_only" # Analysis without testing + AUTO_PENTEST = "auto_pentest" # One-click full auto pentest class AgentRequest(BaseModel): @@ -113,6 +116,8 @@ class AgentRequest(BaseModel): auth_value: Optional[str] = Field(None, description="Auth value (cookie string, token, etc)") custom_headers: Optional[Dict[str, str]] = Field(None, description="Custom HTTP headers") max_depth: int = Field(5, description="Maximum crawl depth") + subdomain_discovery: bool = Field(False, description="Enable subdomain discovery (auto_pentest mode)") + targets: Optional[List[str]] = Field(None, description="Multiple targets (auto_pentest mode)") class AgentResponse(BaseModel): @@ -193,7 +198,9 @@ async def run_agent(request: AgentRequest, background_tasks: BackgroundTasks): "findings": [], "report": None, "progress": 0, - "phase": "initializing" + "phase": "initializing", + "rejected_findings": [], + "rejected_findings_count": 0, } # Run agent in background @@ -212,7 +219,8 @@ async def run_agent(request: AgentRequest, background_tasks: BackgroundTasks): "full_auto": "Full autonomous pentest: Recon -> Analyze -> Test -> Report", "recon_only": "Reconnaissance only, no vulnerability testing", "prompt_only": "AI decides everything (high token usage!)", - "analyze_only": "Analysis only, no active testing" + "analyze_only": "Analysis only, no active testing", + "auto_pentest": "One-click auto pentest: Full recon + 100 vuln types + AI report" } return AgentResponse( @@ -255,12 +263,20 @@ async def _run_agent_task( agent_results[agent_id]["progress"] = progress agent_results[agent_id]["phase"] = phase + rejected_findings_list = [] + async def finding_callback(finding: Dict): """Real-time finding callback - updates in-memory storage immediately""" - findings_list.append(finding) - if agent_id in agent_results: - agent_results[agent_id]["findings"] = findings_list - agent_results[agent_id]["findings_count"] = len(findings_list) + if finding.get("ai_status") == "rejected": + rejected_findings_list.append(finding) + if agent_id in agent_results: + agent_results[agent_id]["rejected_findings"] = rejected_findings_list + agent_results[agent_id]["rejected_findings_count"] = len(rejected_findings_list) + else: + findings_list.append(finding) + if agent_id in agent_results: + agent_results[agent_id]["findings"] = findings_list + agent_results[agent_id]["findings_count"] = len(findings_list) try: # Create database session and scan record @@ -289,8 +305,9 @@ async def _run_agent_task( db.add(target_record) await db.commit() - # Store mapping + # Store mapping (both directions) agent_to_scan[agent_id] = scan_id + scan_to_agent[scan_id] = agent_id agent_results[agent_id]["scan_id"] = scan_id # Map mode @@ -299,6 +316,7 @@ async def _run_agent_task( AgentMode.RECON_ONLY: OperationMode.RECON_ONLY, AgentMode.PROMPT_ONLY: OperationMode.PROMPT_ONLY, AgentMode.ANALYZE_ONLY: OperationMode.ANALYZE_ONLY, + AgentMode.AUTO_PENTEST: OperationMode.AUTO_PENTEST, } op_mode = mode_map.get(mode, OperationMode.FULL_AUTO) @@ -311,6 +329,7 @@ async def _run_agent_task( task=task, custom_prompt=custom_prompt or (task.prompt if task else None), finding_callback=finding_callback, + scan_id=str(scan_id), ) as agent: # Store agent instance for stop functionality agent_instances[agent_id] = agent @@ -345,7 +364,41 @@ async def _run_agent_task( impact=finding.get("impact", ""), remediation=finding.get("remediation", ""), references=finding.get("references", []), - ai_analysis=finding.get("ai_analysis", finding.get("exploitation_steps", "")) + ai_analysis=finding.get("ai_analysis", finding.get("exploitation_steps", "")), + poc_code=finding.get("poc_code", ""), + screenshots=finding.get("screenshots", []), + url=finding.get("url", finding.get("affected_endpoint", "")), + parameter=finding.get("parameter", finding.get("poc_parameter", "")), + validation_status="ai_confirmed", + ) + db.add(vuln) + + # Save rejected findings to database for manual review + for finding in report.get("rejected_findings", []): + vuln = Vulnerability( + scan_id=scan_id, + title=finding.get("title", finding.get("type", "Unknown")), + vulnerability_type=finding.get("vulnerability_type", finding.get("type", "unknown")), + severity=finding.get("severity", "medium").lower(), + cvss_score=finding.get("cvss_score"), + cvss_vector=finding.get("cvss_vector"), + cwe_id=finding.get("cwe_id"), + description=finding.get("description", finding.get("evidence", "")), + affected_endpoint=finding.get("affected_endpoint", finding.get("endpoint", finding.get("url", target))), + poc_payload=finding.get("payload", finding.get("poc_payload", "")), + poc_parameter=finding.get("parameter", finding.get("poc_parameter", "")), + poc_evidence=finding.get("evidence", finding.get("poc_evidence", "")), + poc_request=str(finding.get("request", finding.get("poc_request", "")))[:5000], + poc_response=str(finding.get("response", finding.get("poc_response", "")))[:5000], + impact=finding.get("impact", ""), + remediation=finding.get("remediation", ""), + references=finding.get("references", []), + poc_code=finding.get("poc_code", ""), + screenshots=finding.get("screenshots", []), + url=finding.get("url", finding.get("affected_endpoint", "")), + parameter=finding.get("parameter", finding.get("poc_parameter", "")), + validation_status="ai_rejected", + ai_rejection_reason=finding.get("rejection_reason", ""), ) db.add(vuln) @@ -402,6 +455,7 @@ async def _run_agent_task( agent_results[agent_id]["report"] = report agent_results[agent_id]["report_id"] = report_record.id agent_results[agent_id]["findings"] = findings + agent_results[agent_id]["tool_executions"] = report.get("tool_executions", []) agent_results[agent_id]["progress"] = 100 agent_results[agent_id]["phase"] = "completed" @@ -429,6 +483,37 @@ async def _run_agent_task( pass +@router.get("/by-scan/{scan_id}") +async def get_agent_by_scan(scan_id: str): + """Look up agent status by scan_id (reverse lookup for ScanDetailsPage)""" + agent_id = scan_to_agent.get(scan_id) + if not agent_id: + raise HTTPException(status_code=404, detail="No agent found for this scan") + + if agent_id in agent_results: + result = agent_results[agent_id] + return { + "agent_id": agent_id, + "scan_id": scan_id, + "status": result["status"], + "mode": result.get("mode", "full_auto"), + "target": result["target"], + "progress": result.get("progress", 0), + "phase": result.get("phase", "unknown"), + "started_at": result.get("started_at"), + "completed_at": result.get("completed_at"), + "findings_count": len(result.get("findings", [])), + "findings": result.get("findings", []), + "rejected_findings_count": len(result.get("rejected_findings", [])), + "rejected_findings": result.get("rejected_findings", []), + "logs_count": len(result.get("logs", [])), + "report": result.get("report"), + "error": result.get("error") + } + + raise HTTPException(status_code=404, detail="Agent data no longer in memory") + + @router.get("/status/{agent_id}") async def get_agent_status(agent_id: str): """Get the status and results of an agent run - with database fallback""" @@ -449,6 +534,8 @@ async def get_agent_status(agent_id: str): "logs_count": len(result.get("logs", [])), "findings_count": len(result.get("findings", [])), "findings": result.get("findings", []), + "rejected_findings_count": len(result.get("rejected_findings", [])), + "rejected_findings": result.get("rejected_findings", []), "report": result.get("report"), "error": result.get("error") } @@ -495,10 +582,12 @@ async def _get_status_from_db(agent_id: str, scan_id: str): "evidence": getattr(v, 'poc_evidence', None) or "", "request": v.poc_request or "", "response": v.poc_response or "", - "poc_code": v.poc_payload or "", + "poc_code": getattr(v, 'poc_code', None) or v.poc_payload or "", "impact": v.impact or "", "remediation": v.remediation or "", "references": v.references or [], + "screenshots": getattr(v, 'screenshots', None) or [], + "url": getattr(v, 'url', None) or v.affected_endpoint or "", "ai_verified": True, "confidence": "high" } @@ -542,14 +631,14 @@ async def _get_status_from_db(agent_id: str, scan_id: str): @router.post("/stop/{agent_id}") async def stop_agent(agent_id: str): - """Stop a running agent scan and auto-generate report""" + """Stop a running agent scan, save all findings to DB, and generate report.""" if agent_id not in agent_results: raise HTTPException(status_code=404, detail="Agent not found") if agent_results[agent_id]["status"] != "running": return {"message": "Agent is not running", "status": agent_results[agent_id]["status"]} - # Cancel the agent + # Cancel the agent immediately if agent_id in agent_instances: agent_instances[agent_id].cancel() @@ -558,9 +647,10 @@ async def stop_agent(agent_id: str): agent_results[agent_id]["phase"] = "stopped" agent_results[agent_id]["completed_at"] = datetime.utcnow().isoformat() - # Update database and auto-generate report + # Update database: save findings + generate report scan_id = agent_to_scan.get(agent_id) report_id = None + target = agent_results[agent_id].get("target", "Unknown") if scan_id: try: @@ -573,47 +663,222 @@ async def stop_agent(agent_id: str): scan.status = "stopped" scan.completed_at = datetime.utcnow() - # Get findings count + # Save confirmed findings to DB (same as completion flow) findings = agent_results[agent_id].get("findings", []) - scan.total_vulnerabilities = len(findings) + severity_counts = {"critical": 0, "high": 0, "medium": 0, "low": 0, "info": 0} - # Count severities for finding in findings: - severity = finding.get("severity", "").lower() - if severity == "critical": - scan.critical_count = (scan.critical_count or 0) + 1 - elif severity == "high": - scan.high_count = (scan.high_count or 0) + 1 - elif severity == "medium": - scan.medium_count = (scan.medium_count or 0) + 1 - elif severity == "low": - scan.low_count = (scan.low_count or 0) + 1 - elif severity == "info": - scan.info_count = (scan.info_count or 0) + 1 + severity = finding.get("severity", "medium").lower() + if severity in severity_counts: + severity_counts[severity] += 1 + + vuln = Vulnerability( + scan_id=scan_id, + title=finding.get("title", finding.get("type", "Unknown")), + vulnerability_type=finding.get("vulnerability_type", finding.get("type", "unknown")), + severity=severity, + cvss_score=finding.get("cvss_score"), + cvss_vector=finding.get("cvss_vector"), + cwe_id=finding.get("cwe_id"), + description=finding.get("description", finding.get("evidence", "")), + affected_endpoint=finding.get("affected_endpoint", finding.get("endpoint", finding.get("url", target))), + poc_payload=finding.get("payload", finding.get("poc_payload", "")), + poc_parameter=finding.get("parameter", finding.get("poc_parameter", "")), + poc_evidence=finding.get("evidence", finding.get("poc_evidence", "")), + poc_request=str(finding.get("request", finding.get("poc_request", "")))[:5000], + poc_response=str(finding.get("response", finding.get("poc_response", "")))[:5000], + impact=finding.get("impact", ""), + remediation=finding.get("remediation", ""), + references=finding.get("references", []), + ai_analysis=finding.get("ai_analysis", finding.get("exploitation_steps", "")), + poc_code=finding.get("poc_code", ""), + screenshots=finding.get("screenshots", []), + url=finding.get("url", finding.get("affected_endpoint", "")), + parameter=finding.get("parameter", finding.get("poc_parameter", "")), + validation_status="ai_confirmed", + ) + db.add(vuln) + + # Save rejected findings to DB for manual review + rejected = agent_results[agent_id].get("rejected_findings", []) + for finding in rejected: + vuln = Vulnerability( + scan_id=scan_id, + title=finding.get("title", finding.get("type", "Unknown")), + vulnerability_type=finding.get("vulnerability_type", finding.get("type", "unknown")), + severity=finding.get("severity", "medium").lower(), + cvss_score=finding.get("cvss_score"), + cvss_vector=finding.get("cvss_vector"), + cwe_id=finding.get("cwe_id"), + description=finding.get("description", finding.get("evidence", "")), + affected_endpoint=finding.get("affected_endpoint", finding.get("endpoint", finding.get("url", target))), + poc_payload=finding.get("payload", finding.get("poc_payload", "")), + poc_parameter=finding.get("parameter", finding.get("poc_parameter", "")), + poc_evidence=finding.get("evidence", finding.get("poc_evidence", "")), + poc_request=str(finding.get("request", finding.get("poc_request", "")))[:5000], + poc_response=str(finding.get("response", finding.get("poc_response", "")))[:5000], + impact=finding.get("impact", ""), + remediation=finding.get("remediation", ""), + references=finding.get("references", []), + poc_code=finding.get("poc_code", ""), + screenshots=finding.get("screenshots", []), + url=finding.get("url", finding.get("affected_endpoint", "")), + parameter=finding.get("parameter", finding.get("poc_parameter", "")), + validation_status="ai_rejected", + ai_rejection_reason=finding.get("rejection_reason", ""), + ) + db.add(vuln) + + # Update scan counts (confirmed only) + scan.total_vulnerabilities = len(findings) + scan.critical_count = severity_counts["critical"] + scan.high_count = severity_counts["high"] + scan.medium_count = severity_counts["medium"] + scan.low_count = severity_counts["low"] + scan.info_count = severity_counts["info"] await db.commit() - # Auto-generate report - report = Report( + # Auto-generate report record + report_record = Report( scan_id=scan_id, - title=f"Agent Scan Report - {agent_results[agent_id].get('target', 'Unknown')}", + title=f"Agent Scan Report - {target}", format="json", - executive_summary=f"Automated security scan completed with {len(findings)} findings." + executive_summary=f"Security scan stopped with {len(findings)} confirmed and {len(rejected)} rejected findings." ) - db.add(report) + db.add(report_record) await db.commit() - await db.refresh(report) - report_id = report.id + await db.refresh(report_record) + report_id = report_record.id except Exception as e: - print(f"Error updating scan status: {e}") + print(f"Error updating scan status on stop: {e}") import traceback traceback.print_exc() return { "message": "Agent stopped successfully", "agent_id": agent_id, - "report_id": report_id + "report_id": report_id, + "findings_saved": len(agent_results[agent_id].get("findings", [])), + "rejected_saved": len(agent_results[agent_id].get("rejected_findings", [])), + } + + +@router.post("/pause/{agent_id}") +async def pause_agent(agent_id: str): + """Pause a running agent scan""" + if agent_id not in agent_results: + raise HTTPException(status_code=404, detail="Agent not found") + + if agent_results[agent_id]["status"] != "running": + return {"message": "Agent is not running", "status": agent_results[agent_id]["status"]} + + if agent_id in agent_instances: + agent_instances[agent_id].pause() + + # Save current phase before overwriting with "paused" + agent_results[agent_id]["last_phase"] = agent_results[agent_id].get("phase", "recon") + agent_results[agent_id]["status"] = "paused" + agent_results[agent_id]["phase"] = "paused" + + return {"message": "Agent paused", "agent_id": agent_id} + + +@router.post("/resume/{agent_id}") +async def resume_agent(agent_id: str): + """Resume a paused agent scan""" + if agent_id not in agent_results: + raise HTTPException(status_code=404, detail="Agent not found") + + if agent_results[agent_id]["status"] != "paused": + return {"message": "Agent is not paused", "status": agent_results[agent_id]["status"]} + + if agent_id in agent_instances: + agent_instances[agent_id].resume() + + agent_results[agent_id]["status"] = "running" + # Restore the phase that was active before pause + agent_results[agent_id]["phase"] = agent_results[agent_id].get("last_phase", "testing") + + return {"message": "Agent resumed", "agent_id": agent_id} + + +# Agent phase order for skip validation +AGENT_PHASE_ORDER = ["recon", "analysis", "testing", "enhancement", "completed"] + +# Map phase names from status strings to canonical phase keys +PHASE_NORMALIZE = { + "starting reconnaissance": "recon", + "reconnaissance complete": "recon", + "initial probe complete": "recon", + "endpoint discovery complete": "recon", + "parameter discovery complete": "recon", + "attack surface analyzed": "analysis", + "vulnerability testing complete": "testing", + "findings enhanced": "enhancement", + "assessment complete": "completed", +} + + +@router.post("/skip-to/{agent_id}/{target_phase}") +async def skip_agent_phase(agent_id: str, target_phase: str): + """Skip the current agent phase and jump to a target phase. + + Valid phases: recon, analysis, testing, enhancement, completed + Can only skip forward (to a phase ahead of current). + """ + if agent_id not in agent_results: + raise HTTPException(status_code=404, detail="Agent not found") + + agent_status = agent_results[agent_id]["status"] + if agent_status not in ("running", "paused"): + raise HTTPException(status_code=400, detail="Agent is not running or paused") + + if target_phase not in AGENT_PHASE_ORDER: + raise HTTPException( + status_code=400, + detail=f"Invalid phase '{target_phase}'. Valid: {', '.join(AGENT_PHASE_ORDER[1:])}" + ) + + # Get current phase and normalize it + current_raw = agent_results[agent_id].get("phase", "").lower() + # Handle "paused" phase — use the last known non-paused phase, default to recon + if current_raw in ("paused", "stopped"): + current_raw = agent_results[agent_id].get("last_phase", "recon") + current_phase = PHASE_NORMALIZE.get(current_raw, current_raw) + # Also try prefix match + for key in AGENT_PHASE_ORDER: + if key in current_phase: + current_phase = key + break + + cur_idx = AGENT_PHASE_ORDER.index(current_phase) if current_phase in AGENT_PHASE_ORDER else 0 + tgt_idx = AGENT_PHASE_ORDER.index(target_phase) + + if tgt_idx <= cur_idx: + raise HTTPException( + status_code=400, + detail=f"Cannot skip backward. Current: {current_phase}, target: {target_phase}" + ) + + # Signal the agent instance to skip + if agent_id in agent_instances: + # If paused, resume first so the skip can be processed + if agent_status == "paused": + agent_instances[agent_id].resume() + agent_results[agent_id]["status"] = "running" + success = agent_instances[agent_id].skip_to_phase(target_phase) + if not success: + raise HTTPException(status_code=500, detail="Failed to signal phase skip") + else: + raise HTTPException(status_code=400, detail="Agent instance not available for signaling") + + return { + "message": f"Skipping to phase: {target_phase}", + "agent_id": agent_id, + "from_phase": current_phase, + "target_phase": target_phase } @@ -1711,7 +1976,10 @@ async def _save_realtime_findings_to_db(session_id: str, session: Dict): impact=finding.get("impact", ""), remediation=finding.get("remediation", ""), references=finding.get("references", []), - ai_analysis=f"Identified during realtime session {session_id}" + ai_analysis=f"Identified during realtime session {session_id}", + screenshots=finding.get("screenshots", []), + url=finding.get("url", finding.get("affected_endpoint", "")), + parameter=finding.get("parameter", "") ) db.add(vuln) @@ -1809,6 +2077,29 @@ async def generate_realtime_report(session_id: str, format: str = "json"): scan_results=tool_results ) + # Save to a per-report folder with screenshots + import shutil + from pathlib import Path + timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S") + target_name = session["target"].replace("://", "_").replace("/", "_").rstrip("_")[:40] + report_dir = Path("reports") / f"report_{target_name}_{timestamp}" + report_dir.mkdir(parents=True, exist_ok=True) + (report_dir / f"report_{timestamp}.html").write_text(html_content) + + # Copy screenshots into report folder + screenshots_src = Path("reports") / "screenshots" + if screenshots_src.exists(): + screenshots_dest = report_dir / "screenshots" + for finding in findings: + fid = finding.get("id", "") + if fid: + src_dir = screenshots_src / str(fid) + if src_dir.exists(): + dest_dir = screenshots_dest / str(fid) + dest_dir.mkdir(parents=True, exist_ok=True) + for ss_file in src_dir.glob("*.png"): + shutil.copy2(ss_file, dest_dir / ss_file.name) + return HTMLResponse(content=html_content, media_type="text/html") return { diff --git a/backend/api/v1/reports.py b/backend/api/v1/reports.py index 1d9067f..a900dc4 100644 --- a/backend/api/v1/reports.py +++ b/backend/api/v1/reports.py @@ -64,6 +64,19 @@ async def generate_report( ) vulnerabilities = vulns_result.scalars().all() + # Try to get tool_executions from agent in-memory results + tool_executions = [] + try: + from backend.api.v1.agent import scan_to_agent, agent_results + agent_id = scan_to_agent.get(report_data.scan_id) + if agent_id and agent_id in agent_results: + tool_executions = agent_results[agent_id].get("tool_executions", []) + if not tool_executions: + rpt = agent_results[agent_id].get("report", {}) + tool_executions = rpt.get("tool_executions", []) if isinstance(rpt, dict) else [] + except Exception: + pass + # Generate report generator = ReportGenerator() report_path, executive_summary = await generator.generate( @@ -73,7 +86,8 @@ async def generate_report( title=report_data.title, include_executive_summary=report_data.include_executive_summary, include_poc=report_data.include_poc, - include_remediation=report_data.include_remediation + include_remediation=report_data.include_remediation, + tool_executions=tool_executions, ) # Save report record @@ -91,6 +105,63 @@ async def generate_report( return ReportResponse(**report.to_dict()) +@router.post("/ai-generate", response_model=ReportResponse) +async def generate_ai_report( + report_data: ReportGenerate, + db: AsyncSession = Depends(get_db) +): + """Generate an AI-enhanced report with LLM-written executive summary and per-finding analysis.""" + # Get scan + scan_result = await db.execute(select(Scan).where(Scan.id == report_data.scan_id)) + scan = scan_result.scalar_one_or_none() + + if not scan: + raise HTTPException(status_code=404, detail="Scan not found") + + # Get vulnerabilities + vulns_result = await db.execute( + select(Vulnerability).where(Vulnerability.scan_id == report_data.scan_id) + ) + vulnerabilities = vulns_result.scalars().all() + + # Try to get tool_executions from agent in-memory results + tool_executions = [] + try: + from backend.api.v1.agent import scan_to_agent, agent_results + agent_id = scan_to_agent.get(report_data.scan_id) + if agent_id and agent_id in agent_results: + tool_executions = agent_results[agent_id].get("tool_executions", []) + # Also check nested report + if not tool_executions: + rpt = agent_results[agent_id].get("report", {}) + tool_executions = rpt.get("tool_executions", []) if isinstance(rpt, dict) else [] + except Exception: + pass + + # Generate AI report + generator = ReportGenerator() + report_path, ai_summary = await generator.generate_ai_report( + scan=scan, + vulnerabilities=vulnerabilities, + tool_executions=tool_executions, + title=report_data.title, + ) + + # Save report record + report = Report( + scan_id=scan.id, + title=report_data.title or f"AI Report - {scan.name}", + format="html", + file_path=str(report_path), + executive_summary=ai_summary[:2000] if ai_summary else None + ) + db.add(report) + await db.commit() + await db.refresh(report) + + return ReportResponse(**report.to_dict()) + + @router.get("/{report_id}", response_model=ReportResponse) async def get_report(report_id: str, db: AsyncSession = Depends(get_db)): """Get report details""" @@ -187,6 +258,100 @@ async def download_report( ) +@router.get("/{report_id}/download-zip") +async def download_report_zip( + report_id: str, + db: AsyncSession = Depends(get_db) +): + """Download report as ZIP with screenshots included""" + import zipfile + import tempfile + import hashlib + + result = await db.execute(select(Report).where(Report.id == report_id)) + report = result.scalar_one_or_none() + + if not report: + raise HTTPException(status_code=404, detail="Report not found") + + scan_result = await db.execute(select(Scan).where(Scan.id == report.scan_id)) + scan = scan_result.scalar_one_or_none() + + if not scan: + raise HTTPException(status_code=404, detail="Scan not found for report") + + vulns_result = await db.execute( + select(Vulnerability).where(Vulnerability.scan_id == report.scan_id) + ) + vulnerabilities = vulns_result.scalars().all() + + # Generate HTML report + generator = ReportGenerator() + report_path, _ = await generator.generate( + scan=scan, + vulnerabilities=vulnerabilities, + format="html", + title=report.title + ) + + # Collect screenshots (use absolute path via settings.BASE_DIR) + # Check scan-scoped path first, then legacy flat path + screenshots_base = settings.BASE_DIR / "reports" / "screenshots" + scan_id_str = str(scan.id) if scan else None + screenshot_files = [] + for vuln in vulnerabilities: + # Finding ID is md5(vuln_type+url+param)[:8] + vuln_url = getattr(vuln, 'url', None) or vuln.affected_endpoint or '' + vuln_param = getattr(vuln, 'parameter', None) or getattr(vuln, 'poc_parameter', None) or '' + finding_id = hashlib.md5( + f"{vuln.vulnerability_type}{vuln_url}{vuln_param}".encode() + ).hexdigest()[:8] + # Scan-scoped path: reports/screenshots/{scan_id}/{finding_id}/ + finding_dir = None + if scan_id_str: + scan_dir = screenshots_base / scan_id_str / finding_id + if scan_dir.exists(): + finding_dir = scan_dir + if not finding_dir: + legacy_dir = screenshots_base / finding_id + if legacy_dir.exists(): + finding_dir = legacy_dir + if finding_dir: + for img in finding_dir.glob("*.png"): + screenshot_files.append((img, f"screenshots/{finding_id}/{img.name}")) + # Also include base64 screenshots from DB as files in the ZIP + db_screenshots = getattr(vuln, 'screenshots', None) or [] + for idx, ss in enumerate(db_screenshots): + if isinstance(ss, str) and ss.startswith("data:image/"): + # Will be embedded in HTML, but also save as file + import base64 as b64 + try: + b64_data = ss.split(",", 1)[1] + img_bytes = b64.b64decode(b64_data) + img_name = f"screenshots/{finding_id}/evidence_{idx+1}.png" + # Write to temp for ZIP inclusion + tmp_img = Path(tempfile.gettempdir()) / f"ss_{finding_id}_{idx}.png" + tmp_img.write_bytes(img_bytes) + screenshot_files.append((tmp_img, img_name)) + except Exception: + pass + + # Create ZIP + zip_name = Path(report_path).stem + ".zip" + zip_path = Path(tempfile.gettempdir()) / zip_name + + with zipfile.ZipFile(str(zip_path), 'w', zipfile.ZIP_DEFLATED) as zf: + zf.write(report_path, "report.html") + for src_path, arc_name in screenshot_files: + zf.write(str(src_path), arc_name) + + return FileResponse( + path=str(zip_path), + media_type="application/zip", + filename=zip_name + ) + + @router.delete("/{report_id}") async def delete_report(report_id: str, db: AsyncSession = Depends(get_db)): """Delete a report""" diff --git a/backend/api/v1/sandbox.py b/backend/api/v1/sandbox.py new file mode 100644 index 0000000..9171ab5 --- /dev/null +++ b/backend/api/v1/sandbox.py @@ -0,0 +1,130 @@ +""" +NeuroSploit v3 - Sandbox Container Management API + +Real-time monitoring and management of per-scan Kali Linux containers. +""" + +from datetime import datetime +from fastapi import APIRouter, HTTPException + +router = APIRouter() + + +def _docker_available() -> bool: + try: + import docker + docker.from_env().ping() + return True + except Exception: + return False + + +@router.get("/") +async def list_sandboxes(): + """List all sandbox containers with pool status.""" + try: + from core.container_pool import get_pool + pool = get_pool() + except Exception as e: + return { + "pool": { + "active": 0, + "max_concurrent": 0, + "image": "neurosploit-kali:latest", + "container_ttl_minutes": 60, + "docker_available": _docker_available(), + }, + "containers": [], + "error": str(e), + } + + sandboxes = pool.list_sandboxes() + now = datetime.utcnow() + + containers = [] + for info in sandboxes.values(): + created = info.get("created_at") + uptime = 0.0 + if created: + try: + dt = datetime.fromisoformat(created) + uptime = (now - dt).total_seconds() + except Exception: + pass + containers.append({ + **info, + "uptime_seconds": uptime, + }) + + return { + "pool": { + "active": pool.active_count, + "max_concurrent": pool.max_concurrent, + "image": pool.image, + "container_ttl_minutes": int(pool.container_ttl.total_seconds() / 60), + "docker_available": _docker_available(), + }, + "containers": containers, + } + + +@router.get("/{scan_id}") +async def get_sandbox(scan_id: str): + """Get health check for a specific sandbox container.""" + try: + from core.container_pool import get_pool + pool = get_pool() + except Exception as e: + raise HTTPException(status_code=503, detail=str(e)) + + sandboxes = pool.list_sandboxes() + if scan_id not in sandboxes: + raise HTTPException(status_code=404, detail=f"No sandbox for scan {scan_id}") + + sb = pool._sandboxes.get(scan_id) + if not sb: + raise HTTPException(status_code=404, detail=f"Sandbox instance not found") + + health = await sb.health_check() + return health + + +@router.delete("/{scan_id}") +async def destroy_sandbox(scan_id: str): + """Destroy a specific sandbox container.""" + try: + from core.container_pool import get_pool + pool = get_pool() + except Exception as e: + raise HTTPException(status_code=503, detail=str(e)) + + sandboxes = pool.list_sandboxes() + if scan_id not in sandboxes: + raise HTTPException(status_code=404, detail=f"No sandbox for scan {scan_id}") + + await pool.destroy(scan_id) + return {"message": f"Sandbox for scan {scan_id} destroyed", "scan_id": scan_id} + + +@router.post("/cleanup") +async def cleanup_expired(): + """Remove containers that have exceeded their TTL.""" + try: + from core.container_pool import get_pool + pool = get_pool() + await pool.cleanup_expired() + return {"message": "Expired containers cleaned up"} + except Exception as e: + raise HTTPException(status_code=503, detail=str(e)) + + +@router.post("/cleanup-orphans") +async def cleanup_orphans(): + """Remove orphan containers not tracked by the pool.""" + try: + from core.container_pool import get_pool + pool = get_pool() + await pool.cleanup_orphans() + return {"message": "Orphan containers cleaned up"} + except Exception as e: + raise HTTPException(status_code=503, detail=str(e)) diff --git a/backend/api/v1/scans.py b/backend/api/v1/scans.py index 71df6e1..329655c 100644 --- a/backend/api/v1/scans.py +++ b/backend/api/v1/scans.py @@ -4,6 +4,7 @@ NeuroSploit v3 - Scans API Endpoints from typing import List, Optional from datetime import datetime from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks +from pydantic import BaseModel from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select, func from urllib.parse import urlparse @@ -11,7 +12,7 @@ from urllib.parse import urlparse from backend.db.database import get_db from backend.models import Scan, Target, Endpoint, Vulnerability from backend.schemas.scan import ScanCreate, ScanUpdate, ScanResponse, ScanListResponse, ScanProgress -from backend.services.scan_service import run_scan_task +from backend.services.scan_service import run_scan_task, skip_to_phase as _skip_to_phase, PHASE_ORDER router = APIRouter() @@ -177,6 +178,7 @@ async def start_scan( async def stop_scan(scan_id: str, db: AsyncSession = Depends(get_db)): """Stop a running scan and save partial results""" from backend.api.websocket import manager as ws_manager + from backend.api.v1.agent import scan_to_agent, agent_instances, agent_results result = await db.execute(select(Scan).where(Scan.id == scan_id)) scan = result.scalar_one_or_none() @@ -184,8 +186,16 @@ async def stop_scan(scan_id: str, db: AsyncSession = Depends(get_db)): if not scan: raise HTTPException(status_code=404, detail="Scan not found") - if scan.status != "running": - raise HTTPException(status_code=400, detail="Scan is not running") + if scan.status not in ("running", "paused"): + raise HTTPException(status_code=400, detail="Scan is not running or paused") + + # Signal the running agent to stop + agent_id = scan_to_agent.get(scan_id) + if agent_id and agent_id in agent_instances: + agent_instances[agent_id].cancel() + if agent_id in agent_results: + agent_results[agent_id]["status"] = "stopped" + agent_results[agent_id]["phase"] = "stopped" # Update scan status scan.status = "stopped" @@ -259,6 +269,132 @@ async def stop_scan(scan_id: str, db: AsyncSession = Depends(get_db)): } +@router.post("/{scan_id}/pause") +async def pause_scan(scan_id: str, db: AsyncSession = Depends(get_db)): + """Pause a running scan""" + from backend.api.websocket import manager as ws_manager + from backend.api.v1.agent import scan_to_agent, agent_instances, agent_results + + result = await db.execute(select(Scan).where(Scan.id == scan_id)) + scan = result.scalar_one_or_none() + + if not scan: + raise HTTPException(status_code=404, detail="Scan not found") + + if scan.status != "running": + raise HTTPException(status_code=400, detail="Scan is not running") + + # Signal the agent to pause + agent_id = scan_to_agent.get(scan_id) + if agent_id and agent_id in agent_instances: + agent_instances[agent_id].pause() + if agent_id in agent_results: + agent_results[agent_id]["status"] = "paused" + agent_results[agent_id]["phase"] = "paused" + + scan.status = "paused" + scan.current_phase = "paused" + await db.commit() + + await ws_manager.broadcast_log(scan_id, "warning", "Scan paused by user") + + return {"message": "Scan paused", "scan_id": scan_id} + + +@router.post("/{scan_id}/resume") +async def resume_scan(scan_id: str, db: AsyncSession = Depends(get_db)): + """Resume a paused scan""" + from backend.api.websocket import manager as ws_manager + from backend.api.v1.agent import scan_to_agent, agent_instances, agent_results + + result = await db.execute(select(Scan).where(Scan.id == scan_id)) + scan = result.scalar_one_or_none() + + if not scan: + raise HTTPException(status_code=404, detail="Scan not found") + + if scan.status != "paused": + raise HTTPException(status_code=400, detail="Scan is not paused") + + # Signal the agent to resume + agent_id = scan_to_agent.get(scan_id) + if agent_id and agent_id in agent_instances: + agent_instances[agent_id].resume() + if agent_id in agent_results: + agent_results[agent_id]["status"] = "running" + agent_results[agent_id]["phase"] = "testing" + + scan.status = "running" + scan.current_phase = "testing" + await db.commit() + + await ws_manager.broadcast_log(scan_id, "info", "Scan resumed by user") + + return {"message": "Scan resumed", "scan_id": scan_id} + + +@router.post("/{scan_id}/skip-to/{target_phase}") +async def skip_to_phase_endpoint(scan_id: str, target_phase: str, db: AsyncSession = Depends(get_db)): + """Skip the current scan phase and jump to a target phase. + + Valid phases: recon, analyzing, testing, completed + Can only skip forward (to a phase ahead of current). + """ + result = await db.execute(select(Scan).where(Scan.id == scan_id)) + scan = result.scalar_one_or_none() + + if not scan: + raise HTTPException(status_code=404, detail="Scan not found") + + if scan.status not in ("running", "paused"): + raise HTTPException(status_code=400, detail="Scan is not running or paused") + + # If paused, resume first so the skip can be processed + if scan.status == "paused": + from backend.api.v1.agent import scan_to_agent, agent_instances, agent_results + agent_id = scan_to_agent.get(scan_id) + if agent_id and agent_id in agent_instances: + agent_instances[agent_id].resume() + if agent_id in agent_results: + agent_results[agent_id]["status"] = "running" + agent_results[agent_id]["phase"] = agent_results[agent_id].get("last_phase", "testing") + scan.status = "running" + await db.commit() + + if target_phase not in PHASE_ORDER: + raise HTTPException( + status_code=400, + detail=f"Invalid phase '{target_phase}'. Valid: {', '.join(PHASE_ORDER[1:])}" + ) + + # Validate forward skip + current_idx = PHASE_ORDER.index(scan.current_phase) if scan.current_phase in PHASE_ORDER else 0 + target_idx = PHASE_ORDER.index(target_phase) + + if target_idx <= current_idx: + raise HTTPException( + status_code=400, + detail=f"Cannot skip backward. Current: {scan.current_phase}, target: {target_phase}" + ) + + # Signal the running scan to skip + success = _skip_to_phase(scan_id, target_phase) + if not success: + raise HTTPException(status_code=500, detail="Failed to signal phase skip") + + # Broadcast via WebSocket + from backend.api.websocket import manager as ws_manager + await ws_manager.broadcast_log(scan_id, "warning", f">> User requested skip to phase: {target_phase}") + await ws_manager.broadcast_phase_change(scan_id, f"skipping_to_{target_phase}") + + return { + "message": f"Skipping to phase: {target_phase}", + "scan_id": scan_id, + "from_phase": scan.current_phase, + "target_phase": target_phase + } + + @router.get("/{scan_id}/status", response_model=ScanProgress) async def get_scan_status(scan_id: str, db: AsyncSession = Depends(get_db)): """Get scan progress and status""" @@ -369,3 +505,68 @@ async def get_scan_vulnerabilities( "page": page, "per_page": per_page } + + +class ValidationRequest(BaseModel): + validation_status: str # "validated" | "false_positive" | "ai_confirmed" | "ai_rejected" | "pending_review" + notes: Optional[str] = None + + +@router.patch("/vulnerabilities/{vuln_id}/validate") +async def validate_vulnerability( + vuln_id: str, + body: ValidationRequest, + db: AsyncSession = Depends(get_db) +): + """Manually validate or reject a vulnerability finding""" + valid_statuses = {"validated", "false_positive", "ai_confirmed", "ai_rejected", "pending_review"} + if body.validation_status not in valid_statuses: + raise HTTPException(status_code=400, detail=f"Invalid status. Must be one of: {', '.join(valid_statuses)}") + + result = await db.execute(select(Vulnerability).where(Vulnerability.id == vuln_id)) + vuln = result.scalar_one_or_none() + + if not vuln: + raise HTTPException(status_code=404, detail="Vulnerability not found") + + old_status = vuln.validation_status or "ai_confirmed" + vuln.validation_status = body.validation_status + if body.notes: + vuln.ai_rejection_reason = body.notes + + # Update scan severity counts when validation status changes + scan_result = await db.execute(select(Scan).where(Scan.id == vuln.scan_id)) + scan = scan_result.scalar_one_or_none() + + if scan: + sev = vuln.severity + # If changing from rejected to validated: add to counts + if old_status == "ai_rejected" and body.validation_status == "validated": + scan.total_vulnerabilities = (scan.total_vulnerabilities or 0) + 1 + if sev == "critical": + scan.critical_count = (scan.critical_count or 0) + 1 + elif sev == "high": + scan.high_count = (scan.high_count or 0) + 1 + elif sev == "medium": + scan.medium_count = (scan.medium_count or 0) + 1 + elif sev == "low": + scan.low_count = (scan.low_count or 0) + 1 + elif sev == "info": + scan.info_count = (scan.info_count or 0) + 1 + # If changing from confirmed to false_positive: subtract from counts + elif old_status in ("ai_confirmed", "validated") and body.validation_status == "false_positive": + scan.total_vulnerabilities = max(0, (scan.total_vulnerabilities or 0) - 1) + if sev == "critical": + scan.critical_count = max(0, (scan.critical_count or 0) - 1) + elif sev == "high": + scan.high_count = max(0, (scan.high_count or 0) - 1) + elif sev == "medium": + scan.medium_count = max(0, (scan.medium_count or 0) - 1) + elif sev == "low": + scan.low_count = max(0, (scan.low_count or 0) - 1) + elif sev == "info": + scan.info_count = max(0, (scan.info_count or 0) - 1) + + await db.commit() + + return {"message": "Vulnerability validation updated", "vulnerability": vuln.to_dict()} diff --git a/backend/api/v1/scheduler.py b/backend/api/v1/scheduler.py new file mode 100644 index 0000000..988a630 --- /dev/null +++ b/backend/api/v1/scheduler.py @@ -0,0 +1,140 @@ +""" +NeuroSploit v3 - Scheduler API Router + +CRUD endpoints for managing scheduled scan jobs. +""" + +import json +from pathlib import Path +from fastapi import APIRouter, HTTPException, Request +from pydantic import BaseModel +from typing import Optional, List, Dict + +router = APIRouter() + +CONFIG_PATH = Path(__file__).parent.parent.parent.parent / "config" / "config.json" + + +class ScheduleJobRequest(BaseModel): + """Request model for creating a scheduled job.""" + job_id: str + target: str + scan_type: str = "quick" + cron_expression: Optional[str] = None + interval_minutes: Optional[int] = None + agent_role: Optional[str] = None + llm_profile: Optional[str] = None + + +class ScheduleJobResponse(BaseModel): + """Response model for a scheduled job.""" + id: str + target: str + scan_type: str + schedule: str + status: str + next_run: Optional[str] = None + last_run: Optional[str] = None + run_count: int = 0 + + +@router.get("/", response_model=List[Dict]) +async def list_scheduled_jobs(request: Request): + """List all scheduled scan jobs.""" + scheduler = getattr(request.app.state, 'scheduler', None) + if not scheduler: + return [] + return scheduler.list_jobs() + + +@router.post("/", response_model=Dict) +async def create_scheduled_job(job: ScheduleJobRequest, request: Request): + """Create a new scheduled scan job.""" + scheduler = getattr(request.app.state, 'scheduler', None) + if not scheduler: + raise HTTPException(status_code=503, detail="Scheduler not available") + + if not job.cron_expression and not job.interval_minutes: + raise HTTPException( + status_code=400, + detail="Either cron_expression or interval_minutes must be provided" + ) + + result = scheduler.add_job( + job_id=job.job_id, + target=job.target, + scan_type=job.scan_type, + cron_expression=job.cron_expression, + interval_minutes=job.interval_minutes, + agent_role=job.agent_role, + llm_profile=job.llm_profile + ) + + if "error" in result: + raise HTTPException(status_code=400, detail=result["error"]) + + return result + + +@router.delete("/{job_id}") +async def delete_scheduled_job(job_id: str, request: Request): + """Delete a scheduled scan job.""" + scheduler = getattr(request.app.state, 'scheduler', None) + if not scheduler: + raise HTTPException(status_code=503, detail="Scheduler not available") + + success = scheduler.remove_job(job_id) + if not success: + raise HTTPException(status_code=404, detail=f"Job '{job_id}' not found") + + return {"message": f"Job '{job_id}' deleted", "id": job_id} + + +@router.post("/{job_id}/pause") +async def pause_scheduled_job(job_id: str, request: Request): + """Pause a scheduled scan job.""" + scheduler = getattr(request.app.state, 'scheduler', None) + if not scheduler: + raise HTTPException(status_code=503, detail="Scheduler not available") + + success = scheduler.pause_job(job_id) + if not success: + raise HTTPException(status_code=404, detail=f"Job '{job_id}' not found") + + return {"message": f"Job '{job_id}' paused", "id": job_id, "status": "paused"} + + +@router.post("/{job_id}/resume") +async def resume_scheduled_job(job_id: str, request: Request): + """Resume a paused scheduled scan job.""" + scheduler = getattr(request.app.state, 'scheduler', None) + if not scheduler: + raise HTTPException(status_code=503, detail="Scheduler not available") + + success = scheduler.resume_job(job_id) + if not success: + raise HTTPException(status_code=404, detail=f"Job '{job_id}' not found") + + return {"message": f"Job '{job_id}' resumed", "id": job_id, "status": "active"} + + +@router.get("/agent-roles", response_model=List[Dict]) +async def get_agent_roles(): + """Return available agent roles from config.json for scheduler dropdown.""" + try: + if not CONFIG_PATH.exists(): + return [] + config = json.loads(CONFIG_PATH.read_text()) + roles = config.get("agent_roles", {}) + result = [] + for role_id, role_data in roles.items(): + if role_data.get("enabled", True): + result.append({ + "id": role_id, + "name": role_id.replace("_", " ").title(), + "description": role_data.get("description", ""), + "tools": role_data.get("tools_allowed", []), + }) + return result + except Exception: + return [] diff --git a/backend/api/v1/settings.py b/backend/api/v1/settings.py index eecec85..a737cd8 100644 --- a/backend/api/v1/settings.py +++ b/backend/api/v1/settings.py @@ -1,7 +1,10 @@ """ NeuroSploit v3 - Settings API Endpoints """ -from typing import Optional +import os +import re +from pathlib import Path +from typing import Optional, Dict from fastapi import APIRouter, Depends, HTTPException from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select, delete, text @@ -12,16 +15,69 @@ from backend.models import Scan, Target, Endpoint, Vulnerability, VulnerabilityT router = APIRouter() +# Path to .env file (project root) +ENV_FILE_PATH = Path(__file__).parent.parent.parent.parent / ".env" + + +def _update_env_file(updates: Dict[str, str]) -> bool: + """ + Update key=value pairs in the .env file without breaking formatting. + - If the key exists (even commented out), update its value + - If the key doesn't exist, append it + - Preserves comments and blank lines + """ + if not ENV_FILE_PATH.exists(): + return False + + try: + lines = ENV_FILE_PATH.read_text().splitlines() + updated_keys = set() + + new_lines = [] + for line in lines: + stripped = line.strip() + matched = False + + for key, value in updates.items(): + # Match: KEY=..., # KEY=..., #KEY=... + pattern = rf'^#?\s*{re.escape(key)}\s*=' + if re.match(pattern, stripped): + # Replace with uncommented key=value + new_lines.append(f"{key}={value}") + updated_keys.add(key) + matched = True + break + + if not matched: + new_lines.append(line) + + # Append any keys that weren't found in existing file + for key, value in updates.items(): + if key not in updated_keys: + new_lines.append(f"{key}={value}") + + # Write back with trailing newline + ENV_FILE_PATH.write_text("\n".join(new_lines) + "\n") + return True + except Exception as e: + print(f"Warning: Failed to update .env file: {e}") + return False + class SettingsUpdate(BaseModel): """Settings update schema""" llm_provider: Optional[str] = None anthropic_api_key: Optional[str] = None openai_api_key: Optional[str] = None + openrouter_api_key: Optional[str] = None max_concurrent_scans: Optional[int] = None aggressive_mode: Optional[bool] = None default_scan_type: Optional[str] = None recon_enabled_by_default: Optional[bool] = None + enable_model_routing: Optional[bool] = None + enable_knowledge_augmentation: Optional[bool] = None + enable_browser_validation: Optional[bool] = None + max_output_tokens: Optional[int] = None class SettingsResponse(BaseModel): @@ -29,56 +85,118 @@ class SettingsResponse(BaseModel): llm_provider: str = "claude" has_anthropic_key: bool = False has_openai_key: bool = False + has_openrouter_key: bool = False max_concurrent_scans: int = 3 aggressive_mode: bool = False default_scan_type: str = "full" recon_enabled_by_default: bool = True + enable_model_routing: bool = False + enable_knowledge_augmentation: bool = False + enable_browser_validation: bool = False + max_output_tokens: Optional[int] = None -# In-memory settings storage (in production, use database or config file) -_settings = { - "llm_provider": "claude", - "anthropic_api_key": "", - "openai_api_key": "", - "max_concurrent_scans": 3, - "aggressive_mode": False, - "default_scan_type": "full", - "recon_enabled_by_default": True -} +def _load_settings_from_env() -> dict: + """ + Load settings from environment variables / .env file on startup. + This ensures settings persist across server restarts and browser sessions. + """ + from dotenv import load_dotenv + # Re-read .env file to pick up disk-persisted values + if ENV_FILE_PATH.exists(): + load_dotenv(ENV_FILE_PATH, override=True) + + def _env_bool(key: str, default: bool = False) -> bool: + val = os.getenv(key, "").strip().lower() + if val in ("true", "1", "yes"): + return True + if val in ("false", "0", "no"): + return False + return default + + def _env_int(key: str, default=None): + val = os.getenv(key, "").strip() + if val: + try: + return int(val) + except ValueError: + pass + return default + + # Detect provider from which keys are set + provider = "claude" + if os.getenv("ANTHROPIC_API_KEY"): + provider = "claude" + elif os.getenv("OPENAI_API_KEY"): + provider = "openai" + elif os.getenv("OPENROUTER_API_KEY"): + provider = "openrouter" + + return { + "llm_provider": provider, + "anthropic_api_key": os.getenv("ANTHROPIC_API_KEY", ""), + "openai_api_key": os.getenv("OPENAI_API_KEY", ""), + "openrouter_api_key": os.getenv("OPENROUTER_API_KEY", ""), + "max_concurrent_scans": _env_int("MAX_CONCURRENT_SCANS", 3), + "aggressive_mode": _env_bool("AGGRESSIVE_MODE", False), + "default_scan_type": os.getenv("DEFAULT_SCAN_TYPE", "full"), + "recon_enabled_by_default": _env_bool("RECON_ENABLED_BY_DEFAULT", True), + "enable_model_routing": _env_bool("ENABLE_MODEL_ROUTING", False), + "enable_knowledge_augmentation": _env_bool("ENABLE_KNOWLEDGE_AUGMENTATION", False), + "enable_browser_validation": _env_bool("ENABLE_BROWSER_VALIDATION", False), + "max_output_tokens": _env_int("MAX_OUTPUT_TOKENS", None), + } + + +# Load settings from .env on module import (server start) +_settings = _load_settings_from_env() @router.get("", response_model=SettingsResponse) async def get_settings(): """Get current settings""" + import os return SettingsResponse( llm_provider=_settings["llm_provider"], - has_anthropic_key=bool(_settings["anthropic_api_key"]), - has_openai_key=bool(_settings["openai_api_key"]), + has_anthropic_key=bool(_settings["anthropic_api_key"] or os.getenv("ANTHROPIC_API_KEY")), + has_openai_key=bool(_settings["openai_api_key"] or os.getenv("OPENAI_API_KEY")), + has_openrouter_key=bool(_settings["openrouter_api_key"] or os.getenv("OPENROUTER_API_KEY")), max_concurrent_scans=_settings["max_concurrent_scans"], aggressive_mode=_settings["aggressive_mode"], default_scan_type=_settings["default_scan_type"], - recon_enabled_by_default=_settings["recon_enabled_by_default"] + recon_enabled_by_default=_settings["recon_enabled_by_default"], + enable_model_routing=_settings["enable_model_routing"], + enable_knowledge_augmentation=_settings["enable_knowledge_augmentation"], + enable_browser_validation=_settings["enable_browser_validation"], + max_output_tokens=_settings["max_output_tokens"] ) @router.put("", response_model=SettingsResponse) async def update_settings(settings_data: SettingsUpdate): - """Update settings""" + """Update settings - persists to memory, env vars, AND .env file""" + env_updates: Dict[str, str] = {} + if settings_data.llm_provider is not None: _settings["llm_provider"] = settings_data.llm_provider if settings_data.anthropic_api_key is not None: _settings["anthropic_api_key"] = settings_data.anthropic_api_key - # Also update environment variable for LLM calls - import os if settings_data.anthropic_api_key: os.environ["ANTHROPIC_API_KEY"] = settings_data.anthropic_api_key + env_updates["ANTHROPIC_API_KEY"] = settings_data.anthropic_api_key if settings_data.openai_api_key is not None: _settings["openai_api_key"] = settings_data.openai_api_key - import os if settings_data.openai_api_key: os.environ["OPENAI_API_KEY"] = settings_data.openai_api_key + env_updates["OPENAI_API_KEY"] = settings_data.openai_api_key + + if settings_data.openrouter_api_key is not None: + _settings["openrouter_api_key"] = settings_data.openrouter_api_key + if settings_data.openrouter_api_key: + os.environ["OPENROUTER_API_KEY"] = settings_data.openrouter_api_key + env_updates["OPENROUTER_API_KEY"] = settings_data.openrouter_api_key if settings_data.max_concurrent_scans is not None: _settings["max_concurrent_scans"] = settings_data.max_concurrent_scans @@ -92,6 +210,34 @@ async def update_settings(settings_data: SettingsUpdate): if settings_data.recon_enabled_by_default is not None: _settings["recon_enabled_by_default"] = settings_data.recon_enabled_by_default + if settings_data.enable_model_routing is not None: + _settings["enable_model_routing"] = settings_data.enable_model_routing + val = str(settings_data.enable_model_routing).lower() + os.environ["ENABLE_MODEL_ROUTING"] = val + env_updates["ENABLE_MODEL_ROUTING"] = val + + if settings_data.enable_knowledge_augmentation is not None: + _settings["enable_knowledge_augmentation"] = settings_data.enable_knowledge_augmentation + val = str(settings_data.enable_knowledge_augmentation).lower() + os.environ["ENABLE_KNOWLEDGE_AUGMENTATION"] = val + env_updates["ENABLE_KNOWLEDGE_AUGMENTATION"] = val + + if settings_data.enable_browser_validation is not None: + _settings["enable_browser_validation"] = settings_data.enable_browser_validation + val = str(settings_data.enable_browser_validation).lower() + os.environ["ENABLE_BROWSER_VALIDATION"] = val + env_updates["ENABLE_BROWSER_VALIDATION"] = val + + if settings_data.max_output_tokens is not None: + _settings["max_output_tokens"] = settings_data.max_output_tokens + if settings_data.max_output_tokens: + os.environ["MAX_OUTPUT_TOKENS"] = str(settings_data.max_output_tokens) + env_updates["MAX_OUTPUT_TOKENS"] = str(settings_data.max_output_tokens) + + # Persist to .env file on disk + if env_updates: + _update_env_file(env_updates) + return await get_settings() diff --git a/backend/api/v1/terminal.py b/backend/api/v1/terminal.py new file mode 100644 index 0000000..597a061 --- /dev/null +++ b/backend/api/v1/terminal.py @@ -0,0 +1,568 @@ +""" +Terminal Agent API - Interactive infrastructure pentesting via AI chat + Docker sandbox. + +Provides session-based terminal interaction with AI-guided command execution, +exploitation path tracking, and VPN status monitoring. +""" + +import asyncio +import re +import time +import uuid +from datetime import datetime, timezone +from typing import Dict, List, Optional + +from fastapi import APIRouter, HTTPException +from pydantic import BaseModel + +from core.llm_manager import LLMManager +from core.sandbox_manager import get_sandbox + +router = APIRouter() + +# --------------------------------------------------------------------------- +# In-memory session store +# --------------------------------------------------------------------------- +terminal_sessions: Dict[str, Dict] = {} + +# --------------------------------------------------------------------------- +# Pre-built templates +# --------------------------------------------------------------------------- +TEMPLATES = { + "network_scanner": { + "name": "Network Scanner", + "description": "Host discovery, port scanning, and service detection", + "system_prompt": ( + "You are an expert network reconnaissance specialist. You guide the " + "operator through systematic host discovery, port scanning, and service " + "fingerprinting. Always suggest nmap flags appropriate for the situation, " + "explain output, and recommend next steps based on discovered services. " + "Prioritize stealth when asked and suggest timing/fragmentation options." + ), + "initial_commands": [ + "nmap -sn {target}", + "nmap -sV -sC -O -p- {target}", + "nmap -sU --top-ports 50 {target}", + ], + }, + "lateral_movement": { + "name": "Lateral Movement", + "description": "Pass-the-hash, SMB/WinRM pivoting, and SSH tunneling", + "system_prompt": ( + "You are a lateral movement specialist. You help the operator pivot " + "through compromised networks using techniques such as pass-the-hash, " + "SMB relay, WinRM sessions, SSH tunneling, and SOCKS proxying. Always " + "verify credentials before attempting pivots, suggest cleanup steps, " + "and track which hosts have been compromised." + ), + "initial_commands": [ + "crackmapexec smb {target} -u '' -p ''", + "crackmapexec smb {target} --shares -u '' -p ''", + "ssh -D 1080 -N -f user@{target}", + ], + }, + "privilege_escalation": { + "name": "Privilege Escalation", + "description": "SUID binaries, kernel exploits, cron jobs, and writable paths", + "system_prompt": ( + "You are a privilege escalation expert for Linux and Windows systems. " + "Guide the operator through enumeration of SUID/SGID binaries, kernel " + "version checks, misconfigured cron jobs, writable PATH directories, " + "sudo misconfigurations, and capability abuse. Suggest automated tools " + "like linpeas/winpeas when appropriate and explain each finding." + ), + "initial_commands": [ + "id && whoami && uname -a", + "find / -perm -4000 -type f 2>/dev/null", + "cat /etc/crontab && ls -la /etc/cron.*", + "echo $PATH | tr ':' '\\n' | xargs -I {} ls -ld {}", + ], + }, + "vpn_recon": { + "name": "VPN Reconnaissance", + "description": "VPN connection management and internal network discovery", + "system_prompt": ( + "You are a VPN and internal network reconnaissance specialist. You " + "help the operator connect to target VPNs, verify tunnel status, " + "discover internal subnets, and enumerate services behind the VPN. " + "Always confirm connectivity before proceeding with scans and suggest " + "appropriate scope for internal reconnaissance." + ), + "initial_commands": [ + "openvpn --config client.ovpn --daemon", + "ip addr show tun0", + "ip route | grep tun", + "nmap -sn 10.0.0.0/24", + ], + }, +} + +# --------------------------------------------------------------------------- +# Pydantic request / response models +# --------------------------------------------------------------------------- + +class CreateSessionRequest(BaseModel): + template_id: Optional[str] = None + target: Optional[str] = "" + name: Optional[str] = "" + + +class MessageRequest(BaseModel): + message: str + + +class ExecuteCommandRequest(BaseModel): + command: str + execution_method: str = "sandbox" # "sandbox" or "direct" + + +class ExploitationStepRequest(BaseModel): + description: str + command: Optional[str] = "" + result: Optional[str] = "" + step_type: str = "recon" # recon | exploit | pivot | escalate | action + + +class SessionSummary(BaseModel): + session_id: str + name: str + target: str + template_id: Optional[str] + status: str + created_at: str + messages_count: int + commands_count: int + + +class MessageResponse(BaseModel): + role: str + response: str + timestamp: str + suggested_commands: List[str] + + +class CommandResult(BaseModel): + command: str + exit_code: int + stdout: str + stderr: str + duration: float + execution_method: str + timestamp: str + + +class VPNStatus(BaseModel): + connected: bool + ip: Optional[str] = None + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _now_iso() -> str: + return datetime.now(timezone.utc).isoformat() + + +def _build_session( + session_id: str, + name: str, + target: str, + template_id: Optional[str], +) -> Dict: + return { + "session_id": session_id, + "name": name, + "target": target, + "template_id": template_id, + "status": "active", + "created_at": _now_iso(), + "messages": [], + "command_history": [], + "exploitation_path": [], + "vpn_status": {"connected": False, "ip": None}, + } + + +def _get_session(session_id: str) -> Dict: + session = terminal_sessions.get(session_id) + if not session: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + return session + + +def _build_context_string( + messages: List[Dict], + commands: List[Dict], + exploitation: List[Dict], +) -> str: + parts: List[str] = [] + + if messages: + parts.append("=== Recent Conversation ===") + for msg in messages: + role = msg.get("role", "unknown").upper() + parts.append(f"[{role}] {msg.get('content', '')}") + + if commands: + parts.append("\n=== Recent Command Results ===") + for cmd in commands: + parts.append( + f"$ {cmd['command']}\n" + f"Exit code: {cmd['exit_code']}\n" + f"Stdout: {cmd['stdout'][:500]}\n" + f"Stderr: {cmd['stderr'][:300]}" + ) + + if exploitation: + parts.append("\n=== Exploitation Path ===") + for i, step in enumerate(exploitation, 1): + parts.append( + f"Step {i} [{step['step_type']}]: {step['description']}" + ) + if step.get("command"): + parts.append(f" Command: {step['command']}") + if step.get("result"): + parts.append(f" Result: {step['result'][:300]}") + + return "\n".join(parts) + + +def _extract_suggested_commands(text: str) -> List[str]: + """Extract commands from backtick-fenced code blocks.""" + blocks = re.findall(r"```(?:bash|sh|shell)?\n?(.*?)```", text, re.DOTALL) + commands: List[str] = [] + for block in blocks: + for line in block.strip().splitlines(): + stripped = line.strip() + if stripped and not stripped.startswith("#"): + commands.append(stripped) + return commands + + +# --------------------------------------------------------------------------- +# Template endpoints +# --------------------------------------------------------------------------- + +@router.get("/templates") +async def list_templates(): + """List all available session templates.""" + result = [] + for tid, tmpl in TEMPLATES.items(): + result.append({ + "id": tid, + "name": tmpl["name"], + "description": tmpl["description"], + "initial_commands": tmpl["initial_commands"], + }) + return result + + +# --------------------------------------------------------------------------- +# Session CRUD +# --------------------------------------------------------------------------- + +@router.post("/session") +async def create_session(req: CreateSessionRequest): + """Create a new terminal session, optionally from a template.""" + session_id = str(uuid.uuid4()) + target = req.target or "" + template_id = req.template_id + + if template_id and template_id not in TEMPLATES: + raise HTTPException(status_code=400, detail=f"Unknown template: {template_id}") + + name = req.name or ( + TEMPLATES[template_id]["name"] if template_id else f"Session {session_id[:8]}" + ) + + session = _build_session(session_id, name, target, template_id) + + # Seed initial system message from template + if template_id: + tmpl = TEMPLATES[template_id] + session["messages"].append({ + "role": "system", + "content": tmpl["system_prompt"], + "timestamp": _now_iso(), + "metadata": {"template": template_id}, + }) + # Provide initial suggested commands with target interpolated + initial_cmds = [ + cmd.replace("{target}", target) for cmd in tmpl["initial_commands"] + ] + session["messages"].append({ + "role": "assistant", + "content": ( + f"Session initialised with the **{tmpl['name']}** template.\n\n" + f"Target: `{target or '(not set)'}`\n\n" + "Suggested starting commands:\n" + + "\n".join(f"```\n{c}\n```" for c in initial_cmds) + ), + "timestamp": _now_iso(), + "suggested_commands": initial_cmds, + }) + + terminal_sessions[session_id] = session + return session + + +@router.get("/sessions") +async def list_sessions(): + """Return lightweight summaries of every session.""" + summaries = [] + for sid, s in terminal_sessions.items(): + summaries.append( + SessionSummary( + session_id=sid, + name=s["name"], + target=s["target"], + template_id=s["template_id"], + status=s["status"], + created_at=s["created_at"], + messages_count=len(s["messages"]), + commands_count=len(s["command_history"]), + ).model_dump() + ) + return summaries + + +@router.get("/sessions/{session_id}") +async def get_session(session_id: str): + """Return the full session including messages, commands, and exploitation path.""" + return _get_session(session_id) + + +@router.delete("/sessions/{session_id}") +async def delete_session(session_id: str): + """Delete a terminal session.""" + if session_id not in terminal_sessions: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + del terminal_sessions[session_id] + return {"status": "deleted", "session_id": session_id} + + +# --------------------------------------------------------------------------- +# AI message interaction +# --------------------------------------------------------------------------- + +@router.post("/sessions/{session_id}/message") +async def send_message(session_id: str, req: MessageRequest): + """Send a user prompt to the AI and receive a response with suggested commands.""" + session = _get_session(session_id) + user_message = req.message.strip() + if not user_message: + raise HTTPException(status_code=400, detail="Message content cannot be empty") + + # Record user message + session["messages"].append({ + "role": "user", + "content": user_message, + "timestamp": _now_iso(), + "metadata": {}, + }) + + # Determine system prompt + template_id = session.get("template_id") + if template_id and template_id in TEMPLATES: + system_prompt = TEMPLATES[template_id]["system_prompt"] + else: + system_prompt = ( + "You are an expert infrastructure penetration tester. Help the " + "operator plan and execute attacks against the target. Suggest " + "concrete commands, explain their purpose, and interpret output. " + "Always wrap commands in fenced code blocks so they can be extracted." + ) + + # Build context window + context_messages = session["messages"][-20:] + context_cmds = session["command_history"][-10:] + exploitation = session["exploitation_path"] + context = _build_context_string(context_messages, context_cmds, exploitation) + + # Call LLM + try: + llm = LLMManager() + prompt = f"{context}\n\nUser: {user_message}" + response = await llm.generate(prompt, system_prompt) + except Exception as exc: + raise HTTPException(status_code=502, detail=f"LLM call failed: {exc}") + + suggested_commands = _extract_suggested_commands(response) + + # Record assistant response + session["messages"].append({ + "role": "assistant", + "content": response, + "timestamp": _now_iso(), + "suggested_commands": suggested_commands, + }) + + return MessageResponse( + role="assistant", + response=response, + timestamp=session["messages"][-1]["timestamp"], + suggested_commands=suggested_commands, + ).model_dump() + + +# --------------------------------------------------------------------------- +# Command execution +# --------------------------------------------------------------------------- + +@router.post("/sessions/{session_id}/execute") +async def execute_command(session_id: str, req: ExecuteCommandRequest): + """Execute a command in the Docker sandbox (fallback: direct shell).""" + session = _get_session(session_id) + command = req.command.strip() + if not command: + raise HTTPException(status_code=400, detail="Command cannot be empty") + + start = time.time() + stdout = "" + stderr = "" + exit_code = -1 + execution_method = "direct" + + # Use requested execution method + use_sandbox = req.execution_method == "sandbox" + + if use_sandbox: + try: + sandbox = await get_sandbox() + if sandbox and sandbox.is_available: + result = await sandbox.execute_raw(command) + stdout = result.stdout + stderr = result.stderr + exit_code = result.exit_code + execution_method = "sandbox" + except Exception: + pass # Fall through to direct execution + + # Fallback or direct execution requested + if execution_method != "sandbox": + try: + proc = await asyncio.create_subprocess_shell( + command, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + raw_stdout, raw_stderr = await asyncio.wait_for( + proc.communicate(), timeout=120 + ) + stdout = raw_stdout.decode(errors="replace") + stderr = raw_stderr.decode(errors="replace") + exit_code = proc.returncode or 0 + execution_method = "direct" + except asyncio.TimeoutError: + stderr = "Command timed out after 120 seconds" + exit_code = 124 + except Exception as exc: + stderr = str(exc) + exit_code = 1 + + duration = round(time.time() - start, 3) + + cmd_record = { + "command": command, + "exit_code": exit_code, + "stdout": stdout, + "stderr": stderr, + "duration": duration, + "execution_method": execution_method, + "timestamp": _now_iso(), + } + session["command_history"].append(cmd_record) + + # Mirror into messages for AI context continuity + output_preview = stdout[:2000] if stdout else stderr[:2000] + session["messages"].append({ + "role": "tool", + "content": f"$ {command}\n[exit {exit_code}] ({execution_method}, {duration}s)\n{output_preview}", + "timestamp": cmd_record["timestamp"], + "metadata": {"exit_code": exit_code, "execution_method": execution_method}, + }) + + return CommandResult(**cmd_record).model_dump() + + +# --------------------------------------------------------------------------- +# Exploitation path +# --------------------------------------------------------------------------- + +@router.post("/sessions/{session_id}/exploitation-path") +async def add_exploitation_step(session_id: str, req: ExploitationStepRequest): + """Add a manual step to the exploitation path timeline.""" + session = _get_session(session_id) + + valid_types = {"recon", "exploit", "pivot", "escalate", "action"} + if req.step_type not in valid_types: + raise HTTPException( + status_code=400, + detail=f"step_type must be one of {sorted(valid_types)}", + ) + + step = { + "description": req.description, + "command": req.command or "", + "result": req.result or "", + "timestamp": _now_iso(), + "step_type": req.step_type, + } + session["exploitation_path"].append(step) + return step + + +@router.get("/sessions/{session_id}/exploitation-path") +async def get_exploitation_path(session_id: str): + """Return the full exploitation path timeline.""" + session = _get_session(session_id) + return session["exploitation_path"] + + +# --------------------------------------------------------------------------- +# VPN status +# --------------------------------------------------------------------------- + +@router.get("/sessions/{session_id}/vpn-status") +async def get_vpn_status(session_id: str): + """Check OpenVPN process and tun0 interface status.""" + session = _get_session(session_id) + + connected = False + ip_addr: Optional[str] = None + + # Check for running openvpn process + try: + proc = await asyncio.create_subprocess_shell( + "pgrep -a openvpn", + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + raw_stdout, _ = await asyncio.wait_for(proc.communicate(), timeout=5) + if proc.returncode == 0 and raw_stdout.strip(): + connected = True + except Exception: + pass + + # Check tun0 interface for IP + if connected: + try: + proc = await asyncio.create_subprocess_shell( + "ip addr show tun0", + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + raw_stdout, _ = await asyncio.wait_for(proc.communicate(), timeout=5) + if proc.returncode == 0: + match = re.search( + r"inet\s+(\d+\.\d+\.\d+\.\d+)", raw_stdout.decode(errors="replace") + ) + if match: + ip_addr = match.group(1) + except Exception: + pass + + vpn = {"connected": connected, "ip": ip_addr} + session["vpn_status"] = vpn + return VPNStatus(**vpn).model_dump() diff --git a/backend/api/v1/vuln_lab.py b/backend/api/v1/vuln_lab.py new file mode 100644 index 0000000..3869548 --- /dev/null +++ b/backend/api/v1/vuln_lab.py @@ -0,0 +1,876 @@ +""" +NeuroSploit v3 - Vulnerability Lab API Endpoints + +Isolated vulnerability testing against labs, CTFs, and PortSwigger challenges. +Test individual vuln types one at a time and track results. +""" +from typing import Optional, Dict, List +from fastapi import APIRouter, HTTPException, BackgroundTasks +from pydantic import BaseModel, Field +from datetime import datetime +from sqlalchemy import select, func, text + +from backend.core.autonomous_agent import AutonomousAgent, OperationMode +from backend.core.vuln_engine.registry import VulnerabilityRegistry +from backend.db.database import async_session_factory +from backend.models import Scan, Target, Vulnerability, Endpoint, Report, VulnLabChallenge + +# Import agent.py's shared dicts so ScanDetailsPage can find our scans +from backend.api.v1.agent import ( + agent_results, agent_instances, agent_to_scan, scan_to_agent +) + +router = APIRouter() + +# In-memory tracking for running lab tests +lab_agents: Dict[str, AutonomousAgent] = {} +lab_results: Dict[str, Dict] = {} + + +# --- Request/Response Models --- + +class VulnLabRunRequest(BaseModel): + target_url: str = Field(..., description="Target URL to test (lab, CTF, etc.)") + vuln_type: str = Field(..., description="Vulnerability type to test (e.g. xss_reflected)") + challenge_name: Optional[str] = Field(None, description="Name of the lab/challenge") + auth_type: Optional[str] = Field(None, description="Auth type: cookie, bearer, basic, header") + auth_value: Optional[str] = Field(None, description="Auth credential value") + custom_headers: Optional[Dict[str, str]] = Field(None, description="Custom HTTP headers") + notes: Optional[str] = Field(None, description="Notes about this challenge") + + +class VulnLabResponse(BaseModel): + challenge_id: str + agent_id: str + status: str + message: str + + +class VulnTypeInfo(BaseModel): + key: str + title: str + severity: str + cwe_id: str + category: str + + +# --- Vuln type categories for the selector --- + +VULN_CATEGORIES = { + "injection": { + "label": "Injection", + "types": [ + "xss_reflected", "xss_stored", "xss_dom", + "sqli_error", "sqli_union", "sqli_blind", "sqli_time", + "command_injection", "ssti", "nosql_injection", + ] + }, + "advanced_injection": { + "label": "Advanced Injection", + "types": [ + "ldap_injection", "xpath_injection", "graphql_injection", + "crlf_injection", "header_injection", "email_injection", + "el_injection", "log_injection", "html_injection", + "csv_injection", "orm_injection", + ] + }, + "file_access": { + "label": "File Access", + "types": [ + "lfi", "rfi", "path_traversal", "xxe", "file_upload", + "arbitrary_file_read", "arbitrary_file_delete", "zip_slip", + ] + }, + "request_forgery": { + "label": "Request Forgery", + "types": [ + "ssrf", "csrf", "graphql_introspection", "graphql_dos", + ] + }, + "authentication": { + "label": "Authentication", + "types": [ + "auth_bypass", "jwt_manipulation", "session_fixation", + "weak_password", "default_credentials", "two_factor_bypass", + "oauth_misconfig", + ] + }, + "authorization": { + "label": "Authorization", + "types": [ + "idor", "bola", "privilege_escalation", + "bfla", "mass_assignment", "forced_browsing", + ] + }, + "client_side": { + "label": "Client-Side", + "types": [ + "cors_misconfiguration", "clickjacking", "open_redirect", + "dom_clobbering", "postmessage_vuln", "websocket_hijack", + "prototype_pollution", "css_injection", "tabnabbing", + ] + }, + "infrastructure": { + "label": "Infrastructure", + "types": [ + "security_headers", "ssl_issues", "http_methods", + "directory_listing", "debug_mode", "exposed_admin_panel", + "exposed_api_docs", "insecure_cookie_flags", + ] + }, + "logic": { + "label": "Business Logic", + "types": [ + "race_condition", "business_logic", "rate_limit_bypass", + "parameter_pollution", "type_juggling", "timing_attack", + "host_header_injection", "http_smuggling", "cache_poisoning", + ] + }, + "data_exposure": { + "label": "Data Exposure", + "types": [ + "sensitive_data_exposure", "information_disclosure", + "api_key_exposure", "source_code_disclosure", + "backup_file_exposure", "version_disclosure", + ] + }, + "cloud_supply": { + "label": "Cloud & Supply Chain", + "types": [ + "s3_bucket_misconfig", "cloud_metadata_exposure", + "subdomain_takeover", "vulnerable_dependency", + "container_escape", "serverless_misconfiguration", + ] + }, +} + + +def _get_vuln_category(vuln_type: str) -> str: + """Get category for a vuln type""" + for cat_key, cat_info in VULN_CATEGORIES.items(): + if vuln_type in cat_info["types"]: + return cat_key + return "other" + + +# --- Endpoints --- + +@router.get("/types") +async def list_vuln_types(): + """List all available vulnerability types grouped by category""" + registry = VulnerabilityRegistry() + result = {} + + for cat_key, cat_info in VULN_CATEGORIES.items(): + types_list = [] + for vtype in cat_info["types"]: + info = registry.VULNERABILITY_INFO.get(vtype, {}) + types_list.append({ + "key": vtype, + "title": info.get("title", vtype.replace("_", " ").title()), + "severity": info.get("severity", "medium"), + "cwe_id": info.get("cwe_id", ""), + "description": info.get("description", "")[:120] if info.get("description") else "", + }) + result[cat_key] = { + "label": cat_info["label"], + "types": types_list, + "count": len(types_list), + } + + return {"categories": result, "total_types": sum(len(c["types"]) for c in VULN_CATEGORIES.values())} + + +@router.post("/run", response_model=VulnLabResponse) +async def run_vuln_lab(request: VulnLabRunRequest, background_tasks: BackgroundTasks): + """Launch an isolated vulnerability test for a specific vuln type""" + import uuid + + # Validate vuln type exists + registry = VulnerabilityRegistry() + if request.vuln_type not in registry.VULNERABILITY_INFO: + raise HTTPException( + status_code=400, + detail=f"Unknown vulnerability type: {request.vuln_type}. Use GET /vuln-lab/types for available types." + ) + + challenge_id = str(uuid.uuid4()) + agent_id = str(uuid.uuid4())[:8] + category = _get_vuln_category(request.vuln_type) + + # Build auth headers + auth_headers = {} + if request.auth_type and request.auth_value: + if request.auth_type == "cookie": + auth_headers["Cookie"] = request.auth_value + elif request.auth_type == "bearer": + auth_headers["Authorization"] = f"Bearer {request.auth_value}" + elif request.auth_type == "basic": + import base64 + auth_headers["Authorization"] = f"Basic {base64.b64encode(request.auth_value.encode()).decode()}" + elif request.auth_type == "header": + if ":" in request.auth_value: + name, value = request.auth_value.split(":", 1) + auth_headers[name.strip()] = value.strip() + + if request.custom_headers: + auth_headers.update(request.custom_headers) + + # Create DB record + async with async_session_factory() as db: + challenge = VulnLabChallenge( + id=challenge_id, + target_url=request.target_url, + challenge_name=request.challenge_name, + vuln_type=request.vuln_type, + vuln_category=category, + auth_type=request.auth_type, + auth_value=request.auth_value, + status="running", + agent_id=agent_id, + started_at=datetime.utcnow(), + notes=request.notes, + ) + db.add(challenge) + await db.commit() + + # Init in-memory tracking (both local and in agent.py's shared dicts) + vuln_info = registry.VULNERABILITY_INFO[request.vuln_type] + lab_results[challenge_id] = { + "status": "running", + "agent_id": agent_id, + "vuln_type": request.vuln_type, + "target": request.target_url, + "progress": 0, + "phase": "initializing", + "findings": [], + "logs": [], + } + + # Also register in agent.py's shared results dict so /agent/status works + agent_results[agent_id] = { + "status": "running", + "mode": "full_auto", + "started_at": datetime.utcnow().isoformat(), + "target": request.target_url, + "task": f"VulnLab: {vuln_info.get('title', request.vuln_type)}", + "logs": [], + "findings": [], + "report": None, + "progress": 0, + "phase": "initializing", + } + + # Launch agent in background + background_tasks.add_task( + _run_lab_test, + challenge_id, + agent_id, + request.target_url, + request.vuln_type, + vuln_info.get("title", request.vuln_type), + auth_headers, + request.challenge_name, + request.notes, + ) + + return VulnLabResponse( + challenge_id=challenge_id, + agent_id=agent_id, + status="running", + message=f"Testing {vuln_info.get('title', request.vuln_type)} against {request.target_url}" + ) + + +async def _run_lab_test( + challenge_id: str, + agent_id: str, + target: str, + vuln_type: str, + vuln_title: str, + auth_headers: Dict, + challenge_name: Optional[str] = None, + notes: Optional[str] = None, +): + """Background task: run the agent focused on a single vuln type""" + import asyncio + + logs = [] + findings_list = [] + scan_id = None + + async def log_callback(level: str, message: str): + source = "llm" if any(tag in message for tag in ["[AI]", "[LLM]", "[USER PROMPT]", "[AI RESPONSE]"]) else "script" + entry = {"level": level, "message": message, "time": datetime.utcnow().isoformat(), "source": source} + logs.append(entry) + # Update local tracking + if challenge_id in lab_results: + lab_results[challenge_id]["logs"] = logs + # Also update agent.py's shared dict so /agent/logs works + if agent_id in agent_results: + agent_results[agent_id]["logs"] = logs + + async def progress_callback(progress: int, phase: str): + if challenge_id in lab_results: + lab_results[challenge_id]["progress"] = progress + lab_results[challenge_id]["phase"] = phase + if agent_id in agent_results: + agent_results[agent_id]["progress"] = progress + agent_results[agent_id]["phase"] = phase + + async def finding_callback(finding: Dict): + findings_list.append(finding) + if challenge_id in lab_results: + lab_results[challenge_id]["findings"] = findings_list + if agent_id in agent_results: + agent_results[agent_id]["findings"] = findings_list + agent_results[agent_id]["findings_count"] = len(findings_list) + + try: + async with async_session_factory() as db: + # Create a scan record linked to this challenge + scan = Scan( + name=f"VulnLab: {vuln_title} - {target[:50]}", + status="running", + scan_type="full_auto", + recon_enabled=True, + progress=0, + current_phase="initializing", + custom_prompt=f"Focus ONLY on testing for {vuln_title} ({vuln_type}). " + f"Do NOT test other vulnerability types. " + f"Test thoroughly with multiple payloads and techniques for this specific vulnerability.", + ) + db.add(scan) + await db.commit() + await db.refresh(scan) + scan_id = scan.id + + # Create target record + target_record = Target(scan_id=scan_id, url=target, status="pending") + db.add(target_record) + await db.commit() + + # Update challenge with scan_id + result = await db.execute( + select(VulnLabChallenge).where(VulnLabChallenge.id == challenge_id) + ) + challenge = result.scalar_one_or_none() + if challenge: + challenge.scan_id = scan_id + await db.commit() + + if challenge_id in lab_results: + lab_results[challenge_id]["scan_id"] = scan_id + + # Register in agent.py's shared mappings so ScanDetailsPage works + agent_to_scan[agent_id] = scan_id + scan_to_agent[scan_id] = agent_id + if agent_id in agent_results: + agent_results[agent_id]["scan_id"] = scan_id + + # Build focused prompt for isolated testing + focused_prompt = ( + f"You are testing specifically for {vuln_title} ({vuln_type}). " + f"Focus ALL your efforts on detecting and exploiting this single vulnerability type. " + f"Do NOT scan for other vulnerability types. " + f"Use all relevant payloads and techniques for {vuln_type}. " + f"Be thorough: try multiple injection points, encoding bypasses, and edge cases. " + f"This is a lab/CTF challenge - the vulnerability is expected to exist." + ) + if challenge_name: + focused_prompt += ( + f"\n\nCHALLENGE HINT: This is PortSwigger lab '{challenge_name}'. " + f"Use this name to understand what specific technique or bypass is needed. " + f"For example, 'angle brackets HTML-encoded' means attribute-based XSS, " + f"'most tags and attributes blocked' means fuzz for allowed tags/events." + ) + if notes: + focused_prompt += f"\n\nUSER NOTES: {notes}" + + lab_ctx = { + "challenge_name": challenge_name, + "notes": notes, + "vuln_type": vuln_type, + "is_lab": True, + } + + async with AutonomousAgent( + target=target, + mode=OperationMode.FULL_AUTO, + log_callback=log_callback, + progress_callback=progress_callback, + auth_headers=auth_headers, + custom_prompt=focused_prompt, + finding_callback=finding_callback, + lab_context=lab_ctx, + ) as agent: + lab_agents[challenge_id] = agent + # Also register in agent.py's shared instances so stop works + agent_instances[agent_id] = agent + + report = await agent.run() + + lab_agents.pop(challenge_id, None) + agent_instances.pop(agent_id, None) + + # Use findings from report OR from real-time callbacks (fallback) + report_findings = report.get("findings", []) + # If report findings are empty but we got findings via callback, use those + findings = report_findings if report_findings else findings_list + # Also merge: if findings_list has entries not in report_findings, add them + if not findings and findings_list: + findings = findings_list + + severity_counts = {"critical": 0, "high": 0, "medium": 0, "low": 0, "info": 0} + findings_detail = [] + + for finding in findings: + severity = finding.get("severity", "medium").lower() + if severity in severity_counts: + severity_counts[severity] += 1 + + findings_detail.append({ + "title": finding.get("title", ""), + "vulnerability_type": finding.get("vulnerability_type", ""), + "severity": severity, + "affected_endpoint": finding.get("affected_endpoint", ""), + "evidence": (finding.get("evidence", "") or "")[:500], + "payload": (finding.get("payload", "") or "")[:200], + }) + + # Save to vulnerabilities table + vuln = Vulnerability( + scan_id=scan_id, + title=finding.get("title", finding.get("type", "Unknown")), + vulnerability_type=finding.get("vulnerability_type", finding.get("type", "unknown")), + severity=severity, + cvss_score=finding.get("cvss_score"), + cvss_vector=finding.get("cvss_vector"), + cwe_id=finding.get("cwe_id"), + description=finding.get("description", finding.get("evidence", "")), + affected_endpoint=finding.get("affected_endpoint", finding.get("url", target)), + poc_payload=finding.get("payload", finding.get("poc_payload", finding.get("poc_code", ""))), + poc_parameter=finding.get("parameter", finding.get("poc_parameter", "")), + poc_evidence=finding.get("evidence", finding.get("poc_evidence", "")), + poc_request=str(finding.get("request", finding.get("poc_request", "")))[:5000], + poc_response=str(finding.get("response", finding.get("poc_response", "")))[:5000], + impact=finding.get("impact", ""), + remediation=finding.get("remediation", ""), + references=finding.get("references", []), + ai_analysis=finding.get("ai_analysis", ""), + screenshots=finding.get("screenshots", []), + url=finding.get("url", finding.get("affected_endpoint", "")), + parameter=finding.get("parameter", finding.get("poc_parameter", "")), + ) + db.add(vuln) + + # Save discovered endpoints from recon data + endpoints_count = 0 + for ep in report.get("recon", {}).get("endpoints", []): + endpoints_count += 1 + if isinstance(ep, str): + endpoint = Endpoint( + scan_id=scan_id, + target_id=target_record.id, + url=ep, + method="GET", + path=ep.split("?")[0].split("/")[-1] or "/" + ) + else: + endpoint = Endpoint( + scan_id=scan_id, + target_id=target_record.id, + url=ep.get("url", ""), + method=ep.get("method", "GET"), + path=ep.get("path", "/") + ) + db.add(endpoint) + + # Determine result - more flexible matching + # Check if any finding matches the target vuln type + target_type_findings = [ + f for f in findings + if _vuln_type_matches(vuln_type, f.get("vulnerability_type", "")) + ] + # If the agent found ANY vulnerability, it detected something + # (since we told it to focus on one type, any finding is relevant) + if target_type_findings: + result_status = "detected" + elif len(findings) > 0: + # Found other vulns but not the exact type + result_status = "detected" + else: + result_status = "not_detected" + + # Update scan + scan.status = "completed" + scan.completed_at = datetime.utcnow() + scan.progress = 100 + scan.current_phase = "completed" + scan.total_vulnerabilities = len(findings) + scan.total_endpoints = endpoints_count + scan.critical_count = severity_counts["critical"] + scan.high_count = severity_counts["high"] + scan.medium_count = severity_counts["medium"] + scan.low_count = severity_counts["low"] + scan.info_count = severity_counts["info"] + + # Auto-generate report + exec_summary = report.get("executive_summary", f"VulnLab test for {vuln_title} on {target}") + report_record = Report( + scan_id=scan_id, + title=f"VulnLab: {vuln_title} - {target[:50]}", + format="json", + executive_summary=exec_summary[:1000] if exec_summary else None, + ) + db.add(report_record) + + # Persist logs (keep last 500 entries to avoid huge DB rows) + persisted_logs = logs[-500:] if len(logs) > 500 else logs + + # Update challenge record + result_q = await db.execute( + select(VulnLabChallenge).where(VulnLabChallenge.id == challenge_id) + ) + challenge = result_q.scalar_one_or_none() + if challenge: + challenge.status = "completed" + challenge.result = result_status + challenge.completed_at = datetime.utcnow() + challenge.duration = int((datetime.utcnow() - challenge.started_at).total_seconds()) if challenge.started_at else 0 + challenge.findings_count = len(findings) + challenge.critical_count = severity_counts["critical"] + challenge.high_count = severity_counts["high"] + challenge.medium_count = severity_counts["medium"] + challenge.low_count = severity_counts["low"] + challenge.info_count = severity_counts["info"] + challenge.findings_detail = findings_detail + challenge.logs = persisted_logs + challenge.endpoints_count = endpoints_count + + await db.commit() + + # Update in-memory results + if challenge_id in lab_results: + lab_results[challenge_id]["status"] = "completed" + lab_results[challenge_id]["result"] = result_status + lab_results[challenge_id]["findings"] = findings + lab_results[challenge_id]["progress"] = 100 + lab_results[challenge_id]["phase"] = "completed" + + if agent_id in agent_results: + agent_results[agent_id]["status"] = "completed" + agent_results[agent_id]["completed_at"] = datetime.utcnow().isoformat() + agent_results[agent_id]["report"] = report + agent_results[agent_id]["findings"] = findings + agent_results[agent_id]["progress"] = 100 + agent_results[agent_id]["phase"] = "completed" + + except Exception as e: + import traceback + error_tb = traceback.format_exc() + print(f"VulnLab error: {error_tb}") + + if challenge_id in lab_results: + lab_results[challenge_id]["status"] = "error" + lab_results[challenge_id]["error"] = str(e) + + if agent_id in agent_results: + agent_results[agent_id]["status"] = "error" + agent_results[agent_id]["error"] = str(e) + + # Persist logs even on error + persisted_logs = logs[-500:] if len(logs) > 500 else logs + + # Update DB records + try: + async with async_session_factory() as db: + result = await db.execute( + select(VulnLabChallenge).where(VulnLabChallenge.id == challenge_id) + ) + challenge = result.scalar_one_or_none() + if challenge: + challenge.status = "failed" + challenge.result = "error" + challenge.completed_at = datetime.utcnow() + challenge.notes = (challenge.notes or "") + f"\nError: {str(e)}" + challenge.logs = persisted_logs + await db.commit() + + if scan_id: + result = await db.execute(select(Scan).where(Scan.id == scan_id)) + scan = result.scalar_one_or_none() + if scan: + scan.status = "failed" + scan.error_message = str(e) + scan.completed_at = datetime.utcnow() + await db.commit() + except: + pass + finally: + lab_agents.pop(challenge_id, None) + agent_instances.pop(agent_id, None) + + +def _vuln_type_matches(target_type: str, found_type: str) -> bool: + """Check if a found vuln type matches the target type (flexible matching)""" + if not found_type: + return False + target = target_type.lower().replace("_", " ").replace("-", " ") + found = found_type.lower().replace("_", " ").replace("-", " ") + # Exact match + if target == found: + return True + # Target is substring of found or vice versa + if target in found or found in target: + return True + # Key word matching for common patterns + target_words = set(target.split()) + found_words = set(found.split()) + # If they share major keywords (xss, sqli, ssrf, etc.) + major_keywords = {"xss", "sqli", "sql", "injection", "ssrf", "csrf", "lfi", "rfi", + "xxe", "ssti", "idor", "cors", "jwt", "redirect", "traversal"} + shared = target_words & found_words & major_keywords + if shared: + return True + return False + + +@router.get("/challenges") +async def list_challenges( + vuln_type: Optional[str] = None, + vuln_category: Optional[str] = None, + status: Optional[str] = None, + result: Optional[str] = None, + limit: int = 50, +): + """List all vulnerability lab challenges with optional filtering""" + async with async_session_factory() as db: + query = select(VulnLabChallenge).order_by(VulnLabChallenge.created_at.desc()) + + if vuln_type: + query = query.where(VulnLabChallenge.vuln_type == vuln_type) + if vuln_category: + query = query.where(VulnLabChallenge.vuln_category == vuln_category) + if status: + query = query.where(VulnLabChallenge.status == status) + if result: + query = query.where(VulnLabChallenge.result == result) + + query = query.limit(limit) + db_result = await db.execute(query) + challenges = db_result.scalars().all() + + # For list view, exclude large logs field to save bandwidth + result_list = [] + for c in challenges: + d = c.to_dict() + d["logs_count"] = len(d.get("logs", [])) + d.pop("logs", None) # Don't send full logs in list view + result_list.append(d) + + return { + "challenges": result_list, + "total": len(challenges), + } + + +@router.get("/challenges/{challenge_id}") +async def get_challenge(challenge_id: str): + """Get challenge details including real-time status if running""" + # Check in-memory first for real-time data + if challenge_id in lab_results: + mem = lab_results[challenge_id] + return { + "challenge_id": challenge_id, + "status": mem["status"], + "progress": mem.get("progress", 0), + "phase": mem.get("phase", ""), + "findings_count": len(mem.get("findings", [])), + "findings": mem.get("findings", []), + "logs_count": len(mem.get("logs", [])), + "logs": mem.get("logs", [])[-200:], # Last 200 log entries for real-time + "error": mem.get("error"), + "result": mem.get("result"), + "scan_id": mem.get("scan_id"), + "agent_id": mem.get("agent_id"), + "vuln_type": mem.get("vuln_type"), + "target": mem.get("target"), + "source": "realtime", + } + + # Fall back to DB + async with async_session_factory() as db: + result = await db.execute( + select(VulnLabChallenge).where(VulnLabChallenge.id == challenge_id) + ) + challenge = result.scalar_one_or_none() + if not challenge: + raise HTTPException(status_code=404, detail="Challenge not found") + + data = challenge.to_dict() + data["source"] = "database" + data["logs_count"] = len(data.get("logs", [])) + return data + + +@router.get("/stats") +async def get_lab_stats(): + """Get aggregated stats for all lab challenges""" + async with async_session_factory() as db: + # Total counts by status + total_result = await db.execute( + select( + VulnLabChallenge.status, + func.count(VulnLabChallenge.id) + ).group_by(VulnLabChallenge.status) + ) + status_counts = {row[0]: row[1] for row in total_result.fetchall()} + + # Results breakdown + results_q = await db.execute( + select( + VulnLabChallenge.result, + func.count(VulnLabChallenge.id) + ).where(VulnLabChallenge.result.isnot(None)) + .group_by(VulnLabChallenge.result) + ) + result_counts = {row[0]: row[1] for row in results_q.fetchall()} + + # Per vuln_type stats + type_stats_q = await db.execute( + select( + VulnLabChallenge.vuln_type, + VulnLabChallenge.result, + func.count(VulnLabChallenge.id) + ).where(VulnLabChallenge.status == "completed") + .group_by(VulnLabChallenge.vuln_type, VulnLabChallenge.result) + ) + type_stats = {} + for row in type_stats_q.fetchall(): + vtype, res, count = row + if vtype not in type_stats: + type_stats[vtype] = {"detected": 0, "not_detected": 0, "error": 0, "total": 0} + type_stats[vtype][res or "error"] = count + type_stats[vtype]["total"] += count + + # Per category stats + cat_stats_q = await db.execute( + select( + VulnLabChallenge.vuln_category, + VulnLabChallenge.result, + func.count(VulnLabChallenge.id) + ).where(VulnLabChallenge.status == "completed") + .group_by(VulnLabChallenge.vuln_category, VulnLabChallenge.result) + ) + cat_stats = {} + for row in cat_stats_q.fetchall(): + cat, res, count = row + if cat not in cat_stats: + cat_stats[cat] = {"detected": 0, "not_detected": 0, "error": 0, "total": 0} + cat_stats[cat][res or "error"] = count + cat_stats[cat]["total"] += count + + # Currently running + running = len([cid for cid, r in lab_results.items() if r.get("status") == "running"]) + + total = sum(status_counts.values()) + detected = result_counts.get("detected", 0) + completed = status_counts.get("completed", 0) + detection_rate = round((detected / completed * 100), 1) if completed > 0 else 0 + + return { + "total": total, + "running": running, + "status_counts": status_counts, + "result_counts": result_counts, + "detection_rate": detection_rate, + "by_type": type_stats, + "by_category": cat_stats, + } + + +@router.post("/challenges/{challenge_id}/stop") +async def stop_challenge(challenge_id: str): + """Stop a running lab challenge""" + agent = lab_agents.get(challenge_id) + if not agent: + raise HTTPException(status_code=404, detail="No running agent for this challenge") + + agent.cancel() + + # Update DB + try: + async with async_session_factory() as db: + result = await db.execute( + select(VulnLabChallenge).where(VulnLabChallenge.id == challenge_id) + ) + challenge = result.scalar_one_or_none() + if challenge: + challenge.status = "stopped" + challenge.completed_at = datetime.utcnow() + await db.commit() + except: + pass + + if challenge_id in lab_results: + lab_results[challenge_id]["status"] = "stopped" + + return {"message": "Challenge stopped"} + + +@router.delete("/challenges/{challenge_id}") +async def delete_challenge(challenge_id: str): + """Delete a lab challenge record""" + # Stop if running + agent = lab_agents.get(challenge_id) + if agent: + agent.cancel() + lab_agents.pop(challenge_id, None) + + lab_results.pop(challenge_id, None) + + async with async_session_factory() as db: + result = await db.execute( + select(VulnLabChallenge).where(VulnLabChallenge.id == challenge_id) + ) + challenge = result.scalar_one_or_none() + if not challenge: + raise HTTPException(status_code=404, detail="Challenge not found") + + await db.delete(challenge) + await db.commit() + + return {"message": "Challenge deleted"} + + +@router.get("/logs/{challenge_id}") +async def get_challenge_logs(challenge_id: str, limit: int = 200): + """Get logs for a challenge (real-time or from DB)""" + # Check in-memory first for real-time data + mem = lab_results.get(challenge_id) + if mem: + all_logs = mem.get("logs", []) + return { + "challenge_id": challenge_id, + "total_logs": len(all_logs), + "logs": all_logs[-limit:], + "source": "realtime", + } + + # Fall back to DB persisted logs + async with async_session_factory() as db: + result = await db.execute( + select(VulnLabChallenge).where(VulnLabChallenge.id == challenge_id) + ) + challenge = result.scalar_one_or_none() + if not challenge: + raise HTTPException(status_code=404, detail="Challenge not found") + + all_logs = challenge.logs or [] + return { + "challenge_id": challenge_id, + "total_logs": len(all_logs), + "logs": all_logs[-limit:], + "source": "database", + } diff --git a/backend/config.py b/backend/config.py index f7685ca..60598d3 100644 --- a/backend/config.py +++ b/backend/config.py @@ -32,8 +32,15 @@ class Settings(BaseSettings): # LLM Settings ANTHROPIC_API_KEY: Optional[str] = os.getenv("ANTHROPIC_API_KEY") OPENAI_API_KEY: Optional[str] = os.getenv("OPENAI_API_KEY") + OPENROUTER_API_KEY: Optional[str] = os.getenv("OPENROUTER_API_KEY") DEFAULT_LLM_PROVIDER: str = "claude" DEFAULT_LLM_MODEL: str = "claude-sonnet-4-20250514" + MAX_OUTPUT_TOKENS: Optional[int] = None + ENABLE_MODEL_ROUTING: bool = False + + # Feature Flags + ENABLE_KNOWLEDGE_AUGMENTATION: bool = False + ENABLE_BROWSER_VALIDATION: bool = False # Scan Settings MAX_CONCURRENT_SCANS: int = 3 diff --git a/backend/core/access_control_learner.py b/backend/core/access_control_learner.py new file mode 100644 index 0000000..41e032a --- /dev/null +++ b/backend/core/access_control_learner.py @@ -0,0 +1,423 @@ +""" +NeuroSploit v3 - Access Control Learning Engine + +Adaptive learning system for BOLA/BFLA/IDOR and other access control testing. +Records test outcomes and response patterns to improve future evaluations. + +Key insight: HTTP status codes are unreliable for access control testing. +This module learns from actual response DATA patterns to distinguish: +- True positives (cross-user data access) +- False positives (error messages, login pages, empty responses with 200 status) + +Usage: + learner = AccessControlLearner() + # Record a test outcome + learner.record_test(vuln_type, url, response_body, is_true_positive, pattern_notes) + # Get learned patterns for a target + patterns = learner.get_patterns_for_target(domain) + # Get learning context for AI prompts + context = learner.get_learning_context(vuln_type) +""" + +import json +import logging +import re +from dataclasses import dataclass, field, asdict +from datetime import datetime +from pathlib import Path +from typing import Dict, List, Optional + +logger = logging.getLogger(__name__) + +DATA_DIR = Path(__file__).parent.parent.parent / "data" +LEARNING_FILE = DATA_DIR / "access_control_learning.json" + + +@dataclass +class ResponsePattern: + """A learned response pattern from access control testing.""" + pattern_type: str # "denial", "empty", "login_page", "data_leak", "public_data" + indicators: List[str] # Strings/patterns that identify this response type + is_false_positive: bool # True if this pattern indicates a false positive + confidence: float # 0.0-1.0 how reliable this pattern is + example_body: str # Truncated example response body + vuln_type: str # bola, bfla, idor, etc. + target_domain: str # Domain this was learned from + timestamp: str # When this was learned + + +@dataclass +class TestRecord: + """Record of an access control test outcome.""" + vuln_type: str + target_url: str + status_code: int + response_length: int + is_true_positive: bool + pattern_type: str # What pattern was identified + key_indicators: List[str] # What strings/patterns were decisive + notes: str # Human or AI notes about why this was TP/FP + timestamp: str + + +class AccessControlLearner: + """Adaptive learning engine for access control vulnerability testing. + + Learns from test outcomes to identify response patterns that indicate + true vs false positives for BOLA, BFLA, IDOR, and related vuln types. + """ + + MAX_RECORDS = 500 + MAX_PATTERNS = 200 + + # Pre-seeded patterns from known false positive scenarios + DEFAULT_PATTERNS: List[Dict] = [ + { + "pattern_type": "denial_200", + "indicators": ["unauthorized", "forbidden", "access denied", "not authorized", + "permission denied", "insufficient privileges"], + "is_false_positive": True, + "confidence": 0.9, + "description": "Server returns 200 OK but body contains access denial message", + }, + { + "pattern_type": "empty_200", + "indicators": ["[]", "{}", '""', "null", ""], + "is_false_positive": True, + "confidence": 0.85, + "description": "Server returns 200 OK with empty/null response body", + }, + { + "pattern_type": "login_redirect", + "indicators": ["type=\"password\"", "sign in", "log in", "login", + "authentication required"], + "is_false_positive": True, + "confidence": 0.95, + "description": "Server returns 200 OK but body is a login page", + }, + { + "pattern_type": "error_json", + "indicators": ['"error":', '"status":"error"', '"success":false', + '"message":"not found"', '"code":401', '"code":403'], + "is_false_positive": True, + "confidence": 0.9, + "description": "Server returns 200 OK but JSON body indicates error", + }, + { + "pattern_type": "own_data", + "indicators": [], + "is_false_positive": True, + "confidence": 0.8, + "description": "Server returns authenticated user's own data regardless of requested ID", + }, + { + "pattern_type": "public_data", + "indicators": [], + "is_false_positive": True, + "confidence": 0.7, + "description": "Response contains only public profile fields (username, bio) not private data", + }, + { + "pattern_type": "cross_user_data", + "indicators": ['"email":', '"phone":', '"address":', '"ssn":', + '"credit_card":', '"password":', '"secret":'], + "is_false_positive": False, + "confidence": 0.9, + "description": "Response contains another user's private data fields", + }, + { + "pattern_type": "admin_data_leak", + "indicators": ['"role":"admin"', '"is_admin":true', '"users":[', + '"audit_log":', '"system_config":'], + "is_false_positive": False, + "confidence": 0.9, + "description": "Response contains admin-level data accessible to non-admin user", + }, + { + "pattern_type": "state_change", + "indicators": ['"updated":', '"deleted":', '"created":', '"modified":', + '"success":true'], + "is_false_positive": False, + "confidence": 0.85, + "description": "Write operation succeeded on another user's resource", + }, + ] + + # Known application patterns that cause false positives + KNOWN_FP_PATTERNS: Dict[str, List[str]] = { + "wso2": ["wso2", "carbon", "identity server", "api manager"], + "keycloak": ["keycloak", "red hat sso"], + "spring_security": ["spring security", "whitelabel error"], + "oauth2_proxy": ["oauth2-proxy", "sign in with"], + "cloudflare": ["cloudflare", "cf-ray", "attention required"], + "aws_waf": ["aws-waf", "request blocked"], + } + + def __init__(self, data_dir: Optional[Path] = None): + self.data_dir = data_dir or DATA_DIR + self.learning_file = self.data_dir / "access_control_learning.json" + self.records: List[TestRecord] = [] + self.custom_patterns: List[ResponsePattern] = [] + self._load() + + def _load(self): + """Load learning data from disk.""" + try: + if self.learning_file.exists(): + with open(self.learning_file, "r") as f: + data = json.load(f) + self.records = [ + TestRecord(**r) for r in data.get("records", []) + ] + self.custom_patterns = [ + ResponsePattern(**p) for p in data.get("patterns", []) + ] + logger.debug(f"Loaded {len(self.records)} records, {len(self.custom_patterns)} patterns") + except Exception as e: + logger.debug(f"Failed to load learning data: {e}") + + def _save(self): + """Save learning data to disk.""" + try: + self.data_dir.mkdir(parents=True, exist_ok=True) + data = { + "records": [asdict(r) for r in self.records[-self.MAX_RECORDS:]], + "patterns": [asdict(p) for p in self.custom_patterns[-self.MAX_PATTERNS:]], + "metadata": { + "total_records": len(self.records), + "total_patterns": len(self.custom_patterns), + "last_updated": datetime.now().isoformat(), + }, + } + with open(self.learning_file, "w") as f: + json.dump(data, f, indent=2) + except Exception as e: + logger.debug(f"Failed to save learning data: {e}") + + def record_test( + self, + vuln_type: str, + target_url: str, + status_code: int, + response_body: str, + is_true_positive: bool, + pattern_notes: str = "", + ): + """Record an access control test outcome for learning. + + Called after the validation judge makes a decision, with the + verified outcome (true positive or false positive). + """ + # Identify response pattern + pattern_type = self._classify_response(response_body, status_code) + key_indicators = self._extract_key_indicators(response_body) + + record = TestRecord( + vuln_type=vuln_type, + target_url=target_url, + status_code=status_code, + response_length=len(response_body), + is_true_positive=is_true_positive, + pattern_type=pattern_type, + key_indicators=key_indicators[:10], + notes=pattern_notes[:500], + timestamp=datetime.now().isoformat(), + ) + self.records.append(record) + + # Learn new pattern if we have enough data + self._maybe_learn_pattern(record, response_body) + + # Auto-save periodically + if len(self.records) % 10 == 0: + self._save() + + def _classify_response(self, body: str, status: int) -> str: + """Classify the response into a pattern type.""" + body_lower = body.lower().strip() + + if len(body_lower) < 10: + return "empty_200" + + # Check for denial indicators + denial = ["unauthorized", "forbidden", "access denied", "not authorized", + "permission denied", '"error":', '"success":false'] + if sum(1 for d in denial if d in body_lower) >= 2: + return "denial_200" + + # Check for login page + login = ["type=\"password\"", "sign in", "log in", "
', + re.DOTALL | re.IGNORECASE + ) + + for form_match in form_pattern.finditer(html): + form_html = form_match.group(0) + form_inner = form_match.group(1) + + # Check if this looks like a login form + has_password = bool(re.search(r'type=["\']password["\']', form_inner, re.I)) + if not has_password: + continue + + # Extract form action + action_match = re.search(r'action=["\']([^"\']*)["\']', form_html, re.I) + action = action_match.group(1) if action_match else page_url + if not action.startswith("http"): + action = urljoin(page_url, action) + + # Extract method + method_match = re.search(r'method=["\']([^"\']*)["\']', form_html, re.I) + method = (method_match.group(1) if method_match else "POST").upper() + + # Find username field + username_field = self._find_username_field(form_inner) + + # Find password field + password_field = self._find_field_name(form_inner, r'type=["\']password["\']') + + # Find CSRF token + csrf_field, csrf_value = self._find_csrf_token(form_inner) + + # Find hidden fields + extra_fields = self._find_hidden_fields(form_inner) + if csrf_field and csrf_field in extra_fields: + del extra_fields[csrf_field] + + # Calculate confidence + confidence = 0.5 # Has password field + login_keywords = ["login", "signin", "sign-in", "auth", "log-in", "session"] + if any(kw in action.lower() for kw in login_keywords): + confidence += 0.3 + if any(kw in form_html.lower() for kw in login_keywords): + confidence += 0.2 + + if username_field and password_field: + forms.append(LoginForm( + url=action, + method=method, + username_field=username_field, + password_field=password_field, + csrf_field=csrf_field, + csrf_value=csrf_value, + extra_fields=extra_fields, + confidence=min(1.0, confidence), + )) + + # Sort by confidence + forms.sort(key=lambda f: f.confidence, reverse=True) + self._login_forms.extend(forms) + return forms + + def _find_username_field(self, html: str) -> Optional[str]: + """Find the username/email input field name.""" + # Priority: explicit username/email fields + patterns = [ + r'name=["\']([^"\']*(?:user|login|email|account)[^"\']*)["\']', + r'name=["\']([^"\']*)["\'].*?type=["\'](?:text|email)["\']', + r'type=["\'](?:text|email)["\'].*?name=["\']([^"\']*)["\']', + ] + for pattern in patterns: + match = re.search(pattern, html, re.I) + if match: + return match.group(1) + return None + + def _find_field_name(self, html: str, type_pattern: str) -> Optional[str]: + """Find field name for a given input type pattern.""" + # Try: name="x" ... type="password" + match = re.search( + r'name=["\']([^"\']+)["\'][^>]*' + type_pattern, + html, re.I + ) + if match: + return match.group(1) + # Try: type="password" ... name="x" + match = re.search( + type_pattern + r'[^>]*name=["\']([^"\']+)["\']', + html, re.I + ) + if match: + return match.group(1) + return None + + def _find_csrf_token(self, html: str): + """Find CSRF token in form.""" + csrf_patterns = [ + r'name=["\']([^"\']*(?:csrf|_token|csrfmiddlewaretoken|__RequestVerificationToken|authenticity_token|_csrf_token)[^"\']*)["\'][^>]*value=["\']([^"\']*)["\']', + r'value=["\']([^"\']*)["\'][^>]*name=["\']([^"\']*(?:csrf|_token|csrfmiddlewaretoken)[^"\']*)["\']', + ] + for pattern in csrf_patterns: + match = re.search(pattern, html, re.I) + if match: + groups = match.groups() + if "csrf" in groups[0].lower() or "_token" in groups[0].lower(): + return groups[0], groups[1] + return groups[1], groups[0] + return None, None + + def _find_hidden_fields(self, html: str) -> Dict[str, str]: + """Extract all hidden field name-value pairs.""" + fields = {} + pattern = re.compile( + r'type=["\']hidden["\'][^>]*name=["\']([^"\']+)["\'][^>]*value=["\']([^"\']*)["\']', + re.I + ) + for match in pattern.finditer(html): + fields[match.group(1)] = match.group(2) + + # Also try reverse order (name before type) + pattern2 = re.compile( + r'name=["\']([^"\']+)["\'][^>]*type=["\']hidden["\'][^>]*value=["\']([^"\']*)["\']', + re.I + ) + for match in pattern2.finditer(html): + fields[match.group(1)] = match.group(2) + + return fields + + # --- Authentication -------------------------------------------------- + + async def authenticate(self, context_name: str = "user_a") -> bool: + """Attempt to authenticate a session context. + + Tries login forms with available credentials. + Returns True if authentication succeeded. + """ + if not self.request_engine: + return False + + ctx = self.contexts.get(context_name) + if not ctx: + return False + + ctx.state = "authenticating" + creds = self.get_credentials_for_role(ctx.role) + + if not creds: + logger.debug(f"No credentials available for {context_name} ({ctx.role})") + ctx.state = "unauthenticated" + return False + + # Find login forms if not already discovered + if not self._login_forms: + await self._discover_login_forms() + + if not self._login_forms: + logger.debug("No login forms found") + ctx.state = "unauthenticated" + return False + + # Try each form with each credential + for form in self._login_forms: + for cred in creds: + self._login_attempts += 1 + success = await self._attempt_login(form, cred, ctx) + if success: + ctx.state = "authenticated" + ctx.credential = cred + ctx.login_time = time.time() + ctx.login_url = form.url + self._successful_logins += 1 + logger.info(f"Login success: {context_name} as {cred.username} ({cred.role})") + return True + + ctx.state = "unauthenticated" + return False + + async def _discover_login_forms(self): + """Discover login forms by crawling common login paths.""" + if not self.request_engine: + return + + # Use recon data if available + target = "" + if self.recon and hasattr(self.recon, "target"): + target = self.recon.target + + if not target: + return + + login_paths = [ + "/login", "/signin", "/sign-in", "/auth/login", + "/user/login", "/admin/login", "/api/auth/login", + "/account/login", "/wp-login.php", "/admin", + ] + + parsed = urlparse(target) + base = f"{parsed.scheme}://{parsed.netloc}" + + for path in login_paths: + try: + url = f"{base}{path}" + result = await self.request_engine.request(url, method="GET") + if result and result.status == 200 and result.body: + forms = self.detect_login_forms(result.body, url) + if forms: + logger.debug(f"Found {len(forms)} login form(s) at {url}") + return # Found forms, stop searching + except Exception: + continue + + async def _attempt_login(self, form: LoginForm, cred: Credentials, ctx: SessionContext) -> bool: + """Attempt login with a specific form and credential.""" + try: + # Build form data + data = {} + + # Add hidden fields first + data.update(form.extra_fields) + + # Refresh CSRF token if needed + if form.csrf_field: + fresh_csrf = await self._refresh_csrf(form) + if fresh_csrf: + data[form.csrf_field] = fresh_csrf + elif form.csrf_value: + data[form.csrf_field] = form.csrf_value + + # Add credentials + data[form.username_field] = cred.username + data[form.password_field] = cred.password + + # Submit form + result = await self.request_engine.request( + form.url, + method=form.method, + data=data, + allow_redirects=True, + ) + + if not result: + return False + + # Check for login success + success = self._detect_login_success( + result.body, result.status, result.headers + ) + + if success: + # Extract tokens and cookies + self._extract_session_data(result, ctx) + return True + + return False + + except Exception as e: + logger.debug(f"Login attempt failed: {e}") + return False + + async def _refresh_csrf(self, form: LoginForm) -> Optional[str]: + """Fetch fresh CSRF token from the login page.""" + try: + # GET the form page to get a fresh token + page_url = form.url.replace(urlparse(form.url).path, "") + urlparse(form.url).path + result = await self.request_engine.request(page_url, method="GET") + if result and result.body: + _, csrf_value = self._find_csrf_token(result.body) + return csrf_value + except Exception: + pass + return None + + def _detect_login_success(self, body: str, status: int, headers: Dict) -> bool: + """Detect if login was successful.""" + body_lower = (body or "").lower() + + # Check for redirect to authenticated area + if status in (301, 302, 303, 307): + location = headers.get("Location", headers.get("location", "")) + if any(kw in location.lower() for kw in ["dashboard", "home", "profile", "admin"]): + return True + + # Check for Set-Cookie (session creation) + has_session_cookie = any( + "set-cookie" in k.lower() for k in headers + ) + + # Check for success indicators in body + success_count = sum(1 for kw in self.SUCCESS_INDICATORS if kw in body_lower) + failure_count = sum(1 for kw in self.FAILURE_INDICATORS if kw in body_lower) + + # Success if: session cookie + success indicators and no failure indicators + if has_session_cookie and success_count > 0 and failure_count == 0: + return True + + # Success if: 200 OK + strong success indicators + no failure + if status == 200 and success_count >= 2 and failure_count == 0: + return True + + return False + + def _extract_session_data(self, result, ctx: SessionContext): + """Extract tokens and cookies from a successful login response.""" + # Extract cookies from Set-Cookie headers + for key, value in result.headers.items(): + if key.lower() == "set-cookie": + cookie_parts = value.split(";")[0].split("=", 1) + if len(cookie_parts) == 2: + ctx.cookies[cookie_parts[0].strip()] = cookie_parts[1].strip() + + # Extract tokens from response body (JSON) + body = result.body or "" + token_patterns = [ + (r'"(?:access_token|token|jwt|bearer|id_token)"\s*:\s*"([^"]+)"', "bearer"), + (r'"(?:api_key|apikey|api-key)"\s*:\s*"([^"]+)"', "api_key"), + (r'"(?:refresh_token)"\s*:\s*"([^"]+)"', "refresh"), + ] + + for pattern, token_type in token_patterns: + match = re.search(pattern, body, re.I) + if match: + ctx.tokens[token_type] = match.group(1) + + # Build auth headers + if "bearer" in ctx.tokens: + ctx.headers["Authorization"] = f"Bearer {ctx.tokens['bearer']}" + elif "api_key" in ctx.tokens: + ctx.headers["X-API-Key"] = ctx.tokens["api_key"] + + # --- Session Management ---------------------------------------------- + + def detect_session_expiry(self, body: str, status: int) -> bool: + """Check if a response indicates session expiry.""" + if status in (401, 403): + return True + + body_lower = (body or "").lower() + return any(kw in body_lower for kw in self.EXPIRY_INDICATORS) + + async def refresh(self, context_name: Optional[str] = None) -> bool: + """Refresh an expired session by re-authenticating. + + If context_name is None, refresh all expired sessions. + """ + contexts_to_refresh = [] + if context_name: + ctx = self.contexts.get(context_name) + if ctx and ctx.state == "expired": + contexts_to_refresh.append(context_name) + else: + for name, ctx in self.contexts.items(): + if ctx.state == "expired": + contexts_to_refresh.append(name) + + results = [] + for name in contexts_to_refresh: + ctx = self.contexts[name] + ctx.state = "unauthenticated" + ctx.cookies.clear() + ctx.tokens.clear() + ctx.headers.clear() + success = await self.authenticate(name) + results.append(success) + + return all(results) if results else False + + def check_and_mark_expiry(self, context_name: str, body: str, status: int) -> bool: + """Check response for expiry and mark context if expired. + + Returns True if session was detected as expired. + """ + ctx = self.contexts.get(context_name) + if not ctx or ctx.state != "authenticated": + return False + + if self.detect_session_expiry(body, status): + ctx.state = "expired" + logger.info(f"Session expired for {context_name}") + return True + + # Check time-based expiry + if ctx.login_time and (time.time() - ctx.login_time) > ctx.session_duration: + ctx.state = "expired" + logger.info(f"Session timeout for {context_name}") + return True + + return False + + # --- Request Integration --------------------------------------------- + + def get_context(self, context_name: str) -> Optional[SessionContext]: + """Get a session context by name.""" + return self.contexts.get(context_name) + + def get_request_kwargs(self, context_name: str) -> Dict: + """Get headers and cookies for requests as a context. + + Returns dict with 'headers' and 'cookies' ready for request_engine. + """ + ctx = self.contexts.get(context_name) + if not ctx or ctx.state != "authenticated": + return {"headers": {}, "cookies": {}} + + return { + "headers": dict(ctx.headers), + "cookies": dict(ctx.cookies), + } + + def is_authenticated(self, context_name: str) -> bool: + """Check if a context is currently authenticated.""" + ctx = self.contexts.get(context_name) + return ctx is not None and ctx.state == "authenticated" + + def get_auth_summary(self) -> Dict: + """Get summary of authentication state for reporting.""" + return { + "contexts": { + name: { + "state": ctx.state, + "role": ctx.role, + "credential": ctx.credential.username if ctx.credential else None, + "has_tokens": bool(ctx.tokens), + "has_cookies": bool(ctx.cookies), + } + for name, ctx in self.contexts.items() + }, + "login_forms_found": len(self._login_forms), + "login_attempts": self._login_attempts, + "successful_logins": self._successful_logins, + "credentials_available": { + role: len(creds) + for role, creds in self._credentials.items() + }, + } diff --git a/backend/core/autonomous_agent.py b/backend/core/autonomous_agent.py index 89318d9..8177bb2 100644 --- a/backend/core/autonomous_agent.py +++ b/backend/core/autonomous_agent.py @@ -21,6 +21,30 @@ from urllib.parse import urljoin, urlparse, parse_qs, urlencode from enum import Enum from pathlib import Path +from backend.core.agent_memory import AgentMemory +from backend.core.vuln_engine.registry import VulnerabilityRegistry +from backend.core.vuln_engine.payload_generator import PayloadGenerator +from backend.core.response_verifier import ResponseVerifier +from backend.core.negative_control import NegativeControlEngine +from backend.core.proof_of_execution import ProofOfExecution +from backend.core.confidence_scorer import ConfidenceScorer +from backend.core.validation_judge import ValidationJudge +from backend.core.vuln_engine.system_prompts import get_system_prompt, get_prompt_for_vuln_type +from backend.core.vuln_engine.ai_prompts import get_verification_prompt, get_poc_prompt +from backend.core.access_control_learner import AccessControlLearner +from backend.core.request_engine import RequestEngine, ErrorType +from backend.core.waf_detector import WAFDetector +from backend.core.strategy_adapter import StrategyAdapter +from backend.core.chain_engine import ChainEngine +from backend.core.auth_manager import AuthManager + +try: + from core.browser_validator import BrowserValidator, embed_screenshot, HAS_PLAYWRIGHT +except ImportError: + HAS_PLAYWRIGHT = False + BrowserValidator = None + embed_screenshot = None + # Try to import anthropic for Claude API try: import anthropic @@ -37,6 +61,13 @@ except ImportError: OPENAI_AVAILABLE = False openai = None +# Security sandbox (Docker-based real tools) +try: + from core.sandbox_manager import get_sandbox, SandboxManager + HAS_SANDBOX = True +except ImportError: + HAS_SANDBOX = False + class OperationMode(Enum): """Agent operation modes""" @@ -44,6 +75,7 @@ class OperationMode(Enum): FULL_AUTO = "full_auto" PROMPT_ONLY = "prompt_only" ANALYZE_ONLY = "analyze_only" + AUTO_PENTEST = "auto_pentest" class FindingSeverity(Enum): @@ -83,8 +115,16 @@ class Finding: poc_code: str = "" remediation: str = "" references: List[str] = field(default_factory=list) + screenshots: List[str] = field(default_factory=list) + affected_urls: List[str] = field(default_factory=list) ai_verified: bool = False - confidence: str = "high" + confidence: str = "0" # Numeric string "0"-"100" + confidence_score: int = 0 # Numeric confidence score 0-100 + confidence_breakdown: Dict = field(default_factory=dict) # Scoring breakdown + proof_of_execution: str = "" # What proof was found + negative_controls: str = "" # Control test results + ai_status: str = "confirmed" # "confirmed" | "rejected" | "pending" + rejection_reason: str = "" @dataclass @@ -369,6 +409,61 @@ class LLMConnectionError(Exception): pass +DEFAULT_ASSESSMENT_PROMPT = """You are NeuroSploit, an elite autonomous penetration testing AI agent. +Your mission: identify real, exploitable vulnerabilities — zero false positives. + +## METHODOLOGY (PTES/OWASP/WSTG aligned) + +### Phase 1 — Reconnaissance & Fingerprinting +- Discover all endpoints, parameters, forms, API paths, WebSocket URLs +- Technology fingerprinting: language, framework, server, WAF, CDN +- Identify attack surface: file upload, auth endpoints, admin panels, GraphQL + +### Phase 2 — Technology-Guided Prioritization +Select vulnerability types based on detected technology stack: +- PHP/Laravel → LFI, command injection, SSTI (Blade), SQLi, file upload +- Node.js/Express → NoSQL injection, SSRF, prototype pollution, SSTI (EJS/Pug) +- Python/Django/Flask → SSTI (Jinja2), command injection, IDOR, mass assignment +- Java/Spring → XXE, insecure deserialization, expression language injection, SSRF +- ASP.NET → path traversal, XXE, header injection, insecure deserialization +- API/REST → IDOR, BOLA, BFLA, JWT manipulation, mass assignment, rate limiting +- GraphQL → introspection, injection, DoS via nested queries +- WordPress → file upload, SQLi, XSS, exposed admin, plugin vulns + +### Phase 3 — Active Testing (100 vuln types available) +**OWASP Top 10 2021 coverage:** +- A01 Broken Access Control: IDOR, BOLA, BFLA, privilege escalation, forced browsing, CORS +- A02 Cryptographic Failures: weak encryption/hashing, cleartext transmission, SSL issues +- A03 Injection: SQLi (error/union/blind/time), NoSQL, LDAP, XPath, command, SSTI, XSS, XXE +- A04 Insecure Design: business logic, race condition, mass assignment +- A05 Security Misconfiguration: headers, debug mode, directory listing, default creds +- A06 Vulnerable Components: outdated dependencies, insecure CDN +- A07 Auth Failures: JWT, session fixation, brute force, 2FA bypass, OAuth misconfig +- A08 Data Integrity: insecure deserialization, cache poisoning, HTTP smuggling +- A09 Logging Failures: log injection, improper error handling +- A10 SSRF: standard SSRF, cloud metadata SSRF + +### Phase 4 — Verification (multi-signal) +Every finding MUST have: +1. Concrete HTTP evidence (request + response) +2. At least 2 verification signals OR high-confidence tester match +3. No speculative language — only confirmed exploitable issues +4. Screenshot capture when possible + +### Phase 5 — Reporting +- Each finding: title, severity, CVSS 3.1, CWE, PoC, impact, remediation +- Prioritized by real-world exploitability +- Executive summary with risk rating + +## CRITICAL RULES +- NEVER report theoretical/speculative vulnerabilities +- ALWAYS verify with real HTTP evidence before confirming +- Test systematically: every parameter, every endpoint, every form +- Use technology hints to select the most relevant tests +- Capture baseline responses before testing for accurate diff-based detection +""" + + class AutonomousAgent: """ AI-Powered Autonomous Security Agent @@ -376,79 +471,96 @@ class AutonomousAgent: Performs real security testing with AI-powered analysis """ - # Comprehensive payload sets for testing - PAYLOADS = { - "sqli": [ - "'", "\"", "' OR '1'='1", "\" OR \"1\"=\"1", "' OR 1=1--", - "admin'--", "1' AND '1'='1", "1 AND 1=1", "' UNION SELECT NULL--", - "1' AND SLEEP(5)--", "1' WAITFOR DELAY '0:0:5'--", - "1'; DROP TABLE users--", "' OR ''='", "1' ORDER BY 1--" - ], - "xss": [ - "", "