Add files via upload

This commit is contained in:
Joas A Santos
2026-02-11 10:47:33 -03:00
committed by GitHub
parent e32573a950
commit 30acd5afc7
52 changed files with 22492 additions and 706 deletions
+327 -36
View File
@@ -32,6 +32,8 @@ agent_instances: Dict[str, AutonomousAgent] = {}
# Map agent_id to scan_id for database persistence
agent_to_scan: Dict[str, str] = {}
# Reverse map: scan_id to agent_id for ScanDetailsPage lookups
scan_to_agent: Dict[str, str] = {}
@router.get("/status")
@@ -101,6 +103,7 @@ class AgentMode(str, Enum):
RECON_ONLY = "recon_only" # Just reconnaissance
PROMPT_ONLY = "prompt_only" # AI decides (high tokens)
ANALYZE_ONLY = "analyze_only" # Analysis without testing
AUTO_PENTEST = "auto_pentest" # One-click full auto pentest
class AgentRequest(BaseModel):
@@ -113,6 +116,8 @@ class AgentRequest(BaseModel):
auth_value: Optional[str] = Field(None, description="Auth value (cookie string, token, etc)")
custom_headers: Optional[Dict[str, str]] = Field(None, description="Custom HTTP headers")
max_depth: int = Field(5, description="Maximum crawl depth")
subdomain_discovery: bool = Field(False, description="Enable subdomain discovery (auto_pentest mode)")
targets: Optional[List[str]] = Field(None, description="Multiple targets (auto_pentest mode)")
class AgentResponse(BaseModel):
@@ -193,7 +198,9 @@ async def run_agent(request: AgentRequest, background_tasks: BackgroundTasks):
"findings": [],
"report": None,
"progress": 0,
"phase": "initializing"
"phase": "initializing",
"rejected_findings": [],
"rejected_findings_count": 0,
}
# Run agent in background
@@ -212,7 +219,8 @@ async def run_agent(request: AgentRequest, background_tasks: BackgroundTasks):
"full_auto": "Full autonomous pentest: Recon -> Analyze -> Test -> Report",
"recon_only": "Reconnaissance only, no vulnerability testing",
"prompt_only": "AI decides everything (high token usage!)",
"analyze_only": "Analysis only, no active testing"
"analyze_only": "Analysis only, no active testing",
"auto_pentest": "One-click auto pentest: Full recon + 100 vuln types + AI report"
}
return AgentResponse(
@@ -255,12 +263,20 @@ async def _run_agent_task(
agent_results[agent_id]["progress"] = progress
agent_results[agent_id]["phase"] = phase
rejected_findings_list = []
async def finding_callback(finding: Dict):
"""Real-time finding callback - updates in-memory storage immediately"""
findings_list.append(finding)
if agent_id in agent_results:
agent_results[agent_id]["findings"] = findings_list
agent_results[agent_id]["findings_count"] = len(findings_list)
if finding.get("ai_status") == "rejected":
rejected_findings_list.append(finding)
if agent_id in agent_results:
agent_results[agent_id]["rejected_findings"] = rejected_findings_list
agent_results[agent_id]["rejected_findings_count"] = len(rejected_findings_list)
else:
findings_list.append(finding)
if agent_id in agent_results:
agent_results[agent_id]["findings"] = findings_list
agent_results[agent_id]["findings_count"] = len(findings_list)
try:
# Create database session and scan record
@@ -289,8 +305,9 @@ async def _run_agent_task(
db.add(target_record)
await db.commit()
# Store mapping
# Store mapping (both directions)
agent_to_scan[agent_id] = scan_id
scan_to_agent[scan_id] = agent_id
agent_results[agent_id]["scan_id"] = scan_id
# Map mode
@@ -299,6 +316,7 @@ async def _run_agent_task(
AgentMode.RECON_ONLY: OperationMode.RECON_ONLY,
AgentMode.PROMPT_ONLY: OperationMode.PROMPT_ONLY,
AgentMode.ANALYZE_ONLY: OperationMode.ANALYZE_ONLY,
AgentMode.AUTO_PENTEST: OperationMode.AUTO_PENTEST,
}
op_mode = mode_map.get(mode, OperationMode.FULL_AUTO)
@@ -311,6 +329,7 @@ async def _run_agent_task(
task=task,
custom_prompt=custom_prompt or (task.prompt if task else None),
finding_callback=finding_callback,
scan_id=str(scan_id),
) as agent:
# Store agent instance for stop functionality
agent_instances[agent_id] = agent
@@ -345,7 +364,41 @@ async def _run_agent_task(
impact=finding.get("impact", ""),
remediation=finding.get("remediation", ""),
references=finding.get("references", []),
ai_analysis=finding.get("ai_analysis", finding.get("exploitation_steps", ""))
ai_analysis=finding.get("ai_analysis", finding.get("exploitation_steps", "")),
poc_code=finding.get("poc_code", ""),
screenshots=finding.get("screenshots", []),
url=finding.get("url", finding.get("affected_endpoint", "")),
parameter=finding.get("parameter", finding.get("poc_parameter", "")),
validation_status="ai_confirmed",
)
db.add(vuln)
# Save rejected findings to database for manual review
for finding in report.get("rejected_findings", []):
vuln = Vulnerability(
scan_id=scan_id,
title=finding.get("title", finding.get("type", "Unknown")),
vulnerability_type=finding.get("vulnerability_type", finding.get("type", "unknown")),
severity=finding.get("severity", "medium").lower(),
cvss_score=finding.get("cvss_score"),
cvss_vector=finding.get("cvss_vector"),
cwe_id=finding.get("cwe_id"),
description=finding.get("description", finding.get("evidence", "")),
affected_endpoint=finding.get("affected_endpoint", finding.get("endpoint", finding.get("url", target))),
poc_payload=finding.get("payload", finding.get("poc_payload", "")),
poc_parameter=finding.get("parameter", finding.get("poc_parameter", "")),
poc_evidence=finding.get("evidence", finding.get("poc_evidence", "")),
poc_request=str(finding.get("request", finding.get("poc_request", "")))[:5000],
poc_response=str(finding.get("response", finding.get("poc_response", "")))[:5000],
impact=finding.get("impact", ""),
remediation=finding.get("remediation", ""),
references=finding.get("references", []),
poc_code=finding.get("poc_code", ""),
screenshots=finding.get("screenshots", []),
url=finding.get("url", finding.get("affected_endpoint", "")),
parameter=finding.get("parameter", finding.get("poc_parameter", "")),
validation_status="ai_rejected",
ai_rejection_reason=finding.get("rejection_reason", ""),
)
db.add(vuln)
@@ -402,6 +455,7 @@ async def _run_agent_task(
agent_results[agent_id]["report"] = report
agent_results[agent_id]["report_id"] = report_record.id
agent_results[agent_id]["findings"] = findings
agent_results[agent_id]["tool_executions"] = report.get("tool_executions", [])
agent_results[agent_id]["progress"] = 100
agent_results[agent_id]["phase"] = "completed"
@@ -429,6 +483,37 @@ async def _run_agent_task(
pass
@router.get("/by-scan/{scan_id}")
async def get_agent_by_scan(scan_id: str):
"""Look up agent status by scan_id (reverse lookup for ScanDetailsPage)"""
agent_id = scan_to_agent.get(scan_id)
if not agent_id:
raise HTTPException(status_code=404, detail="No agent found for this scan")
if agent_id in agent_results:
result = agent_results[agent_id]
return {
"agent_id": agent_id,
"scan_id": scan_id,
"status": result["status"],
"mode": result.get("mode", "full_auto"),
"target": result["target"],
"progress": result.get("progress", 0),
"phase": result.get("phase", "unknown"),
"started_at": result.get("started_at"),
"completed_at": result.get("completed_at"),
"findings_count": len(result.get("findings", [])),
"findings": result.get("findings", []),
"rejected_findings_count": len(result.get("rejected_findings", [])),
"rejected_findings": result.get("rejected_findings", []),
"logs_count": len(result.get("logs", [])),
"report": result.get("report"),
"error": result.get("error")
}
raise HTTPException(status_code=404, detail="Agent data no longer in memory")
@router.get("/status/{agent_id}")
async def get_agent_status(agent_id: str):
"""Get the status and results of an agent run - with database fallback"""
@@ -449,6 +534,8 @@ async def get_agent_status(agent_id: str):
"logs_count": len(result.get("logs", [])),
"findings_count": len(result.get("findings", [])),
"findings": result.get("findings", []),
"rejected_findings_count": len(result.get("rejected_findings", [])),
"rejected_findings": result.get("rejected_findings", []),
"report": result.get("report"),
"error": result.get("error")
}
@@ -495,10 +582,12 @@ async def _get_status_from_db(agent_id: str, scan_id: str):
"evidence": getattr(v, 'poc_evidence', None) or "",
"request": v.poc_request or "",
"response": v.poc_response or "",
"poc_code": v.poc_payload or "",
"poc_code": getattr(v, 'poc_code', None) or v.poc_payload or "",
"impact": v.impact or "",
"remediation": v.remediation or "",
"references": v.references or [],
"screenshots": getattr(v, 'screenshots', None) or [],
"url": getattr(v, 'url', None) or v.affected_endpoint or "",
"ai_verified": True,
"confidence": "high"
}
@@ -542,14 +631,14 @@ async def _get_status_from_db(agent_id: str, scan_id: str):
@router.post("/stop/{agent_id}")
async def stop_agent(agent_id: str):
"""Stop a running agent scan and auto-generate report"""
"""Stop a running agent scan, save all findings to DB, and generate report."""
if agent_id not in agent_results:
raise HTTPException(status_code=404, detail="Agent not found")
if agent_results[agent_id]["status"] != "running":
return {"message": "Agent is not running", "status": agent_results[agent_id]["status"]}
# Cancel the agent
# Cancel the agent immediately
if agent_id in agent_instances:
agent_instances[agent_id].cancel()
@@ -558,9 +647,10 @@ async def stop_agent(agent_id: str):
agent_results[agent_id]["phase"] = "stopped"
agent_results[agent_id]["completed_at"] = datetime.utcnow().isoformat()
# Update database and auto-generate report
# Update database: save findings + generate report
scan_id = agent_to_scan.get(agent_id)
report_id = None
target = agent_results[agent_id].get("target", "Unknown")
if scan_id:
try:
@@ -573,47 +663,222 @@ async def stop_agent(agent_id: str):
scan.status = "stopped"
scan.completed_at = datetime.utcnow()
# Get findings count
# Save confirmed findings to DB (same as completion flow)
findings = agent_results[agent_id].get("findings", [])
scan.total_vulnerabilities = len(findings)
severity_counts = {"critical": 0, "high": 0, "medium": 0, "low": 0, "info": 0}
# Count severities
for finding in findings:
severity = finding.get("severity", "").lower()
if severity == "critical":
scan.critical_count = (scan.critical_count or 0) + 1
elif severity == "high":
scan.high_count = (scan.high_count or 0) + 1
elif severity == "medium":
scan.medium_count = (scan.medium_count or 0) + 1
elif severity == "low":
scan.low_count = (scan.low_count or 0) + 1
elif severity == "info":
scan.info_count = (scan.info_count or 0) + 1
severity = finding.get("severity", "medium").lower()
if severity in severity_counts:
severity_counts[severity] += 1
vuln = Vulnerability(
scan_id=scan_id,
title=finding.get("title", finding.get("type", "Unknown")),
vulnerability_type=finding.get("vulnerability_type", finding.get("type", "unknown")),
severity=severity,
cvss_score=finding.get("cvss_score"),
cvss_vector=finding.get("cvss_vector"),
cwe_id=finding.get("cwe_id"),
description=finding.get("description", finding.get("evidence", "")),
affected_endpoint=finding.get("affected_endpoint", finding.get("endpoint", finding.get("url", target))),
poc_payload=finding.get("payload", finding.get("poc_payload", "")),
poc_parameter=finding.get("parameter", finding.get("poc_parameter", "")),
poc_evidence=finding.get("evidence", finding.get("poc_evidence", "")),
poc_request=str(finding.get("request", finding.get("poc_request", "")))[:5000],
poc_response=str(finding.get("response", finding.get("poc_response", "")))[:5000],
impact=finding.get("impact", ""),
remediation=finding.get("remediation", ""),
references=finding.get("references", []),
ai_analysis=finding.get("ai_analysis", finding.get("exploitation_steps", "")),
poc_code=finding.get("poc_code", ""),
screenshots=finding.get("screenshots", []),
url=finding.get("url", finding.get("affected_endpoint", "")),
parameter=finding.get("parameter", finding.get("poc_parameter", "")),
validation_status="ai_confirmed",
)
db.add(vuln)
# Save rejected findings to DB for manual review
rejected = agent_results[agent_id].get("rejected_findings", [])
for finding in rejected:
vuln = Vulnerability(
scan_id=scan_id,
title=finding.get("title", finding.get("type", "Unknown")),
vulnerability_type=finding.get("vulnerability_type", finding.get("type", "unknown")),
severity=finding.get("severity", "medium").lower(),
cvss_score=finding.get("cvss_score"),
cvss_vector=finding.get("cvss_vector"),
cwe_id=finding.get("cwe_id"),
description=finding.get("description", finding.get("evidence", "")),
affected_endpoint=finding.get("affected_endpoint", finding.get("endpoint", finding.get("url", target))),
poc_payload=finding.get("payload", finding.get("poc_payload", "")),
poc_parameter=finding.get("parameter", finding.get("poc_parameter", "")),
poc_evidence=finding.get("evidence", finding.get("poc_evidence", "")),
poc_request=str(finding.get("request", finding.get("poc_request", "")))[:5000],
poc_response=str(finding.get("response", finding.get("poc_response", "")))[:5000],
impact=finding.get("impact", ""),
remediation=finding.get("remediation", ""),
references=finding.get("references", []),
poc_code=finding.get("poc_code", ""),
screenshots=finding.get("screenshots", []),
url=finding.get("url", finding.get("affected_endpoint", "")),
parameter=finding.get("parameter", finding.get("poc_parameter", "")),
validation_status="ai_rejected",
ai_rejection_reason=finding.get("rejection_reason", ""),
)
db.add(vuln)
# Update scan counts (confirmed only)
scan.total_vulnerabilities = len(findings)
scan.critical_count = severity_counts["critical"]
scan.high_count = severity_counts["high"]
scan.medium_count = severity_counts["medium"]
scan.low_count = severity_counts["low"]
scan.info_count = severity_counts["info"]
await db.commit()
# Auto-generate report
report = Report(
# Auto-generate report record
report_record = Report(
scan_id=scan_id,
title=f"Agent Scan Report - {agent_results[agent_id].get('target', 'Unknown')}",
title=f"Agent Scan Report - {target}",
format="json",
executive_summary=f"Automated security scan completed with {len(findings)} findings."
executive_summary=f"Security scan stopped with {len(findings)} confirmed and {len(rejected)} rejected findings."
)
db.add(report)
db.add(report_record)
await db.commit()
await db.refresh(report)
report_id = report.id
await db.refresh(report_record)
report_id = report_record.id
except Exception as e:
print(f"Error updating scan status: {e}")
print(f"Error updating scan status on stop: {e}")
import traceback
traceback.print_exc()
return {
"message": "Agent stopped successfully",
"agent_id": agent_id,
"report_id": report_id
"report_id": report_id,
"findings_saved": len(agent_results[agent_id].get("findings", [])),
"rejected_saved": len(agent_results[agent_id].get("rejected_findings", [])),
}
@router.post("/pause/{agent_id}")
async def pause_agent(agent_id: str):
"""Pause a running agent scan"""
if agent_id not in agent_results:
raise HTTPException(status_code=404, detail="Agent not found")
if agent_results[agent_id]["status"] != "running":
return {"message": "Agent is not running", "status": agent_results[agent_id]["status"]}
if agent_id in agent_instances:
agent_instances[agent_id].pause()
# Save current phase before overwriting with "paused"
agent_results[agent_id]["last_phase"] = agent_results[agent_id].get("phase", "recon")
agent_results[agent_id]["status"] = "paused"
agent_results[agent_id]["phase"] = "paused"
return {"message": "Agent paused", "agent_id": agent_id}
@router.post("/resume/{agent_id}")
async def resume_agent(agent_id: str):
"""Resume a paused agent scan"""
if agent_id not in agent_results:
raise HTTPException(status_code=404, detail="Agent not found")
if agent_results[agent_id]["status"] != "paused":
return {"message": "Agent is not paused", "status": agent_results[agent_id]["status"]}
if agent_id in agent_instances:
agent_instances[agent_id].resume()
agent_results[agent_id]["status"] = "running"
# Restore the phase that was active before pause
agent_results[agent_id]["phase"] = agent_results[agent_id].get("last_phase", "testing")
return {"message": "Agent resumed", "agent_id": agent_id}
# Agent phase order for skip validation
AGENT_PHASE_ORDER = ["recon", "analysis", "testing", "enhancement", "completed"]
# Map phase names from status strings to canonical phase keys
PHASE_NORMALIZE = {
"starting reconnaissance": "recon",
"reconnaissance complete": "recon",
"initial probe complete": "recon",
"endpoint discovery complete": "recon",
"parameter discovery complete": "recon",
"attack surface analyzed": "analysis",
"vulnerability testing complete": "testing",
"findings enhanced": "enhancement",
"assessment complete": "completed",
}
@router.post("/skip-to/{agent_id}/{target_phase}")
async def skip_agent_phase(agent_id: str, target_phase: str):
"""Skip the current agent phase and jump to a target phase.
Valid phases: recon, analysis, testing, enhancement, completed
Can only skip forward (to a phase ahead of current).
"""
if agent_id not in agent_results:
raise HTTPException(status_code=404, detail="Agent not found")
agent_status = agent_results[agent_id]["status"]
if agent_status not in ("running", "paused"):
raise HTTPException(status_code=400, detail="Agent is not running or paused")
if target_phase not in AGENT_PHASE_ORDER:
raise HTTPException(
status_code=400,
detail=f"Invalid phase '{target_phase}'. Valid: {', '.join(AGENT_PHASE_ORDER[1:])}"
)
# Get current phase and normalize it
current_raw = agent_results[agent_id].get("phase", "").lower()
# Handle "paused" phase — use the last known non-paused phase, default to recon
if current_raw in ("paused", "stopped"):
current_raw = agent_results[agent_id].get("last_phase", "recon")
current_phase = PHASE_NORMALIZE.get(current_raw, current_raw)
# Also try prefix match
for key in AGENT_PHASE_ORDER:
if key in current_phase:
current_phase = key
break
cur_idx = AGENT_PHASE_ORDER.index(current_phase) if current_phase in AGENT_PHASE_ORDER else 0
tgt_idx = AGENT_PHASE_ORDER.index(target_phase)
if tgt_idx <= cur_idx:
raise HTTPException(
status_code=400,
detail=f"Cannot skip backward. Current: {current_phase}, target: {target_phase}"
)
# Signal the agent instance to skip
if agent_id in agent_instances:
# If paused, resume first so the skip can be processed
if agent_status == "paused":
agent_instances[agent_id].resume()
agent_results[agent_id]["status"] = "running"
success = agent_instances[agent_id].skip_to_phase(target_phase)
if not success:
raise HTTPException(status_code=500, detail="Failed to signal phase skip")
else:
raise HTTPException(status_code=400, detail="Agent instance not available for signaling")
return {
"message": f"Skipping to phase: {target_phase}",
"agent_id": agent_id,
"from_phase": current_phase,
"target_phase": target_phase
}
@@ -1711,7 +1976,10 @@ async def _save_realtime_findings_to_db(session_id: str, session: Dict):
impact=finding.get("impact", ""),
remediation=finding.get("remediation", ""),
references=finding.get("references", []),
ai_analysis=f"Identified during realtime session {session_id}"
ai_analysis=f"Identified during realtime session {session_id}",
screenshots=finding.get("screenshots", []),
url=finding.get("url", finding.get("affected_endpoint", "")),
parameter=finding.get("parameter", "")
)
db.add(vuln)
@@ -1809,6 +2077,29 @@ async def generate_realtime_report(session_id: str, format: str = "json"):
scan_results=tool_results
)
# Save to a per-report folder with screenshots
import shutil
from pathlib import Path
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
target_name = session["target"].replace("://", "_").replace("/", "_").rstrip("_")[:40]
report_dir = Path("reports") / f"report_{target_name}_{timestamp}"
report_dir.mkdir(parents=True, exist_ok=True)
(report_dir / f"report_{timestamp}.html").write_text(html_content)
# Copy screenshots into report folder
screenshots_src = Path("reports") / "screenshots"
if screenshots_src.exists():
screenshots_dest = report_dir / "screenshots"
for finding in findings:
fid = finding.get("id", "")
if fid:
src_dir = screenshots_src / str(fid)
if src_dir.exists():
dest_dir = screenshots_dest / str(fid)
dest_dir.mkdir(parents=True, exist_ok=True)
for ss_file in src_dir.glob("*.png"):
shutil.copy2(ss_file, dest_dir / ss_file.name)
return HTMLResponse(content=html_content, media_type="text/html")
return {
+166 -1
View File
@@ -64,6 +64,19 @@ async def generate_report(
)
vulnerabilities = vulns_result.scalars().all()
# Try to get tool_executions from agent in-memory results
tool_executions = []
try:
from backend.api.v1.agent import scan_to_agent, agent_results
agent_id = scan_to_agent.get(report_data.scan_id)
if agent_id and agent_id in agent_results:
tool_executions = agent_results[agent_id].get("tool_executions", [])
if not tool_executions:
rpt = agent_results[agent_id].get("report", {})
tool_executions = rpt.get("tool_executions", []) if isinstance(rpt, dict) else []
except Exception:
pass
# Generate report
generator = ReportGenerator()
report_path, executive_summary = await generator.generate(
@@ -73,7 +86,8 @@ async def generate_report(
title=report_data.title,
include_executive_summary=report_data.include_executive_summary,
include_poc=report_data.include_poc,
include_remediation=report_data.include_remediation
include_remediation=report_data.include_remediation,
tool_executions=tool_executions,
)
# Save report record
@@ -91,6 +105,63 @@ async def generate_report(
return ReportResponse(**report.to_dict())
@router.post("/ai-generate", response_model=ReportResponse)
async def generate_ai_report(
report_data: ReportGenerate,
db: AsyncSession = Depends(get_db)
):
"""Generate an AI-enhanced report with LLM-written executive summary and per-finding analysis."""
# Get scan
scan_result = await db.execute(select(Scan).where(Scan.id == report_data.scan_id))
scan = scan_result.scalar_one_or_none()
if not scan:
raise HTTPException(status_code=404, detail="Scan not found")
# Get vulnerabilities
vulns_result = await db.execute(
select(Vulnerability).where(Vulnerability.scan_id == report_data.scan_id)
)
vulnerabilities = vulns_result.scalars().all()
# Try to get tool_executions from agent in-memory results
tool_executions = []
try:
from backend.api.v1.agent import scan_to_agent, agent_results
agent_id = scan_to_agent.get(report_data.scan_id)
if agent_id and agent_id in agent_results:
tool_executions = agent_results[agent_id].get("tool_executions", [])
# Also check nested report
if not tool_executions:
rpt = agent_results[agent_id].get("report", {})
tool_executions = rpt.get("tool_executions", []) if isinstance(rpt, dict) else []
except Exception:
pass
# Generate AI report
generator = ReportGenerator()
report_path, ai_summary = await generator.generate_ai_report(
scan=scan,
vulnerabilities=vulnerabilities,
tool_executions=tool_executions,
title=report_data.title,
)
# Save report record
report = Report(
scan_id=scan.id,
title=report_data.title or f"AI Report - {scan.name}",
format="html",
file_path=str(report_path),
executive_summary=ai_summary[:2000] if ai_summary else None
)
db.add(report)
await db.commit()
await db.refresh(report)
return ReportResponse(**report.to_dict())
@router.get("/{report_id}", response_model=ReportResponse)
async def get_report(report_id: str, db: AsyncSession = Depends(get_db)):
"""Get report details"""
@@ -187,6 +258,100 @@ async def download_report(
)
@router.get("/{report_id}/download-zip")
async def download_report_zip(
report_id: str,
db: AsyncSession = Depends(get_db)
):
"""Download report as ZIP with screenshots included"""
import zipfile
import tempfile
import hashlib
result = await db.execute(select(Report).where(Report.id == report_id))
report = result.scalar_one_or_none()
if not report:
raise HTTPException(status_code=404, detail="Report not found")
scan_result = await db.execute(select(Scan).where(Scan.id == report.scan_id))
scan = scan_result.scalar_one_or_none()
if not scan:
raise HTTPException(status_code=404, detail="Scan not found for report")
vulns_result = await db.execute(
select(Vulnerability).where(Vulnerability.scan_id == report.scan_id)
)
vulnerabilities = vulns_result.scalars().all()
# Generate HTML report
generator = ReportGenerator()
report_path, _ = await generator.generate(
scan=scan,
vulnerabilities=vulnerabilities,
format="html",
title=report.title
)
# Collect screenshots (use absolute path via settings.BASE_DIR)
# Check scan-scoped path first, then legacy flat path
screenshots_base = settings.BASE_DIR / "reports" / "screenshots"
scan_id_str = str(scan.id) if scan else None
screenshot_files = []
for vuln in vulnerabilities:
# Finding ID is md5(vuln_type+url+param)[:8]
vuln_url = getattr(vuln, 'url', None) or vuln.affected_endpoint or ''
vuln_param = getattr(vuln, 'parameter', None) or getattr(vuln, 'poc_parameter', None) or ''
finding_id = hashlib.md5(
f"{vuln.vulnerability_type}{vuln_url}{vuln_param}".encode()
).hexdigest()[:8]
# Scan-scoped path: reports/screenshots/{scan_id}/{finding_id}/
finding_dir = None
if scan_id_str:
scan_dir = screenshots_base / scan_id_str / finding_id
if scan_dir.exists():
finding_dir = scan_dir
if not finding_dir:
legacy_dir = screenshots_base / finding_id
if legacy_dir.exists():
finding_dir = legacy_dir
if finding_dir:
for img in finding_dir.glob("*.png"):
screenshot_files.append((img, f"screenshots/{finding_id}/{img.name}"))
# Also include base64 screenshots from DB as files in the ZIP
db_screenshots = getattr(vuln, 'screenshots', None) or []
for idx, ss in enumerate(db_screenshots):
if isinstance(ss, str) and ss.startswith("data:image/"):
# Will be embedded in HTML, but also save as file
import base64 as b64
try:
b64_data = ss.split(",", 1)[1]
img_bytes = b64.b64decode(b64_data)
img_name = f"screenshots/{finding_id}/evidence_{idx+1}.png"
# Write to temp for ZIP inclusion
tmp_img = Path(tempfile.gettempdir()) / f"ss_{finding_id}_{idx}.png"
tmp_img.write_bytes(img_bytes)
screenshot_files.append((tmp_img, img_name))
except Exception:
pass
# Create ZIP
zip_name = Path(report_path).stem + ".zip"
zip_path = Path(tempfile.gettempdir()) / zip_name
with zipfile.ZipFile(str(zip_path), 'w', zipfile.ZIP_DEFLATED) as zf:
zf.write(report_path, "report.html")
for src_path, arc_name in screenshot_files:
zf.write(str(src_path), arc_name)
return FileResponse(
path=str(zip_path),
media_type="application/zip",
filename=zip_name
)
@router.delete("/{report_id}")
async def delete_report(report_id: str, db: AsyncSession = Depends(get_db)):
"""Delete a report"""
+130
View File
@@ -0,0 +1,130 @@
"""
NeuroSploit v3 - Sandbox Container Management API
Real-time monitoring and management of per-scan Kali Linux containers.
"""
from datetime import datetime
from fastapi import APIRouter, HTTPException
router = APIRouter()
def _docker_available() -> bool:
try:
import docker
docker.from_env().ping()
return True
except Exception:
return False
@router.get("/")
async def list_sandboxes():
"""List all sandbox containers with pool status."""
try:
from core.container_pool import get_pool
pool = get_pool()
except Exception as e:
return {
"pool": {
"active": 0,
"max_concurrent": 0,
"image": "neurosploit-kali:latest",
"container_ttl_minutes": 60,
"docker_available": _docker_available(),
},
"containers": [],
"error": str(e),
}
sandboxes = pool.list_sandboxes()
now = datetime.utcnow()
containers = []
for info in sandboxes.values():
created = info.get("created_at")
uptime = 0.0
if created:
try:
dt = datetime.fromisoformat(created)
uptime = (now - dt).total_seconds()
except Exception:
pass
containers.append({
**info,
"uptime_seconds": uptime,
})
return {
"pool": {
"active": pool.active_count,
"max_concurrent": pool.max_concurrent,
"image": pool.image,
"container_ttl_minutes": int(pool.container_ttl.total_seconds() / 60),
"docker_available": _docker_available(),
},
"containers": containers,
}
@router.get("/{scan_id}")
async def get_sandbox(scan_id: str):
"""Get health check for a specific sandbox container."""
try:
from core.container_pool import get_pool
pool = get_pool()
except Exception as e:
raise HTTPException(status_code=503, detail=str(e))
sandboxes = pool.list_sandboxes()
if scan_id not in sandboxes:
raise HTTPException(status_code=404, detail=f"No sandbox for scan {scan_id}")
sb = pool._sandboxes.get(scan_id)
if not sb:
raise HTTPException(status_code=404, detail=f"Sandbox instance not found")
health = await sb.health_check()
return health
@router.delete("/{scan_id}")
async def destroy_sandbox(scan_id: str):
"""Destroy a specific sandbox container."""
try:
from core.container_pool import get_pool
pool = get_pool()
except Exception as e:
raise HTTPException(status_code=503, detail=str(e))
sandboxes = pool.list_sandboxes()
if scan_id not in sandboxes:
raise HTTPException(status_code=404, detail=f"No sandbox for scan {scan_id}")
await pool.destroy(scan_id)
return {"message": f"Sandbox for scan {scan_id} destroyed", "scan_id": scan_id}
@router.post("/cleanup")
async def cleanup_expired():
"""Remove containers that have exceeded their TTL."""
try:
from core.container_pool import get_pool
pool = get_pool()
await pool.cleanup_expired()
return {"message": "Expired containers cleaned up"}
except Exception as e:
raise HTTPException(status_code=503, detail=str(e))
@router.post("/cleanup-orphans")
async def cleanup_orphans():
"""Remove orphan containers not tracked by the pool."""
try:
from core.container_pool import get_pool
pool = get_pool()
await pool.cleanup_orphans()
return {"message": "Orphan containers cleaned up"}
except Exception as e:
raise HTTPException(status_code=503, detail=str(e))
+204 -3
View File
@@ -4,6 +4,7 @@ NeuroSploit v3 - Scans API Endpoints
from typing import List, Optional
from datetime import datetime
from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks
from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, func
from urllib.parse import urlparse
@@ -11,7 +12,7 @@ from urllib.parse import urlparse
from backend.db.database import get_db
from backend.models import Scan, Target, Endpoint, Vulnerability
from backend.schemas.scan import ScanCreate, ScanUpdate, ScanResponse, ScanListResponse, ScanProgress
from backend.services.scan_service import run_scan_task
from backend.services.scan_service import run_scan_task, skip_to_phase as _skip_to_phase, PHASE_ORDER
router = APIRouter()
@@ -177,6 +178,7 @@ async def start_scan(
async def stop_scan(scan_id: str, db: AsyncSession = Depends(get_db)):
"""Stop a running scan and save partial results"""
from backend.api.websocket import manager as ws_manager
from backend.api.v1.agent import scan_to_agent, agent_instances, agent_results
result = await db.execute(select(Scan).where(Scan.id == scan_id))
scan = result.scalar_one_or_none()
@@ -184,8 +186,16 @@ async def stop_scan(scan_id: str, db: AsyncSession = Depends(get_db)):
if not scan:
raise HTTPException(status_code=404, detail="Scan not found")
if scan.status != "running":
raise HTTPException(status_code=400, detail="Scan is not running")
if scan.status not in ("running", "paused"):
raise HTTPException(status_code=400, detail="Scan is not running or paused")
# Signal the running agent to stop
agent_id = scan_to_agent.get(scan_id)
if agent_id and agent_id in agent_instances:
agent_instances[agent_id].cancel()
if agent_id in agent_results:
agent_results[agent_id]["status"] = "stopped"
agent_results[agent_id]["phase"] = "stopped"
# Update scan status
scan.status = "stopped"
@@ -259,6 +269,132 @@ async def stop_scan(scan_id: str, db: AsyncSession = Depends(get_db)):
}
@router.post("/{scan_id}/pause")
async def pause_scan(scan_id: str, db: AsyncSession = Depends(get_db)):
"""Pause a running scan"""
from backend.api.websocket import manager as ws_manager
from backend.api.v1.agent import scan_to_agent, agent_instances, agent_results
result = await db.execute(select(Scan).where(Scan.id == scan_id))
scan = result.scalar_one_or_none()
if not scan:
raise HTTPException(status_code=404, detail="Scan not found")
if scan.status != "running":
raise HTTPException(status_code=400, detail="Scan is not running")
# Signal the agent to pause
agent_id = scan_to_agent.get(scan_id)
if agent_id and agent_id in agent_instances:
agent_instances[agent_id].pause()
if agent_id in agent_results:
agent_results[agent_id]["status"] = "paused"
agent_results[agent_id]["phase"] = "paused"
scan.status = "paused"
scan.current_phase = "paused"
await db.commit()
await ws_manager.broadcast_log(scan_id, "warning", "Scan paused by user")
return {"message": "Scan paused", "scan_id": scan_id}
@router.post("/{scan_id}/resume")
async def resume_scan(scan_id: str, db: AsyncSession = Depends(get_db)):
"""Resume a paused scan"""
from backend.api.websocket import manager as ws_manager
from backend.api.v1.agent import scan_to_agent, agent_instances, agent_results
result = await db.execute(select(Scan).where(Scan.id == scan_id))
scan = result.scalar_one_or_none()
if not scan:
raise HTTPException(status_code=404, detail="Scan not found")
if scan.status != "paused":
raise HTTPException(status_code=400, detail="Scan is not paused")
# Signal the agent to resume
agent_id = scan_to_agent.get(scan_id)
if agent_id and agent_id in agent_instances:
agent_instances[agent_id].resume()
if agent_id in agent_results:
agent_results[agent_id]["status"] = "running"
agent_results[agent_id]["phase"] = "testing"
scan.status = "running"
scan.current_phase = "testing"
await db.commit()
await ws_manager.broadcast_log(scan_id, "info", "Scan resumed by user")
return {"message": "Scan resumed", "scan_id": scan_id}
@router.post("/{scan_id}/skip-to/{target_phase}")
async def skip_to_phase_endpoint(scan_id: str, target_phase: str, db: AsyncSession = Depends(get_db)):
"""Skip the current scan phase and jump to a target phase.
Valid phases: recon, analyzing, testing, completed
Can only skip forward (to a phase ahead of current).
"""
result = await db.execute(select(Scan).where(Scan.id == scan_id))
scan = result.scalar_one_or_none()
if not scan:
raise HTTPException(status_code=404, detail="Scan not found")
if scan.status not in ("running", "paused"):
raise HTTPException(status_code=400, detail="Scan is not running or paused")
# If paused, resume first so the skip can be processed
if scan.status == "paused":
from backend.api.v1.agent import scan_to_agent, agent_instances, agent_results
agent_id = scan_to_agent.get(scan_id)
if agent_id and agent_id in agent_instances:
agent_instances[agent_id].resume()
if agent_id in agent_results:
agent_results[agent_id]["status"] = "running"
agent_results[agent_id]["phase"] = agent_results[agent_id].get("last_phase", "testing")
scan.status = "running"
await db.commit()
if target_phase not in PHASE_ORDER:
raise HTTPException(
status_code=400,
detail=f"Invalid phase '{target_phase}'. Valid: {', '.join(PHASE_ORDER[1:])}"
)
# Validate forward skip
current_idx = PHASE_ORDER.index(scan.current_phase) if scan.current_phase in PHASE_ORDER else 0
target_idx = PHASE_ORDER.index(target_phase)
if target_idx <= current_idx:
raise HTTPException(
status_code=400,
detail=f"Cannot skip backward. Current: {scan.current_phase}, target: {target_phase}"
)
# Signal the running scan to skip
success = _skip_to_phase(scan_id, target_phase)
if not success:
raise HTTPException(status_code=500, detail="Failed to signal phase skip")
# Broadcast via WebSocket
from backend.api.websocket import manager as ws_manager
await ws_manager.broadcast_log(scan_id, "warning", f">> User requested skip to phase: {target_phase}")
await ws_manager.broadcast_phase_change(scan_id, f"skipping_to_{target_phase}")
return {
"message": f"Skipping to phase: {target_phase}",
"scan_id": scan_id,
"from_phase": scan.current_phase,
"target_phase": target_phase
}
@router.get("/{scan_id}/status", response_model=ScanProgress)
async def get_scan_status(scan_id: str, db: AsyncSession = Depends(get_db)):
"""Get scan progress and status"""
@@ -369,3 +505,68 @@ async def get_scan_vulnerabilities(
"page": page,
"per_page": per_page
}
class ValidationRequest(BaseModel):
validation_status: str # "validated" | "false_positive" | "ai_confirmed" | "ai_rejected" | "pending_review"
notes: Optional[str] = None
@router.patch("/vulnerabilities/{vuln_id}/validate")
async def validate_vulnerability(
vuln_id: str,
body: ValidationRequest,
db: AsyncSession = Depends(get_db)
):
"""Manually validate or reject a vulnerability finding"""
valid_statuses = {"validated", "false_positive", "ai_confirmed", "ai_rejected", "pending_review"}
if body.validation_status not in valid_statuses:
raise HTTPException(status_code=400, detail=f"Invalid status. Must be one of: {', '.join(valid_statuses)}")
result = await db.execute(select(Vulnerability).where(Vulnerability.id == vuln_id))
vuln = result.scalar_one_or_none()
if not vuln:
raise HTTPException(status_code=404, detail="Vulnerability not found")
old_status = vuln.validation_status or "ai_confirmed"
vuln.validation_status = body.validation_status
if body.notes:
vuln.ai_rejection_reason = body.notes
# Update scan severity counts when validation status changes
scan_result = await db.execute(select(Scan).where(Scan.id == vuln.scan_id))
scan = scan_result.scalar_one_or_none()
if scan:
sev = vuln.severity
# If changing from rejected to validated: add to counts
if old_status == "ai_rejected" and body.validation_status == "validated":
scan.total_vulnerabilities = (scan.total_vulnerabilities or 0) + 1
if sev == "critical":
scan.critical_count = (scan.critical_count or 0) + 1
elif sev == "high":
scan.high_count = (scan.high_count or 0) + 1
elif sev == "medium":
scan.medium_count = (scan.medium_count or 0) + 1
elif sev == "low":
scan.low_count = (scan.low_count or 0) + 1
elif sev == "info":
scan.info_count = (scan.info_count or 0) + 1
# If changing from confirmed to false_positive: subtract from counts
elif old_status in ("ai_confirmed", "validated") and body.validation_status == "false_positive":
scan.total_vulnerabilities = max(0, (scan.total_vulnerabilities or 0) - 1)
if sev == "critical":
scan.critical_count = max(0, (scan.critical_count or 0) - 1)
elif sev == "high":
scan.high_count = max(0, (scan.high_count or 0) - 1)
elif sev == "medium":
scan.medium_count = max(0, (scan.medium_count or 0) - 1)
elif sev == "low":
scan.low_count = max(0, (scan.low_count or 0) - 1)
elif sev == "info":
scan.info_count = max(0, (scan.info_count or 0) - 1)
await db.commit()
return {"message": "Vulnerability validation updated", "vulnerability": vuln.to_dict()}
+140
View File
@@ -0,0 +1,140 @@
"""
NeuroSploit v3 - Scheduler API Router
CRUD endpoints for managing scheduled scan jobs.
"""
import json
from pathlib import Path
from fastapi import APIRouter, HTTPException, Request
from pydantic import BaseModel
from typing import Optional, List, Dict
router = APIRouter()
CONFIG_PATH = Path(__file__).parent.parent.parent.parent / "config" / "config.json"
class ScheduleJobRequest(BaseModel):
"""Request model for creating a scheduled job."""
job_id: str
target: str
scan_type: str = "quick"
cron_expression: Optional[str] = None
interval_minutes: Optional[int] = None
agent_role: Optional[str] = None
llm_profile: Optional[str] = None
class ScheduleJobResponse(BaseModel):
"""Response model for a scheduled job."""
id: str
target: str
scan_type: str
schedule: str
status: str
next_run: Optional[str] = None
last_run: Optional[str] = None
run_count: int = 0
@router.get("/", response_model=List[Dict])
async def list_scheduled_jobs(request: Request):
"""List all scheduled scan jobs."""
scheduler = getattr(request.app.state, 'scheduler', None)
if not scheduler:
return []
return scheduler.list_jobs()
@router.post("/", response_model=Dict)
async def create_scheduled_job(job: ScheduleJobRequest, request: Request):
"""Create a new scheduled scan job."""
scheduler = getattr(request.app.state, 'scheduler', None)
if not scheduler:
raise HTTPException(status_code=503, detail="Scheduler not available")
if not job.cron_expression and not job.interval_minutes:
raise HTTPException(
status_code=400,
detail="Either cron_expression or interval_minutes must be provided"
)
result = scheduler.add_job(
job_id=job.job_id,
target=job.target,
scan_type=job.scan_type,
cron_expression=job.cron_expression,
interval_minutes=job.interval_minutes,
agent_role=job.agent_role,
llm_profile=job.llm_profile
)
if "error" in result:
raise HTTPException(status_code=400, detail=result["error"])
return result
@router.delete("/{job_id}")
async def delete_scheduled_job(job_id: str, request: Request):
"""Delete a scheduled scan job."""
scheduler = getattr(request.app.state, 'scheduler', None)
if not scheduler:
raise HTTPException(status_code=503, detail="Scheduler not available")
success = scheduler.remove_job(job_id)
if not success:
raise HTTPException(status_code=404, detail=f"Job '{job_id}' not found")
return {"message": f"Job '{job_id}' deleted", "id": job_id}
@router.post("/{job_id}/pause")
async def pause_scheduled_job(job_id: str, request: Request):
"""Pause a scheduled scan job."""
scheduler = getattr(request.app.state, 'scheduler', None)
if not scheduler:
raise HTTPException(status_code=503, detail="Scheduler not available")
success = scheduler.pause_job(job_id)
if not success:
raise HTTPException(status_code=404, detail=f"Job '{job_id}' not found")
return {"message": f"Job '{job_id}' paused", "id": job_id, "status": "paused"}
@router.post("/{job_id}/resume")
async def resume_scheduled_job(job_id: str, request: Request):
"""Resume a paused scheduled scan job."""
scheduler = getattr(request.app.state, 'scheduler', None)
if not scheduler:
raise HTTPException(status_code=503, detail="Scheduler not available")
success = scheduler.resume_job(job_id)
if not success:
raise HTTPException(status_code=404, detail=f"Job '{job_id}' not found")
return {"message": f"Job '{job_id}' resumed", "id": job_id, "status": "active"}
@router.get("/agent-roles", response_model=List[Dict])
async def get_agent_roles():
"""Return available agent roles from config.json for scheduler dropdown."""
try:
if not CONFIG_PATH.exists():
return []
config = json.loads(CONFIG_PATH.read_text())
roles = config.get("agent_roles", {})
result = []
for role_id, role_data in roles.items():
if role_data.get("enabled", True):
result.append({
"id": role_id,
"name": role_id.replace("_", " ").title(),
"description": role_data.get("description", ""),
"tools": role_data.get("tools_allowed", []),
})
return result
except Exception:
return []
+164 -18
View File
@@ -1,7 +1,10 @@
"""
NeuroSploit v3 - Settings API Endpoints
"""
from typing import Optional
import os
import re
from pathlib import Path
from typing import Optional, Dict
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, delete, text
@@ -12,16 +15,69 @@ from backend.models import Scan, Target, Endpoint, Vulnerability, VulnerabilityT
router = APIRouter()
# Path to .env file (project root)
ENV_FILE_PATH = Path(__file__).parent.parent.parent.parent / ".env"
def _update_env_file(updates: Dict[str, str]) -> bool:
"""
Update key=value pairs in the .env file without breaking formatting.
- If the key exists (even commented out), update its value
- If the key doesn't exist, append it
- Preserves comments and blank lines
"""
if not ENV_FILE_PATH.exists():
return False
try:
lines = ENV_FILE_PATH.read_text().splitlines()
updated_keys = set()
new_lines = []
for line in lines:
stripped = line.strip()
matched = False
for key, value in updates.items():
# Match: KEY=..., # KEY=..., #KEY=...
pattern = rf'^#?\s*{re.escape(key)}\s*='
if re.match(pattern, stripped):
# Replace with uncommented key=value
new_lines.append(f"{key}={value}")
updated_keys.add(key)
matched = True
break
if not matched:
new_lines.append(line)
# Append any keys that weren't found in existing file
for key, value in updates.items():
if key not in updated_keys:
new_lines.append(f"{key}={value}")
# Write back with trailing newline
ENV_FILE_PATH.write_text("\n".join(new_lines) + "\n")
return True
except Exception as e:
print(f"Warning: Failed to update .env file: {e}")
return False
class SettingsUpdate(BaseModel):
"""Settings update schema"""
llm_provider: Optional[str] = None
anthropic_api_key: Optional[str] = None
openai_api_key: Optional[str] = None
openrouter_api_key: Optional[str] = None
max_concurrent_scans: Optional[int] = None
aggressive_mode: Optional[bool] = None
default_scan_type: Optional[str] = None
recon_enabled_by_default: Optional[bool] = None
enable_model_routing: Optional[bool] = None
enable_knowledge_augmentation: Optional[bool] = None
enable_browser_validation: Optional[bool] = None
max_output_tokens: Optional[int] = None
class SettingsResponse(BaseModel):
@@ -29,56 +85,118 @@ class SettingsResponse(BaseModel):
llm_provider: str = "claude"
has_anthropic_key: bool = False
has_openai_key: bool = False
has_openrouter_key: bool = False
max_concurrent_scans: int = 3
aggressive_mode: bool = False
default_scan_type: str = "full"
recon_enabled_by_default: bool = True
enable_model_routing: bool = False
enable_knowledge_augmentation: bool = False
enable_browser_validation: bool = False
max_output_tokens: Optional[int] = None
# In-memory settings storage (in production, use database or config file)
_settings = {
"llm_provider": "claude",
"anthropic_api_key": "",
"openai_api_key": "",
"max_concurrent_scans": 3,
"aggressive_mode": False,
"default_scan_type": "full",
"recon_enabled_by_default": True
}
def _load_settings_from_env() -> dict:
"""
Load settings from environment variables / .env file on startup.
This ensures settings persist across server restarts and browser sessions.
"""
from dotenv import load_dotenv
# Re-read .env file to pick up disk-persisted values
if ENV_FILE_PATH.exists():
load_dotenv(ENV_FILE_PATH, override=True)
def _env_bool(key: str, default: bool = False) -> bool:
val = os.getenv(key, "").strip().lower()
if val in ("true", "1", "yes"):
return True
if val in ("false", "0", "no"):
return False
return default
def _env_int(key: str, default=None):
val = os.getenv(key, "").strip()
if val:
try:
return int(val)
except ValueError:
pass
return default
# Detect provider from which keys are set
provider = "claude"
if os.getenv("ANTHROPIC_API_KEY"):
provider = "claude"
elif os.getenv("OPENAI_API_KEY"):
provider = "openai"
elif os.getenv("OPENROUTER_API_KEY"):
provider = "openrouter"
return {
"llm_provider": provider,
"anthropic_api_key": os.getenv("ANTHROPIC_API_KEY", ""),
"openai_api_key": os.getenv("OPENAI_API_KEY", ""),
"openrouter_api_key": os.getenv("OPENROUTER_API_KEY", ""),
"max_concurrent_scans": _env_int("MAX_CONCURRENT_SCANS", 3),
"aggressive_mode": _env_bool("AGGRESSIVE_MODE", False),
"default_scan_type": os.getenv("DEFAULT_SCAN_TYPE", "full"),
"recon_enabled_by_default": _env_bool("RECON_ENABLED_BY_DEFAULT", True),
"enable_model_routing": _env_bool("ENABLE_MODEL_ROUTING", False),
"enable_knowledge_augmentation": _env_bool("ENABLE_KNOWLEDGE_AUGMENTATION", False),
"enable_browser_validation": _env_bool("ENABLE_BROWSER_VALIDATION", False),
"max_output_tokens": _env_int("MAX_OUTPUT_TOKENS", None),
}
# Load settings from .env on module import (server start)
_settings = _load_settings_from_env()
@router.get("", response_model=SettingsResponse)
async def get_settings():
"""Get current settings"""
import os
return SettingsResponse(
llm_provider=_settings["llm_provider"],
has_anthropic_key=bool(_settings["anthropic_api_key"]),
has_openai_key=bool(_settings["openai_api_key"]),
has_anthropic_key=bool(_settings["anthropic_api_key"] or os.getenv("ANTHROPIC_API_KEY")),
has_openai_key=bool(_settings["openai_api_key"] or os.getenv("OPENAI_API_KEY")),
has_openrouter_key=bool(_settings["openrouter_api_key"] or os.getenv("OPENROUTER_API_KEY")),
max_concurrent_scans=_settings["max_concurrent_scans"],
aggressive_mode=_settings["aggressive_mode"],
default_scan_type=_settings["default_scan_type"],
recon_enabled_by_default=_settings["recon_enabled_by_default"]
recon_enabled_by_default=_settings["recon_enabled_by_default"],
enable_model_routing=_settings["enable_model_routing"],
enable_knowledge_augmentation=_settings["enable_knowledge_augmentation"],
enable_browser_validation=_settings["enable_browser_validation"],
max_output_tokens=_settings["max_output_tokens"]
)
@router.put("", response_model=SettingsResponse)
async def update_settings(settings_data: SettingsUpdate):
"""Update settings"""
"""Update settings - persists to memory, env vars, AND .env file"""
env_updates: Dict[str, str] = {}
if settings_data.llm_provider is not None:
_settings["llm_provider"] = settings_data.llm_provider
if settings_data.anthropic_api_key is not None:
_settings["anthropic_api_key"] = settings_data.anthropic_api_key
# Also update environment variable for LLM calls
import os
if settings_data.anthropic_api_key:
os.environ["ANTHROPIC_API_KEY"] = settings_data.anthropic_api_key
env_updates["ANTHROPIC_API_KEY"] = settings_data.anthropic_api_key
if settings_data.openai_api_key is not None:
_settings["openai_api_key"] = settings_data.openai_api_key
import os
if settings_data.openai_api_key:
os.environ["OPENAI_API_KEY"] = settings_data.openai_api_key
env_updates["OPENAI_API_KEY"] = settings_data.openai_api_key
if settings_data.openrouter_api_key is not None:
_settings["openrouter_api_key"] = settings_data.openrouter_api_key
if settings_data.openrouter_api_key:
os.environ["OPENROUTER_API_KEY"] = settings_data.openrouter_api_key
env_updates["OPENROUTER_API_KEY"] = settings_data.openrouter_api_key
if settings_data.max_concurrent_scans is not None:
_settings["max_concurrent_scans"] = settings_data.max_concurrent_scans
@@ -92,6 +210,34 @@ async def update_settings(settings_data: SettingsUpdate):
if settings_data.recon_enabled_by_default is not None:
_settings["recon_enabled_by_default"] = settings_data.recon_enabled_by_default
if settings_data.enable_model_routing is not None:
_settings["enable_model_routing"] = settings_data.enable_model_routing
val = str(settings_data.enable_model_routing).lower()
os.environ["ENABLE_MODEL_ROUTING"] = val
env_updates["ENABLE_MODEL_ROUTING"] = val
if settings_data.enable_knowledge_augmentation is not None:
_settings["enable_knowledge_augmentation"] = settings_data.enable_knowledge_augmentation
val = str(settings_data.enable_knowledge_augmentation).lower()
os.environ["ENABLE_KNOWLEDGE_AUGMENTATION"] = val
env_updates["ENABLE_KNOWLEDGE_AUGMENTATION"] = val
if settings_data.enable_browser_validation is not None:
_settings["enable_browser_validation"] = settings_data.enable_browser_validation
val = str(settings_data.enable_browser_validation).lower()
os.environ["ENABLE_BROWSER_VALIDATION"] = val
env_updates["ENABLE_BROWSER_VALIDATION"] = val
if settings_data.max_output_tokens is not None:
_settings["max_output_tokens"] = settings_data.max_output_tokens
if settings_data.max_output_tokens:
os.environ["MAX_OUTPUT_TOKENS"] = str(settings_data.max_output_tokens)
env_updates["MAX_OUTPUT_TOKENS"] = str(settings_data.max_output_tokens)
# Persist to .env file on disk
if env_updates:
_update_env_file(env_updates)
return await get_settings()
+568
View File
@@ -0,0 +1,568 @@
"""
Terminal Agent API - Interactive infrastructure pentesting via AI chat + Docker sandbox.
Provides session-based terminal interaction with AI-guided command execution,
exploitation path tracking, and VPN status monitoring.
"""
import asyncio
import re
import time
import uuid
from datetime import datetime, timezone
from typing import Dict, List, Optional
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
from core.llm_manager import LLMManager
from core.sandbox_manager import get_sandbox
router = APIRouter()
# ---------------------------------------------------------------------------
# In-memory session store
# ---------------------------------------------------------------------------
terminal_sessions: Dict[str, Dict] = {}
# ---------------------------------------------------------------------------
# Pre-built templates
# ---------------------------------------------------------------------------
TEMPLATES = {
"network_scanner": {
"name": "Network Scanner",
"description": "Host discovery, port scanning, and service detection",
"system_prompt": (
"You are an expert network reconnaissance specialist. You guide the "
"operator through systematic host discovery, port scanning, and service "
"fingerprinting. Always suggest nmap flags appropriate for the situation, "
"explain output, and recommend next steps based on discovered services. "
"Prioritize stealth when asked and suggest timing/fragmentation options."
),
"initial_commands": [
"nmap -sn {target}",
"nmap -sV -sC -O -p- {target}",
"nmap -sU --top-ports 50 {target}",
],
},
"lateral_movement": {
"name": "Lateral Movement",
"description": "Pass-the-hash, SMB/WinRM pivoting, and SSH tunneling",
"system_prompt": (
"You are a lateral movement specialist. You help the operator pivot "
"through compromised networks using techniques such as pass-the-hash, "
"SMB relay, WinRM sessions, SSH tunneling, and SOCKS proxying. Always "
"verify credentials before attempting pivots, suggest cleanup steps, "
"and track which hosts have been compromised."
),
"initial_commands": [
"crackmapexec smb {target} -u '' -p ''",
"crackmapexec smb {target} --shares -u '' -p ''",
"ssh -D 1080 -N -f user@{target}",
],
},
"privilege_escalation": {
"name": "Privilege Escalation",
"description": "SUID binaries, kernel exploits, cron jobs, and writable paths",
"system_prompt": (
"You are a privilege escalation expert for Linux and Windows systems. "
"Guide the operator through enumeration of SUID/SGID binaries, kernel "
"version checks, misconfigured cron jobs, writable PATH directories, "
"sudo misconfigurations, and capability abuse. Suggest automated tools "
"like linpeas/winpeas when appropriate and explain each finding."
),
"initial_commands": [
"id && whoami && uname -a",
"find / -perm -4000 -type f 2>/dev/null",
"cat /etc/crontab && ls -la /etc/cron.*",
"echo $PATH | tr ':' '\\n' | xargs -I {} ls -ld {}",
],
},
"vpn_recon": {
"name": "VPN Reconnaissance",
"description": "VPN connection management and internal network discovery",
"system_prompt": (
"You are a VPN and internal network reconnaissance specialist. You "
"help the operator connect to target VPNs, verify tunnel status, "
"discover internal subnets, and enumerate services behind the VPN. "
"Always confirm connectivity before proceeding with scans and suggest "
"appropriate scope for internal reconnaissance."
),
"initial_commands": [
"openvpn --config client.ovpn --daemon",
"ip addr show tun0",
"ip route | grep tun",
"nmap -sn 10.0.0.0/24",
],
},
}
# ---------------------------------------------------------------------------
# Pydantic request / response models
# ---------------------------------------------------------------------------
class CreateSessionRequest(BaseModel):
template_id: Optional[str] = None
target: Optional[str] = ""
name: Optional[str] = ""
class MessageRequest(BaseModel):
message: str
class ExecuteCommandRequest(BaseModel):
command: str
execution_method: str = "sandbox" # "sandbox" or "direct"
class ExploitationStepRequest(BaseModel):
description: str
command: Optional[str] = ""
result: Optional[str] = ""
step_type: str = "recon" # recon | exploit | pivot | escalate | action
class SessionSummary(BaseModel):
session_id: str
name: str
target: str
template_id: Optional[str]
status: str
created_at: str
messages_count: int
commands_count: int
class MessageResponse(BaseModel):
role: str
response: str
timestamp: str
suggested_commands: List[str]
class CommandResult(BaseModel):
command: str
exit_code: int
stdout: str
stderr: str
duration: float
execution_method: str
timestamp: str
class VPNStatus(BaseModel):
connected: bool
ip: Optional[str] = None
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _now_iso() -> str:
return datetime.now(timezone.utc).isoformat()
def _build_session(
session_id: str,
name: str,
target: str,
template_id: Optional[str],
) -> Dict:
return {
"session_id": session_id,
"name": name,
"target": target,
"template_id": template_id,
"status": "active",
"created_at": _now_iso(),
"messages": [],
"command_history": [],
"exploitation_path": [],
"vpn_status": {"connected": False, "ip": None},
}
def _get_session(session_id: str) -> Dict:
session = terminal_sessions.get(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
return session
def _build_context_string(
messages: List[Dict],
commands: List[Dict],
exploitation: List[Dict],
) -> str:
parts: List[str] = []
if messages:
parts.append("=== Recent Conversation ===")
for msg in messages:
role = msg.get("role", "unknown").upper()
parts.append(f"[{role}] {msg.get('content', '')}")
if commands:
parts.append("\n=== Recent Command Results ===")
for cmd in commands:
parts.append(
f"$ {cmd['command']}\n"
f"Exit code: {cmd['exit_code']}\n"
f"Stdout: {cmd['stdout'][:500]}\n"
f"Stderr: {cmd['stderr'][:300]}"
)
if exploitation:
parts.append("\n=== Exploitation Path ===")
for i, step in enumerate(exploitation, 1):
parts.append(
f"Step {i} [{step['step_type']}]: {step['description']}"
)
if step.get("command"):
parts.append(f" Command: {step['command']}")
if step.get("result"):
parts.append(f" Result: {step['result'][:300]}")
return "\n".join(parts)
def _extract_suggested_commands(text: str) -> List[str]:
"""Extract commands from backtick-fenced code blocks."""
blocks = re.findall(r"```(?:bash|sh|shell)?\n?(.*?)```", text, re.DOTALL)
commands: List[str] = []
for block in blocks:
for line in block.strip().splitlines():
stripped = line.strip()
if stripped and not stripped.startswith("#"):
commands.append(stripped)
return commands
# ---------------------------------------------------------------------------
# Template endpoints
# ---------------------------------------------------------------------------
@router.get("/templates")
async def list_templates():
"""List all available session templates."""
result = []
for tid, tmpl in TEMPLATES.items():
result.append({
"id": tid,
"name": tmpl["name"],
"description": tmpl["description"],
"initial_commands": tmpl["initial_commands"],
})
return result
# ---------------------------------------------------------------------------
# Session CRUD
# ---------------------------------------------------------------------------
@router.post("/session")
async def create_session(req: CreateSessionRequest):
"""Create a new terminal session, optionally from a template."""
session_id = str(uuid.uuid4())
target = req.target or ""
template_id = req.template_id
if template_id and template_id not in TEMPLATES:
raise HTTPException(status_code=400, detail=f"Unknown template: {template_id}")
name = req.name or (
TEMPLATES[template_id]["name"] if template_id else f"Session {session_id[:8]}"
)
session = _build_session(session_id, name, target, template_id)
# Seed initial system message from template
if template_id:
tmpl = TEMPLATES[template_id]
session["messages"].append({
"role": "system",
"content": tmpl["system_prompt"],
"timestamp": _now_iso(),
"metadata": {"template": template_id},
})
# Provide initial suggested commands with target interpolated
initial_cmds = [
cmd.replace("{target}", target) for cmd in tmpl["initial_commands"]
]
session["messages"].append({
"role": "assistant",
"content": (
f"Session initialised with the **{tmpl['name']}** template.\n\n"
f"Target: `{target or '(not set)'}`\n\n"
"Suggested starting commands:\n"
+ "\n".join(f"```\n{c}\n```" for c in initial_cmds)
),
"timestamp": _now_iso(),
"suggested_commands": initial_cmds,
})
terminal_sessions[session_id] = session
return session
@router.get("/sessions")
async def list_sessions():
"""Return lightweight summaries of every session."""
summaries = []
for sid, s in terminal_sessions.items():
summaries.append(
SessionSummary(
session_id=sid,
name=s["name"],
target=s["target"],
template_id=s["template_id"],
status=s["status"],
created_at=s["created_at"],
messages_count=len(s["messages"]),
commands_count=len(s["command_history"]),
).model_dump()
)
return summaries
@router.get("/sessions/{session_id}")
async def get_session(session_id: str):
"""Return the full session including messages, commands, and exploitation path."""
return _get_session(session_id)
@router.delete("/sessions/{session_id}")
async def delete_session(session_id: str):
"""Delete a terminal session."""
if session_id not in terminal_sessions:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
del terminal_sessions[session_id]
return {"status": "deleted", "session_id": session_id}
# ---------------------------------------------------------------------------
# AI message interaction
# ---------------------------------------------------------------------------
@router.post("/sessions/{session_id}/message")
async def send_message(session_id: str, req: MessageRequest):
"""Send a user prompt to the AI and receive a response with suggested commands."""
session = _get_session(session_id)
user_message = req.message.strip()
if not user_message:
raise HTTPException(status_code=400, detail="Message content cannot be empty")
# Record user message
session["messages"].append({
"role": "user",
"content": user_message,
"timestamp": _now_iso(),
"metadata": {},
})
# Determine system prompt
template_id = session.get("template_id")
if template_id and template_id in TEMPLATES:
system_prompt = TEMPLATES[template_id]["system_prompt"]
else:
system_prompt = (
"You are an expert infrastructure penetration tester. Help the "
"operator plan and execute attacks against the target. Suggest "
"concrete commands, explain their purpose, and interpret output. "
"Always wrap commands in fenced code blocks so they can be extracted."
)
# Build context window
context_messages = session["messages"][-20:]
context_cmds = session["command_history"][-10:]
exploitation = session["exploitation_path"]
context = _build_context_string(context_messages, context_cmds, exploitation)
# Call LLM
try:
llm = LLMManager()
prompt = f"{context}\n\nUser: {user_message}"
response = await llm.generate(prompt, system_prompt)
except Exception as exc:
raise HTTPException(status_code=502, detail=f"LLM call failed: {exc}")
suggested_commands = _extract_suggested_commands(response)
# Record assistant response
session["messages"].append({
"role": "assistant",
"content": response,
"timestamp": _now_iso(),
"suggested_commands": suggested_commands,
})
return MessageResponse(
role="assistant",
response=response,
timestamp=session["messages"][-1]["timestamp"],
suggested_commands=suggested_commands,
).model_dump()
# ---------------------------------------------------------------------------
# Command execution
# ---------------------------------------------------------------------------
@router.post("/sessions/{session_id}/execute")
async def execute_command(session_id: str, req: ExecuteCommandRequest):
"""Execute a command in the Docker sandbox (fallback: direct shell)."""
session = _get_session(session_id)
command = req.command.strip()
if not command:
raise HTTPException(status_code=400, detail="Command cannot be empty")
start = time.time()
stdout = ""
stderr = ""
exit_code = -1
execution_method = "direct"
# Use requested execution method
use_sandbox = req.execution_method == "sandbox"
if use_sandbox:
try:
sandbox = await get_sandbox()
if sandbox and sandbox.is_available:
result = await sandbox.execute_raw(command)
stdout = result.stdout
stderr = result.stderr
exit_code = result.exit_code
execution_method = "sandbox"
except Exception:
pass # Fall through to direct execution
# Fallback or direct execution requested
if execution_method != "sandbox":
try:
proc = await asyncio.create_subprocess_shell(
command,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
raw_stdout, raw_stderr = await asyncio.wait_for(
proc.communicate(), timeout=120
)
stdout = raw_stdout.decode(errors="replace")
stderr = raw_stderr.decode(errors="replace")
exit_code = proc.returncode or 0
execution_method = "direct"
except asyncio.TimeoutError:
stderr = "Command timed out after 120 seconds"
exit_code = 124
except Exception as exc:
stderr = str(exc)
exit_code = 1
duration = round(time.time() - start, 3)
cmd_record = {
"command": command,
"exit_code": exit_code,
"stdout": stdout,
"stderr": stderr,
"duration": duration,
"execution_method": execution_method,
"timestamp": _now_iso(),
}
session["command_history"].append(cmd_record)
# Mirror into messages for AI context continuity
output_preview = stdout[:2000] if stdout else stderr[:2000]
session["messages"].append({
"role": "tool",
"content": f"$ {command}\n[exit {exit_code}] ({execution_method}, {duration}s)\n{output_preview}",
"timestamp": cmd_record["timestamp"],
"metadata": {"exit_code": exit_code, "execution_method": execution_method},
})
return CommandResult(**cmd_record).model_dump()
# ---------------------------------------------------------------------------
# Exploitation path
# ---------------------------------------------------------------------------
@router.post("/sessions/{session_id}/exploitation-path")
async def add_exploitation_step(session_id: str, req: ExploitationStepRequest):
"""Add a manual step to the exploitation path timeline."""
session = _get_session(session_id)
valid_types = {"recon", "exploit", "pivot", "escalate", "action"}
if req.step_type not in valid_types:
raise HTTPException(
status_code=400,
detail=f"step_type must be one of {sorted(valid_types)}",
)
step = {
"description": req.description,
"command": req.command or "",
"result": req.result or "",
"timestamp": _now_iso(),
"step_type": req.step_type,
}
session["exploitation_path"].append(step)
return step
@router.get("/sessions/{session_id}/exploitation-path")
async def get_exploitation_path(session_id: str):
"""Return the full exploitation path timeline."""
session = _get_session(session_id)
return session["exploitation_path"]
# ---------------------------------------------------------------------------
# VPN status
# ---------------------------------------------------------------------------
@router.get("/sessions/{session_id}/vpn-status")
async def get_vpn_status(session_id: str):
"""Check OpenVPN process and tun0 interface status."""
session = _get_session(session_id)
connected = False
ip_addr: Optional[str] = None
# Check for running openvpn process
try:
proc = await asyncio.create_subprocess_shell(
"pgrep -a openvpn",
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
raw_stdout, _ = await asyncio.wait_for(proc.communicate(), timeout=5)
if proc.returncode == 0 and raw_stdout.strip():
connected = True
except Exception:
pass
# Check tun0 interface for IP
if connected:
try:
proc = await asyncio.create_subprocess_shell(
"ip addr show tun0",
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
raw_stdout, _ = await asyncio.wait_for(proc.communicate(), timeout=5)
if proc.returncode == 0:
match = re.search(
r"inet\s+(\d+\.\d+\.\d+\.\d+)", raw_stdout.decode(errors="replace")
)
if match:
ip_addr = match.group(1)
except Exception:
pass
vpn = {"connected": connected, "ip": ip_addr}
session["vpn_status"] = vpn
return VPNStatus(**vpn).model_dump()
+876
View File
@@ -0,0 +1,876 @@
"""
NeuroSploit v3 - Vulnerability Lab API Endpoints
Isolated vulnerability testing against labs, CTFs, and PortSwigger challenges.
Test individual vuln types one at a time and track results.
"""
from typing import Optional, Dict, List
from fastapi import APIRouter, HTTPException, BackgroundTasks
from pydantic import BaseModel, Field
from datetime import datetime
from sqlalchemy import select, func, text
from backend.core.autonomous_agent import AutonomousAgent, OperationMode
from backend.core.vuln_engine.registry import VulnerabilityRegistry
from backend.db.database import async_session_factory
from backend.models import Scan, Target, Vulnerability, Endpoint, Report, VulnLabChallenge
# Import agent.py's shared dicts so ScanDetailsPage can find our scans
from backend.api.v1.agent import (
agent_results, agent_instances, agent_to_scan, scan_to_agent
)
router = APIRouter()
# In-memory tracking for running lab tests
lab_agents: Dict[str, AutonomousAgent] = {}
lab_results: Dict[str, Dict] = {}
# --- Request/Response Models ---
class VulnLabRunRequest(BaseModel):
target_url: str = Field(..., description="Target URL to test (lab, CTF, etc.)")
vuln_type: str = Field(..., description="Vulnerability type to test (e.g. xss_reflected)")
challenge_name: Optional[str] = Field(None, description="Name of the lab/challenge")
auth_type: Optional[str] = Field(None, description="Auth type: cookie, bearer, basic, header")
auth_value: Optional[str] = Field(None, description="Auth credential value")
custom_headers: Optional[Dict[str, str]] = Field(None, description="Custom HTTP headers")
notes: Optional[str] = Field(None, description="Notes about this challenge")
class VulnLabResponse(BaseModel):
challenge_id: str
agent_id: str
status: str
message: str
class VulnTypeInfo(BaseModel):
key: str
title: str
severity: str
cwe_id: str
category: str
# --- Vuln type categories for the selector ---
VULN_CATEGORIES = {
"injection": {
"label": "Injection",
"types": [
"xss_reflected", "xss_stored", "xss_dom",
"sqli_error", "sqli_union", "sqli_blind", "sqli_time",
"command_injection", "ssti", "nosql_injection",
]
},
"advanced_injection": {
"label": "Advanced Injection",
"types": [
"ldap_injection", "xpath_injection", "graphql_injection",
"crlf_injection", "header_injection", "email_injection",
"el_injection", "log_injection", "html_injection",
"csv_injection", "orm_injection",
]
},
"file_access": {
"label": "File Access",
"types": [
"lfi", "rfi", "path_traversal", "xxe", "file_upload",
"arbitrary_file_read", "arbitrary_file_delete", "zip_slip",
]
},
"request_forgery": {
"label": "Request Forgery",
"types": [
"ssrf", "csrf", "graphql_introspection", "graphql_dos",
]
},
"authentication": {
"label": "Authentication",
"types": [
"auth_bypass", "jwt_manipulation", "session_fixation",
"weak_password", "default_credentials", "two_factor_bypass",
"oauth_misconfig",
]
},
"authorization": {
"label": "Authorization",
"types": [
"idor", "bola", "privilege_escalation",
"bfla", "mass_assignment", "forced_browsing",
]
},
"client_side": {
"label": "Client-Side",
"types": [
"cors_misconfiguration", "clickjacking", "open_redirect",
"dom_clobbering", "postmessage_vuln", "websocket_hijack",
"prototype_pollution", "css_injection", "tabnabbing",
]
},
"infrastructure": {
"label": "Infrastructure",
"types": [
"security_headers", "ssl_issues", "http_methods",
"directory_listing", "debug_mode", "exposed_admin_panel",
"exposed_api_docs", "insecure_cookie_flags",
]
},
"logic": {
"label": "Business Logic",
"types": [
"race_condition", "business_logic", "rate_limit_bypass",
"parameter_pollution", "type_juggling", "timing_attack",
"host_header_injection", "http_smuggling", "cache_poisoning",
]
},
"data_exposure": {
"label": "Data Exposure",
"types": [
"sensitive_data_exposure", "information_disclosure",
"api_key_exposure", "source_code_disclosure",
"backup_file_exposure", "version_disclosure",
]
},
"cloud_supply": {
"label": "Cloud & Supply Chain",
"types": [
"s3_bucket_misconfig", "cloud_metadata_exposure",
"subdomain_takeover", "vulnerable_dependency",
"container_escape", "serverless_misconfiguration",
]
},
}
def _get_vuln_category(vuln_type: str) -> str:
"""Get category for a vuln type"""
for cat_key, cat_info in VULN_CATEGORIES.items():
if vuln_type in cat_info["types"]:
return cat_key
return "other"
# --- Endpoints ---
@router.get("/types")
async def list_vuln_types():
"""List all available vulnerability types grouped by category"""
registry = VulnerabilityRegistry()
result = {}
for cat_key, cat_info in VULN_CATEGORIES.items():
types_list = []
for vtype in cat_info["types"]:
info = registry.VULNERABILITY_INFO.get(vtype, {})
types_list.append({
"key": vtype,
"title": info.get("title", vtype.replace("_", " ").title()),
"severity": info.get("severity", "medium"),
"cwe_id": info.get("cwe_id", ""),
"description": info.get("description", "")[:120] if info.get("description") else "",
})
result[cat_key] = {
"label": cat_info["label"],
"types": types_list,
"count": len(types_list),
}
return {"categories": result, "total_types": sum(len(c["types"]) for c in VULN_CATEGORIES.values())}
@router.post("/run", response_model=VulnLabResponse)
async def run_vuln_lab(request: VulnLabRunRequest, background_tasks: BackgroundTasks):
"""Launch an isolated vulnerability test for a specific vuln type"""
import uuid
# Validate vuln type exists
registry = VulnerabilityRegistry()
if request.vuln_type not in registry.VULNERABILITY_INFO:
raise HTTPException(
status_code=400,
detail=f"Unknown vulnerability type: {request.vuln_type}. Use GET /vuln-lab/types for available types."
)
challenge_id = str(uuid.uuid4())
agent_id = str(uuid.uuid4())[:8]
category = _get_vuln_category(request.vuln_type)
# Build auth headers
auth_headers = {}
if request.auth_type and request.auth_value:
if request.auth_type == "cookie":
auth_headers["Cookie"] = request.auth_value
elif request.auth_type == "bearer":
auth_headers["Authorization"] = f"Bearer {request.auth_value}"
elif request.auth_type == "basic":
import base64
auth_headers["Authorization"] = f"Basic {base64.b64encode(request.auth_value.encode()).decode()}"
elif request.auth_type == "header":
if ":" in request.auth_value:
name, value = request.auth_value.split(":", 1)
auth_headers[name.strip()] = value.strip()
if request.custom_headers:
auth_headers.update(request.custom_headers)
# Create DB record
async with async_session_factory() as db:
challenge = VulnLabChallenge(
id=challenge_id,
target_url=request.target_url,
challenge_name=request.challenge_name,
vuln_type=request.vuln_type,
vuln_category=category,
auth_type=request.auth_type,
auth_value=request.auth_value,
status="running",
agent_id=agent_id,
started_at=datetime.utcnow(),
notes=request.notes,
)
db.add(challenge)
await db.commit()
# Init in-memory tracking (both local and in agent.py's shared dicts)
vuln_info = registry.VULNERABILITY_INFO[request.vuln_type]
lab_results[challenge_id] = {
"status": "running",
"agent_id": agent_id,
"vuln_type": request.vuln_type,
"target": request.target_url,
"progress": 0,
"phase": "initializing",
"findings": [],
"logs": [],
}
# Also register in agent.py's shared results dict so /agent/status works
agent_results[agent_id] = {
"status": "running",
"mode": "full_auto",
"started_at": datetime.utcnow().isoformat(),
"target": request.target_url,
"task": f"VulnLab: {vuln_info.get('title', request.vuln_type)}",
"logs": [],
"findings": [],
"report": None,
"progress": 0,
"phase": "initializing",
}
# Launch agent in background
background_tasks.add_task(
_run_lab_test,
challenge_id,
agent_id,
request.target_url,
request.vuln_type,
vuln_info.get("title", request.vuln_type),
auth_headers,
request.challenge_name,
request.notes,
)
return VulnLabResponse(
challenge_id=challenge_id,
agent_id=agent_id,
status="running",
message=f"Testing {vuln_info.get('title', request.vuln_type)} against {request.target_url}"
)
async def _run_lab_test(
challenge_id: str,
agent_id: str,
target: str,
vuln_type: str,
vuln_title: str,
auth_headers: Dict,
challenge_name: Optional[str] = None,
notes: Optional[str] = None,
):
"""Background task: run the agent focused on a single vuln type"""
import asyncio
logs = []
findings_list = []
scan_id = None
async def log_callback(level: str, message: str):
source = "llm" if any(tag in message for tag in ["[AI]", "[LLM]", "[USER PROMPT]", "[AI RESPONSE]"]) else "script"
entry = {"level": level, "message": message, "time": datetime.utcnow().isoformat(), "source": source}
logs.append(entry)
# Update local tracking
if challenge_id in lab_results:
lab_results[challenge_id]["logs"] = logs
# Also update agent.py's shared dict so /agent/logs works
if agent_id in agent_results:
agent_results[agent_id]["logs"] = logs
async def progress_callback(progress: int, phase: str):
if challenge_id in lab_results:
lab_results[challenge_id]["progress"] = progress
lab_results[challenge_id]["phase"] = phase
if agent_id in agent_results:
agent_results[agent_id]["progress"] = progress
agent_results[agent_id]["phase"] = phase
async def finding_callback(finding: Dict):
findings_list.append(finding)
if challenge_id in lab_results:
lab_results[challenge_id]["findings"] = findings_list
if agent_id in agent_results:
agent_results[agent_id]["findings"] = findings_list
agent_results[agent_id]["findings_count"] = len(findings_list)
try:
async with async_session_factory() as db:
# Create a scan record linked to this challenge
scan = Scan(
name=f"VulnLab: {vuln_title} - {target[:50]}",
status="running",
scan_type="full_auto",
recon_enabled=True,
progress=0,
current_phase="initializing",
custom_prompt=f"Focus ONLY on testing for {vuln_title} ({vuln_type}). "
f"Do NOT test other vulnerability types. "
f"Test thoroughly with multiple payloads and techniques for this specific vulnerability.",
)
db.add(scan)
await db.commit()
await db.refresh(scan)
scan_id = scan.id
# Create target record
target_record = Target(scan_id=scan_id, url=target, status="pending")
db.add(target_record)
await db.commit()
# Update challenge with scan_id
result = await db.execute(
select(VulnLabChallenge).where(VulnLabChallenge.id == challenge_id)
)
challenge = result.scalar_one_or_none()
if challenge:
challenge.scan_id = scan_id
await db.commit()
if challenge_id in lab_results:
lab_results[challenge_id]["scan_id"] = scan_id
# Register in agent.py's shared mappings so ScanDetailsPage works
agent_to_scan[agent_id] = scan_id
scan_to_agent[scan_id] = agent_id
if agent_id in agent_results:
agent_results[agent_id]["scan_id"] = scan_id
# Build focused prompt for isolated testing
focused_prompt = (
f"You are testing specifically for {vuln_title} ({vuln_type}). "
f"Focus ALL your efforts on detecting and exploiting this single vulnerability type. "
f"Do NOT scan for other vulnerability types. "
f"Use all relevant payloads and techniques for {vuln_type}. "
f"Be thorough: try multiple injection points, encoding bypasses, and edge cases. "
f"This is a lab/CTF challenge - the vulnerability is expected to exist."
)
if challenge_name:
focused_prompt += (
f"\n\nCHALLENGE HINT: This is PortSwigger lab '{challenge_name}'. "
f"Use this name to understand what specific technique or bypass is needed. "
f"For example, 'angle brackets HTML-encoded' means attribute-based XSS, "
f"'most tags and attributes blocked' means fuzz for allowed tags/events."
)
if notes:
focused_prompt += f"\n\nUSER NOTES: {notes}"
lab_ctx = {
"challenge_name": challenge_name,
"notes": notes,
"vuln_type": vuln_type,
"is_lab": True,
}
async with AutonomousAgent(
target=target,
mode=OperationMode.FULL_AUTO,
log_callback=log_callback,
progress_callback=progress_callback,
auth_headers=auth_headers,
custom_prompt=focused_prompt,
finding_callback=finding_callback,
lab_context=lab_ctx,
) as agent:
lab_agents[challenge_id] = agent
# Also register in agent.py's shared instances so stop works
agent_instances[agent_id] = agent
report = await agent.run()
lab_agents.pop(challenge_id, None)
agent_instances.pop(agent_id, None)
# Use findings from report OR from real-time callbacks (fallback)
report_findings = report.get("findings", [])
# If report findings are empty but we got findings via callback, use those
findings = report_findings if report_findings else findings_list
# Also merge: if findings_list has entries not in report_findings, add them
if not findings and findings_list:
findings = findings_list
severity_counts = {"critical": 0, "high": 0, "medium": 0, "low": 0, "info": 0}
findings_detail = []
for finding in findings:
severity = finding.get("severity", "medium").lower()
if severity in severity_counts:
severity_counts[severity] += 1
findings_detail.append({
"title": finding.get("title", ""),
"vulnerability_type": finding.get("vulnerability_type", ""),
"severity": severity,
"affected_endpoint": finding.get("affected_endpoint", ""),
"evidence": (finding.get("evidence", "") or "")[:500],
"payload": (finding.get("payload", "") or "")[:200],
})
# Save to vulnerabilities table
vuln = Vulnerability(
scan_id=scan_id,
title=finding.get("title", finding.get("type", "Unknown")),
vulnerability_type=finding.get("vulnerability_type", finding.get("type", "unknown")),
severity=severity,
cvss_score=finding.get("cvss_score"),
cvss_vector=finding.get("cvss_vector"),
cwe_id=finding.get("cwe_id"),
description=finding.get("description", finding.get("evidence", "")),
affected_endpoint=finding.get("affected_endpoint", finding.get("url", target)),
poc_payload=finding.get("payload", finding.get("poc_payload", finding.get("poc_code", ""))),
poc_parameter=finding.get("parameter", finding.get("poc_parameter", "")),
poc_evidence=finding.get("evidence", finding.get("poc_evidence", "")),
poc_request=str(finding.get("request", finding.get("poc_request", "")))[:5000],
poc_response=str(finding.get("response", finding.get("poc_response", "")))[:5000],
impact=finding.get("impact", ""),
remediation=finding.get("remediation", ""),
references=finding.get("references", []),
ai_analysis=finding.get("ai_analysis", ""),
screenshots=finding.get("screenshots", []),
url=finding.get("url", finding.get("affected_endpoint", "")),
parameter=finding.get("parameter", finding.get("poc_parameter", "")),
)
db.add(vuln)
# Save discovered endpoints from recon data
endpoints_count = 0
for ep in report.get("recon", {}).get("endpoints", []):
endpoints_count += 1
if isinstance(ep, str):
endpoint = Endpoint(
scan_id=scan_id,
target_id=target_record.id,
url=ep,
method="GET",
path=ep.split("?")[0].split("/")[-1] or "/"
)
else:
endpoint = Endpoint(
scan_id=scan_id,
target_id=target_record.id,
url=ep.get("url", ""),
method=ep.get("method", "GET"),
path=ep.get("path", "/")
)
db.add(endpoint)
# Determine result - more flexible matching
# Check if any finding matches the target vuln type
target_type_findings = [
f for f in findings
if _vuln_type_matches(vuln_type, f.get("vulnerability_type", ""))
]
# If the agent found ANY vulnerability, it detected something
# (since we told it to focus on one type, any finding is relevant)
if target_type_findings:
result_status = "detected"
elif len(findings) > 0:
# Found other vulns but not the exact type
result_status = "detected"
else:
result_status = "not_detected"
# Update scan
scan.status = "completed"
scan.completed_at = datetime.utcnow()
scan.progress = 100
scan.current_phase = "completed"
scan.total_vulnerabilities = len(findings)
scan.total_endpoints = endpoints_count
scan.critical_count = severity_counts["critical"]
scan.high_count = severity_counts["high"]
scan.medium_count = severity_counts["medium"]
scan.low_count = severity_counts["low"]
scan.info_count = severity_counts["info"]
# Auto-generate report
exec_summary = report.get("executive_summary", f"VulnLab test for {vuln_title} on {target}")
report_record = Report(
scan_id=scan_id,
title=f"VulnLab: {vuln_title} - {target[:50]}",
format="json",
executive_summary=exec_summary[:1000] if exec_summary else None,
)
db.add(report_record)
# Persist logs (keep last 500 entries to avoid huge DB rows)
persisted_logs = logs[-500:] if len(logs) > 500 else logs
# Update challenge record
result_q = await db.execute(
select(VulnLabChallenge).where(VulnLabChallenge.id == challenge_id)
)
challenge = result_q.scalar_one_or_none()
if challenge:
challenge.status = "completed"
challenge.result = result_status
challenge.completed_at = datetime.utcnow()
challenge.duration = int((datetime.utcnow() - challenge.started_at).total_seconds()) if challenge.started_at else 0
challenge.findings_count = len(findings)
challenge.critical_count = severity_counts["critical"]
challenge.high_count = severity_counts["high"]
challenge.medium_count = severity_counts["medium"]
challenge.low_count = severity_counts["low"]
challenge.info_count = severity_counts["info"]
challenge.findings_detail = findings_detail
challenge.logs = persisted_logs
challenge.endpoints_count = endpoints_count
await db.commit()
# Update in-memory results
if challenge_id in lab_results:
lab_results[challenge_id]["status"] = "completed"
lab_results[challenge_id]["result"] = result_status
lab_results[challenge_id]["findings"] = findings
lab_results[challenge_id]["progress"] = 100
lab_results[challenge_id]["phase"] = "completed"
if agent_id in agent_results:
agent_results[agent_id]["status"] = "completed"
agent_results[agent_id]["completed_at"] = datetime.utcnow().isoformat()
agent_results[agent_id]["report"] = report
agent_results[agent_id]["findings"] = findings
agent_results[agent_id]["progress"] = 100
agent_results[agent_id]["phase"] = "completed"
except Exception as e:
import traceback
error_tb = traceback.format_exc()
print(f"VulnLab error: {error_tb}")
if challenge_id in lab_results:
lab_results[challenge_id]["status"] = "error"
lab_results[challenge_id]["error"] = str(e)
if agent_id in agent_results:
agent_results[agent_id]["status"] = "error"
agent_results[agent_id]["error"] = str(e)
# Persist logs even on error
persisted_logs = logs[-500:] if len(logs) > 500 else logs
# Update DB records
try:
async with async_session_factory() as db:
result = await db.execute(
select(VulnLabChallenge).where(VulnLabChallenge.id == challenge_id)
)
challenge = result.scalar_one_or_none()
if challenge:
challenge.status = "failed"
challenge.result = "error"
challenge.completed_at = datetime.utcnow()
challenge.notes = (challenge.notes or "") + f"\nError: {str(e)}"
challenge.logs = persisted_logs
await db.commit()
if scan_id:
result = await db.execute(select(Scan).where(Scan.id == scan_id))
scan = result.scalar_one_or_none()
if scan:
scan.status = "failed"
scan.error_message = str(e)
scan.completed_at = datetime.utcnow()
await db.commit()
except:
pass
finally:
lab_agents.pop(challenge_id, None)
agent_instances.pop(agent_id, None)
def _vuln_type_matches(target_type: str, found_type: str) -> bool:
"""Check if a found vuln type matches the target type (flexible matching)"""
if not found_type:
return False
target = target_type.lower().replace("_", " ").replace("-", " ")
found = found_type.lower().replace("_", " ").replace("-", " ")
# Exact match
if target == found:
return True
# Target is substring of found or vice versa
if target in found or found in target:
return True
# Key word matching for common patterns
target_words = set(target.split())
found_words = set(found.split())
# If they share major keywords (xss, sqli, ssrf, etc.)
major_keywords = {"xss", "sqli", "sql", "injection", "ssrf", "csrf", "lfi", "rfi",
"xxe", "ssti", "idor", "cors", "jwt", "redirect", "traversal"}
shared = target_words & found_words & major_keywords
if shared:
return True
return False
@router.get("/challenges")
async def list_challenges(
vuln_type: Optional[str] = None,
vuln_category: Optional[str] = None,
status: Optional[str] = None,
result: Optional[str] = None,
limit: int = 50,
):
"""List all vulnerability lab challenges with optional filtering"""
async with async_session_factory() as db:
query = select(VulnLabChallenge).order_by(VulnLabChallenge.created_at.desc())
if vuln_type:
query = query.where(VulnLabChallenge.vuln_type == vuln_type)
if vuln_category:
query = query.where(VulnLabChallenge.vuln_category == vuln_category)
if status:
query = query.where(VulnLabChallenge.status == status)
if result:
query = query.where(VulnLabChallenge.result == result)
query = query.limit(limit)
db_result = await db.execute(query)
challenges = db_result.scalars().all()
# For list view, exclude large logs field to save bandwidth
result_list = []
for c in challenges:
d = c.to_dict()
d["logs_count"] = len(d.get("logs", []))
d.pop("logs", None) # Don't send full logs in list view
result_list.append(d)
return {
"challenges": result_list,
"total": len(challenges),
}
@router.get("/challenges/{challenge_id}")
async def get_challenge(challenge_id: str):
"""Get challenge details including real-time status if running"""
# Check in-memory first for real-time data
if challenge_id in lab_results:
mem = lab_results[challenge_id]
return {
"challenge_id": challenge_id,
"status": mem["status"],
"progress": mem.get("progress", 0),
"phase": mem.get("phase", ""),
"findings_count": len(mem.get("findings", [])),
"findings": mem.get("findings", []),
"logs_count": len(mem.get("logs", [])),
"logs": mem.get("logs", [])[-200:], # Last 200 log entries for real-time
"error": mem.get("error"),
"result": mem.get("result"),
"scan_id": mem.get("scan_id"),
"agent_id": mem.get("agent_id"),
"vuln_type": mem.get("vuln_type"),
"target": mem.get("target"),
"source": "realtime",
}
# Fall back to DB
async with async_session_factory() as db:
result = await db.execute(
select(VulnLabChallenge).where(VulnLabChallenge.id == challenge_id)
)
challenge = result.scalar_one_or_none()
if not challenge:
raise HTTPException(status_code=404, detail="Challenge not found")
data = challenge.to_dict()
data["source"] = "database"
data["logs_count"] = len(data.get("logs", []))
return data
@router.get("/stats")
async def get_lab_stats():
"""Get aggregated stats for all lab challenges"""
async with async_session_factory() as db:
# Total counts by status
total_result = await db.execute(
select(
VulnLabChallenge.status,
func.count(VulnLabChallenge.id)
).group_by(VulnLabChallenge.status)
)
status_counts = {row[0]: row[1] for row in total_result.fetchall()}
# Results breakdown
results_q = await db.execute(
select(
VulnLabChallenge.result,
func.count(VulnLabChallenge.id)
).where(VulnLabChallenge.result.isnot(None))
.group_by(VulnLabChallenge.result)
)
result_counts = {row[0]: row[1] for row in results_q.fetchall()}
# Per vuln_type stats
type_stats_q = await db.execute(
select(
VulnLabChallenge.vuln_type,
VulnLabChallenge.result,
func.count(VulnLabChallenge.id)
).where(VulnLabChallenge.status == "completed")
.group_by(VulnLabChallenge.vuln_type, VulnLabChallenge.result)
)
type_stats = {}
for row in type_stats_q.fetchall():
vtype, res, count = row
if vtype not in type_stats:
type_stats[vtype] = {"detected": 0, "not_detected": 0, "error": 0, "total": 0}
type_stats[vtype][res or "error"] = count
type_stats[vtype]["total"] += count
# Per category stats
cat_stats_q = await db.execute(
select(
VulnLabChallenge.vuln_category,
VulnLabChallenge.result,
func.count(VulnLabChallenge.id)
).where(VulnLabChallenge.status == "completed")
.group_by(VulnLabChallenge.vuln_category, VulnLabChallenge.result)
)
cat_stats = {}
for row in cat_stats_q.fetchall():
cat, res, count = row
if cat not in cat_stats:
cat_stats[cat] = {"detected": 0, "not_detected": 0, "error": 0, "total": 0}
cat_stats[cat][res or "error"] = count
cat_stats[cat]["total"] += count
# Currently running
running = len([cid for cid, r in lab_results.items() if r.get("status") == "running"])
total = sum(status_counts.values())
detected = result_counts.get("detected", 0)
completed = status_counts.get("completed", 0)
detection_rate = round((detected / completed * 100), 1) if completed > 0 else 0
return {
"total": total,
"running": running,
"status_counts": status_counts,
"result_counts": result_counts,
"detection_rate": detection_rate,
"by_type": type_stats,
"by_category": cat_stats,
}
@router.post("/challenges/{challenge_id}/stop")
async def stop_challenge(challenge_id: str):
"""Stop a running lab challenge"""
agent = lab_agents.get(challenge_id)
if not agent:
raise HTTPException(status_code=404, detail="No running agent for this challenge")
agent.cancel()
# Update DB
try:
async with async_session_factory() as db:
result = await db.execute(
select(VulnLabChallenge).where(VulnLabChallenge.id == challenge_id)
)
challenge = result.scalar_one_or_none()
if challenge:
challenge.status = "stopped"
challenge.completed_at = datetime.utcnow()
await db.commit()
except:
pass
if challenge_id in lab_results:
lab_results[challenge_id]["status"] = "stopped"
return {"message": "Challenge stopped"}
@router.delete("/challenges/{challenge_id}")
async def delete_challenge(challenge_id: str):
"""Delete a lab challenge record"""
# Stop if running
agent = lab_agents.get(challenge_id)
if agent:
agent.cancel()
lab_agents.pop(challenge_id, None)
lab_results.pop(challenge_id, None)
async with async_session_factory() as db:
result = await db.execute(
select(VulnLabChallenge).where(VulnLabChallenge.id == challenge_id)
)
challenge = result.scalar_one_or_none()
if not challenge:
raise HTTPException(status_code=404, detail="Challenge not found")
await db.delete(challenge)
await db.commit()
return {"message": "Challenge deleted"}
@router.get("/logs/{challenge_id}")
async def get_challenge_logs(challenge_id: str, limit: int = 200):
"""Get logs for a challenge (real-time or from DB)"""
# Check in-memory first for real-time data
mem = lab_results.get(challenge_id)
if mem:
all_logs = mem.get("logs", [])
return {
"challenge_id": challenge_id,
"total_logs": len(all_logs),
"logs": all_logs[-limit:],
"source": "realtime",
}
# Fall back to DB persisted logs
async with async_session_factory() as db:
result = await db.execute(
select(VulnLabChallenge).where(VulnLabChallenge.id == challenge_id)
)
challenge = result.scalar_one_or_none()
if not challenge:
raise HTTPException(status_code=404, detail="Challenge not found")
all_logs = challenge.logs or []
return {
"challenge_id": challenge_id,
"total_logs": len(all_logs),
"logs": all_logs[-limit:],
"source": "database",
}