mirror of
https://github.com/CyberSecurityUP/NeuroSploit.git
synced 2026-06-12 23:27:47 +02:00
Add files via upload
This commit is contained in:
+327
-36
@@ -32,6 +32,8 @@ agent_instances: Dict[str, AutonomousAgent] = {}
|
||||
|
||||
# Map agent_id to scan_id for database persistence
|
||||
agent_to_scan: Dict[str, str] = {}
|
||||
# Reverse map: scan_id to agent_id for ScanDetailsPage lookups
|
||||
scan_to_agent: Dict[str, str] = {}
|
||||
|
||||
|
||||
@router.get("/status")
|
||||
@@ -101,6 +103,7 @@ class AgentMode(str, Enum):
|
||||
RECON_ONLY = "recon_only" # Just reconnaissance
|
||||
PROMPT_ONLY = "prompt_only" # AI decides (high tokens)
|
||||
ANALYZE_ONLY = "analyze_only" # Analysis without testing
|
||||
AUTO_PENTEST = "auto_pentest" # One-click full auto pentest
|
||||
|
||||
|
||||
class AgentRequest(BaseModel):
|
||||
@@ -113,6 +116,8 @@ class AgentRequest(BaseModel):
|
||||
auth_value: Optional[str] = Field(None, description="Auth value (cookie string, token, etc)")
|
||||
custom_headers: Optional[Dict[str, str]] = Field(None, description="Custom HTTP headers")
|
||||
max_depth: int = Field(5, description="Maximum crawl depth")
|
||||
subdomain_discovery: bool = Field(False, description="Enable subdomain discovery (auto_pentest mode)")
|
||||
targets: Optional[List[str]] = Field(None, description="Multiple targets (auto_pentest mode)")
|
||||
|
||||
|
||||
class AgentResponse(BaseModel):
|
||||
@@ -193,7 +198,9 @@ async def run_agent(request: AgentRequest, background_tasks: BackgroundTasks):
|
||||
"findings": [],
|
||||
"report": None,
|
||||
"progress": 0,
|
||||
"phase": "initializing"
|
||||
"phase": "initializing",
|
||||
"rejected_findings": [],
|
||||
"rejected_findings_count": 0,
|
||||
}
|
||||
|
||||
# Run agent in background
|
||||
@@ -212,7 +219,8 @@ async def run_agent(request: AgentRequest, background_tasks: BackgroundTasks):
|
||||
"full_auto": "Full autonomous pentest: Recon -> Analyze -> Test -> Report",
|
||||
"recon_only": "Reconnaissance only, no vulnerability testing",
|
||||
"prompt_only": "AI decides everything (high token usage!)",
|
||||
"analyze_only": "Analysis only, no active testing"
|
||||
"analyze_only": "Analysis only, no active testing",
|
||||
"auto_pentest": "One-click auto pentest: Full recon + 100 vuln types + AI report"
|
||||
}
|
||||
|
||||
return AgentResponse(
|
||||
@@ -255,12 +263,20 @@ async def _run_agent_task(
|
||||
agent_results[agent_id]["progress"] = progress
|
||||
agent_results[agent_id]["phase"] = phase
|
||||
|
||||
rejected_findings_list = []
|
||||
|
||||
async def finding_callback(finding: Dict):
|
||||
"""Real-time finding callback - updates in-memory storage immediately"""
|
||||
findings_list.append(finding)
|
||||
if agent_id in agent_results:
|
||||
agent_results[agent_id]["findings"] = findings_list
|
||||
agent_results[agent_id]["findings_count"] = len(findings_list)
|
||||
if finding.get("ai_status") == "rejected":
|
||||
rejected_findings_list.append(finding)
|
||||
if agent_id in agent_results:
|
||||
agent_results[agent_id]["rejected_findings"] = rejected_findings_list
|
||||
agent_results[agent_id]["rejected_findings_count"] = len(rejected_findings_list)
|
||||
else:
|
||||
findings_list.append(finding)
|
||||
if agent_id in agent_results:
|
||||
agent_results[agent_id]["findings"] = findings_list
|
||||
agent_results[agent_id]["findings_count"] = len(findings_list)
|
||||
|
||||
try:
|
||||
# Create database session and scan record
|
||||
@@ -289,8 +305,9 @@ async def _run_agent_task(
|
||||
db.add(target_record)
|
||||
await db.commit()
|
||||
|
||||
# Store mapping
|
||||
# Store mapping (both directions)
|
||||
agent_to_scan[agent_id] = scan_id
|
||||
scan_to_agent[scan_id] = agent_id
|
||||
agent_results[agent_id]["scan_id"] = scan_id
|
||||
|
||||
# Map mode
|
||||
@@ -299,6 +316,7 @@ async def _run_agent_task(
|
||||
AgentMode.RECON_ONLY: OperationMode.RECON_ONLY,
|
||||
AgentMode.PROMPT_ONLY: OperationMode.PROMPT_ONLY,
|
||||
AgentMode.ANALYZE_ONLY: OperationMode.ANALYZE_ONLY,
|
||||
AgentMode.AUTO_PENTEST: OperationMode.AUTO_PENTEST,
|
||||
}
|
||||
op_mode = mode_map.get(mode, OperationMode.FULL_AUTO)
|
||||
|
||||
@@ -311,6 +329,7 @@ async def _run_agent_task(
|
||||
task=task,
|
||||
custom_prompt=custom_prompt or (task.prompt if task else None),
|
||||
finding_callback=finding_callback,
|
||||
scan_id=str(scan_id),
|
||||
) as agent:
|
||||
# Store agent instance for stop functionality
|
||||
agent_instances[agent_id] = agent
|
||||
@@ -345,7 +364,41 @@ async def _run_agent_task(
|
||||
impact=finding.get("impact", ""),
|
||||
remediation=finding.get("remediation", ""),
|
||||
references=finding.get("references", []),
|
||||
ai_analysis=finding.get("ai_analysis", finding.get("exploitation_steps", ""))
|
||||
ai_analysis=finding.get("ai_analysis", finding.get("exploitation_steps", "")),
|
||||
poc_code=finding.get("poc_code", ""),
|
||||
screenshots=finding.get("screenshots", []),
|
||||
url=finding.get("url", finding.get("affected_endpoint", "")),
|
||||
parameter=finding.get("parameter", finding.get("poc_parameter", "")),
|
||||
validation_status="ai_confirmed",
|
||||
)
|
||||
db.add(vuln)
|
||||
|
||||
# Save rejected findings to database for manual review
|
||||
for finding in report.get("rejected_findings", []):
|
||||
vuln = Vulnerability(
|
||||
scan_id=scan_id,
|
||||
title=finding.get("title", finding.get("type", "Unknown")),
|
||||
vulnerability_type=finding.get("vulnerability_type", finding.get("type", "unknown")),
|
||||
severity=finding.get("severity", "medium").lower(),
|
||||
cvss_score=finding.get("cvss_score"),
|
||||
cvss_vector=finding.get("cvss_vector"),
|
||||
cwe_id=finding.get("cwe_id"),
|
||||
description=finding.get("description", finding.get("evidence", "")),
|
||||
affected_endpoint=finding.get("affected_endpoint", finding.get("endpoint", finding.get("url", target))),
|
||||
poc_payload=finding.get("payload", finding.get("poc_payload", "")),
|
||||
poc_parameter=finding.get("parameter", finding.get("poc_parameter", "")),
|
||||
poc_evidence=finding.get("evidence", finding.get("poc_evidence", "")),
|
||||
poc_request=str(finding.get("request", finding.get("poc_request", "")))[:5000],
|
||||
poc_response=str(finding.get("response", finding.get("poc_response", "")))[:5000],
|
||||
impact=finding.get("impact", ""),
|
||||
remediation=finding.get("remediation", ""),
|
||||
references=finding.get("references", []),
|
||||
poc_code=finding.get("poc_code", ""),
|
||||
screenshots=finding.get("screenshots", []),
|
||||
url=finding.get("url", finding.get("affected_endpoint", "")),
|
||||
parameter=finding.get("parameter", finding.get("poc_parameter", "")),
|
||||
validation_status="ai_rejected",
|
||||
ai_rejection_reason=finding.get("rejection_reason", ""),
|
||||
)
|
||||
db.add(vuln)
|
||||
|
||||
@@ -402,6 +455,7 @@ async def _run_agent_task(
|
||||
agent_results[agent_id]["report"] = report
|
||||
agent_results[agent_id]["report_id"] = report_record.id
|
||||
agent_results[agent_id]["findings"] = findings
|
||||
agent_results[agent_id]["tool_executions"] = report.get("tool_executions", [])
|
||||
agent_results[agent_id]["progress"] = 100
|
||||
agent_results[agent_id]["phase"] = "completed"
|
||||
|
||||
@@ -429,6 +483,37 @@ async def _run_agent_task(
|
||||
pass
|
||||
|
||||
|
||||
@router.get("/by-scan/{scan_id}")
|
||||
async def get_agent_by_scan(scan_id: str):
|
||||
"""Look up agent status by scan_id (reverse lookup for ScanDetailsPage)"""
|
||||
agent_id = scan_to_agent.get(scan_id)
|
||||
if not agent_id:
|
||||
raise HTTPException(status_code=404, detail="No agent found for this scan")
|
||||
|
||||
if agent_id in agent_results:
|
||||
result = agent_results[agent_id]
|
||||
return {
|
||||
"agent_id": agent_id,
|
||||
"scan_id": scan_id,
|
||||
"status": result["status"],
|
||||
"mode": result.get("mode", "full_auto"),
|
||||
"target": result["target"],
|
||||
"progress": result.get("progress", 0),
|
||||
"phase": result.get("phase", "unknown"),
|
||||
"started_at": result.get("started_at"),
|
||||
"completed_at": result.get("completed_at"),
|
||||
"findings_count": len(result.get("findings", [])),
|
||||
"findings": result.get("findings", []),
|
||||
"rejected_findings_count": len(result.get("rejected_findings", [])),
|
||||
"rejected_findings": result.get("rejected_findings", []),
|
||||
"logs_count": len(result.get("logs", [])),
|
||||
"report": result.get("report"),
|
||||
"error": result.get("error")
|
||||
}
|
||||
|
||||
raise HTTPException(status_code=404, detail="Agent data no longer in memory")
|
||||
|
||||
|
||||
@router.get("/status/{agent_id}")
|
||||
async def get_agent_status(agent_id: str):
|
||||
"""Get the status and results of an agent run - with database fallback"""
|
||||
@@ -449,6 +534,8 @@ async def get_agent_status(agent_id: str):
|
||||
"logs_count": len(result.get("logs", [])),
|
||||
"findings_count": len(result.get("findings", [])),
|
||||
"findings": result.get("findings", []),
|
||||
"rejected_findings_count": len(result.get("rejected_findings", [])),
|
||||
"rejected_findings": result.get("rejected_findings", []),
|
||||
"report": result.get("report"),
|
||||
"error": result.get("error")
|
||||
}
|
||||
@@ -495,10 +582,12 @@ async def _get_status_from_db(agent_id: str, scan_id: str):
|
||||
"evidence": getattr(v, 'poc_evidence', None) or "",
|
||||
"request": v.poc_request or "",
|
||||
"response": v.poc_response or "",
|
||||
"poc_code": v.poc_payload or "",
|
||||
"poc_code": getattr(v, 'poc_code', None) or v.poc_payload or "",
|
||||
"impact": v.impact or "",
|
||||
"remediation": v.remediation or "",
|
||||
"references": v.references or [],
|
||||
"screenshots": getattr(v, 'screenshots', None) or [],
|
||||
"url": getattr(v, 'url', None) or v.affected_endpoint or "",
|
||||
"ai_verified": True,
|
||||
"confidence": "high"
|
||||
}
|
||||
@@ -542,14 +631,14 @@ async def _get_status_from_db(agent_id: str, scan_id: str):
|
||||
|
||||
@router.post("/stop/{agent_id}")
|
||||
async def stop_agent(agent_id: str):
|
||||
"""Stop a running agent scan and auto-generate report"""
|
||||
"""Stop a running agent scan, save all findings to DB, and generate report."""
|
||||
if agent_id not in agent_results:
|
||||
raise HTTPException(status_code=404, detail="Agent not found")
|
||||
|
||||
if agent_results[agent_id]["status"] != "running":
|
||||
return {"message": "Agent is not running", "status": agent_results[agent_id]["status"]}
|
||||
|
||||
# Cancel the agent
|
||||
# Cancel the agent immediately
|
||||
if agent_id in agent_instances:
|
||||
agent_instances[agent_id].cancel()
|
||||
|
||||
@@ -558,9 +647,10 @@ async def stop_agent(agent_id: str):
|
||||
agent_results[agent_id]["phase"] = "stopped"
|
||||
agent_results[agent_id]["completed_at"] = datetime.utcnow().isoformat()
|
||||
|
||||
# Update database and auto-generate report
|
||||
# Update database: save findings + generate report
|
||||
scan_id = agent_to_scan.get(agent_id)
|
||||
report_id = None
|
||||
target = agent_results[agent_id].get("target", "Unknown")
|
||||
|
||||
if scan_id:
|
||||
try:
|
||||
@@ -573,47 +663,222 @@ async def stop_agent(agent_id: str):
|
||||
scan.status = "stopped"
|
||||
scan.completed_at = datetime.utcnow()
|
||||
|
||||
# Get findings count
|
||||
# Save confirmed findings to DB (same as completion flow)
|
||||
findings = agent_results[agent_id].get("findings", [])
|
||||
scan.total_vulnerabilities = len(findings)
|
||||
severity_counts = {"critical": 0, "high": 0, "medium": 0, "low": 0, "info": 0}
|
||||
|
||||
# Count severities
|
||||
for finding in findings:
|
||||
severity = finding.get("severity", "").lower()
|
||||
if severity == "critical":
|
||||
scan.critical_count = (scan.critical_count or 0) + 1
|
||||
elif severity == "high":
|
||||
scan.high_count = (scan.high_count or 0) + 1
|
||||
elif severity == "medium":
|
||||
scan.medium_count = (scan.medium_count or 0) + 1
|
||||
elif severity == "low":
|
||||
scan.low_count = (scan.low_count or 0) + 1
|
||||
elif severity == "info":
|
||||
scan.info_count = (scan.info_count or 0) + 1
|
||||
severity = finding.get("severity", "medium").lower()
|
||||
if severity in severity_counts:
|
||||
severity_counts[severity] += 1
|
||||
|
||||
vuln = Vulnerability(
|
||||
scan_id=scan_id,
|
||||
title=finding.get("title", finding.get("type", "Unknown")),
|
||||
vulnerability_type=finding.get("vulnerability_type", finding.get("type", "unknown")),
|
||||
severity=severity,
|
||||
cvss_score=finding.get("cvss_score"),
|
||||
cvss_vector=finding.get("cvss_vector"),
|
||||
cwe_id=finding.get("cwe_id"),
|
||||
description=finding.get("description", finding.get("evidence", "")),
|
||||
affected_endpoint=finding.get("affected_endpoint", finding.get("endpoint", finding.get("url", target))),
|
||||
poc_payload=finding.get("payload", finding.get("poc_payload", "")),
|
||||
poc_parameter=finding.get("parameter", finding.get("poc_parameter", "")),
|
||||
poc_evidence=finding.get("evidence", finding.get("poc_evidence", "")),
|
||||
poc_request=str(finding.get("request", finding.get("poc_request", "")))[:5000],
|
||||
poc_response=str(finding.get("response", finding.get("poc_response", "")))[:5000],
|
||||
impact=finding.get("impact", ""),
|
||||
remediation=finding.get("remediation", ""),
|
||||
references=finding.get("references", []),
|
||||
ai_analysis=finding.get("ai_analysis", finding.get("exploitation_steps", "")),
|
||||
poc_code=finding.get("poc_code", ""),
|
||||
screenshots=finding.get("screenshots", []),
|
||||
url=finding.get("url", finding.get("affected_endpoint", "")),
|
||||
parameter=finding.get("parameter", finding.get("poc_parameter", "")),
|
||||
validation_status="ai_confirmed",
|
||||
)
|
||||
db.add(vuln)
|
||||
|
||||
# Save rejected findings to DB for manual review
|
||||
rejected = agent_results[agent_id].get("rejected_findings", [])
|
||||
for finding in rejected:
|
||||
vuln = Vulnerability(
|
||||
scan_id=scan_id,
|
||||
title=finding.get("title", finding.get("type", "Unknown")),
|
||||
vulnerability_type=finding.get("vulnerability_type", finding.get("type", "unknown")),
|
||||
severity=finding.get("severity", "medium").lower(),
|
||||
cvss_score=finding.get("cvss_score"),
|
||||
cvss_vector=finding.get("cvss_vector"),
|
||||
cwe_id=finding.get("cwe_id"),
|
||||
description=finding.get("description", finding.get("evidence", "")),
|
||||
affected_endpoint=finding.get("affected_endpoint", finding.get("endpoint", finding.get("url", target))),
|
||||
poc_payload=finding.get("payload", finding.get("poc_payload", "")),
|
||||
poc_parameter=finding.get("parameter", finding.get("poc_parameter", "")),
|
||||
poc_evidence=finding.get("evidence", finding.get("poc_evidence", "")),
|
||||
poc_request=str(finding.get("request", finding.get("poc_request", "")))[:5000],
|
||||
poc_response=str(finding.get("response", finding.get("poc_response", "")))[:5000],
|
||||
impact=finding.get("impact", ""),
|
||||
remediation=finding.get("remediation", ""),
|
||||
references=finding.get("references", []),
|
||||
poc_code=finding.get("poc_code", ""),
|
||||
screenshots=finding.get("screenshots", []),
|
||||
url=finding.get("url", finding.get("affected_endpoint", "")),
|
||||
parameter=finding.get("parameter", finding.get("poc_parameter", "")),
|
||||
validation_status="ai_rejected",
|
||||
ai_rejection_reason=finding.get("rejection_reason", ""),
|
||||
)
|
||||
db.add(vuln)
|
||||
|
||||
# Update scan counts (confirmed only)
|
||||
scan.total_vulnerabilities = len(findings)
|
||||
scan.critical_count = severity_counts["critical"]
|
||||
scan.high_count = severity_counts["high"]
|
||||
scan.medium_count = severity_counts["medium"]
|
||||
scan.low_count = severity_counts["low"]
|
||||
scan.info_count = severity_counts["info"]
|
||||
|
||||
await db.commit()
|
||||
|
||||
# Auto-generate report
|
||||
report = Report(
|
||||
# Auto-generate report record
|
||||
report_record = Report(
|
||||
scan_id=scan_id,
|
||||
title=f"Agent Scan Report - {agent_results[agent_id].get('target', 'Unknown')}",
|
||||
title=f"Agent Scan Report - {target}",
|
||||
format="json",
|
||||
executive_summary=f"Automated security scan completed with {len(findings)} findings."
|
||||
executive_summary=f"Security scan stopped with {len(findings)} confirmed and {len(rejected)} rejected findings."
|
||||
)
|
||||
db.add(report)
|
||||
db.add(report_record)
|
||||
await db.commit()
|
||||
await db.refresh(report)
|
||||
report_id = report.id
|
||||
await db.refresh(report_record)
|
||||
report_id = report_record.id
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error updating scan status: {e}")
|
||||
print(f"Error updating scan status on stop: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
return {
|
||||
"message": "Agent stopped successfully",
|
||||
"agent_id": agent_id,
|
||||
"report_id": report_id
|
||||
"report_id": report_id,
|
||||
"findings_saved": len(agent_results[agent_id].get("findings", [])),
|
||||
"rejected_saved": len(agent_results[agent_id].get("rejected_findings", [])),
|
||||
}
|
||||
|
||||
|
||||
@router.post("/pause/{agent_id}")
|
||||
async def pause_agent(agent_id: str):
|
||||
"""Pause a running agent scan"""
|
||||
if agent_id not in agent_results:
|
||||
raise HTTPException(status_code=404, detail="Agent not found")
|
||||
|
||||
if agent_results[agent_id]["status"] != "running":
|
||||
return {"message": "Agent is not running", "status": agent_results[agent_id]["status"]}
|
||||
|
||||
if agent_id in agent_instances:
|
||||
agent_instances[agent_id].pause()
|
||||
|
||||
# Save current phase before overwriting with "paused"
|
||||
agent_results[agent_id]["last_phase"] = agent_results[agent_id].get("phase", "recon")
|
||||
agent_results[agent_id]["status"] = "paused"
|
||||
agent_results[agent_id]["phase"] = "paused"
|
||||
|
||||
return {"message": "Agent paused", "agent_id": agent_id}
|
||||
|
||||
|
||||
@router.post("/resume/{agent_id}")
|
||||
async def resume_agent(agent_id: str):
|
||||
"""Resume a paused agent scan"""
|
||||
if agent_id not in agent_results:
|
||||
raise HTTPException(status_code=404, detail="Agent not found")
|
||||
|
||||
if agent_results[agent_id]["status"] != "paused":
|
||||
return {"message": "Agent is not paused", "status": agent_results[agent_id]["status"]}
|
||||
|
||||
if agent_id in agent_instances:
|
||||
agent_instances[agent_id].resume()
|
||||
|
||||
agent_results[agent_id]["status"] = "running"
|
||||
# Restore the phase that was active before pause
|
||||
agent_results[agent_id]["phase"] = agent_results[agent_id].get("last_phase", "testing")
|
||||
|
||||
return {"message": "Agent resumed", "agent_id": agent_id}
|
||||
|
||||
|
||||
# Agent phase order for skip validation
|
||||
AGENT_PHASE_ORDER = ["recon", "analysis", "testing", "enhancement", "completed"]
|
||||
|
||||
# Map phase names from status strings to canonical phase keys
|
||||
PHASE_NORMALIZE = {
|
||||
"starting reconnaissance": "recon",
|
||||
"reconnaissance complete": "recon",
|
||||
"initial probe complete": "recon",
|
||||
"endpoint discovery complete": "recon",
|
||||
"parameter discovery complete": "recon",
|
||||
"attack surface analyzed": "analysis",
|
||||
"vulnerability testing complete": "testing",
|
||||
"findings enhanced": "enhancement",
|
||||
"assessment complete": "completed",
|
||||
}
|
||||
|
||||
|
||||
@router.post("/skip-to/{agent_id}/{target_phase}")
|
||||
async def skip_agent_phase(agent_id: str, target_phase: str):
|
||||
"""Skip the current agent phase and jump to a target phase.
|
||||
|
||||
Valid phases: recon, analysis, testing, enhancement, completed
|
||||
Can only skip forward (to a phase ahead of current).
|
||||
"""
|
||||
if agent_id not in agent_results:
|
||||
raise HTTPException(status_code=404, detail="Agent not found")
|
||||
|
||||
agent_status = agent_results[agent_id]["status"]
|
||||
if agent_status not in ("running", "paused"):
|
||||
raise HTTPException(status_code=400, detail="Agent is not running or paused")
|
||||
|
||||
if target_phase not in AGENT_PHASE_ORDER:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid phase '{target_phase}'. Valid: {', '.join(AGENT_PHASE_ORDER[1:])}"
|
||||
)
|
||||
|
||||
# Get current phase and normalize it
|
||||
current_raw = agent_results[agent_id].get("phase", "").lower()
|
||||
# Handle "paused" phase — use the last known non-paused phase, default to recon
|
||||
if current_raw in ("paused", "stopped"):
|
||||
current_raw = agent_results[agent_id].get("last_phase", "recon")
|
||||
current_phase = PHASE_NORMALIZE.get(current_raw, current_raw)
|
||||
# Also try prefix match
|
||||
for key in AGENT_PHASE_ORDER:
|
||||
if key in current_phase:
|
||||
current_phase = key
|
||||
break
|
||||
|
||||
cur_idx = AGENT_PHASE_ORDER.index(current_phase) if current_phase in AGENT_PHASE_ORDER else 0
|
||||
tgt_idx = AGENT_PHASE_ORDER.index(target_phase)
|
||||
|
||||
if tgt_idx <= cur_idx:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Cannot skip backward. Current: {current_phase}, target: {target_phase}"
|
||||
)
|
||||
|
||||
# Signal the agent instance to skip
|
||||
if agent_id in agent_instances:
|
||||
# If paused, resume first so the skip can be processed
|
||||
if agent_status == "paused":
|
||||
agent_instances[agent_id].resume()
|
||||
agent_results[agent_id]["status"] = "running"
|
||||
success = agent_instances[agent_id].skip_to_phase(target_phase)
|
||||
if not success:
|
||||
raise HTTPException(status_code=500, detail="Failed to signal phase skip")
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail="Agent instance not available for signaling")
|
||||
|
||||
return {
|
||||
"message": f"Skipping to phase: {target_phase}",
|
||||
"agent_id": agent_id,
|
||||
"from_phase": current_phase,
|
||||
"target_phase": target_phase
|
||||
}
|
||||
|
||||
|
||||
@@ -1711,7 +1976,10 @@ async def _save_realtime_findings_to_db(session_id: str, session: Dict):
|
||||
impact=finding.get("impact", ""),
|
||||
remediation=finding.get("remediation", ""),
|
||||
references=finding.get("references", []),
|
||||
ai_analysis=f"Identified during realtime session {session_id}"
|
||||
ai_analysis=f"Identified during realtime session {session_id}",
|
||||
screenshots=finding.get("screenshots", []),
|
||||
url=finding.get("url", finding.get("affected_endpoint", "")),
|
||||
parameter=finding.get("parameter", "")
|
||||
)
|
||||
db.add(vuln)
|
||||
|
||||
@@ -1809,6 +2077,29 @@ async def generate_realtime_report(session_id: str, format: str = "json"):
|
||||
scan_results=tool_results
|
||||
)
|
||||
|
||||
# Save to a per-report folder with screenshots
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
|
||||
target_name = session["target"].replace("://", "_").replace("/", "_").rstrip("_")[:40]
|
||||
report_dir = Path("reports") / f"report_{target_name}_{timestamp}"
|
||||
report_dir.mkdir(parents=True, exist_ok=True)
|
||||
(report_dir / f"report_{timestamp}.html").write_text(html_content)
|
||||
|
||||
# Copy screenshots into report folder
|
||||
screenshots_src = Path("reports") / "screenshots"
|
||||
if screenshots_src.exists():
|
||||
screenshots_dest = report_dir / "screenshots"
|
||||
for finding in findings:
|
||||
fid = finding.get("id", "")
|
||||
if fid:
|
||||
src_dir = screenshots_src / str(fid)
|
||||
if src_dir.exists():
|
||||
dest_dir = screenshots_dest / str(fid)
|
||||
dest_dir.mkdir(parents=True, exist_ok=True)
|
||||
for ss_file in src_dir.glob("*.png"):
|
||||
shutil.copy2(ss_file, dest_dir / ss_file.name)
|
||||
|
||||
return HTMLResponse(content=html_content, media_type="text/html")
|
||||
|
||||
return {
|
||||
|
||||
+166
-1
@@ -64,6 +64,19 @@ async def generate_report(
|
||||
)
|
||||
vulnerabilities = vulns_result.scalars().all()
|
||||
|
||||
# Try to get tool_executions from agent in-memory results
|
||||
tool_executions = []
|
||||
try:
|
||||
from backend.api.v1.agent import scan_to_agent, agent_results
|
||||
agent_id = scan_to_agent.get(report_data.scan_id)
|
||||
if agent_id and agent_id in agent_results:
|
||||
tool_executions = agent_results[agent_id].get("tool_executions", [])
|
||||
if not tool_executions:
|
||||
rpt = agent_results[agent_id].get("report", {})
|
||||
tool_executions = rpt.get("tool_executions", []) if isinstance(rpt, dict) else []
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Generate report
|
||||
generator = ReportGenerator()
|
||||
report_path, executive_summary = await generator.generate(
|
||||
@@ -73,7 +86,8 @@ async def generate_report(
|
||||
title=report_data.title,
|
||||
include_executive_summary=report_data.include_executive_summary,
|
||||
include_poc=report_data.include_poc,
|
||||
include_remediation=report_data.include_remediation
|
||||
include_remediation=report_data.include_remediation,
|
||||
tool_executions=tool_executions,
|
||||
)
|
||||
|
||||
# Save report record
|
||||
@@ -91,6 +105,63 @@ async def generate_report(
|
||||
return ReportResponse(**report.to_dict())
|
||||
|
||||
|
||||
@router.post("/ai-generate", response_model=ReportResponse)
|
||||
async def generate_ai_report(
|
||||
report_data: ReportGenerate,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Generate an AI-enhanced report with LLM-written executive summary and per-finding analysis."""
|
||||
# Get scan
|
||||
scan_result = await db.execute(select(Scan).where(Scan.id == report_data.scan_id))
|
||||
scan = scan_result.scalar_one_or_none()
|
||||
|
||||
if not scan:
|
||||
raise HTTPException(status_code=404, detail="Scan not found")
|
||||
|
||||
# Get vulnerabilities
|
||||
vulns_result = await db.execute(
|
||||
select(Vulnerability).where(Vulnerability.scan_id == report_data.scan_id)
|
||||
)
|
||||
vulnerabilities = vulns_result.scalars().all()
|
||||
|
||||
# Try to get tool_executions from agent in-memory results
|
||||
tool_executions = []
|
||||
try:
|
||||
from backend.api.v1.agent import scan_to_agent, agent_results
|
||||
agent_id = scan_to_agent.get(report_data.scan_id)
|
||||
if agent_id and agent_id in agent_results:
|
||||
tool_executions = agent_results[agent_id].get("tool_executions", [])
|
||||
# Also check nested report
|
||||
if not tool_executions:
|
||||
rpt = agent_results[agent_id].get("report", {})
|
||||
tool_executions = rpt.get("tool_executions", []) if isinstance(rpt, dict) else []
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Generate AI report
|
||||
generator = ReportGenerator()
|
||||
report_path, ai_summary = await generator.generate_ai_report(
|
||||
scan=scan,
|
||||
vulnerabilities=vulnerabilities,
|
||||
tool_executions=tool_executions,
|
||||
title=report_data.title,
|
||||
)
|
||||
|
||||
# Save report record
|
||||
report = Report(
|
||||
scan_id=scan.id,
|
||||
title=report_data.title or f"AI Report - {scan.name}",
|
||||
format="html",
|
||||
file_path=str(report_path),
|
||||
executive_summary=ai_summary[:2000] if ai_summary else None
|
||||
)
|
||||
db.add(report)
|
||||
await db.commit()
|
||||
await db.refresh(report)
|
||||
|
||||
return ReportResponse(**report.to_dict())
|
||||
|
||||
|
||||
@router.get("/{report_id}", response_model=ReportResponse)
|
||||
async def get_report(report_id: str, db: AsyncSession = Depends(get_db)):
|
||||
"""Get report details"""
|
||||
@@ -187,6 +258,100 @@ async def download_report(
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{report_id}/download-zip")
|
||||
async def download_report_zip(
|
||||
report_id: str,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Download report as ZIP with screenshots included"""
|
||||
import zipfile
|
||||
import tempfile
|
||||
import hashlib
|
||||
|
||||
result = await db.execute(select(Report).where(Report.id == report_id))
|
||||
report = result.scalar_one_or_none()
|
||||
|
||||
if not report:
|
||||
raise HTTPException(status_code=404, detail="Report not found")
|
||||
|
||||
scan_result = await db.execute(select(Scan).where(Scan.id == report.scan_id))
|
||||
scan = scan_result.scalar_one_or_none()
|
||||
|
||||
if not scan:
|
||||
raise HTTPException(status_code=404, detail="Scan not found for report")
|
||||
|
||||
vulns_result = await db.execute(
|
||||
select(Vulnerability).where(Vulnerability.scan_id == report.scan_id)
|
||||
)
|
||||
vulnerabilities = vulns_result.scalars().all()
|
||||
|
||||
# Generate HTML report
|
||||
generator = ReportGenerator()
|
||||
report_path, _ = await generator.generate(
|
||||
scan=scan,
|
||||
vulnerabilities=vulnerabilities,
|
||||
format="html",
|
||||
title=report.title
|
||||
)
|
||||
|
||||
# Collect screenshots (use absolute path via settings.BASE_DIR)
|
||||
# Check scan-scoped path first, then legacy flat path
|
||||
screenshots_base = settings.BASE_DIR / "reports" / "screenshots"
|
||||
scan_id_str = str(scan.id) if scan else None
|
||||
screenshot_files = []
|
||||
for vuln in vulnerabilities:
|
||||
# Finding ID is md5(vuln_type+url+param)[:8]
|
||||
vuln_url = getattr(vuln, 'url', None) or vuln.affected_endpoint or ''
|
||||
vuln_param = getattr(vuln, 'parameter', None) or getattr(vuln, 'poc_parameter', None) or ''
|
||||
finding_id = hashlib.md5(
|
||||
f"{vuln.vulnerability_type}{vuln_url}{vuln_param}".encode()
|
||||
).hexdigest()[:8]
|
||||
# Scan-scoped path: reports/screenshots/{scan_id}/{finding_id}/
|
||||
finding_dir = None
|
||||
if scan_id_str:
|
||||
scan_dir = screenshots_base / scan_id_str / finding_id
|
||||
if scan_dir.exists():
|
||||
finding_dir = scan_dir
|
||||
if not finding_dir:
|
||||
legacy_dir = screenshots_base / finding_id
|
||||
if legacy_dir.exists():
|
||||
finding_dir = legacy_dir
|
||||
if finding_dir:
|
||||
for img in finding_dir.glob("*.png"):
|
||||
screenshot_files.append((img, f"screenshots/{finding_id}/{img.name}"))
|
||||
# Also include base64 screenshots from DB as files in the ZIP
|
||||
db_screenshots = getattr(vuln, 'screenshots', None) or []
|
||||
for idx, ss in enumerate(db_screenshots):
|
||||
if isinstance(ss, str) and ss.startswith("data:image/"):
|
||||
# Will be embedded in HTML, but also save as file
|
||||
import base64 as b64
|
||||
try:
|
||||
b64_data = ss.split(",", 1)[1]
|
||||
img_bytes = b64.b64decode(b64_data)
|
||||
img_name = f"screenshots/{finding_id}/evidence_{idx+1}.png"
|
||||
# Write to temp for ZIP inclusion
|
||||
tmp_img = Path(tempfile.gettempdir()) / f"ss_{finding_id}_{idx}.png"
|
||||
tmp_img.write_bytes(img_bytes)
|
||||
screenshot_files.append((tmp_img, img_name))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Create ZIP
|
||||
zip_name = Path(report_path).stem + ".zip"
|
||||
zip_path = Path(tempfile.gettempdir()) / zip_name
|
||||
|
||||
with zipfile.ZipFile(str(zip_path), 'w', zipfile.ZIP_DEFLATED) as zf:
|
||||
zf.write(report_path, "report.html")
|
||||
for src_path, arc_name in screenshot_files:
|
||||
zf.write(str(src_path), arc_name)
|
||||
|
||||
return FileResponse(
|
||||
path=str(zip_path),
|
||||
media_type="application/zip",
|
||||
filename=zip_name
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/{report_id}")
|
||||
async def delete_report(report_id: str, db: AsyncSession = Depends(get_db)):
|
||||
"""Delete a report"""
|
||||
|
||||
@@ -0,0 +1,130 @@
|
||||
"""
|
||||
NeuroSploit v3 - Sandbox Container Management API
|
||||
|
||||
Real-time monitoring and management of per-scan Kali Linux containers.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from fastapi import APIRouter, HTTPException
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def _docker_available() -> bool:
|
||||
try:
|
||||
import docker
|
||||
docker.from_env().ping()
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
@router.get("/")
|
||||
async def list_sandboxes():
|
||||
"""List all sandbox containers with pool status."""
|
||||
try:
|
||||
from core.container_pool import get_pool
|
||||
pool = get_pool()
|
||||
except Exception as e:
|
||||
return {
|
||||
"pool": {
|
||||
"active": 0,
|
||||
"max_concurrent": 0,
|
||||
"image": "neurosploit-kali:latest",
|
||||
"container_ttl_minutes": 60,
|
||||
"docker_available": _docker_available(),
|
||||
},
|
||||
"containers": [],
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
sandboxes = pool.list_sandboxes()
|
||||
now = datetime.utcnow()
|
||||
|
||||
containers = []
|
||||
for info in sandboxes.values():
|
||||
created = info.get("created_at")
|
||||
uptime = 0.0
|
||||
if created:
|
||||
try:
|
||||
dt = datetime.fromisoformat(created)
|
||||
uptime = (now - dt).total_seconds()
|
||||
except Exception:
|
||||
pass
|
||||
containers.append({
|
||||
**info,
|
||||
"uptime_seconds": uptime,
|
||||
})
|
||||
|
||||
return {
|
||||
"pool": {
|
||||
"active": pool.active_count,
|
||||
"max_concurrent": pool.max_concurrent,
|
||||
"image": pool.image,
|
||||
"container_ttl_minutes": int(pool.container_ttl.total_seconds() / 60),
|
||||
"docker_available": _docker_available(),
|
||||
},
|
||||
"containers": containers,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/{scan_id}")
|
||||
async def get_sandbox(scan_id: str):
|
||||
"""Get health check for a specific sandbox container."""
|
||||
try:
|
||||
from core.container_pool import get_pool
|
||||
pool = get_pool()
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=503, detail=str(e))
|
||||
|
||||
sandboxes = pool.list_sandboxes()
|
||||
if scan_id not in sandboxes:
|
||||
raise HTTPException(status_code=404, detail=f"No sandbox for scan {scan_id}")
|
||||
|
||||
sb = pool._sandboxes.get(scan_id)
|
||||
if not sb:
|
||||
raise HTTPException(status_code=404, detail=f"Sandbox instance not found")
|
||||
|
||||
health = await sb.health_check()
|
||||
return health
|
||||
|
||||
|
||||
@router.delete("/{scan_id}")
|
||||
async def destroy_sandbox(scan_id: str):
|
||||
"""Destroy a specific sandbox container."""
|
||||
try:
|
||||
from core.container_pool import get_pool
|
||||
pool = get_pool()
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=503, detail=str(e))
|
||||
|
||||
sandboxes = pool.list_sandboxes()
|
||||
if scan_id not in sandboxes:
|
||||
raise HTTPException(status_code=404, detail=f"No sandbox for scan {scan_id}")
|
||||
|
||||
await pool.destroy(scan_id)
|
||||
return {"message": f"Sandbox for scan {scan_id} destroyed", "scan_id": scan_id}
|
||||
|
||||
|
||||
@router.post("/cleanup")
|
||||
async def cleanup_expired():
|
||||
"""Remove containers that have exceeded their TTL."""
|
||||
try:
|
||||
from core.container_pool import get_pool
|
||||
pool = get_pool()
|
||||
await pool.cleanup_expired()
|
||||
return {"message": "Expired containers cleaned up"}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=503, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/cleanup-orphans")
|
||||
async def cleanup_orphans():
|
||||
"""Remove orphan containers not tracked by the pool."""
|
||||
try:
|
||||
from core.container_pool import get_pool
|
||||
pool = get_pool()
|
||||
await pool.cleanup_orphans()
|
||||
return {"message": "Orphan containers cleaned up"}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=503, detail=str(e))
|
||||
+204
-3
@@ -4,6 +4,7 @@ NeuroSploit v3 - Scans API Endpoints
|
||||
from typing import List, Optional
|
||||
from datetime import datetime
|
||||
from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, func
|
||||
from urllib.parse import urlparse
|
||||
@@ -11,7 +12,7 @@ from urllib.parse import urlparse
|
||||
from backend.db.database import get_db
|
||||
from backend.models import Scan, Target, Endpoint, Vulnerability
|
||||
from backend.schemas.scan import ScanCreate, ScanUpdate, ScanResponse, ScanListResponse, ScanProgress
|
||||
from backend.services.scan_service import run_scan_task
|
||||
from backend.services.scan_service import run_scan_task, skip_to_phase as _skip_to_phase, PHASE_ORDER
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -177,6 +178,7 @@ async def start_scan(
|
||||
async def stop_scan(scan_id: str, db: AsyncSession = Depends(get_db)):
|
||||
"""Stop a running scan and save partial results"""
|
||||
from backend.api.websocket import manager as ws_manager
|
||||
from backend.api.v1.agent import scan_to_agent, agent_instances, agent_results
|
||||
|
||||
result = await db.execute(select(Scan).where(Scan.id == scan_id))
|
||||
scan = result.scalar_one_or_none()
|
||||
@@ -184,8 +186,16 @@ async def stop_scan(scan_id: str, db: AsyncSession = Depends(get_db)):
|
||||
if not scan:
|
||||
raise HTTPException(status_code=404, detail="Scan not found")
|
||||
|
||||
if scan.status != "running":
|
||||
raise HTTPException(status_code=400, detail="Scan is not running")
|
||||
if scan.status not in ("running", "paused"):
|
||||
raise HTTPException(status_code=400, detail="Scan is not running or paused")
|
||||
|
||||
# Signal the running agent to stop
|
||||
agent_id = scan_to_agent.get(scan_id)
|
||||
if agent_id and agent_id in agent_instances:
|
||||
agent_instances[agent_id].cancel()
|
||||
if agent_id in agent_results:
|
||||
agent_results[agent_id]["status"] = "stopped"
|
||||
agent_results[agent_id]["phase"] = "stopped"
|
||||
|
||||
# Update scan status
|
||||
scan.status = "stopped"
|
||||
@@ -259,6 +269,132 @@ async def stop_scan(scan_id: str, db: AsyncSession = Depends(get_db)):
|
||||
}
|
||||
|
||||
|
||||
@router.post("/{scan_id}/pause")
|
||||
async def pause_scan(scan_id: str, db: AsyncSession = Depends(get_db)):
|
||||
"""Pause a running scan"""
|
||||
from backend.api.websocket import manager as ws_manager
|
||||
from backend.api.v1.agent import scan_to_agent, agent_instances, agent_results
|
||||
|
||||
result = await db.execute(select(Scan).where(Scan.id == scan_id))
|
||||
scan = result.scalar_one_or_none()
|
||||
|
||||
if not scan:
|
||||
raise HTTPException(status_code=404, detail="Scan not found")
|
||||
|
||||
if scan.status != "running":
|
||||
raise HTTPException(status_code=400, detail="Scan is not running")
|
||||
|
||||
# Signal the agent to pause
|
||||
agent_id = scan_to_agent.get(scan_id)
|
||||
if agent_id and agent_id in agent_instances:
|
||||
agent_instances[agent_id].pause()
|
||||
if agent_id in agent_results:
|
||||
agent_results[agent_id]["status"] = "paused"
|
||||
agent_results[agent_id]["phase"] = "paused"
|
||||
|
||||
scan.status = "paused"
|
||||
scan.current_phase = "paused"
|
||||
await db.commit()
|
||||
|
||||
await ws_manager.broadcast_log(scan_id, "warning", "Scan paused by user")
|
||||
|
||||
return {"message": "Scan paused", "scan_id": scan_id}
|
||||
|
||||
|
||||
@router.post("/{scan_id}/resume")
|
||||
async def resume_scan(scan_id: str, db: AsyncSession = Depends(get_db)):
|
||||
"""Resume a paused scan"""
|
||||
from backend.api.websocket import manager as ws_manager
|
||||
from backend.api.v1.agent import scan_to_agent, agent_instances, agent_results
|
||||
|
||||
result = await db.execute(select(Scan).where(Scan.id == scan_id))
|
||||
scan = result.scalar_one_or_none()
|
||||
|
||||
if not scan:
|
||||
raise HTTPException(status_code=404, detail="Scan not found")
|
||||
|
||||
if scan.status != "paused":
|
||||
raise HTTPException(status_code=400, detail="Scan is not paused")
|
||||
|
||||
# Signal the agent to resume
|
||||
agent_id = scan_to_agent.get(scan_id)
|
||||
if agent_id and agent_id in agent_instances:
|
||||
agent_instances[agent_id].resume()
|
||||
if agent_id in agent_results:
|
||||
agent_results[agent_id]["status"] = "running"
|
||||
agent_results[agent_id]["phase"] = "testing"
|
||||
|
||||
scan.status = "running"
|
||||
scan.current_phase = "testing"
|
||||
await db.commit()
|
||||
|
||||
await ws_manager.broadcast_log(scan_id, "info", "Scan resumed by user")
|
||||
|
||||
return {"message": "Scan resumed", "scan_id": scan_id}
|
||||
|
||||
|
||||
@router.post("/{scan_id}/skip-to/{target_phase}")
|
||||
async def skip_to_phase_endpoint(scan_id: str, target_phase: str, db: AsyncSession = Depends(get_db)):
|
||||
"""Skip the current scan phase and jump to a target phase.
|
||||
|
||||
Valid phases: recon, analyzing, testing, completed
|
||||
Can only skip forward (to a phase ahead of current).
|
||||
"""
|
||||
result = await db.execute(select(Scan).where(Scan.id == scan_id))
|
||||
scan = result.scalar_one_or_none()
|
||||
|
||||
if not scan:
|
||||
raise HTTPException(status_code=404, detail="Scan not found")
|
||||
|
||||
if scan.status not in ("running", "paused"):
|
||||
raise HTTPException(status_code=400, detail="Scan is not running or paused")
|
||||
|
||||
# If paused, resume first so the skip can be processed
|
||||
if scan.status == "paused":
|
||||
from backend.api.v1.agent import scan_to_agent, agent_instances, agent_results
|
||||
agent_id = scan_to_agent.get(scan_id)
|
||||
if agent_id and agent_id in agent_instances:
|
||||
agent_instances[agent_id].resume()
|
||||
if agent_id in agent_results:
|
||||
agent_results[agent_id]["status"] = "running"
|
||||
agent_results[agent_id]["phase"] = agent_results[agent_id].get("last_phase", "testing")
|
||||
scan.status = "running"
|
||||
await db.commit()
|
||||
|
||||
if target_phase not in PHASE_ORDER:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid phase '{target_phase}'. Valid: {', '.join(PHASE_ORDER[1:])}"
|
||||
)
|
||||
|
||||
# Validate forward skip
|
||||
current_idx = PHASE_ORDER.index(scan.current_phase) if scan.current_phase in PHASE_ORDER else 0
|
||||
target_idx = PHASE_ORDER.index(target_phase)
|
||||
|
||||
if target_idx <= current_idx:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Cannot skip backward. Current: {scan.current_phase}, target: {target_phase}"
|
||||
)
|
||||
|
||||
# Signal the running scan to skip
|
||||
success = _skip_to_phase(scan_id, target_phase)
|
||||
if not success:
|
||||
raise HTTPException(status_code=500, detail="Failed to signal phase skip")
|
||||
|
||||
# Broadcast via WebSocket
|
||||
from backend.api.websocket import manager as ws_manager
|
||||
await ws_manager.broadcast_log(scan_id, "warning", f">> User requested skip to phase: {target_phase}")
|
||||
await ws_manager.broadcast_phase_change(scan_id, f"skipping_to_{target_phase}")
|
||||
|
||||
return {
|
||||
"message": f"Skipping to phase: {target_phase}",
|
||||
"scan_id": scan_id,
|
||||
"from_phase": scan.current_phase,
|
||||
"target_phase": target_phase
|
||||
}
|
||||
|
||||
|
||||
@router.get("/{scan_id}/status", response_model=ScanProgress)
|
||||
async def get_scan_status(scan_id: str, db: AsyncSession = Depends(get_db)):
|
||||
"""Get scan progress and status"""
|
||||
@@ -369,3 +505,68 @@ async def get_scan_vulnerabilities(
|
||||
"page": page,
|
||||
"per_page": per_page
|
||||
}
|
||||
|
||||
|
||||
class ValidationRequest(BaseModel):
|
||||
validation_status: str # "validated" | "false_positive" | "ai_confirmed" | "ai_rejected" | "pending_review"
|
||||
notes: Optional[str] = None
|
||||
|
||||
|
||||
@router.patch("/vulnerabilities/{vuln_id}/validate")
|
||||
async def validate_vulnerability(
|
||||
vuln_id: str,
|
||||
body: ValidationRequest,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Manually validate or reject a vulnerability finding"""
|
||||
valid_statuses = {"validated", "false_positive", "ai_confirmed", "ai_rejected", "pending_review"}
|
||||
if body.validation_status not in valid_statuses:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid status. Must be one of: {', '.join(valid_statuses)}")
|
||||
|
||||
result = await db.execute(select(Vulnerability).where(Vulnerability.id == vuln_id))
|
||||
vuln = result.scalar_one_or_none()
|
||||
|
||||
if not vuln:
|
||||
raise HTTPException(status_code=404, detail="Vulnerability not found")
|
||||
|
||||
old_status = vuln.validation_status or "ai_confirmed"
|
||||
vuln.validation_status = body.validation_status
|
||||
if body.notes:
|
||||
vuln.ai_rejection_reason = body.notes
|
||||
|
||||
# Update scan severity counts when validation status changes
|
||||
scan_result = await db.execute(select(Scan).where(Scan.id == vuln.scan_id))
|
||||
scan = scan_result.scalar_one_or_none()
|
||||
|
||||
if scan:
|
||||
sev = vuln.severity
|
||||
# If changing from rejected to validated: add to counts
|
||||
if old_status == "ai_rejected" and body.validation_status == "validated":
|
||||
scan.total_vulnerabilities = (scan.total_vulnerabilities or 0) + 1
|
||||
if sev == "critical":
|
||||
scan.critical_count = (scan.critical_count or 0) + 1
|
||||
elif sev == "high":
|
||||
scan.high_count = (scan.high_count or 0) + 1
|
||||
elif sev == "medium":
|
||||
scan.medium_count = (scan.medium_count or 0) + 1
|
||||
elif sev == "low":
|
||||
scan.low_count = (scan.low_count or 0) + 1
|
||||
elif sev == "info":
|
||||
scan.info_count = (scan.info_count or 0) + 1
|
||||
# If changing from confirmed to false_positive: subtract from counts
|
||||
elif old_status in ("ai_confirmed", "validated") and body.validation_status == "false_positive":
|
||||
scan.total_vulnerabilities = max(0, (scan.total_vulnerabilities or 0) - 1)
|
||||
if sev == "critical":
|
||||
scan.critical_count = max(0, (scan.critical_count or 0) - 1)
|
||||
elif sev == "high":
|
||||
scan.high_count = max(0, (scan.high_count or 0) - 1)
|
||||
elif sev == "medium":
|
||||
scan.medium_count = max(0, (scan.medium_count or 0) - 1)
|
||||
elif sev == "low":
|
||||
scan.low_count = max(0, (scan.low_count or 0) - 1)
|
||||
elif sev == "info":
|
||||
scan.info_count = max(0, (scan.info_count or 0) - 1)
|
||||
|
||||
await db.commit()
|
||||
|
||||
return {"message": "Vulnerability validation updated", "vulnerability": vuln.to_dict()}
|
||||
|
||||
@@ -0,0 +1,140 @@
|
||||
"""
|
||||
NeuroSploit v3 - Scheduler API Router
|
||||
|
||||
CRUD endpoints for managing scheduled scan jobs.
|
||||
"""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional, List, Dict
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
CONFIG_PATH = Path(__file__).parent.parent.parent.parent / "config" / "config.json"
|
||||
|
||||
|
||||
class ScheduleJobRequest(BaseModel):
|
||||
"""Request model for creating a scheduled job."""
|
||||
job_id: str
|
||||
target: str
|
||||
scan_type: str = "quick"
|
||||
cron_expression: Optional[str] = None
|
||||
interval_minutes: Optional[int] = None
|
||||
agent_role: Optional[str] = None
|
||||
llm_profile: Optional[str] = None
|
||||
|
||||
|
||||
class ScheduleJobResponse(BaseModel):
|
||||
"""Response model for a scheduled job."""
|
||||
id: str
|
||||
target: str
|
||||
scan_type: str
|
||||
schedule: str
|
||||
status: str
|
||||
next_run: Optional[str] = None
|
||||
last_run: Optional[str] = None
|
||||
run_count: int = 0
|
||||
|
||||
|
||||
@router.get("/", response_model=List[Dict])
|
||||
async def list_scheduled_jobs(request: Request):
|
||||
"""List all scheduled scan jobs."""
|
||||
scheduler = getattr(request.app.state, 'scheduler', None)
|
||||
if not scheduler:
|
||||
return []
|
||||
return scheduler.list_jobs()
|
||||
|
||||
|
||||
@router.post("/", response_model=Dict)
|
||||
async def create_scheduled_job(job: ScheduleJobRequest, request: Request):
|
||||
"""Create a new scheduled scan job."""
|
||||
scheduler = getattr(request.app.state, 'scheduler', None)
|
||||
if not scheduler:
|
||||
raise HTTPException(status_code=503, detail="Scheduler not available")
|
||||
|
||||
if not job.cron_expression and not job.interval_minutes:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Either cron_expression or interval_minutes must be provided"
|
||||
)
|
||||
|
||||
result = scheduler.add_job(
|
||||
job_id=job.job_id,
|
||||
target=job.target,
|
||||
scan_type=job.scan_type,
|
||||
cron_expression=job.cron_expression,
|
||||
interval_minutes=job.interval_minutes,
|
||||
agent_role=job.agent_role,
|
||||
llm_profile=job.llm_profile
|
||||
)
|
||||
|
||||
if "error" in result:
|
||||
raise HTTPException(status_code=400, detail=result["error"])
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@router.delete("/{job_id}")
|
||||
async def delete_scheduled_job(job_id: str, request: Request):
|
||||
"""Delete a scheduled scan job."""
|
||||
scheduler = getattr(request.app.state, 'scheduler', None)
|
||||
if not scheduler:
|
||||
raise HTTPException(status_code=503, detail="Scheduler not available")
|
||||
|
||||
success = scheduler.remove_job(job_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail=f"Job '{job_id}' not found")
|
||||
|
||||
return {"message": f"Job '{job_id}' deleted", "id": job_id}
|
||||
|
||||
|
||||
@router.post("/{job_id}/pause")
|
||||
async def pause_scheduled_job(job_id: str, request: Request):
|
||||
"""Pause a scheduled scan job."""
|
||||
scheduler = getattr(request.app.state, 'scheduler', None)
|
||||
if not scheduler:
|
||||
raise HTTPException(status_code=503, detail="Scheduler not available")
|
||||
|
||||
success = scheduler.pause_job(job_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail=f"Job '{job_id}' not found")
|
||||
|
||||
return {"message": f"Job '{job_id}' paused", "id": job_id, "status": "paused"}
|
||||
|
||||
|
||||
@router.post("/{job_id}/resume")
|
||||
async def resume_scheduled_job(job_id: str, request: Request):
|
||||
"""Resume a paused scheduled scan job."""
|
||||
scheduler = getattr(request.app.state, 'scheduler', None)
|
||||
if not scheduler:
|
||||
raise HTTPException(status_code=503, detail="Scheduler not available")
|
||||
|
||||
success = scheduler.resume_job(job_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail=f"Job '{job_id}' not found")
|
||||
|
||||
return {"message": f"Job '{job_id}' resumed", "id": job_id, "status": "active"}
|
||||
|
||||
|
||||
@router.get("/agent-roles", response_model=List[Dict])
|
||||
async def get_agent_roles():
|
||||
"""Return available agent roles from config.json for scheduler dropdown."""
|
||||
try:
|
||||
if not CONFIG_PATH.exists():
|
||||
return []
|
||||
config = json.loads(CONFIG_PATH.read_text())
|
||||
roles = config.get("agent_roles", {})
|
||||
result = []
|
||||
for role_id, role_data in roles.items():
|
||||
if role_data.get("enabled", True):
|
||||
result.append({
|
||||
"id": role_id,
|
||||
"name": role_id.replace("_", " ").title(),
|
||||
"description": role_data.get("description", ""),
|
||||
"tools": role_data.get("tools_allowed", []),
|
||||
})
|
||||
return result
|
||||
except Exception:
|
||||
return []
|
||||
+164
-18
@@ -1,7 +1,10 @@
|
||||
"""
|
||||
NeuroSploit v3 - Settings API Endpoints
|
||||
"""
|
||||
from typing import Optional
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, delete, text
|
||||
@@ -12,16 +15,69 @@ from backend.models import Scan, Target, Endpoint, Vulnerability, VulnerabilityT
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# Path to .env file (project root)
|
||||
ENV_FILE_PATH = Path(__file__).parent.parent.parent.parent / ".env"
|
||||
|
||||
|
||||
def _update_env_file(updates: Dict[str, str]) -> bool:
|
||||
"""
|
||||
Update key=value pairs in the .env file without breaking formatting.
|
||||
- If the key exists (even commented out), update its value
|
||||
- If the key doesn't exist, append it
|
||||
- Preserves comments and blank lines
|
||||
"""
|
||||
if not ENV_FILE_PATH.exists():
|
||||
return False
|
||||
|
||||
try:
|
||||
lines = ENV_FILE_PATH.read_text().splitlines()
|
||||
updated_keys = set()
|
||||
|
||||
new_lines = []
|
||||
for line in lines:
|
||||
stripped = line.strip()
|
||||
matched = False
|
||||
|
||||
for key, value in updates.items():
|
||||
# Match: KEY=..., # KEY=..., #KEY=...
|
||||
pattern = rf'^#?\s*{re.escape(key)}\s*='
|
||||
if re.match(pattern, stripped):
|
||||
# Replace with uncommented key=value
|
||||
new_lines.append(f"{key}={value}")
|
||||
updated_keys.add(key)
|
||||
matched = True
|
||||
break
|
||||
|
||||
if not matched:
|
||||
new_lines.append(line)
|
||||
|
||||
# Append any keys that weren't found in existing file
|
||||
for key, value in updates.items():
|
||||
if key not in updated_keys:
|
||||
new_lines.append(f"{key}={value}")
|
||||
|
||||
# Write back with trailing newline
|
||||
ENV_FILE_PATH.write_text("\n".join(new_lines) + "\n")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to update .env file: {e}")
|
||||
return False
|
||||
|
||||
|
||||
class SettingsUpdate(BaseModel):
|
||||
"""Settings update schema"""
|
||||
llm_provider: Optional[str] = None
|
||||
anthropic_api_key: Optional[str] = None
|
||||
openai_api_key: Optional[str] = None
|
||||
openrouter_api_key: Optional[str] = None
|
||||
max_concurrent_scans: Optional[int] = None
|
||||
aggressive_mode: Optional[bool] = None
|
||||
default_scan_type: Optional[str] = None
|
||||
recon_enabled_by_default: Optional[bool] = None
|
||||
enable_model_routing: Optional[bool] = None
|
||||
enable_knowledge_augmentation: Optional[bool] = None
|
||||
enable_browser_validation: Optional[bool] = None
|
||||
max_output_tokens: Optional[int] = None
|
||||
|
||||
|
||||
class SettingsResponse(BaseModel):
|
||||
@@ -29,56 +85,118 @@ class SettingsResponse(BaseModel):
|
||||
llm_provider: str = "claude"
|
||||
has_anthropic_key: bool = False
|
||||
has_openai_key: bool = False
|
||||
has_openrouter_key: bool = False
|
||||
max_concurrent_scans: int = 3
|
||||
aggressive_mode: bool = False
|
||||
default_scan_type: str = "full"
|
||||
recon_enabled_by_default: bool = True
|
||||
enable_model_routing: bool = False
|
||||
enable_knowledge_augmentation: bool = False
|
||||
enable_browser_validation: bool = False
|
||||
max_output_tokens: Optional[int] = None
|
||||
|
||||
|
||||
# In-memory settings storage (in production, use database or config file)
|
||||
_settings = {
|
||||
"llm_provider": "claude",
|
||||
"anthropic_api_key": "",
|
||||
"openai_api_key": "",
|
||||
"max_concurrent_scans": 3,
|
||||
"aggressive_mode": False,
|
||||
"default_scan_type": "full",
|
||||
"recon_enabled_by_default": True
|
||||
}
|
||||
def _load_settings_from_env() -> dict:
|
||||
"""
|
||||
Load settings from environment variables / .env file on startup.
|
||||
This ensures settings persist across server restarts and browser sessions.
|
||||
"""
|
||||
from dotenv import load_dotenv
|
||||
# Re-read .env file to pick up disk-persisted values
|
||||
if ENV_FILE_PATH.exists():
|
||||
load_dotenv(ENV_FILE_PATH, override=True)
|
||||
|
||||
def _env_bool(key: str, default: bool = False) -> bool:
|
||||
val = os.getenv(key, "").strip().lower()
|
||||
if val in ("true", "1", "yes"):
|
||||
return True
|
||||
if val in ("false", "0", "no"):
|
||||
return False
|
||||
return default
|
||||
|
||||
def _env_int(key: str, default=None):
|
||||
val = os.getenv(key, "").strip()
|
||||
if val:
|
||||
try:
|
||||
return int(val)
|
||||
except ValueError:
|
||||
pass
|
||||
return default
|
||||
|
||||
# Detect provider from which keys are set
|
||||
provider = "claude"
|
||||
if os.getenv("ANTHROPIC_API_KEY"):
|
||||
provider = "claude"
|
||||
elif os.getenv("OPENAI_API_KEY"):
|
||||
provider = "openai"
|
||||
elif os.getenv("OPENROUTER_API_KEY"):
|
||||
provider = "openrouter"
|
||||
|
||||
return {
|
||||
"llm_provider": provider,
|
||||
"anthropic_api_key": os.getenv("ANTHROPIC_API_KEY", ""),
|
||||
"openai_api_key": os.getenv("OPENAI_API_KEY", ""),
|
||||
"openrouter_api_key": os.getenv("OPENROUTER_API_KEY", ""),
|
||||
"max_concurrent_scans": _env_int("MAX_CONCURRENT_SCANS", 3),
|
||||
"aggressive_mode": _env_bool("AGGRESSIVE_MODE", False),
|
||||
"default_scan_type": os.getenv("DEFAULT_SCAN_TYPE", "full"),
|
||||
"recon_enabled_by_default": _env_bool("RECON_ENABLED_BY_DEFAULT", True),
|
||||
"enable_model_routing": _env_bool("ENABLE_MODEL_ROUTING", False),
|
||||
"enable_knowledge_augmentation": _env_bool("ENABLE_KNOWLEDGE_AUGMENTATION", False),
|
||||
"enable_browser_validation": _env_bool("ENABLE_BROWSER_VALIDATION", False),
|
||||
"max_output_tokens": _env_int("MAX_OUTPUT_TOKENS", None),
|
||||
}
|
||||
|
||||
|
||||
# Load settings from .env on module import (server start)
|
||||
_settings = _load_settings_from_env()
|
||||
|
||||
|
||||
@router.get("", response_model=SettingsResponse)
|
||||
async def get_settings():
|
||||
"""Get current settings"""
|
||||
import os
|
||||
return SettingsResponse(
|
||||
llm_provider=_settings["llm_provider"],
|
||||
has_anthropic_key=bool(_settings["anthropic_api_key"]),
|
||||
has_openai_key=bool(_settings["openai_api_key"]),
|
||||
has_anthropic_key=bool(_settings["anthropic_api_key"] or os.getenv("ANTHROPIC_API_KEY")),
|
||||
has_openai_key=bool(_settings["openai_api_key"] or os.getenv("OPENAI_API_KEY")),
|
||||
has_openrouter_key=bool(_settings["openrouter_api_key"] or os.getenv("OPENROUTER_API_KEY")),
|
||||
max_concurrent_scans=_settings["max_concurrent_scans"],
|
||||
aggressive_mode=_settings["aggressive_mode"],
|
||||
default_scan_type=_settings["default_scan_type"],
|
||||
recon_enabled_by_default=_settings["recon_enabled_by_default"]
|
||||
recon_enabled_by_default=_settings["recon_enabled_by_default"],
|
||||
enable_model_routing=_settings["enable_model_routing"],
|
||||
enable_knowledge_augmentation=_settings["enable_knowledge_augmentation"],
|
||||
enable_browser_validation=_settings["enable_browser_validation"],
|
||||
max_output_tokens=_settings["max_output_tokens"]
|
||||
)
|
||||
|
||||
|
||||
@router.put("", response_model=SettingsResponse)
|
||||
async def update_settings(settings_data: SettingsUpdate):
|
||||
"""Update settings"""
|
||||
"""Update settings - persists to memory, env vars, AND .env file"""
|
||||
env_updates: Dict[str, str] = {}
|
||||
|
||||
if settings_data.llm_provider is not None:
|
||||
_settings["llm_provider"] = settings_data.llm_provider
|
||||
|
||||
if settings_data.anthropic_api_key is not None:
|
||||
_settings["anthropic_api_key"] = settings_data.anthropic_api_key
|
||||
# Also update environment variable for LLM calls
|
||||
import os
|
||||
if settings_data.anthropic_api_key:
|
||||
os.environ["ANTHROPIC_API_KEY"] = settings_data.anthropic_api_key
|
||||
env_updates["ANTHROPIC_API_KEY"] = settings_data.anthropic_api_key
|
||||
|
||||
if settings_data.openai_api_key is not None:
|
||||
_settings["openai_api_key"] = settings_data.openai_api_key
|
||||
import os
|
||||
if settings_data.openai_api_key:
|
||||
os.environ["OPENAI_API_KEY"] = settings_data.openai_api_key
|
||||
env_updates["OPENAI_API_KEY"] = settings_data.openai_api_key
|
||||
|
||||
if settings_data.openrouter_api_key is not None:
|
||||
_settings["openrouter_api_key"] = settings_data.openrouter_api_key
|
||||
if settings_data.openrouter_api_key:
|
||||
os.environ["OPENROUTER_API_KEY"] = settings_data.openrouter_api_key
|
||||
env_updates["OPENROUTER_API_KEY"] = settings_data.openrouter_api_key
|
||||
|
||||
if settings_data.max_concurrent_scans is not None:
|
||||
_settings["max_concurrent_scans"] = settings_data.max_concurrent_scans
|
||||
@@ -92,6 +210,34 @@ async def update_settings(settings_data: SettingsUpdate):
|
||||
if settings_data.recon_enabled_by_default is not None:
|
||||
_settings["recon_enabled_by_default"] = settings_data.recon_enabled_by_default
|
||||
|
||||
if settings_data.enable_model_routing is not None:
|
||||
_settings["enable_model_routing"] = settings_data.enable_model_routing
|
||||
val = str(settings_data.enable_model_routing).lower()
|
||||
os.environ["ENABLE_MODEL_ROUTING"] = val
|
||||
env_updates["ENABLE_MODEL_ROUTING"] = val
|
||||
|
||||
if settings_data.enable_knowledge_augmentation is not None:
|
||||
_settings["enable_knowledge_augmentation"] = settings_data.enable_knowledge_augmentation
|
||||
val = str(settings_data.enable_knowledge_augmentation).lower()
|
||||
os.environ["ENABLE_KNOWLEDGE_AUGMENTATION"] = val
|
||||
env_updates["ENABLE_KNOWLEDGE_AUGMENTATION"] = val
|
||||
|
||||
if settings_data.enable_browser_validation is not None:
|
||||
_settings["enable_browser_validation"] = settings_data.enable_browser_validation
|
||||
val = str(settings_data.enable_browser_validation).lower()
|
||||
os.environ["ENABLE_BROWSER_VALIDATION"] = val
|
||||
env_updates["ENABLE_BROWSER_VALIDATION"] = val
|
||||
|
||||
if settings_data.max_output_tokens is not None:
|
||||
_settings["max_output_tokens"] = settings_data.max_output_tokens
|
||||
if settings_data.max_output_tokens:
|
||||
os.environ["MAX_OUTPUT_TOKENS"] = str(settings_data.max_output_tokens)
|
||||
env_updates["MAX_OUTPUT_TOKENS"] = str(settings_data.max_output_tokens)
|
||||
|
||||
# Persist to .env file on disk
|
||||
if env_updates:
|
||||
_update_env_file(env_updates)
|
||||
|
||||
return await get_settings()
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,568 @@
|
||||
"""
|
||||
Terminal Agent API - Interactive infrastructure pentesting via AI chat + Docker sandbox.
|
||||
|
||||
Provides session-based terminal interaction with AI-guided command execution,
|
||||
exploitation path tracking, and VPN status monitoring.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import re
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.llm_manager import LLMManager
|
||||
from core.sandbox_manager import get_sandbox
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# In-memory session store
|
||||
# ---------------------------------------------------------------------------
|
||||
terminal_sessions: Dict[str, Dict] = {}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pre-built templates
|
||||
# ---------------------------------------------------------------------------
|
||||
TEMPLATES = {
|
||||
"network_scanner": {
|
||||
"name": "Network Scanner",
|
||||
"description": "Host discovery, port scanning, and service detection",
|
||||
"system_prompt": (
|
||||
"You are an expert network reconnaissance specialist. You guide the "
|
||||
"operator through systematic host discovery, port scanning, and service "
|
||||
"fingerprinting. Always suggest nmap flags appropriate for the situation, "
|
||||
"explain output, and recommend next steps based on discovered services. "
|
||||
"Prioritize stealth when asked and suggest timing/fragmentation options."
|
||||
),
|
||||
"initial_commands": [
|
||||
"nmap -sn {target}",
|
||||
"nmap -sV -sC -O -p- {target}",
|
||||
"nmap -sU --top-ports 50 {target}",
|
||||
],
|
||||
},
|
||||
"lateral_movement": {
|
||||
"name": "Lateral Movement",
|
||||
"description": "Pass-the-hash, SMB/WinRM pivoting, and SSH tunneling",
|
||||
"system_prompt": (
|
||||
"You are a lateral movement specialist. You help the operator pivot "
|
||||
"through compromised networks using techniques such as pass-the-hash, "
|
||||
"SMB relay, WinRM sessions, SSH tunneling, and SOCKS proxying. Always "
|
||||
"verify credentials before attempting pivots, suggest cleanup steps, "
|
||||
"and track which hosts have been compromised."
|
||||
),
|
||||
"initial_commands": [
|
||||
"crackmapexec smb {target} -u '' -p ''",
|
||||
"crackmapexec smb {target} --shares -u '' -p ''",
|
||||
"ssh -D 1080 -N -f user@{target}",
|
||||
],
|
||||
},
|
||||
"privilege_escalation": {
|
||||
"name": "Privilege Escalation",
|
||||
"description": "SUID binaries, kernel exploits, cron jobs, and writable paths",
|
||||
"system_prompt": (
|
||||
"You are a privilege escalation expert for Linux and Windows systems. "
|
||||
"Guide the operator through enumeration of SUID/SGID binaries, kernel "
|
||||
"version checks, misconfigured cron jobs, writable PATH directories, "
|
||||
"sudo misconfigurations, and capability abuse. Suggest automated tools "
|
||||
"like linpeas/winpeas when appropriate and explain each finding."
|
||||
),
|
||||
"initial_commands": [
|
||||
"id && whoami && uname -a",
|
||||
"find / -perm -4000 -type f 2>/dev/null",
|
||||
"cat /etc/crontab && ls -la /etc/cron.*",
|
||||
"echo $PATH | tr ':' '\\n' | xargs -I {} ls -ld {}",
|
||||
],
|
||||
},
|
||||
"vpn_recon": {
|
||||
"name": "VPN Reconnaissance",
|
||||
"description": "VPN connection management and internal network discovery",
|
||||
"system_prompt": (
|
||||
"You are a VPN and internal network reconnaissance specialist. You "
|
||||
"help the operator connect to target VPNs, verify tunnel status, "
|
||||
"discover internal subnets, and enumerate services behind the VPN. "
|
||||
"Always confirm connectivity before proceeding with scans and suggest "
|
||||
"appropriate scope for internal reconnaissance."
|
||||
),
|
||||
"initial_commands": [
|
||||
"openvpn --config client.ovpn --daemon",
|
||||
"ip addr show tun0",
|
||||
"ip route | grep tun",
|
||||
"nmap -sn 10.0.0.0/24",
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pydantic request / response models
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class CreateSessionRequest(BaseModel):
|
||||
template_id: Optional[str] = None
|
||||
target: Optional[str] = ""
|
||||
name: Optional[str] = ""
|
||||
|
||||
|
||||
class MessageRequest(BaseModel):
|
||||
message: str
|
||||
|
||||
|
||||
class ExecuteCommandRequest(BaseModel):
|
||||
command: str
|
||||
execution_method: str = "sandbox" # "sandbox" or "direct"
|
||||
|
||||
|
||||
class ExploitationStepRequest(BaseModel):
|
||||
description: str
|
||||
command: Optional[str] = ""
|
||||
result: Optional[str] = ""
|
||||
step_type: str = "recon" # recon | exploit | pivot | escalate | action
|
||||
|
||||
|
||||
class SessionSummary(BaseModel):
|
||||
session_id: str
|
||||
name: str
|
||||
target: str
|
||||
template_id: Optional[str]
|
||||
status: str
|
||||
created_at: str
|
||||
messages_count: int
|
||||
commands_count: int
|
||||
|
||||
|
||||
class MessageResponse(BaseModel):
|
||||
role: str
|
||||
response: str
|
||||
timestamp: str
|
||||
suggested_commands: List[str]
|
||||
|
||||
|
||||
class CommandResult(BaseModel):
|
||||
command: str
|
||||
exit_code: int
|
||||
stdout: str
|
||||
stderr: str
|
||||
duration: float
|
||||
execution_method: str
|
||||
timestamp: str
|
||||
|
||||
|
||||
class VPNStatus(BaseModel):
|
||||
connected: bool
|
||||
ip: Optional[str] = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _now_iso() -> str:
|
||||
return datetime.now(timezone.utc).isoformat()
|
||||
|
||||
|
||||
def _build_session(
|
||||
session_id: str,
|
||||
name: str,
|
||||
target: str,
|
||||
template_id: Optional[str],
|
||||
) -> Dict:
|
||||
return {
|
||||
"session_id": session_id,
|
||||
"name": name,
|
||||
"target": target,
|
||||
"template_id": template_id,
|
||||
"status": "active",
|
||||
"created_at": _now_iso(),
|
||||
"messages": [],
|
||||
"command_history": [],
|
||||
"exploitation_path": [],
|
||||
"vpn_status": {"connected": False, "ip": None},
|
||||
}
|
||||
|
||||
|
||||
def _get_session(session_id: str) -> Dict:
|
||||
session = terminal_sessions.get(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
return session
|
||||
|
||||
|
||||
def _build_context_string(
|
||||
messages: List[Dict],
|
||||
commands: List[Dict],
|
||||
exploitation: List[Dict],
|
||||
) -> str:
|
||||
parts: List[str] = []
|
||||
|
||||
if messages:
|
||||
parts.append("=== Recent Conversation ===")
|
||||
for msg in messages:
|
||||
role = msg.get("role", "unknown").upper()
|
||||
parts.append(f"[{role}] {msg.get('content', '')}")
|
||||
|
||||
if commands:
|
||||
parts.append("\n=== Recent Command Results ===")
|
||||
for cmd in commands:
|
||||
parts.append(
|
||||
f"$ {cmd['command']}\n"
|
||||
f"Exit code: {cmd['exit_code']}\n"
|
||||
f"Stdout: {cmd['stdout'][:500]}\n"
|
||||
f"Stderr: {cmd['stderr'][:300]}"
|
||||
)
|
||||
|
||||
if exploitation:
|
||||
parts.append("\n=== Exploitation Path ===")
|
||||
for i, step in enumerate(exploitation, 1):
|
||||
parts.append(
|
||||
f"Step {i} [{step['step_type']}]: {step['description']}"
|
||||
)
|
||||
if step.get("command"):
|
||||
parts.append(f" Command: {step['command']}")
|
||||
if step.get("result"):
|
||||
parts.append(f" Result: {step['result'][:300]}")
|
||||
|
||||
return "\n".join(parts)
|
||||
|
||||
|
||||
def _extract_suggested_commands(text: str) -> List[str]:
|
||||
"""Extract commands from backtick-fenced code blocks."""
|
||||
blocks = re.findall(r"```(?:bash|sh|shell)?\n?(.*?)```", text, re.DOTALL)
|
||||
commands: List[str] = []
|
||||
for block in blocks:
|
||||
for line in block.strip().splitlines():
|
||||
stripped = line.strip()
|
||||
if stripped and not stripped.startswith("#"):
|
||||
commands.append(stripped)
|
||||
return commands
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Template endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.get("/templates")
|
||||
async def list_templates():
|
||||
"""List all available session templates."""
|
||||
result = []
|
||||
for tid, tmpl in TEMPLATES.items():
|
||||
result.append({
|
||||
"id": tid,
|
||||
"name": tmpl["name"],
|
||||
"description": tmpl["description"],
|
||||
"initial_commands": tmpl["initial_commands"],
|
||||
})
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Session CRUD
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.post("/session")
|
||||
async def create_session(req: CreateSessionRequest):
|
||||
"""Create a new terminal session, optionally from a template."""
|
||||
session_id = str(uuid.uuid4())
|
||||
target = req.target or ""
|
||||
template_id = req.template_id
|
||||
|
||||
if template_id and template_id not in TEMPLATES:
|
||||
raise HTTPException(status_code=400, detail=f"Unknown template: {template_id}")
|
||||
|
||||
name = req.name or (
|
||||
TEMPLATES[template_id]["name"] if template_id else f"Session {session_id[:8]}"
|
||||
)
|
||||
|
||||
session = _build_session(session_id, name, target, template_id)
|
||||
|
||||
# Seed initial system message from template
|
||||
if template_id:
|
||||
tmpl = TEMPLATES[template_id]
|
||||
session["messages"].append({
|
||||
"role": "system",
|
||||
"content": tmpl["system_prompt"],
|
||||
"timestamp": _now_iso(),
|
||||
"metadata": {"template": template_id},
|
||||
})
|
||||
# Provide initial suggested commands with target interpolated
|
||||
initial_cmds = [
|
||||
cmd.replace("{target}", target) for cmd in tmpl["initial_commands"]
|
||||
]
|
||||
session["messages"].append({
|
||||
"role": "assistant",
|
||||
"content": (
|
||||
f"Session initialised with the **{tmpl['name']}** template.\n\n"
|
||||
f"Target: `{target or '(not set)'}`\n\n"
|
||||
"Suggested starting commands:\n"
|
||||
+ "\n".join(f"```\n{c}\n```" for c in initial_cmds)
|
||||
),
|
||||
"timestamp": _now_iso(),
|
||||
"suggested_commands": initial_cmds,
|
||||
})
|
||||
|
||||
terminal_sessions[session_id] = session
|
||||
return session
|
||||
|
||||
|
||||
@router.get("/sessions")
|
||||
async def list_sessions():
|
||||
"""Return lightweight summaries of every session."""
|
||||
summaries = []
|
||||
for sid, s in terminal_sessions.items():
|
||||
summaries.append(
|
||||
SessionSummary(
|
||||
session_id=sid,
|
||||
name=s["name"],
|
||||
target=s["target"],
|
||||
template_id=s["template_id"],
|
||||
status=s["status"],
|
||||
created_at=s["created_at"],
|
||||
messages_count=len(s["messages"]),
|
||||
commands_count=len(s["command_history"]),
|
||||
).model_dump()
|
||||
)
|
||||
return summaries
|
||||
|
||||
|
||||
@router.get("/sessions/{session_id}")
|
||||
async def get_session(session_id: str):
|
||||
"""Return the full session including messages, commands, and exploitation path."""
|
||||
return _get_session(session_id)
|
||||
|
||||
|
||||
@router.delete("/sessions/{session_id}")
|
||||
async def delete_session(session_id: str):
|
||||
"""Delete a terminal session."""
|
||||
if session_id not in terminal_sessions:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
del terminal_sessions[session_id]
|
||||
return {"status": "deleted", "session_id": session_id}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AI message interaction
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.post("/sessions/{session_id}/message")
|
||||
async def send_message(session_id: str, req: MessageRequest):
|
||||
"""Send a user prompt to the AI and receive a response with suggested commands."""
|
||||
session = _get_session(session_id)
|
||||
user_message = req.message.strip()
|
||||
if not user_message:
|
||||
raise HTTPException(status_code=400, detail="Message content cannot be empty")
|
||||
|
||||
# Record user message
|
||||
session["messages"].append({
|
||||
"role": "user",
|
||||
"content": user_message,
|
||||
"timestamp": _now_iso(),
|
||||
"metadata": {},
|
||||
})
|
||||
|
||||
# Determine system prompt
|
||||
template_id = session.get("template_id")
|
||||
if template_id and template_id in TEMPLATES:
|
||||
system_prompt = TEMPLATES[template_id]["system_prompt"]
|
||||
else:
|
||||
system_prompt = (
|
||||
"You are an expert infrastructure penetration tester. Help the "
|
||||
"operator plan and execute attacks against the target. Suggest "
|
||||
"concrete commands, explain their purpose, and interpret output. "
|
||||
"Always wrap commands in fenced code blocks so they can be extracted."
|
||||
)
|
||||
|
||||
# Build context window
|
||||
context_messages = session["messages"][-20:]
|
||||
context_cmds = session["command_history"][-10:]
|
||||
exploitation = session["exploitation_path"]
|
||||
context = _build_context_string(context_messages, context_cmds, exploitation)
|
||||
|
||||
# Call LLM
|
||||
try:
|
||||
llm = LLMManager()
|
||||
prompt = f"{context}\n\nUser: {user_message}"
|
||||
response = await llm.generate(prompt, system_prompt)
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=502, detail=f"LLM call failed: {exc}")
|
||||
|
||||
suggested_commands = _extract_suggested_commands(response)
|
||||
|
||||
# Record assistant response
|
||||
session["messages"].append({
|
||||
"role": "assistant",
|
||||
"content": response,
|
||||
"timestamp": _now_iso(),
|
||||
"suggested_commands": suggested_commands,
|
||||
})
|
||||
|
||||
return MessageResponse(
|
||||
role="assistant",
|
||||
response=response,
|
||||
timestamp=session["messages"][-1]["timestamp"],
|
||||
suggested_commands=suggested_commands,
|
||||
).model_dump()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Command execution
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.post("/sessions/{session_id}/execute")
|
||||
async def execute_command(session_id: str, req: ExecuteCommandRequest):
|
||||
"""Execute a command in the Docker sandbox (fallback: direct shell)."""
|
||||
session = _get_session(session_id)
|
||||
command = req.command.strip()
|
||||
if not command:
|
||||
raise HTTPException(status_code=400, detail="Command cannot be empty")
|
||||
|
||||
start = time.time()
|
||||
stdout = ""
|
||||
stderr = ""
|
||||
exit_code = -1
|
||||
execution_method = "direct"
|
||||
|
||||
# Use requested execution method
|
||||
use_sandbox = req.execution_method == "sandbox"
|
||||
|
||||
if use_sandbox:
|
||||
try:
|
||||
sandbox = await get_sandbox()
|
||||
if sandbox and sandbox.is_available:
|
||||
result = await sandbox.execute_raw(command)
|
||||
stdout = result.stdout
|
||||
stderr = result.stderr
|
||||
exit_code = result.exit_code
|
||||
execution_method = "sandbox"
|
||||
except Exception:
|
||||
pass # Fall through to direct execution
|
||||
|
||||
# Fallback or direct execution requested
|
||||
if execution_method != "sandbox":
|
||||
try:
|
||||
proc = await asyncio.create_subprocess_shell(
|
||||
command,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
raw_stdout, raw_stderr = await asyncio.wait_for(
|
||||
proc.communicate(), timeout=120
|
||||
)
|
||||
stdout = raw_stdout.decode(errors="replace")
|
||||
stderr = raw_stderr.decode(errors="replace")
|
||||
exit_code = proc.returncode or 0
|
||||
execution_method = "direct"
|
||||
except asyncio.TimeoutError:
|
||||
stderr = "Command timed out after 120 seconds"
|
||||
exit_code = 124
|
||||
except Exception as exc:
|
||||
stderr = str(exc)
|
||||
exit_code = 1
|
||||
|
||||
duration = round(time.time() - start, 3)
|
||||
|
||||
cmd_record = {
|
||||
"command": command,
|
||||
"exit_code": exit_code,
|
||||
"stdout": stdout,
|
||||
"stderr": stderr,
|
||||
"duration": duration,
|
||||
"execution_method": execution_method,
|
||||
"timestamp": _now_iso(),
|
||||
}
|
||||
session["command_history"].append(cmd_record)
|
||||
|
||||
# Mirror into messages for AI context continuity
|
||||
output_preview = stdout[:2000] if stdout else stderr[:2000]
|
||||
session["messages"].append({
|
||||
"role": "tool",
|
||||
"content": f"$ {command}\n[exit {exit_code}] ({execution_method}, {duration}s)\n{output_preview}",
|
||||
"timestamp": cmd_record["timestamp"],
|
||||
"metadata": {"exit_code": exit_code, "execution_method": execution_method},
|
||||
})
|
||||
|
||||
return CommandResult(**cmd_record).model_dump()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Exploitation path
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.post("/sessions/{session_id}/exploitation-path")
|
||||
async def add_exploitation_step(session_id: str, req: ExploitationStepRequest):
|
||||
"""Add a manual step to the exploitation path timeline."""
|
||||
session = _get_session(session_id)
|
||||
|
||||
valid_types = {"recon", "exploit", "pivot", "escalate", "action"}
|
||||
if req.step_type not in valid_types:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"step_type must be one of {sorted(valid_types)}",
|
||||
)
|
||||
|
||||
step = {
|
||||
"description": req.description,
|
||||
"command": req.command or "",
|
||||
"result": req.result or "",
|
||||
"timestamp": _now_iso(),
|
||||
"step_type": req.step_type,
|
||||
}
|
||||
session["exploitation_path"].append(step)
|
||||
return step
|
||||
|
||||
|
||||
@router.get("/sessions/{session_id}/exploitation-path")
|
||||
async def get_exploitation_path(session_id: str):
|
||||
"""Return the full exploitation path timeline."""
|
||||
session = _get_session(session_id)
|
||||
return session["exploitation_path"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# VPN status
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.get("/sessions/{session_id}/vpn-status")
|
||||
async def get_vpn_status(session_id: str):
|
||||
"""Check OpenVPN process and tun0 interface status."""
|
||||
session = _get_session(session_id)
|
||||
|
||||
connected = False
|
||||
ip_addr: Optional[str] = None
|
||||
|
||||
# Check for running openvpn process
|
||||
try:
|
||||
proc = await asyncio.create_subprocess_shell(
|
||||
"pgrep -a openvpn",
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
raw_stdout, _ = await asyncio.wait_for(proc.communicate(), timeout=5)
|
||||
if proc.returncode == 0 and raw_stdout.strip():
|
||||
connected = True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Check tun0 interface for IP
|
||||
if connected:
|
||||
try:
|
||||
proc = await asyncio.create_subprocess_shell(
|
||||
"ip addr show tun0",
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
raw_stdout, _ = await asyncio.wait_for(proc.communicate(), timeout=5)
|
||||
if proc.returncode == 0:
|
||||
match = re.search(
|
||||
r"inet\s+(\d+\.\d+\.\d+\.\d+)", raw_stdout.decode(errors="replace")
|
||||
)
|
||||
if match:
|
||||
ip_addr = match.group(1)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
vpn = {"connected": connected, "ip": ip_addr}
|
||||
session["vpn_status"] = vpn
|
||||
return VPNStatus(**vpn).model_dump()
|
||||
@@ -0,0 +1,876 @@
|
||||
"""
|
||||
NeuroSploit v3 - Vulnerability Lab API Endpoints
|
||||
|
||||
Isolated vulnerability testing against labs, CTFs, and PortSwigger challenges.
|
||||
Test individual vuln types one at a time and track results.
|
||||
"""
|
||||
from typing import Optional, Dict, List
|
||||
from fastapi import APIRouter, HTTPException, BackgroundTasks
|
||||
from pydantic import BaseModel, Field
|
||||
from datetime import datetime
|
||||
from sqlalchemy import select, func, text
|
||||
|
||||
from backend.core.autonomous_agent import AutonomousAgent, OperationMode
|
||||
from backend.core.vuln_engine.registry import VulnerabilityRegistry
|
||||
from backend.db.database import async_session_factory
|
||||
from backend.models import Scan, Target, Vulnerability, Endpoint, Report, VulnLabChallenge
|
||||
|
||||
# Import agent.py's shared dicts so ScanDetailsPage can find our scans
|
||||
from backend.api.v1.agent import (
|
||||
agent_results, agent_instances, agent_to_scan, scan_to_agent
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# In-memory tracking for running lab tests
|
||||
lab_agents: Dict[str, AutonomousAgent] = {}
|
||||
lab_results: Dict[str, Dict] = {}
|
||||
|
||||
|
||||
# --- Request/Response Models ---
|
||||
|
||||
class VulnLabRunRequest(BaseModel):
|
||||
target_url: str = Field(..., description="Target URL to test (lab, CTF, etc.)")
|
||||
vuln_type: str = Field(..., description="Vulnerability type to test (e.g. xss_reflected)")
|
||||
challenge_name: Optional[str] = Field(None, description="Name of the lab/challenge")
|
||||
auth_type: Optional[str] = Field(None, description="Auth type: cookie, bearer, basic, header")
|
||||
auth_value: Optional[str] = Field(None, description="Auth credential value")
|
||||
custom_headers: Optional[Dict[str, str]] = Field(None, description="Custom HTTP headers")
|
||||
notes: Optional[str] = Field(None, description="Notes about this challenge")
|
||||
|
||||
|
||||
class VulnLabResponse(BaseModel):
|
||||
challenge_id: str
|
||||
agent_id: str
|
||||
status: str
|
||||
message: str
|
||||
|
||||
|
||||
class VulnTypeInfo(BaseModel):
|
||||
key: str
|
||||
title: str
|
||||
severity: str
|
||||
cwe_id: str
|
||||
category: str
|
||||
|
||||
|
||||
# --- Vuln type categories for the selector ---
|
||||
|
||||
VULN_CATEGORIES = {
|
||||
"injection": {
|
||||
"label": "Injection",
|
||||
"types": [
|
||||
"xss_reflected", "xss_stored", "xss_dom",
|
||||
"sqli_error", "sqli_union", "sqli_blind", "sqli_time",
|
||||
"command_injection", "ssti", "nosql_injection",
|
||||
]
|
||||
},
|
||||
"advanced_injection": {
|
||||
"label": "Advanced Injection",
|
||||
"types": [
|
||||
"ldap_injection", "xpath_injection", "graphql_injection",
|
||||
"crlf_injection", "header_injection", "email_injection",
|
||||
"el_injection", "log_injection", "html_injection",
|
||||
"csv_injection", "orm_injection",
|
||||
]
|
||||
},
|
||||
"file_access": {
|
||||
"label": "File Access",
|
||||
"types": [
|
||||
"lfi", "rfi", "path_traversal", "xxe", "file_upload",
|
||||
"arbitrary_file_read", "arbitrary_file_delete", "zip_slip",
|
||||
]
|
||||
},
|
||||
"request_forgery": {
|
||||
"label": "Request Forgery",
|
||||
"types": [
|
||||
"ssrf", "csrf", "graphql_introspection", "graphql_dos",
|
||||
]
|
||||
},
|
||||
"authentication": {
|
||||
"label": "Authentication",
|
||||
"types": [
|
||||
"auth_bypass", "jwt_manipulation", "session_fixation",
|
||||
"weak_password", "default_credentials", "two_factor_bypass",
|
||||
"oauth_misconfig",
|
||||
]
|
||||
},
|
||||
"authorization": {
|
||||
"label": "Authorization",
|
||||
"types": [
|
||||
"idor", "bola", "privilege_escalation",
|
||||
"bfla", "mass_assignment", "forced_browsing",
|
||||
]
|
||||
},
|
||||
"client_side": {
|
||||
"label": "Client-Side",
|
||||
"types": [
|
||||
"cors_misconfiguration", "clickjacking", "open_redirect",
|
||||
"dom_clobbering", "postmessage_vuln", "websocket_hijack",
|
||||
"prototype_pollution", "css_injection", "tabnabbing",
|
||||
]
|
||||
},
|
||||
"infrastructure": {
|
||||
"label": "Infrastructure",
|
||||
"types": [
|
||||
"security_headers", "ssl_issues", "http_methods",
|
||||
"directory_listing", "debug_mode", "exposed_admin_panel",
|
||||
"exposed_api_docs", "insecure_cookie_flags",
|
||||
]
|
||||
},
|
||||
"logic": {
|
||||
"label": "Business Logic",
|
||||
"types": [
|
||||
"race_condition", "business_logic", "rate_limit_bypass",
|
||||
"parameter_pollution", "type_juggling", "timing_attack",
|
||||
"host_header_injection", "http_smuggling", "cache_poisoning",
|
||||
]
|
||||
},
|
||||
"data_exposure": {
|
||||
"label": "Data Exposure",
|
||||
"types": [
|
||||
"sensitive_data_exposure", "information_disclosure",
|
||||
"api_key_exposure", "source_code_disclosure",
|
||||
"backup_file_exposure", "version_disclosure",
|
||||
]
|
||||
},
|
||||
"cloud_supply": {
|
||||
"label": "Cloud & Supply Chain",
|
||||
"types": [
|
||||
"s3_bucket_misconfig", "cloud_metadata_exposure",
|
||||
"subdomain_takeover", "vulnerable_dependency",
|
||||
"container_escape", "serverless_misconfiguration",
|
||||
]
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _get_vuln_category(vuln_type: str) -> str:
|
||||
"""Get category for a vuln type"""
|
||||
for cat_key, cat_info in VULN_CATEGORIES.items():
|
||||
if vuln_type in cat_info["types"]:
|
||||
return cat_key
|
||||
return "other"
|
||||
|
||||
|
||||
# --- Endpoints ---
|
||||
|
||||
@router.get("/types")
|
||||
async def list_vuln_types():
|
||||
"""List all available vulnerability types grouped by category"""
|
||||
registry = VulnerabilityRegistry()
|
||||
result = {}
|
||||
|
||||
for cat_key, cat_info in VULN_CATEGORIES.items():
|
||||
types_list = []
|
||||
for vtype in cat_info["types"]:
|
||||
info = registry.VULNERABILITY_INFO.get(vtype, {})
|
||||
types_list.append({
|
||||
"key": vtype,
|
||||
"title": info.get("title", vtype.replace("_", " ").title()),
|
||||
"severity": info.get("severity", "medium"),
|
||||
"cwe_id": info.get("cwe_id", ""),
|
||||
"description": info.get("description", "")[:120] if info.get("description") else "",
|
||||
})
|
||||
result[cat_key] = {
|
||||
"label": cat_info["label"],
|
||||
"types": types_list,
|
||||
"count": len(types_list),
|
||||
}
|
||||
|
||||
return {"categories": result, "total_types": sum(len(c["types"]) for c in VULN_CATEGORIES.values())}
|
||||
|
||||
|
||||
@router.post("/run", response_model=VulnLabResponse)
|
||||
async def run_vuln_lab(request: VulnLabRunRequest, background_tasks: BackgroundTasks):
|
||||
"""Launch an isolated vulnerability test for a specific vuln type"""
|
||||
import uuid
|
||||
|
||||
# Validate vuln type exists
|
||||
registry = VulnerabilityRegistry()
|
||||
if request.vuln_type not in registry.VULNERABILITY_INFO:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Unknown vulnerability type: {request.vuln_type}. Use GET /vuln-lab/types for available types."
|
||||
)
|
||||
|
||||
challenge_id = str(uuid.uuid4())
|
||||
agent_id = str(uuid.uuid4())[:8]
|
||||
category = _get_vuln_category(request.vuln_type)
|
||||
|
||||
# Build auth headers
|
||||
auth_headers = {}
|
||||
if request.auth_type and request.auth_value:
|
||||
if request.auth_type == "cookie":
|
||||
auth_headers["Cookie"] = request.auth_value
|
||||
elif request.auth_type == "bearer":
|
||||
auth_headers["Authorization"] = f"Bearer {request.auth_value}"
|
||||
elif request.auth_type == "basic":
|
||||
import base64
|
||||
auth_headers["Authorization"] = f"Basic {base64.b64encode(request.auth_value.encode()).decode()}"
|
||||
elif request.auth_type == "header":
|
||||
if ":" in request.auth_value:
|
||||
name, value = request.auth_value.split(":", 1)
|
||||
auth_headers[name.strip()] = value.strip()
|
||||
|
||||
if request.custom_headers:
|
||||
auth_headers.update(request.custom_headers)
|
||||
|
||||
# Create DB record
|
||||
async with async_session_factory() as db:
|
||||
challenge = VulnLabChallenge(
|
||||
id=challenge_id,
|
||||
target_url=request.target_url,
|
||||
challenge_name=request.challenge_name,
|
||||
vuln_type=request.vuln_type,
|
||||
vuln_category=category,
|
||||
auth_type=request.auth_type,
|
||||
auth_value=request.auth_value,
|
||||
status="running",
|
||||
agent_id=agent_id,
|
||||
started_at=datetime.utcnow(),
|
||||
notes=request.notes,
|
||||
)
|
||||
db.add(challenge)
|
||||
await db.commit()
|
||||
|
||||
# Init in-memory tracking (both local and in agent.py's shared dicts)
|
||||
vuln_info = registry.VULNERABILITY_INFO[request.vuln_type]
|
||||
lab_results[challenge_id] = {
|
||||
"status": "running",
|
||||
"agent_id": agent_id,
|
||||
"vuln_type": request.vuln_type,
|
||||
"target": request.target_url,
|
||||
"progress": 0,
|
||||
"phase": "initializing",
|
||||
"findings": [],
|
||||
"logs": [],
|
||||
}
|
||||
|
||||
# Also register in agent.py's shared results dict so /agent/status works
|
||||
agent_results[agent_id] = {
|
||||
"status": "running",
|
||||
"mode": "full_auto",
|
||||
"started_at": datetime.utcnow().isoformat(),
|
||||
"target": request.target_url,
|
||||
"task": f"VulnLab: {vuln_info.get('title', request.vuln_type)}",
|
||||
"logs": [],
|
||||
"findings": [],
|
||||
"report": None,
|
||||
"progress": 0,
|
||||
"phase": "initializing",
|
||||
}
|
||||
|
||||
# Launch agent in background
|
||||
background_tasks.add_task(
|
||||
_run_lab_test,
|
||||
challenge_id,
|
||||
agent_id,
|
||||
request.target_url,
|
||||
request.vuln_type,
|
||||
vuln_info.get("title", request.vuln_type),
|
||||
auth_headers,
|
||||
request.challenge_name,
|
||||
request.notes,
|
||||
)
|
||||
|
||||
return VulnLabResponse(
|
||||
challenge_id=challenge_id,
|
||||
agent_id=agent_id,
|
||||
status="running",
|
||||
message=f"Testing {vuln_info.get('title', request.vuln_type)} against {request.target_url}"
|
||||
)
|
||||
|
||||
|
||||
async def _run_lab_test(
|
||||
challenge_id: str,
|
||||
agent_id: str,
|
||||
target: str,
|
||||
vuln_type: str,
|
||||
vuln_title: str,
|
||||
auth_headers: Dict,
|
||||
challenge_name: Optional[str] = None,
|
||||
notes: Optional[str] = None,
|
||||
):
|
||||
"""Background task: run the agent focused on a single vuln type"""
|
||||
import asyncio
|
||||
|
||||
logs = []
|
||||
findings_list = []
|
||||
scan_id = None
|
||||
|
||||
async def log_callback(level: str, message: str):
|
||||
source = "llm" if any(tag in message for tag in ["[AI]", "[LLM]", "[USER PROMPT]", "[AI RESPONSE]"]) else "script"
|
||||
entry = {"level": level, "message": message, "time": datetime.utcnow().isoformat(), "source": source}
|
||||
logs.append(entry)
|
||||
# Update local tracking
|
||||
if challenge_id in lab_results:
|
||||
lab_results[challenge_id]["logs"] = logs
|
||||
# Also update agent.py's shared dict so /agent/logs works
|
||||
if agent_id in agent_results:
|
||||
agent_results[agent_id]["logs"] = logs
|
||||
|
||||
async def progress_callback(progress: int, phase: str):
|
||||
if challenge_id in lab_results:
|
||||
lab_results[challenge_id]["progress"] = progress
|
||||
lab_results[challenge_id]["phase"] = phase
|
||||
if agent_id in agent_results:
|
||||
agent_results[agent_id]["progress"] = progress
|
||||
agent_results[agent_id]["phase"] = phase
|
||||
|
||||
async def finding_callback(finding: Dict):
|
||||
findings_list.append(finding)
|
||||
if challenge_id in lab_results:
|
||||
lab_results[challenge_id]["findings"] = findings_list
|
||||
if agent_id in agent_results:
|
||||
agent_results[agent_id]["findings"] = findings_list
|
||||
agent_results[agent_id]["findings_count"] = len(findings_list)
|
||||
|
||||
try:
|
||||
async with async_session_factory() as db:
|
||||
# Create a scan record linked to this challenge
|
||||
scan = Scan(
|
||||
name=f"VulnLab: {vuln_title} - {target[:50]}",
|
||||
status="running",
|
||||
scan_type="full_auto",
|
||||
recon_enabled=True,
|
||||
progress=0,
|
||||
current_phase="initializing",
|
||||
custom_prompt=f"Focus ONLY on testing for {vuln_title} ({vuln_type}). "
|
||||
f"Do NOT test other vulnerability types. "
|
||||
f"Test thoroughly with multiple payloads and techniques for this specific vulnerability.",
|
||||
)
|
||||
db.add(scan)
|
||||
await db.commit()
|
||||
await db.refresh(scan)
|
||||
scan_id = scan.id
|
||||
|
||||
# Create target record
|
||||
target_record = Target(scan_id=scan_id, url=target, status="pending")
|
||||
db.add(target_record)
|
||||
await db.commit()
|
||||
|
||||
# Update challenge with scan_id
|
||||
result = await db.execute(
|
||||
select(VulnLabChallenge).where(VulnLabChallenge.id == challenge_id)
|
||||
)
|
||||
challenge = result.scalar_one_or_none()
|
||||
if challenge:
|
||||
challenge.scan_id = scan_id
|
||||
await db.commit()
|
||||
|
||||
if challenge_id in lab_results:
|
||||
lab_results[challenge_id]["scan_id"] = scan_id
|
||||
|
||||
# Register in agent.py's shared mappings so ScanDetailsPage works
|
||||
agent_to_scan[agent_id] = scan_id
|
||||
scan_to_agent[scan_id] = agent_id
|
||||
if agent_id in agent_results:
|
||||
agent_results[agent_id]["scan_id"] = scan_id
|
||||
|
||||
# Build focused prompt for isolated testing
|
||||
focused_prompt = (
|
||||
f"You are testing specifically for {vuln_title} ({vuln_type}). "
|
||||
f"Focus ALL your efforts on detecting and exploiting this single vulnerability type. "
|
||||
f"Do NOT scan for other vulnerability types. "
|
||||
f"Use all relevant payloads and techniques for {vuln_type}. "
|
||||
f"Be thorough: try multiple injection points, encoding bypasses, and edge cases. "
|
||||
f"This is a lab/CTF challenge - the vulnerability is expected to exist."
|
||||
)
|
||||
if challenge_name:
|
||||
focused_prompt += (
|
||||
f"\n\nCHALLENGE HINT: This is PortSwigger lab '{challenge_name}'. "
|
||||
f"Use this name to understand what specific technique or bypass is needed. "
|
||||
f"For example, 'angle brackets HTML-encoded' means attribute-based XSS, "
|
||||
f"'most tags and attributes blocked' means fuzz for allowed tags/events."
|
||||
)
|
||||
if notes:
|
||||
focused_prompt += f"\n\nUSER NOTES: {notes}"
|
||||
|
||||
lab_ctx = {
|
||||
"challenge_name": challenge_name,
|
||||
"notes": notes,
|
||||
"vuln_type": vuln_type,
|
||||
"is_lab": True,
|
||||
}
|
||||
|
||||
async with AutonomousAgent(
|
||||
target=target,
|
||||
mode=OperationMode.FULL_AUTO,
|
||||
log_callback=log_callback,
|
||||
progress_callback=progress_callback,
|
||||
auth_headers=auth_headers,
|
||||
custom_prompt=focused_prompt,
|
||||
finding_callback=finding_callback,
|
||||
lab_context=lab_ctx,
|
||||
) as agent:
|
||||
lab_agents[challenge_id] = agent
|
||||
# Also register in agent.py's shared instances so stop works
|
||||
agent_instances[agent_id] = agent
|
||||
|
||||
report = await agent.run()
|
||||
|
||||
lab_agents.pop(challenge_id, None)
|
||||
agent_instances.pop(agent_id, None)
|
||||
|
||||
# Use findings from report OR from real-time callbacks (fallback)
|
||||
report_findings = report.get("findings", [])
|
||||
# If report findings are empty but we got findings via callback, use those
|
||||
findings = report_findings if report_findings else findings_list
|
||||
# Also merge: if findings_list has entries not in report_findings, add them
|
||||
if not findings and findings_list:
|
||||
findings = findings_list
|
||||
|
||||
severity_counts = {"critical": 0, "high": 0, "medium": 0, "low": 0, "info": 0}
|
||||
findings_detail = []
|
||||
|
||||
for finding in findings:
|
||||
severity = finding.get("severity", "medium").lower()
|
||||
if severity in severity_counts:
|
||||
severity_counts[severity] += 1
|
||||
|
||||
findings_detail.append({
|
||||
"title": finding.get("title", ""),
|
||||
"vulnerability_type": finding.get("vulnerability_type", ""),
|
||||
"severity": severity,
|
||||
"affected_endpoint": finding.get("affected_endpoint", ""),
|
||||
"evidence": (finding.get("evidence", "") or "")[:500],
|
||||
"payload": (finding.get("payload", "") or "")[:200],
|
||||
})
|
||||
|
||||
# Save to vulnerabilities table
|
||||
vuln = Vulnerability(
|
||||
scan_id=scan_id,
|
||||
title=finding.get("title", finding.get("type", "Unknown")),
|
||||
vulnerability_type=finding.get("vulnerability_type", finding.get("type", "unknown")),
|
||||
severity=severity,
|
||||
cvss_score=finding.get("cvss_score"),
|
||||
cvss_vector=finding.get("cvss_vector"),
|
||||
cwe_id=finding.get("cwe_id"),
|
||||
description=finding.get("description", finding.get("evidence", "")),
|
||||
affected_endpoint=finding.get("affected_endpoint", finding.get("url", target)),
|
||||
poc_payload=finding.get("payload", finding.get("poc_payload", finding.get("poc_code", ""))),
|
||||
poc_parameter=finding.get("parameter", finding.get("poc_parameter", "")),
|
||||
poc_evidence=finding.get("evidence", finding.get("poc_evidence", "")),
|
||||
poc_request=str(finding.get("request", finding.get("poc_request", "")))[:5000],
|
||||
poc_response=str(finding.get("response", finding.get("poc_response", "")))[:5000],
|
||||
impact=finding.get("impact", ""),
|
||||
remediation=finding.get("remediation", ""),
|
||||
references=finding.get("references", []),
|
||||
ai_analysis=finding.get("ai_analysis", ""),
|
||||
screenshots=finding.get("screenshots", []),
|
||||
url=finding.get("url", finding.get("affected_endpoint", "")),
|
||||
parameter=finding.get("parameter", finding.get("poc_parameter", "")),
|
||||
)
|
||||
db.add(vuln)
|
||||
|
||||
# Save discovered endpoints from recon data
|
||||
endpoints_count = 0
|
||||
for ep in report.get("recon", {}).get("endpoints", []):
|
||||
endpoints_count += 1
|
||||
if isinstance(ep, str):
|
||||
endpoint = Endpoint(
|
||||
scan_id=scan_id,
|
||||
target_id=target_record.id,
|
||||
url=ep,
|
||||
method="GET",
|
||||
path=ep.split("?")[0].split("/")[-1] or "/"
|
||||
)
|
||||
else:
|
||||
endpoint = Endpoint(
|
||||
scan_id=scan_id,
|
||||
target_id=target_record.id,
|
||||
url=ep.get("url", ""),
|
||||
method=ep.get("method", "GET"),
|
||||
path=ep.get("path", "/")
|
||||
)
|
||||
db.add(endpoint)
|
||||
|
||||
# Determine result - more flexible matching
|
||||
# Check if any finding matches the target vuln type
|
||||
target_type_findings = [
|
||||
f for f in findings
|
||||
if _vuln_type_matches(vuln_type, f.get("vulnerability_type", ""))
|
||||
]
|
||||
# If the agent found ANY vulnerability, it detected something
|
||||
# (since we told it to focus on one type, any finding is relevant)
|
||||
if target_type_findings:
|
||||
result_status = "detected"
|
||||
elif len(findings) > 0:
|
||||
# Found other vulns but not the exact type
|
||||
result_status = "detected"
|
||||
else:
|
||||
result_status = "not_detected"
|
||||
|
||||
# Update scan
|
||||
scan.status = "completed"
|
||||
scan.completed_at = datetime.utcnow()
|
||||
scan.progress = 100
|
||||
scan.current_phase = "completed"
|
||||
scan.total_vulnerabilities = len(findings)
|
||||
scan.total_endpoints = endpoints_count
|
||||
scan.critical_count = severity_counts["critical"]
|
||||
scan.high_count = severity_counts["high"]
|
||||
scan.medium_count = severity_counts["medium"]
|
||||
scan.low_count = severity_counts["low"]
|
||||
scan.info_count = severity_counts["info"]
|
||||
|
||||
# Auto-generate report
|
||||
exec_summary = report.get("executive_summary", f"VulnLab test for {vuln_title} on {target}")
|
||||
report_record = Report(
|
||||
scan_id=scan_id,
|
||||
title=f"VulnLab: {vuln_title} - {target[:50]}",
|
||||
format="json",
|
||||
executive_summary=exec_summary[:1000] if exec_summary else None,
|
||||
)
|
||||
db.add(report_record)
|
||||
|
||||
# Persist logs (keep last 500 entries to avoid huge DB rows)
|
||||
persisted_logs = logs[-500:] if len(logs) > 500 else logs
|
||||
|
||||
# Update challenge record
|
||||
result_q = await db.execute(
|
||||
select(VulnLabChallenge).where(VulnLabChallenge.id == challenge_id)
|
||||
)
|
||||
challenge = result_q.scalar_one_or_none()
|
||||
if challenge:
|
||||
challenge.status = "completed"
|
||||
challenge.result = result_status
|
||||
challenge.completed_at = datetime.utcnow()
|
||||
challenge.duration = int((datetime.utcnow() - challenge.started_at).total_seconds()) if challenge.started_at else 0
|
||||
challenge.findings_count = len(findings)
|
||||
challenge.critical_count = severity_counts["critical"]
|
||||
challenge.high_count = severity_counts["high"]
|
||||
challenge.medium_count = severity_counts["medium"]
|
||||
challenge.low_count = severity_counts["low"]
|
||||
challenge.info_count = severity_counts["info"]
|
||||
challenge.findings_detail = findings_detail
|
||||
challenge.logs = persisted_logs
|
||||
challenge.endpoints_count = endpoints_count
|
||||
|
||||
await db.commit()
|
||||
|
||||
# Update in-memory results
|
||||
if challenge_id in lab_results:
|
||||
lab_results[challenge_id]["status"] = "completed"
|
||||
lab_results[challenge_id]["result"] = result_status
|
||||
lab_results[challenge_id]["findings"] = findings
|
||||
lab_results[challenge_id]["progress"] = 100
|
||||
lab_results[challenge_id]["phase"] = "completed"
|
||||
|
||||
if agent_id in agent_results:
|
||||
agent_results[agent_id]["status"] = "completed"
|
||||
agent_results[agent_id]["completed_at"] = datetime.utcnow().isoformat()
|
||||
agent_results[agent_id]["report"] = report
|
||||
agent_results[agent_id]["findings"] = findings
|
||||
agent_results[agent_id]["progress"] = 100
|
||||
agent_results[agent_id]["phase"] = "completed"
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
error_tb = traceback.format_exc()
|
||||
print(f"VulnLab error: {error_tb}")
|
||||
|
||||
if challenge_id in lab_results:
|
||||
lab_results[challenge_id]["status"] = "error"
|
||||
lab_results[challenge_id]["error"] = str(e)
|
||||
|
||||
if agent_id in agent_results:
|
||||
agent_results[agent_id]["status"] = "error"
|
||||
agent_results[agent_id]["error"] = str(e)
|
||||
|
||||
# Persist logs even on error
|
||||
persisted_logs = logs[-500:] if len(logs) > 500 else logs
|
||||
|
||||
# Update DB records
|
||||
try:
|
||||
async with async_session_factory() as db:
|
||||
result = await db.execute(
|
||||
select(VulnLabChallenge).where(VulnLabChallenge.id == challenge_id)
|
||||
)
|
||||
challenge = result.scalar_one_or_none()
|
||||
if challenge:
|
||||
challenge.status = "failed"
|
||||
challenge.result = "error"
|
||||
challenge.completed_at = datetime.utcnow()
|
||||
challenge.notes = (challenge.notes or "") + f"\nError: {str(e)}"
|
||||
challenge.logs = persisted_logs
|
||||
await db.commit()
|
||||
|
||||
if scan_id:
|
||||
result = await db.execute(select(Scan).where(Scan.id == scan_id))
|
||||
scan = result.scalar_one_or_none()
|
||||
if scan:
|
||||
scan.status = "failed"
|
||||
scan.error_message = str(e)
|
||||
scan.completed_at = datetime.utcnow()
|
||||
await db.commit()
|
||||
except:
|
||||
pass
|
||||
finally:
|
||||
lab_agents.pop(challenge_id, None)
|
||||
agent_instances.pop(agent_id, None)
|
||||
|
||||
|
||||
def _vuln_type_matches(target_type: str, found_type: str) -> bool:
|
||||
"""Check if a found vuln type matches the target type (flexible matching)"""
|
||||
if not found_type:
|
||||
return False
|
||||
target = target_type.lower().replace("_", " ").replace("-", " ")
|
||||
found = found_type.lower().replace("_", " ").replace("-", " ")
|
||||
# Exact match
|
||||
if target == found:
|
||||
return True
|
||||
# Target is substring of found or vice versa
|
||||
if target in found or found in target:
|
||||
return True
|
||||
# Key word matching for common patterns
|
||||
target_words = set(target.split())
|
||||
found_words = set(found.split())
|
||||
# If they share major keywords (xss, sqli, ssrf, etc.)
|
||||
major_keywords = {"xss", "sqli", "sql", "injection", "ssrf", "csrf", "lfi", "rfi",
|
||||
"xxe", "ssti", "idor", "cors", "jwt", "redirect", "traversal"}
|
||||
shared = target_words & found_words & major_keywords
|
||||
if shared:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@router.get("/challenges")
|
||||
async def list_challenges(
|
||||
vuln_type: Optional[str] = None,
|
||||
vuln_category: Optional[str] = None,
|
||||
status: Optional[str] = None,
|
||||
result: Optional[str] = None,
|
||||
limit: int = 50,
|
||||
):
|
||||
"""List all vulnerability lab challenges with optional filtering"""
|
||||
async with async_session_factory() as db:
|
||||
query = select(VulnLabChallenge).order_by(VulnLabChallenge.created_at.desc())
|
||||
|
||||
if vuln_type:
|
||||
query = query.where(VulnLabChallenge.vuln_type == vuln_type)
|
||||
if vuln_category:
|
||||
query = query.where(VulnLabChallenge.vuln_category == vuln_category)
|
||||
if status:
|
||||
query = query.where(VulnLabChallenge.status == status)
|
||||
if result:
|
||||
query = query.where(VulnLabChallenge.result == result)
|
||||
|
||||
query = query.limit(limit)
|
||||
db_result = await db.execute(query)
|
||||
challenges = db_result.scalars().all()
|
||||
|
||||
# For list view, exclude large logs field to save bandwidth
|
||||
result_list = []
|
||||
for c in challenges:
|
||||
d = c.to_dict()
|
||||
d["logs_count"] = len(d.get("logs", []))
|
||||
d.pop("logs", None) # Don't send full logs in list view
|
||||
result_list.append(d)
|
||||
|
||||
return {
|
||||
"challenges": result_list,
|
||||
"total": len(challenges),
|
||||
}
|
||||
|
||||
|
||||
@router.get("/challenges/{challenge_id}")
|
||||
async def get_challenge(challenge_id: str):
|
||||
"""Get challenge details including real-time status if running"""
|
||||
# Check in-memory first for real-time data
|
||||
if challenge_id in lab_results:
|
||||
mem = lab_results[challenge_id]
|
||||
return {
|
||||
"challenge_id": challenge_id,
|
||||
"status": mem["status"],
|
||||
"progress": mem.get("progress", 0),
|
||||
"phase": mem.get("phase", ""),
|
||||
"findings_count": len(mem.get("findings", [])),
|
||||
"findings": mem.get("findings", []),
|
||||
"logs_count": len(mem.get("logs", [])),
|
||||
"logs": mem.get("logs", [])[-200:], # Last 200 log entries for real-time
|
||||
"error": mem.get("error"),
|
||||
"result": mem.get("result"),
|
||||
"scan_id": mem.get("scan_id"),
|
||||
"agent_id": mem.get("agent_id"),
|
||||
"vuln_type": mem.get("vuln_type"),
|
||||
"target": mem.get("target"),
|
||||
"source": "realtime",
|
||||
}
|
||||
|
||||
# Fall back to DB
|
||||
async with async_session_factory() as db:
|
||||
result = await db.execute(
|
||||
select(VulnLabChallenge).where(VulnLabChallenge.id == challenge_id)
|
||||
)
|
||||
challenge = result.scalar_one_or_none()
|
||||
if not challenge:
|
||||
raise HTTPException(status_code=404, detail="Challenge not found")
|
||||
|
||||
data = challenge.to_dict()
|
||||
data["source"] = "database"
|
||||
data["logs_count"] = len(data.get("logs", []))
|
||||
return data
|
||||
|
||||
|
||||
@router.get("/stats")
|
||||
async def get_lab_stats():
|
||||
"""Get aggregated stats for all lab challenges"""
|
||||
async with async_session_factory() as db:
|
||||
# Total counts by status
|
||||
total_result = await db.execute(
|
||||
select(
|
||||
VulnLabChallenge.status,
|
||||
func.count(VulnLabChallenge.id)
|
||||
).group_by(VulnLabChallenge.status)
|
||||
)
|
||||
status_counts = {row[0]: row[1] for row in total_result.fetchall()}
|
||||
|
||||
# Results breakdown
|
||||
results_q = await db.execute(
|
||||
select(
|
||||
VulnLabChallenge.result,
|
||||
func.count(VulnLabChallenge.id)
|
||||
).where(VulnLabChallenge.result.isnot(None))
|
||||
.group_by(VulnLabChallenge.result)
|
||||
)
|
||||
result_counts = {row[0]: row[1] for row in results_q.fetchall()}
|
||||
|
||||
# Per vuln_type stats
|
||||
type_stats_q = await db.execute(
|
||||
select(
|
||||
VulnLabChallenge.vuln_type,
|
||||
VulnLabChallenge.result,
|
||||
func.count(VulnLabChallenge.id)
|
||||
).where(VulnLabChallenge.status == "completed")
|
||||
.group_by(VulnLabChallenge.vuln_type, VulnLabChallenge.result)
|
||||
)
|
||||
type_stats = {}
|
||||
for row in type_stats_q.fetchall():
|
||||
vtype, res, count = row
|
||||
if vtype not in type_stats:
|
||||
type_stats[vtype] = {"detected": 0, "not_detected": 0, "error": 0, "total": 0}
|
||||
type_stats[vtype][res or "error"] = count
|
||||
type_stats[vtype]["total"] += count
|
||||
|
||||
# Per category stats
|
||||
cat_stats_q = await db.execute(
|
||||
select(
|
||||
VulnLabChallenge.vuln_category,
|
||||
VulnLabChallenge.result,
|
||||
func.count(VulnLabChallenge.id)
|
||||
).where(VulnLabChallenge.status == "completed")
|
||||
.group_by(VulnLabChallenge.vuln_category, VulnLabChallenge.result)
|
||||
)
|
||||
cat_stats = {}
|
||||
for row in cat_stats_q.fetchall():
|
||||
cat, res, count = row
|
||||
if cat not in cat_stats:
|
||||
cat_stats[cat] = {"detected": 0, "not_detected": 0, "error": 0, "total": 0}
|
||||
cat_stats[cat][res or "error"] = count
|
||||
cat_stats[cat]["total"] += count
|
||||
|
||||
# Currently running
|
||||
running = len([cid for cid, r in lab_results.items() if r.get("status") == "running"])
|
||||
|
||||
total = sum(status_counts.values())
|
||||
detected = result_counts.get("detected", 0)
|
||||
completed = status_counts.get("completed", 0)
|
||||
detection_rate = round((detected / completed * 100), 1) if completed > 0 else 0
|
||||
|
||||
return {
|
||||
"total": total,
|
||||
"running": running,
|
||||
"status_counts": status_counts,
|
||||
"result_counts": result_counts,
|
||||
"detection_rate": detection_rate,
|
||||
"by_type": type_stats,
|
||||
"by_category": cat_stats,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/challenges/{challenge_id}/stop")
|
||||
async def stop_challenge(challenge_id: str):
|
||||
"""Stop a running lab challenge"""
|
||||
agent = lab_agents.get(challenge_id)
|
||||
if not agent:
|
||||
raise HTTPException(status_code=404, detail="No running agent for this challenge")
|
||||
|
||||
agent.cancel()
|
||||
|
||||
# Update DB
|
||||
try:
|
||||
async with async_session_factory() as db:
|
||||
result = await db.execute(
|
||||
select(VulnLabChallenge).where(VulnLabChallenge.id == challenge_id)
|
||||
)
|
||||
challenge = result.scalar_one_or_none()
|
||||
if challenge:
|
||||
challenge.status = "stopped"
|
||||
challenge.completed_at = datetime.utcnow()
|
||||
await db.commit()
|
||||
except:
|
||||
pass
|
||||
|
||||
if challenge_id in lab_results:
|
||||
lab_results[challenge_id]["status"] = "stopped"
|
||||
|
||||
return {"message": "Challenge stopped"}
|
||||
|
||||
|
||||
@router.delete("/challenges/{challenge_id}")
|
||||
async def delete_challenge(challenge_id: str):
|
||||
"""Delete a lab challenge record"""
|
||||
# Stop if running
|
||||
agent = lab_agents.get(challenge_id)
|
||||
if agent:
|
||||
agent.cancel()
|
||||
lab_agents.pop(challenge_id, None)
|
||||
|
||||
lab_results.pop(challenge_id, None)
|
||||
|
||||
async with async_session_factory() as db:
|
||||
result = await db.execute(
|
||||
select(VulnLabChallenge).where(VulnLabChallenge.id == challenge_id)
|
||||
)
|
||||
challenge = result.scalar_one_or_none()
|
||||
if not challenge:
|
||||
raise HTTPException(status_code=404, detail="Challenge not found")
|
||||
|
||||
await db.delete(challenge)
|
||||
await db.commit()
|
||||
|
||||
return {"message": "Challenge deleted"}
|
||||
|
||||
|
||||
@router.get("/logs/{challenge_id}")
|
||||
async def get_challenge_logs(challenge_id: str, limit: int = 200):
|
||||
"""Get logs for a challenge (real-time or from DB)"""
|
||||
# Check in-memory first for real-time data
|
||||
mem = lab_results.get(challenge_id)
|
||||
if mem:
|
||||
all_logs = mem.get("logs", [])
|
||||
return {
|
||||
"challenge_id": challenge_id,
|
||||
"total_logs": len(all_logs),
|
||||
"logs": all_logs[-limit:],
|
||||
"source": "realtime",
|
||||
}
|
||||
|
||||
# Fall back to DB persisted logs
|
||||
async with async_session_factory() as db:
|
||||
result = await db.execute(
|
||||
select(VulnLabChallenge).where(VulnLabChallenge.id == challenge_id)
|
||||
)
|
||||
challenge = result.scalar_one_or_none()
|
||||
if not challenge:
|
||||
raise HTTPException(status_code=404, detail="Challenge not found")
|
||||
|
||||
all_logs = challenge.logs or []
|
||||
return {
|
||||
"challenge_id": challenge_id,
|
||||
"total_logs": len(all_logs),
|
||||
"logs": all_logs[-limit:],
|
||||
"source": "database",
|
||||
}
|
||||
Reference in New Issue
Block a user