NeuroSploit v3.2 - Autonomous AI Penetration Testing Platform

116 modules | 100 vuln types | 18 API routes | 18 frontend pages

Major features:
- VulnEngine: 100 vuln types, 526+ payloads, 12 testers, anti-hallucination prompts
- Autonomous Agent: 3-stream auto pentest, multi-session (5 concurrent), pause/resume/stop
- CLI Agent: Claude Code / Gemini CLI / Codex CLI inside Kali containers
- Validation Pipeline: negative controls, proof of execution, confidence scoring, judge
- AI Reasoning: ReACT engine, token budget, endpoint classifier, CVE hunter, deep recon
- Multi-Agent: 5 specialists + orchestrator + researcher AI + vuln type agents
- RAG System: BM25/TF-IDF/ChromaDB vectorstore, few-shot, reasoning templates
- Smart Router: 20 providers (8 CLI OAuth + 12 API), tier failover, token refresh
- Kali Sandbox: container-per-scan, 56 tools, VPN support, on-demand install
- Full IA Testing: methodology-driven comprehensive pentest sessions
- Notifications: Discord, Telegram, WhatsApp/Twilio multi-channel alerts
- Frontend: React/TypeScript with 18 pages, real-time WebSocket updates
This commit is contained in:
CyberSecurityUP
2026-02-22 17:58:12 -03:00
commit e0935793c5
271 changed files with 132462 additions and 0 deletions

1
backend/api/__init__.py Executable file
View File

@@ -0,0 +1 @@
# API package

1
backend/api/v1/__init__.py Executable file
View File

@@ -0,0 +1 @@
# API v1 package

3128
backend/api/v1/agent.py Executable file

File diff suppressed because it is too large Load Diff

176
backend/api/v1/agent_tasks.py Executable file
View File

@@ -0,0 +1,176 @@
"""
NeuroSploit v3 - Agent Tasks API Endpoints
"""
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, func
from backend.db.database import get_db
from backend.models import AgentTask, Scan
from backend.schemas.agent_task import (
AgentTaskResponse,
AgentTaskListResponse,
AgentTaskSummary
)
router = APIRouter()
@router.get("", response_model=AgentTaskListResponse)
async def list_agent_tasks(
scan_id: str,
status: Optional[str] = None,
task_type: Optional[str] = None,
page: int = 1,
per_page: int = 50,
db: AsyncSession = Depends(get_db)
):
"""List all agent tasks for a scan"""
# Verify scan exists
scan_result = await db.execute(select(Scan).where(Scan.id == scan_id))
scan = scan_result.scalar_one_or_none()
if not scan:
raise HTTPException(status_code=404, detail="Scan not found")
# Build query
query = select(AgentTask).where(AgentTask.scan_id == scan_id)
if status:
query = query.where(AgentTask.status == status)
if task_type:
query = query.where(AgentTask.task_type == task_type)
query = query.order_by(AgentTask.created_at.desc())
# Get total count
count_query = select(func.count()).select_from(AgentTask).where(AgentTask.scan_id == scan_id)
if status:
count_query = count_query.where(AgentTask.status == status)
if task_type:
count_query = count_query.where(AgentTask.task_type == task_type)
total_result = await db.execute(count_query)
total = total_result.scalar() or 0
# Apply pagination
query = query.offset((page - 1) * per_page).limit(per_page)
result = await db.execute(query)
tasks = result.scalars().all()
return AgentTaskListResponse(
tasks=[AgentTaskResponse(**t.to_dict()) for t in tasks],
total=total,
scan_id=scan_id
)
@router.get("/summary", response_model=AgentTaskSummary)
async def get_agent_tasks_summary(
scan_id: str,
db: AsyncSession = Depends(get_db)
):
"""Get summary statistics for agent tasks in a scan"""
# Verify scan exists
scan_result = await db.execute(select(Scan).where(Scan.id == scan_id))
scan = scan_result.scalar_one_or_none()
if not scan:
raise HTTPException(status_code=404, detail="Scan not found")
# Total count
total_result = await db.execute(
select(func.count()).select_from(AgentTask).where(AgentTask.scan_id == scan_id)
)
total = total_result.scalar() or 0
# Count by status
status_counts = {}
for status in ["pending", "running", "completed", "failed"]:
count_result = await db.execute(
select(func.count()).select_from(AgentTask)
.where(AgentTask.scan_id == scan_id)
.where(AgentTask.status == status)
)
status_counts[status] = count_result.scalar() or 0
# Count by task type
type_query = select(
AgentTask.task_type,
func.count(AgentTask.id).label("count")
).where(AgentTask.scan_id == scan_id).group_by(AgentTask.task_type)
type_result = await db.execute(type_query)
by_type = {row[0]: row[1] for row in type_result.all()}
# Count by tool
tool_query = select(
AgentTask.tool_name,
func.count(AgentTask.id).label("count")
).where(AgentTask.scan_id == scan_id).where(AgentTask.tool_name.isnot(None)).group_by(AgentTask.tool_name)
tool_result = await db.execute(tool_query)
by_tool = {row[0]: row[1] for row in tool_result.all()}
return AgentTaskSummary(
total=total,
pending=status_counts.get("pending", 0),
running=status_counts.get("running", 0),
completed=status_counts.get("completed", 0),
failed=status_counts.get("failed", 0),
by_type=by_type,
by_tool=by_tool
)
@router.get("/{task_id}", response_model=AgentTaskResponse)
async def get_agent_task(
task_id: str,
db: AsyncSession = Depends(get_db)
):
"""Get a specific agent task by ID"""
result = await db.execute(select(AgentTask).where(AgentTask.id == task_id))
task = result.scalar_one_or_none()
if not task:
raise HTTPException(status_code=404, detail="Agent task not found")
return AgentTaskResponse(**task.to_dict())
@router.get("/scan/{scan_id}/timeline")
async def get_agent_tasks_timeline(
scan_id: str,
db: AsyncSession = Depends(get_db)
):
"""Get agent tasks as a timeline for visualization"""
# Verify scan exists
scan_result = await db.execute(select(Scan).where(Scan.id == scan_id))
scan = scan_result.scalar_one_or_none()
if not scan:
raise HTTPException(status_code=404, detail="Scan not found")
# Get all tasks ordered by creation time
query = select(AgentTask).where(AgentTask.scan_id == scan_id).order_by(AgentTask.created_at.asc())
result = await db.execute(query)
tasks = result.scalars().all()
timeline = []
for task in tasks:
timeline_item = {
"id": task.id,
"task_name": task.task_name,
"task_type": task.task_type,
"tool_name": task.tool_name,
"status": task.status,
"started_at": task.started_at.isoformat() if task.started_at else None,
"completed_at": task.completed_at.isoformat() if task.completed_at else None,
"duration_ms": task.duration_ms,
"items_processed": task.items_processed,
"items_found": task.items_found,
"result_summary": task.result_summary,
"error_message": task.error_message
}
timeline.append(timeline_item)
return {
"scan_id": scan_id,
"timeline": timeline,
"total": len(timeline)
}

144
backend/api/v1/cli_agent.py Normal file
View File

@@ -0,0 +1,144 @@
"""
CLI Agent API - Endpoints for CLI agent provider detection and methodology listing.
"""
import os
import glob
import logging
from typing import List, Dict, Optional
from fastapi import APIRouter
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/cli-agent", tags=["CLI Agent"])
# CLI providers that can run as autonomous agents
CLI_AGENT_PROVIDER_IDS = ["claude_code", "gemini_cli", "codex_cli"]
@router.get("/providers")
async def get_cli_providers() -> Dict:
"""List available CLI agent providers with connection status from SmartRouter."""
providers = []
try:
from backend.core.smart_router import get_registry
registry = get_registry()
except Exception:
registry = None
for pid in CLI_AGENT_PROVIDER_IDS:
provider_info = {
"id": pid,
"name": pid,
"connected": False,
"account_label": None,
"source": None,
}
if registry:
provider = registry.get_provider(pid)
if provider:
provider_info["name"] = provider.name
accounts = registry.get_active_accounts(pid)
if accounts:
provider_info["connected"] = True
provider_info["account_label"] = accounts[0].label
provider_info["source"] = accounts[0].source
providers.append(provider_info)
# Also check env var API keys as fallback
env_fallbacks = {
"claude_code": "ANTHROPIC_API_KEY",
"gemini_cli": "GEMINI_API_KEY",
"codex_cli": "OPENAI_API_KEY",
}
for p in providers:
if not p["connected"]:
env_key = env_fallbacks.get(p["id"], "")
if env_key and os.getenv(env_key, ""):
p["connected"] = True
p["source"] = "env_var"
p["account_label"] = f"${env_key}"
enabled = os.getenv("ENABLE_CLI_AGENT", "false").lower() == "true"
return {
"enabled": enabled,
"providers": providers,
"connected_count": sum(1 for p in providers if p["connected"]),
}
@router.get("/methodologies")
async def list_methodologies() -> Dict:
"""List available methodology .md files for CLI agent."""
methodologies: List[Dict] = []
seen_paths: set = set()
# 1. Check METHODOLOGY_FILE env var (default)
default_path = os.getenv("METHODOLOGY_FILE", "")
if default_path and os.path.exists(default_path):
size = os.path.getsize(default_path)
methodologies.append({
"name": os.path.basename(default_path),
"path": default_path,
"size": size,
"size_human": _human_size(size),
"is_default": True,
})
seen_paths.add(os.path.abspath(default_path))
# 2. Scan /opt/Prompts-PenTest/ for .md files
prompts_dir = "/opt/Prompts-PenTest"
if os.path.isdir(prompts_dir):
for md_file in sorted(glob.glob(os.path.join(prompts_dir, "*.md"))):
abs_path = os.path.abspath(md_file)
if abs_path in seen_paths:
continue
seen_paths.add(abs_path)
name = os.path.basename(md_file)
size = os.path.getsize(md_file)
# Only include pentest-related files (skip research reports, etc.)
name_lower = name.lower()
if any(kw in name_lower for kw in ["pentest", "prompt", "bugbounty", "methodology", "chunk"]):
methodologies.append({
"name": name,
"path": md_file,
"size": size,
"size_human": _human_size(size),
"is_default": False,
})
# 3. Check data/ directory
data_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "data")
if os.path.isdir(data_dir):
for md_file in glob.glob(os.path.join(data_dir, "*methodology*.md")):
abs_path = os.path.abspath(md_file)
if abs_path not in seen_paths:
seen_paths.add(abs_path)
size = os.path.getsize(md_file)
methodologies.append({
"name": os.path.basename(md_file),
"path": md_file,
"size": size,
"size_human": _human_size(size),
"is_default": False,
})
return {
"methodologies": methodologies,
"total": len(methodologies),
}
def _human_size(size_bytes: int) -> str:
"""Convert bytes to human-readable size."""
if size_bytes < 1024:
return f"{size_bytes} B"
elif size_bytes < 1024 * 1024:
return f"{size_bytes / 1024:.1f} KB"
else:
return f"{size_bytes / (1024 * 1024):.1f} MB"

299
backend/api/v1/dashboard.py Executable file
View File

@@ -0,0 +1,299 @@
"""
NeuroSploit v3 - Dashboard API Endpoints
"""
from typing import List
from fastapi import APIRouter, Depends
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, func
from datetime import datetime, timedelta
from backend.db.database import get_db
from backend.models import Scan, Vulnerability, Endpoint, AgentTask, Report
router = APIRouter()
@router.get("/stats")
async def get_dashboard_stats(db: AsyncSession = Depends(get_db)):
"""Get overall dashboard statistics"""
# Total scans
total_scans_result = await db.execute(select(func.count()).select_from(Scan))
total_scans = total_scans_result.scalar() or 0
# Scans by status
running_result = await db.execute(
select(func.count()).select_from(Scan).where(Scan.status == "running")
)
running_scans = running_result.scalar() or 0
completed_result = await db.execute(
select(func.count()).select_from(Scan).where(Scan.status == "completed")
)
completed_scans = completed_result.scalar() or 0
stopped_result = await db.execute(
select(func.count()).select_from(Scan).where(Scan.status == "stopped")
)
stopped_scans = stopped_result.scalar() or 0
failed_result = await db.execute(
select(func.count()).select_from(Scan).where(Scan.status == "failed")
)
failed_scans = failed_result.scalar() or 0
pending_result = await db.execute(
select(func.count()).select_from(Scan).where(Scan.status == "pending")
)
pending_scans = pending_result.scalar() or 0
# Total vulnerabilities by severity
vuln_counts = {}
for severity in ["critical", "high", "medium", "low", "info"]:
result = await db.execute(
select(func.count()).select_from(Vulnerability).where(Vulnerability.severity == severity)
)
vuln_counts[severity] = result.scalar() or 0
total_vulns = sum(vuln_counts.values())
# Total endpoints
endpoints_result = await db.execute(select(func.count()).select_from(Endpoint))
total_endpoints = endpoints_result.scalar() or 0
# Recent activity (last 7 days)
week_ago = datetime.utcnow() - timedelta(days=7)
recent_scans_result = await db.execute(
select(func.count()).select_from(Scan).where(Scan.created_at >= week_ago)
)
recent_scans = recent_scans_result.scalar() or 0
recent_vulns_result = await db.execute(
select(func.count()).select_from(Vulnerability).where(Vulnerability.created_at >= week_ago)
)
recent_vulns = recent_vulns_result.scalar() or 0
return {
"scans": {
"total": total_scans,
"running": running_scans,
"completed": completed_scans,
"stopped": stopped_scans,
"failed": failed_scans,
"pending": pending_scans,
"recent": recent_scans
},
"vulnerabilities": {
"total": total_vulns,
"critical": vuln_counts["critical"],
"high": vuln_counts["high"],
"medium": vuln_counts["medium"],
"low": vuln_counts["low"],
"info": vuln_counts["info"],
"recent": recent_vulns
},
"endpoints": {
"total": total_endpoints
}
}
@router.get("/recent")
async def get_recent_activity(
limit: int = 10,
db: AsyncSession = Depends(get_db)
):
"""Get recent scan activity"""
# Recent scans
scans_query = select(Scan).order_by(Scan.created_at.desc()).limit(limit)
scans_result = await db.execute(scans_query)
recent_scans = scans_result.scalars().all()
# Recent vulnerabilities
vulns_query = select(Vulnerability).order_by(Vulnerability.created_at.desc()).limit(limit)
vulns_result = await db.execute(vulns_query)
recent_vulns = vulns_result.scalars().all()
return {
"recent_scans": [s.to_dict() for s in recent_scans],
"recent_vulnerabilities": [v.to_dict() for v in recent_vulns]
}
@router.get("/findings")
async def get_recent_findings(
limit: int = 20,
severity: str = None,
db: AsyncSession = Depends(get_db)
):
"""Get recent vulnerability findings"""
query = select(Vulnerability).order_by(Vulnerability.created_at.desc())
if severity:
query = query.where(Vulnerability.severity == severity)
query = query.limit(limit)
result = await db.execute(query)
vulnerabilities = result.scalars().all()
return {
"findings": [v.to_dict() for v in vulnerabilities],
"total": len(vulnerabilities)
}
@router.get("/vulnerability-types")
async def get_vulnerability_distribution(db: AsyncSession = Depends(get_db)):
"""Get vulnerability distribution by type"""
query = select(
Vulnerability.vulnerability_type,
func.count(Vulnerability.id).label("count")
).group_by(Vulnerability.vulnerability_type)
result = await db.execute(query)
distribution = result.all()
return {
"distribution": [
{"type": row[0], "count": row[1]}
for row in distribution
]
}
@router.get("/scan-history")
async def get_scan_history(
days: int = 30,
db: AsyncSession = Depends(get_db)
):
"""Get scan history for charts"""
start_date = datetime.utcnow() - timedelta(days=days)
# Get scans grouped by date
scans = await db.execute(
select(Scan).where(Scan.created_at >= start_date).order_by(Scan.created_at)
)
all_scans = scans.scalars().all()
# Group by date
history = {}
for scan in all_scans:
date_str = scan.created_at.strftime("%Y-%m-%d")
if date_str not in history:
history[date_str] = {
"date": date_str,
"scans": 0,
"vulnerabilities": 0,
"critical": 0,
"high": 0
}
history[date_str]["scans"] += 1
history[date_str]["vulnerabilities"] += scan.total_vulnerabilities
history[date_str]["critical"] += scan.critical_count
history[date_str]["high"] += scan.high_count
return {"history": list(history.values())}
@router.get("/agent-tasks")
async def get_recent_agent_tasks(
limit: int = 20,
db: AsyncSession = Depends(get_db)
):
"""Get recent agent tasks across all scans"""
query = (
select(AgentTask)
.order_by(AgentTask.created_at.desc())
.limit(limit)
)
result = await db.execute(query)
tasks = result.scalars().all()
return {
"agent_tasks": [t.to_dict() for t in tasks],
"total": len(tasks)
}
@router.get("/activity-feed")
async def get_activity_feed(
limit: int = 30,
db: AsyncSession = Depends(get_db)
):
"""Get unified activity feed with all recent events"""
activities = []
# Get recent scans
scans_result = await db.execute(
select(Scan).order_by(Scan.created_at.desc()).limit(limit // 3)
)
for scan in scans_result.scalars().all():
activities.append({
"type": "scan",
"action": f"Scan {scan.status}",
"title": scan.name or "Unnamed Scan",
"description": f"{scan.total_vulnerabilities} vulnerabilities found",
"status": scan.status,
"severity": None,
"timestamp": scan.created_at.isoformat(),
"scan_id": scan.id,
"link": f"/scan/{scan.id}"
})
# Get recent vulnerabilities
vulns_result = await db.execute(
select(Vulnerability).order_by(Vulnerability.created_at.desc()).limit(limit // 3)
)
for vuln in vulns_result.scalars().all():
activities.append({
"type": "vulnerability",
"action": "Vulnerability found",
"title": vuln.title,
"description": vuln.affected_endpoint or "",
"status": None,
"severity": vuln.severity,
"timestamp": vuln.created_at.isoformat(),
"scan_id": vuln.scan_id,
"link": f"/scan/{vuln.scan_id}"
})
# Get recent agent tasks
tasks_result = await db.execute(
select(AgentTask).order_by(AgentTask.created_at.desc()).limit(limit // 3)
)
for task in tasks_result.scalars().all():
activities.append({
"type": "agent_task",
"action": f"Task {task.status}",
"title": task.task_name,
"description": task.result_summary or task.description or "",
"status": task.status,
"severity": None,
"timestamp": task.created_at.isoformat(),
"scan_id": task.scan_id,
"link": f"/scan/{task.scan_id}"
})
# Get recent reports
reports_result = await db.execute(
select(Report).order_by(Report.generated_at.desc()).limit(limit // 4)
)
for report in reports_result.scalars().all():
activities.append({
"type": "report",
"action": "Report generated" if report.auto_generated else "Report created",
"title": report.title or "Report",
"description": f"{report.format.upper()} format",
"status": "auto" if report.auto_generated else "manual",
"severity": None,
"timestamp": report.generated_at.isoformat(),
"scan_id": report.scan_id,
"link": f"/reports"
})
# Sort all activities by timestamp (newest first)
activities.sort(key=lambda x: x["timestamp"], reverse=True)
return {
"activities": activities[:limit],
"total": len(activities)
}

38
backend/api/v1/full_ia.py Normal file
View File

@@ -0,0 +1,38 @@
"""
NeuroSploit v3 - FULL AI Testing API
Serves the comprehensive pentest prompt and manages FULL AI testing sessions.
"""
import logging
from pathlib import Path
from fastapi import APIRouter, HTTPException
logger = logging.getLogger(__name__)
router = APIRouter()
# Default prompt file path - English translation preferred, fallback to original
PROMPT_PATH_EN = Path("/opt/Prompts-PenTest/pentestcompleto_en.md")
PROMPT_PATH_PT = Path("/opt/Prompts-PenTest/pentestcompleto.md")
PROMPT_PATH = PROMPT_PATH_EN if PROMPT_PATH_EN.exists() else PROMPT_PATH_PT
@router.get("/prompt")
async def get_full_ia_prompt():
"""Return the comprehensive pentest prompt content."""
if not PROMPT_PATH.exists():
raise HTTPException(
status_code=404,
detail=f"Pentest prompt file not found at {PROMPT_PATH}"
)
try:
content = PROMPT_PATH.read_text(encoding="utf-8")
return {
"content": content,
"path": str(PROMPT_PATH),
"size": len(content),
"lines": content.count("\n") + 1,
}
except Exception as e:
logger.error(f"Failed to read prompt file: {e}")
raise HTTPException(status_code=500, detail=str(e))

172
backend/api/v1/knowledge.py Normal file
View File

@@ -0,0 +1,172 @@
"""
NeuroSploit v3 - Knowledge Management API
Upload, manage, and query custom security knowledge documents.
"""
import os
from typing import Optional, List
from fastapi import APIRouter, HTTPException, UploadFile, File, Query
from pydantic import BaseModel
router = APIRouter()
# Lazy-loaded processor instance
_processor = None
def _get_processor():
global _processor
if _processor is None:
from backend.core.knowledge_processor import KnowledgeProcessor
# Try to get LLM client for AI analysis
llm = None
try:
from backend.core.autonomous_agent import LLMClient
client = LLMClient()
if client.is_available():
llm = client
except Exception:
pass
_processor = KnowledgeProcessor(llm_client=llm)
return _processor
# --- Schemas ---
class KnowledgeDocumentResponse(BaseModel):
id: str
filename: str
title: str
source_type: str
uploaded_at: str
processed: bool
file_size_bytes: int
summary: str
vuln_types: List[str]
entries_count: int
class KnowledgeEntryResponse(BaseModel):
vuln_type: str
methodology: str = ""
payloads: List[str] = []
key_insights: str = ""
bypass_techniques: List[str] = []
source_document: str = ""
class KnowledgeStatsResponse(BaseModel):
total_documents: int
total_entries: int
vuln_types_covered: List[str]
storage_bytes: int
# --- Endpoints ---
@router.post("/upload", response_model=KnowledgeDocumentResponse)
async def upload_knowledge(file: UploadFile = File(...)):
"""Upload a security document for knowledge extraction.
Supported formats: PDF, Markdown (.md), Text (.txt), HTML
The document will be analyzed and indexed by vulnerability type.
"""
if not file.filename:
raise HTTPException(400, "Filename is required")
# Read file content
content = await file.read()
if len(content) > 50 * 1024 * 1024: # 50MB limit
raise HTTPException(413, "File too large (max 50MB)")
if len(content) == 0:
raise HTTPException(400, "Empty file")
processor = _get_processor()
try:
doc = await processor.process_upload(content, file.filename)
except ValueError as e:
raise HTTPException(400, str(e))
except Exception as e:
raise HTTPException(500, f"Processing failed: {str(e)}")
return KnowledgeDocumentResponse(
id=doc["id"],
filename=doc["filename"],
title=doc["title"],
source_type=doc["source_type"],
uploaded_at=doc["uploaded_at"],
processed=doc["processed"],
file_size_bytes=doc["file_size_bytes"],
summary=doc["summary"],
vuln_types=doc["vuln_types"],
entries_count=len(doc.get("knowledge_entries", [])),
)
@router.get("/documents", response_model=List[KnowledgeDocumentResponse])
async def list_documents():
"""List all indexed knowledge documents."""
processor = _get_processor()
docs = processor.get_documents()
return [
KnowledgeDocumentResponse(
id=d["id"],
filename=d["filename"],
title=d["title"],
source_type=d["source_type"],
uploaded_at=d["uploaded_at"],
processed=d["processed"],
file_size_bytes=d["file_size_bytes"],
summary=d["summary"],
vuln_types=d["vuln_types"],
entries_count=d["entries_count"],
)
for d in docs
]
@router.get("/documents/{doc_id}")
async def get_document(doc_id: str):
"""Get a specific document with its full knowledge entries."""
processor = _get_processor()
doc = processor.get_document(doc_id)
if not doc:
raise HTTPException(404, f"Document '{doc_id}' not found")
return doc
@router.delete("/documents/{doc_id}")
async def delete_document(doc_id: str):
"""Delete a knowledge document and its index entries."""
processor = _get_processor()
deleted = processor.delete_document(doc_id)
if not deleted:
raise HTTPException(404, f"Document '{doc_id}' not found")
return {"message": f"Document '{doc_id}' deleted", "id": doc_id}
@router.get("/search", response_model=List[KnowledgeEntryResponse])
async def search_knowledge(vuln_type: str = Query(..., description="Vulnerability type to search")):
"""Search knowledge entries by vulnerability type."""
processor = _get_processor()
entries = processor.search_by_vuln_type(vuln_type)
return [
KnowledgeEntryResponse(
vuln_type=e.get("vuln_type", ""),
methodology=e.get("methodology", ""),
payloads=e.get("payloads", []),
key_insights=e.get("key_insights", ""),
bypass_techniques=e.get("bypass_techniques", []),
source_document=e.get("source_document", ""),
)
for e in entries
]
@router.get("/stats", response_model=KnowledgeStatsResponse)
async def get_stats():
"""Get knowledge base statistics."""
processor = _get_processor()
stats = processor.get_stats()
return KnowledgeStatsResponse(**stats)

320
backend/api/v1/mcp.py Normal file
View File

@@ -0,0 +1,320 @@
"""
NeuroSploit v3 - MCP Server Management API
CRUD for Model Context Protocol server connections.
Persists to config/config.json mcp_servers section.
"""
import json
import asyncio
from pathlib import Path
from typing import Optional, List, Dict
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel, Field
router = APIRouter()
CONFIG_PATH = Path(__file__).parent.parent.parent.parent / "config" / "config.json"
BUILTIN_SERVER = "neurosploit_tools"
# --- Schemas ---
class MCPServerCreate(BaseModel):
name: str = Field(..., min_length=1, max_length=100, description="Unique server identifier")
transport: str = Field("stdio", description="Transport type: stdio or sse")
command: Optional[str] = Field(None, description="Command for stdio transport")
args: Optional[List[str]] = Field(None, description="Args for stdio transport")
url: Optional[str] = Field(None, description="URL for sse transport")
env: Optional[Dict[str, str]] = Field(None, description="Environment variables")
description: str = Field("", description="Server description")
enabled: bool = Field(True, description="Whether server is enabled")
class MCPServerUpdate(BaseModel):
transport: Optional[str] = None
command: Optional[str] = None
args: Optional[List[str]] = None
url: Optional[str] = None
env: Optional[Dict[str, str]] = None
description: Optional[str] = None
enabled: Optional[bool] = None
class MCPServerResponse(BaseModel):
name: str
transport: str
command: Optional[str] = None
args: Optional[List[str]] = None
url: Optional[str] = None
env: Optional[Dict[str, str]] = None
description: str = ""
enabled: bool = True
is_builtin: bool = False
class MCPToolResponse(BaseModel):
name: str
description: str
server_name: str
# --- Config helpers ---
def _read_config() -> dict:
if not CONFIG_PATH.exists():
return {}
with open(CONFIG_PATH) as f:
return json.load(f)
def _write_config(config: dict):
with open(CONFIG_PATH, "w") as f:
json.dump(config, f, indent=4)
def _get_mcp_servers(config: dict) -> dict:
return config.get("mcp_servers", {})
def _server_to_response(name: str, server: dict) -> MCPServerResponse:
return MCPServerResponse(
name=name,
transport=server.get("transport", "stdio"),
command=server.get("command"),
args=server.get("args"),
url=server.get("url"),
env=server.get("env"),
description=server.get("description", ""),
enabled=server.get("enabled", True),
is_builtin=(name == BUILTIN_SERVER),
)
# --- Endpoints ---
@router.get("/servers", response_model=List[MCPServerResponse])
async def list_servers():
"""List all configured MCP servers."""
config = _read_config()
servers = _get_mcp_servers(config)
return [_server_to_response(name, srv) for name, srv in servers.items()]
@router.get("/servers/{name}", response_model=MCPServerResponse)
async def get_server(name: str):
"""Get a specific MCP server configuration."""
config = _read_config()
servers = _get_mcp_servers(config)
if name not in servers:
raise HTTPException(404, f"MCP server '{name}' not found")
return _server_to_response(name, servers[name])
@router.post("/servers", response_model=MCPServerResponse)
async def create_server(body: MCPServerCreate):
"""Add a new MCP server configuration."""
config = _read_config()
if "mcp_servers" not in config:
config["mcp_servers"] = {}
servers = config["mcp_servers"]
if body.name in servers:
raise HTTPException(409, f"Server '{body.name}' already exists")
# Validate transport-specific fields
if body.transport == "stdio" and not body.command:
raise HTTPException(400, "stdio transport requires 'command' field")
if body.transport == "sse" and not body.url:
raise HTTPException(400, "sse transport requires 'url' field")
server_config = {
"transport": body.transport,
"description": body.description,
"enabled": body.enabled,
}
if body.command:
server_config["command"] = body.command
if body.args:
server_config["args"] = body.args
if body.url:
server_config["url"] = body.url
if body.env:
server_config["env"] = body.env
servers[body.name] = server_config
_write_config(config)
return _server_to_response(body.name, server_config)
@router.put("/servers/{name}", response_model=MCPServerResponse)
async def update_server(name: str, body: MCPServerUpdate):
"""Update an MCP server configuration."""
config = _read_config()
servers = _get_mcp_servers(config)
if name not in servers:
raise HTTPException(404, f"MCP server '{name}' not found")
srv = servers[name]
if body.transport is not None:
srv["transport"] = body.transport
if body.command is not None:
srv["command"] = body.command
if body.args is not None:
srv["args"] = body.args
if body.url is not None:
srv["url"] = body.url
if body.env is not None:
srv["env"] = body.env
if body.description is not None:
srv["description"] = body.description
if body.enabled is not None:
srv["enabled"] = body.enabled
_write_config(config)
return _server_to_response(name, srv)
@router.delete("/servers/{name}")
async def delete_server(name: str):
"""Delete an MCP server configuration."""
if name == BUILTIN_SERVER:
raise HTTPException(403, f"Cannot delete built-in server '{BUILTIN_SERVER}'")
config = _read_config()
servers = _get_mcp_servers(config)
if name not in servers:
raise HTTPException(404, f"MCP server '{name}' not found")
del servers[name]
_write_config(config)
return {"message": f"Server '{name}' deleted"}
@router.post("/servers/{name}/toggle", response_model=MCPServerResponse)
async def toggle_server(name: str):
"""Toggle a server's enabled state."""
config = _read_config()
servers = _get_mcp_servers(config)
if name not in servers:
raise HTTPException(404, f"MCP server '{name}' not found")
srv = servers[name]
srv["enabled"] = not srv.get("enabled", True)
_write_config(config)
return _server_to_response(name, srv)
@router.post("/servers/{name}/test")
async def test_server_connection(name: str):
"""Test connection to an MCP server."""
config = _read_config()
servers = _get_mcp_servers(config)
if name not in servers:
raise HTTPException(404, f"MCP server '{name}' not found")
srv = servers[name]
transport = srv.get("transport", "stdio")
try:
if transport == "sse":
# Test SSE endpoint
import aiohttp
url = srv.get("url", "")
if not url:
return {"success": False, "error": "No URL configured", "tools_count": 0}
async with aiohttp.ClientSession() as session:
async with session.get(url, timeout=aiohttp.ClientTimeout(total=5)) as resp:
if resp.status < 400:
return {"success": True, "message": f"SSE endpoint reachable (HTTP {resp.status})", "tools_count": 0}
return {"success": False, "error": f"HTTP {resp.status}", "tools_count": 0}
elif transport == "stdio":
# Test stdio by checking command exists
import shutil
command = srv.get("command", "")
if not command:
return {"success": False, "error": "No command configured", "tools_count": 0}
if shutil.which(command):
return {"success": True, "message": f"Command '{command}' found in PATH", "tools_count": 0}
else:
return {"success": False, "error": f"Command '{command}' not found in PATH", "tools_count": 0}
except asyncio.TimeoutError:
return {"success": False, "error": "Connection timed out (5s)", "tools_count": 0}
except Exception as e:
return {"success": False, "error": str(e), "tools_count": 0}
@router.get("/servers/{name}/tools", response_model=List[MCPToolResponse])
async def list_server_tools(name: str):
"""List available tools from an MCP server.
For the built-in server, returns tools from the registry.
For external servers, attempts to connect and query.
"""
config = _read_config()
servers = _get_mcp_servers(config)
if name not in servers:
raise HTTPException(404, f"MCP server '{name}' not found")
# For builtin server, return tools from the MCP server module
if name == BUILTIN_SERVER:
try:
from core.mcp_server import TOOLS
return [
MCPToolResponse(
name=t["name"],
description=t.get("description", ""),
server_name=name,
)
for t in TOOLS
]
except ImportError:
return []
# For external servers, try to connect via MCPToolClient
try:
from core.mcp_client import MCPToolClient
# Build minimal config for this single server
client_config = {
"mcp_servers": {
"enabled": True,
"servers": {name: servers[name]}
}
}
client = MCPToolClient(client_config)
connected = await asyncio.wait_for(client.connect(name), timeout=10)
if not connected:
raise HTTPException(502, f"Failed to connect to MCP server '{name}'")
tools_dict = await client.list_tools(name)
tool_list = tools_dict.get(name, [])
await client.disconnect_all()
return [
MCPToolResponse(
name=t.get("name", ""),
description=t.get("description", ""),
server_name=name,
)
for t in tool_list
]
except ImportError:
raise HTTPException(501, "MCP client library not installed")
except asyncio.TimeoutError:
raise HTTPException(504, "Connection to MCP server timed out (10s)")
except HTTPException:
raise
except Exception as e:
raise HTTPException(502, f"Failed to list tools: {str(e)}")

372
backend/api/v1/prompts.py Executable file
View File

@@ -0,0 +1,372 @@
"""
NeuroSploit v3 - Prompts API Endpoints
"""
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from backend.db.database import get_db
from backend.models import Prompt
from backend.schemas.prompt import (
PromptCreate, PromptUpdate, PromptResponse, PromptParse, PromptParseResult, PromptPreset
)
from backend.core.prompt_engine.parser import PromptParser
router = APIRouter()
# Preset prompts
PRESET_PROMPTS = [
{
"id": "full_pentest",
"name": "Full Penetration Test",
"description": "Comprehensive security assessment covering all vulnerability categories",
"category": "pentest",
"content": """Perform a comprehensive penetration test on the target application.
Test for ALL vulnerability categories:
- Injection vulnerabilities (XSS, SQL Injection, Command Injection, LDAP, XPath, Template Injection)
- Authentication flaws (Broken auth, session management, JWT issues, OAuth flaws)
- Authorization issues (IDOR, BOLA, privilege escalation, access control bypass)
- File handling vulnerabilities (LFI, RFI, path traversal, file upload, XXE)
- Request forgery (SSRF, CSRF)
- API security issues (rate limiting, mass assignment, excessive data exposure)
- Client-side vulnerabilities (CORS misconfig, clickjacking, open redirect)
- Information disclosure (error messages, stack traces, sensitive data exposure)
- Infrastructure issues (security headers, SSL/TLS, HTTP methods)
- Business logic flaws (race conditions, workflow bypass)
Use thorough testing with multiple payloads and bypass techniques.
Generate detailed PoC for each vulnerability found.
Provide remediation recommendations."""
},
{
"id": "owasp_top10",
"name": "OWASP Top 10",
"description": "Test for OWASP Top 10 2021 vulnerabilities",
"category": "compliance",
"content": """Test for OWASP Top 10 2021 vulnerabilities:
A01:2021 - Broken Access Control
- IDOR, privilege escalation, access control bypass, CORS misconfig
A02:2021 - Cryptographic Failures
- Sensitive data exposure, weak encryption, cleartext transmission
A03:2021 - Injection
- SQL injection, XSS, command injection, LDAP injection
A04:2021 - Insecure Design
- Business logic flaws, missing security controls
A05:2021 - Security Misconfiguration
- Default configs, unnecessary features, missing headers
A06:2021 - Vulnerable Components
- Outdated libraries, known CVEs
A07:2021 - Identification and Authentication Failures
- Weak passwords, session fixation, credential stuffing
A08:2021 - Software and Data Integrity Failures
- Insecure deserialization, CI/CD vulnerabilities
A09:2021 - Security Logging and Monitoring Failures
- Missing audit logs, insufficient monitoring
A10:2021 - Server-Side Request Forgery (SSRF)
- Internal network access, cloud metadata exposure"""
},
{
"id": "api_security",
"name": "API Security Testing",
"description": "Focused testing for REST and GraphQL APIs",
"category": "api",
"content": """Perform API security testing:
Authentication & Authorization:
- Test JWT implementation (algorithm confusion, signature bypass, claim manipulation)
- OAuth/OIDC flow testing
- API key exposure and validation
- Rate limiting bypass
- BOLA/IDOR on all endpoints
Input Validation:
- SQL injection on API parameters
- NoSQL injection
- Command injection
- Parameter pollution
- Mass assignment vulnerabilities
Data Exposure:
- Excessive data exposure in responses
- Sensitive data in error messages
- Information disclosure in headers
- Debug endpoints exposure
GraphQL Specific (if applicable):
- Introspection enabled
- Query depth attacks
- Batching attacks
- Field suggestion exploitation
API Abuse:
- Rate limiting effectiveness
- Resource exhaustion
- Denial of service vectors"""
},
{
"id": "bug_bounty",
"name": "Bug Bounty Hunter",
"description": "Focus on high-impact, bounty-worthy vulnerabilities",
"category": "bug_bounty",
"content": """Hunt for high-impact vulnerabilities suitable for bug bounty:
Priority 1 - Critical Impact:
- Remote Code Execution (RCE)
- SQL Injection leading to data breach
- Authentication bypass
- SSRF to internal services/cloud metadata
- Privilege escalation to admin
Priority 2 - High Impact:
- Stored XSS
- IDOR on sensitive resources
- Account takeover vectors
- Payment/billing manipulation
- PII exposure
Priority 3 - Medium Impact:
- Reflected XSS
- CSRF on sensitive actions
- Information disclosure
- Rate limiting bypass
- Open redirects (if exploitable)
Look for:
- Unique attack chains
- Business logic flaws
- Edge cases and race conditions
- Bypass techniques for existing security controls
Document with clear PoC and impact assessment."""
},
{
"id": "quick_scan",
"name": "Quick Security Scan",
"description": "Fast scan for common vulnerabilities",
"category": "quick",
"content": """Perform a quick security scan for common vulnerabilities:
- Reflected XSS on input parameters
- Basic SQL injection testing
- Directory traversal/LFI
- Security headers check
- SSL/TLS configuration
- Common misconfigurations
- Information disclosure
Use minimal payloads for speed.
Focus on quick wins and obvious issues."""
},
{
"id": "auth_testing",
"name": "Authentication Testing",
"description": "Focus on authentication and session management",
"category": "auth",
"content": """Test authentication and session management:
Login Functionality:
- Username enumeration
- Password brute force protection
- Account lockout bypass
- Credential stuffing protection
- SQL injection in login
Session Management:
- Session token entropy
- Session fixation
- Session timeout
- Cookie security flags (HttpOnly, Secure, SameSite)
- Session invalidation on logout
Password Reset:
- Token predictability
- Token expiration
- Account enumeration
- Host header injection
Multi-Factor Authentication:
- MFA bypass techniques
- Backup codes weakness
- Rate limiting on OTP
OAuth/SSO:
- State parameter validation
- Redirect URI manipulation
- Token leakage"""
}
]
@router.get("/presets", response_model=List[PromptPreset])
async def get_preset_prompts():
"""Get list of preset prompts"""
return [
PromptPreset(
id=p["id"],
name=p["name"],
description=p["description"],
category=p["category"],
vulnerability_count=len(p["content"].split("\n"))
)
for p in PRESET_PROMPTS
]
@router.get("/presets/{preset_id}")
async def get_preset_prompt(preset_id: str):
"""Get a specific preset prompt by ID"""
for preset in PRESET_PROMPTS:
if preset["id"] == preset_id:
return preset
raise HTTPException(status_code=404, detail="Preset not found")
@router.post("/parse", response_model=PromptParseResult)
async def parse_prompt(prompt_data: PromptParse):
"""Parse a prompt to extract vulnerability types and testing scope"""
parser = PromptParser()
result = await parser.parse(prompt_data.content)
return result
@router.get("", response_model=List[PromptResponse])
async def list_prompts(
category: Optional[str] = None,
db: AsyncSession = Depends(get_db)
):
"""List all custom prompts"""
query = select(Prompt).where(Prompt.is_preset == False)
if category:
query = query.where(Prompt.category == category)
query = query.order_by(Prompt.created_at.desc())
result = await db.execute(query)
prompts = result.scalars().all()
return [PromptResponse(**p.to_dict()) for p in prompts]
@router.post("", response_model=PromptResponse)
async def create_prompt(prompt_data: PromptCreate, db: AsyncSession = Depends(get_db)):
"""Create a custom prompt"""
# Parse vulnerabilities from content
parser = PromptParser()
parsed = await parser.parse(prompt_data.content)
prompt = Prompt(
name=prompt_data.name,
description=prompt_data.description,
content=prompt_data.content,
category=prompt_data.category,
is_preset=False,
parsed_vulnerabilities=[v.dict() for v in parsed.vulnerabilities_to_test]
)
db.add(prompt)
await db.commit()
await db.refresh(prompt)
return PromptResponse(**prompt.to_dict())
@router.get("/{prompt_id}", response_model=PromptResponse)
async def get_prompt(prompt_id: str, db: AsyncSession = Depends(get_db)):
"""Get a prompt by ID"""
result = await db.execute(select(Prompt).where(Prompt.id == prompt_id))
prompt = result.scalar_one_or_none()
if not prompt:
raise HTTPException(status_code=404, detail="Prompt not found")
return PromptResponse(**prompt.to_dict())
@router.put("/{prompt_id}", response_model=PromptResponse)
async def update_prompt(
prompt_id: str,
prompt_data: PromptUpdate,
db: AsyncSession = Depends(get_db)
):
"""Update a prompt"""
result = await db.execute(select(Prompt).where(Prompt.id == prompt_id))
prompt = result.scalar_one_or_none()
if not prompt:
raise HTTPException(status_code=404, detail="Prompt not found")
if prompt.is_preset:
raise HTTPException(status_code=400, detail="Cannot modify preset prompts")
if prompt_data.name is not None:
prompt.name = prompt_data.name
if prompt_data.description is not None:
prompt.description = prompt_data.description
if prompt_data.content is not None:
prompt.content = prompt_data.content
# Re-parse vulnerabilities
parser = PromptParser()
parsed = await parser.parse(prompt_data.content)
prompt.parsed_vulnerabilities = [v.dict() for v in parsed.vulnerabilities_to_test]
if prompt_data.category is not None:
prompt.category = prompt_data.category
await db.commit()
await db.refresh(prompt)
return PromptResponse(**prompt.to_dict())
@router.delete("/{prompt_id}")
async def delete_prompt(prompt_id: str, db: AsyncSession = Depends(get_db)):
"""Delete a prompt"""
result = await db.execute(select(Prompt).where(Prompt.id == prompt_id))
prompt = result.scalar_one_or_none()
if not prompt:
raise HTTPException(status_code=404, detail="Prompt not found")
if prompt.is_preset:
raise HTTPException(status_code=400, detail="Cannot delete preset prompts")
await db.delete(prompt)
await db.commit()
return {"message": "Prompt deleted"}
@router.post("/upload")
async def upload_prompt(file: UploadFile = File(...)):
"""Upload a prompt file (.md or .txt)"""
if not file.filename:
raise HTTPException(status_code=400, detail="No file provided")
ext = "." + file.filename.split(".")[-1].lower() if "." in file.filename else ""
if ext not in {".md", ".txt"}:
raise HTTPException(status_code=400, detail="Invalid file type. Use .md or .txt")
content = await file.read()
try:
text = content.decode("utf-8")
except UnicodeDecodeError:
raise HTTPException(status_code=400, detail="Unable to decode file")
# Parse the prompt
parser = PromptParser()
parsed = await parser.parse(text)
return {
"filename": file.filename,
"content": text,
"parsed": parsed.dict()
}

403
backend/api/v1/providers.py Normal file
View File

@@ -0,0 +1,403 @@
"""
NeuroSploit v3 - Providers API
REST endpoints for managing LLM providers and accounts.
"""
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
from typing import Optional
router = APIRouter()
class ConnectRequest(BaseModel):
label: str = "Manual API Key"
credential: str
credential_type: str = "api_key"
model_override: Optional[str] = None
@router.get("")
async def list_providers():
"""List all providers with their accounts and status."""
from backend.core.smart_router import get_registry
registry = get_registry()
if not registry:
return {"enabled": False, "providers": []}
providers = []
for p in registry.get_all_providers():
accounts = []
for a in p.accounts.values():
accounts.append({
"id": a.id,
"label": a.label,
"source": a.source,
"credential_type": a.credential_type,
"is_active": a.is_active,
"tokens_used": a.tokens_used,
"last_used": a.last_used,
"expires_at": a.expires_at,
"model_override": a.model_override,
})
providers.append({
"id": p.id,
"name": p.name,
"auth_type": p.auth_type,
"api_format": p.api_format,
"tier": p.tier,
"default_model": p.default_model,
"accounts": accounts,
"connected": any(
a.is_active and a.id in registry._credentials
for a in p.accounts.values()
),
"enabled": getattr(p, "enabled", True),
})
return {"enabled": True, "providers": providers}
@router.get("/status")
async def providers_status():
"""Get quota and usage summary."""
from backend.core.smart_router import get_router
router_instance = get_router()
if not router_instance:
return {"enabled": False}
return {"enabled": True, **router_instance.get_status()}
@router.post("/{provider_id}/detect")
async def detect_cli_token(provider_id: str):
"""Auto-detect CLI token for a specific provider."""
from backend.core.smart_router import get_registry, get_extractor
registry = get_registry()
extractor = get_extractor()
if not registry or not extractor:
raise HTTPException(400, "Smart Router not enabled")
token = extractor.detect(provider_id)
if not token:
return {"detected": False, "message": f"No CLI token found for {provider_id}"}
# Add to registry
acct_id = registry.add_account(
provider_id=provider_id,
label=token.label,
credential=token.token,
credential_type=token.credential_type,
source="cli_detect",
refresh_token=token.refresh_token,
expires_at=token.expires_at,
)
return {
"detected": True,
"account_id": acct_id,
"label": token.label,
"credential_type": token.credential_type,
"has_refresh_token": token.refresh_token is not None,
"expires_at": token.expires_at,
}
@router.post("/{provider_id}/connect")
async def connect_provider(provider_id: str, req: ConnectRequest):
"""Manually add an API key or credential."""
from backend.core.smart_router import get_registry
registry = get_registry()
if not registry:
raise HTTPException(400, "Smart Router not enabled")
acct_id = registry.add_account(
provider_id=provider_id,
label=req.label,
credential=req.credential,
credential_type=req.credential_type,
source="manual",
model_override=req.model_override,
)
if not acct_id:
raise HTTPException(404, f"Unknown provider: {provider_id}")
return {"success": True, "account_id": acct_id}
@router.delete("/{provider_id}/accounts/{account_id}")
async def remove_account(provider_id: str, account_id: str):
"""Remove an account from a provider."""
from backend.core.smart_router import get_registry
registry = get_registry()
if not registry:
raise HTTPException(400, "Smart Router not enabled")
success = registry.remove_account(provider_id, account_id)
if not success:
raise HTTPException(404, "Account not found")
return {"success": True}
@router.post("/test/{provider_id}/{account_id}")
async def test_connection(provider_id: str, account_id: str):
"""Test connectivity for a specific account."""
from backend.core.smart_router import get_router
router_instance = get_router()
if not router_instance:
raise HTTPException(400, "Smart Router not enabled")
success, message = await router_instance.test_account(provider_id, account_id)
return {"success": success, "message": message}
# Known models per provider for dropdown selection
PROVIDER_MODELS = {
"claude_code": [
"claude-opus-4-6-20250918",
"claude-sonnet-4-5-20250929",
"claude-haiku-4-5-20251001",
"claude-sonnet-4-20250514",
"claude-opus-4-20250514",
"claude-haiku-4-20250514",
],
"kiro": [
"claude-opus-4-6-20250918",
"claude-sonnet-4-5-20250929",
"claude-haiku-4-5-20251001",
"claude-sonnet-4-20250514",
"claude-opus-4-20250514",
"claude-haiku-4-20250514",
],
"anthropic": [
"claude-opus-4-6-20250918",
"claude-sonnet-4-5-20250929",
"claude-haiku-4-5-20251001",
"claude-sonnet-4-20250514",
"claude-opus-4-20250514",
"claude-haiku-4-20250514",
"claude-3-5-sonnet-20241022",
],
"codex_cli": [
"gpt-4o",
"gpt-4o-mini",
"o3-mini",
"o4-mini",
"gpt-4.1",
"gpt-4.1-mini",
"gpt-4.1-nano",
],
"openai": [
"gpt-4o",
"gpt-4o-mini",
"o3-mini",
"o4-mini",
"gpt-4.1",
"gpt-4.1-mini",
"gpt-4.1-nano",
],
"gemini_cli": [
"gemini-3.0-pro",
"gemini-2.5-pro",
"gemini-2.5-flash",
"gemini-2.0-flash",
"gemini-2.0-flash-lite",
],
"gemini": [
"gemini-3.0-pro",
"gemini-2.5-pro",
"gemini-2.5-flash",
"gemini-2.0-flash",
"gemini-2.0-flash-lite",
],
"cursor": [
"cursor-fast",
"cursor-small",
"gpt-4o",
"claude-sonnet-4-5-20250929",
"claude-3-5-sonnet-20241022",
],
"copilot": [
"gpt-4o",
"gpt-4o-mini",
"claude-sonnet-4-5-20250929",
"claude-3-5-sonnet-20241022",
],
"openrouter": [
"anthropic/claude-opus-4-6",
"anthropic/claude-sonnet-4-5",
"anthropic/claude-haiku-4-5",
"anthropic/claude-sonnet-4",
"anthropic/claude-opus-4",
"openai/gpt-4o",
"google/gemini-3.0-pro",
"google/gemini-2.5-pro",
"google/gemini-2.5-flash",
"meta-llama/llama-4-maverick",
"deepseek/deepseek-r1",
],
"together": [
"meta-llama/Llama-3-70b-chat-hf",
"meta-llama/Llama-3.3-70B-Instruct-Turbo",
"deepseek-ai/DeepSeek-R1",
"Qwen/Qwen2.5-72B-Instruct-Turbo",
],
"fireworks": [
"accounts/fireworks/models/llama-v3p1-70b-instruct",
"accounts/fireworks/models/llama-v3p3-70b-instruct",
"accounts/fireworks/models/deepseek-r1",
],
"iflow": ["kimi-k2"],
"qwen_code": ["qwen3-coder", "qwen-max"],
"ollama": ["llama3", "llama3.2", "mistral", "codellama", "deepseek-r1"],
"lmstudio": ["local-model"],
}
@router.get("/available-models")
async def available_models():
"""Get list of available provider+model combinations for selection dropdowns."""
from backend.core.smart_router import get_registry
registry = get_registry()
if not registry:
return {"models": []}
models = []
for p in registry.get_all_providers():
active = registry.get_active_accounts(p.id)
if not active:
continue
models.append({
"provider_id": p.id,
"provider_name": p.name,
"default_model": p.default_model,
"tier": p.tier,
"available_models": PROVIDER_MODELS.get(p.id, [p.default_model]),
})
# Sort by tier (paid first) then name
models.sort(key=lambda m: (m["tier"], m["provider_name"]))
return {"models": models}
@router.post("/detect-all")
async def detect_all_tokens():
"""Scan all CLI tools for available tokens."""
from backend.core.smart_router import get_registry, get_extractor
registry = get_registry()
extractor = get_extractor()
if not registry or not extractor:
raise HTTPException(400, "Smart Router not enabled")
tokens = extractor.detect_all()
results = []
for token in tokens:
acct_id = registry.add_account(
provider_id=token.provider_id,
label=token.label,
credential=token.token,
credential_type=token.credential_type,
source="cli_detect",
refresh_token=token.refresh_token,
expires_at=token.expires_at,
)
results.append({
"provider_id": token.provider_id,
"label": token.label,
"account_id": acct_id,
})
return {
"detected_count": len(results),
"results": results,
}
class ToggleRequest(BaseModel):
enabled: bool
@router.post("/{provider_id}/toggle")
async def toggle_provider(provider_id: str, req: ToggleRequest):
"""Enable or disable a provider. Disabled providers are skipped by the router."""
from backend.core.smart_router import get_registry
registry = get_registry()
if not registry:
raise HTTPException(400, "Smart Router not enabled")
success = registry.toggle_provider(provider_id, req.enabled)
if not success:
raise HTTPException(404, f"Unknown provider: {provider_id}")
return {"success": True, "provider_id": provider_id, "enabled": req.enabled}
# Whitelist of env keys that can be modified via UI
ALLOWED_ENV_KEYS = {
"ANTHROPIC_API_KEY", "OPENAI_API_KEY", "GEMINI_API_KEY", "GOOGLE_API_KEY",
"OPENROUTER_API_KEY", "TOGETHER_API_KEY", "FIREWORKS_API_KEY",
"OLLAMA_HOST", "LMSTUDIO_HOST",
"ENABLE_SMART_ROUTER", "ENABLE_REASONING", "ENABLE_CVE_HUNT",
"ENABLE_MULTI_AGENT", "ENABLE_RESEARCHER_AI",
"NVD_API_KEY", "GITHUB_TOKEN", "TOKEN_BUDGET",
}
class EnvUpdateRequest(BaseModel):
key: str
value: str
@router.get("/env")
async def get_env_keys():
"""Get current values of allowed env keys (masked for secrets)."""
import os
result = {}
for key in sorted(ALLOWED_ENV_KEYS):
val = os.getenv(key, "")
if val and "KEY" in key and key not in ("ENABLE_SMART_ROUTER", "ENABLE_REASONING",
"ENABLE_CVE_HUNT", "ENABLE_MULTI_AGENT",
"ENABLE_RESEARCHER_AI", "TOKEN_BUDGET"):
# Mask API keys: show first 8 and last 4 chars
if len(val) > 16:
result[key] = val[:8] + "..." + val[-4:]
else:
result[key] = "****"
else:
result[key] = val
return {"env": result, "allowed_keys": sorted(ALLOWED_ENV_KEYS)}
@router.post("/env")
async def update_env_key(req: EnvUpdateRequest):
"""Update an env var and persist to .env file."""
import os
from pathlib import Path
if req.key not in ALLOWED_ENV_KEYS:
raise HTTPException(400, f"Key '{req.key}' is not in the allowed whitelist")
# Update in-process env
os.environ[req.key] = req.value
# Persist to .env file
env_path = Path(__file__).parent.parent.parent.parent / ".env"
try:
lines = []
found = False
if env_path.exists():
for line in env_path.read_text().splitlines():
stripped = line.strip()
if stripped.startswith(f"{req.key}=") or stripped.startswith(f"# {req.key}="):
lines.append(f"{req.key}={req.value}")
found = True
else:
lines.append(line)
if not found:
lines.append(f"{req.key}={req.value}")
env_path.write_text("\n".join(lines) + "\n")
except Exception as e:
# Still updated in-process even if file write failed
return {"success": True, "persisted": False, "error": str(e)}
return {"success": True, "persisted": True}

387
backend/api/v1/reports.py Executable file
View File

@@ -0,0 +1,387 @@
"""
NeuroSploit v3 - Reports API Endpoints
"""
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import FileResponse, HTMLResponse
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from pathlib import Path
from backend.db.database import get_db
from backend.models import Scan, Report, Vulnerability, Endpoint
from backend.schemas.report import ReportGenerate, ReportResponse, ReportListResponse
from backend.core.report_engine.generator import ReportGenerator
from backend.config import settings
router = APIRouter()
@router.get("", response_model=ReportListResponse)
async def list_reports(
scan_id: Optional[str] = None,
auto_generated: Optional[bool] = None,
is_partial: Optional[bool] = None,
db: AsyncSession = Depends(get_db)
):
"""List all reports with optional filtering"""
query = select(Report).order_by(Report.generated_at.desc())
if scan_id:
query = query.where(Report.scan_id == scan_id)
if auto_generated is not None:
query = query.where(Report.auto_generated == auto_generated)
if is_partial is not None:
query = query.where(Report.is_partial == is_partial)
result = await db.execute(query)
reports = result.scalars().all()
return ReportListResponse(
reports=[ReportResponse(**r.to_dict()) for r in reports],
total=len(reports)
)
@router.post("", response_model=ReportResponse)
async def generate_report(
report_data: ReportGenerate,
db: AsyncSession = Depends(get_db)
):
"""Generate a new report for a scan"""
# Get scan
scan_result = await db.execute(select(Scan).where(Scan.id == report_data.scan_id))
scan = scan_result.scalar_one_or_none()
if not scan:
raise HTTPException(status_code=404, detail="Scan not found")
# Get vulnerabilities
vulns_result = await db.execute(
select(Vulnerability).where(Vulnerability.scan_id == report_data.scan_id)
)
vulnerabilities = vulns_result.scalars().all()
# Try to get tool_executions from agent in-memory results
tool_executions = []
try:
from backend.api.v1.agent import scan_to_agent, agent_results
agent_id = scan_to_agent.get(report_data.scan_id)
if agent_id and agent_id in agent_results:
tool_executions = agent_results[agent_id].get("tool_executions", [])
if not tool_executions:
rpt = agent_results[agent_id].get("report", {})
tool_executions = rpt.get("tool_executions", []) if isinstance(rpt, dict) else []
except Exception:
pass
# Get endpoints
endpoints_result = await db.execute(
select(Endpoint).where(Endpoint.scan_id == report_data.scan_id)
)
endpoints = endpoints_result.scalars().all()
# Generate report
generator = ReportGenerator()
report_path, executive_summary = await generator.generate(
scan=scan,
vulnerabilities=vulnerabilities,
format=report_data.format,
title=report_data.title,
include_executive_summary=report_data.include_executive_summary,
include_poc=report_data.include_poc,
include_remediation=report_data.include_remediation,
tool_executions=tool_executions,
endpoints=endpoints,
)
# Save report record
report = Report(
scan_id=scan.id,
title=report_data.title or f"Report - {scan.name}",
format=report_data.format,
file_path=str(report_path),
executive_summary=executive_summary
)
db.add(report)
await db.commit()
await db.refresh(report)
return ReportResponse(**report.to_dict())
@router.post("/ai-generate", response_model=ReportResponse)
async def generate_ai_report(
report_data: ReportGenerate,
db: AsyncSession = Depends(get_db)
):
"""Generate an AI-enhanced report with LLM-written executive summary and per-finding analysis."""
# Get scan
scan_result = await db.execute(select(Scan).where(Scan.id == report_data.scan_id))
scan = scan_result.scalar_one_or_none()
if not scan:
raise HTTPException(status_code=404, detail="Scan not found")
# Get vulnerabilities
vulns_result = await db.execute(
select(Vulnerability).where(Vulnerability.scan_id == report_data.scan_id)
)
vulnerabilities = vulns_result.scalars().all()
# Try to get tool_executions from agent in-memory results
tool_executions = []
try:
from backend.api.v1.agent import scan_to_agent, agent_results
agent_id = scan_to_agent.get(report_data.scan_id)
if agent_id and agent_id in agent_results:
tool_executions = agent_results[agent_id].get("tool_executions", [])
# Also check nested report
if not tool_executions:
rpt = agent_results[agent_id].get("report", {})
tool_executions = rpt.get("tool_executions", []) if isinstance(rpt, dict) else []
except Exception:
pass
# Generate AI report
generator = ReportGenerator()
try:
report_path, ai_summary = await generator.generate_ai_report(
scan=scan,
vulnerabilities=vulnerabilities,
tool_executions=tool_executions,
title=report_data.title,
preferred_provider=report_data.preferred_provider,
preferred_model=report_data.preferred_model,
)
except Exception as e:
import logging
logging.getLogger(__name__).error(f"AI report generation failed: {e}")
raise HTTPException(status_code=500, detail=f"AI report generation failed: {str(e)}")
# Save report record
report = Report(
scan_id=scan.id,
title=report_data.title or f"AI Report - {scan.name}",
format="html",
file_path=str(report_path),
executive_summary=ai_summary[:2000] if ai_summary else None
)
db.add(report)
await db.commit()
await db.refresh(report)
return ReportResponse(**report.to_dict())
@router.get("/{report_id}", response_model=ReportResponse)
async def get_report(report_id: str, db: AsyncSession = Depends(get_db)):
"""Get report details"""
result = await db.execute(select(Report).where(Report.id == report_id))
report = result.scalar_one_or_none()
if not report:
raise HTTPException(status_code=404, detail="Report not found")
return ReportResponse(**report.to_dict())
@router.get("/{report_id}/view")
async def view_report(report_id: str, db: AsyncSession = Depends(get_db)):
"""View report in browser (HTML)"""
result = await db.execute(select(Report).where(Report.id == report_id))
report = result.scalar_one_or_none()
if not report:
raise HTTPException(status_code=404, detail="Report not found")
if not report.file_path:
raise HTTPException(status_code=404, detail="Report file not found")
file_path = Path(report.file_path)
if not file_path.exists():
raise HTTPException(status_code=404, detail="Report file not found on disk")
if report.format == "html":
content = file_path.read_text()
return HTMLResponse(content=content)
else:
return FileResponse(
path=str(file_path),
media_type="application/octet-stream",
filename=file_path.name
)
@router.get("/{report_id}/download/{format}")
async def download_report(
report_id: str,
format: str,
db: AsyncSession = Depends(get_db)
):
"""Download report in specified format"""
result = await db.execute(select(Report).where(Report.id == report_id))
report = result.scalar_one_or_none()
if not report:
raise HTTPException(status_code=404, detail="Report not found")
# Get scan and vulnerabilities for generating report
scan_result = await db.execute(select(Scan).where(Scan.id == report.scan_id))
scan = scan_result.scalar_one_or_none()
if not scan:
raise HTTPException(status_code=404, detail="Scan not found for report")
vulns_result = await db.execute(
select(Vulnerability).where(Vulnerability.scan_id == report.scan_id)
)
vulnerabilities = vulns_result.scalars().all()
# Always generate fresh report file (handles auto-generated reports without file_path)
generator = ReportGenerator()
report_path, _ = await generator.generate(
scan=scan,
vulnerabilities=vulnerabilities,
format=format,
title=report.title
)
file_path = Path(report_path)
# Update report with file path if not set
if not report.file_path:
report.file_path = str(file_path)
report.format = format
await db.commit()
if not file_path.exists():
raise HTTPException(status_code=404, detail="Report file not found")
media_types = {
"html": "text/html",
"pdf": "application/pdf",
"json": "application/json"
}
return FileResponse(
path=str(file_path),
media_type=media_types.get(format, "application/octet-stream"),
filename=file_path.name
)
@router.get("/{report_id}/download-zip")
async def download_report_zip(
report_id: str,
db: AsyncSession = Depends(get_db)
):
"""Download report as ZIP with screenshots included"""
import zipfile
import tempfile
import hashlib
result = await db.execute(select(Report).where(Report.id == report_id))
report = result.scalar_one_or_none()
if not report:
raise HTTPException(status_code=404, detail="Report not found")
scan_result = await db.execute(select(Scan).where(Scan.id == report.scan_id))
scan = scan_result.scalar_one_or_none()
if not scan:
raise HTTPException(status_code=404, detail="Scan not found for report")
vulns_result = await db.execute(
select(Vulnerability).where(Vulnerability.scan_id == report.scan_id)
)
vulnerabilities = vulns_result.scalars().all()
# Generate HTML report
generator = ReportGenerator()
report_path, _ = await generator.generate(
scan=scan,
vulnerabilities=vulnerabilities,
format="html",
title=report.title
)
# Collect screenshots (use absolute path via settings.BASE_DIR)
# Check scan-scoped path first, then legacy flat path
screenshots_base = settings.BASE_DIR / "reports" / "screenshots"
scan_id_str = str(scan.id) if scan else None
screenshot_files = []
for vuln in vulnerabilities:
# Finding ID is md5(vuln_type+url+param)[:8]
vuln_url = getattr(vuln, 'url', None) or vuln.affected_endpoint or ''
vuln_param = getattr(vuln, 'parameter', None) or getattr(vuln, 'poc_parameter', None) or ''
finding_id = hashlib.md5(
f"{vuln.vulnerability_type}{vuln_url}{vuln_param}".encode()
).hexdigest()[:8]
# Scan-scoped path: reports/screenshots/{scan_id}/{finding_id}/
finding_dir = None
if scan_id_str:
scan_dir = screenshots_base / scan_id_str / finding_id
if scan_dir.exists():
finding_dir = scan_dir
if not finding_dir:
legacy_dir = screenshots_base / finding_id
if legacy_dir.exists():
finding_dir = legacy_dir
if finding_dir:
for img in finding_dir.glob("*.png"):
screenshot_files.append((img, f"screenshots/{finding_id}/{img.name}"))
# Also include base64 screenshots from DB as files in the ZIP
db_screenshots = getattr(vuln, 'screenshots', None) or []
for idx, ss in enumerate(db_screenshots):
if isinstance(ss, str) and ss.startswith("data:image/"):
# Will be embedded in HTML, but also save as file
import base64 as b64
try:
b64_data = ss.split(",", 1)[1]
img_bytes = b64.b64decode(b64_data)
img_name = f"screenshots/{finding_id}/evidence_{idx+1}.png"
# Write to temp for ZIP inclusion
tmp_img = Path(tempfile.gettempdir()) / f"ss_{finding_id}_{idx}.png"
tmp_img.write_bytes(img_bytes)
screenshot_files.append((tmp_img, img_name))
except Exception:
pass
# Create ZIP
zip_name = Path(report_path).stem + ".zip"
zip_path = Path(tempfile.gettempdir()) / zip_name
with zipfile.ZipFile(str(zip_path), 'w', zipfile.ZIP_DEFLATED) as zf:
zf.write(report_path, "report.html")
for src_path, arc_name in screenshot_files:
zf.write(str(src_path), arc_name)
return FileResponse(
path=str(zip_path),
media_type="application/zip",
filename=zip_name
)
@router.delete("/{report_id}")
async def delete_report(report_id: str, db: AsyncSession = Depends(get_db)):
"""Delete a report"""
result = await db.execute(select(Report).where(Report.id == report_id))
report = result.scalar_one_or_none()
if not report:
raise HTTPException(status_code=404, detail="Report not found")
# Delete file if exists
if report.file_path:
file_path = Path(report.file_path)
if file_path.exists():
file_path.unlink()
await db.delete(report)
await db.commit()
return {"message": "Report deleted"}

130
backend/api/v1/sandbox.py Executable file
View File

@@ -0,0 +1,130 @@
"""
NeuroSploit v3 - Sandbox Container Management API
Real-time monitoring and management of per-scan Kali Linux containers.
"""
from datetime import datetime
from fastapi import APIRouter, HTTPException
router = APIRouter()
def _docker_available() -> bool:
try:
import docker
docker.from_env().ping()
return True
except Exception:
return False
@router.get("/")
async def list_sandboxes():
"""List all sandbox containers with pool status."""
try:
from core.container_pool import get_pool
pool = get_pool()
except Exception as e:
return {
"pool": {
"active": 0,
"max_concurrent": 0,
"image": "neurosploit-kali:latest",
"container_ttl_minutes": 60,
"docker_available": _docker_available(),
},
"containers": [],
"error": str(e),
}
sandboxes = pool.list_sandboxes()
now = datetime.utcnow()
containers = []
for info in sandboxes.values():
created = info.get("created_at")
uptime = 0.0
if created:
try:
dt = datetime.fromisoformat(created)
uptime = (now - dt).total_seconds()
except Exception:
pass
containers.append({
**info,
"uptime_seconds": uptime,
})
return {
"pool": {
"active": pool.active_count,
"max_concurrent": pool.max_concurrent,
"image": pool.image,
"container_ttl_minutes": int(pool.container_ttl.total_seconds() / 60),
"docker_available": _docker_available(),
},
"containers": containers,
}
@router.get("/{scan_id}")
async def get_sandbox(scan_id: str):
"""Get health check for a specific sandbox container."""
try:
from core.container_pool import get_pool
pool = get_pool()
except Exception as e:
raise HTTPException(status_code=503, detail=str(e))
sandboxes = pool.list_sandboxes()
if scan_id not in sandboxes:
raise HTTPException(status_code=404, detail=f"No sandbox for scan {scan_id}")
sb = pool._sandboxes.get(scan_id)
if not sb:
raise HTTPException(status_code=404, detail=f"Sandbox instance not found")
health = await sb.health_check()
return health
@router.delete("/{scan_id}")
async def destroy_sandbox(scan_id: str):
"""Destroy a specific sandbox container."""
try:
from core.container_pool import get_pool
pool = get_pool()
except Exception as e:
raise HTTPException(status_code=503, detail=str(e))
sandboxes = pool.list_sandboxes()
if scan_id not in sandboxes:
raise HTTPException(status_code=404, detail=f"No sandbox for scan {scan_id}")
await pool.destroy(scan_id)
return {"message": f"Sandbox for scan {scan_id} destroyed", "scan_id": scan_id}
@router.post("/cleanup")
async def cleanup_expired():
"""Remove containers that have exceeded their TTL."""
try:
from core.container_pool import get_pool
pool = get_pool()
await pool.cleanup_expired()
return {"message": "Expired containers cleaned up"}
except Exception as e:
raise HTTPException(status_code=503, detail=str(e))
@router.post("/cleanup-orphans")
async def cleanup_orphans():
"""Remove orphan containers not tracked by the pool."""
try:
from core.container_pool import get_pool
pool = get_pool()
await pool.cleanup_orphans()
return {"message": "Orphan containers cleaned up"}
except Exception as e:
raise HTTPException(status_code=503, detail=str(e))

656
backend/api/v1/scans.py Executable file
View File

@@ -0,0 +1,656 @@
"""
NeuroSploit v3 - Scans API Endpoints
"""
from typing import List, Optional
from datetime import datetime
from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks
from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, func
from urllib.parse import urlparse
from backend.db.database import get_db
from backend.models import Scan, Target, Endpoint, Vulnerability
from backend.schemas.scan import ScanCreate, ScanUpdate, ScanResponse, ScanListResponse, ScanProgress
from backend.services.scan_service import run_scan_task, skip_to_phase as _skip_to_phase, PHASE_ORDER
router = APIRouter()
@router.get("", response_model=ScanListResponse)
async def list_scans(
page: int = 1,
per_page: int = 10,
status: Optional[str] = None,
db: AsyncSession = Depends(get_db)
):
"""List all scans with pagination"""
query = select(Scan).order_by(Scan.created_at.desc())
if status:
query = query.where(Scan.status == status)
# Get total count
count_query = select(func.count()).select_from(Scan)
if status:
count_query = count_query.where(Scan.status == status)
total_result = await db.execute(count_query)
total = total_result.scalar()
# Apply pagination
query = query.offset((page - 1) * per_page).limit(per_page)
result = await db.execute(query)
scans = result.scalars().all()
# Load targets for each scan
scan_responses = []
for scan in scans:
targets_query = select(Target).where(Target.scan_id == scan.id)
targets_result = await db.execute(targets_query)
targets = targets_result.scalars().all()
scan_dict = scan.to_dict()
scan_dict["targets"] = [t.to_dict() for t in targets]
scan_responses.append(ScanResponse(**scan_dict))
return ScanListResponse(
scans=scan_responses,
total=total,
page=page,
per_page=per_page
)
@router.post("", response_model=ScanResponse)
async def create_scan(
scan_data: ScanCreate,
background_tasks: BackgroundTasks,
db: AsyncSession = Depends(get_db)
):
"""Create a new scan with optional authentication for authenticated testing"""
# Process authentication config
auth_type = None
auth_credentials = None
if scan_data.auth:
auth_type = scan_data.auth.auth_type
auth_credentials = {}
if scan_data.auth.cookie:
auth_credentials["cookie"] = scan_data.auth.cookie
if scan_data.auth.bearer_token:
auth_credentials["bearer_token"] = scan_data.auth.bearer_token
if scan_data.auth.username:
auth_credentials["username"] = scan_data.auth.username
if scan_data.auth.password:
auth_credentials["password"] = scan_data.auth.password
if scan_data.auth.header_name and scan_data.auth.header_value:
auth_credentials["header_name"] = scan_data.auth.header_name
auth_credentials["header_value"] = scan_data.auth.header_value
# Create scan
scan = Scan(
name=scan_data.name or f"Scan {datetime.now().strftime('%Y-%m-%d %H:%M')}",
scan_type=scan_data.scan_type,
recon_enabled=scan_data.recon_enabled,
custom_prompt=scan_data.custom_prompt,
prompt_id=scan_data.prompt_id,
config=scan_data.config,
auth_type=auth_type,
auth_credentials=auth_credentials,
custom_headers=scan_data.custom_headers,
status="pending"
)
db.add(scan)
await db.flush()
# Create targets
targets = []
for url in scan_data.targets:
parsed = urlparse(url)
target = Target(
scan_id=scan.id,
url=url,
hostname=parsed.hostname,
port=parsed.port or (443 if parsed.scheme == "https" else 80),
protocol=parsed.scheme or "https",
path=parsed.path or "/"
)
db.add(target)
targets.append(target)
await db.commit()
await db.refresh(scan)
scan_dict = scan.to_dict()
scan_dict["targets"] = [t.to_dict() for t in targets]
return ScanResponse(**scan_dict)
@router.get("/{scan_id}", response_model=ScanResponse)
async def get_scan(scan_id: str, db: AsyncSession = Depends(get_db)):
"""Get scan details by ID"""
result = await db.execute(select(Scan).where(Scan.id == scan_id))
scan = result.scalar_one_or_none()
if not scan:
raise HTTPException(status_code=404, detail="Scan not found")
# Load targets
targets_result = await db.execute(select(Target).where(Target.scan_id == scan_id))
targets = targets_result.scalars().all()
scan_dict = scan.to_dict()
scan_dict["targets"] = [t.to_dict() for t in targets]
return ScanResponse(**scan_dict)
@router.post("/{scan_id}/start")
async def start_scan(
scan_id: str,
background_tasks: BackgroundTasks,
db: AsyncSession = Depends(get_db)
):
"""Start a scan execution"""
result = await db.execute(select(Scan).where(Scan.id == scan_id))
scan = result.scalar_one_or_none()
if not scan:
raise HTTPException(status_code=404, detail="Scan not found")
if scan.status == "running":
raise HTTPException(status_code=400, detail="Scan is already running")
# Update scan status
scan.status = "running"
scan.started_at = datetime.utcnow()
scan.current_phase = "initializing"
scan.progress = 0
await db.commit()
# Start scan in background with its own database session
background_tasks.add_task(run_scan_task, scan_id)
return {"message": "Scan started", "scan_id": scan_id}
@router.post("/{scan_id}/stop")
async def stop_scan(scan_id: str, db: AsyncSession = Depends(get_db)):
"""Stop a running scan and save partial results"""
from backend.api.websocket import manager as ws_manager
from backend.api.v1.agent import scan_to_agent, agent_instances, agent_results
result = await db.execute(select(Scan).where(Scan.id == scan_id))
scan = result.scalar_one_or_none()
if not scan:
raise HTTPException(status_code=404, detail="Scan not found")
if scan.status not in ("running", "paused"):
raise HTTPException(status_code=400, detail="Scan is not running or paused")
# Signal the running agent to stop
agent_id = scan_to_agent.get(scan_id)
if agent_id and agent_id in agent_instances:
agent_instances[agent_id].cancel()
if agent_id in agent_results:
agent_results[agent_id]["status"] = "stopped"
agent_results[agent_id]["phase"] = "stopped"
# Update scan status
scan.status = "stopped"
scan.completed_at = datetime.utcnow()
scan.current_phase = "stopped"
# Calculate duration
if scan.started_at:
duration = (scan.completed_at - scan.started_at).total_seconds()
scan.duration = int(duration)
# Compute final vulnerability statistics from database
for severity in ["critical", "high", "medium", "low", "info"]:
count_result = await db.execute(
select(func.count()).select_from(Vulnerability)
.where(Vulnerability.scan_id == scan_id)
.where(Vulnerability.severity == severity)
)
setattr(scan, f"{severity}_count", count_result.scalar() or 0)
# Get total vulnerability count
total_vuln_result = await db.execute(
select(func.count()).select_from(Vulnerability)
.where(Vulnerability.scan_id == scan_id)
)
scan.total_vulnerabilities = total_vuln_result.scalar() or 0
# Get total endpoint count
total_endpoint_result = await db.execute(
select(func.count()).select_from(Endpoint)
.where(Endpoint.scan_id == scan_id)
)
scan.total_endpoints = total_endpoint_result.scalar() or 0
await db.commit()
# Build summary for WebSocket broadcast
summary = {
"total_endpoints": scan.total_endpoints,
"total_vulnerabilities": scan.total_vulnerabilities,
"critical": scan.critical_count,
"high": scan.high_count,
"medium": scan.medium_count,
"low": scan.low_count,
"info": scan.info_count,
"duration": scan.duration,
"progress": scan.progress
}
# Broadcast stop event via WebSocket
await ws_manager.broadcast_scan_stopped(scan_id, summary)
await ws_manager.broadcast_log(scan_id, "warning", "Scan stopped by user")
await ws_manager.broadcast_log(scan_id, "info", f"Partial results: {scan.total_vulnerabilities} vulnerabilities found")
# Auto-generate partial report
report_data = None
try:
from backend.services.report_service import auto_generate_report
await ws_manager.broadcast_log(scan_id, "info", "Generating partial report...")
report = await auto_generate_report(db, scan_id, is_partial=True)
report_data = report.to_dict()
await ws_manager.broadcast_log(scan_id, "info", f"Partial report generated: {report.title}")
except Exception as report_error:
await ws_manager.broadcast_log(scan_id, "warning", f"Failed to generate partial report: {str(report_error)}")
return {
"message": "Scan stopped",
"scan_id": scan_id,
"summary": summary,
"report": report_data
}
@router.post("/{scan_id}/pause")
async def pause_scan(scan_id: str, db: AsyncSession = Depends(get_db)):
"""Pause a running scan"""
from backend.api.websocket import manager as ws_manager
from backend.api.v1.agent import scan_to_agent, agent_instances, agent_results
result = await db.execute(select(Scan).where(Scan.id == scan_id))
scan = result.scalar_one_or_none()
if not scan:
raise HTTPException(status_code=404, detail="Scan not found")
if scan.status != "running":
raise HTTPException(status_code=400, detail="Scan is not running")
# Signal the agent to pause
agent_id = scan_to_agent.get(scan_id)
if agent_id and agent_id in agent_instances:
agent_instances[agent_id].pause()
if agent_id in agent_results:
agent_results[agent_id]["status"] = "paused"
agent_results[agent_id]["phase"] = "paused"
scan.status = "paused"
scan.current_phase = "paused"
await db.commit()
await ws_manager.broadcast_log(scan_id, "warning", "Scan paused by user")
return {"message": "Scan paused", "scan_id": scan_id}
@router.post("/{scan_id}/resume")
async def resume_scan(scan_id: str, db: AsyncSession = Depends(get_db)):
"""Resume a paused scan"""
from backend.api.websocket import manager as ws_manager
from backend.api.v1.agent import scan_to_agent, agent_instances, agent_results
result = await db.execute(select(Scan).where(Scan.id == scan_id))
scan = result.scalar_one_or_none()
if not scan:
raise HTTPException(status_code=404, detail="Scan not found")
if scan.status != "paused":
raise HTTPException(status_code=400, detail="Scan is not paused")
# Signal the agent to resume
agent_id = scan_to_agent.get(scan_id)
if agent_id and agent_id in agent_instances:
agent_instances[agent_id].resume()
if agent_id in agent_results:
agent_results[agent_id]["status"] = "running"
agent_results[agent_id]["phase"] = "testing"
scan.status = "running"
scan.current_phase = "testing"
await db.commit()
await ws_manager.broadcast_log(scan_id, "info", "Scan resumed by user")
return {"message": "Scan resumed", "scan_id": scan_id}
@router.post("/{scan_id}/skip-to/{target_phase}")
async def skip_to_phase_endpoint(scan_id: str, target_phase: str, db: AsyncSession = Depends(get_db)):
"""Skip the current scan phase and jump to a target phase.
Valid phases: recon, analyzing, testing, completed
Can only skip forward (to a phase ahead of current).
"""
result = await db.execute(select(Scan).where(Scan.id == scan_id))
scan = result.scalar_one_or_none()
if not scan:
raise HTTPException(status_code=404, detail="Scan not found")
if scan.status not in ("running", "paused"):
raise HTTPException(status_code=400, detail="Scan is not running or paused")
# If paused, resume first so the skip can be processed
if scan.status == "paused":
from backend.api.v1.agent import scan_to_agent, agent_instances, agent_results
agent_id = scan_to_agent.get(scan_id)
if agent_id and agent_id in agent_instances:
agent_instances[agent_id].resume()
if agent_id in agent_results:
agent_results[agent_id]["status"] = "running"
agent_results[agent_id]["phase"] = agent_results[agent_id].get("last_phase", "testing")
scan.status = "running"
await db.commit()
if target_phase not in PHASE_ORDER:
raise HTTPException(
status_code=400,
detail=f"Invalid phase '{target_phase}'. Valid: {', '.join(PHASE_ORDER[1:])}"
)
# Validate forward skip
current_idx = PHASE_ORDER.index(scan.current_phase) if scan.current_phase in PHASE_ORDER else 0
target_idx = PHASE_ORDER.index(target_phase)
if target_idx <= current_idx:
raise HTTPException(
status_code=400,
detail=f"Cannot skip backward. Current: {scan.current_phase}, target: {target_phase}"
)
# Signal the running scan to skip
success = _skip_to_phase(scan_id, target_phase)
if not success:
raise HTTPException(status_code=500, detail="Failed to signal phase skip")
# Broadcast via WebSocket
from backend.api.websocket import manager as ws_manager
await ws_manager.broadcast_log(scan_id, "warning", f">> User requested skip to phase: {target_phase}")
await ws_manager.broadcast_phase_change(scan_id, f"skipping_to_{target_phase}")
return {
"message": f"Skipping to phase: {target_phase}",
"scan_id": scan_id,
"from_phase": scan.current_phase,
"target_phase": target_phase
}
@router.get("/{scan_id}/status", response_model=ScanProgress)
async def get_scan_status(scan_id: str, db: AsyncSession = Depends(get_db)):
"""Get scan progress and status"""
result = await db.execute(select(Scan).where(Scan.id == scan_id))
scan = result.scalar_one_or_none()
if not scan:
raise HTTPException(status_code=404, detail="Scan not found")
return ScanProgress(
scan_id=scan.id,
status=scan.status,
progress=scan.progress,
current_phase=scan.current_phase,
total_endpoints=scan.total_endpoints,
total_vulnerabilities=scan.total_vulnerabilities
)
@router.delete("/{scan_id}")
async def delete_scan(scan_id: str, db: AsyncSession = Depends(get_db)):
"""Delete a scan"""
result = await db.execute(select(Scan).where(Scan.id == scan_id))
scan = result.scalar_one_or_none()
if not scan:
raise HTTPException(status_code=404, detail="Scan not found")
if scan.status == "running":
raise HTTPException(status_code=400, detail="Cannot delete running scan")
await db.delete(scan)
await db.commit()
return {"message": "Scan deleted", "scan_id": scan_id}
@router.get("/{scan_id}/endpoints")
async def get_scan_endpoints(
scan_id: str,
page: int = 1,
per_page: int = 50,
db: AsyncSession = Depends(get_db)
):
"""Get endpoints discovered in a scan"""
result = await db.execute(select(Scan).where(Scan.id == scan_id))
scan = result.scalar_one_or_none()
if not scan:
raise HTTPException(status_code=404, detail="Scan not found")
query = select(Endpoint).where(Endpoint.scan_id == scan_id).order_by(Endpoint.discovered_at.desc())
# Count
count_result = await db.execute(select(func.count()).select_from(Endpoint).where(Endpoint.scan_id == scan_id))
total = count_result.scalar()
# Paginate
query = query.offset((page - 1) * per_page).limit(per_page)
result = await db.execute(query)
endpoints = result.scalars().all()
return {
"endpoints": [e.to_dict() for e in endpoints],
"total": total,
"page": page,
"per_page": per_page
}
@router.get("/{scan_id}/vulnerabilities")
async def get_scan_vulnerabilities(
scan_id: str,
severity: Optional[str] = None,
page: int = 1,
per_page: int = 50,
db: AsyncSession = Depends(get_db)
):
"""Get vulnerabilities found in a scan"""
result = await db.execute(select(Scan).where(Scan.id == scan_id))
scan = result.scalar_one_or_none()
if not scan:
raise HTTPException(status_code=404, detail="Scan not found")
query = select(Vulnerability).where(Vulnerability.scan_id == scan_id)
if severity:
query = query.where(Vulnerability.severity == severity)
query = query.order_by(Vulnerability.created_at.desc())
# Count
count_query = select(func.count()).select_from(Vulnerability).where(Vulnerability.scan_id == scan_id)
if severity:
count_query = count_query.where(Vulnerability.severity == severity)
count_result = await db.execute(count_query)
total = count_result.scalar()
# Paginate
query = query.offset((page - 1) * per_page).limit(per_page)
result = await db.execute(query)
vulnerabilities = result.scalars().all()
return {
"vulnerabilities": [v.to_dict() for v in vulnerabilities],
"total": total,
"page": page,
"per_page": per_page
}
class ValidationRequest(BaseModel):
validation_status: str # "validated" | "false_positive" | "ai_confirmed" | "ai_rejected" | "pending_review"
notes: Optional[str] = None
@router.patch("/vulnerabilities/{vuln_id}/validate")
async def validate_vulnerability(
vuln_id: str,
body: ValidationRequest,
db: AsyncSession = Depends(get_db)
):
"""Manually validate or reject a vulnerability finding"""
valid_statuses = {"validated", "false_positive", "ai_confirmed", "ai_rejected", "pending_review"}
if body.validation_status not in valid_statuses:
raise HTTPException(status_code=400, detail=f"Invalid status. Must be one of: {', '.join(valid_statuses)}")
result = await db.execute(select(Vulnerability).where(Vulnerability.id == vuln_id))
vuln = result.scalar_one_or_none()
if not vuln:
raise HTTPException(status_code=404, detail="Vulnerability not found")
old_status = vuln.validation_status or "ai_confirmed"
vuln.validation_status = body.validation_status
if body.notes:
vuln.ai_rejection_reason = body.notes
# Update scan severity counts when validation status changes
scan_result = await db.execute(select(Scan).where(Scan.id == vuln.scan_id))
scan = scan_result.scalar_one_or_none()
if scan:
sev = vuln.severity
# If changing from rejected to validated: add to counts
if old_status == "ai_rejected" and body.validation_status == "validated":
scan.total_vulnerabilities = (scan.total_vulnerabilities or 0) + 1
if sev == "critical":
scan.critical_count = (scan.critical_count or 0) + 1
elif sev == "high":
scan.high_count = (scan.high_count or 0) + 1
elif sev == "medium":
scan.medium_count = (scan.medium_count or 0) + 1
elif sev == "low":
scan.low_count = (scan.low_count or 0) + 1
elif sev == "info":
scan.info_count = (scan.info_count or 0) + 1
# If changing from confirmed to false_positive: subtract from counts
elif old_status in ("ai_confirmed", "validated") and body.validation_status == "false_positive":
scan.total_vulnerabilities = max(0, (scan.total_vulnerabilities or 0) - 1)
if sev == "critical":
scan.critical_count = max(0, (scan.critical_count or 0) - 1)
elif sev == "high":
scan.high_count = max(0, (scan.high_count or 0) - 1)
elif sev == "medium":
scan.medium_count = max(0, (scan.medium_count or 0) - 1)
elif sev == "low":
scan.low_count = max(0, (scan.low_count or 0) - 1)
elif sev == "info":
scan.info_count = max(0, (scan.info_count or 0) - 1)
await db.commit()
return {"message": "Vulnerability validation updated", "vulnerability": vuln.to_dict()}
# --- Adaptive Learning Feedback ---
class FeedbackRequest(BaseModel):
is_true_positive: bool
explanation: str = ""
@router.post("/vulnerabilities/{vuln_id}/feedback")
async def submit_vulnerability_feedback(
vuln_id: str,
body: FeedbackRequest,
db: AsyncSession = Depends(get_db)
):
"""Submit TP/FP feedback for a vulnerability finding.
Records feedback in the adaptive learner so the agent improves over time.
Also updates the validation_status in the database.
"""
result = await db.execute(select(Vulnerability).where(Vulnerability.id == vuln_id))
vuln = result.scalar_one_or_none()
if not vuln:
raise HTTPException(status_code=404, detail="Vulnerability not found")
if len(body.explanation) < 3 and not body.is_true_positive:
raise HTTPException(status_code=400, detail="Explanation required for false positive feedback (min 3 chars)")
# Update DB validation status
vuln.validation_status = "validated" if body.is_true_positive else "false_positive"
if body.explanation:
vuln.ai_rejection_reason = body.explanation
# Update scan counts
scan_result = await db.execute(select(Scan).where(Scan.id == vuln.scan_id))
scan = scan_result.scalar_one_or_none()
if scan and not body.is_true_positive:
sev = vuln.severity
scan.total_vulnerabilities = max(0, (scan.total_vulnerabilities or 0) - 1)
count_attr = f"{sev}_count"
if hasattr(scan, count_attr):
setattr(scan, count_attr, max(0, (getattr(scan, count_attr) or 0) - 1))
await db.commit()
# Record in adaptive learner
pattern_count = 0
try:
from backend.core.adaptive_learner import AdaptiveLearner
learner = AdaptiveLearner()
vuln_dict = vuln.to_dict()
learner.record_feedback(
vuln_id=vuln_id,
vuln_type=vuln_dict.get("vuln_type", "unknown"),
endpoint=vuln_dict.get("url", ""),
param=vuln_dict.get("parameter", ""),
payload=vuln_dict.get("payload", ""),
is_tp=body.is_true_positive,
explanation=body.explanation,
severity=vuln_dict.get("severity", "medium"),
domain=vuln_dict.get("url", ""),
)
stats = learner.get_stats()
pattern_count = stats.get("total_patterns", 0)
except Exception as e:
logger.warning(f"Adaptive learner feedback failed: {e}")
return {
"message": "Feedback recorded",
"vulnerability_id": vuln_id,
"is_true_positive": body.is_true_positive,
"pattern_count": pattern_count,
}
@router.get("/vulnerabilities/learning/stats")
async def get_learning_stats():
"""Get adaptive learning statistics."""
try:
from backend.core.adaptive_learner import AdaptiveLearner
learner = AdaptiveLearner()
return learner.get_stats()
except Exception as e:
return {"error": str(e), "total_feedback": 0, "total_patterns": 0}

140
backend/api/v1/scheduler.py Executable file
View File

@@ -0,0 +1,140 @@
"""
NeuroSploit v3 - Scheduler API Router
CRUD endpoints for managing scheduled scan jobs.
"""
import json
from pathlib import Path
from fastapi import APIRouter, HTTPException, Request
from pydantic import BaseModel
from typing import Optional, List, Dict
router = APIRouter()
CONFIG_PATH = Path(__file__).parent.parent.parent.parent / "config" / "config.json"
class ScheduleJobRequest(BaseModel):
"""Request model for creating a scheduled job."""
job_id: str
target: str
scan_type: str = "quick"
cron_expression: Optional[str] = None
interval_minutes: Optional[int] = None
agent_role: Optional[str] = None
llm_profile: Optional[str] = None
class ScheduleJobResponse(BaseModel):
"""Response model for a scheduled job."""
id: str
target: str
scan_type: str
schedule: str
status: str
next_run: Optional[str] = None
last_run: Optional[str] = None
run_count: int = 0
@router.get("/", response_model=List[Dict])
async def list_scheduled_jobs(request: Request):
"""List all scheduled scan jobs."""
scheduler = getattr(request.app.state, 'scheduler', None)
if not scheduler:
return []
return scheduler.list_jobs()
@router.post("/", response_model=Dict)
async def create_scheduled_job(job: ScheduleJobRequest, request: Request):
"""Create a new scheduled scan job."""
scheduler = getattr(request.app.state, 'scheduler', None)
if not scheduler:
raise HTTPException(status_code=503, detail="Scheduler not available")
if not job.cron_expression and not job.interval_minutes:
raise HTTPException(
status_code=400,
detail="Either cron_expression or interval_minutes must be provided"
)
result = scheduler.add_job(
job_id=job.job_id,
target=job.target,
scan_type=job.scan_type,
cron_expression=job.cron_expression,
interval_minutes=job.interval_minutes,
agent_role=job.agent_role,
llm_profile=job.llm_profile
)
if "error" in result:
raise HTTPException(status_code=400, detail=result["error"])
return result
@router.delete("/{job_id}")
async def delete_scheduled_job(job_id: str, request: Request):
"""Delete a scheduled scan job."""
scheduler = getattr(request.app.state, 'scheduler', None)
if not scheduler:
raise HTTPException(status_code=503, detail="Scheduler not available")
success = scheduler.remove_job(job_id)
if not success:
raise HTTPException(status_code=404, detail=f"Job '{job_id}' not found")
return {"message": f"Job '{job_id}' deleted", "id": job_id}
@router.post("/{job_id}/pause")
async def pause_scheduled_job(job_id: str, request: Request):
"""Pause a scheduled scan job."""
scheduler = getattr(request.app.state, 'scheduler', None)
if not scheduler:
raise HTTPException(status_code=503, detail="Scheduler not available")
success = scheduler.pause_job(job_id)
if not success:
raise HTTPException(status_code=404, detail=f"Job '{job_id}' not found")
return {"message": f"Job '{job_id}' paused", "id": job_id, "status": "paused"}
@router.post("/{job_id}/resume")
async def resume_scheduled_job(job_id: str, request: Request):
"""Resume a paused scheduled scan job."""
scheduler = getattr(request.app.state, 'scheduler', None)
if not scheduler:
raise HTTPException(status_code=503, detail="Scheduler not available")
success = scheduler.resume_job(job_id)
if not success:
raise HTTPException(status_code=404, detail=f"Job '{job_id}' not found")
return {"message": f"Job '{job_id}' resumed", "id": job_id, "status": "active"}
@router.get("/agent-roles", response_model=List[Dict])
async def get_agent_roles():
"""Return available agent roles from config.json for scheduler dropdown."""
try:
if not CONFIG_PATH.exists():
return []
config = json.loads(CONFIG_PATH.read_text())
roles = config.get("agent_roles", {})
result = []
for role_id, role_data in roles.items():
if role_data.get("enabled", True):
result.append({
"id": role_id,
"name": role_id.replace("_", " ").title(),
"description": role_data.get("description", ""),
"tools": role_data.get("tools_allowed", []),
})
return result
except Exception:
return []

707
backend/api/v1/settings.py Executable file
View File

@@ -0,0 +1,707 @@
"""
NeuroSploit v3 - Settings API Endpoints
"""
import os
import re
import time
from pathlib import Path
from typing import Optional, Dict, List
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, delete, text
from pydantic import BaseModel
from backend.db.database import get_db, engine
from backend.models import Scan, Target, Endpoint, Vulnerability, VulnerabilityTest, Report
router = APIRouter()
# Path to .env file (project root)
ENV_FILE_PATH = Path(__file__).parent.parent.parent.parent / ".env"
def _update_env_file(updates: Dict[str, str]) -> bool:
"""
Update key=value pairs in the .env file without breaking formatting.
- If the key exists (even commented out), update its value
- If the key doesn't exist, append it
- Preserves comments and blank lines
"""
if not ENV_FILE_PATH.exists():
return False
try:
lines = ENV_FILE_PATH.read_text().splitlines()
updated_keys = set()
new_lines = []
for line in lines:
stripped = line.strip()
matched = False
for key, value in updates.items():
# Match: KEY=..., # KEY=..., #KEY=...
pattern = rf'^#?\s*{re.escape(key)}\s*='
if re.match(pattern, stripped):
# Replace with uncommented key=value
new_lines.append(f"{key}={value}")
updated_keys.add(key)
matched = True
break
if not matched:
new_lines.append(line)
# Append any keys that weren't found in existing file
for key, value in updates.items():
if key not in updated_keys:
new_lines.append(f"{key}={value}")
# Write back with trailing newline
ENV_FILE_PATH.write_text("\n".join(new_lines) + "\n")
return True
except Exception as e:
print(f"Warning: Failed to update .env file: {e}")
return False
class SettingsUpdate(BaseModel):
"""Settings update schema"""
llm_provider: Optional[str] = None
llm_model: Optional[str] = None
anthropic_api_key: Optional[str] = None
openai_api_key: Optional[str] = None
openrouter_api_key: Optional[str] = None
gemini_api_key: Optional[str] = None
together_api_key: Optional[str] = None
fireworks_api_key: Optional[str] = None
ollama_base_url: Optional[str] = None
lmstudio_base_url: Optional[str] = None
max_concurrent_scans: Optional[int] = None
aggressive_mode: Optional[bool] = None
default_scan_type: Optional[str] = None
recon_enabled_by_default: Optional[bool] = None
enable_model_routing: Optional[bool] = None
enable_knowledge_augmentation: Optional[bool] = None
enable_browser_validation: Optional[bool] = None
max_output_tokens: Optional[int] = None
# Notifications
enable_notifications: Optional[bool] = None
discord_webhook_url: Optional[str] = None
telegram_bot_token: Optional[str] = None
telegram_chat_id: Optional[str] = None
twilio_account_sid: Optional[str] = None
twilio_auth_token: Optional[str] = None
twilio_from_number: Optional[str] = None
twilio_to_number: Optional[str] = None
notification_severity_filter: Optional[str] = None
class SettingsResponse(BaseModel):
"""Settings response schema"""
llm_provider: str = "claude"
llm_model: str = ""
has_anthropic_key: bool = False
has_openai_key: bool = False
has_openrouter_key: bool = False
has_gemini_key: bool = False
has_together_key: bool = False
has_fireworks_key: bool = False
ollama_base_url: str = ""
lmstudio_base_url: str = ""
max_concurrent_scans: int = 3
aggressive_mode: bool = False
default_scan_type: str = "full"
recon_enabled_by_default: bool = True
enable_model_routing: bool = False
enable_knowledge_augmentation: bool = False
enable_browser_validation: bool = False
max_output_tokens: Optional[int] = None
# Notifications
enable_notifications: bool = False
has_discord_webhook: bool = False
has_telegram_bot: bool = False
has_twilio_credentials: bool = False
notification_severity_filter: str = "critical,high"
class ModelInfo(BaseModel):
"""Info about an available LLM model"""
provider: str
model_id: str
display_name: str
size: Optional[str] = None
context_length: Optional[int] = None
is_local: bool = False
class ModelCatalogResponse(BaseModel):
"""Response from model catalog endpoint"""
provider: str
models: List[ModelInfo]
available: bool
error: Optional[str] = None
def _load_settings_from_env() -> dict:
"""
Load settings from environment variables / .env file on startup.
This ensures settings persist across server restarts and browser sessions.
"""
from dotenv import load_dotenv
# Re-read .env file to pick up disk-persisted values
if ENV_FILE_PATH.exists():
load_dotenv(ENV_FILE_PATH, override=True)
def _env_bool(key: str, default: bool = False) -> bool:
val = os.getenv(key, "").strip().lower()
if val in ("true", "1", "yes"):
return True
if val in ("false", "0", "no"):
return False
return default
def _env_int(key: str, default=None):
val = os.getenv(key, "").strip()
if val:
try:
return int(val)
except ValueError:
pass
return default
# Detect provider from which keys are set
provider = "claude"
if os.getenv("ANTHROPIC_API_KEY"):
provider = "claude"
elif os.getenv("OPENAI_API_KEY"):
provider = "openai"
elif os.getenv("OPENROUTER_API_KEY"):
provider = "openrouter"
return {
"llm_provider": provider,
"llm_model": os.getenv("DEFAULT_LLM_MODEL", ""),
"anthropic_api_key": os.getenv("ANTHROPIC_API_KEY", ""),
"openai_api_key": os.getenv("OPENAI_API_KEY", ""),
"openrouter_api_key": os.getenv("OPENROUTER_API_KEY", ""),
"gemini_api_key": os.getenv("GEMINI_API_KEY", ""),
"together_api_key": os.getenv("TOGETHER_API_KEY", ""),
"fireworks_api_key": os.getenv("FIREWORKS_API_KEY", ""),
"ollama_base_url": os.getenv("OLLAMA_BASE_URL", os.getenv("OLLAMA_URL", "")),
"lmstudio_base_url": os.getenv("LMSTUDIO_BASE_URL", os.getenv("LMSTUDIO_URL", "")),
"max_concurrent_scans": _env_int("MAX_CONCURRENT_SCANS", 3),
"aggressive_mode": _env_bool("AGGRESSIVE_MODE", False),
"default_scan_type": os.getenv("DEFAULT_SCAN_TYPE", "full"),
"recon_enabled_by_default": _env_bool("RECON_ENABLED_BY_DEFAULT", True),
"enable_model_routing": _env_bool("ENABLE_MODEL_ROUTING", False),
"enable_knowledge_augmentation": _env_bool("ENABLE_KNOWLEDGE_AUGMENTATION", False),
"enable_browser_validation": _env_bool("ENABLE_BROWSER_VALIDATION", False),
"max_output_tokens": _env_int("MAX_OUTPUT_TOKENS", None),
# Notifications
"enable_notifications": _env_bool("ENABLE_NOTIFICATIONS", False),
"discord_webhook_url": os.getenv("DISCORD_WEBHOOK_URL", ""),
"telegram_bot_token": os.getenv("TELEGRAM_BOT_TOKEN", ""),
"telegram_chat_id": os.getenv("TELEGRAM_CHAT_ID", ""),
"twilio_account_sid": os.getenv("TWILIO_ACCOUNT_SID", ""),
"twilio_auth_token": os.getenv("TWILIO_AUTH_TOKEN", ""),
"twilio_from_number": os.getenv("TWILIO_FROM_NUMBER", ""),
"twilio_to_number": os.getenv("TWILIO_TO_NUMBER", ""),
"notification_severity_filter": os.getenv("NOTIFICATION_SEVERITY_FILTER", "critical,high"),
}
# Load settings from .env on module import (server start)
_settings = _load_settings_from_env()
@router.get("", response_model=SettingsResponse)
async def get_settings():
"""Get current settings"""
import os
return SettingsResponse(
llm_provider=_settings["llm_provider"],
llm_model=_settings.get("llm_model", ""),
has_anthropic_key=bool(_settings["anthropic_api_key"] or os.getenv("ANTHROPIC_API_KEY")),
has_openai_key=bool(_settings["openai_api_key"] or os.getenv("OPENAI_API_KEY")),
has_openrouter_key=bool(_settings["openrouter_api_key"] or os.getenv("OPENROUTER_API_KEY")),
has_gemini_key=bool(_settings.get("gemini_api_key") or os.getenv("GEMINI_API_KEY")),
has_together_key=bool(_settings.get("together_api_key") or os.getenv("TOGETHER_API_KEY")),
has_fireworks_key=bool(_settings.get("fireworks_api_key") or os.getenv("FIREWORKS_API_KEY")),
ollama_base_url=_settings.get("ollama_base_url", ""),
lmstudio_base_url=_settings.get("lmstudio_base_url", ""),
max_concurrent_scans=_settings["max_concurrent_scans"],
aggressive_mode=_settings["aggressive_mode"],
default_scan_type=_settings["default_scan_type"],
recon_enabled_by_default=_settings["recon_enabled_by_default"],
enable_model_routing=_settings["enable_model_routing"],
enable_knowledge_augmentation=_settings["enable_knowledge_augmentation"],
enable_browser_validation=_settings["enable_browser_validation"],
max_output_tokens=_settings["max_output_tokens"],
# Notifications
enable_notifications=_settings.get("enable_notifications", False),
has_discord_webhook=bool(_settings.get("discord_webhook_url")),
has_telegram_bot=bool(_settings.get("telegram_bot_token") and _settings.get("telegram_chat_id")),
has_twilio_credentials=bool(
_settings.get("twilio_account_sid") and _settings.get("twilio_auth_token")
and _settings.get("twilio_from_number") and _settings.get("twilio_to_number")
),
notification_severity_filter=_settings.get("notification_severity_filter", "critical,high"),
)
@router.put("", response_model=SettingsResponse)
async def update_settings(settings_data: SettingsUpdate):
"""Update settings - persists to memory, env vars, AND .env file"""
env_updates: Dict[str, str] = {}
if settings_data.llm_provider is not None:
_settings["llm_provider"] = settings_data.llm_provider
if settings_data.llm_model is not None:
_settings["llm_model"] = settings_data.llm_model
os.environ["DEFAULT_LLM_MODEL"] = settings_data.llm_model
env_updates["DEFAULT_LLM_MODEL"] = settings_data.llm_model
if settings_data.anthropic_api_key is not None:
_settings["anthropic_api_key"] = settings_data.anthropic_api_key
if settings_data.anthropic_api_key:
os.environ["ANTHROPIC_API_KEY"] = settings_data.anthropic_api_key
env_updates["ANTHROPIC_API_KEY"] = settings_data.anthropic_api_key
if settings_data.openai_api_key is not None:
_settings["openai_api_key"] = settings_data.openai_api_key
if settings_data.openai_api_key:
os.environ["OPENAI_API_KEY"] = settings_data.openai_api_key
env_updates["OPENAI_API_KEY"] = settings_data.openai_api_key
if settings_data.openrouter_api_key is not None:
_settings["openrouter_api_key"] = settings_data.openrouter_api_key
if settings_data.openrouter_api_key:
os.environ["OPENROUTER_API_KEY"] = settings_data.openrouter_api_key
env_updates["OPENROUTER_API_KEY"] = settings_data.openrouter_api_key
if settings_data.gemini_api_key is not None:
_settings["gemini_api_key"] = settings_data.gemini_api_key
if settings_data.gemini_api_key:
os.environ["GEMINI_API_KEY"] = settings_data.gemini_api_key
env_updates["GEMINI_API_KEY"] = settings_data.gemini_api_key
if settings_data.together_api_key is not None:
_settings["together_api_key"] = settings_data.together_api_key
if settings_data.together_api_key:
os.environ["TOGETHER_API_KEY"] = settings_data.together_api_key
env_updates["TOGETHER_API_KEY"] = settings_data.together_api_key
if settings_data.fireworks_api_key is not None:
_settings["fireworks_api_key"] = settings_data.fireworks_api_key
if settings_data.fireworks_api_key:
os.environ["FIREWORKS_API_KEY"] = settings_data.fireworks_api_key
env_updates["FIREWORKS_API_KEY"] = settings_data.fireworks_api_key
if settings_data.ollama_base_url is not None:
_settings["ollama_base_url"] = settings_data.ollama_base_url
if settings_data.ollama_base_url:
os.environ["OLLAMA_BASE_URL"] = settings_data.ollama_base_url
env_updates["OLLAMA_BASE_URL"] = settings_data.ollama_base_url
if settings_data.lmstudio_base_url is not None:
_settings["lmstudio_base_url"] = settings_data.lmstudio_base_url
if settings_data.lmstudio_base_url:
os.environ["LMSTUDIO_BASE_URL"] = settings_data.lmstudio_base_url
env_updates["LMSTUDIO_BASE_URL"] = settings_data.lmstudio_base_url
if settings_data.max_concurrent_scans is not None:
_settings["max_concurrent_scans"] = settings_data.max_concurrent_scans
if settings_data.aggressive_mode is not None:
_settings["aggressive_mode"] = settings_data.aggressive_mode
if settings_data.default_scan_type is not None:
_settings["default_scan_type"] = settings_data.default_scan_type
if settings_data.recon_enabled_by_default is not None:
_settings["recon_enabled_by_default"] = settings_data.recon_enabled_by_default
if settings_data.enable_model_routing is not None:
_settings["enable_model_routing"] = settings_data.enable_model_routing
val = str(settings_data.enable_model_routing).lower()
os.environ["ENABLE_MODEL_ROUTING"] = val
env_updates["ENABLE_MODEL_ROUTING"] = val
if settings_data.enable_knowledge_augmentation is not None:
_settings["enable_knowledge_augmentation"] = settings_data.enable_knowledge_augmentation
val = str(settings_data.enable_knowledge_augmentation).lower()
os.environ["ENABLE_KNOWLEDGE_AUGMENTATION"] = val
env_updates["ENABLE_KNOWLEDGE_AUGMENTATION"] = val
if settings_data.enable_browser_validation is not None:
_settings["enable_browser_validation"] = settings_data.enable_browser_validation
val = str(settings_data.enable_browser_validation).lower()
os.environ["ENABLE_BROWSER_VALIDATION"] = val
env_updates["ENABLE_BROWSER_VALIDATION"] = val
if settings_data.max_output_tokens is not None:
_settings["max_output_tokens"] = settings_data.max_output_tokens
if settings_data.max_output_tokens:
os.environ["MAX_OUTPUT_TOKENS"] = str(settings_data.max_output_tokens)
env_updates["MAX_OUTPUT_TOKENS"] = str(settings_data.max_output_tokens)
# Notifications
if settings_data.enable_notifications is not None:
_settings["enable_notifications"] = settings_data.enable_notifications
val = str(settings_data.enable_notifications).lower()
os.environ["ENABLE_NOTIFICATIONS"] = val
env_updates["ENABLE_NOTIFICATIONS"] = val
if settings_data.discord_webhook_url is not None:
_settings["discord_webhook_url"] = settings_data.discord_webhook_url
os.environ["DISCORD_WEBHOOK_URL"] = settings_data.discord_webhook_url
env_updates["DISCORD_WEBHOOK_URL"] = settings_data.discord_webhook_url
if settings_data.telegram_bot_token is not None:
_settings["telegram_bot_token"] = settings_data.telegram_bot_token
os.environ["TELEGRAM_BOT_TOKEN"] = settings_data.telegram_bot_token
env_updates["TELEGRAM_BOT_TOKEN"] = settings_data.telegram_bot_token
if settings_data.telegram_chat_id is not None:
_settings["telegram_chat_id"] = settings_data.telegram_chat_id
os.environ["TELEGRAM_CHAT_ID"] = settings_data.telegram_chat_id
env_updates["TELEGRAM_CHAT_ID"] = settings_data.telegram_chat_id
if settings_data.twilio_account_sid is not None:
_settings["twilio_account_sid"] = settings_data.twilio_account_sid
os.environ["TWILIO_ACCOUNT_SID"] = settings_data.twilio_account_sid
env_updates["TWILIO_ACCOUNT_SID"] = settings_data.twilio_account_sid
if settings_data.twilio_auth_token is not None:
_settings["twilio_auth_token"] = settings_data.twilio_auth_token
os.environ["TWILIO_AUTH_TOKEN"] = settings_data.twilio_auth_token
env_updates["TWILIO_AUTH_TOKEN"] = settings_data.twilio_auth_token
if settings_data.twilio_from_number is not None:
_settings["twilio_from_number"] = settings_data.twilio_from_number
os.environ["TWILIO_FROM_NUMBER"] = settings_data.twilio_from_number
env_updates["TWILIO_FROM_NUMBER"] = settings_data.twilio_from_number
if settings_data.twilio_to_number is not None:
_settings["twilio_to_number"] = settings_data.twilio_to_number
os.environ["TWILIO_TO_NUMBER"] = settings_data.twilio_to_number
env_updates["TWILIO_TO_NUMBER"] = settings_data.twilio_to_number
if settings_data.notification_severity_filter is not None:
_settings["notification_severity_filter"] = settings_data.notification_severity_filter
os.environ["NOTIFICATION_SEVERITY_FILTER"] = settings_data.notification_severity_filter
env_updates["NOTIFICATION_SEVERITY_FILTER"] = settings_data.notification_severity_filter
# Persist to .env file on disk
if env_updates:
_update_env_file(env_updates)
# Reload notification config if any notification-related fields changed
try:
from backend.core.notification_manager import notification_manager
notification_manager.reload_config()
except ImportError:
pass
return await get_settings()
@router.post("/notifications/test/{channel}")
async def test_notification_channel(channel: str):
"""Send a test notification to a specific channel (discord, telegram, whatsapp)."""
try:
from backend.core.notification_manager import notification_manager
result = await notification_manager.test_channel(channel)
return result
except ImportError:
raise HTTPException(500, "Notification manager not available")
@router.post("/clear-database")
async def clear_database(db: AsyncSession = Depends(get_db)):
"""Clear all data from the database (reset to fresh state)"""
try:
# Delete in correct order to respect foreign key constraints
await db.execute(delete(VulnerabilityTest))
await db.execute(delete(Vulnerability))
await db.execute(delete(Endpoint))
await db.execute(delete(Report))
await db.execute(delete(Target))
await db.execute(delete(Scan))
await db.commit()
return {
"message": "Database cleared successfully",
"status": "success"
}
except Exception as e:
await db.rollback()
raise HTTPException(status_code=500, detail=f"Failed to clear database: {str(e)}")
@router.get("/stats")
async def get_database_stats(db: AsyncSession = Depends(get_db)):
"""Get database statistics"""
from sqlalchemy import func
scans_count = (await db.execute(select(func.count()).select_from(Scan))).scalar() or 0
vulns_count = (await db.execute(select(func.count()).select_from(Vulnerability))).scalar() or 0
endpoints_count = (await db.execute(select(func.count()).select_from(Endpoint))).scalar() or 0
reports_count = (await db.execute(select(func.count()).select_from(Report))).scalar() or 0
return {
"scans": scans_count,
"vulnerabilities": vulns_count,
"endpoints": endpoints_count,
"reports": reports_count
}
@router.get("/tools")
async def get_installed_tools():
"""Check which security tools are installed"""
import asyncio
import shutil
# Complete list of 40+ tools
tools = {
"recon": [
"subfinder", "amass", "assetfinder", "chaos", "uncover",
"dnsx", "massdns", "puredns", "cero", "tlsx", "cdncheck"
],
"web_discovery": [
"httpx", "httprobe", "katana", "gospider", "hakrawler",
"gau", "waybackurls", "cariddi", "getJS", "gowitness"
],
"fuzzing": [
"ffuf", "gobuster", "dirb", "dirsearch", "wfuzz", "arjun", "paramspider"
],
"vulnerability_scanning": [
"nuclei", "nikto", "sqlmap", "xsstrike", "dalfox", "crlfuzz"
],
"port_scanning": [
"nmap", "naabu", "rustscan"
],
"utilities": [
"gf", "qsreplace", "unfurl", "anew", "uro", "jq"
],
"tech_detection": [
"whatweb", "wafw00f"
],
"exploitation": [
"hydra", "medusa", "john", "hashcat"
],
"network": [
"curl", "wget", "dig", "whois"
]
}
results = {}
total_installed = 0
total_tools = 0
for category, tool_list in tools.items():
results[category] = {}
for tool in tool_list:
total_tools += 1
# Check if tool exists in PATH
is_installed = shutil.which(tool) is not None
results[category][tool] = is_installed
if is_installed:
total_installed += 1
return {
"tools": results,
"summary": {
"total": total_tools,
"installed": total_installed,
"missing": total_tools - total_installed,
"percentage": round((total_installed / total_tools) * 100, 1)
}
}
# --- Model Catalog ---
# Cache for model catalog queries (60-second TTL)
_model_cache: Dict[str, dict] = {}
_model_cache_time: Dict[str, float] = {}
MODEL_CACHE_TTL = 60 # seconds
# Common cloud models for dropdown suggestions
CLOUD_MODELS = {
"claude": [
{"model_id": "claude-sonnet-4-20250514", "display_name": "Claude Sonnet 4", "context_length": 200000},
{"model_id": "claude-opus-4-20250514", "display_name": "Claude Opus 4", "context_length": 200000},
{"model_id": "claude-haiku-4-20250514", "display_name": "Claude Haiku 4", "context_length": 200000},
],
"openai": [
{"model_id": "gpt-4-turbo-preview", "display_name": "GPT-4 Turbo", "context_length": 128000},
{"model_id": "gpt-4o", "display_name": "GPT-4o", "context_length": 128000},
{"model_id": "gpt-4o-mini", "display_name": "GPT-4o Mini", "context_length": 128000},
{"model_id": "o1-preview", "display_name": "O1 Preview", "context_length": 128000},
{"model_id": "o1-mini", "display_name": "O1 Mini", "context_length": 128000},
],
"gemini": [
{"model_id": "gemini-pro", "display_name": "Gemini Pro", "context_length": 30720},
{"model_id": "gemini-1.5-pro", "display_name": "Gemini 1.5 Pro", "context_length": 1048576},
{"model_id": "gemini-1.5-flash", "display_name": "Gemini 1.5 Flash", "context_length": 1048576},
{"model_id": "gemini-2.0-flash", "display_name": "Gemini 2.0 Flash", "context_length": 1048576},
],
"together": [
{"model_id": "meta-llama/Llama-3.3-70B-Instruct-Turbo", "display_name": "Llama 3.3 70B", "context_length": 131072},
{"model_id": "Qwen/Qwen2.5-72B-Instruct-Turbo", "display_name": "Qwen 2.5 72B", "context_length": 32768},
{"model_id": "deepseek-ai/DeepSeek-R1", "display_name": "DeepSeek R1", "context_length": 65536},
{"model_id": "mistralai/Mixtral-8x22B-Instruct-v0.1", "display_name": "Mixtral 8x22B", "context_length": 65536},
],
"fireworks": [
{"model_id": "accounts/fireworks/models/llama-v3p3-70b-instruct", "display_name": "Llama 3.3 70B", "context_length": 131072},
{"model_id": "accounts/fireworks/models/qwen2p5-72b-instruct", "display_name": "Qwen 2.5 72B", "context_length": 32768},
{"model_id": "accounts/fireworks/models/deepseek-r1", "display_name": "DeepSeek R1", "context_length": 65536},
],
"codex": [
{"model_id": "codex-mini-latest", "display_name": "Codex Mini", "context_length": 192000},
],
}
@router.get("/models/{provider}", response_model=ModelCatalogResponse)
async def get_provider_models(provider: str):
"""Get available models for a specific provider.
For local providers (ollama, lmstudio), queries the running service.
For cloud providers, returns common model suggestions.
For openrouter, queries the API for available models.
"""
import aiohttp
# Check cache
now = time.time()
if provider in _model_cache and (now - _model_cache_time.get(provider, 0)) < MODEL_CACHE_TTL:
return ModelCatalogResponse(**_model_cache[provider])
if provider == "ollama":
result = await _get_ollama_models()
elif provider == "lmstudio":
result = await _get_lmstudio_models()
elif provider == "openrouter":
result = await _get_openrouter_models()
elif provider in CLOUD_MODELS:
result = {
"provider": provider,
"models": [
ModelInfo(
provider=provider,
model_id=m["model_id"],
display_name=m["display_name"],
context_length=m.get("context_length"),
is_local=False,
).dict()
for m in CLOUD_MODELS[provider]
],
"available": True,
"error": None,
}
else:
raise HTTPException(400, f"Unknown provider: {provider}")
# Cache the result
_model_cache[provider] = result
_model_cache_time[provider] = now
return ModelCatalogResponse(**result)
async def _get_ollama_models() -> dict:
"""Query Ollama for installed models."""
import aiohttp
ollama_url = os.getenv("OLLAMA_BASE_URL", os.getenv("OLLAMA_URL", "http://localhost:11434"))
try:
async with aiohttp.ClientSession() as session:
async with session.get(
f"{ollama_url}/api/tags",
timeout=aiohttp.ClientTimeout(total=3)
) as resp:
if resp.status != 200:
return {"provider": "ollama", "models": [], "available": False, "error": f"HTTP {resp.status}"}
data = await resp.json()
models = []
for m in data.get("models", []):
name = m.get("name", "")
size_bytes = m.get("size", 0)
size_str = f"{size_bytes / 1e9:.1f}B" if size_bytes else None
details = m.get("details", {})
models.append(ModelInfo(
provider="ollama",
model_id=name,
display_name=name,
size=size_str,
context_length=details.get("context_length"),
is_local=True,
).dict())
return {"provider": "ollama", "models": models, "available": True, "error": None}
except Exception as e:
return {"provider": "ollama", "models": [], "available": False, "error": str(e)}
async def _get_lmstudio_models() -> dict:
"""Query LM Studio for loaded models."""
import aiohttp
lmstudio_url = os.getenv("LMSTUDIO_BASE_URL", os.getenv("LMSTUDIO_URL", "http://localhost:1234"))
try:
async with aiohttp.ClientSession() as session:
async with session.get(
f"{lmstudio_url}/v1/models",
timeout=aiohttp.ClientTimeout(total=3)
) as resp:
if resp.status != 200:
return {"provider": "lmstudio", "models": [], "available": False, "error": f"HTTP {resp.status}"}
data = await resp.json()
models = []
for m in data.get("data", []):
model_id = m.get("id", "")
models.append(ModelInfo(
provider="lmstudio",
model_id=model_id,
display_name=model_id,
is_local=True,
).dict())
return {"provider": "lmstudio", "models": models, "available": True, "error": None}
except Exception as e:
return {"provider": "lmstudio", "models": [], "available": False, "error": str(e)}
async def _get_openrouter_models() -> dict:
"""Query OpenRouter for available models."""
import aiohttp
api_key = os.getenv("OPENROUTER_API_KEY", "")
try:
headers = {}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
async with aiohttp.ClientSession() as session:
async with session.get(
"https://openrouter.ai/api/v1/models",
headers=headers,
timeout=aiohttp.ClientTimeout(total=5)
) as resp:
if resp.status != 200:
return {"provider": "openrouter", "models": [], "available": False, "error": f"HTTP {resp.status}"}
data = await resp.json()
models = []
for m in data.get("data", [])[:100]: # Limit to 100 models
model_id = m.get("id", "")
name = m.get("name", model_id)
ctx = m.get("context_length")
models.append(ModelInfo(
provider="openrouter",
model_id=model_id,
display_name=name,
context_length=ctx,
is_local=False,
).dict())
return {"provider": "openrouter", "models": models, "available": True, "error": None}
except Exception as e:
return {"provider": "openrouter", "models": [], "available": False, "error": str(e)}

142
backend/api/v1/targets.py Executable file
View File

@@ -0,0 +1,142 @@
"""
NeuroSploit v3 - Targets API Endpoints
"""
from typing import List
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File
from sqlalchemy.ext.asyncio import AsyncSession
from urllib.parse import urlparse
import re
from backend.db.database import get_db
from backend.schemas.target import TargetCreate, TargetBulkCreate, TargetValidation, TargetResponse
router = APIRouter()
def validate_url(url: str) -> TargetValidation:
"""Validate and parse a URL"""
url = url.strip()
if not url:
return TargetValidation(url=url, valid=False, error="URL is empty")
# URL pattern
url_pattern = re.compile(
r'^https?://'
r'(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+[A-Z]{2,6}\.?|'
r'localhost|'
r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})'
r'(?::\d+)?'
r'(?:/?|[/?]\S+)$', re.IGNORECASE)
# Try with the URL as-is
if url_pattern.match(url):
normalized = url
elif url_pattern.match(f"https://{url}"):
normalized = f"https://{url}"
else:
return TargetValidation(url=url, valid=False, error="Invalid URL format")
# Parse URL
parsed = urlparse(normalized)
return TargetValidation(
url=url,
valid=True,
normalized_url=normalized,
hostname=parsed.hostname,
port=parsed.port or (443 if parsed.scheme == "https" else 80),
protocol=parsed.scheme
)
@router.post("/validate", response_model=TargetValidation)
async def validate_target(target: TargetCreate):
"""Validate a single target URL"""
return validate_url(target.url)
@router.post("/validate/bulk", response_model=List[TargetValidation])
async def validate_targets_bulk(targets: TargetBulkCreate):
"""Validate multiple target URLs"""
results = []
for url in targets.urls:
results.append(validate_url(url))
return results
@router.post("/upload", response_model=List[TargetValidation])
async def upload_targets(file: UploadFile = File(...)):
"""Upload a file with URLs (one per line)"""
if not file.filename:
raise HTTPException(status_code=400, detail="No file provided")
# Check file extension
allowed_extensions = {".txt", ".csv", ".lst"}
ext = "." + file.filename.split(".")[-1].lower() if "." in file.filename else ""
if ext not in allowed_extensions:
raise HTTPException(
status_code=400,
detail=f"Invalid file type. Allowed: {', '.join(allowed_extensions)}"
)
# Read file content
content = await file.read()
try:
text = content.decode("utf-8")
except UnicodeDecodeError:
try:
text = content.decode("latin-1")
except Exception:
raise HTTPException(status_code=400, detail="Unable to decode file")
# Parse URLs (one per line, or comma-separated)
urls = []
for line in text.split("\n"):
line = line.strip()
if not line or line.startswith("#"):
continue
# Handle comma-separated URLs
if "," in line and "://" in line:
for url in line.split(","):
url = url.strip()
if url:
urls.append(url)
else:
urls.append(line)
if not urls:
raise HTTPException(status_code=400, detail="No URLs found in file")
# Validate all URLs
results = []
for url in urls:
results.append(validate_url(url))
return results
@router.post("/parse-input", response_model=List[TargetValidation])
async def parse_target_input(input_text: str):
"""Parse target input (comma-separated or newline-separated)"""
urls = []
# Split by newlines first
for line in input_text.split("\n"):
line = line.strip()
if not line:
continue
# Then split by commas
for url in line.split(","):
url = url.strip()
if url:
urls.append(url)
if not urls:
raise HTTPException(status_code=400, detail="No URLs provided")
results = []
for url in urls:
results.append(validate_url(url))
return results

753
backend/api/v1/terminal.py Executable file
View File

@@ -0,0 +1,753 @@
"""
Terminal Agent API - Interactive infrastructure pentesting via AI chat + Docker sandbox.
Provides session-based terminal interaction with AI-guided command execution,
exploitation path tracking, and VPN status monitoring.
"""
import asyncio
import logging
import re
import time
import uuid
from datetime import datetime, timezone
from typing import Dict, List, Optional
from fastapi import APIRouter, HTTPException, UploadFile, File, Form
from pydantic import BaseModel
from core.llm_manager import LLMManager
from core.sandbox_manager import get_sandbox
logger = logging.getLogger(__name__)
router = APIRouter()
# ---------------------------------------------------------------------------
# In-memory session store
# ---------------------------------------------------------------------------
terminal_sessions: Dict[str, Dict] = {}
# Map session_id -> KaliSandbox instance (per-session container)
session_sandboxes: Dict[str, object] = {}
# ---------------------------------------------------------------------------
# Pre-built templates
# ---------------------------------------------------------------------------
TEMPLATES = {
"network_scanner": {
"name": "Network Scanner",
"description": "Host discovery, port scanning, and service detection",
"system_prompt": (
"You are an expert network reconnaissance specialist. You guide the "
"operator through systematic host discovery, port scanning, and service "
"fingerprinting. Always suggest nmap flags appropriate for the situation, "
"explain output, and recommend next steps based on discovered services. "
"Prioritize stealth when asked and suggest timing/fragmentation options."
),
"initial_commands": [
"nmap -sn {target}",
"nmap -sV -sC -O -p- {target}",
"nmap -sU --top-ports 50 {target}",
],
},
"lateral_movement": {
"name": "Lateral Movement",
"description": "Pass-the-hash, SMB/WinRM pivoting, and SSH tunneling",
"system_prompt": (
"You are a lateral movement specialist. You help the operator pivot "
"through compromised networks using techniques such as pass-the-hash, "
"SMB relay, WinRM sessions, SSH tunneling, and SOCKS proxying. Always "
"verify credentials before attempting pivots, suggest cleanup steps, "
"and track which hosts have been compromised."
),
"initial_commands": [
"crackmapexec smb {target} -u '' -p ''",
"crackmapexec smb {target} --shares -u '' -p ''",
"ssh -D 1080 -N -f user@{target}",
],
},
"privilege_escalation": {
"name": "Privilege Escalation",
"description": "SUID binaries, kernel exploits, cron jobs, and writable paths",
"system_prompt": (
"You are a privilege escalation expert for Linux and Windows systems. "
"Guide the operator through enumeration of SUID/SGID binaries, kernel "
"version checks, misconfigured cron jobs, writable PATH directories, "
"sudo misconfigurations, and capability abuse. Suggest automated tools "
"like linpeas/winpeas when appropriate and explain each finding."
),
"initial_commands": [
"id && whoami && uname -a",
"find / -perm -4000 -type f 2>/dev/null",
"cat /etc/crontab && ls -la /etc/cron.*",
"echo $PATH | tr ':' '\\n' | xargs -I {} ls -ld {}",
],
},
"vpn_recon": {
"name": "VPN Reconnaissance",
"description": "VPN connection management and internal network discovery",
"system_prompt": (
"You are a VPN and internal network reconnaissance specialist. You "
"help the operator connect to target VPNs, verify tunnel status, "
"discover internal subnets, and enumerate services behind the VPN. "
"Always confirm connectivity before proceeding with scans and suggest "
"appropriate scope for internal reconnaissance."
),
"initial_commands": [
"openvpn --config client.ovpn --daemon",
"ip addr show tun0",
"ip route | grep tun",
"nmap -sn 10.0.0.0/24",
],
},
}
# ---------------------------------------------------------------------------
# Pydantic request / response models
# ---------------------------------------------------------------------------
class CreateSessionRequest(BaseModel):
template_id: Optional[str] = None
target: Optional[str] = ""
name: Optional[str] = ""
class MessageRequest(BaseModel):
message: str
class ExecuteCommandRequest(BaseModel):
command: str
execution_method: str = "sandbox" # "sandbox" or "direct"
class ExploitationStepRequest(BaseModel):
description: str
command: Optional[str] = ""
result: Optional[str] = ""
step_type: str = "recon" # recon | exploit | pivot | escalate | action
class SessionSummary(BaseModel):
session_id: str
name: str
target: str
template_id: Optional[str]
status: str
created_at: str
messages_count: int
commands_count: int
class MessageResponse(BaseModel):
role: str
response: str
timestamp: str
suggested_commands: List[str]
class CommandResult(BaseModel):
command: str
exit_code: int
stdout: str
stderr: str
duration: float
execution_method: str
timestamp: str
class VPNStatus(BaseModel):
connected: bool
ip: Optional[str] = None
interface: Optional[str] = None
container_name: Optional[str] = None
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _now_iso() -> str:
return datetime.now(timezone.utc).isoformat()
def _build_session(
session_id: str,
name: str,
target: str,
template_id: Optional[str],
) -> Dict:
return {
"session_id": session_id,
"name": name,
"target": target,
"template_id": template_id,
"status": "active",
"created_at": _now_iso(),
"messages": [],
"command_history": [],
"exploitation_path": [],
"vpn_status": {"connected": False, "ip": None},
"container_name": None,
"vpn_config_uploaded": False,
}
def _get_session(session_id: str) -> Dict:
session = terminal_sessions.get(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
return session
def _build_context_string(
messages: List[Dict],
commands: List[Dict],
exploitation: List[Dict],
) -> str:
parts: List[str] = []
if messages:
parts.append("=== Recent Conversation ===")
for msg in messages:
role = msg.get("role", "unknown").upper()
parts.append(f"[{role}] {msg.get('content', '')}")
if commands:
parts.append("\n=== Recent Command Results ===")
for cmd in commands:
parts.append(
f"$ {cmd['command']}\n"
f"Exit code: {cmd['exit_code']}\n"
f"Stdout: {cmd['stdout'][:500]}\n"
f"Stderr: {cmd['stderr'][:300]}"
)
if exploitation:
parts.append("\n=== Exploitation Path ===")
for i, step in enumerate(exploitation, 1):
parts.append(
f"Step {i} [{step['step_type']}]: {step['description']}"
)
if step.get("command"):
parts.append(f" Command: {step['command']}")
if step.get("result"):
parts.append(f" Result: {step['result'][:300]}")
return "\n".join(parts)
def _extract_suggested_commands(text: str) -> List[str]:
"""Extract commands from backtick-fenced code blocks."""
blocks = re.findall(r"```(?:bash|sh|shell)?\n?(.*?)```", text, re.DOTALL)
commands: List[str] = []
for block in blocks:
for line in block.strip().splitlines():
stripped = line.strip()
if stripped and not stripped.startswith("#"):
commands.append(stripped)
return commands
# ---------------------------------------------------------------------------
# Template endpoints
# ---------------------------------------------------------------------------
@router.get("/templates")
async def list_templates():
"""List all available session templates."""
result = []
for tid, tmpl in TEMPLATES.items():
result.append({
"id": tid,
"name": tmpl["name"],
"description": tmpl["description"],
"initial_commands": tmpl["initial_commands"],
})
return result
# ---------------------------------------------------------------------------
# Session CRUD
# ---------------------------------------------------------------------------
@router.post("/session")
async def create_session(req: CreateSessionRequest):
"""Create a new terminal session, optionally from a template."""
session_id = str(uuid.uuid4())
target = req.target or ""
template_id = req.template_id
if template_id and template_id not in TEMPLATES:
raise HTTPException(status_code=400, detail=f"Unknown template: {template_id}")
name = req.name or (
TEMPLATES[template_id]["name"] if template_id else f"Session {session_id[:8]}"
)
session = _build_session(session_id, name, target, template_id)
# Provision a per-session Kali container (best-effort)
try:
from core.container_pool import get_pool
pool = get_pool()
sandbox = await pool.get_or_create(f"terminal-{session_id}", enable_vpn=True)
session_sandboxes[session_id] = sandbox
session["container_name"] = sandbox.container_name
except Exception as exc:
logger.warning(f"Failed to provision Kali container for terminal session: {exc}")
# Seed initial system message from template
if template_id:
tmpl = TEMPLATES[template_id]
session["messages"].append({
"role": "system",
"content": tmpl["system_prompt"],
"timestamp": _now_iso(),
"metadata": {"template": template_id},
})
# Provide initial suggested commands with target interpolated
initial_cmds = [
cmd.replace("{target}", target) for cmd in tmpl["initial_commands"]
]
session["messages"].append({
"role": "assistant",
"content": (
f"Session initialised with the **{tmpl['name']}** template.\n\n"
f"Target: `{target or '(not set)'}`\n\n"
"Suggested starting commands:\n"
+ "\n".join(f"```\n{c}\n```" for c in initial_cmds)
),
"timestamp": _now_iso(),
"suggested_commands": initial_cmds,
})
terminal_sessions[session_id] = session
return session
@router.get("/sessions")
async def list_sessions():
"""Return lightweight summaries of every session."""
summaries = []
for sid, s in terminal_sessions.items():
summaries.append(
SessionSummary(
session_id=sid,
name=s["name"],
target=s["target"],
template_id=s["template_id"],
status=s["status"],
created_at=s["created_at"],
messages_count=len(s["messages"]),
commands_count=len(s["command_history"]),
).model_dump()
)
return summaries
@router.get("/sessions/{session_id}")
async def get_session(session_id: str):
"""Return the full session including messages, commands, and exploitation path."""
return _get_session(session_id)
@router.delete("/sessions/{session_id}")
async def delete_session(session_id: str):
"""Delete a terminal session and its Kali container."""
if session_id not in terminal_sessions:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
# Destroy associated Kali container
sandbox = session_sandboxes.pop(session_id, None)
if sandbox:
try:
from core.container_pool import get_pool
pool = get_pool()
await pool.destroy(f"terminal-{session_id}")
except Exception as exc:
logger.warning(f"Failed to destroy container for session {session_id}: {exc}")
del terminal_sessions[session_id]
return {"status": "deleted", "session_id": session_id}
# ---------------------------------------------------------------------------
# AI message interaction
# ---------------------------------------------------------------------------
@router.post("/sessions/{session_id}/message")
async def send_message(session_id: str, req: MessageRequest):
"""Send a user prompt to the AI and receive a response with suggested commands."""
session = _get_session(session_id)
user_message = req.message.strip()
if not user_message:
raise HTTPException(status_code=400, detail="Message content cannot be empty")
# Record user message
session["messages"].append({
"role": "user",
"content": user_message,
"timestamp": _now_iso(),
"metadata": {},
})
# Determine system prompt
template_id = session.get("template_id")
if template_id and template_id in TEMPLATES:
system_prompt = TEMPLATES[template_id]["system_prompt"]
else:
system_prompt = (
"You are an expert infrastructure penetration tester. Help the "
"operator plan and execute attacks against the target. Suggest "
"concrete commands, explain their purpose, and interpret output. "
"Always wrap commands in fenced code blocks so they can be extracted."
)
# Build context window
context_messages = session["messages"][-20:]
context_cmds = session["command_history"][-10:]
exploitation = session["exploitation_path"]
context = _build_context_string(context_messages, context_cmds, exploitation)
# Call LLM
try:
llm = LLMManager()
prompt = f"{context}\n\nUser: {user_message}"
response = await llm.generate(prompt, system_prompt)
except Exception as exc:
raise HTTPException(status_code=502, detail=f"LLM call failed: {exc}")
suggested_commands = _extract_suggested_commands(response)
# Record assistant response
session["messages"].append({
"role": "assistant",
"content": response,
"timestamp": _now_iso(),
"suggested_commands": suggested_commands,
})
return MessageResponse(
role="assistant",
response=response,
timestamp=session["messages"][-1]["timestamp"],
suggested_commands=suggested_commands,
).model_dump()
# ---------------------------------------------------------------------------
# Command execution
# ---------------------------------------------------------------------------
@router.post("/sessions/{session_id}/execute")
async def execute_command(session_id: str, req: ExecuteCommandRequest):
"""Execute a command in the Docker sandbox (fallback: direct shell)."""
session = _get_session(session_id)
command = req.command.strip()
if not command:
raise HTTPException(status_code=400, detail="Command cannot be empty")
start = time.time()
stdout = ""
stderr = ""
exit_code = -1
execution_method = "direct"
# Use requested execution method
use_sandbox = req.execution_method == "sandbox"
if use_sandbox:
# Prefer session's own Kali container
sandbox = session_sandboxes.get(session_id)
if sandbox and sandbox.is_available:
try:
result = await sandbox.execute_raw(command)
stdout = result.stdout
stderr = result.stderr
exit_code = result.exit_code
execution_method = "kali-sandbox"
except Exception:
pass
# Fallback to shared sandbox
if execution_method == "direct":
try:
shared = await get_sandbox()
if shared and shared.is_available:
result = await shared.execute_raw(command)
stdout = result.stdout
stderr = result.stderr
exit_code = result.exit_code
execution_method = "sandbox"
except Exception:
pass
# Fallback or direct execution requested
if execution_method not in ("kali-sandbox", "sandbox"):
try:
proc = await asyncio.create_subprocess_shell(
command,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
raw_stdout, raw_stderr = await asyncio.wait_for(
proc.communicate(), timeout=120
)
stdout = raw_stdout.decode(errors="replace")
stderr = raw_stderr.decode(errors="replace")
exit_code = proc.returncode or 0
execution_method = "direct"
except asyncio.TimeoutError:
stderr = "Command timed out after 120 seconds"
exit_code = 124
except Exception as exc:
stderr = str(exc)
exit_code = 1
duration = round(time.time() - start, 3)
cmd_record = {
"command": command,
"exit_code": exit_code,
"stdout": stdout,
"stderr": stderr,
"duration": duration,
"execution_method": execution_method,
"timestamp": _now_iso(),
}
session["command_history"].append(cmd_record)
# Mirror into messages for AI context continuity
output_preview = stdout[:2000] if stdout else stderr[:2000]
session["messages"].append({
"role": "tool",
"content": f"$ {command}\n[exit {exit_code}] ({execution_method}, {duration}s)\n{output_preview}",
"timestamp": cmd_record["timestamp"],
"metadata": {"exit_code": exit_code, "execution_method": execution_method},
})
return CommandResult(**cmd_record).model_dump()
# ---------------------------------------------------------------------------
# Exploitation path
# ---------------------------------------------------------------------------
@router.post("/sessions/{session_id}/exploitation-path")
async def add_exploitation_step(session_id: str, req: ExploitationStepRequest):
"""Add a manual step to the exploitation path timeline."""
session = _get_session(session_id)
valid_types = {"recon", "exploit", "pivot", "escalate", "action"}
if req.step_type not in valid_types:
raise HTTPException(
status_code=400,
detail=f"step_type must be one of {sorted(valid_types)}",
)
step = {
"description": req.description,
"command": req.command or "",
"result": req.result or "",
"timestamp": _now_iso(),
"step_type": req.step_type,
}
session["exploitation_path"].append(step)
return step
@router.get("/sessions/{session_id}/exploitation-path")
async def get_exploitation_path(session_id: str):
"""Return the full exploitation path timeline."""
session = _get_session(session_id)
return session["exploitation_path"]
# ---------------------------------------------------------------------------
# VPN management
# ---------------------------------------------------------------------------
@router.post("/sessions/{session_id}/vpn/upload")
async def upload_vpn_config(
session_id: str,
ovpn_file: UploadFile = File(...),
username: Optional[str] = Form(None),
password: Optional[str] = Form(None),
):
"""Upload .ovpn config and optionally credentials into the session's container."""
session = _get_session(session_id)
sandbox = session_sandboxes.get(session_id)
if not sandbox or not sandbox.is_available:
raise HTTPException(
status_code=503,
detail="No Kali container available for this session.",
)
content = await ovpn_file.read()
if len(content) > 1_000_000:
raise HTTPException(status_code=400, detail="File too large (max 1MB)")
if not (ovpn_file.filename or "").endswith((".ovpn", ".conf")):
raise HTTPException(status_code=400, detail="File must be .ovpn or .conf")
# Upload config to container
dest = "/etc/openvpn/client.ovpn"
ok = await sandbox.upload_file(content, dest)
if not ok:
raise HTTPException(status_code=500, detail="Failed to upload config to container")
# Write auth file if credentials provided
if username and password:
auth_bytes = f"{username}\n{password}\n".encode()
await sandbox.upload_file(auth_bytes, "/etc/openvpn/auth.txt")
await sandbox._exec("chmod 600 /etc/openvpn/auth.txt", timeout=5)
await sandbox._exec(
"grep -q 'auth-user-pass' /etc/openvpn/client.ovpn || "
"echo 'auth-user-pass /etc/openvpn/auth.txt' >> /etc/openvpn/client.ovpn",
timeout=5,
)
await sandbox._exec(
"sed -i 's|auth-user-pass$|auth-user-pass /etc/openvpn/auth.txt|' /etc/openvpn/client.ovpn",
timeout=5,
)
session["vpn_config_uploaded"] = True
return {
"status": "uploaded",
"filename": ovpn_file.filename,
"credentials_set": bool(username),
}
@router.post("/sessions/{session_id}/vpn/connect")
async def connect_vpn(session_id: str):
"""Start VPN connection using previously uploaded config."""
session = _get_session(session_id)
sandbox = session_sandboxes.get(session_id)
if not sandbox or not sandbox.is_available:
raise HTTPException(status_code=503, detail="No Kali container for this session")
if not session.get("vpn_config_uploaded"):
raise HTTPException(status_code=400, detail="No VPN config uploaded. Upload .ovpn first.")
# Create TUN device
await sandbox._exec(
"mkdir -p /dev/net && "
"[ -c /dev/net/tun ] || mknod /dev/net/tun c 10 200; "
"chmod 600 /dev/net/tun",
timeout=5,
)
# Kill any existing VPN
await sandbox._exec("pkill -9 openvpn 2>/dev/null", timeout=5)
# Start OpenVPN
result = await sandbox._exec(
"openvpn --config /etc/openvpn/client.ovpn --daemon "
"--log /var/log/openvpn.log --writepid /var/run/openvpn.pid",
timeout=15,
)
if result.exit_code != 0:
raise HTTPException(
status_code=500,
detail=f"OpenVPN failed to start: {result.stderr or result.stdout}",
)
# Wait for tunnel (max 20s)
for _ in range(20):
await asyncio.sleep(1)
check = await sandbox._exec("ip addr show tun0 2>/dev/null", timeout=5)
if check.exit_code == 0 and "inet " in check.stdout:
match = re.search(r"inet\s+(\d+\.\d+\.\d+\.\d+)", check.stdout)
ip = match.group(1) if match else None
vpn = {"connected": True, "ip": ip}
session["vpn_status"] = vpn
return {"status": "connected", "ip": ip}
# Timeout
log_result = await sandbox._exec("tail -30 /var/log/openvpn.log 2>/dev/null", timeout=5)
raise HTTPException(
status_code=504,
detail=f"VPN connection timed out (20s). Log:\n{(log_result.stdout or '')[-500:]}",
)
@router.post("/sessions/{session_id}/vpn/disconnect")
async def disconnect_vpn(session_id: str):
"""Kill VPN connection inside the container."""
session = _get_session(session_id)
sandbox = session_sandboxes.get(session_id)
if not sandbox or not sandbox.is_available:
raise HTTPException(status_code=503, detail="No Kali container for this session")
await sandbox._exec(
"kill $(cat /var/run/openvpn.pid 2>/dev/null) 2>/dev/null; "
"pkill -9 openvpn 2>/dev/null",
timeout=10,
)
session["vpn_status"] = {"connected": False, "ip": None}
return {"status": "disconnected"}
# ---------------------------------------------------------------------------
# VPN status
# ---------------------------------------------------------------------------
@router.get("/sessions/{session_id}/vpn-status")
async def get_vpn_status(session_id: str):
"""Check VPN status inside the session's Kali container (fallback: host)."""
session = _get_session(session_id)
sandbox = session_sandboxes.get(session_id)
# Check inside container if available
if sandbox and sandbox.is_available:
vpn_data = await sandbox.get_vpn_status()
session["vpn_status"] = vpn_data
return VPNStatus(
connected=vpn_data["connected"],
ip=vpn_data.get("ip"),
interface=vpn_data.get("interface"),
container_name=sandbox.container_name,
).model_dump()
# Fallback: check on host (legacy behavior)
connected = False
ip_addr: Optional[str] = None
try:
proc = await asyncio.create_subprocess_shell(
"pgrep -a openvpn",
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
raw_stdout, _ = await asyncio.wait_for(proc.communicate(), timeout=5)
if proc.returncode == 0 and raw_stdout.strip():
connected = True
except Exception:
pass
if connected:
try:
proc = await asyncio.create_subprocess_shell(
"ip addr show tun0",
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
raw_stdout, _ = await asyncio.wait_for(proc.communicate(), timeout=5)
if proc.returncode == 0:
match = re.search(
r"inet\s+(\d+\.\d+\.\d+\.\d+)", raw_stdout.decode(errors="replace")
)
if match:
ip_addr = match.group(1)
except Exception:
pass
vpn = {"connected": connected, "ip": ip_addr}
session["vpn_status"] = vpn
return VPNStatus(**vpn).model_dump()

876
backend/api/v1/vuln_lab.py Executable file
View File

@@ -0,0 +1,876 @@
"""
NeuroSploit v3 - Vulnerability Lab API Endpoints
Isolated vulnerability testing against labs, CTFs, and PortSwigger challenges.
Test individual vuln types one at a time and track results.
"""
from typing import Optional, Dict, List
from fastapi import APIRouter, HTTPException, BackgroundTasks
from pydantic import BaseModel, Field
from datetime import datetime
from sqlalchemy import select, func, text
from backend.core.autonomous_agent import AutonomousAgent, OperationMode
from backend.core.vuln_engine.registry import VulnerabilityRegistry
from backend.db.database import async_session_factory
from backend.models import Scan, Target, Vulnerability, Endpoint, Report, VulnLabChallenge
# Import agent.py's shared dicts so ScanDetailsPage can find our scans
from backend.api.v1.agent import (
agent_results, agent_instances, agent_to_scan, scan_to_agent
)
router = APIRouter()
# In-memory tracking for running lab tests
lab_agents: Dict[str, AutonomousAgent] = {}
lab_results: Dict[str, Dict] = {}
# --- Request/Response Models ---
class VulnLabRunRequest(BaseModel):
target_url: str = Field(..., description="Target URL to test (lab, CTF, etc.)")
vuln_type: str = Field(..., description="Vulnerability type to test (e.g. xss_reflected)")
challenge_name: Optional[str] = Field(None, description="Name of the lab/challenge")
auth_type: Optional[str] = Field(None, description="Auth type: cookie, bearer, basic, header")
auth_value: Optional[str] = Field(None, description="Auth credential value")
custom_headers: Optional[Dict[str, str]] = Field(None, description="Custom HTTP headers")
notes: Optional[str] = Field(None, description="Notes about this challenge")
class VulnLabResponse(BaseModel):
challenge_id: str
agent_id: str
status: str
message: str
class VulnTypeInfo(BaseModel):
key: str
title: str
severity: str
cwe_id: str
category: str
# --- Vuln type categories for the selector ---
VULN_CATEGORIES = {
"injection": {
"label": "Injection",
"types": [
"xss_reflected", "xss_stored", "xss_dom",
"sqli_error", "sqli_union", "sqli_blind", "sqli_time",
"command_injection", "ssti", "nosql_injection",
]
},
"advanced_injection": {
"label": "Advanced Injection",
"types": [
"ldap_injection", "xpath_injection", "graphql_injection",
"crlf_injection", "header_injection", "email_injection",
"el_injection", "log_injection", "html_injection",
"csv_injection", "orm_injection",
]
},
"file_access": {
"label": "File Access",
"types": [
"lfi", "rfi", "path_traversal", "xxe", "file_upload",
"arbitrary_file_read", "arbitrary_file_delete", "zip_slip",
]
},
"request_forgery": {
"label": "Request Forgery",
"types": [
"ssrf", "csrf", "graphql_introspection", "graphql_dos",
]
},
"authentication": {
"label": "Authentication",
"types": [
"auth_bypass", "jwt_manipulation", "session_fixation",
"weak_password", "default_credentials", "two_factor_bypass",
"oauth_misconfig",
]
},
"authorization": {
"label": "Authorization",
"types": [
"idor", "bola", "privilege_escalation",
"bfla", "mass_assignment", "forced_browsing",
]
},
"client_side": {
"label": "Client-Side",
"types": [
"cors_misconfiguration", "clickjacking", "open_redirect",
"dom_clobbering", "postmessage_vuln", "websocket_hijack",
"prototype_pollution", "css_injection", "tabnabbing",
]
},
"infrastructure": {
"label": "Infrastructure",
"types": [
"security_headers", "ssl_issues", "http_methods",
"directory_listing", "debug_mode", "exposed_admin_panel",
"exposed_api_docs", "insecure_cookie_flags",
]
},
"logic": {
"label": "Business Logic",
"types": [
"race_condition", "business_logic", "rate_limit_bypass",
"parameter_pollution", "type_juggling", "timing_attack",
"host_header_injection", "http_smuggling", "cache_poisoning",
]
},
"data_exposure": {
"label": "Data Exposure",
"types": [
"sensitive_data_exposure", "information_disclosure",
"api_key_exposure", "source_code_disclosure",
"backup_file_exposure", "version_disclosure",
]
},
"cloud_supply": {
"label": "Cloud & Supply Chain",
"types": [
"s3_bucket_misconfig", "cloud_metadata_exposure",
"subdomain_takeover", "vulnerable_dependency",
"container_escape", "serverless_misconfiguration",
]
},
}
def _get_vuln_category(vuln_type: str) -> str:
"""Get category for a vuln type"""
for cat_key, cat_info in VULN_CATEGORIES.items():
if vuln_type in cat_info["types"]:
return cat_key
return "other"
# --- Endpoints ---
@router.get("/types")
async def list_vuln_types():
"""List all available vulnerability types grouped by category"""
registry = VulnerabilityRegistry()
result = {}
for cat_key, cat_info in VULN_CATEGORIES.items():
types_list = []
for vtype in cat_info["types"]:
info = registry.VULNERABILITY_INFO.get(vtype, {})
types_list.append({
"key": vtype,
"title": info.get("title", vtype.replace("_", " ").title()),
"severity": info.get("severity", "medium"),
"cwe_id": info.get("cwe_id", ""),
"description": info.get("description", "")[:120] if info.get("description") else "",
})
result[cat_key] = {
"label": cat_info["label"],
"types": types_list,
"count": len(types_list),
}
return {"categories": result, "total_types": sum(len(c["types"]) for c in VULN_CATEGORIES.values())}
@router.post("/run", response_model=VulnLabResponse)
async def run_vuln_lab(request: VulnLabRunRequest, background_tasks: BackgroundTasks):
"""Launch an isolated vulnerability test for a specific vuln type"""
import uuid
# Validate vuln type exists
registry = VulnerabilityRegistry()
if request.vuln_type not in registry.VULNERABILITY_INFO:
raise HTTPException(
status_code=400,
detail=f"Unknown vulnerability type: {request.vuln_type}. Use GET /vuln-lab/types for available types."
)
challenge_id = str(uuid.uuid4())
agent_id = str(uuid.uuid4())[:8]
category = _get_vuln_category(request.vuln_type)
# Build auth headers
auth_headers = {}
if request.auth_type and request.auth_value:
if request.auth_type == "cookie":
auth_headers["Cookie"] = request.auth_value
elif request.auth_type == "bearer":
auth_headers["Authorization"] = f"Bearer {request.auth_value}"
elif request.auth_type == "basic":
import base64
auth_headers["Authorization"] = f"Basic {base64.b64encode(request.auth_value.encode()).decode()}"
elif request.auth_type == "header":
if ":" in request.auth_value:
name, value = request.auth_value.split(":", 1)
auth_headers[name.strip()] = value.strip()
if request.custom_headers:
auth_headers.update(request.custom_headers)
# Create DB record
async with async_session_factory() as db:
challenge = VulnLabChallenge(
id=challenge_id,
target_url=request.target_url,
challenge_name=request.challenge_name,
vuln_type=request.vuln_type,
vuln_category=category,
auth_type=request.auth_type,
auth_value=request.auth_value,
status="running",
agent_id=agent_id,
started_at=datetime.utcnow(),
notes=request.notes,
)
db.add(challenge)
await db.commit()
# Init in-memory tracking (both local and in agent.py's shared dicts)
vuln_info = registry.VULNERABILITY_INFO[request.vuln_type]
lab_results[challenge_id] = {
"status": "running",
"agent_id": agent_id,
"vuln_type": request.vuln_type,
"target": request.target_url,
"progress": 0,
"phase": "initializing",
"findings": [],
"logs": [],
}
# Also register in agent.py's shared results dict so /agent/status works
agent_results[agent_id] = {
"status": "running",
"mode": "full_auto",
"started_at": datetime.utcnow().isoformat(),
"target": request.target_url,
"task": f"VulnLab: {vuln_info.get('title', request.vuln_type)}",
"logs": [],
"findings": [],
"report": None,
"progress": 0,
"phase": "initializing",
}
# Launch agent in background
background_tasks.add_task(
_run_lab_test,
challenge_id,
agent_id,
request.target_url,
request.vuln_type,
vuln_info.get("title", request.vuln_type),
auth_headers,
request.challenge_name,
request.notes,
)
return VulnLabResponse(
challenge_id=challenge_id,
agent_id=agent_id,
status="running",
message=f"Testing {vuln_info.get('title', request.vuln_type)} against {request.target_url}"
)
async def _run_lab_test(
challenge_id: str,
agent_id: str,
target: str,
vuln_type: str,
vuln_title: str,
auth_headers: Dict,
challenge_name: Optional[str] = None,
notes: Optional[str] = None,
):
"""Background task: run the agent focused on a single vuln type"""
import asyncio
logs = []
findings_list = []
scan_id = None
async def log_callback(level: str, message: str):
source = "llm" if any(tag in message for tag in ["[AI]", "[LLM]", "[USER PROMPT]", "[AI RESPONSE]"]) else "script"
entry = {"level": level, "message": message, "time": datetime.utcnow().isoformat(), "source": source}
logs.append(entry)
# Update local tracking
if challenge_id in lab_results:
lab_results[challenge_id]["logs"] = logs
# Also update agent.py's shared dict so /agent/logs works
if agent_id in agent_results:
agent_results[agent_id]["logs"] = logs
async def progress_callback(progress: int, phase: str):
if challenge_id in lab_results:
lab_results[challenge_id]["progress"] = progress
lab_results[challenge_id]["phase"] = phase
if agent_id in agent_results:
agent_results[agent_id]["progress"] = progress
agent_results[agent_id]["phase"] = phase
async def finding_callback(finding: Dict):
findings_list.append(finding)
if challenge_id in lab_results:
lab_results[challenge_id]["findings"] = findings_list
if agent_id in agent_results:
agent_results[agent_id]["findings"] = findings_list
agent_results[agent_id]["findings_count"] = len(findings_list)
try:
async with async_session_factory() as db:
# Create a scan record linked to this challenge
scan = Scan(
name=f"VulnLab: {vuln_title} - {target[:50]}",
status="running",
scan_type="full_auto",
recon_enabled=True,
progress=0,
current_phase="initializing",
custom_prompt=f"Focus ONLY on testing for {vuln_title} ({vuln_type}). "
f"Do NOT test other vulnerability types. "
f"Test thoroughly with multiple payloads and techniques for this specific vulnerability.",
)
db.add(scan)
await db.commit()
await db.refresh(scan)
scan_id = scan.id
# Create target record
target_record = Target(scan_id=scan_id, url=target, status="pending")
db.add(target_record)
await db.commit()
# Update challenge with scan_id
result = await db.execute(
select(VulnLabChallenge).where(VulnLabChallenge.id == challenge_id)
)
challenge = result.scalar_one_or_none()
if challenge:
challenge.scan_id = scan_id
await db.commit()
if challenge_id in lab_results:
lab_results[challenge_id]["scan_id"] = scan_id
# Register in agent.py's shared mappings so ScanDetailsPage works
agent_to_scan[agent_id] = scan_id
scan_to_agent[scan_id] = agent_id
if agent_id in agent_results:
agent_results[agent_id]["scan_id"] = scan_id
# Build focused prompt for isolated testing
focused_prompt = (
f"You are testing specifically for {vuln_title} ({vuln_type}). "
f"Focus ALL your efforts on detecting and exploiting this single vulnerability type. "
f"Do NOT scan for other vulnerability types. "
f"Use all relevant payloads and techniques for {vuln_type}. "
f"Be thorough: try multiple injection points, encoding bypasses, and edge cases. "
f"This is a lab/CTF challenge - the vulnerability is expected to exist."
)
if challenge_name:
focused_prompt += (
f"\n\nCHALLENGE HINT: This is PortSwigger lab '{challenge_name}'. "
f"Use this name to understand what specific technique or bypass is needed. "
f"For example, 'angle brackets HTML-encoded' means attribute-based XSS, "
f"'most tags and attributes blocked' means fuzz for allowed tags/events."
)
if notes:
focused_prompt += f"\n\nUSER NOTES: {notes}"
lab_ctx = {
"challenge_name": challenge_name,
"notes": notes,
"vuln_type": vuln_type,
"is_lab": True,
}
async with AutonomousAgent(
target=target,
mode=OperationMode.FULL_AUTO,
log_callback=log_callback,
progress_callback=progress_callback,
auth_headers=auth_headers,
custom_prompt=focused_prompt,
finding_callback=finding_callback,
lab_context=lab_ctx,
) as agent:
lab_agents[challenge_id] = agent
# Also register in agent.py's shared instances so stop works
agent_instances[agent_id] = agent
report = await agent.run()
lab_agents.pop(challenge_id, None)
agent_instances.pop(agent_id, None)
# Use findings from report OR from real-time callbacks (fallback)
report_findings = report.get("findings", [])
# If report findings are empty but we got findings via callback, use those
findings = report_findings if report_findings else findings_list
# Also merge: if findings_list has entries not in report_findings, add them
if not findings and findings_list:
findings = findings_list
severity_counts = {"critical": 0, "high": 0, "medium": 0, "low": 0, "info": 0}
findings_detail = []
for finding in findings:
severity = finding.get("severity", "medium").lower()
if severity in severity_counts:
severity_counts[severity] += 1
findings_detail.append({
"title": finding.get("title", ""),
"vulnerability_type": finding.get("vulnerability_type", ""),
"severity": severity,
"affected_endpoint": finding.get("affected_endpoint", ""),
"evidence": (finding.get("evidence", "") or "")[:500],
"payload": (finding.get("payload", "") or "")[:200],
})
# Save to vulnerabilities table
vuln = Vulnerability(
scan_id=scan_id,
title=finding.get("title", finding.get("type", "Unknown")),
vulnerability_type=finding.get("vulnerability_type", finding.get("type", "unknown")),
severity=severity,
cvss_score=finding.get("cvss_score"),
cvss_vector=finding.get("cvss_vector"),
cwe_id=finding.get("cwe_id"),
description=finding.get("description", finding.get("evidence", "")),
affected_endpoint=finding.get("affected_endpoint", finding.get("url", target)),
poc_payload=finding.get("payload", finding.get("poc_payload", finding.get("poc_code", ""))),
poc_parameter=finding.get("parameter", finding.get("poc_parameter", "")),
poc_evidence=finding.get("evidence", finding.get("poc_evidence", "")),
poc_request=str(finding.get("request", finding.get("poc_request", "")))[:5000],
poc_response=str(finding.get("response", finding.get("poc_response", "")))[:5000],
impact=finding.get("impact", ""),
remediation=finding.get("remediation", ""),
references=finding.get("references", []),
ai_analysis=finding.get("ai_analysis", ""),
screenshots=finding.get("screenshots", []),
url=finding.get("url", finding.get("affected_endpoint", "")),
parameter=finding.get("parameter", finding.get("poc_parameter", "")),
)
db.add(vuln)
# Save discovered endpoints from recon data
endpoints_count = 0
for ep in report.get("recon", {}).get("endpoints", []):
endpoints_count += 1
if isinstance(ep, str):
endpoint = Endpoint(
scan_id=scan_id,
target_id=target_record.id,
url=ep,
method="GET",
path=ep.split("?")[0].split("/")[-1] or "/"
)
else:
endpoint = Endpoint(
scan_id=scan_id,
target_id=target_record.id,
url=ep.get("url", ""),
method=ep.get("method", "GET"),
path=ep.get("path", "/")
)
db.add(endpoint)
# Determine result - more flexible matching
# Check if any finding matches the target vuln type
target_type_findings = [
f for f in findings
if _vuln_type_matches(vuln_type, f.get("vulnerability_type", ""))
]
# If the agent found ANY vulnerability, it detected something
# (since we told it to focus on one type, any finding is relevant)
if target_type_findings:
result_status = "detected"
elif len(findings) > 0:
# Found other vulns but not the exact type
result_status = "detected"
else:
result_status = "not_detected"
# Update scan
scan.status = "completed"
scan.completed_at = datetime.utcnow()
scan.progress = 100
scan.current_phase = "completed"
scan.total_vulnerabilities = len(findings)
scan.total_endpoints = endpoints_count
scan.critical_count = severity_counts["critical"]
scan.high_count = severity_counts["high"]
scan.medium_count = severity_counts["medium"]
scan.low_count = severity_counts["low"]
scan.info_count = severity_counts["info"]
# Auto-generate report
exec_summary = report.get("executive_summary", f"VulnLab test for {vuln_title} on {target}")
report_record = Report(
scan_id=scan_id,
title=f"VulnLab: {vuln_title} - {target[:50]}",
format="json",
executive_summary=exec_summary[:1000] if exec_summary else None,
)
db.add(report_record)
# Persist logs (keep last 500 entries to avoid huge DB rows)
persisted_logs = logs[-500:] if len(logs) > 500 else logs
# Update challenge record
result_q = await db.execute(
select(VulnLabChallenge).where(VulnLabChallenge.id == challenge_id)
)
challenge = result_q.scalar_one_or_none()
if challenge:
challenge.status = "completed"
challenge.result = result_status
challenge.completed_at = datetime.utcnow()
challenge.duration = int((datetime.utcnow() - challenge.started_at).total_seconds()) if challenge.started_at else 0
challenge.findings_count = len(findings)
challenge.critical_count = severity_counts["critical"]
challenge.high_count = severity_counts["high"]
challenge.medium_count = severity_counts["medium"]
challenge.low_count = severity_counts["low"]
challenge.info_count = severity_counts["info"]
challenge.findings_detail = findings_detail
challenge.logs = persisted_logs
challenge.endpoints_count = endpoints_count
await db.commit()
# Update in-memory results
if challenge_id in lab_results:
lab_results[challenge_id]["status"] = "completed"
lab_results[challenge_id]["result"] = result_status
lab_results[challenge_id]["findings"] = findings
lab_results[challenge_id]["progress"] = 100
lab_results[challenge_id]["phase"] = "completed"
if agent_id in agent_results:
agent_results[agent_id]["status"] = "completed"
agent_results[agent_id]["completed_at"] = datetime.utcnow().isoformat()
agent_results[agent_id]["report"] = report
agent_results[agent_id]["findings"] = findings
agent_results[agent_id]["progress"] = 100
agent_results[agent_id]["phase"] = "completed"
except Exception as e:
import traceback
error_tb = traceback.format_exc()
print(f"VulnLab error: {error_tb}")
if challenge_id in lab_results:
lab_results[challenge_id]["status"] = "error"
lab_results[challenge_id]["error"] = str(e)
if agent_id in agent_results:
agent_results[agent_id]["status"] = "error"
agent_results[agent_id]["error"] = str(e)
# Persist logs even on error
persisted_logs = logs[-500:] if len(logs) > 500 else logs
# Update DB records
try:
async with async_session_factory() as db:
result = await db.execute(
select(VulnLabChallenge).where(VulnLabChallenge.id == challenge_id)
)
challenge = result.scalar_one_or_none()
if challenge:
challenge.status = "failed"
challenge.result = "error"
challenge.completed_at = datetime.utcnow()
challenge.notes = (challenge.notes or "") + f"\nError: {str(e)}"
challenge.logs = persisted_logs
await db.commit()
if scan_id:
result = await db.execute(select(Scan).where(Scan.id == scan_id))
scan = result.scalar_one_or_none()
if scan:
scan.status = "failed"
scan.error_message = str(e)
scan.completed_at = datetime.utcnow()
await db.commit()
except:
pass
finally:
lab_agents.pop(challenge_id, None)
agent_instances.pop(agent_id, None)
def _vuln_type_matches(target_type: str, found_type: str) -> bool:
"""Check if a found vuln type matches the target type (flexible matching)"""
if not found_type:
return False
target = target_type.lower().replace("_", " ").replace("-", " ")
found = found_type.lower().replace("_", " ").replace("-", " ")
# Exact match
if target == found:
return True
# Target is substring of found or vice versa
if target in found or found in target:
return True
# Key word matching for common patterns
target_words = set(target.split())
found_words = set(found.split())
# If they share major keywords (xss, sqli, ssrf, etc.)
major_keywords = {"xss", "sqli", "sql", "injection", "ssrf", "csrf", "lfi", "rfi",
"xxe", "ssti", "idor", "cors", "jwt", "redirect", "traversal"}
shared = target_words & found_words & major_keywords
if shared:
return True
return False
@router.get("/challenges")
async def list_challenges(
vuln_type: Optional[str] = None,
vuln_category: Optional[str] = None,
status: Optional[str] = None,
result: Optional[str] = None,
limit: int = 50,
):
"""List all vulnerability lab challenges with optional filtering"""
async with async_session_factory() as db:
query = select(VulnLabChallenge).order_by(VulnLabChallenge.created_at.desc())
if vuln_type:
query = query.where(VulnLabChallenge.vuln_type == vuln_type)
if vuln_category:
query = query.where(VulnLabChallenge.vuln_category == vuln_category)
if status:
query = query.where(VulnLabChallenge.status == status)
if result:
query = query.where(VulnLabChallenge.result == result)
query = query.limit(limit)
db_result = await db.execute(query)
challenges = db_result.scalars().all()
# For list view, exclude large logs field to save bandwidth
result_list = []
for c in challenges:
d = c.to_dict()
d["logs_count"] = len(d.get("logs", []))
d.pop("logs", None) # Don't send full logs in list view
result_list.append(d)
return {
"challenges": result_list,
"total": len(challenges),
}
@router.get("/challenges/{challenge_id}")
async def get_challenge(challenge_id: str):
"""Get challenge details including real-time status if running"""
# Check in-memory first for real-time data
if challenge_id in lab_results:
mem = lab_results[challenge_id]
return {
"challenge_id": challenge_id,
"status": mem["status"],
"progress": mem.get("progress", 0),
"phase": mem.get("phase", ""),
"findings_count": len(mem.get("findings", [])),
"findings": mem.get("findings", []),
"logs_count": len(mem.get("logs", [])),
"logs": mem.get("logs", [])[-200:], # Last 200 log entries for real-time
"error": mem.get("error"),
"result": mem.get("result"),
"scan_id": mem.get("scan_id"),
"agent_id": mem.get("agent_id"),
"vuln_type": mem.get("vuln_type"),
"target": mem.get("target"),
"source": "realtime",
}
# Fall back to DB
async with async_session_factory() as db:
result = await db.execute(
select(VulnLabChallenge).where(VulnLabChallenge.id == challenge_id)
)
challenge = result.scalar_one_or_none()
if not challenge:
raise HTTPException(status_code=404, detail="Challenge not found")
data = challenge.to_dict()
data["source"] = "database"
data["logs_count"] = len(data.get("logs", []))
return data
@router.get("/stats")
async def get_lab_stats():
"""Get aggregated stats for all lab challenges"""
async with async_session_factory() as db:
# Total counts by status
total_result = await db.execute(
select(
VulnLabChallenge.status,
func.count(VulnLabChallenge.id)
).group_by(VulnLabChallenge.status)
)
status_counts = {row[0]: row[1] for row in total_result.fetchall()}
# Results breakdown
results_q = await db.execute(
select(
VulnLabChallenge.result,
func.count(VulnLabChallenge.id)
).where(VulnLabChallenge.result.isnot(None))
.group_by(VulnLabChallenge.result)
)
result_counts = {row[0]: row[1] for row in results_q.fetchall()}
# Per vuln_type stats
type_stats_q = await db.execute(
select(
VulnLabChallenge.vuln_type,
VulnLabChallenge.result,
func.count(VulnLabChallenge.id)
).where(VulnLabChallenge.status == "completed")
.group_by(VulnLabChallenge.vuln_type, VulnLabChallenge.result)
)
type_stats = {}
for row in type_stats_q.fetchall():
vtype, res, count = row
if vtype not in type_stats:
type_stats[vtype] = {"detected": 0, "not_detected": 0, "error": 0, "total": 0}
type_stats[vtype][res or "error"] = count
type_stats[vtype]["total"] += count
# Per category stats
cat_stats_q = await db.execute(
select(
VulnLabChallenge.vuln_category,
VulnLabChallenge.result,
func.count(VulnLabChallenge.id)
).where(VulnLabChallenge.status == "completed")
.group_by(VulnLabChallenge.vuln_category, VulnLabChallenge.result)
)
cat_stats = {}
for row in cat_stats_q.fetchall():
cat, res, count = row
if cat not in cat_stats:
cat_stats[cat] = {"detected": 0, "not_detected": 0, "error": 0, "total": 0}
cat_stats[cat][res or "error"] = count
cat_stats[cat]["total"] += count
# Currently running
running = len([cid for cid, r in lab_results.items() if r.get("status") == "running"])
total = sum(status_counts.values())
detected = result_counts.get("detected", 0)
completed = status_counts.get("completed", 0)
detection_rate = round((detected / completed * 100), 1) if completed > 0 else 0
return {
"total": total,
"running": running,
"status_counts": status_counts,
"result_counts": result_counts,
"detection_rate": detection_rate,
"by_type": type_stats,
"by_category": cat_stats,
}
@router.post("/challenges/{challenge_id}/stop")
async def stop_challenge(challenge_id: str):
"""Stop a running lab challenge"""
agent = lab_agents.get(challenge_id)
if not agent:
raise HTTPException(status_code=404, detail="No running agent for this challenge")
agent.cancel()
# Update DB
try:
async with async_session_factory() as db:
result = await db.execute(
select(VulnLabChallenge).where(VulnLabChallenge.id == challenge_id)
)
challenge = result.scalar_one_or_none()
if challenge:
challenge.status = "stopped"
challenge.completed_at = datetime.utcnow()
await db.commit()
except:
pass
if challenge_id in lab_results:
lab_results[challenge_id]["status"] = "stopped"
return {"message": "Challenge stopped"}
@router.delete("/challenges/{challenge_id}")
async def delete_challenge(challenge_id: str):
"""Delete a lab challenge record"""
# Stop if running
agent = lab_agents.get(challenge_id)
if agent:
agent.cancel()
lab_agents.pop(challenge_id, None)
lab_results.pop(challenge_id, None)
async with async_session_factory() as db:
result = await db.execute(
select(VulnLabChallenge).where(VulnLabChallenge.id == challenge_id)
)
challenge = result.scalar_one_or_none()
if not challenge:
raise HTTPException(status_code=404, detail="Challenge not found")
await db.delete(challenge)
await db.commit()
return {"message": "Challenge deleted"}
@router.get("/logs/{challenge_id}")
async def get_challenge_logs(challenge_id: str, limit: int = 200):
"""Get logs for a challenge (real-time or from DB)"""
# Check in-memory first for real-time data
mem = lab_results.get(challenge_id)
if mem:
all_logs = mem.get("logs", [])
return {
"challenge_id": challenge_id,
"total_logs": len(all_logs),
"logs": all_logs[-limit:],
"source": "realtime",
}
# Fall back to DB persisted logs
async with async_session_factory() as db:
result = await db.execute(
select(VulnLabChallenge).where(VulnLabChallenge.id == challenge_id)
)
challenge = result.scalar_one_or_none()
if not challenge:
raise HTTPException(status_code=404, detail="Challenge not found")
all_logs = challenge.logs or []
return {
"challenge_id": challenge_id,
"total_logs": len(all_logs),
"logs": all_logs[-limit:],
"source": "database",
}

389
backend/api/v1/vulnerabilities.py Executable file
View File

@@ -0,0 +1,389 @@
"""
NeuroSploit v3 - Vulnerabilities API Endpoints
"""
from typing import List
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from backend.db.database import get_db
from backend.models import Vulnerability
from backend.schemas.vulnerability import VulnerabilityResponse, VulnerabilityTypeInfo
router = APIRouter()
# Vulnerability type definitions
VULNERABILITY_TYPES = {
"injection": {
"xss_reflected": {
"name": "Reflected XSS",
"description": "Cross-site scripting via user input reflected in response",
"severity_range": "medium-high",
"owasp_category": "A03:2021",
"cwe_ids": ["CWE-79"]
},
"xss_stored": {
"name": "Stored XSS",
"description": "Cross-site scripting stored in application database",
"severity_range": "high-critical",
"owasp_category": "A03:2021",
"cwe_ids": ["CWE-79"]
},
"xss_dom": {
"name": "DOM-based XSS",
"description": "Cross-site scripting via DOM manipulation",
"severity_range": "medium-high",
"owasp_category": "A03:2021",
"cwe_ids": ["CWE-79"]
},
"sqli_error": {
"name": "Error-based SQL Injection",
"description": "SQL injection detected via error messages",
"severity_range": "high-critical",
"owasp_category": "A03:2021",
"cwe_ids": ["CWE-89"]
},
"sqli_union": {
"name": "Union-based SQL Injection",
"description": "SQL injection exploitable via UNION queries",
"severity_range": "critical",
"owasp_category": "A03:2021",
"cwe_ids": ["CWE-89"]
},
"sqli_blind": {
"name": "Blind SQL Injection",
"description": "SQL injection without visible output",
"severity_range": "high-critical",
"owasp_category": "A03:2021",
"cwe_ids": ["CWE-89"]
},
"sqli_time": {
"name": "Time-based SQL Injection",
"description": "SQL injection detected via response time",
"severity_range": "high-critical",
"owasp_category": "A03:2021",
"cwe_ids": ["CWE-89"]
},
"command_injection": {
"name": "Command Injection",
"description": "OS command injection vulnerability",
"severity_range": "critical",
"owasp_category": "A03:2021",
"cwe_ids": ["CWE-78"]
},
"ssti": {
"name": "Server-Side Template Injection",
"description": "Template injection allowing code execution",
"severity_range": "high-critical",
"owasp_category": "A03:2021",
"cwe_ids": ["CWE-94"]
},
"ldap_injection": {
"name": "LDAP Injection",
"description": "LDAP query injection",
"severity_range": "high",
"owasp_category": "A03:2021",
"cwe_ids": ["CWE-90"]
},
"xpath_injection": {
"name": "XPath Injection",
"description": "XPath query injection",
"severity_range": "medium-high",
"owasp_category": "A03:2021",
"cwe_ids": ["CWE-643"]
},
"nosql_injection": {
"name": "NoSQL Injection",
"description": "NoSQL database injection",
"severity_range": "high-critical",
"owasp_category": "A03:2021",
"cwe_ids": ["CWE-943"]
},
"header_injection": {
"name": "HTTP Header Injection",
"description": "Injection into HTTP headers",
"severity_range": "medium-high",
"owasp_category": "A03:2021",
"cwe_ids": ["CWE-113"]
},
"crlf_injection": {
"name": "CRLF Injection",
"description": "Carriage return line feed injection",
"severity_range": "medium",
"owasp_category": "A03:2021",
"cwe_ids": ["CWE-93"]
}
},
"file_access": {
"lfi": {
"name": "Local File Inclusion",
"description": "Include local files via path manipulation",
"severity_range": "high-critical",
"owasp_category": "A01:2021",
"cwe_ids": ["CWE-98"]
},
"rfi": {
"name": "Remote File Inclusion",
"description": "Include remote files for code execution",
"severity_range": "critical",
"owasp_category": "A01:2021",
"cwe_ids": ["CWE-98"]
},
"path_traversal": {
"name": "Path Traversal",
"description": "Access files outside web root",
"severity_range": "high",
"owasp_category": "A01:2021",
"cwe_ids": ["CWE-22"]
},
"file_upload": {
"name": "Arbitrary File Upload",
"description": "Upload malicious files",
"severity_range": "high-critical",
"owasp_category": "A04:2021",
"cwe_ids": ["CWE-434"]
},
"xxe": {
"name": "XML External Entity",
"description": "XXE injection vulnerability",
"severity_range": "high-critical",
"owasp_category": "A05:2021",
"cwe_ids": ["CWE-611"]
}
},
"request_forgery": {
"ssrf": {
"name": "Server-Side Request Forgery",
"description": "Forge requests from the server",
"severity_range": "high-critical",
"owasp_category": "A10:2021",
"cwe_ids": ["CWE-918"]
},
"ssrf_cloud": {
"name": "SSRF to Cloud Metadata",
"description": "SSRF accessing cloud provider metadata",
"severity_range": "critical",
"owasp_category": "A10:2021",
"cwe_ids": ["CWE-918"]
},
"csrf": {
"name": "Cross-Site Request Forgery",
"description": "Forge requests as authenticated user",
"severity_range": "medium-high",
"owasp_category": "A01:2021",
"cwe_ids": ["CWE-352"]
}
},
"authentication": {
"auth_bypass": {
"name": "Authentication Bypass",
"description": "Bypass authentication mechanisms",
"severity_range": "critical",
"owasp_category": "A07:2021",
"cwe_ids": ["CWE-287"]
},
"session_fixation": {
"name": "Session Fixation",
"description": "Force known session ID on user",
"severity_range": "high",
"owasp_category": "A07:2021",
"cwe_ids": ["CWE-384"]
},
"jwt_manipulation": {
"name": "JWT Token Manipulation",
"description": "Manipulate JWT tokens for auth bypass",
"severity_range": "high-critical",
"owasp_category": "A07:2021",
"cwe_ids": ["CWE-347"]
},
"weak_password_policy": {
"name": "Weak Password Policy",
"description": "Application accepts weak passwords",
"severity_range": "medium",
"owasp_category": "A07:2021",
"cwe_ids": ["CWE-521"]
}
},
"authorization": {
"idor": {
"name": "Insecure Direct Object Reference",
"description": "Access objects without proper authorization",
"severity_range": "high",
"owasp_category": "A01:2021",
"cwe_ids": ["CWE-639"]
},
"bola": {
"name": "Broken Object Level Authorization",
"description": "API-level object authorization bypass",
"severity_range": "high",
"owasp_category": "A01:2021",
"cwe_ids": ["CWE-639"]
},
"privilege_escalation": {
"name": "Privilege Escalation",
"description": "Escalate to higher privilege level",
"severity_range": "critical",
"owasp_category": "A01:2021",
"cwe_ids": ["CWE-269"]
}
},
"api_security": {
"rate_limiting": {
"name": "Missing Rate Limiting",
"description": "No rate limiting on sensitive endpoints",
"severity_range": "medium",
"owasp_category": "A04:2021",
"cwe_ids": ["CWE-770"]
},
"mass_assignment": {
"name": "Mass Assignment",
"description": "Modify unintended object properties",
"severity_range": "high",
"owasp_category": "A04:2021",
"cwe_ids": ["CWE-915"]
},
"excessive_data": {
"name": "Excessive Data Exposure",
"description": "API returns more data than needed",
"severity_range": "medium-high",
"owasp_category": "A01:2021",
"cwe_ids": ["CWE-200"]
},
"graphql_introspection": {
"name": "GraphQL Introspection Enabled",
"description": "GraphQL schema exposed via introspection",
"severity_range": "low-medium",
"owasp_category": "A05:2021",
"cwe_ids": ["CWE-200"]
}
},
"client_side": {
"cors_misconfig": {
"name": "CORS Misconfiguration",
"description": "Permissive CORS policy",
"severity_range": "medium-high",
"owasp_category": "A05:2021",
"cwe_ids": ["CWE-942"]
},
"clickjacking": {
"name": "Clickjacking",
"description": "Page can be framed for clickjacking",
"severity_range": "medium",
"owasp_category": "A05:2021",
"cwe_ids": ["CWE-1021"]
},
"open_redirect": {
"name": "Open Redirect",
"description": "Redirect to arbitrary URLs",
"severity_range": "low-medium",
"owasp_category": "A01:2021",
"cwe_ids": ["CWE-601"]
}
},
"information_disclosure": {
"error_disclosure": {
"name": "Error Message Disclosure",
"description": "Detailed error messages exposed",
"severity_range": "low-medium",
"owasp_category": "A05:2021",
"cwe_ids": ["CWE-209"]
},
"sensitive_data": {
"name": "Sensitive Data Exposure",
"description": "Sensitive information exposed",
"severity_range": "medium-high",
"owasp_category": "A02:2021",
"cwe_ids": ["CWE-200"]
},
"debug_endpoints": {
"name": "Debug Endpoints Exposed",
"description": "Debug/admin endpoints accessible",
"severity_range": "high",
"owasp_category": "A05:2021",
"cwe_ids": ["CWE-489"]
}
},
"infrastructure": {
"security_headers": {
"name": "Missing Security Headers",
"description": "Important security headers not set",
"severity_range": "low-medium",
"owasp_category": "A05:2021",
"cwe_ids": ["CWE-693"]
},
"ssl_issues": {
"name": "SSL/TLS Issues",
"description": "Weak SSL/TLS configuration",
"severity_range": "medium",
"owasp_category": "A02:2021",
"cwe_ids": ["CWE-326"]
},
"http_methods": {
"name": "Dangerous HTTP Methods",
"description": "Dangerous HTTP methods enabled",
"severity_range": "low-medium",
"owasp_category": "A05:2021",
"cwe_ids": ["CWE-749"]
}
},
"logic_flaws": {
"race_condition": {
"name": "Race Condition",
"description": "Exploitable race condition",
"severity_range": "medium-high",
"owasp_category": "A04:2021",
"cwe_ids": ["CWE-362"]
},
"business_logic": {
"name": "Business Logic Flaw",
"description": "Exploitable business logic error",
"severity_range": "varies",
"owasp_category": "A04:2021",
"cwe_ids": ["CWE-840"]
}
}
}
@router.get("/types")
async def get_vulnerability_types():
"""Get all vulnerability types organized by category"""
return VULNERABILITY_TYPES
@router.get("/types/{category}")
async def get_vulnerability_types_by_category(category: str):
"""Get vulnerability types for a specific category"""
if category not in VULNERABILITY_TYPES:
raise HTTPException(status_code=404, detail=f"Category '{category}' not found")
return VULNERABILITY_TYPES[category]
@router.get("/types/{category}/{vuln_type}", response_model=VulnerabilityTypeInfo)
async def get_vulnerability_type_info(category: str, vuln_type: str):
"""Get detailed info for a specific vulnerability type"""
if category not in VULNERABILITY_TYPES:
raise HTTPException(status_code=404, detail=f"Category '{category}' not found")
if vuln_type not in VULNERABILITY_TYPES[category]:
raise HTTPException(status_code=404, detail=f"Type '{vuln_type}' not found in category '{category}'")
info = VULNERABILITY_TYPES[category][vuln_type]
return VulnerabilityTypeInfo(
type=vuln_type,
category=category,
**info
)
@router.get("/{vuln_id}", response_model=VulnerabilityResponse)
async def get_vulnerability(vuln_id: str, db: AsyncSession = Depends(get_db)):
"""Get a specific vulnerability by ID"""
result = await db.execute(select(Vulnerability).where(Vulnerability.id == vuln_id))
vuln = result.scalar_one_or_none()
if not vuln:
raise HTTPException(status_code=404, detail="Vulnerability not found")
return VulnerabilityResponse(**vuln.to_dict())

247
backend/api/websocket.py Executable file
View File

@@ -0,0 +1,247 @@
"""
NeuroSploit v3 - WebSocket Manager
"""
from typing import Dict, List, Optional
from fastapi import WebSocket
import json
import asyncio
try:
from backend.core.notification_manager import notification_manager, NotificationEvent
HAS_NOTIFICATIONS = True
except ImportError:
HAS_NOTIFICATIONS = False
class ConnectionManager:
"""Manages WebSocket connections for real-time updates"""
def __init__(self):
# scan_id -> list of websocket connections
self.active_connections: Dict[str, List[WebSocket]] = {}
self._lock = asyncio.Lock()
async def connect(self, websocket: WebSocket, scan_id: str):
"""Accept a WebSocket connection and register it for a scan"""
await websocket.accept()
async with self._lock:
if scan_id not in self.active_connections:
self.active_connections[scan_id] = []
self.active_connections[scan_id].append(websocket)
print(f"WebSocket connected for scan: {scan_id}")
def disconnect(self, websocket: WebSocket, scan_id: str):
"""Remove a WebSocket connection"""
if scan_id in self.active_connections:
if websocket in self.active_connections[scan_id]:
self.active_connections[scan_id].remove(websocket)
if not self.active_connections[scan_id]:
del self.active_connections[scan_id]
print(f"WebSocket disconnected for scan: {scan_id}")
async def send_to_scan(self, scan_id: str, message: dict):
"""Send a message to all connections watching a specific scan"""
if scan_id not in self.active_connections:
return
dead_connections = []
for connection in self.active_connections[scan_id]:
try:
await connection.send_text(json.dumps(message))
except Exception:
dead_connections.append(connection)
# Clean up dead connections
for conn in dead_connections:
self.disconnect(conn, scan_id)
async def broadcast_scan_started(self, scan_id: str, target: str = ""):
"""Notify that a scan has started"""
await self.send_to_scan(scan_id, {
"type": "scan_started",
"scan_id": scan_id
})
if HAS_NOTIFICATIONS:
asyncio.create_task(notification_manager.notify(
NotificationEvent.SCAN_STARTED, {"target": target, "scan_id": scan_id}
))
async def broadcast_phase_change(self, scan_id: str, phase: str):
"""Notify phase change (recon, testing, reporting)"""
await self.send_to_scan(scan_id, {
"type": "phase_change",
"scan_id": scan_id,
"phase": phase
})
async def broadcast_progress(self, scan_id: str, progress: int, message: Optional[str] = None):
"""Send progress update"""
await self.send_to_scan(scan_id, {
"type": "progress_update",
"scan_id": scan_id,
"progress": progress,
"message": message
})
async def broadcast_endpoint_found(self, scan_id: str, endpoint: dict):
"""Notify a new endpoint was discovered"""
await self.send_to_scan(scan_id, {
"type": "endpoint_found",
"scan_id": scan_id,
"endpoint": endpoint
})
async def broadcast_path_crawled(self, scan_id: str, path: str, status: int):
"""Notify a path was crawled"""
await self.send_to_scan(scan_id, {
"type": "path_crawled",
"scan_id": scan_id,
"path": path,
"status": status
})
async def broadcast_url_discovered(self, scan_id: str, url: str):
"""Notify a URL was discovered"""
await self.send_to_scan(scan_id, {
"type": "url_discovered",
"scan_id": scan_id,
"url": url
})
async def broadcast_test_started(self, scan_id: str, vuln_type: str, endpoint: str):
"""Notify a vulnerability test has started"""
await self.send_to_scan(scan_id, {
"type": "test_started",
"scan_id": scan_id,
"vulnerability_type": vuln_type,
"endpoint": endpoint
})
async def broadcast_test_completed(self, scan_id: str, vuln_type: str, endpoint: str, is_vulnerable: bool):
"""Notify a vulnerability test has completed"""
await self.send_to_scan(scan_id, {
"type": "test_completed",
"scan_id": scan_id,
"vulnerability_type": vuln_type,
"endpoint": endpoint,
"is_vulnerable": is_vulnerable
})
async def broadcast_vulnerability_found(self, scan_id: str, vulnerability: dict):
"""Notify a vulnerability was found"""
await self.send_to_scan(scan_id, {
"type": "vuln_found",
"scan_id": scan_id,
"vulnerability": vulnerability
})
if HAS_NOTIFICATIONS:
asyncio.create_task(notification_manager.notify(
NotificationEvent.VULN_FOUND, {
"title": vulnerability.get("title", "Vulnerability Found"),
"severity": vulnerability.get("severity", "medium"),
"vulnerability_type": vulnerability.get("vulnerability_type", "unknown"),
"endpoint": vulnerability.get("endpoint", ""),
"description": vulnerability.get("description", ""),
}
))
async def broadcast_log(self, scan_id: str, level: str, message: str):
"""Send a log message"""
await self.send_to_scan(scan_id, {
"type": "log_message",
"scan_id": scan_id,
"level": level,
"message": message
})
async def broadcast_scan_completed(self, scan_id: str, summary: dict):
"""Notify that a scan has completed"""
await self.send_to_scan(scan_id, {
"type": "scan_completed",
"scan_id": scan_id,
"summary": summary
})
if HAS_NOTIFICATIONS:
asyncio.create_task(notification_manager.notify(
NotificationEvent.SCAN_COMPLETED, {
"total_vulnerabilities": summary.get("total_vulnerabilities", 0),
"critical": summary.get("critical", 0),
"high": summary.get("high", 0),
"medium": summary.get("medium", 0),
}
))
async def broadcast_scan_stopped(self, scan_id: str, summary: dict):
"""Notify that a scan was stopped by user"""
await self.send_to_scan(scan_id, {
"type": "scan_stopped",
"scan_id": scan_id,
"status": "stopped",
"summary": summary
})
async def broadcast_scan_failed(self, scan_id: str, error: str, summary: dict = None):
"""Notify that a scan has failed"""
await self.send_to_scan(scan_id, {
"type": "scan_failed",
"scan_id": scan_id,
"status": "failed",
"error": error,
"summary": summary or {}
})
if HAS_NOTIFICATIONS:
asyncio.create_task(notification_manager.notify(
NotificationEvent.SCAN_FAILED, {"error": error}
))
async def broadcast_stats_update(self, scan_id: str, stats: dict):
"""Broadcast updated scan statistics"""
await self.send_to_scan(scan_id, {
"type": "stats_update",
"scan_id": scan_id,
"stats": stats
})
async def broadcast_agent_task(self, scan_id: str, task: dict):
"""Broadcast agent task update (created, started, completed, failed)"""
await self.send_to_scan(scan_id, {
"type": "agent_task",
"scan_id": scan_id,
"task": task
})
async def broadcast_agent_task_started(self, scan_id: str, task: dict):
"""Broadcast when an agent task starts"""
await self.send_to_scan(scan_id, {
"type": "agent_task_started",
"scan_id": scan_id,
"task": task
})
async def broadcast_agent_task_completed(self, scan_id: str, task: dict):
"""Broadcast when an agent task completes"""
await self.send_to_scan(scan_id, {
"type": "agent_task_completed",
"scan_id": scan_id,
"task": task
})
async def broadcast_report_generated(self, scan_id: str, report: dict):
"""Broadcast when a report is generated"""
await self.send_to_scan(scan_id, {
"type": "report_generated",
"scan_id": scan_id,
"report": report
})
async def broadcast_error(self, scan_id: str, error: str):
"""Notify an error occurred"""
await self.send_to_scan(scan_id, {
"type": "error",
"scan_id": scan_id,
"error": error
})
# Global instance
manager = ConnectionManager()