mirror of
https://github.com/CyberSecurityUP/NeuroSploit.git
synced 2026-03-26 05:00:24 +01:00
NeuroSploit v3.2 - Autonomous AI Penetration Testing Platform
116 modules | 100 vuln types | 18 API routes | 18 frontend pages Major features: - VulnEngine: 100 vuln types, 526+ payloads, 12 testers, anti-hallucination prompts - Autonomous Agent: 3-stream auto pentest, multi-session (5 concurrent), pause/resume/stop - CLI Agent: Claude Code / Gemini CLI / Codex CLI inside Kali containers - Validation Pipeline: negative controls, proof of execution, confidence scoring, judge - AI Reasoning: ReACT engine, token budget, endpoint classifier, CVE hunter, deep recon - Multi-Agent: 5 specialists + orchestrator + researcher AI + vuln type agents - RAG System: BM25/TF-IDF/ChromaDB vectorstore, few-shot, reasoning templates - Smart Router: 20 providers (8 CLI OAuth + 12 API), tier failover, token refresh - Kali Sandbox: container-per-scan, 56 tools, VPN support, on-demand install - Full IA Testing: methodology-driven comprehensive pentest sessions - Notifications: Discord, Telegram, WhatsApp/Twilio multi-channel alerts - Frontend: React/TypeScript with 18 pages, real-time WebSocket updates
This commit is contained in:
1
backend/api/__init__.py
Executable file
1
backend/api/__init__.py
Executable file
@@ -0,0 +1 @@
|
||||
# API package
|
||||
1
backend/api/v1/__init__.py
Executable file
1
backend/api/v1/__init__.py
Executable file
@@ -0,0 +1 @@
|
||||
# API v1 package
|
||||
3128
backend/api/v1/agent.py
Executable file
3128
backend/api/v1/agent.py
Executable file
File diff suppressed because it is too large
Load Diff
176
backend/api/v1/agent_tasks.py
Executable file
176
backend/api/v1/agent_tasks.py
Executable file
@@ -0,0 +1,176 @@
|
||||
"""
|
||||
NeuroSploit v3 - Agent Tasks API Endpoints
|
||||
"""
|
||||
from typing import Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, func
|
||||
|
||||
from backend.db.database import get_db
|
||||
from backend.models import AgentTask, Scan
|
||||
from backend.schemas.agent_task import (
|
||||
AgentTaskResponse,
|
||||
AgentTaskListResponse,
|
||||
AgentTaskSummary
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("", response_model=AgentTaskListResponse)
|
||||
async def list_agent_tasks(
|
||||
scan_id: str,
|
||||
status: Optional[str] = None,
|
||||
task_type: Optional[str] = None,
|
||||
page: int = 1,
|
||||
per_page: int = 50,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""List all agent tasks for a scan"""
|
||||
# Verify scan exists
|
||||
scan_result = await db.execute(select(Scan).where(Scan.id == scan_id))
|
||||
scan = scan_result.scalar_one_or_none()
|
||||
if not scan:
|
||||
raise HTTPException(status_code=404, detail="Scan not found")
|
||||
|
||||
# Build query
|
||||
query = select(AgentTask).where(AgentTask.scan_id == scan_id)
|
||||
|
||||
if status:
|
||||
query = query.where(AgentTask.status == status)
|
||||
if task_type:
|
||||
query = query.where(AgentTask.task_type == task_type)
|
||||
|
||||
query = query.order_by(AgentTask.created_at.desc())
|
||||
|
||||
# Get total count
|
||||
count_query = select(func.count()).select_from(AgentTask).where(AgentTask.scan_id == scan_id)
|
||||
if status:
|
||||
count_query = count_query.where(AgentTask.status == status)
|
||||
if task_type:
|
||||
count_query = count_query.where(AgentTask.task_type == task_type)
|
||||
total_result = await db.execute(count_query)
|
||||
total = total_result.scalar() or 0
|
||||
|
||||
# Apply pagination
|
||||
query = query.offset((page - 1) * per_page).limit(per_page)
|
||||
result = await db.execute(query)
|
||||
tasks = result.scalars().all()
|
||||
|
||||
return AgentTaskListResponse(
|
||||
tasks=[AgentTaskResponse(**t.to_dict()) for t in tasks],
|
||||
total=total,
|
||||
scan_id=scan_id
|
||||
)
|
||||
|
||||
|
||||
@router.get("/summary", response_model=AgentTaskSummary)
|
||||
async def get_agent_tasks_summary(
|
||||
scan_id: str,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Get summary statistics for agent tasks in a scan"""
|
||||
# Verify scan exists
|
||||
scan_result = await db.execute(select(Scan).where(Scan.id == scan_id))
|
||||
scan = scan_result.scalar_one_or_none()
|
||||
if not scan:
|
||||
raise HTTPException(status_code=404, detail="Scan not found")
|
||||
|
||||
# Total count
|
||||
total_result = await db.execute(
|
||||
select(func.count()).select_from(AgentTask).where(AgentTask.scan_id == scan_id)
|
||||
)
|
||||
total = total_result.scalar() or 0
|
||||
|
||||
# Count by status
|
||||
status_counts = {}
|
||||
for status in ["pending", "running", "completed", "failed"]:
|
||||
count_result = await db.execute(
|
||||
select(func.count()).select_from(AgentTask)
|
||||
.where(AgentTask.scan_id == scan_id)
|
||||
.where(AgentTask.status == status)
|
||||
)
|
||||
status_counts[status] = count_result.scalar() or 0
|
||||
|
||||
# Count by task type
|
||||
type_query = select(
|
||||
AgentTask.task_type,
|
||||
func.count(AgentTask.id).label("count")
|
||||
).where(AgentTask.scan_id == scan_id).group_by(AgentTask.task_type)
|
||||
type_result = await db.execute(type_query)
|
||||
by_type = {row[0]: row[1] for row in type_result.all()}
|
||||
|
||||
# Count by tool
|
||||
tool_query = select(
|
||||
AgentTask.tool_name,
|
||||
func.count(AgentTask.id).label("count")
|
||||
).where(AgentTask.scan_id == scan_id).where(AgentTask.tool_name.isnot(None)).group_by(AgentTask.tool_name)
|
||||
tool_result = await db.execute(tool_query)
|
||||
by_tool = {row[0]: row[1] for row in tool_result.all()}
|
||||
|
||||
return AgentTaskSummary(
|
||||
total=total,
|
||||
pending=status_counts.get("pending", 0),
|
||||
running=status_counts.get("running", 0),
|
||||
completed=status_counts.get("completed", 0),
|
||||
failed=status_counts.get("failed", 0),
|
||||
by_type=by_type,
|
||||
by_tool=by_tool
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{task_id}", response_model=AgentTaskResponse)
|
||||
async def get_agent_task(
|
||||
task_id: str,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Get a specific agent task by ID"""
|
||||
result = await db.execute(select(AgentTask).where(AgentTask.id == task_id))
|
||||
task = result.scalar_one_or_none()
|
||||
|
||||
if not task:
|
||||
raise HTTPException(status_code=404, detail="Agent task not found")
|
||||
|
||||
return AgentTaskResponse(**task.to_dict())
|
||||
|
||||
|
||||
@router.get("/scan/{scan_id}/timeline")
|
||||
async def get_agent_tasks_timeline(
|
||||
scan_id: str,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Get agent tasks as a timeline for visualization"""
|
||||
# Verify scan exists
|
||||
scan_result = await db.execute(select(Scan).where(Scan.id == scan_id))
|
||||
scan = scan_result.scalar_one_or_none()
|
||||
if not scan:
|
||||
raise HTTPException(status_code=404, detail="Scan not found")
|
||||
|
||||
# Get all tasks ordered by creation time
|
||||
query = select(AgentTask).where(AgentTask.scan_id == scan_id).order_by(AgentTask.created_at.asc())
|
||||
result = await db.execute(query)
|
||||
tasks = result.scalars().all()
|
||||
|
||||
timeline = []
|
||||
for task in tasks:
|
||||
timeline_item = {
|
||||
"id": task.id,
|
||||
"task_name": task.task_name,
|
||||
"task_type": task.task_type,
|
||||
"tool_name": task.tool_name,
|
||||
"status": task.status,
|
||||
"started_at": task.started_at.isoformat() if task.started_at else None,
|
||||
"completed_at": task.completed_at.isoformat() if task.completed_at else None,
|
||||
"duration_ms": task.duration_ms,
|
||||
"items_processed": task.items_processed,
|
||||
"items_found": task.items_found,
|
||||
"result_summary": task.result_summary,
|
||||
"error_message": task.error_message
|
||||
}
|
||||
timeline.append(timeline_item)
|
||||
|
||||
return {
|
||||
"scan_id": scan_id,
|
||||
"timeline": timeline,
|
||||
"total": len(timeline)
|
||||
}
|
||||
144
backend/api/v1/cli_agent.py
Normal file
144
backend/api/v1/cli_agent.py
Normal file
@@ -0,0 +1,144 @@
|
||||
"""
|
||||
CLI Agent API - Endpoints for CLI agent provider detection and methodology listing.
|
||||
"""
|
||||
import os
|
||||
import glob
|
||||
import logging
|
||||
from typing import List, Dict, Optional
|
||||
from fastapi import APIRouter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/v1/cli-agent", tags=["CLI Agent"])
|
||||
|
||||
# CLI providers that can run as autonomous agents
|
||||
CLI_AGENT_PROVIDER_IDS = ["claude_code", "gemini_cli", "codex_cli"]
|
||||
|
||||
|
||||
@router.get("/providers")
|
||||
async def get_cli_providers() -> Dict:
|
||||
"""List available CLI agent providers with connection status from SmartRouter."""
|
||||
providers = []
|
||||
|
||||
try:
|
||||
from backend.core.smart_router import get_registry
|
||||
registry = get_registry()
|
||||
except Exception:
|
||||
registry = None
|
||||
|
||||
for pid in CLI_AGENT_PROVIDER_IDS:
|
||||
provider_info = {
|
||||
"id": pid,
|
||||
"name": pid,
|
||||
"connected": False,
|
||||
"account_label": None,
|
||||
"source": None,
|
||||
}
|
||||
|
||||
if registry:
|
||||
provider = registry.get_provider(pid)
|
||||
if provider:
|
||||
provider_info["name"] = provider.name
|
||||
accounts = registry.get_active_accounts(pid)
|
||||
if accounts:
|
||||
provider_info["connected"] = True
|
||||
provider_info["account_label"] = accounts[0].label
|
||||
provider_info["source"] = accounts[0].source
|
||||
|
||||
providers.append(provider_info)
|
||||
|
||||
# Also check env var API keys as fallback
|
||||
env_fallbacks = {
|
||||
"claude_code": "ANTHROPIC_API_KEY",
|
||||
"gemini_cli": "GEMINI_API_KEY",
|
||||
"codex_cli": "OPENAI_API_KEY",
|
||||
}
|
||||
for p in providers:
|
||||
if not p["connected"]:
|
||||
env_key = env_fallbacks.get(p["id"], "")
|
||||
if env_key and os.getenv(env_key, ""):
|
||||
p["connected"] = True
|
||||
p["source"] = "env_var"
|
||||
p["account_label"] = f"${env_key}"
|
||||
|
||||
enabled = os.getenv("ENABLE_CLI_AGENT", "false").lower() == "true"
|
||||
|
||||
return {
|
||||
"enabled": enabled,
|
||||
"providers": providers,
|
||||
"connected_count": sum(1 for p in providers if p["connected"]),
|
||||
}
|
||||
|
||||
|
||||
@router.get("/methodologies")
|
||||
async def list_methodologies() -> Dict:
|
||||
"""List available methodology .md files for CLI agent."""
|
||||
methodologies: List[Dict] = []
|
||||
seen_paths: set = set()
|
||||
|
||||
# 1. Check METHODOLOGY_FILE env var (default)
|
||||
default_path = os.getenv("METHODOLOGY_FILE", "")
|
||||
if default_path and os.path.exists(default_path):
|
||||
size = os.path.getsize(default_path)
|
||||
methodologies.append({
|
||||
"name": os.path.basename(default_path),
|
||||
"path": default_path,
|
||||
"size": size,
|
||||
"size_human": _human_size(size),
|
||||
"is_default": True,
|
||||
})
|
||||
seen_paths.add(os.path.abspath(default_path))
|
||||
|
||||
# 2. Scan /opt/Prompts-PenTest/ for .md files
|
||||
prompts_dir = "/opt/Prompts-PenTest"
|
||||
if os.path.isdir(prompts_dir):
|
||||
for md_file in sorted(glob.glob(os.path.join(prompts_dir, "*.md"))):
|
||||
abs_path = os.path.abspath(md_file)
|
||||
if abs_path in seen_paths:
|
||||
continue
|
||||
seen_paths.add(abs_path)
|
||||
|
||||
name = os.path.basename(md_file)
|
||||
size = os.path.getsize(md_file)
|
||||
|
||||
# Only include pentest-related files (skip research reports, etc.)
|
||||
name_lower = name.lower()
|
||||
if any(kw in name_lower for kw in ["pentest", "prompt", "bugbounty", "methodology", "chunk"]):
|
||||
methodologies.append({
|
||||
"name": name,
|
||||
"path": md_file,
|
||||
"size": size,
|
||||
"size_human": _human_size(size),
|
||||
"is_default": False,
|
||||
})
|
||||
|
||||
# 3. Check data/ directory
|
||||
data_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "data")
|
||||
if os.path.isdir(data_dir):
|
||||
for md_file in glob.glob(os.path.join(data_dir, "*methodology*.md")):
|
||||
abs_path = os.path.abspath(md_file)
|
||||
if abs_path not in seen_paths:
|
||||
seen_paths.add(abs_path)
|
||||
size = os.path.getsize(md_file)
|
||||
methodologies.append({
|
||||
"name": os.path.basename(md_file),
|
||||
"path": md_file,
|
||||
"size": size,
|
||||
"size_human": _human_size(size),
|
||||
"is_default": False,
|
||||
})
|
||||
|
||||
return {
|
||||
"methodologies": methodologies,
|
||||
"total": len(methodologies),
|
||||
}
|
||||
|
||||
|
||||
def _human_size(size_bytes: int) -> str:
|
||||
"""Convert bytes to human-readable size."""
|
||||
if size_bytes < 1024:
|
||||
return f"{size_bytes} B"
|
||||
elif size_bytes < 1024 * 1024:
|
||||
return f"{size_bytes / 1024:.1f} KB"
|
||||
else:
|
||||
return f"{size_bytes / (1024 * 1024):.1f} MB"
|
||||
299
backend/api/v1/dashboard.py
Executable file
299
backend/api/v1/dashboard.py
Executable file
@@ -0,0 +1,299 @@
|
||||
"""
|
||||
NeuroSploit v3 - Dashboard API Endpoints
|
||||
"""
|
||||
from typing import List
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, func
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from backend.db.database import get_db
|
||||
from backend.models import Scan, Vulnerability, Endpoint, AgentTask, Report
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/stats")
|
||||
async def get_dashboard_stats(db: AsyncSession = Depends(get_db)):
|
||||
"""Get overall dashboard statistics"""
|
||||
# Total scans
|
||||
total_scans_result = await db.execute(select(func.count()).select_from(Scan))
|
||||
total_scans = total_scans_result.scalar() or 0
|
||||
|
||||
# Scans by status
|
||||
running_result = await db.execute(
|
||||
select(func.count()).select_from(Scan).where(Scan.status == "running")
|
||||
)
|
||||
running_scans = running_result.scalar() or 0
|
||||
|
||||
completed_result = await db.execute(
|
||||
select(func.count()).select_from(Scan).where(Scan.status == "completed")
|
||||
)
|
||||
completed_scans = completed_result.scalar() or 0
|
||||
|
||||
stopped_result = await db.execute(
|
||||
select(func.count()).select_from(Scan).where(Scan.status == "stopped")
|
||||
)
|
||||
stopped_scans = stopped_result.scalar() or 0
|
||||
|
||||
failed_result = await db.execute(
|
||||
select(func.count()).select_from(Scan).where(Scan.status == "failed")
|
||||
)
|
||||
failed_scans = failed_result.scalar() or 0
|
||||
|
||||
pending_result = await db.execute(
|
||||
select(func.count()).select_from(Scan).where(Scan.status == "pending")
|
||||
)
|
||||
pending_scans = pending_result.scalar() or 0
|
||||
|
||||
# Total vulnerabilities by severity
|
||||
vuln_counts = {}
|
||||
for severity in ["critical", "high", "medium", "low", "info"]:
|
||||
result = await db.execute(
|
||||
select(func.count()).select_from(Vulnerability).where(Vulnerability.severity == severity)
|
||||
)
|
||||
vuln_counts[severity] = result.scalar() or 0
|
||||
|
||||
total_vulns = sum(vuln_counts.values())
|
||||
|
||||
# Total endpoints
|
||||
endpoints_result = await db.execute(select(func.count()).select_from(Endpoint))
|
||||
total_endpoints = endpoints_result.scalar() or 0
|
||||
|
||||
# Recent activity (last 7 days)
|
||||
week_ago = datetime.utcnow() - timedelta(days=7)
|
||||
recent_scans_result = await db.execute(
|
||||
select(func.count()).select_from(Scan).where(Scan.created_at >= week_ago)
|
||||
)
|
||||
recent_scans = recent_scans_result.scalar() or 0
|
||||
|
||||
recent_vulns_result = await db.execute(
|
||||
select(func.count()).select_from(Vulnerability).where(Vulnerability.created_at >= week_ago)
|
||||
)
|
||||
recent_vulns = recent_vulns_result.scalar() or 0
|
||||
|
||||
return {
|
||||
"scans": {
|
||||
"total": total_scans,
|
||||
"running": running_scans,
|
||||
"completed": completed_scans,
|
||||
"stopped": stopped_scans,
|
||||
"failed": failed_scans,
|
||||
"pending": pending_scans,
|
||||
"recent": recent_scans
|
||||
},
|
||||
"vulnerabilities": {
|
||||
"total": total_vulns,
|
||||
"critical": vuln_counts["critical"],
|
||||
"high": vuln_counts["high"],
|
||||
"medium": vuln_counts["medium"],
|
||||
"low": vuln_counts["low"],
|
||||
"info": vuln_counts["info"],
|
||||
"recent": recent_vulns
|
||||
},
|
||||
"endpoints": {
|
||||
"total": total_endpoints
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@router.get("/recent")
|
||||
async def get_recent_activity(
|
||||
limit: int = 10,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Get recent scan activity"""
|
||||
# Recent scans
|
||||
scans_query = select(Scan).order_by(Scan.created_at.desc()).limit(limit)
|
||||
scans_result = await db.execute(scans_query)
|
||||
recent_scans = scans_result.scalars().all()
|
||||
|
||||
# Recent vulnerabilities
|
||||
vulns_query = select(Vulnerability).order_by(Vulnerability.created_at.desc()).limit(limit)
|
||||
vulns_result = await db.execute(vulns_query)
|
||||
recent_vulns = vulns_result.scalars().all()
|
||||
|
||||
return {
|
||||
"recent_scans": [s.to_dict() for s in recent_scans],
|
||||
"recent_vulnerabilities": [v.to_dict() for v in recent_vulns]
|
||||
}
|
||||
|
||||
|
||||
@router.get("/findings")
|
||||
async def get_recent_findings(
|
||||
limit: int = 20,
|
||||
severity: str = None,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Get recent vulnerability findings"""
|
||||
query = select(Vulnerability).order_by(Vulnerability.created_at.desc())
|
||||
|
||||
if severity:
|
||||
query = query.where(Vulnerability.severity == severity)
|
||||
|
||||
query = query.limit(limit)
|
||||
result = await db.execute(query)
|
||||
vulnerabilities = result.scalars().all()
|
||||
|
||||
return {
|
||||
"findings": [v.to_dict() for v in vulnerabilities],
|
||||
"total": len(vulnerabilities)
|
||||
}
|
||||
|
||||
|
||||
@router.get("/vulnerability-types")
|
||||
async def get_vulnerability_distribution(db: AsyncSession = Depends(get_db)):
|
||||
"""Get vulnerability distribution by type"""
|
||||
query = select(
|
||||
Vulnerability.vulnerability_type,
|
||||
func.count(Vulnerability.id).label("count")
|
||||
).group_by(Vulnerability.vulnerability_type)
|
||||
|
||||
result = await db.execute(query)
|
||||
distribution = result.all()
|
||||
|
||||
return {
|
||||
"distribution": [
|
||||
{"type": row[0], "count": row[1]}
|
||||
for row in distribution
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@router.get("/scan-history")
|
||||
async def get_scan_history(
|
||||
days: int = 30,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Get scan history for charts"""
|
||||
start_date = datetime.utcnow() - timedelta(days=days)
|
||||
|
||||
# Get scans grouped by date
|
||||
scans = await db.execute(
|
||||
select(Scan).where(Scan.created_at >= start_date).order_by(Scan.created_at)
|
||||
)
|
||||
all_scans = scans.scalars().all()
|
||||
|
||||
# Group by date
|
||||
history = {}
|
||||
for scan in all_scans:
|
||||
date_str = scan.created_at.strftime("%Y-%m-%d")
|
||||
if date_str not in history:
|
||||
history[date_str] = {
|
||||
"date": date_str,
|
||||
"scans": 0,
|
||||
"vulnerabilities": 0,
|
||||
"critical": 0,
|
||||
"high": 0
|
||||
}
|
||||
history[date_str]["scans"] += 1
|
||||
history[date_str]["vulnerabilities"] += scan.total_vulnerabilities
|
||||
history[date_str]["critical"] += scan.critical_count
|
||||
history[date_str]["high"] += scan.high_count
|
||||
|
||||
return {"history": list(history.values())}
|
||||
|
||||
|
||||
@router.get("/agent-tasks")
|
||||
async def get_recent_agent_tasks(
|
||||
limit: int = 20,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Get recent agent tasks across all scans"""
|
||||
query = (
|
||||
select(AgentTask)
|
||||
.order_by(AgentTask.created_at.desc())
|
||||
.limit(limit)
|
||||
)
|
||||
result = await db.execute(query)
|
||||
tasks = result.scalars().all()
|
||||
|
||||
return {
|
||||
"agent_tasks": [t.to_dict() for t in tasks],
|
||||
"total": len(tasks)
|
||||
}
|
||||
|
||||
|
||||
@router.get("/activity-feed")
|
||||
async def get_activity_feed(
|
||||
limit: int = 30,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Get unified activity feed with all recent events"""
|
||||
activities = []
|
||||
|
||||
# Get recent scans
|
||||
scans_result = await db.execute(
|
||||
select(Scan).order_by(Scan.created_at.desc()).limit(limit // 3)
|
||||
)
|
||||
for scan in scans_result.scalars().all():
|
||||
activities.append({
|
||||
"type": "scan",
|
||||
"action": f"Scan {scan.status}",
|
||||
"title": scan.name or "Unnamed Scan",
|
||||
"description": f"{scan.total_vulnerabilities} vulnerabilities found",
|
||||
"status": scan.status,
|
||||
"severity": None,
|
||||
"timestamp": scan.created_at.isoformat(),
|
||||
"scan_id": scan.id,
|
||||
"link": f"/scan/{scan.id}"
|
||||
})
|
||||
|
||||
# Get recent vulnerabilities
|
||||
vulns_result = await db.execute(
|
||||
select(Vulnerability).order_by(Vulnerability.created_at.desc()).limit(limit // 3)
|
||||
)
|
||||
for vuln in vulns_result.scalars().all():
|
||||
activities.append({
|
||||
"type": "vulnerability",
|
||||
"action": "Vulnerability found",
|
||||
"title": vuln.title,
|
||||
"description": vuln.affected_endpoint or "",
|
||||
"status": None,
|
||||
"severity": vuln.severity,
|
||||
"timestamp": vuln.created_at.isoformat(),
|
||||
"scan_id": vuln.scan_id,
|
||||
"link": f"/scan/{vuln.scan_id}"
|
||||
})
|
||||
|
||||
# Get recent agent tasks
|
||||
tasks_result = await db.execute(
|
||||
select(AgentTask).order_by(AgentTask.created_at.desc()).limit(limit // 3)
|
||||
)
|
||||
for task in tasks_result.scalars().all():
|
||||
activities.append({
|
||||
"type": "agent_task",
|
||||
"action": f"Task {task.status}",
|
||||
"title": task.task_name,
|
||||
"description": task.result_summary or task.description or "",
|
||||
"status": task.status,
|
||||
"severity": None,
|
||||
"timestamp": task.created_at.isoformat(),
|
||||
"scan_id": task.scan_id,
|
||||
"link": f"/scan/{task.scan_id}"
|
||||
})
|
||||
|
||||
# Get recent reports
|
||||
reports_result = await db.execute(
|
||||
select(Report).order_by(Report.generated_at.desc()).limit(limit // 4)
|
||||
)
|
||||
for report in reports_result.scalars().all():
|
||||
activities.append({
|
||||
"type": "report",
|
||||
"action": "Report generated" if report.auto_generated else "Report created",
|
||||
"title": report.title or "Report",
|
||||
"description": f"{report.format.upper()} format",
|
||||
"status": "auto" if report.auto_generated else "manual",
|
||||
"severity": None,
|
||||
"timestamp": report.generated_at.isoformat(),
|
||||
"scan_id": report.scan_id,
|
||||
"link": f"/reports"
|
||||
})
|
||||
|
||||
# Sort all activities by timestamp (newest first)
|
||||
activities.sort(key=lambda x: x["timestamp"], reverse=True)
|
||||
|
||||
return {
|
||||
"activities": activities[:limit],
|
||||
"total": len(activities)
|
||||
}
|
||||
38
backend/api/v1/full_ia.py
Normal file
38
backend/api/v1/full_ia.py
Normal file
@@ -0,0 +1,38 @@
|
||||
"""
|
||||
NeuroSploit v3 - FULL AI Testing API
|
||||
|
||||
Serves the comprehensive pentest prompt and manages FULL AI testing sessions.
|
||||
"""
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from fastapi import APIRouter, HTTPException
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# Default prompt file path - English translation preferred, fallback to original
|
||||
PROMPT_PATH_EN = Path("/opt/Prompts-PenTest/pentestcompleto_en.md")
|
||||
PROMPT_PATH_PT = Path("/opt/Prompts-PenTest/pentestcompleto.md")
|
||||
PROMPT_PATH = PROMPT_PATH_EN if PROMPT_PATH_EN.exists() else PROMPT_PATH_PT
|
||||
|
||||
|
||||
@router.get("/prompt")
|
||||
async def get_full_ia_prompt():
|
||||
"""Return the comprehensive pentest prompt content."""
|
||||
if not PROMPT_PATH.exists():
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Pentest prompt file not found at {PROMPT_PATH}"
|
||||
)
|
||||
try:
|
||||
content = PROMPT_PATH.read_text(encoding="utf-8")
|
||||
return {
|
||||
"content": content,
|
||||
"path": str(PROMPT_PATH),
|
||||
"size": len(content),
|
||||
"lines": content.count("\n") + 1,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to read prompt file: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
172
backend/api/v1/knowledge.py
Normal file
172
backend/api/v1/knowledge.py
Normal file
@@ -0,0 +1,172 @@
|
||||
"""
|
||||
NeuroSploit v3 - Knowledge Management API
|
||||
|
||||
Upload, manage, and query custom security knowledge documents.
|
||||
"""
|
||||
import os
|
||||
from typing import Optional, List
|
||||
from fastapi import APIRouter, HTTPException, UploadFile, File, Query
|
||||
from pydantic import BaseModel
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# Lazy-loaded processor instance
|
||||
_processor = None
|
||||
|
||||
|
||||
def _get_processor():
|
||||
global _processor
|
||||
if _processor is None:
|
||||
from backend.core.knowledge_processor import KnowledgeProcessor
|
||||
# Try to get LLM client for AI analysis
|
||||
llm = None
|
||||
try:
|
||||
from backend.core.autonomous_agent import LLMClient
|
||||
client = LLMClient()
|
||||
if client.is_available():
|
||||
llm = client
|
||||
except Exception:
|
||||
pass
|
||||
_processor = KnowledgeProcessor(llm_client=llm)
|
||||
return _processor
|
||||
|
||||
|
||||
# --- Schemas ---
|
||||
|
||||
class KnowledgeDocumentResponse(BaseModel):
|
||||
id: str
|
||||
filename: str
|
||||
title: str
|
||||
source_type: str
|
||||
uploaded_at: str
|
||||
processed: bool
|
||||
file_size_bytes: int
|
||||
summary: str
|
||||
vuln_types: List[str]
|
||||
entries_count: int
|
||||
|
||||
|
||||
class KnowledgeEntryResponse(BaseModel):
|
||||
vuln_type: str
|
||||
methodology: str = ""
|
||||
payloads: List[str] = []
|
||||
key_insights: str = ""
|
||||
bypass_techniques: List[str] = []
|
||||
source_document: str = ""
|
||||
|
||||
|
||||
class KnowledgeStatsResponse(BaseModel):
|
||||
total_documents: int
|
||||
total_entries: int
|
||||
vuln_types_covered: List[str]
|
||||
storage_bytes: int
|
||||
|
||||
|
||||
# --- Endpoints ---
|
||||
|
||||
@router.post("/upload", response_model=KnowledgeDocumentResponse)
|
||||
async def upload_knowledge(file: UploadFile = File(...)):
|
||||
"""Upload a security document for knowledge extraction.
|
||||
|
||||
Supported formats: PDF, Markdown (.md), Text (.txt), HTML
|
||||
The document will be analyzed and indexed by vulnerability type.
|
||||
"""
|
||||
if not file.filename:
|
||||
raise HTTPException(400, "Filename is required")
|
||||
|
||||
# Read file content
|
||||
content = await file.read()
|
||||
if len(content) > 50 * 1024 * 1024: # 50MB limit
|
||||
raise HTTPException(413, "File too large (max 50MB)")
|
||||
if len(content) == 0:
|
||||
raise HTTPException(400, "Empty file")
|
||||
|
||||
processor = _get_processor()
|
||||
|
||||
try:
|
||||
doc = await processor.process_upload(content, file.filename)
|
||||
except ValueError as e:
|
||||
raise HTTPException(400, str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(500, f"Processing failed: {str(e)}")
|
||||
|
||||
return KnowledgeDocumentResponse(
|
||||
id=doc["id"],
|
||||
filename=doc["filename"],
|
||||
title=doc["title"],
|
||||
source_type=doc["source_type"],
|
||||
uploaded_at=doc["uploaded_at"],
|
||||
processed=doc["processed"],
|
||||
file_size_bytes=doc["file_size_bytes"],
|
||||
summary=doc["summary"],
|
||||
vuln_types=doc["vuln_types"],
|
||||
entries_count=len(doc.get("knowledge_entries", [])),
|
||||
)
|
||||
|
||||
|
||||
@router.get("/documents", response_model=List[KnowledgeDocumentResponse])
|
||||
async def list_documents():
|
||||
"""List all indexed knowledge documents."""
|
||||
processor = _get_processor()
|
||||
docs = processor.get_documents()
|
||||
return [
|
||||
KnowledgeDocumentResponse(
|
||||
id=d["id"],
|
||||
filename=d["filename"],
|
||||
title=d["title"],
|
||||
source_type=d["source_type"],
|
||||
uploaded_at=d["uploaded_at"],
|
||||
processed=d["processed"],
|
||||
file_size_bytes=d["file_size_bytes"],
|
||||
summary=d["summary"],
|
||||
vuln_types=d["vuln_types"],
|
||||
entries_count=d["entries_count"],
|
||||
)
|
||||
for d in docs
|
||||
]
|
||||
|
||||
|
||||
@router.get("/documents/{doc_id}")
|
||||
async def get_document(doc_id: str):
|
||||
"""Get a specific document with its full knowledge entries."""
|
||||
processor = _get_processor()
|
||||
doc = processor.get_document(doc_id)
|
||||
if not doc:
|
||||
raise HTTPException(404, f"Document '{doc_id}' not found")
|
||||
return doc
|
||||
|
||||
|
||||
@router.delete("/documents/{doc_id}")
|
||||
async def delete_document(doc_id: str):
|
||||
"""Delete a knowledge document and its index entries."""
|
||||
processor = _get_processor()
|
||||
deleted = processor.delete_document(doc_id)
|
||||
if not deleted:
|
||||
raise HTTPException(404, f"Document '{doc_id}' not found")
|
||||
return {"message": f"Document '{doc_id}' deleted", "id": doc_id}
|
||||
|
||||
|
||||
@router.get("/search", response_model=List[KnowledgeEntryResponse])
|
||||
async def search_knowledge(vuln_type: str = Query(..., description="Vulnerability type to search")):
|
||||
"""Search knowledge entries by vulnerability type."""
|
||||
processor = _get_processor()
|
||||
entries = processor.search_by_vuln_type(vuln_type)
|
||||
return [
|
||||
KnowledgeEntryResponse(
|
||||
vuln_type=e.get("vuln_type", ""),
|
||||
methodology=e.get("methodology", ""),
|
||||
payloads=e.get("payloads", []),
|
||||
key_insights=e.get("key_insights", ""),
|
||||
bypass_techniques=e.get("bypass_techniques", []),
|
||||
source_document=e.get("source_document", ""),
|
||||
)
|
||||
for e in entries
|
||||
]
|
||||
|
||||
|
||||
@router.get("/stats", response_model=KnowledgeStatsResponse)
|
||||
async def get_stats():
|
||||
"""Get knowledge base statistics."""
|
||||
processor = _get_processor()
|
||||
stats = processor.get_stats()
|
||||
return KnowledgeStatsResponse(**stats)
|
||||
320
backend/api/v1/mcp.py
Normal file
320
backend/api/v1/mcp.py
Normal file
@@ -0,0 +1,320 @@
|
||||
"""
|
||||
NeuroSploit v3 - MCP Server Management API
|
||||
|
||||
CRUD for Model Context Protocol server connections.
|
||||
Persists to config/config.json mcp_servers section.
|
||||
"""
|
||||
import json
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, Dict
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
CONFIG_PATH = Path(__file__).parent.parent.parent.parent / "config" / "config.json"
|
||||
|
||||
BUILTIN_SERVER = "neurosploit_tools"
|
||||
|
||||
|
||||
# --- Schemas ---
|
||||
|
||||
class MCPServerCreate(BaseModel):
|
||||
name: str = Field(..., min_length=1, max_length=100, description="Unique server identifier")
|
||||
transport: str = Field("stdio", description="Transport type: stdio or sse")
|
||||
command: Optional[str] = Field(None, description="Command for stdio transport")
|
||||
args: Optional[List[str]] = Field(None, description="Args for stdio transport")
|
||||
url: Optional[str] = Field(None, description="URL for sse transport")
|
||||
env: Optional[Dict[str, str]] = Field(None, description="Environment variables")
|
||||
description: str = Field("", description="Server description")
|
||||
enabled: bool = Field(True, description="Whether server is enabled")
|
||||
|
||||
|
||||
class MCPServerUpdate(BaseModel):
|
||||
transport: Optional[str] = None
|
||||
command: Optional[str] = None
|
||||
args: Optional[List[str]] = None
|
||||
url: Optional[str] = None
|
||||
env: Optional[Dict[str, str]] = None
|
||||
description: Optional[str] = None
|
||||
enabled: Optional[bool] = None
|
||||
|
||||
|
||||
class MCPServerResponse(BaseModel):
|
||||
name: str
|
||||
transport: str
|
||||
command: Optional[str] = None
|
||||
args: Optional[List[str]] = None
|
||||
url: Optional[str] = None
|
||||
env: Optional[Dict[str, str]] = None
|
||||
description: str = ""
|
||||
enabled: bool = True
|
||||
is_builtin: bool = False
|
||||
|
||||
|
||||
class MCPToolResponse(BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
server_name: str
|
||||
|
||||
|
||||
# --- Config helpers ---
|
||||
|
||||
def _read_config() -> dict:
|
||||
if not CONFIG_PATH.exists():
|
||||
return {}
|
||||
with open(CONFIG_PATH) as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def _write_config(config: dict):
|
||||
with open(CONFIG_PATH, "w") as f:
|
||||
json.dump(config, f, indent=4)
|
||||
|
||||
|
||||
def _get_mcp_servers(config: dict) -> dict:
|
||||
return config.get("mcp_servers", {})
|
||||
|
||||
|
||||
def _server_to_response(name: str, server: dict) -> MCPServerResponse:
|
||||
return MCPServerResponse(
|
||||
name=name,
|
||||
transport=server.get("transport", "stdio"),
|
||||
command=server.get("command"),
|
||||
args=server.get("args"),
|
||||
url=server.get("url"),
|
||||
env=server.get("env"),
|
||||
description=server.get("description", ""),
|
||||
enabled=server.get("enabled", True),
|
||||
is_builtin=(name == BUILTIN_SERVER),
|
||||
)
|
||||
|
||||
|
||||
# --- Endpoints ---
|
||||
|
||||
@router.get("/servers", response_model=List[MCPServerResponse])
|
||||
async def list_servers():
|
||||
"""List all configured MCP servers."""
|
||||
config = _read_config()
|
||||
servers = _get_mcp_servers(config)
|
||||
return [_server_to_response(name, srv) for name, srv in servers.items()]
|
||||
|
||||
|
||||
@router.get("/servers/{name}", response_model=MCPServerResponse)
|
||||
async def get_server(name: str):
|
||||
"""Get a specific MCP server configuration."""
|
||||
config = _read_config()
|
||||
servers = _get_mcp_servers(config)
|
||||
if name not in servers:
|
||||
raise HTTPException(404, f"MCP server '{name}' not found")
|
||||
return _server_to_response(name, servers[name])
|
||||
|
||||
|
||||
@router.post("/servers", response_model=MCPServerResponse)
|
||||
async def create_server(body: MCPServerCreate):
|
||||
"""Add a new MCP server configuration."""
|
||||
config = _read_config()
|
||||
if "mcp_servers" not in config:
|
||||
config["mcp_servers"] = {}
|
||||
|
||||
servers = config["mcp_servers"]
|
||||
if body.name in servers:
|
||||
raise HTTPException(409, f"Server '{body.name}' already exists")
|
||||
|
||||
# Validate transport-specific fields
|
||||
if body.transport == "stdio" and not body.command:
|
||||
raise HTTPException(400, "stdio transport requires 'command' field")
|
||||
if body.transport == "sse" and not body.url:
|
||||
raise HTTPException(400, "sse transport requires 'url' field")
|
||||
|
||||
server_config = {
|
||||
"transport": body.transport,
|
||||
"description": body.description,
|
||||
"enabled": body.enabled,
|
||||
}
|
||||
if body.command:
|
||||
server_config["command"] = body.command
|
||||
if body.args:
|
||||
server_config["args"] = body.args
|
||||
if body.url:
|
||||
server_config["url"] = body.url
|
||||
if body.env:
|
||||
server_config["env"] = body.env
|
||||
|
||||
servers[body.name] = server_config
|
||||
_write_config(config)
|
||||
|
||||
return _server_to_response(body.name, server_config)
|
||||
|
||||
|
||||
@router.put("/servers/{name}", response_model=MCPServerResponse)
|
||||
async def update_server(name: str, body: MCPServerUpdate):
|
||||
"""Update an MCP server configuration."""
|
||||
config = _read_config()
|
||||
servers = _get_mcp_servers(config)
|
||||
|
||||
if name not in servers:
|
||||
raise HTTPException(404, f"MCP server '{name}' not found")
|
||||
|
||||
srv = servers[name]
|
||||
if body.transport is not None:
|
||||
srv["transport"] = body.transport
|
||||
if body.command is not None:
|
||||
srv["command"] = body.command
|
||||
if body.args is not None:
|
||||
srv["args"] = body.args
|
||||
if body.url is not None:
|
||||
srv["url"] = body.url
|
||||
if body.env is not None:
|
||||
srv["env"] = body.env
|
||||
if body.description is not None:
|
||||
srv["description"] = body.description
|
||||
if body.enabled is not None:
|
||||
srv["enabled"] = body.enabled
|
||||
|
||||
_write_config(config)
|
||||
return _server_to_response(name, srv)
|
||||
|
||||
|
||||
@router.delete("/servers/{name}")
|
||||
async def delete_server(name: str):
|
||||
"""Delete an MCP server configuration."""
|
||||
if name == BUILTIN_SERVER:
|
||||
raise HTTPException(403, f"Cannot delete built-in server '{BUILTIN_SERVER}'")
|
||||
|
||||
config = _read_config()
|
||||
servers = _get_mcp_servers(config)
|
||||
|
||||
if name not in servers:
|
||||
raise HTTPException(404, f"MCP server '{name}' not found")
|
||||
|
||||
del servers[name]
|
||||
_write_config(config)
|
||||
return {"message": f"Server '{name}' deleted"}
|
||||
|
||||
|
||||
@router.post("/servers/{name}/toggle", response_model=MCPServerResponse)
|
||||
async def toggle_server(name: str):
|
||||
"""Toggle a server's enabled state."""
|
||||
config = _read_config()
|
||||
servers = _get_mcp_servers(config)
|
||||
|
||||
if name not in servers:
|
||||
raise HTTPException(404, f"MCP server '{name}' not found")
|
||||
|
||||
srv = servers[name]
|
||||
srv["enabled"] = not srv.get("enabled", True)
|
||||
_write_config(config)
|
||||
return _server_to_response(name, srv)
|
||||
|
||||
|
||||
@router.post("/servers/{name}/test")
|
||||
async def test_server_connection(name: str):
|
||||
"""Test connection to an MCP server."""
|
||||
config = _read_config()
|
||||
servers = _get_mcp_servers(config)
|
||||
|
||||
if name not in servers:
|
||||
raise HTTPException(404, f"MCP server '{name}' not found")
|
||||
|
||||
srv = servers[name]
|
||||
transport = srv.get("transport", "stdio")
|
||||
|
||||
try:
|
||||
if transport == "sse":
|
||||
# Test SSE endpoint
|
||||
import aiohttp
|
||||
url = srv.get("url", "")
|
||||
if not url:
|
||||
return {"success": False, "error": "No URL configured", "tools_count": 0}
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url, timeout=aiohttp.ClientTimeout(total=5)) as resp:
|
||||
if resp.status < 400:
|
||||
return {"success": True, "message": f"SSE endpoint reachable (HTTP {resp.status})", "tools_count": 0}
|
||||
return {"success": False, "error": f"HTTP {resp.status}", "tools_count": 0}
|
||||
|
||||
elif transport == "stdio":
|
||||
# Test stdio by checking command exists
|
||||
import shutil
|
||||
command = srv.get("command", "")
|
||||
if not command:
|
||||
return {"success": False, "error": "No command configured", "tools_count": 0}
|
||||
|
||||
if shutil.which(command):
|
||||
return {"success": True, "message": f"Command '{command}' found in PATH", "tools_count": 0}
|
||||
else:
|
||||
return {"success": False, "error": f"Command '{command}' not found in PATH", "tools_count": 0}
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
return {"success": False, "error": "Connection timed out (5s)", "tools_count": 0}
|
||||
except Exception as e:
|
||||
return {"success": False, "error": str(e), "tools_count": 0}
|
||||
|
||||
|
||||
@router.get("/servers/{name}/tools", response_model=List[MCPToolResponse])
|
||||
async def list_server_tools(name: str):
|
||||
"""List available tools from an MCP server.
|
||||
|
||||
For the built-in server, returns tools from the registry.
|
||||
For external servers, attempts to connect and query.
|
||||
"""
|
||||
config = _read_config()
|
||||
servers = _get_mcp_servers(config)
|
||||
|
||||
if name not in servers:
|
||||
raise HTTPException(404, f"MCP server '{name}' not found")
|
||||
|
||||
# For builtin server, return tools from the MCP server module
|
||||
if name == BUILTIN_SERVER:
|
||||
try:
|
||||
from core.mcp_server import TOOLS
|
||||
return [
|
||||
MCPToolResponse(
|
||||
name=t["name"],
|
||||
description=t.get("description", ""),
|
||||
server_name=name,
|
||||
)
|
||||
for t in TOOLS
|
||||
]
|
||||
except ImportError:
|
||||
return []
|
||||
|
||||
# For external servers, try to connect via MCPToolClient
|
||||
try:
|
||||
from core.mcp_client import MCPToolClient
|
||||
|
||||
# Build minimal config for this single server
|
||||
client_config = {
|
||||
"mcp_servers": {
|
||||
"enabled": True,
|
||||
"servers": {name: servers[name]}
|
||||
}
|
||||
}
|
||||
client = MCPToolClient(client_config)
|
||||
|
||||
connected = await asyncio.wait_for(client.connect(name), timeout=10)
|
||||
if not connected:
|
||||
raise HTTPException(502, f"Failed to connect to MCP server '{name}'")
|
||||
|
||||
tools_dict = await client.list_tools(name)
|
||||
tool_list = tools_dict.get(name, [])
|
||||
|
||||
await client.disconnect_all()
|
||||
|
||||
return [
|
||||
MCPToolResponse(
|
||||
name=t.get("name", ""),
|
||||
description=t.get("description", ""),
|
||||
server_name=name,
|
||||
)
|
||||
for t in tool_list
|
||||
]
|
||||
except ImportError:
|
||||
raise HTTPException(501, "MCP client library not installed")
|
||||
except asyncio.TimeoutError:
|
||||
raise HTTPException(504, "Connection to MCP server timed out (10s)")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(502, f"Failed to list tools: {str(e)}")
|
||||
372
backend/api/v1/prompts.py
Executable file
372
backend/api/v1/prompts.py
Executable file
@@ -0,0 +1,372 @@
|
||||
"""
|
||||
NeuroSploit v3 - Prompts API Endpoints
|
||||
"""
|
||||
from typing import List, Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
|
||||
from backend.db.database import get_db
|
||||
from backend.models import Prompt
|
||||
from backend.schemas.prompt import (
|
||||
PromptCreate, PromptUpdate, PromptResponse, PromptParse, PromptParseResult, PromptPreset
|
||||
)
|
||||
from backend.core.prompt_engine.parser import PromptParser
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# Preset prompts
|
||||
PRESET_PROMPTS = [
|
||||
{
|
||||
"id": "full_pentest",
|
||||
"name": "Full Penetration Test",
|
||||
"description": "Comprehensive security assessment covering all vulnerability categories",
|
||||
"category": "pentest",
|
||||
"content": """Perform a comprehensive penetration test on the target application.
|
||||
|
||||
Test for ALL vulnerability categories:
|
||||
- Injection vulnerabilities (XSS, SQL Injection, Command Injection, LDAP, XPath, Template Injection)
|
||||
- Authentication flaws (Broken auth, session management, JWT issues, OAuth flaws)
|
||||
- Authorization issues (IDOR, BOLA, privilege escalation, access control bypass)
|
||||
- File handling vulnerabilities (LFI, RFI, path traversal, file upload, XXE)
|
||||
- Request forgery (SSRF, CSRF)
|
||||
- API security issues (rate limiting, mass assignment, excessive data exposure)
|
||||
- Client-side vulnerabilities (CORS misconfig, clickjacking, open redirect)
|
||||
- Information disclosure (error messages, stack traces, sensitive data exposure)
|
||||
- Infrastructure issues (security headers, SSL/TLS, HTTP methods)
|
||||
- Business logic flaws (race conditions, workflow bypass)
|
||||
|
||||
Use thorough testing with multiple payloads and bypass techniques.
|
||||
Generate detailed PoC for each vulnerability found.
|
||||
Provide remediation recommendations."""
|
||||
},
|
||||
{
|
||||
"id": "owasp_top10",
|
||||
"name": "OWASP Top 10",
|
||||
"description": "Test for OWASP Top 10 2021 vulnerabilities",
|
||||
"category": "compliance",
|
||||
"content": """Test for OWASP Top 10 2021 vulnerabilities:
|
||||
|
||||
A01:2021 - Broken Access Control
|
||||
- IDOR, privilege escalation, access control bypass, CORS misconfig
|
||||
|
||||
A02:2021 - Cryptographic Failures
|
||||
- Sensitive data exposure, weak encryption, cleartext transmission
|
||||
|
||||
A03:2021 - Injection
|
||||
- SQL injection, XSS, command injection, LDAP injection
|
||||
|
||||
A04:2021 - Insecure Design
|
||||
- Business logic flaws, missing security controls
|
||||
|
||||
A05:2021 - Security Misconfiguration
|
||||
- Default configs, unnecessary features, missing headers
|
||||
|
||||
A06:2021 - Vulnerable Components
|
||||
- Outdated libraries, known CVEs
|
||||
|
||||
A07:2021 - Identification and Authentication Failures
|
||||
- Weak passwords, session fixation, credential stuffing
|
||||
|
||||
A08:2021 - Software and Data Integrity Failures
|
||||
- Insecure deserialization, CI/CD vulnerabilities
|
||||
|
||||
A09:2021 - Security Logging and Monitoring Failures
|
||||
- Missing audit logs, insufficient monitoring
|
||||
|
||||
A10:2021 - Server-Side Request Forgery (SSRF)
|
||||
- Internal network access, cloud metadata exposure"""
|
||||
},
|
||||
{
|
||||
"id": "api_security",
|
||||
"name": "API Security Testing",
|
||||
"description": "Focused testing for REST and GraphQL APIs",
|
||||
"category": "api",
|
||||
"content": """Perform API security testing:
|
||||
|
||||
Authentication & Authorization:
|
||||
- Test JWT implementation (algorithm confusion, signature bypass, claim manipulation)
|
||||
- OAuth/OIDC flow testing
|
||||
- API key exposure and validation
|
||||
- Rate limiting bypass
|
||||
- BOLA/IDOR on all endpoints
|
||||
|
||||
Input Validation:
|
||||
- SQL injection on API parameters
|
||||
- NoSQL injection
|
||||
- Command injection
|
||||
- Parameter pollution
|
||||
- Mass assignment vulnerabilities
|
||||
|
||||
Data Exposure:
|
||||
- Excessive data exposure in responses
|
||||
- Sensitive data in error messages
|
||||
- Information disclosure in headers
|
||||
- Debug endpoints exposure
|
||||
|
||||
GraphQL Specific (if applicable):
|
||||
- Introspection enabled
|
||||
- Query depth attacks
|
||||
- Batching attacks
|
||||
- Field suggestion exploitation
|
||||
|
||||
API Abuse:
|
||||
- Rate limiting effectiveness
|
||||
- Resource exhaustion
|
||||
- Denial of service vectors"""
|
||||
},
|
||||
{
|
||||
"id": "bug_bounty",
|
||||
"name": "Bug Bounty Hunter",
|
||||
"description": "Focus on high-impact, bounty-worthy vulnerabilities",
|
||||
"category": "bug_bounty",
|
||||
"content": """Hunt for high-impact vulnerabilities suitable for bug bounty:
|
||||
|
||||
Priority 1 - Critical Impact:
|
||||
- Remote Code Execution (RCE)
|
||||
- SQL Injection leading to data breach
|
||||
- Authentication bypass
|
||||
- SSRF to internal services/cloud metadata
|
||||
- Privilege escalation to admin
|
||||
|
||||
Priority 2 - High Impact:
|
||||
- Stored XSS
|
||||
- IDOR on sensitive resources
|
||||
- Account takeover vectors
|
||||
- Payment/billing manipulation
|
||||
- PII exposure
|
||||
|
||||
Priority 3 - Medium Impact:
|
||||
- Reflected XSS
|
||||
- CSRF on sensitive actions
|
||||
- Information disclosure
|
||||
- Rate limiting bypass
|
||||
- Open redirects (if exploitable)
|
||||
|
||||
Look for:
|
||||
- Unique attack chains
|
||||
- Business logic flaws
|
||||
- Edge cases and race conditions
|
||||
- Bypass techniques for existing security controls
|
||||
|
||||
Document with clear PoC and impact assessment."""
|
||||
},
|
||||
{
|
||||
"id": "quick_scan",
|
||||
"name": "Quick Security Scan",
|
||||
"description": "Fast scan for common vulnerabilities",
|
||||
"category": "quick",
|
||||
"content": """Perform a quick security scan for common vulnerabilities:
|
||||
|
||||
- Reflected XSS on input parameters
|
||||
- Basic SQL injection testing
|
||||
- Directory traversal/LFI
|
||||
- Security headers check
|
||||
- SSL/TLS configuration
|
||||
- Common misconfigurations
|
||||
- Information disclosure
|
||||
|
||||
Use minimal payloads for speed.
|
||||
Focus on quick wins and obvious issues."""
|
||||
},
|
||||
{
|
||||
"id": "auth_testing",
|
||||
"name": "Authentication Testing",
|
||||
"description": "Focus on authentication and session management",
|
||||
"category": "auth",
|
||||
"content": """Test authentication and session management:
|
||||
|
||||
Login Functionality:
|
||||
- Username enumeration
|
||||
- Password brute force protection
|
||||
- Account lockout bypass
|
||||
- Credential stuffing protection
|
||||
- SQL injection in login
|
||||
|
||||
Session Management:
|
||||
- Session token entropy
|
||||
- Session fixation
|
||||
- Session timeout
|
||||
- Cookie security flags (HttpOnly, Secure, SameSite)
|
||||
- Session invalidation on logout
|
||||
|
||||
Password Reset:
|
||||
- Token predictability
|
||||
- Token expiration
|
||||
- Account enumeration
|
||||
- Host header injection
|
||||
|
||||
Multi-Factor Authentication:
|
||||
- MFA bypass techniques
|
||||
- Backup codes weakness
|
||||
- Rate limiting on OTP
|
||||
|
||||
OAuth/SSO:
|
||||
- State parameter validation
|
||||
- Redirect URI manipulation
|
||||
- Token leakage"""
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
@router.get("/presets", response_model=List[PromptPreset])
|
||||
async def get_preset_prompts():
|
||||
"""Get list of preset prompts"""
|
||||
return [
|
||||
PromptPreset(
|
||||
id=p["id"],
|
||||
name=p["name"],
|
||||
description=p["description"],
|
||||
category=p["category"],
|
||||
vulnerability_count=len(p["content"].split("\n"))
|
||||
)
|
||||
for p in PRESET_PROMPTS
|
||||
]
|
||||
|
||||
|
||||
@router.get("/presets/{preset_id}")
|
||||
async def get_preset_prompt(preset_id: str):
|
||||
"""Get a specific preset prompt by ID"""
|
||||
for preset in PRESET_PROMPTS:
|
||||
if preset["id"] == preset_id:
|
||||
return preset
|
||||
raise HTTPException(status_code=404, detail="Preset not found")
|
||||
|
||||
|
||||
@router.post("/parse", response_model=PromptParseResult)
|
||||
async def parse_prompt(prompt_data: PromptParse):
|
||||
"""Parse a prompt to extract vulnerability types and testing scope"""
|
||||
parser = PromptParser()
|
||||
result = await parser.parse(prompt_data.content)
|
||||
return result
|
||||
|
||||
|
||||
@router.get("", response_model=List[PromptResponse])
|
||||
async def list_prompts(
|
||||
category: Optional[str] = None,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""List all custom prompts"""
|
||||
query = select(Prompt).where(Prompt.is_preset == False)
|
||||
if category:
|
||||
query = query.where(Prompt.category == category)
|
||||
query = query.order_by(Prompt.created_at.desc())
|
||||
|
||||
result = await db.execute(query)
|
||||
prompts = result.scalars().all()
|
||||
|
||||
return [PromptResponse(**p.to_dict()) for p in prompts]
|
||||
|
||||
|
||||
@router.post("", response_model=PromptResponse)
|
||||
async def create_prompt(prompt_data: PromptCreate, db: AsyncSession = Depends(get_db)):
|
||||
"""Create a custom prompt"""
|
||||
# Parse vulnerabilities from content
|
||||
parser = PromptParser()
|
||||
parsed = await parser.parse(prompt_data.content)
|
||||
|
||||
prompt = Prompt(
|
||||
name=prompt_data.name,
|
||||
description=prompt_data.description,
|
||||
content=prompt_data.content,
|
||||
category=prompt_data.category,
|
||||
is_preset=False,
|
||||
parsed_vulnerabilities=[v.dict() for v in parsed.vulnerabilities_to_test]
|
||||
)
|
||||
db.add(prompt)
|
||||
await db.commit()
|
||||
await db.refresh(prompt)
|
||||
|
||||
return PromptResponse(**prompt.to_dict())
|
||||
|
||||
|
||||
@router.get("/{prompt_id}", response_model=PromptResponse)
|
||||
async def get_prompt(prompt_id: str, db: AsyncSession = Depends(get_db)):
|
||||
"""Get a prompt by ID"""
|
||||
result = await db.execute(select(Prompt).where(Prompt.id == prompt_id))
|
||||
prompt = result.scalar_one_or_none()
|
||||
|
||||
if not prompt:
|
||||
raise HTTPException(status_code=404, detail="Prompt not found")
|
||||
|
||||
return PromptResponse(**prompt.to_dict())
|
||||
|
||||
|
||||
@router.put("/{prompt_id}", response_model=PromptResponse)
|
||||
async def update_prompt(
|
||||
prompt_id: str,
|
||||
prompt_data: PromptUpdate,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Update a prompt"""
|
||||
result = await db.execute(select(Prompt).where(Prompt.id == prompt_id))
|
||||
prompt = result.scalar_one_or_none()
|
||||
|
||||
if not prompt:
|
||||
raise HTTPException(status_code=404, detail="Prompt not found")
|
||||
|
||||
if prompt.is_preset:
|
||||
raise HTTPException(status_code=400, detail="Cannot modify preset prompts")
|
||||
|
||||
if prompt_data.name is not None:
|
||||
prompt.name = prompt_data.name
|
||||
if prompt_data.description is not None:
|
||||
prompt.description = prompt_data.description
|
||||
if prompt_data.content is not None:
|
||||
prompt.content = prompt_data.content
|
||||
# Re-parse vulnerabilities
|
||||
parser = PromptParser()
|
||||
parsed = await parser.parse(prompt_data.content)
|
||||
prompt.parsed_vulnerabilities = [v.dict() for v in parsed.vulnerabilities_to_test]
|
||||
if prompt_data.category is not None:
|
||||
prompt.category = prompt_data.category
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(prompt)
|
||||
|
||||
return PromptResponse(**prompt.to_dict())
|
||||
|
||||
|
||||
@router.delete("/{prompt_id}")
|
||||
async def delete_prompt(prompt_id: str, db: AsyncSession = Depends(get_db)):
|
||||
"""Delete a prompt"""
|
||||
result = await db.execute(select(Prompt).where(Prompt.id == prompt_id))
|
||||
prompt = result.scalar_one_or_none()
|
||||
|
||||
if not prompt:
|
||||
raise HTTPException(status_code=404, detail="Prompt not found")
|
||||
|
||||
if prompt.is_preset:
|
||||
raise HTTPException(status_code=400, detail="Cannot delete preset prompts")
|
||||
|
||||
await db.delete(prompt)
|
||||
await db.commit()
|
||||
|
||||
return {"message": "Prompt deleted"}
|
||||
|
||||
|
||||
@router.post("/upload")
|
||||
async def upload_prompt(file: UploadFile = File(...)):
|
||||
"""Upload a prompt file (.md or .txt)"""
|
||||
if not file.filename:
|
||||
raise HTTPException(status_code=400, detail="No file provided")
|
||||
|
||||
ext = "." + file.filename.split(".")[-1].lower() if "." in file.filename else ""
|
||||
if ext not in {".md", ".txt"}:
|
||||
raise HTTPException(status_code=400, detail="Invalid file type. Use .md or .txt")
|
||||
|
||||
content = await file.read()
|
||||
try:
|
||||
text = content.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
raise HTTPException(status_code=400, detail="Unable to decode file")
|
||||
|
||||
# Parse the prompt
|
||||
parser = PromptParser()
|
||||
parsed = await parser.parse(text)
|
||||
|
||||
return {
|
||||
"filename": file.filename,
|
||||
"content": text,
|
||||
"parsed": parsed.dict()
|
||||
}
|
||||
403
backend/api/v1/providers.py
Normal file
403
backend/api/v1/providers.py
Normal file
@@ -0,0 +1,403 @@
|
||||
"""
|
||||
NeuroSploit v3 - Providers API
|
||||
|
||||
REST endpoints for managing LLM providers and accounts.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class ConnectRequest(BaseModel):
|
||||
label: str = "Manual API Key"
|
||||
credential: str
|
||||
credential_type: str = "api_key"
|
||||
model_override: Optional[str] = None
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def list_providers():
|
||||
"""List all providers with their accounts and status."""
|
||||
from backend.core.smart_router import get_registry
|
||||
registry = get_registry()
|
||||
if not registry:
|
||||
return {"enabled": False, "providers": []}
|
||||
|
||||
providers = []
|
||||
for p in registry.get_all_providers():
|
||||
accounts = []
|
||||
for a in p.accounts.values():
|
||||
accounts.append({
|
||||
"id": a.id,
|
||||
"label": a.label,
|
||||
"source": a.source,
|
||||
"credential_type": a.credential_type,
|
||||
"is_active": a.is_active,
|
||||
"tokens_used": a.tokens_used,
|
||||
"last_used": a.last_used,
|
||||
"expires_at": a.expires_at,
|
||||
"model_override": a.model_override,
|
||||
})
|
||||
providers.append({
|
||||
"id": p.id,
|
||||
"name": p.name,
|
||||
"auth_type": p.auth_type,
|
||||
"api_format": p.api_format,
|
||||
"tier": p.tier,
|
||||
"default_model": p.default_model,
|
||||
"accounts": accounts,
|
||||
"connected": any(
|
||||
a.is_active and a.id in registry._credentials
|
||||
for a in p.accounts.values()
|
||||
),
|
||||
"enabled": getattr(p, "enabled", True),
|
||||
})
|
||||
|
||||
return {"enabled": True, "providers": providers}
|
||||
|
||||
|
||||
@router.get("/status")
|
||||
async def providers_status():
|
||||
"""Get quota and usage summary."""
|
||||
from backend.core.smart_router import get_router
|
||||
router_instance = get_router()
|
||||
if not router_instance:
|
||||
return {"enabled": False}
|
||||
return {"enabled": True, **router_instance.get_status()}
|
||||
|
||||
|
||||
@router.post("/{provider_id}/detect")
|
||||
async def detect_cli_token(provider_id: str):
|
||||
"""Auto-detect CLI token for a specific provider."""
|
||||
from backend.core.smart_router import get_registry, get_extractor
|
||||
registry = get_registry()
|
||||
extractor = get_extractor()
|
||||
if not registry or not extractor:
|
||||
raise HTTPException(400, "Smart Router not enabled")
|
||||
|
||||
token = extractor.detect(provider_id)
|
||||
if not token:
|
||||
return {"detected": False, "message": f"No CLI token found for {provider_id}"}
|
||||
|
||||
# Add to registry
|
||||
acct_id = registry.add_account(
|
||||
provider_id=provider_id,
|
||||
label=token.label,
|
||||
credential=token.token,
|
||||
credential_type=token.credential_type,
|
||||
source="cli_detect",
|
||||
refresh_token=token.refresh_token,
|
||||
expires_at=token.expires_at,
|
||||
)
|
||||
|
||||
return {
|
||||
"detected": True,
|
||||
"account_id": acct_id,
|
||||
"label": token.label,
|
||||
"credential_type": token.credential_type,
|
||||
"has_refresh_token": token.refresh_token is not None,
|
||||
"expires_at": token.expires_at,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/{provider_id}/connect")
|
||||
async def connect_provider(provider_id: str, req: ConnectRequest):
|
||||
"""Manually add an API key or credential."""
|
||||
from backend.core.smart_router import get_registry
|
||||
registry = get_registry()
|
||||
if not registry:
|
||||
raise HTTPException(400, "Smart Router not enabled")
|
||||
|
||||
acct_id = registry.add_account(
|
||||
provider_id=provider_id,
|
||||
label=req.label,
|
||||
credential=req.credential,
|
||||
credential_type=req.credential_type,
|
||||
source="manual",
|
||||
model_override=req.model_override,
|
||||
)
|
||||
if not acct_id:
|
||||
raise HTTPException(404, f"Unknown provider: {provider_id}")
|
||||
|
||||
return {"success": True, "account_id": acct_id}
|
||||
|
||||
|
||||
@router.delete("/{provider_id}/accounts/{account_id}")
|
||||
async def remove_account(provider_id: str, account_id: str):
|
||||
"""Remove an account from a provider."""
|
||||
from backend.core.smart_router import get_registry
|
||||
registry = get_registry()
|
||||
if not registry:
|
||||
raise HTTPException(400, "Smart Router not enabled")
|
||||
|
||||
success = registry.remove_account(provider_id, account_id)
|
||||
if not success:
|
||||
raise HTTPException(404, "Account not found")
|
||||
return {"success": True}
|
||||
|
||||
|
||||
@router.post("/test/{provider_id}/{account_id}")
|
||||
async def test_connection(provider_id: str, account_id: str):
|
||||
"""Test connectivity for a specific account."""
|
||||
from backend.core.smart_router import get_router
|
||||
router_instance = get_router()
|
||||
if not router_instance:
|
||||
raise HTTPException(400, "Smart Router not enabled")
|
||||
|
||||
success, message = await router_instance.test_account(provider_id, account_id)
|
||||
return {"success": success, "message": message}
|
||||
|
||||
|
||||
# Known models per provider for dropdown selection
|
||||
PROVIDER_MODELS = {
|
||||
"claude_code": [
|
||||
"claude-opus-4-6-20250918",
|
||||
"claude-sonnet-4-5-20250929",
|
||||
"claude-haiku-4-5-20251001",
|
||||
"claude-sonnet-4-20250514",
|
||||
"claude-opus-4-20250514",
|
||||
"claude-haiku-4-20250514",
|
||||
],
|
||||
"kiro": [
|
||||
"claude-opus-4-6-20250918",
|
||||
"claude-sonnet-4-5-20250929",
|
||||
"claude-haiku-4-5-20251001",
|
||||
"claude-sonnet-4-20250514",
|
||||
"claude-opus-4-20250514",
|
||||
"claude-haiku-4-20250514",
|
||||
],
|
||||
"anthropic": [
|
||||
"claude-opus-4-6-20250918",
|
||||
"claude-sonnet-4-5-20250929",
|
||||
"claude-haiku-4-5-20251001",
|
||||
"claude-sonnet-4-20250514",
|
||||
"claude-opus-4-20250514",
|
||||
"claude-haiku-4-20250514",
|
||||
"claude-3-5-sonnet-20241022",
|
||||
],
|
||||
"codex_cli": [
|
||||
"gpt-4o",
|
||||
"gpt-4o-mini",
|
||||
"o3-mini",
|
||||
"o4-mini",
|
||||
"gpt-4.1",
|
||||
"gpt-4.1-mini",
|
||||
"gpt-4.1-nano",
|
||||
],
|
||||
"openai": [
|
||||
"gpt-4o",
|
||||
"gpt-4o-mini",
|
||||
"o3-mini",
|
||||
"o4-mini",
|
||||
"gpt-4.1",
|
||||
"gpt-4.1-mini",
|
||||
"gpt-4.1-nano",
|
||||
],
|
||||
"gemini_cli": [
|
||||
"gemini-3.0-pro",
|
||||
"gemini-2.5-pro",
|
||||
"gemini-2.5-flash",
|
||||
"gemini-2.0-flash",
|
||||
"gemini-2.0-flash-lite",
|
||||
],
|
||||
"gemini": [
|
||||
"gemini-3.0-pro",
|
||||
"gemini-2.5-pro",
|
||||
"gemini-2.5-flash",
|
||||
"gemini-2.0-flash",
|
||||
"gemini-2.0-flash-lite",
|
||||
],
|
||||
"cursor": [
|
||||
"cursor-fast",
|
||||
"cursor-small",
|
||||
"gpt-4o",
|
||||
"claude-sonnet-4-5-20250929",
|
||||
"claude-3-5-sonnet-20241022",
|
||||
],
|
||||
"copilot": [
|
||||
"gpt-4o",
|
||||
"gpt-4o-mini",
|
||||
"claude-sonnet-4-5-20250929",
|
||||
"claude-3-5-sonnet-20241022",
|
||||
],
|
||||
"openrouter": [
|
||||
"anthropic/claude-opus-4-6",
|
||||
"anthropic/claude-sonnet-4-5",
|
||||
"anthropic/claude-haiku-4-5",
|
||||
"anthropic/claude-sonnet-4",
|
||||
"anthropic/claude-opus-4",
|
||||
"openai/gpt-4o",
|
||||
"google/gemini-3.0-pro",
|
||||
"google/gemini-2.5-pro",
|
||||
"google/gemini-2.5-flash",
|
||||
"meta-llama/llama-4-maverick",
|
||||
"deepseek/deepseek-r1",
|
||||
],
|
||||
"together": [
|
||||
"meta-llama/Llama-3-70b-chat-hf",
|
||||
"meta-llama/Llama-3.3-70B-Instruct-Turbo",
|
||||
"deepseek-ai/DeepSeek-R1",
|
||||
"Qwen/Qwen2.5-72B-Instruct-Turbo",
|
||||
],
|
||||
"fireworks": [
|
||||
"accounts/fireworks/models/llama-v3p1-70b-instruct",
|
||||
"accounts/fireworks/models/llama-v3p3-70b-instruct",
|
||||
"accounts/fireworks/models/deepseek-r1",
|
||||
],
|
||||
"iflow": ["kimi-k2"],
|
||||
"qwen_code": ["qwen3-coder", "qwen-max"],
|
||||
"ollama": ["llama3", "llama3.2", "mistral", "codellama", "deepseek-r1"],
|
||||
"lmstudio": ["local-model"],
|
||||
}
|
||||
|
||||
|
||||
@router.get("/available-models")
|
||||
async def available_models():
|
||||
"""Get list of available provider+model combinations for selection dropdowns."""
|
||||
from backend.core.smart_router import get_registry
|
||||
registry = get_registry()
|
||||
if not registry:
|
||||
return {"models": []}
|
||||
|
||||
models = []
|
||||
for p in registry.get_all_providers():
|
||||
active = registry.get_active_accounts(p.id)
|
||||
if not active:
|
||||
continue
|
||||
models.append({
|
||||
"provider_id": p.id,
|
||||
"provider_name": p.name,
|
||||
"default_model": p.default_model,
|
||||
"tier": p.tier,
|
||||
"available_models": PROVIDER_MODELS.get(p.id, [p.default_model]),
|
||||
})
|
||||
|
||||
# Sort by tier (paid first) then name
|
||||
models.sort(key=lambda m: (m["tier"], m["provider_name"]))
|
||||
return {"models": models}
|
||||
|
||||
|
||||
@router.post("/detect-all")
|
||||
async def detect_all_tokens():
|
||||
"""Scan all CLI tools for available tokens."""
|
||||
from backend.core.smart_router import get_registry, get_extractor
|
||||
registry = get_registry()
|
||||
extractor = get_extractor()
|
||||
if not registry or not extractor:
|
||||
raise HTTPException(400, "Smart Router not enabled")
|
||||
|
||||
tokens = extractor.detect_all()
|
||||
results = []
|
||||
for token in tokens:
|
||||
acct_id = registry.add_account(
|
||||
provider_id=token.provider_id,
|
||||
label=token.label,
|
||||
credential=token.token,
|
||||
credential_type=token.credential_type,
|
||||
source="cli_detect",
|
||||
refresh_token=token.refresh_token,
|
||||
expires_at=token.expires_at,
|
||||
)
|
||||
results.append({
|
||||
"provider_id": token.provider_id,
|
||||
"label": token.label,
|
||||
"account_id": acct_id,
|
||||
})
|
||||
|
||||
return {
|
||||
"detected_count": len(results),
|
||||
"results": results,
|
||||
}
|
||||
|
||||
|
||||
class ToggleRequest(BaseModel):
|
||||
enabled: bool
|
||||
|
||||
|
||||
@router.post("/{provider_id}/toggle")
|
||||
async def toggle_provider(provider_id: str, req: ToggleRequest):
|
||||
"""Enable or disable a provider. Disabled providers are skipped by the router."""
|
||||
from backend.core.smart_router import get_registry
|
||||
registry = get_registry()
|
||||
if not registry:
|
||||
raise HTTPException(400, "Smart Router not enabled")
|
||||
|
||||
success = registry.toggle_provider(provider_id, req.enabled)
|
||||
if not success:
|
||||
raise HTTPException(404, f"Unknown provider: {provider_id}")
|
||||
|
||||
return {"success": True, "provider_id": provider_id, "enabled": req.enabled}
|
||||
|
||||
|
||||
# Whitelist of env keys that can be modified via UI
|
||||
ALLOWED_ENV_KEYS = {
|
||||
"ANTHROPIC_API_KEY", "OPENAI_API_KEY", "GEMINI_API_KEY", "GOOGLE_API_KEY",
|
||||
"OPENROUTER_API_KEY", "TOGETHER_API_KEY", "FIREWORKS_API_KEY",
|
||||
"OLLAMA_HOST", "LMSTUDIO_HOST",
|
||||
"ENABLE_SMART_ROUTER", "ENABLE_REASONING", "ENABLE_CVE_HUNT",
|
||||
"ENABLE_MULTI_AGENT", "ENABLE_RESEARCHER_AI",
|
||||
"NVD_API_KEY", "GITHUB_TOKEN", "TOKEN_BUDGET",
|
||||
}
|
||||
|
||||
|
||||
class EnvUpdateRequest(BaseModel):
|
||||
key: str
|
||||
value: str
|
||||
|
||||
|
||||
@router.get("/env")
|
||||
async def get_env_keys():
|
||||
"""Get current values of allowed env keys (masked for secrets)."""
|
||||
import os
|
||||
result = {}
|
||||
for key in sorted(ALLOWED_ENV_KEYS):
|
||||
val = os.getenv(key, "")
|
||||
if val and "KEY" in key and key not in ("ENABLE_SMART_ROUTER", "ENABLE_REASONING",
|
||||
"ENABLE_CVE_HUNT", "ENABLE_MULTI_AGENT",
|
||||
"ENABLE_RESEARCHER_AI", "TOKEN_BUDGET"):
|
||||
# Mask API keys: show first 8 and last 4 chars
|
||||
if len(val) > 16:
|
||||
result[key] = val[:8] + "..." + val[-4:]
|
||||
else:
|
||||
result[key] = "****"
|
||||
else:
|
||||
result[key] = val
|
||||
return {"env": result, "allowed_keys": sorted(ALLOWED_ENV_KEYS)}
|
||||
|
||||
|
||||
@router.post("/env")
|
||||
async def update_env_key(req: EnvUpdateRequest):
|
||||
"""Update an env var and persist to .env file."""
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
if req.key not in ALLOWED_ENV_KEYS:
|
||||
raise HTTPException(400, f"Key '{req.key}' is not in the allowed whitelist")
|
||||
|
||||
# Update in-process env
|
||||
os.environ[req.key] = req.value
|
||||
|
||||
# Persist to .env file
|
||||
env_path = Path(__file__).parent.parent.parent.parent / ".env"
|
||||
try:
|
||||
lines = []
|
||||
found = False
|
||||
if env_path.exists():
|
||||
for line in env_path.read_text().splitlines():
|
||||
stripped = line.strip()
|
||||
if stripped.startswith(f"{req.key}=") or stripped.startswith(f"# {req.key}="):
|
||||
lines.append(f"{req.key}={req.value}")
|
||||
found = True
|
||||
else:
|
||||
lines.append(line)
|
||||
if not found:
|
||||
lines.append(f"{req.key}={req.value}")
|
||||
env_path.write_text("\n".join(lines) + "\n")
|
||||
except Exception as e:
|
||||
# Still updated in-process even if file write failed
|
||||
return {"success": True, "persisted": False, "error": str(e)}
|
||||
|
||||
return {"success": True, "persisted": True}
|
||||
387
backend/api/v1/reports.py
Executable file
387
backend/api/v1/reports.py
Executable file
@@ -0,0 +1,387 @@
|
||||
"""
|
||||
NeuroSploit v3 - Reports API Endpoints
|
||||
"""
|
||||
from typing import List, Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import FileResponse, HTMLResponse
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
from pathlib import Path
|
||||
|
||||
from backend.db.database import get_db
|
||||
from backend.models import Scan, Report, Vulnerability, Endpoint
|
||||
from backend.schemas.report import ReportGenerate, ReportResponse, ReportListResponse
|
||||
from backend.core.report_engine.generator import ReportGenerator
|
||||
from backend.config import settings
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("", response_model=ReportListResponse)
|
||||
async def list_reports(
|
||||
scan_id: Optional[str] = None,
|
||||
auto_generated: Optional[bool] = None,
|
||||
is_partial: Optional[bool] = None,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""List all reports with optional filtering"""
|
||||
query = select(Report).order_by(Report.generated_at.desc())
|
||||
|
||||
if scan_id:
|
||||
query = query.where(Report.scan_id == scan_id)
|
||||
|
||||
if auto_generated is not None:
|
||||
query = query.where(Report.auto_generated == auto_generated)
|
||||
|
||||
if is_partial is not None:
|
||||
query = query.where(Report.is_partial == is_partial)
|
||||
|
||||
result = await db.execute(query)
|
||||
reports = result.scalars().all()
|
||||
|
||||
return ReportListResponse(
|
||||
reports=[ReportResponse(**r.to_dict()) for r in reports],
|
||||
total=len(reports)
|
||||
)
|
||||
|
||||
|
||||
@router.post("", response_model=ReportResponse)
|
||||
async def generate_report(
|
||||
report_data: ReportGenerate,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Generate a new report for a scan"""
|
||||
# Get scan
|
||||
scan_result = await db.execute(select(Scan).where(Scan.id == report_data.scan_id))
|
||||
scan = scan_result.scalar_one_or_none()
|
||||
|
||||
if not scan:
|
||||
raise HTTPException(status_code=404, detail="Scan not found")
|
||||
|
||||
# Get vulnerabilities
|
||||
vulns_result = await db.execute(
|
||||
select(Vulnerability).where(Vulnerability.scan_id == report_data.scan_id)
|
||||
)
|
||||
vulnerabilities = vulns_result.scalars().all()
|
||||
|
||||
# Try to get tool_executions from agent in-memory results
|
||||
tool_executions = []
|
||||
try:
|
||||
from backend.api.v1.agent import scan_to_agent, agent_results
|
||||
agent_id = scan_to_agent.get(report_data.scan_id)
|
||||
if agent_id and agent_id in agent_results:
|
||||
tool_executions = agent_results[agent_id].get("tool_executions", [])
|
||||
if not tool_executions:
|
||||
rpt = agent_results[agent_id].get("report", {})
|
||||
tool_executions = rpt.get("tool_executions", []) if isinstance(rpt, dict) else []
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Get endpoints
|
||||
endpoints_result = await db.execute(
|
||||
select(Endpoint).where(Endpoint.scan_id == report_data.scan_id)
|
||||
)
|
||||
endpoints = endpoints_result.scalars().all()
|
||||
|
||||
# Generate report
|
||||
generator = ReportGenerator()
|
||||
report_path, executive_summary = await generator.generate(
|
||||
scan=scan,
|
||||
vulnerabilities=vulnerabilities,
|
||||
format=report_data.format,
|
||||
title=report_data.title,
|
||||
include_executive_summary=report_data.include_executive_summary,
|
||||
include_poc=report_data.include_poc,
|
||||
include_remediation=report_data.include_remediation,
|
||||
tool_executions=tool_executions,
|
||||
endpoints=endpoints,
|
||||
)
|
||||
|
||||
# Save report record
|
||||
report = Report(
|
||||
scan_id=scan.id,
|
||||
title=report_data.title or f"Report - {scan.name}",
|
||||
format=report_data.format,
|
||||
file_path=str(report_path),
|
||||
executive_summary=executive_summary
|
||||
)
|
||||
db.add(report)
|
||||
await db.commit()
|
||||
await db.refresh(report)
|
||||
|
||||
return ReportResponse(**report.to_dict())
|
||||
|
||||
|
||||
@router.post("/ai-generate", response_model=ReportResponse)
|
||||
async def generate_ai_report(
|
||||
report_data: ReportGenerate,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Generate an AI-enhanced report with LLM-written executive summary and per-finding analysis."""
|
||||
# Get scan
|
||||
scan_result = await db.execute(select(Scan).where(Scan.id == report_data.scan_id))
|
||||
scan = scan_result.scalar_one_or_none()
|
||||
|
||||
if not scan:
|
||||
raise HTTPException(status_code=404, detail="Scan not found")
|
||||
|
||||
# Get vulnerabilities
|
||||
vulns_result = await db.execute(
|
||||
select(Vulnerability).where(Vulnerability.scan_id == report_data.scan_id)
|
||||
)
|
||||
vulnerabilities = vulns_result.scalars().all()
|
||||
|
||||
# Try to get tool_executions from agent in-memory results
|
||||
tool_executions = []
|
||||
try:
|
||||
from backend.api.v1.agent import scan_to_agent, agent_results
|
||||
agent_id = scan_to_agent.get(report_data.scan_id)
|
||||
if agent_id and agent_id in agent_results:
|
||||
tool_executions = agent_results[agent_id].get("tool_executions", [])
|
||||
# Also check nested report
|
||||
if not tool_executions:
|
||||
rpt = agent_results[agent_id].get("report", {})
|
||||
tool_executions = rpt.get("tool_executions", []) if isinstance(rpt, dict) else []
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Generate AI report
|
||||
generator = ReportGenerator()
|
||||
try:
|
||||
report_path, ai_summary = await generator.generate_ai_report(
|
||||
scan=scan,
|
||||
vulnerabilities=vulnerabilities,
|
||||
tool_executions=tool_executions,
|
||||
title=report_data.title,
|
||||
preferred_provider=report_data.preferred_provider,
|
||||
preferred_model=report_data.preferred_model,
|
||||
)
|
||||
except Exception as e:
|
||||
import logging
|
||||
logging.getLogger(__name__).error(f"AI report generation failed: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"AI report generation failed: {str(e)}")
|
||||
|
||||
# Save report record
|
||||
report = Report(
|
||||
scan_id=scan.id,
|
||||
title=report_data.title or f"AI Report - {scan.name}",
|
||||
format="html",
|
||||
file_path=str(report_path),
|
||||
executive_summary=ai_summary[:2000] if ai_summary else None
|
||||
)
|
||||
db.add(report)
|
||||
await db.commit()
|
||||
await db.refresh(report)
|
||||
|
||||
return ReportResponse(**report.to_dict())
|
||||
|
||||
|
||||
@router.get("/{report_id}", response_model=ReportResponse)
|
||||
async def get_report(report_id: str, db: AsyncSession = Depends(get_db)):
|
||||
"""Get report details"""
|
||||
result = await db.execute(select(Report).where(Report.id == report_id))
|
||||
report = result.scalar_one_or_none()
|
||||
|
||||
if not report:
|
||||
raise HTTPException(status_code=404, detail="Report not found")
|
||||
|
||||
return ReportResponse(**report.to_dict())
|
||||
|
||||
|
||||
@router.get("/{report_id}/view")
|
||||
async def view_report(report_id: str, db: AsyncSession = Depends(get_db)):
|
||||
"""View report in browser (HTML)"""
|
||||
result = await db.execute(select(Report).where(Report.id == report_id))
|
||||
report = result.scalar_one_or_none()
|
||||
|
||||
if not report:
|
||||
raise HTTPException(status_code=404, detail="Report not found")
|
||||
|
||||
if not report.file_path:
|
||||
raise HTTPException(status_code=404, detail="Report file not found")
|
||||
|
||||
file_path = Path(report.file_path)
|
||||
if not file_path.exists():
|
||||
raise HTTPException(status_code=404, detail="Report file not found on disk")
|
||||
|
||||
if report.format == "html":
|
||||
content = file_path.read_text()
|
||||
return HTMLResponse(content=content)
|
||||
else:
|
||||
return FileResponse(
|
||||
path=str(file_path),
|
||||
media_type="application/octet-stream",
|
||||
filename=file_path.name
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{report_id}/download/{format}")
|
||||
async def download_report(
|
||||
report_id: str,
|
||||
format: str,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Download report in specified format"""
|
||||
result = await db.execute(select(Report).where(Report.id == report_id))
|
||||
report = result.scalar_one_or_none()
|
||||
|
||||
if not report:
|
||||
raise HTTPException(status_code=404, detail="Report not found")
|
||||
|
||||
# Get scan and vulnerabilities for generating report
|
||||
scan_result = await db.execute(select(Scan).where(Scan.id == report.scan_id))
|
||||
scan = scan_result.scalar_one_or_none()
|
||||
|
||||
if not scan:
|
||||
raise HTTPException(status_code=404, detail="Scan not found for report")
|
||||
|
||||
vulns_result = await db.execute(
|
||||
select(Vulnerability).where(Vulnerability.scan_id == report.scan_id)
|
||||
)
|
||||
vulnerabilities = vulns_result.scalars().all()
|
||||
|
||||
# Always generate fresh report file (handles auto-generated reports without file_path)
|
||||
generator = ReportGenerator()
|
||||
report_path, _ = await generator.generate(
|
||||
scan=scan,
|
||||
vulnerabilities=vulnerabilities,
|
||||
format=format,
|
||||
title=report.title
|
||||
)
|
||||
file_path = Path(report_path)
|
||||
|
||||
# Update report with file path if not set
|
||||
if not report.file_path:
|
||||
report.file_path = str(file_path)
|
||||
report.format = format
|
||||
await db.commit()
|
||||
|
||||
if not file_path.exists():
|
||||
raise HTTPException(status_code=404, detail="Report file not found")
|
||||
|
||||
media_types = {
|
||||
"html": "text/html",
|
||||
"pdf": "application/pdf",
|
||||
"json": "application/json"
|
||||
}
|
||||
|
||||
return FileResponse(
|
||||
path=str(file_path),
|
||||
media_type=media_types.get(format, "application/octet-stream"),
|
||||
filename=file_path.name
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{report_id}/download-zip")
|
||||
async def download_report_zip(
|
||||
report_id: str,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Download report as ZIP with screenshots included"""
|
||||
import zipfile
|
||||
import tempfile
|
||||
import hashlib
|
||||
|
||||
result = await db.execute(select(Report).where(Report.id == report_id))
|
||||
report = result.scalar_one_or_none()
|
||||
|
||||
if not report:
|
||||
raise HTTPException(status_code=404, detail="Report not found")
|
||||
|
||||
scan_result = await db.execute(select(Scan).where(Scan.id == report.scan_id))
|
||||
scan = scan_result.scalar_one_or_none()
|
||||
|
||||
if not scan:
|
||||
raise HTTPException(status_code=404, detail="Scan not found for report")
|
||||
|
||||
vulns_result = await db.execute(
|
||||
select(Vulnerability).where(Vulnerability.scan_id == report.scan_id)
|
||||
)
|
||||
vulnerabilities = vulns_result.scalars().all()
|
||||
|
||||
# Generate HTML report
|
||||
generator = ReportGenerator()
|
||||
report_path, _ = await generator.generate(
|
||||
scan=scan,
|
||||
vulnerabilities=vulnerabilities,
|
||||
format="html",
|
||||
title=report.title
|
||||
)
|
||||
|
||||
# Collect screenshots (use absolute path via settings.BASE_DIR)
|
||||
# Check scan-scoped path first, then legacy flat path
|
||||
screenshots_base = settings.BASE_DIR / "reports" / "screenshots"
|
||||
scan_id_str = str(scan.id) if scan else None
|
||||
screenshot_files = []
|
||||
for vuln in vulnerabilities:
|
||||
# Finding ID is md5(vuln_type+url+param)[:8]
|
||||
vuln_url = getattr(vuln, 'url', None) or vuln.affected_endpoint or ''
|
||||
vuln_param = getattr(vuln, 'parameter', None) or getattr(vuln, 'poc_parameter', None) or ''
|
||||
finding_id = hashlib.md5(
|
||||
f"{vuln.vulnerability_type}{vuln_url}{vuln_param}".encode()
|
||||
).hexdigest()[:8]
|
||||
# Scan-scoped path: reports/screenshots/{scan_id}/{finding_id}/
|
||||
finding_dir = None
|
||||
if scan_id_str:
|
||||
scan_dir = screenshots_base / scan_id_str / finding_id
|
||||
if scan_dir.exists():
|
||||
finding_dir = scan_dir
|
||||
if not finding_dir:
|
||||
legacy_dir = screenshots_base / finding_id
|
||||
if legacy_dir.exists():
|
||||
finding_dir = legacy_dir
|
||||
if finding_dir:
|
||||
for img in finding_dir.glob("*.png"):
|
||||
screenshot_files.append((img, f"screenshots/{finding_id}/{img.name}"))
|
||||
# Also include base64 screenshots from DB as files in the ZIP
|
||||
db_screenshots = getattr(vuln, 'screenshots', None) or []
|
||||
for idx, ss in enumerate(db_screenshots):
|
||||
if isinstance(ss, str) and ss.startswith("data:image/"):
|
||||
# Will be embedded in HTML, but also save as file
|
||||
import base64 as b64
|
||||
try:
|
||||
b64_data = ss.split(",", 1)[1]
|
||||
img_bytes = b64.b64decode(b64_data)
|
||||
img_name = f"screenshots/{finding_id}/evidence_{idx+1}.png"
|
||||
# Write to temp for ZIP inclusion
|
||||
tmp_img = Path(tempfile.gettempdir()) / f"ss_{finding_id}_{idx}.png"
|
||||
tmp_img.write_bytes(img_bytes)
|
||||
screenshot_files.append((tmp_img, img_name))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Create ZIP
|
||||
zip_name = Path(report_path).stem + ".zip"
|
||||
zip_path = Path(tempfile.gettempdir()) / zip_name
|
||||
|
||||
with zipfile.ZipFile(str(zip_path), 'w', zipfile.ZIP_DEFLATED) as zf:
|
||||
zf.write(report_path, "report.html")
|
||||
for src_path, arc_name in screenshot_files:
|
||||
zf.write(str(src_path), arc_name)
|
||||
|
||||
return FileResponse(
|
||||
path=str(zip_path),
|
||||
media_type="application/zip",
|
||||
filename=zip_name
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/{report_id}")
|
||||
async def delete_report(report_id: str, db: AsyncSession = Depends(get_db)):
|
||||
"""Delete a report"""
|
||||
result = await db.execute(select(Report).where(Report.id == report_id))
|
||||
report = result.scalar_one_or_none()
|
||||
|
||||
if not report:
|
||||
raise HTTPException(status_code=404, detail="Report not found")
|
||||
|
||||
# Delete file if exists
|
||||
if report.file_path:
|
||||
file_path = Path(report.file_path)
|
||||
if file_path.exists():
|
||||
file_path.unlink()
|
||||
|
||||
await db.delete(report)
|
||||
await db.commit()
|
||||
|
||||
return {"message": "Report deleted"}
|
||||
130
backend/api/v1/sandbox.py
Executable file
130
backend/api/v1/sandbox.py
Executable file
@@ -0,0 +1,130 @@
|
||||
"""
|
||||
NeuroSploit v3 - Sandbox Container Management API
|
||||
|
||||
Real-time monitoring and management of per-scan Kali Linux containers.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from fastapi import APIRouter, HTTPException
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def _docker_available() -> bool:
|
||||
try:
|
||||
import docker
|
||||
docker.from_env().ping()
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
@router.get("/")
|
||||
async def list_sandboxes():
|
||||
"""List all sandbox containers with pool status."""
|
||||
try:
|
||||
from core.container_pool import get_pool
|
||||
pool = get_pool()
|
||||
except Exception as e:
|
||||
return {
|
||||
"pool": {
|
||||
"active": 0,
|
||||
"max_concurrent": 0,
|
||||
"image": "neurosploit-kali:latest",
|
||||
"container_ttl_minutes": 60,
|
||||
"docker_available": _docker_available(),
|
||||
},
|
||||
"containers": [],
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
sandboxes = pool.list_sandboxes()
|
||||
now = datetime.utcnow()
|
||||
|
||||
containers = []
|
||||
for info in sandboxes.values():
|
||||
created = info.get("created_at")
|
||||
uptime = 0.0
|
||||
if created:
|
||||
try:
|
||||
dt = datetime.fromisoformat(created)
|
||||
uptime = (now - dt).total_seconds()
|
||||
except Exception:
|
||||
pass
|
||||
containers.append({
|
||||
**info,
|
||||
"uptime_seconds": uptime,
|
||||
})
|
||||
|
||||
return {
|
||||
"pool": {
|
||||
"active": pool.active_count,
|
||||
"max_concurrent": pool.max_concurrent,
|
||||
"image": pool.image,
|
||||
"container_ttl_minutes": int(pool.container_ttl.total_seconds() / 60),
|
||||
"docker_available": _docker_available(),
|
||||
},
|
||||
"containers": containers,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/{scan_id}")
|
||||
async def get_sandbox(scan_id: str):
|
||||
"""Get health check for a specific sandbox container."""
|
||||
try:
|
||||
from core.container_pool import get_pool
|
||||
pool = get_pool()
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=503, detail=str(e))
|
||||
|
||||
sandboxes = pool.list_sandboxes()
|
||||
if scan_id not in sandboxes:
|
||||
raise HTTPException(status_code=404, detail=f"No sandbox for scan {scan_id}")
|
||||
|
||||
sb = pool._sandboxes.get(scan_id)
|
||||
if not sb:
|
||||
raise HTTPException(status_code=404, detail=f"Sandbox instance not found")
|
||||
|
||||
health = await sb.health_check()
|
||||
return health
|
||||
|
||||
|
||||
@router.delete("/{scan_id}")
|
||||
async def destroy_sandbox(scan_id: str):
|
||||
"""Destroy a specific sandbox container."""
|
||||
try:
|
||||
from core.container_pool import get_pool
|
||||
pool = get_pool()
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=503, detail=str(e))
|
||||
|
||||
sandboxes = pool.list_sandboxes()
|
||||
if scan_id not in sandboxes:
|
||||
raise HTTPException(status_code=404, detail=f"No sandbox for scan {scan_id}")
|
||||
|
||||
await pool.destroy(scan_id)
|
||||
return {"message": f"Sandbox for scan {scan_id} destroyed", "scan_id": scan_id}
|
||||
|
||||
|
||||
@router.post("/cleanup")
|
||||
async def cleanup_expired():
|
||||
"""Remove containers that have exceeded their TTL."""
|
||||
try:
|
||||
from core.container_pool import get_pool
|
||||
pool = get_pool()
|
||||
await pool.cleanup_expired()
|
||||
return {"message": "Expired containers cleaned up"}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=503, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/cleanup-orphans")
|
||||
async def cleanup_orphans():
|
||||
"""Remove orphan containers not tracked by the pool."""
|
||||
try:
|
||||
from core.container_pool import get_pool
|
||||
pool = get_pool()
|
||||
await pool.cleanup_orphans()
|
||||
return {"message": "Orphan containers cleaned up"}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=503, detail=str(e))
|
||||
656
backend/api/v1/scans.py
Executable file
656
backend/api/v1/scans.py
Executable file
@@ -0,0 +1,656 @@
|
||||
"""
|
||||
NeuroSploit v3 - Scans API Endpoints
|
||||
"""
|
||||
from typing import List, Optional
|
||||
from datetime import datetime
|
||||
from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, func
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from backend.db.database import get_db
|
||||
from backend.models import Scan, Target, Endpoint, Vulnerability
|
||||
from backend.schemas.scan import ScanCreate, ScanUpdate, ScanResponse, ScanListResponse, ScanProgress
|
||||
from backend.services.scan_service import run_scan_task, skip_to_phase as _skip_to_phase, PHASE_ORDER
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("", response_model=ScanListResponse)
|
||||
async def list_scans(
|
||||
page: int = 1,
|
||||
per_page: int = 10,
|
||||
status: Optional[str] = None,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""List all scans with pagination"""
|
||||
query = select(Scan).order_by(Scan.created_at.desc())
|
||||
|
||||
if status:
|
||||
query = query.where(Scan.status == status)
|
||||
|
||||
# Get total count
|
||||
count_query = select(func.count()).select_from(Scan)
|
||||
if status:
|
||||
count_query = count_query.where(Scan.status == status)
|
||||
total_result = await db.execute(count_query)
|
||||
total = total_result.scalar()
|
||||
|
||||
# Apply pagination
|
||||
query = query.offset((page - 1) * per_page).limit(per_page)
|
||||
result = await db.execute(query)
|
||||
scans = result.scalars().all()
|
||||
|
||||
# Load targets for each scan
|
||||
scan_responses = []
|
||||
for scan in scans:
|
||||
targets_query = select(Target).where(Target.scan_id == scan.id)
|
||||
targets_result = await db.execute(targets_query)
|
||||
targets = targets_result.scalars().all()
|
||||
|
||||
scan_dict = scan.to_dict()
|
||||
scan_dict["targets"] = [t.to_dict() for t in targets]
|
||||
scan_responses.append(ScanResponse(**scan_dict))
|
||||
|
||||
return ScanListResponse(
|
||||
scans=scan_responses,
|
||||
total=total,
|
||||
page=page,
|
||||
per_page=per_page
|
||||
)
|
||||
|
||||
|
||||
@router.post("", response_model=ScanResponse)
|
||||
async def create_scan(
|
||||
scan_data: ScanCreate,
|
||||
background_tasks: BackgroundTasks,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Create a new scan with optional authentication for authenticated testing"""
|
||||
# Process authentication config
|
||||
auth_type = None
|
||||
auth_credentials = None
|
||||
if scan_data.auth:
|
||||
auth_type = scan_data.auth.auth_type
|
||||
auth_credentials = {}
|
||||
if scan_data.auth.cookie:
|
||||
auth_credentials["cookie"] = scan_data.auth.cookie
|
||||
if scan_data.auth.bearer_token:
|
||||
auth_credentials["bearer_token"] = scan_data.auth.bearer_token
|
||||
if scan_data.auth.username:
|
||||
auth_credentials["username"] = scan_data.auth.username
|
||||
if scan_data.auth.password:
|
||||
auth_credentials["password"] = scan_data.auth.password
|
||||
if scan_data.auth.header_name and scan_data.auth.header_value:
|
||||
auth_credentials["header_name"] = scan_data.auth.header_name
|
||||
auth_credentials["header_value"] = scan_data.auth.header_value
|
||||
|
||||
# Create scan
|
||||
scan = Scan(
|
||||
name=scan_data.name or f"Scan {datetime.now().strftime('%Y-%m-%d %H:%M')}",
|
||||
scan_type=scan_data.scan_type,
|
||||
recon_enabled=scan_data.recon_enabled,
|
||||
custom_prompt=scan_data.custom_prompt,
|
||||
prompt_id=scan_data.prompt_id,
|
||||
config=scan_data.config,
|
||||
auth_type=auth_type,
|
||||
auth_credentials=auth_credentials,
|
||||
custom_headers=scan_data.custom_headers,
|
||||
status="pending"
|
||||
)
|
||||
db.add(scan)
|
||||
await db.flush()
|
||||
|
||||
# Create targets
|
||||
targets = []
|
||||
for url in scan_data.targets:
|
||||
parsed = urlparse(url)
|
||||
target = Target(
|
||||
scan_id=scan.id,
|
||||
url=url,
|
||||
hostname=parsed.hostname,
|
||||
port=parsed.port or (443 if parsed.scheme == "https" else 80),
|
||||
protocol=parsed.scheme or "https",
|
||||
path=parsed.path or "/"
|
||||
)
|
||||
db.add(target)
|
||||
targets.append(target)
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(scan)
|
||||
|
||||
scan_dict = scan.to_dict()
|
||||
scan_dict["targets"] = [t.to_dict() for t in targets]
|
||||
|
||||
return ScanResponse(**scan_dict)
|
||||
|
||||
|
||||
@router.get("/{scan_id}", response_model=ScanResponse)
|
||||
async def get_scan(scan_id: str, db: AsyncSession = Depends(get_db)):
|
||||
"""Get scan details by ID"""
|
||||
result = await db.execute(select(Scan).where(Scan.id == scan_id))
|
||||
scan = result.scalar_one_or_none()
|
||||
|
||||
if not scan:
|
||||
raise HTTPException(status_code=404, detail="Scan not found")
|
||||
|
||||
# Load targets
|
||||
targets_result = await db.execute(select(Target).where(Target.scan_id == scan_id))
|
||||
targets = targets_result.scalars().all()
|
||||
|
||||
scan_dict = scan.to_dict()
|
||||
scan_dict["targets"] = [t.to_dict() for t in targets]
|
||||
|
||||
return ScanResponse(**scan_dict)
|
||||
|
||||
|
||||
@router.post("/{scan_id}/start")
|
||||
async def start_scan(
|
||||
scan_id: str,
|
||||
background_tasks: BackgroundTasks,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Start a scan execution"""
|
||||
result = await db.execute(select(Scan).where(Scan.id == scan_id))
|
||||
scan = result.scalar_one_or_none()
|
||||
|
||||
if not scan:
|
||||
raise HTTPException(status_code=404, detail="Scan not found")
|
||||
|
||||
if scan.status == "running":
|
||||
raise HTTPException(status_code=400, detail="Scan is already running")
|
||||
|
||||
# Update scan status
|
||||
scan.status = "running"
|
||||
scan.started_at = datetime.utcnow()
|
||||
scan.current_phase = "initializing"
|
||||
scan.progress = 0
|
||||
await db.commit()
|
||||
|
||||
# Start scan in background with its own database session
|
||||
background_tasks.add_task(run_scan_task, scan_id)
|
||||
|
||||
return {"message": "Scan started", "scan_id": scan_id}
|
||||
|
||||
|
||||
@router.post("/{scan_id}/stop")
|
||||
async def stop_scan(scan_id: str, db: AsyncSession = Depends(get_db)):
|
||||
"""Stop a running scan and save partial results"""
|
||||
from backend.api.websocket import manager as ws_manager
|
||||
from backend.api.v1.agent import scan_to_agent, agent_instances, agent_results
|
||||
|
||||
result = await db.execute(select(Scan).where(Scan.id == scan_id))
|
||||
scan = result.scalar_one_or_none()
|
||||
|
||||
if not scan:
|
||||
raise HTTPException(status_code=404, detail="Scan not found")
|
||||
|
||||
if scan.status not in ("running", "paused"):
|
||||
raise HTTPException(status_code=400, detail="Scan is not running or paused")
|
||||
|
||||
# Signal the running agent to stop
|
||||
agent_id = scan_to_agent.get(scan_id)
|
||||
if agent_id and agent_id in agent_instances:
|
||||
agent_instances[agent_id].cancel()
|
||||
if agent_id in agent_results:
|
||||
agent_results[agent_id]["status"] = "stopped"
|
||||
agent_results[agent_id]["phase"] = "stopped"
|
||||
|
||||
# Update scan status
|
||||
scan.status = "stopped"
|
||||
scan.completed_at = datetime.utcnow()
|
||||
scan.current_phase = "stopped"
|
||||
|
||||
# Calculate duration
|
||||
if scan.started_at:
|
||||
duration = (scan.completed_at - scan.started_at).total_seconds()
|
||||
scan.duration = int(duration)
|
||||
|
||||
# Compute final vulnerability statistics from database
|
||||
for severity in ["critical", "high", "medium", "low", "info"]:
|
||||
count_result = await db.execute(
|
||||
select(func.count()).select_from(Vulnerability)
|
||||
.where(Vulnerability.scan_id == scan_id)
|
||||
.where(Vulnerability.severity == severity)
|
||||
)
|
||||
setattr(scan, f"{severity}_count", count_result.scalar() or 0)
|
||||
|
||||
# Get total vulnerability count
|
||||
total_vuln_result = await db.execute(
|
||||
select(func.count()).select_from(Vulnerability)
|
||||
.where(Vulnerability.scan_id == scan_id)
|
||||
)
|
||||
scan.total_vulnerabilities = total_vuln_result.scalar() or 0
|
||||
|
||||
# Get total endpoint count
|
||||
total_endpoint_result = await db.execute(
|
||||
select(func.count()).select_from(Endpoint)
|
||||
.where(Endpoint.scan_id == scan_id)
|
||||
)
|
||||
scan.total_endpoints = total_endpoint_result.scalar() or 0
|
||||
|
||||
await db.commit()
|
||||
|
||||
# Build summary for WebSocket broadcast
|
||||
summary = {
|
||||
"total_endpoints": scan.total_endpoints,
|
||||
"total_vulnerabilities": scan.total_vulnerabilities,
|
||||
"critical": scan.critical_count,
|
||||
"high": scan.high_count,
|
||||
"medium": scan.medium_count,
|
||||
"low": scan.low_count,
|
||||
"info": scan.info_count,
|
||||
"duration": scan.duration,
|
||||
"progress": scan.progress
|
||||
}
|
||||
|
||||
# Broadcast stop event via WebSocket
|
||||
await ws_manager.broadcast_scan_stopped(scan_id, summary)
|
||||
await ws_manager.broadcast_log(scan_id, "warning", "Scan stopped by user")
|
||||
await ws_manager.broadcast_log(scan_id, "info", f"Partial results: {scan.total_vulnerabilities} vulnerabilities found")
|
||||
|
||||
# Auto-generate partial report
|
||||
report_data = None
|
||||
try:
|
||||
from backend.services.report_service import auto_generate_report
|
||||
await ws_manager.broadcast_log(scan_id, "info", "Generating partial report...")
|
||||
report = await auto_generate_report(db, scan_id, is_partial=True)
|
||||
report_data = report.to_dict()
|
||||
await ws_manager.broadcast_log(scan_id, "info", f"Partial report generated: {report.title}")
|
||||
except Exception as report_error:
|
||||
await ws_manager.broadcast_log(scan_id, "warning", f"Failed to generate partial report: {str(report_error)}")
|
||||
|
||||
return {
|
||||
"message": "Scan stopped",
|
||||
"scan_id": scan_id,
|
||||
"summary": summary,
|
||||
"report": report_data
|
||||
}
|
||||
|
||||
|
||||
@router.post("/{scan_id}/pause")
|
||||
async def pause_scan(scan_id: str, db: AsyncSession = Depends(get_db)):
|
||||
"""Pause a running scan"""
|
||||
from backend.api.websocket import manager as ws_manager
|
||||
from backend.api.v1.agent import scan_to_agent, agent_instances, agent_results
|
||||
|
||||
result = await db.execute(select(Scan).where(Scan.id == scan_id))
|
||||
scan = result.scalar_one_or_none()
|
||||
|
||||
if not scan:
|
||||
raise HTTPException(status_code=404, detail="Scan not found")
|
||||
|
||||
if scan.status != "running":
|
||||
raise HTTPException(status_code=400, detail="Scan is not running")
|
||||
|
||||
# Signal the agent to pause
|
||||
agent_id = scan_to_agent.get(scan_id)
|
||||
if agent_id and agent_id in agent_instances:
|
||||
agent_instances[agent_id].pause()
|
||||
if agent_id in agent_results:
|
||||
agent_results[agent_id]["status"] = "paused"
|
||||
agent_results[agent_id]["phase"] = "paused"
|
||||
|
||||
scan.status = "paused"
|
||||
scan.current_phase = "paused"
|
||||
await db.commit()
|
||||
|
||||
await ws_manager.broadcast_log(scan_id, "warning", "Scan paused by user")
|
||||
|
||||
return {"message": "Scan paused", "scan_id": scan_id}
|
||||
|
||||
|
||||
@router.post("/{scan_id}/resume")
|
||||
async def resume_scan(scan_id: str, db: AsyncSession = Depends(get_db)):
|
||||
"""Resume a paused scan"""
|
||||
from backend.api.websocket import manager as ws_manager
|
||||
from backend.api.v1.agent import scan_to_agent, agent_instances, agent_results
|
||||
|
||||
result = await db.execute(select(Scan).where(Scan.id == scan_id))
|
||||
scan = result.scalar_one_or_none()
|
||||
|
||||
if not scan:
|
||||
raise HTTPException(status_code=404, detail="Scan not found")
|
||||
|
||||
if scan.status != "paused":
|
||||
raise HTTPException(status_code=400, detail="Scan is not paused")
|
||||
|
||||
# Signal the agent to resume
|
||||
agent_id = scan_to_agent.get(scan_id)
|
||||
if agent_id and agent_id in agent_instances:
|
||||
agent_instances[agent_id].resume()
|
||||
if agent_id in agent_results:
|
||||
agent_results[agent_id]["status"] = "running"
|
||||
agent_results[agent_id]["phase"] = "testing"
|
||||
|
||||
scan.status = "running"
|
||||
scan.current_phase = "testing"
|
||||
await db.commit()
|
||||
|
||||
await ws_manager.broadcast_log(scan_id, "info", "Scan resumed by user")
|
||||
|
||||
return {"message": "Scan resumed", "scan_id": scan_id}
|
||||
|
||||
|
||||
@router.post("/{scan_id}/skip-to/{target_phase}")
|
||||
async def skip_to_phase_endpoint(scan_id: str, target_phase: str, db: AsyncSession = Depends(get_db)):
|
||||
"""Skip the current scan phase and jump to a target phase.
|
||||
|
||||
Valid phases: recon, analyzing, testing, completed
|
||||
Can only skip forward (to a phase ahead of current).
|
||||
"""
|
||||
result = await db.execute(select(Scan).where(Scan.id == scan_id))
|
||||
scan = result.scalar_one_or_none()
|
||||
|
||||
if not scan:
|
||||
raise HTTPException(status_code=404, detail="Scan not found")
|
||||
|
||||
if scan.status not in ("running", "paused"):
|
||||
raise HTTPException(status_code=400, detail="Scan is not running or paused")
|
||||
|
||||
# If paused, resume first so the skip can be processed
|
||||
if scan.status == "paused":
|
||||
from backend.api.v1.agent import scan_to_agent, agent_instances, agent_results
|
||||
agent_id = scan_to_agent.get(scan_id)
|
||||
if agent_id and agent_id in agent_instances:
|
||||
agent_instances[agent_id].resume()
|
||||
if agent_id in agent_results:
|
||||
agent_results[agent_id]["status"] = "running"
|
||||
agent_results[agent_id]["phase"] = agent_results[agent_id].get("last_phase", "testing")
|
||||
scan.status = "running"
|
||||
await db.commit()
|
||||
|
||||
if target_phase not in PHASE_ORDER:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid phase '{target_phase}'. Valid: {', '.join(PHASE_ORDER[1:])}"
|
||||
)
|
||||
|
||||
# Validate forward skip
|
||||
current_idx = PHASE_ORDER.index(scan.current_phase) if scan.current_phase in PHASE_ORDER else 0
|
||||
target_idx = PHASE_ORDER.index(target_phase)
|
||||
|
||||
if target_idx <= current_idx:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Cannot skip backward. Current: {scan.current_phase}, target: {target_phase}"
|
||||
)
|
||||
|
||||
# Signal the running scan to skip
|
||||
success = _skip_to_phase(scan_id, target_phase)
|
||||
if not success:
|
||||
raise HTTPException(status_code=500, detail="Failed to signal phase skip")
|
||||
|
||||
# Broadcast via WebSocket
|
||||
from backend.api.websocket import manager as ws_manager
|
||||
await ws_manager.broadcast_log(scan_id, "warning", f">> User requested skip to phase: {target_phase}")
|
||||
await ws_manager.broadcast_phase_change(scan_id, f"skipping_to_{target_phase}")
|
||||
|
||||
return {
|
||||
"message": f"Skipping to phase: {target_phase}",
|
||||
"scan_id": scan_id,
|
||||
"from_phase": scan.current_phase,
|
||||
"target_phase": target_phase
|
||||
}
|
||||
|
||||
|
||||
@router.get("/{scan_id}/status", response_model=ScanProgress)
|
||||
async def get_scan_status(scan_id: str, db: AsyncSession = Depends(get_db)):
|
||||
"""Get scan progress and status"""
|
||||
result = await db.execute(select(Scan).where(Scan.id == scan_id))
|
||||
scan = result.scalar_one_or_none()
|
||||
|
||||
if not scan:
|
||||
raise HTTPException(status_code=404, detail="Scan not found")
|
||||
|
||||
return ScanProgress(
|
||||
scan_id=scan.id,
|
||||
status=scan.status,
|
||||
progress=scan.progress,
|
||||
current_phase=scan.current_phase,
|
||||
total_endpoints=scan.total_endpoints,
|
||||
total_vulnerabilities=scan.total_vulnerabilities
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/{scan_id}")
|
||||
async def delete_scan(scan_id: str, db: AsyncSession = Depends(get_db)):
|
||||
"""Delete a scan"""
|
||||
result = await db.execute(select(Scan).where(Scan.id == scan_id))
|
||||
scan = result.scalar_one_or_none()
|
||||
|
||||
if not scan:
|
||||
raise HTTPException(status_code=404, detail="Scan not found")
|
||||
|
||||
if scan.status == "running":
|
||||
raise HTTPException(status_code=400, detail="Cannot delete running scan")
|
||||
|
||||
await db.delete(scan)
|
||||
await db.commit()
|
||||
|
||||
return {"message": "Scan deleted", "scan_id": scan_id}
|
||||
|
||||
|
||||
@router.get("/{scan_id}/endpoints")
|
||||
async def get_scan_endpoints(
|
||||
scan_id: str,
|
||||
page: int = 1,
|
||||
per_page: int = 50,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Get endpoints discovered in a scan"""
|
||||
result = await db.execute(select(Scan).where(Scan.id == scan_id))
|
||||
scan = result.scalar_one_or_none()
|
||||
|
||||
if not scan:
|
||||
raise HTTPException(status_code=404, detail="Scan not found")
|
||||
|
||||
query = select(Endpoint).where(Endpoint.scan_id == scan_id).order_by(Endpoint.discovered_at.desc())
|
||||
|
||||
# Count
|
||||
count_result = await db.execute(select(func.count()).select_from(Endpoint).where(Endpoint.scan_id == scan_id))
|
||||
total = count_result.scalar()
|
||||
|
||||
# Paginate
|
||||
query = query.offset((page - 1) * per_page).limit(per_page)
|
||||
result = await db.execute(query)
|
||||
endpoints = result.scalars().all()
|
||||
|
||||
return {
|
||||
"endpoints": [e.to_dict() for e in endpoints],
|
||||
"total": total,
|
||||
"page": page,
|
||||
"per_page": per_page
|
||||
}
|
||||
|
||||
|
||||
@router.get("/{scan_id}/vulnerabilities")
|
||||
async def get_scan_vulnerabilities(
|
||||
scan_id: str,
|
||||
severity: Optional[str] = None,
|
||||
page: int = 1,
|
||||
per_page: int = 50,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Get vulnerabilities found in a scan"""
|
||||
result = await db.execute(select(Scan).where(Scan.id == scan_id))
|
||||
scan = result.scalar_one_or_none()
|
||||
|
||||
if not scan:
|
||||
raise HTTPException(status_code=404, detail="Scan not found")
|
||||
|
||||
query = select(Vulnerability).where(Vulnerability.scan_id == scan_id)
|
||||
|
||||
if severity:
|
||||
query = query.where(Vulnerability.severity == severity)
|
||||
|
||||
query = query.order_by(Vulnerability.created_at.desc())
|
||||
|
||||
# Count
|
||||
count_query = select(func.count()).select_from(Vulnerability).where(Vulnerability.scan_id == scan_id)
|
||||
if severity:
|
||||
count_query = count_query.where(Vulnerability.severity == severity)
|
||||
count_result = await db.execute(count_query)
|
||||
total = count_result.scalar()
|
||||
|
||||
# Paginate
|
||||
query = query.offset((page - 1) * per_page).limit(per_page)
|
||||
result = await db.execute(query)
|
||||
vulnerabilities = result.scalars().all()
|
||||
|
||||
return {
|
||||
"vulnerabilities": [v.to_dict() for v in vulnerabilities],
|
||||
"total": total,
|
||||
"page": page,
|
||||
"per_page": per_page
|
||||
}
|
||||
|
||||
|
||||
class ValidationRequest(BaseModel):
|
||||
validation_status: str # "validated" | "false_positive" | "ai_confirmed" | "ai_rejected" | "pending_review"
|
||||
notes: Optional[str] = None
|
||||
|
||||
|
||||
@router.patch("/vulnerabilities/{vuln_id}/validate")
|
||||
async def validate_vulnerability(
|
||||
vuln_id: str,
|
||||
body: ValidationRequest,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Manually validate or reject a vulnerability finding"""
|
||||
valid_statuses = {"validated", "false_positive", "ai_confirmed", "ai_rejected", "pending_review"}
|
||||
if body.validation_status not in valid_statuses:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid status. Must be one of: {', '.join(valid_statuses)}")
|
||||
|
||||
result = await db.execute(select(Vulnerability).where(Vulnerability.id == vuln_id))
|
||||
vuln = result.scalar_one_or_none()
|
||||
|
||||
if not vuln:
|
||||
raise HTTPException(status_code=404, detail="Vulnerability not found")
|
||||
|
||||
old_status = vuln.validation_status or "ai_confirmed"
|
||||
vuln.validation_status = body.validation_status
|
||||
if body.notes:
|
||||
vuln.ai_rejection_reason = body.notes
|
||||
|
||||
# Update scan severity counts when validation status changes
|
||||
scan_result = await db.execute(select(Scan).where(Scan.id == vuln.scan_id))
|
||||
scan = scan_result.scalar_one_or_none()
|
||||
|
||||
if scan:
|
||||
sev = vuln.severity
|
||||
# If changing from rejected to validated: add to counts
|
||||
if old_status == "ai_rejected" and body.validation_status == "validated":
|
||||
scan.total_vulnerabilities = (scan.total_vulnerabilities or 0) + 1
|
||||
if sev == "critical":
|
||||
scan.critical_count = (scan.critical_count or 0) + 1
|
||||
elif sev == "high":
|
||||
scan.high_count = (scan.high_count or 0) + 1
|
||||
elif sev == "medium":
|
||||
scan.medium_count = (scan.medium_count or 0) + 1
|
||||
elif sev == "low":
|
||||
scan.low_count = (scan.low_count or 0) + 1
|
||||
elif sev == "info":
|
||||
scan.info_count = (scan.info_count or 0) + 1
|
||||
# If changing from confirmed to false_positive: subtract from counts
|
||||
elif old_status in ("ai_confirmed", "validated") and body.validation_status == "false_positive":
|
||||
scan.total_vulnerabilities = max(0, (scan.total_vulnerabilities or 0) - 1)
|
||||
if sev == "critical":
|
||||
scan.critical_count = max(0, (scan.critical_count or 0) - 1)
|
||||
elif sev == "high":
|
||||
scan.high_count = max(0, (scan.high_count or 0) - 1)
|
||||
elif sev == "medium":
|
||||
scan.medium_count = max(0, (scan.medium_count or 0) - 1)
|
||||
elif sev == "low":
|
||||
scan.low_count = max(0, (scan.low_count or 0) - 1)
|
||||
elif sev == "info":
|
||||
scan.info_count = max(0, (scan.info_count or 0) - 1)
|
||||
|
||||
await db.commit()
|
||||
|
||||
return {"message": "Vulnerability validation updated", "vulnerability": vuln.to_dict()}
|
||||
|
||||
|
||||
# --- Adaptive Learning Feedback ---
|
||||
|
||||
class FeedbackRequest(BaseModel):
|
||||
is_true_positive: bool
|
||||
explanation: str = ""
|
||||
|
||||
|
||||
@router.post("/vulnerabilities/{vuln_id}/feedback")
|
||||
async def submit_vulnerability_feedback(
|
||||
vuln_id: str,
|
||||
body: FeedbackRequest,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Submit TP/FP feedback for a vulnerability finding.
|
||||
|
||||
Records feedback in the adaptive learner so the agent improves over time.
|
||||
Also updates the validation_status in the database.
|
||||
"""
|
||||
result = await db.execute(select(Vulnerability).where(Vulnerability.id == vuln_id))
|
||||
vuln = result.scalar_one_or_none()
|
||||
if not vuln:
|
||||
raise HTTPException(status_code=404, detail="Vulnerability not found")
|
||||
|
||||
if len(body.explanation) < 3 and not body.is_true_positive:
|
||||
raise HTTPException(status_code=400, detail="Explanation required for false positive feedback (min 3 chars)")
|
||||
|
||||
# Update DB validation status
|
||||
vuln.validation_status = "validated" if body.is_true_positive else "false_positive"
|
||||
if body.explanation:
|
||||
vuln.ai_rejection_reason = body.explanation
|
||||
|
||||
# Update scan counts
|
||||
scan_result = await db.execute(select(Scan).where(Scan.id == vuln.scan_id))
|
||||
scan = scan_result.scalar_one_or_none()
|
||||
if scan and not body.is_true_positive:
|
||||
sev = vuln.severity
|
||||
scan.total_vulnerabilities = max(0, (scan.total_vulnerabilities or 0) - 1)
|
||||
count_attr = f"{sev}_count"
|
||||
if hasattr(scan, count_attr):
|
||||
setattr(scan, count_attr, max(0, (getattr(scan, count_attr) or 0) - 1))
|
||||
|
||||
await db.commit()
|
||||
|
||||
# Record in adaptive learner
|
||||
pattern_count = 0
|
||||
try:
|
||||
from backend.core.adaptive_learner import AdaptiveLearner
|
||||
learner = AdaptiveLearner()
|
||||
vuln_dict = vuln.to_dict()
|
||||
learner.record_feedback(
|
||||
vuln_id=vuln_id,
|
||||
vuln_type=vuln_dict.get("vuln_type", "unknown"),
|
||||
endpoint=vuln_dict.get("url", ""),
|
||||
param=vuln_dict.get("parameter", ""),
|
||||
payload=vuln_dict.get("payload", ""),
|
||||
is_tp=body.is_true_positive,
|
||||
explanation=body.explanation,
|
||||
severity=vuln_dict.get("severity", "medium"),
|
||||
domain=vuln_dict.get("url", ""),
|
||||
)
|
||||
stats = learner.get_stats()
|
||||
pattern_count = stats.get("total_patterns", 0)
|
||||
except Exception as e:
|
||||
logger.warning(f"Adaptive learner feedback failed: {e}")
|
||||
|
||||
return {
|
||||
"message": "Feedback recorded",
|
||||
"vulnerability_id": vuln_id,
|
||||
"is_true_positive": body.is_true_positive,
|
||||
"pattern_count": pattern_count,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/vulnerabilities/learning/stats")
|
||||
async def get_learning_stats():
|
||||
"""Get adaptive learning statistics."""
|
||||
try:
|
||||
from backend.core.adaptive_learner import AdaptiveLearner
|
||||
learner = AdaptiveLearner()
|
||||
return learner.get_stats()
|
||||
except Exception as e:
|
||||
return {"error": str(e), "total_feedback": 0, "total_patterns": 0}
|
||||
140
backend/api/v1/scheduler.py
Executable file
140
backend/api/v1/scheduler.py
Executable file
@@ -0,0 +1,140 @@
|
||||
"""
|
||||
NeuroSploit v3 - Scheduler API Router
|
||||
|
||||
CRUD endpoints for managing scheduled scan jobs.
|
||||
"""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional, List, Dict
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
CONFIG_PATH = Path(__file__).parent.parent.parent.parent / "config" / "config.json"
|
||||
|
||||
|
||||
class ScheduleJobRequest(BaseModel):
|
||||
"""Request model for creating a scheduled job."""
|
||||
job_id: str
|
||||
target: str
|
||||
scan_type: str = "quick"
|
||||
cron_expression: Optional[str] = None
|
||||
interval_minutes: Optional[int] = None
|
||||
agent_role: Optional[str] = None
|
||||
llm_profile: Optional[str] = None
|
||||
|
||||
|
||||
class ScheduleJobResponse(BaseModel):
|
||||
"""Response model for a scheduled job."""
|
||||
id: str
|
||||
target: str
|
||||
scan_type: str
|
||||
schedule: str
|
||||
status: str
|
||||
next_run: Optional[str] = None
|
||||
last_run: Optional[str] = None
|
||||
run_count: int = 0
|
||||
|
||||
|
||||
@router.get("/", response_model=List[Dict])
|
||||
async def list_scheduled_jobs(request: Request):
|
||||
"""List all scheduled scan jobs."""
|
||||
scheduler = getattr(request.app.state, 'scheduler', None)
|
||||
if not scheduler:
|
||||
return []
|
||||
return scheduler.list_jobs()
|
||||
|
||||
|
||||
@router.post("/", response_model=Dict)
|
||||
async def create_scheduled_job(job: ScheduleJobRequest, request: Request):
|
||||
"""Create a new scheduled scan job."""
|
||||
scheduler = getattr(request.app.state, 'scheduler', None)
|
||||
if not scheduler:
|
||||
raise HTTPException(status_code=503, detail="Scheduler not available")
|
||||
|
||||
if not job.cron_expression and not job.interval_minutes:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Either cron_expression or interval_minutes must be provided"
|
||||
)
|
||||
|
||||
result = scheduler.add_job(
|
||||
job_id=job.job_id,
|
||||
target=job.target,
|
||||
scan_type=job.scan_type,
|
||||
cron_expression=job.cron_expression,
|
||||
interval_minutes=job.interval_minutes,
|
||||
agent_role=job.agent_role,
|
||||
llm_profile=job.llm_profile
|
||||
)
|
||||
|
||||
if "error" in result:
|
||||
raise HTTPException(status_code=400, detail=result["error"])
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@router.delete("/{job_id}")
|
||||
async def delete_scheduled_job(job_id: str, request: Request):
|
||||
"""Delete a scheduled scan job."""
|
||||
scheduler = getattr(request.app.state, 'scheduler', None)
|
||||
if not scheduler:
|
||||
raise HTTPException(status_code=503, detail="Scheduler not available")
|
||||
|
||||
success = scheduler.remove_job(job_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail=f"Job '{job_id}' not found")
|
||||
|
||||
return {"message": f"Job '{job_id}' deleted", "id": job_id}
|
||||
|
||||
|
||||
@router.post("/{job_id}/pause")
|
||||
async def pause_scheduled_job(job_id: str, request: Request):
|
||||
"""Pause a scheduled scan job."""
|
||||
scheduler = getattr(request.app.state, 'scheduler', None)
|
||||
if not scheduler:
|
||||
raise HTTPException(status_code=503, detail="Scheduler not available")
|
||||
|
||||
success = scheduler.pause_job(job_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail=f"Job '{job_id}' not found")
|
||||
|
||||
return {"message": f"Job '{job_id}' paused", "id": job_id, "status": "paused"}
|
||||
|
||||
|
||||
@router.post("/{job_id}/resume")
|
||||
async def resume_scheduled_job(job_id: str, request: Request):
|
||||
"""Resume a paused scheduled scan job."""
|
||||
scheduler = getattr(request.app.state, 'scheduler', None)
|
||||
if not scheduler:
|
||||
raise HTTPException(status_code=503, detail="Scheduler not available")
|
||||
|
||||
success = scheduler.resume_job(job_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail=f"Job '{job_id}' not found")
|
||||
|
||||
return {"message": f"Job '{job_id}' resumed", "id": job_id, "status": "active"}
|
||||
|
||||
|
||||
@router.get("/agent-roles", response_model=List[Dict])
|
||||
async def get_agent_roles():
|
||||
"""Return available agent roles from config.json for scheduler dropdown."""
|
||||
try:
|
||||
if not CONFIG_PATH.exists():
|
||||
return []
|
||||
config = json.loads(CONFIG_PATH.read_text())
|
||||
roles = config.get("agent_roles", {})
|
||||
result = []
|
||||
for role_id, role_data in roles.items():
|
||||
if role_data.get("enabled", True):
|
||||
result.append({
|
||||
"id": role_id,
|
||||
"name": role_id.replace("_", " ").title(),
|
||||
"description": role_data.get("description", ""),
|
||||
"tools": role_data.get("tools_allowed", []),
|
||||
})
|
||||
return result
|
||||
except Exception:
|
||||
return []
|
||||
707
backend/api/v1/settings.py
Executable file
707
backend/api/v1/settings.py
Executable file
@@ -0,0 +1,707 @@
|
||||
"""
|
||||
NeuroSploit v3 - Settings API Endpoints
|
||||
"""
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, List
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, delete, text
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.db.database import get_db, engine
|
||||
from backend.models import Scan, Target, Endpoint, Vulnerability, VulnerabilityTest, Report
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# Path to .env file (project root)
|
||||
ENV_FILE_PATH = Path(__file__).parent.parent.parent.parent / ".env"
|
||||
|
||||
|
||||
def _update_env_file(updates: Dict[str, str]) -> bool:
|
||||
"""
|
||||
Update key=value pairs in the .env file without breaking formatting.
|
||||
- If the key exists (even commented out), update its value
|
||||
- If the key doesn't exist, append it
|
||||
- Preserves comments and blank lines
|
||||
"""
|
||||
if not ENV_FILE_PATH.exists():
|
||||
return False
|
||||
|
||||
try:
|
||||
lines = ENV_FILE_PATH.read_text().splitlines()
|
||||
updated_keys = set()
|
||||
|
||||
new_lines = []
|
||||
for line in lines:
|
||||
stripped = line.strip()
|
||||
matched = False
|
||||
|
||||
for key, value in updates.items():
|
||||
# Match: KEY=..., # KEY=..., #KEY=...
|
||||
pattern = rf'^#?\s*{re.escape(key)}\s*='
|
||||
if re.match(pattern, stripped):
|
||||
# Replace with uncommented key=value
|
||||
new_lines.append(f"{key}={value}")
|
||||
updated_keys.add(key)
|
||||
matched = True
|
||||
break
|
||||
|
||||
if not matched:
|
||||
new_lines.append(line)
|
||||
|
||||
# Append any keys that weren't found in existing file
|
||||
for key, value in updates.items():
|
||||
if key not in updated_keys:
|
||||
new_lines.append(f"{key}={value}")
|
||||
|
||||
# Write back with trailing newline
|
||||
ENV_FILE_PATH.write_text("\n".join(new_lines) + "\n")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to update .env file: {e}")
|
||||
return False
|
||||
|
||||
|
||||
class SettingsUpdate(BaseModel):
|
||||
"""Settings update schema"""
|
||||
llm_provider: Optional[str] = None
|
||||
llm_model: Optional[str] = None
|
||||
anthropic_api_key: Optional[str] = None
|
||||
openai_api_key: Optional[str] = None
|
||||
openrouter_api_key: Optional[str] = None
|
||||
gemini_api_key: Optional[str] = None
|
||||
together_api_key: Optional[str] = None
|
||||
fireworks_api_key: Optional[str] = None
|
||||
ollama_base_url: Optional[str] = None
|
||||
lmstudio_base_url: Optional[str] = None
|
||||
max_concurrent_scans: Optional[int] = None
|
||||
aggressive_mode: Optional[bool] = None
|
||||
default_scan_type: Optional[str] = None
|
||||
recon_enabled_by_default: Optional[bool] = None
|
||||
enable_model_routing: Optional[bool] = None
|
||||
enable_knowledge_augmentation: Optional[bool] = None
|
||||
enable_browser_validation: Optional[bool] = None
|
||||
max_output_tokens: Optional[int] = None
|
||||
# Notifications
|
||||
enable_notifications: Optional[bool] = None
|
||||
discord_webhook_url: Optional[str] = None
|
||||
telegram_bot_token: Optional[str] = None
|
||||
telegram_chat_id: Optional[str] = None
|
||||
twilio_account_sid: Optional[str] = None
|
||||
twilio_auth_token: Optional[str] = None
|
||||
twilio_from_number: Optional[str] = None
|
||||
twilio_to_number: Optional[str] = None
|
||||
notification_severity_filter: Optional[str] = None
|
||||
|
||||
|
||||
class SettingsResponse(BaseModel):
|
||||
"""Settings response schema"""
|
||||
llm_provider: str = "claude"
|
||||
llm_model: str = ""
|
||||
has_anthropic_key: bool = False
|
||||
has_openai_key: bool = False
|
||||
has_openrouter_key: bool = False
|
||||
has_gemini_key: bool = False
|
||||
has_together_key: bool = False
|
||||
has_fireworks_key: bool = False
|
||||
ollama_base_url: str = ""
|
||||
lmstudio_base_url: str = ""
|
||||
max_concurrent_scans: int = 3
|
||||
aggressive_mode: bool = False
|
||||
default_scan_type: str = "full"
|
||||
recon_enabled_by_default: bool = True
|
||||
enable_model_routing: bool = False
|
||||
enable_knowledge_augmentation: bool = False
|
||||
enable_browser_validation: bool = False
|
||||
max_output_tokens: Optional[int] = None
|
||||
# Notifications
|
||||
enable_notifications: bool = False
|
||||
has_discord_webhook: bool = False
|
||||
has_telegram_bot: bool = False
|
||||
has_twilio_credentials: bool = False
|
||||
notification_severity_filter: str = "critical,high"
|
||||
|
||||
|
||||
class ModelInfo(BaseModel):
|
||||
"""Info about an available LLM model"""
|
||||
provider: str
|
||||
model_id: str
|
||||
display_name: str
|
||||
size: Optional[str] = None
|
||||
context_length: Optional[int] = None
|
||||
is_local: bool = False
|
||||
|
||||
|
||||
class ModelCatalogResponse(BaseModel):
|
||||
"""Response from model catalog endpoint"""
|
||||
provider: str
|
||||
models: List[ModelInfo]
|
||||
available: bool
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
def _load_settings_from_env() -> dict:
|
||||
"""
|
||||
Load settings from environment variables / .env file on startup.
|
||||
This ensures settings persist across server restarts and browser sessions.
|
||||
"""
|
||||
from dotenv import load_dotenv
|
||||
# Re-read .env file to pick up disk-persisted values
|
||||
if ENV_FILE_PATH.exists():
|
||||
load_dotenv(ENV_FILE_PATH, override=True)
|
||||
|
||||
def _env_bool(key: str, default: bool = False) -> bool:
|
||||
val = os.getenv(key, "").strip().lower()
|
||||
if val in ("true", "1", "yes"):
|
||||
return True
|
||||
if val in ("false", "0", "no"):
|
||||
return False
|
||||
return default
|
||||
|
||||
def _env_int(key: str, default=None):
|
||||
val = os.getenv(key, "").strip()
|
||||
if val:
|
||||
try:
|
||||
return int(val)
|
||||
except ValueError:
|
||||
pass
|
||||
return default
|
||||
|
||||
# Detect provider from which keys are set
|
||||
provider = "claude"
|
||||
if os.getenv("ANTHROPIC_API_KEY"):
|
||||
provider = "claude"
|
||||
elif os.getenv("OPENAI_API_KEY"):
|
||||
provider = "openai"
|
||||
elif os.getenv("OPENROUTER_API_KEY"):
|
||||
provider = "openrouter"
|
||||
|
||||
return {
|
||||
"llm_provider": provider,
|
||||
"llm_model": os.getenv("DEFAULT_LLM_MODEL", ""),
|
||||
"anthropic_api_key": os.getenv("ANTHROPIC_API_KEY", ""),
|
||||
"openai_api_key": os.getenv("OPENAI_API_KEY", ""),
|
||||
"openrouter_api_key": os.getenv("OPENROUTER_API_KEY", ""),
|
||||
"gemini_api_key": os.getenv("GEMINI_API_KEY", ""),
|
||||
"together_api_key": os.getenv("TOGETHER_API_KEY", ""),
|
||||
"fireworks_api_key": os.getenv("FIREWORKS_API_KEY", ""),
|
||||
"ollama_base_url": os.getenv("OLLAMA_BASE_URL", os.getenv("OLLAMA_URL", "")),
|
||||
"lmstudio_base_url": os.getenv("LMSTUDIO_BASE_URL", os.getenv("LMSTUDIO_URL", "")),
|
||||
"max_concurrent_scans": _env_int("MAX_CONCURRENT_SCANS", 3),
|
||||
"aggressive_mode": _env_bool("AGGRESSIVE_MODE", False),
|
||||
"default_scan_type": os.getenv("DEFAULT_SCAN_TYPE", "full"),
|
||||
"recon_enabled_by_default": _env_bool("RECON_ENABLED_BY_DEFAULT", True),
|
||||
"enable_model_routing": _env_bool("ENABLE_MODEL_ROUTING", False),
|
||||
"enable_knowledge_augmentation": _env_bool("ENABLE_KNOWLEDGE_AUGMENTATION", False),
|
||||
"enable_browser_validation": _env_bool("ENABLE_BROWSER_VALIDATION", False),
|
||||
"max_output_tokens": _env_int("MAX_OUTPUT_TOKENS", None),
|
||||
# Notifications
|
||||
"enable_notifications": _env_bool("ENABLE_NOTIFICATIONS", False),
|
||||
"discord_webhook_url": os.getenv("DISCORD_WEBHOOK_URL", ""),
|
||||
"telegram_bot_token": os.getenv("TELEGRAM_BOT_TOKEN", ""),
|
||||
"telegram_chat_id": os.getenv("TELEGRAM_CHAT_ID", ""),
|
||||
"twilio_account_sid": os.getenv("TWILIO_ACCOUNT_SID", ""),
|
||||
"twilio_auth_token": os.getenv("TWILIO_AUTH_TOKEN", ""),
|
||||
"twilio_from_number": os.getenv("TWILIO_FROM_NUMBER", ""),
|
||||
"twilio_to_number": os.getenv("TWILIO_TO_NUMBER", ""),
|
||||
"notification_severity_filter": os.getenv("NOTIFICATION_SEVERITY_FILTER", "critical,high"),
|
||||
}
|
||||
|
||||
|
||||
# Load settings from .env on module import (server start)
|
||||
_settings = _load_settings_from_env()
|
||||
|
||||
|
||||
@router.get("", response_model=SettingsResponse)
|
||||
async def get_settings():
|
||||
"""Get current settings"""
|
||||
import os
|
||||
return SettingsResponse(
|
||||
llm_provider=_settings["llm_provider"],
|
||||
llm_model=_settings.get("llm_model", ""),
|
||||
has_anthropic_key=bool(_settings["anthropic_api_key"] or os.getenv("ANTHROPIC_API_KEY")),
|
||||
has_openai_key=bool(_settings["openai_api_key"] or os.getenv("OPENAI_API_KEY")),
|
||||
has_openrouter_key=bool(_settings["openrouter_api_key"] or os.getenv("OPENROUTER_API_KEY")),
|
||||
has_gemini_key=bool(_settings.get("gemini_api_key") or os.getenv("GEMINI_API_KEY")),
|
||||
has_together_key=bool(_settings.get("together_api_key") or os.getenv("TOGETHER_API_KEY")),
|
||||
has_fireworks_key=bool(_settings.get("fireworks_api_key") or os.getenv("FIREWORKS_API_KEY")),
|
||||
ollama_base_url=_settings.get("ollama_base_url", ""),
|
||||
lmstudio_base_url=_settings.get("lmstudio_base_url", ""),
|
||||
max_concurrent_scans=_settings["max_concurrent_scans"],
|
||||
aggressive_mode=_settings["aggressive_mode"],
|
||||
default_scan_type=_settings["default_scan_type"],
|
||||
recon_enabled_by_default=_settings["recon_enabled_by_default"],
|
||||
enable_model_routing=_settings["enable_model_routing"],
|
||||
enable_knowledge_augmentation=_settings["enable_knowledge_augmentation"],
|
||||
enable_browser_validation=_settings["enable_browser_validation"],
|
||||
max_output_tokens=_settings["max_output_tokens"],
|
||||
# Notifications
|
||||
enable_notifications=_settings.get("enable_notifications", False),
|
||||
has_discord_webhook=bool(_settings.get("discord_webhook_url")),
|
||||
has_telegram_bot=bool(_settings.get("telegram_bot_token") and _settings.get("telegram_chat_id")),
|
||||
has_twilio_credentials=bool(
|
||||
_settings.get("twilio_account_sid") and _settings.get("twilio_auth_token")
|
||||
and _settings.get("twilio_from_number") and _settings.get("twilio_to_number")
|
||||
),
|
||||
notification_severity_filter=_settings.get("notification_severity_filter", "critical,high"),
|
||||
)
|
||||
|
||||
|
||||
@router.put("", response_model=SettingsResponse)
|
||||
async def update_settings(settings_data: SettingsUpdate):
|
||||
"""Update settings - persists to memory, env vars, AND .env file"""
|
||||
env_updates: Dict[str, str] = {}
|
||||
|
||||
if settings_data.llm_provider is not None:
|
||||
_settings["llm_provider"] = settings_data.llm_provider
|
||||
|
||||
if settings_data.llm_model is not None:
|
||||
_settings["llm_model"] = settings_data.llm_model
|
||||
os.environ["DEFAULT_LLM_MODEL"] = settings_data.llm_model
|
||||
env_updates["DEFAULT_LLM_MODEL"] = settings_data.llm_model
|
||||
|
||||
if settings_data.anthropic_api_key is not None:
|
||||
_settings["anthropic_api_key"] = settings_data.anthropic_api_key
|
||||
if settings_data.anthropic_api_key:
|
||||
os.environ["ANTHROPIC_API_KEY"] = settings_data.anthropic_api_key
|
||||
env_updates["ANTHROPIC_API_KEY"] = settings_data.anthropic_api_key
|
||||
|
||||
if settings_data.openai_api_key is not None:
|
||||
_settings["openai_api_key"] = settings_data.openai_api_key
|
||||
if settings_data.openai_api_key:
|
||||
os.environ["OPENAI_API_KEY"] = settings_data.openai_api_key
|
||||
env_updates["OPENAI_API_KEY"] = settings_data.openai_api_key
|
||||
|
||||
if settings_data.openrouter_api_key is not None:
|
||||
_settings["openrouter_api_key"] = settings_data.openrouter_api_key
|
||||
if settings_data.openrouter_api_key:
|
||||
os.environ["OPENROUTER_API_KEY"] = settings_data.openrouter_api_key
|
||||
env_updates["OPENROUTER_API_KEY"] = settings_data.openrouter_api_key
|
||||
|
||||
if settings_data.gemini_api_key is not None:
|
||||
_settings["gemini_api_key"] = settings_data.gemini_api_key
|
||||
if settings_data.gemini_api_key:
|
||||
os.environ["GEMINI_API_KEY"] = settings_data.gemini_api_key
|
||||
env_updates["GEMINI_API_KEY"] = settings_data.gemini_api_key
|
||||
|
||||
if settings_data.together_api_key is not None:
|
||||
_settings["together_api_key"] = settings_data.together_api_key
|
||||
if settings_data.together_api_key:
|
||||
os.environ["TOGETHER_API_KEY"] = settings_data.together_api_key
|
||||
env_updates["TOGETHER_API_KEY"] = settings_data.together_api_key
|
||||
|
||||
if settings_data.fireworks_api_key is not None:
|
||||
_settings["fireworks_api_key"] = settings_data.fireworks_api_key
|
||||
if settings_data.fireworks_api_key:
|
||||
os.environ["FIREWORKS_API_KEY"] = settings_data.fireworks_api_key
|
||||
env_updates["FIREWORKS_API_KEY"] = settings_data.fireworks_api_key
|
||||
|
||||
if settings_data.ollama_base_url is not None:
|
||||
_settings["ollama_base_url"] = settings_data.ollama_base_url
|
||||
if settings_data.ollama_base_url:
|
||||
os.environ["OLLAMA_BASE_URL"] = settings_data.ollama_base_url
|
||||
env_updates["OLLAMA_BASE_URL"] = settings_data.ollama_base_url
|
||||
|
||||
if settings_data.lmstudio_base_url is not None:
|
||||
_settings["lmstudio_base_url"] = settings_data.lmstudio_base_url
|
||||
if settings_data.lmstudio_base_url:
|
||||
os.environ["LMSTUDIO_BASE_URL"] = settings_data.lmstudio_base_url
|
||||
env_updates["LMSTUDIO_BASE_URL"] = settings_data.lmstudio_base_url
|
||||
|
||||
if settings_data.max_concurrent_scans is not None:
|
||||
_settings["max_concurrent_scans"] = settings_data.max_concurrent_scans
|
||||
|
||||
if settings_data.aggressive_mode is not None:
|
||||
_settings["aggressive_mode"] = settings_data.aggressive_mode
|
||||
|
||||
if settings_data.default_scan_type is not None:
|
||||
_settings["default_scan_type"] = settings_data.default_scan_type
|
||||
|
||||
if settings_data.recon_enabled_by_default is not None:
|
||||
_settings["recon_enabled_by_default"] = settings_data.recon_enabled_by_default
|
||||
|
||||
if settings_data.enable_model_routing is not None:
|
||||
_settings["enable_model_routing"] = settings_data.enable_model_routing
|
||||
val = str(settings_data.enable_model_routing).lower()
|
||||
os.environ["ENABLE_MODEL_ROUTING"] = val
|
||||
env_updates["ENABLE_MODEL_ROUTING"] = val
|
||||
|
||||
if settings_data.enable_knowledge_augmentation is not None:
|
||||
_settings["enable_knowledge_augmentation"] = settings_data.enable_knowledge_augmentation
|
||||
val = str(settings_data.enable_knowledge_augmentation).lower()
|
||||
os.environ["ENABLE_KNOWLEDGE_AUGMENTATION"] = val
|
||||
env_updates["ENABLE_KNOWLEDGE_AUGMENTATION"] = val
|
||||
|
||||
if settings_data.enable_browser_validation is not None:
|
||||
_settings["enable_browser_validation"] = settings_data.enable_browser_validation
|
||||
val = str(settings_data.enable_browser_validation).lower()
|
||||
os.environ["ENABLE_BROWSER_VALIDATION"] = val
|
||||
env_updates["ENABLE_BROWSER_VALIDATION"] = val
|
||||
|
||||
if settings_data.max_output_tokens is not None:
|
||||
_settings["max_output_tokens"] = settings_data.max_output_tokens
|
||||
if settings_data.max_output_tokens:
|
||||
os.environ["MAX_OUTPUT_TOKENS"] = str(settings_data.max_output_tokens)
|
||||
env_updates["MAX_OUTPUT_TOKENS"] = str(settings_data.max_output_tokens)
|
||||
|
||||
# Notifications
|
||||
if settings_data.enable_notifications is not None:
|
||||
_settings["enable_notifications"] = settings_data.enable_notifications
|
||||
val = str(settings_data.enable_notifications).lower()
|
||||
os.environ["ENABLE_NOTIFICATIONS"] = val
|
||||
env_updates["ENABLE_NOTIFICATIONS"] = val
|
||||
|
||||
if settings_data.discord_webhook_url is not None:
|
||||
_settings["discord_webhook_url"] = settings_data.discord_webhook_url
|
||||
os.environ["DISCORD_WEBHOOK_URL"] = settings_data.discord_webhook_url
|
||||
env_updates["DISCORD_WEBHOOK_URL"] = settings_data.discord_webhook_url
|
||||
|
||||
if settings_data.telegram_bot_token is not None:
|
||||
_settings["telegram_bot_token"] = settings_data.telegram_bot_token
|
||||
os.environ["TELEGRAM_BOT_TOKEN"] = settings_data.telegram_bot_token
|
||||
env_updates["TELEGRAM_BOT_TOKEN"] = settings_data.telegram_bot_token
|
||||
|
||||
if settings_data.telegram_chat_id is not None:
|
||||
_settings["telegram_chat_id"] = settings_data.telegram_chat_id
|
||||
os.environ["TELEGRAM_CHAT_ID"] = settings_data.telegram_chat_id
|
||||
env_updates["TELEGRAM_CHAT_ID"] = settings_data.telegram_chat_id
|
||||
|
||||
if settings_data.twilio_account_sid is not None:
|
||||
_settings["twilio_account_sid"] = settings_data.twilio_account_sid
|
||||
os.environ["TWILIO_ACCOUNT_SID"] = settings_data.twilio_account_sid
|
||||
env_updates["TWILIO_ACCOUNT_SID"] = settings_data.twilio_account_sid
|
||||
|
||||
if settings_data.twilio_auth_token is not None:
|
||||
_settings["twilio_auth_token"] = settings_data.twilio_auth_token
|
||||
os.environ["TWILIO_AUTH_TOKEN"] = settings_data.twilio_auth_token
|
||||
env_updates["TWILIO_AUTH_TOKEN"] = settings_data.twilio_auth_token
|
||||
|
||||
if settings_data.twilio_from_number is not None:
|
||||
_settings["twilio_from_number"] = settings_data.twilio_from_number
|
||||
os.environ["TWILIO_FROM_NUMBER"] = settings_data.twilio_from_number
|
||||
env_updates["TWILIO_FROM_NUMBER"] = settings_data.twilio_from_number
|
||||
|
||||
if settings_data.twilio_to_number is not None:
|
||||
_settings["twilio_to_number"] = settings_data.twilio_to_number
|
||||
os.environ["TWILIO_TO_NUMBER"] = settings_data.twilio_to_number
|
||||
env_updates["TWILIO_TO_NUMBER"] = settings_data.twilio_to_number
|
||||
|
||||
if settings_data.notification_severity_filter is not None:
|
||||
_settings["notification_severity_filter"] = settings_data.notification_severity_filter
|
||||
os.environ["NOTIFICATION_SEVERITY_FILTER"] = settings_data.notification_severity_filter
|
||||
env_updates["NOTIFICATION_SEVERITY_FILTER"] = settings_data.notification_severity_filter
|
||||
|
||||
# Persist to .env file on disk
|
||||
if env_updates:
|
||||
_update_env_file(env_updates)
|
||||
|
||||
# Reload notification config if any notification-related fields changed
|
||||
try:
|
||||
from backend.core.notification_manager import notification_manager
|
||||
notification_manager.reload_config()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
return await get_settings()
|
||||
|
||||
|
||||
@router.post("/notifications/test/{channel}")
|
||||
async def test_notification_channel(channel: str):
|
||||
"""Send a test notification to a specific channel (discord, telegram, whatsapp)."""
|
||||
try:
|
||||
from backend.core.notification_manager import notification_manager
|
||||
result = await notification_manager.test_channel(channel)
|
||||
return result
|
||||
except ImportError:
|
||||
raise HTTPException(500, "Notification manager not available")
|
||||
|
||||
|
||||
@router.post("/clear-database")
|
||||
async def clear_database(db: AsyncSession = Depends(get_db)):
|
||||
"""Clear all data from the database (reset to fresh state)"""
|
||||
try:
|
||||
# Delete in correct order to respect foreign key constraints
|
||||
await db.execute(delete(VulnerabilityTest))
|
||||
await db.execute(delete(Vulnerability))
|
||||
await db.execute(delete(Endpoint))
|
||||
await db.execute(delete(Report))
|
||||
await db.execute(delete(Target))
|
||||
await db.execute(delete(Scan))
|
||||
await db.commit()
|
||||
|
||||
return {
|
||||
"message": "Database cleared successfully",
|
||||
"status": "success"
|
||||
}
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
raise HTTPException(status_code=500, detail=f"Failed to clear database: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/stats")
|
||||
async def get_database_stats(db: AsyncSession = Depends(get_db)):
|
||||
"""Get database statistics"""
|
||||
from sqlalchemy import func
|
||||
|
||||
scans_count = (await db.execute(select(func.count()).select_from(Scan))).scalar() or 0
|
||||
vulns_count = (await db.execute(select(func.count()).select_from(Vulnerability))).scalar() or 0
|
||||
endpoints_count = (await db.execute(select(func.count()).select_from(Endpoint))).scalar() or 0
|
||||
reports_count = (await db.execute(select(func.count()).select_from(Report))).scalar() or 0
|
||||
|
||||
return {
|
||||
"scans": scans_count,
|
||||
"vulnerabilities": vulns_count,
|
||||
"endpoints": endpoints_count,
|
||||
"reports": reports_count
|
||||
}
|
||||
|
||||
|
||||
@router.get("/tools")
|
||||
async def get_installed_tools():
|
||||
"""Check which security tools are installed"""
|
||||
import asyncio
|
||||
import shutil
|
||||
|
||||
# Complete list of 40+ tools
|
||||
tools = {
|
||||
"recon": [
|
||||
"subfinder", "amass", "assetfinder", "chaos", "uncover",
|
||||
"dnsx", "massdns", "puredns", "cero", "tlsx", "cdncheck"
|
||||
],
|
||||
"web_discovery": [
|
||||
"httpx", "httprobe", "katana", "gospider", "hakrawler",
|
||||
"gau", "waybackurls", "cariddi", "getJS", "gowitness"
|
||||
],
|
||||
"fuzzing": [
|
||||
"ffuf", "gobuster", "dirb", "dirsearch", "wfuzz", "arjun", "paramspider"
|
||||
],
|
||||
"vulnerability_scanning": [
|
||||
"nuclei", "nikto", "sqlmap", "xsstrike", "dalfox", "crlfuzz"
|
||||
],
|
||||
"port_scanning": [
|
||||
"nmap", "naabu", "rustscan"
|
||||
],
|
||||
"utilities": [
|
||||
"gf", "qsreplace", "unfurl", "anew", "uro", "jq"
|
||||
],
|
||||
"tech_detection": [
|
||||
"whatweb", "wafw00f"
|
||||
],
|
||||
"exploitation": [
|
||||
"hydra", "medusa", "john", "hashcat"
|
||||
],
|
||||
"network": [
|
||||
"curl", "wget", "dig", "whois"
|
||||
]
|
||||
}
|
||||
|
||||
results = {}
|
||||
total_installed = 0
|
||||
total_tools = 0
|
||||
|
||||
for category, tool_list in tools.items():
|
||||
results[category] = {}
|
||||
for tool in tool_list:
|
||||
total_tools += 1
|
||||
# Check if tool exists in PATH
|
||||
is_installed = shutil.which(tool) is not None
|
||||
results[category][tool] = is_installed
|
||||
if is_installed:
|
||||
total_installed += 1
|
||||
|
||||
return {
|
||||
"tools": results,
|
||||
"summary": {
|
||||
"total": total_tools,
|
||||
"installed": total_installed,
|
||||
"missing": total_tools - total_installed,
|
||||
"percentage": round((total_installed / total_tools) * 100, 1)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# --- Model Catalog ---
|
||||
|
||||
# Cache for model catalog queries (60-second TTL)
|
||||
_model_cache: Dict[str, dict] = {}
|
||||
_model_cache_time: Dict[str, float] = {}
|
||||
MODEL_CACHE_TTL = 60 # seconds
|
||||
|
||||
# Common cloud models for dropdown suggestions
|
||||
CLOUD_MODELS = {
|
||||
"claude": [
|
||||
{"model_id": "claude-sonnet-4-20250514", "display_name": "Claude Sonnet 4", "context_length": 200000},
|
||||
{"model_id": "claude-opus-4-20250514", "display_name": "Claude Opus 4", "context_length": 200000},
|
||||
{"model_id": "claude-haiku-4-20250514", "display_name": "Claude Haiku 4", "context_length": 200000},
|
||||
],
|
||||
"openai": [
|
||||
{"model_id": "gpt-4-turbo-preview", "display_name": "GPT-4 Turbo", "context_length": 128000},
|
||||
{"model_id": "gpt-4o", "display_name": "GPT-4o", "context_length": 128000},
|
||||
{"model_id": "gpt-4o-mini", "display_name": "GPT-4o Mini", "context_length": 128000},
|
||||
{"model_id": "o1-preview", "display_name": "O1 Preview", "context_length": 128000},
|
||||
{"model_id": "o1-mini", "display_name": "O1 Mini", "context_length": 128000},
|
||||
],
|
||||
"gemini": [
|
||||
{"model_id": "gemini-pro", "display_name": "Gemini Pro", "context_length": 30720},
|
||||
{"model_id": "gemini-1.5-pro", "display_name": "Gemini 1.5 Pro", "context_length": 1048576},
|
||||
{"model_id": "gemini-1.5-flash", "display_name": "Gemini 1.5 Flash", "context_length": 1048576},
|
||||
{"model_id": "gemini-2.0-flash", "display_name": "Gemini 2.0 Flash", "context_length": 1048576},
|
||||
],
|
||||
"together": [
|
||||
{"model_id": "meta-llama/Llama-3.3-70B-Instruct-Turbo", "display_name": "Llama 3.3 70B", "context_length": 131072},
|
||||
{"model_id": "Qwen/Qwen2.5-72B-Instruct-Turbo", "display_name": "Qwen 2.5 72B", "context_length": 32768},
|
||||
{"model_id": "deepseek-ai/DeepSeek-R1", "display_name": "DeepSeek R1", "context_length": 65536},
|
||||
{"model_id": "mistralai/Mixtral-8x22B-Instruct-v0.1", "display_name": "Mixtral 8x22B", "context_length": 65536},
|
||||
],
|
||||
"fireworks": [
|
||||
{"model_id": "accounts/fireworks/models/llama-v3p3-70b-instruct", "display_name": "Llama 3.3 70B", "context_length": 131072},
|
||||
{"model_id": "accounts/fireworks/models/qwen2p5-72b-instruct", "display_name": "Qwen 2.5 72B", "context_length": 32768},
|
||||
{"model_id": "accounts/fireworks/models/deepseek-r1", "display_name": "DeepSeek R1", "context_length": 65536},
|
||||
],
|
||||
"codex": [
|
||||
{"model_id": "codex-mini-latest", "display_name": "Codex Mini", "context_length": 192000},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@router.get("/models/{provider}", response_model=ModelCatalogResponse)
|
||||
async def get_provider_models(provider: str):
|
||||
"""Get available models for a specific provider.
|
||||
|
||||
For local providers (ollama, lmstudio), queries the running service.
|
||||
For cloud providers, returns common model suggestions.
|
||||
For openrouter, queries the API for available models.
|
||||
"""
|
||||
import aiohttp
|
||||
|
||||
# Check cache
|
||||
now = time.time()
|
||||
if provider in _model_cache and (now - _model_cache_time.get(provider, 0)) < MODEL_CACHE_TTL:
|
||||
return ModelCatalogResponse(**_model_cache[provider])
|
||||
|
||||
if provider == "ollama":
|
||||
result = await _get_ollama_models()
|
||||
elif provider == "lmstudio":
|
||||
result = await _get_lmstudio_models()
|
||||
elif provider == "openrouter":
|
||||
result = await _get_openrouter_models()
|
||||
elif provider in CLOUD_MODELS:
|
||||
result = {
|
||||
"provider": provider,
|
||||
"models": [
|
||||
ModelInfo(
|
||||
provider=provider,
|
||||
model_id=m["model_id"],
|
||||
display_name=m["display_name"],
|
||||
context_length=m.get("context_length"),
|
||||
is_local=False,
|
||||
).dict()
|
||||
for m in CLOUD_MODELS[provider]
|
||||
],
|
||||
"available": True,
|
||||
"error": None,
|
||||
}
|
||||
else:
|
||||
raise HTTPException(400, f"Unknown provider: {provider}")
|
||||
|
||||
# Cache the result
|
||||
_model_cache[provider] = result
|
||||
_model_cache_time[provider] = now
|
||||
|
||||
return ModelCatalogResponse(**result)
|
||||
|
||||
|
||||
async def _get_ollama_models() -> dict:
|
||||
"""Query Ollama for installed models."""
|
||||
import aiohttp
|
||||
ollama_url = os.getenv("OLLAMA_BASE_URL", os.getenv("OLLAMA_URL", "http://localhost:11434"))
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
f"{ollama_url}/api/tags",
|
||||
timeout=aiohttp.ClientTimeout(total=3)
|
||||
) as resp:
|
||||
if resp.status != 200:
|
||||
return {"provider": "ollama", "models": [], "available": False, "error": f"HTTP {resp.status}"}
|
||||
data = await resp.json()
|
||||
models = []
|
||||
for m in data.get("models", []):
|
||||
name = m.get("name", "")
|
||||
size_bytes = m.get("size", 0)
|
||||
size_str = f"{size_bytes / 1e9:.1f}B" if size_bytes else None
|
||||
details = m.get("details", {})
|
||||
models.append(ModelInfo(
|
||||
provider="ollama",
|
||||
model_id=name,
|
||||
display_name=name,
|
||||
size=size_str,
|
||||
context_length=details.get("context_length"),
|
||||
is_local=True,
|
||||
).dict())
|
||||
return {"provider": "ollama", "models": models, "available": True, "error": None}
|
||||
except Exception as e:
|
||||
return {"provider": "ollama", "models": [], "available": False, "error": str(e)}
|
||||
|
||||
|
||||
async def _get_lmstudio_models() -> dict:
|
||||
"""Query LM Studio for loaded models."""
|
||||
import aiohttp
|
||||
lmstudio_url = os.getenv("LMSTUDIO_BASE_URL", os.getenv("LMSTUDIO_URL", "http://localhost:1234"))
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
f"{lmstudio_url}/v1/models",
|
||||
timeout=aiohttp.ClientTimeout(total=3)
|
||||
) as resp:
|
||||
if resp.status != 200:
|
||||
return {"provider": "lmstudio", "models": [], "available": False, "error": f"HTTP {resp.status}"}
|
||||
data = await resp.json()
|
||||
models = []
|
||||
for m in data.get("data", []):
|
||||
model_id = m.get("id", "")
|
||||
models.append(ModelInfo(
|
||||
provider="lmstudio",
|
||||
model_id=model_id,
|
||||
display_name=model_id,
|
||||
is_local=True,
|
||||
).dict())
|
||||
return {"provider": "lmstudio", "models": models, "available": True, "error": None}
|
||||
except Exception as e:
|
||||
return {"provider": "lmstudio", "models": [], "available": False, "error": str(e)}
|
||||
|
||||
|
||||
async def _get_openrouter_models() -> dict:
|
||||
"""Query OpenRouter for available models."""
|
||||
import aiohttp
|
||||
api_key = os.getenv("OPENROUTER_API_KEY", "")
|
||||
try:
|
||||
headers = {}
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
"https://openrouter.ai/api/v1/models",
|
||||
headers=headers,
|
||||
timeout=aiohttp.ClientTimeout(total=5)
|
||||
) as resp:
|
||||
if resp.status != 200:
|
||||
return {"provider": "openrouter", "models": [], "available": False, "error": f"HTTP {resp.status}"}
|
||||
data = await resp.json()
|
||||
models = []
|
||||
for m in data.get("data", [])[:100]: # Limit to 100 models
|
||||
model_id = m.get("id", "")
|
||||
name = m.get("name", model_id)
|
||||
ctx = m.get("context_length")
|
||||
models.append(ModelInfo(
|
||||
provider="openrouter",
|
||||
model_id=model_id,
|
||||
display_name=name,
|
||||
context_length=ctx,
|
||||
is_local=False,
|
||||
).dict())
|
||||
return {"provider": "openrouter", "models": models, "available": True, "error": None}
|
||||
except Exception as e:
|
||||
return {"provider": "openrouter", "models": [], "available": False, "error": str(e)}
|
||||
142
backend/api/v1/targets.py
Executable file
142
backend/api/v1/targets.py
Executable file
@@ -0,0 +1,142 @@
|
||||
"""
|
||||
NeuroSploit v3 - Targets API Endpoints
|
||||
"""
|
||||
from typing import List
|
||||
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from urllib.parse import urlparse
|
||||
import re
|
||||
|
||||
from backend.db.database import get_db
|
||||
from backend.schemas.target import TargetCreate, TargetBulkCreate, TargetValidation, TargetResponse
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def validate_url(url: str) -> TargetValidation:
|
||||
"""Validate and parse a URL"""
|
||||
url = url.strip()
|
||||
|
||||
if not url:
|
||||
return TargetValidation(url=url, valid=False, error="URL is empty")
|
||||
|
||||
# URL pattern
|
||||
url_pattern = re.compile(
|
||||
r'^https?://'
|
||||
r'(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+[A-Z]{2,6}\.?|'
|
||||
r'localhost|'
|
||||
r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})'
|
||||
r'(?::\d+)?'
|
||||
r'(?:/?|[/?]\S+)$', re.IGNORECASE)
|
||||
|
||||
# Try with the URL as-is
|
||||
if url_pattern.match(url):
|
||||
normalized = url
|
||||
elif url_pattern.match(f"https://{url}"):
|
||||
normalized = f"https://{url}"
|
||||
else:
|
||||
return TargetValidation(url=url, valid=False, error="Invalid URL format")
|
||||
|
||||
# Parse URL
|
||||
parsed = urlparse(normalized)
|
||||
|
||||
return TargetValidation(
|
||||
url=url,
|
||||
valid=True,
|
||||
normalized_url=normalized,
|
||||
hostname=parsed.hostname,
|
||||
port=parsed.port or (443 if parsed.scheme == "https" else 80),
|
||||
protocol=parsed.scheme
|
||||
)
|
||||
|
||||
|
||||
@router.post("/validate", response_model=TargetValidation)
|
||||
async def validate_target(target: TargetCreate):
|
||||
"""Validate a single target URL"""
|
||||
return validate_url(target.url)
|
||||
|
||||
|
||||
@router.post("/validate/bulk", response_model=List[TargetValidation])
|
||||
async def validate_targets_bulk(targets: TargetBulkCreate):
|
||||
"""Validate multiple target URLs"""
|
||||
results = []
|
||||
for url in targets.urls:
|
||||
results.append(validate_url(url))
|
||||
return results
|
||||
|
||||
|
||||
@router.post("/upload", response_model=List[TargetValidation])
|
||||
async def upload_targets(file: UploadFile = File(...)):
|
||||
"""Upload a file with URLs (one per line)"""
|
||||
if not file.filename:
|
||||
raise HTTPException(status_code=400, detail="No file provided")
|
||||
|
||||
# Check file extension
|
||||
allowed_extensions = {".txt", ".csv", ".lst"}
|
||||
ext = "." + file.filename.split(".")[-1].lower() if "." in file.filename else ""
|
||||
if ext not in allowed_extensions:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid file type. Allowed: {', '.join(allowed_extensions)}"
|
||||
)
|
||||
|
||||
# Read file content
|
||||
content = await file.read()
|
||||
try:
|
||||
text = content.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
try:
|
||||
text = content.decode("latin-1")
|
||||
except Exception:
|
||||
raise HTTPException(status_code=400, detail="Unable to decode file")
|
||||
|
||||
# Parse URLs (one per line, or comma-separated)
|
||||
urls = []
|
||||
for line in text.split("\n"):
|
||||
line = line.strip()
|
||||
if not line or line.startswith("#"):
|
||||
continue
|
||||
# Handle comma-separated URLs
|
||||
if "," in line and "://" in line:
|
||||
for url in line.split(","):
|
||||
url = url.strip()
|
||||
if url:
|
||||
urls.append(url)
|
||||
else:
|
||||
urls.append(line)
|
||||
|
||||
if not urls:
|
||||
raise HTTPException(status_code=400, detail="No URLs found in file")
|
||||
|
||||
# Validate all URLs
|
||||
results = []
|
||||
for url in urls:
|
||||
results.append(validate_url(url))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
@router.post("/parse-input", response_model=List[TargetValidation])
|
||||
async def parse_target_input(input_text: str):
|
||||
"""Parse target input (comma-separated or newline-separated)"""
|
||||
urls = []
|
||||
|
||||
# Split by newlines first
|
||||
for line in input_text.split("\n"):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
# Then split by commas
|
||||
for url in line.split(","):
|
||||
url = url.strip()
|
||||
if url:
|
||||
urls.append(url)
|
||||
|
||||
if not urls:
|
||||
raise HTTPException(status_code=400, detail="No URLs provided")
|
||||
|
||||
results = []
|
||||
for url in urls:
|
||||
results.append(validate_url(url))
|
||||
|
||||
return results
|
||||
753
backend/api/v1/terminal.py
Executable file
753
backend/api/v1/terminal.py
Executable file
@@ -0,0 +1,753 @@
|
||||
"""
|
||||
Terminal Agent API - Interactive infrastructure pentesting via AI chat + Docker sandbox.
|
||||
|
||||
Provides session-based terminal interaction with AI-guided command execution,
|
||||
exploitation path tracking, and VPN status monitoring.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from fastapi import APIRouter, HTTPException, UploadFile, File, Form
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.llm_manager import LLMManager
|
||||
from core.sandbox_manager import get_sandbox
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# In-memory session store
|
||||
# ---------------------------------------------------------------------------
|
||||
terminal_sessions: Dict[str, Dict] = {}
|
||||
|
||||
# Map session_id -> KaliSandbox instance (per-session container)
|
||||
session_sandboxes: Dict[str, object] = {}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pre-built templates
|
||||
# ---------------------------------------------------------------------------
|
||||
TEMPLATES = {
|
||||
"network_scanner": {
|
||||
"name": "Network Scanner",
|
||||
"description": "Host discovery, port scanning, and service detection",
|
||||
"system_prompt": (
|
||||
"You are an expert network reconnaissance specialist. You guide the "
|
||||
"operator through systematic host discovery, port scanning, and service "
|
||||
"fingerprinting. Always suggest nmap flags appropriate for the situation, "
|
||||
"explain output, and recommend next steps based on discovered services. "
|
||||
"Prioritize stealth when asked and suggest timing/fragmentation options."
|
||||
),
|
||||
"initial_commands": [
|
||||
"nmap -sn {target}",
|
||||
"nmap -sV -sC -O -p- {target}",
|
||||
"nmap -sU --top-ports 50 {target}",
|
||||
],
|
||||
},
|
||||
"lateral_movement": {
|
||||
"name": "Lateral Movement",
|
||||
"description": "Pass-the-hash, SMB/WinRM pivoting, and SSH tunneling",
|
||||
"system_prompt": (
|
||||
"You are a lateral movement specialist. You help the operator pivot "
|
||||
"through compromised networks using techniques such as pass-the-hash, "
|
||||
"SMB relay, WinRM sessions, SSH tunneling, and SOCKS proxying. Always "
|
||||
"verify credentials before attempting pivots, suggest cleanup steps, "
|
||||
"and track which hosts have been compromised."
|
||||
),
|
||||
"initial_commands": [
|
||||
"crackmapexec smb {target} -u '' -p ''",
|
||||
"crackmapexec smb {target} --shares -u '' -p ''",
|
||||
"ssh -D 1080 -N -f user@{target}",
|
||||
],
|
||||
},
|
||||
"privilege_escalation": {
|
||||
"name": "Privilege Escalation",
|
||||
"description": "SUID binaries, kernel exploits, cron jobs, and writable paths",
|
||||
"system_prompt": (
|
||||
"You are a privilege escalation expert for Linux and Windows systems. "
|
||||
"Guide the operator through enumeration of SUID/SGID binaries, kernel "
|
||||
"version checks, misconfigured cron jobs, writable PATH directories, "
|
||||
"sudo misconfigurations, and capability abuse. Suggest automated tools "
|
||||
"like linpeas/winpeas when appropriate and explain each finding."
|
||||
),
|
||||
"initial_commands": [
|
||||
"id && whoami && uname -a",
|
||||
"find / -perm -4000 -type f 2>/dev/null",
|
||||
"cat /etc/crontab && ls -la /etc/cron.*",
|
||||
"echo $PATH | tr ':' '\\n' | xargs -I {} ls -ld {}",
|
||||
],
|
||||
},
|
||||
"vpn_recon": {
|
||||
"name": "VPN Reconnaissance",
|
||||
"description": "VPN connection management and internal network discovery",
|
||||
"system_prompt": (
|
||||
"You are a VPN and internal network reconnaissance specialist. You "
|
||||
"help the operator connect to target VPNs, verify tunnel status, "
|
||||
"discover internal subnets, and enumerate services behind the VPN. "
|
||||
"Always confirm connectivity before proceeding with scans and suggest "
|
||||
"appropriate scope for internal reconnaissance."
|
||||
),
|
||||
"initial_commands": [
|
||||
"openvpn --config client.ovpn --daemon",
|
||||
"ip addr show tun0",
|
||||
"ip route | grep tun",
|
||||
"nmap -sn 10.0.0.0/24",
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pydantic request / response models
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class CreateSessionRequest(BaseModel):
|
||||
template_id: Optional[str] = None
|
||||
target: Optional[str] = ""
|
||||
name: Optional[str] = ""
|
||||
|
||||
|
||||
class MessageRequest(BaseModel):
|
||||
message: str
|
||||
|
||||
|
||||
class ExecuteCommandRequest(BaseModel):
|
||||
command: str
|
||||
execution_method: str = "sandbox" # "sandbox" or "direct"
|
||||
|
||||
|
||||
class ExploitationStepRequest(BaseModel):
|
||||
description: str
|
||||
command: Optional[str] = ""
|
||||
result: Optional[str] = ""
|
||||
step_type: str = "recon" # recon | exploit | pivot | escalate | action
|
||||
|
||||
|
||||
class SessionSummary(BaseModel):
|
||||
session_id: str
|
||||
name: str
|
||||
target: str
|
||||
template_id: Optional[str]
|
||||
status: str
|
||||
created_at: str
|
||||
messages_count: int
|
||||
commands_count: int
|
||||
|
||||
|
||||
class MessageResponse(BaseModel):
|
||||
role: str
|
||||
response: str
|
||||
timestamp: str
|
||||
suggested_commands: List[str]
|
||||
|
||||
|
||||
class CommandResult(BaseModel):
|
||||
command: str
|
||||
exit_code: int
|
||||
stdout: str
|
||||
stderr: str
|
||||
duration: float
|
||||
execution_method: str
|
||||
timestamp: str
|
||||
|
||||
|
||||
class VPNStatus(BaseModel):
|
||||
connected: bool
|
||||
ip: Optional[str] = None
|
||||
interface: Optional[str] = None
|
||||
container_name: Optional[str] = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _now_iso() -> str:
|
||||
return datetime.now(timezone.utc).isoformat()
|
||||
|
||||
|
||||
def _build_session(
|
||||
session_id: str,
|
||||
name: str,
|
||||
target: str,
|
||||
template_id: Optional[str],
|
||||
) -> Dict:
|
||||
return {
|
||||
"session_id": session_id,
|
||||
"name": name,
|
||||
"target": target,
|
||||
"template_id": template_id,
|
||||
"status": "active",
|
||||
"created_at": _now_iso(),
|
||||
"messages": [],
|
||||
"command_history": [],
|
||||
"exploitation_path": [],
|
||||
"vpn_status": {"connected": False, "ip": None},
|
||||
"container_name": None,
|
||||
"vpn_config_uploaded": False,
|
||||
}
|
||||
|
||||
|
||||
def _get_session(session_id: str) -> Dict:
|
||||
session = terminal_sessions.get(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
return session
|
||||
|
||||
|
||||
def _build_context_string(
|
||||
messages: List[Dict],
|
||||
commands: List[Dict],
|
||||
exploitation: List[Dict],
|
||||
) -> str:
|
||||
parts: List[str] = []
|
||||
|
||||
if messages:
|
||||
parts.append("=== Recent Conversation ===")
|
||||
for msg in messages:
|
||||
role = msg.get("role", "unknown").upper()
|
||||
parts.append(f"[{role}] {msg.get('content', '')}")
|
||||
|
||||
if commands:
|
||||
parts.append("\n=== Recent Command Results ===")
|
||||
for cmd in commands:
|
||||
parts.append(
|
||||
f"$ {cmd['command']}\n"
|
||||
f"Exit code: {cmd['exit_code']}\n"
|
||||
f"Stdout: {cmd['stdout'][:500]}\n"
|
||||
f"Stderr: {cmd['stderr'][:300]}"
|
||||
)
|
||||
|
||||
if exploitation:
|
||||
parts.append("\n=== Exploitation Path ===")
|
||||
for i, step in enumerate(exploitation, 1):
|
||||
parts.append(
|
||||
f"Step {i} [{step['step_type']}]: {step['description']}"
|
||||
)
|
||||
if step.get("command"):
|
||||
parts.append(f" Command: {step['command']}")
|
||||
if step.get("result"):
|
||||
parts.append(f" Result: {step['result'][:300]}")
|
||||
|
||||
return "\n".join(parts)
|
||||
|
||||
|
||||
def _extract_suggested_commands(text: str) -> List[str]:
|
||||
"""Extract commands from backtick-fenced code blocks."""
|
||||
blocks = re.findall(r"```(?:bash|sh|shell)?\n?(.*?)```", text, re.DOTALL)
|
||||
commands: List[str] = []
|
||||
for block in blocks:
|
||||
for line in block.strip().splitlines():
|
||||
stripped = line.strip()
|
||||
if stripped and not stripped.startswith("#"):
|
||||
commands.append(stripped)
|
||||
return commands
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Template endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.get("/templates")
|
||||
async def list_templates():
|
||||
"""List all available session templates."""
|
||||
result = []
|
||||
for tid, tmpl in TEMPLATES.items():
|
||||
result.append({
|
||||
"id": tid,
|
||||
"name": tmpl["name"],
|
||||
"description": tmpl["description"],
|
||||
"initial_commands": tmpl["initial_commands"],
|
||||
})
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Session CRUD
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.post("/session")
|
||||
async def create_session(req: CreateSessionRequest):
|
||||
"""Create a new terminal session, optionally from a template."""
|
||||
session_id = str(uuid.uuid4())
|
||||
target = req.target or ""
|
||||
template_id = req.template_id
|
||||
|
||||
if template_id and template_id not in TEMPLATES:
|
||||
raise HTTPException(status_code=400, detail=f"Unknown template: {template_id}")
|
||||
|
||||
name = req.name or (
|
||||
TEMPLATES[template_id]["name"] if template_id else f"Session {session_id[:8]}"
|
||||
)
|
||||
|
||||
session = _build_session(session_id, name, target, template_id)
|
||||
|
||||
# Provision a per-session Kali container (best-effort)
|
||||
try:
|
||||
from core.container_pool import get_pool
|
||||
pool = get_pool()
|
||||
sandbox = await pool.get_or_create(f"terminal-{session_id}", enable_vpn=True)
|
||||
session_sandboxes[session_id] = sandbox
|
||||
session["container_name"] = sandbox.container_name
|
||||
except Exception as exc:
|
||||
logger.warning(f"Failed to provision Kali container for terminal session: {exc}")
|
||||
|
||||
# Seed initial system message from template
|
||||
if template_id:
|
||||
tmpl = TEMPLATES[template_id]
|
||||
session["messages"].append({
|
||||
"role": "system",
|
||||
"content": tmpl["system_prompt"],
|
||||
"timestamp": _now_iso(),
|
||||
"metadata": {"template": template_id},
|
||||
})
|
||||
# Provide initial suggested commands with target interpolated
|
||||
initial_cmds = [
|
||||
cmd.replace("{target}", target) for cmd in tmpl["initial_commands"]
|
||||
]
|
||||
session["messages"].append({
|
||||
"role": "assistant",
|
||||
"content": (
|
||||
f"Session initialised with the **{tmpl['name']}** template.\n\n"
|
||||
f"Target: `{target or '(not set)'}`\n\n"
|
||||
"Suggested starting commands:\n"
|
||||
+ "\n".join(f"```\n{c}\n```" for c in initial_cmds)
|
||||
),
|
||||
"timestamp": _now_iso(),
|
||||
"suggested_commands": initial_cmds,
|
||||
})
|
||||
|
||||
terminal_sessions[session_id] = session
|
||||
return session
|
||||
|
||||
|
||||
@router.get("/sessions")
|
||||
async def list_sessions():
|
||||
"""Return lightweight summaries of every session."""
|
||||
summaries = []
|
||||
for sid, s in terminal_sessions.items():
|
||||
summaries.append(
|
||||
SessionSummary(
|
||||
session_id=sid,
|
||||
name=s["name"],
|
||||
target=s["target"],
|
||||
template_id=s["template_id"],
|
||||
status=s["status"],
|
||||
created_at=s["created_at"],
|
||||
messages_count=len(s["messages"]),
|
||||
commands_count=len(s["command_history"]),
|
||||
).model_dump()
|
||||
)
|
||||
return summaries
|
||||
|
||||
|
||||
@router.get("/sessions/{session_id}")
|
||||
async def get_session(session_id: str):
|
||||
"""Return the full session including messages, commands, and exploitation path."""
|
||||
return _get_session(session_id)
|
||||
|
||||
|
||||
@router.delete("/sessions/{session_id}")
|
||||
async def delete_session(session_id: str):
|
||||
"""Delete a terminal session and its Kali container."""
|
||||
if session_id not in terminal_sessions:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
|
||||
# Destroy associated Kali container
|
||||
sandbox = session_sandboxes.pop(session_id, None)
|
||||
if sandbox:
|
||||
try:
|
||||
from core.container_pool import get_pool
|
||||
pool = get_pool()
|
||||
await pool.destroy(f"terminal-{session_id}")
|
||||
except Exception as exc:
|
||||
logger.warning(f"Failed to destroy container for session {session_id}: {exc}")
|
||||
|
||||
del terminal_sessions[session_id]
|
||||
return {"status": "deleted", "session_id": session_id}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AI message interaction
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.post("/sessions/{session_id}/message")
|
||||
async def send_message(session_id: str, req: MessageRequest):
|
||||
"""Send a user prompt to the AI and receive a response with suggested commands."""
|
||||
session = _get_session(session_id)
|
||||
user_message = req.message.strip()
|
||||
if not user_message:
|
||||
raise HTTPException(status_code=400, detail="Message content cannot be empty")
|
||||
|
||||
# Record user message
|
||||
session["messages"].append({
|
||||
"role": "user",
|
||||
"content": user_message,
|
||||
"timestamp": _now_iso(),
|
||||
"metadata": {},
|
||||
})
|
||||
|
||||
# Determine system prompt
|
||||
template_id = session.get("template_id")
|
||||
if template_id and template_id in TEMPLATES:
|
||||
system_prompt = TEMPLATES[template_id]["system_prompt"]
|
||||
else:
|
||||
system_prompt = (
|
||||
"You are an expert infrastructure penetration tester. Help the "
|
||||
"operator plan and execute attacks against the target. Suggest "
|
||||
"concrete commands, explain their purpose, and interpret output. "
|
||||
"Always wrap commands in fenced code blocks so they can be extracted."
|
||||
)
|
||||
|
||||
# Build context window
|
||||
context_messages = session["messages"][-20:]
|
||||
context_cmds = session["command_history"][-10:]
|
||||
exploitation = session["exploitation_path"]
|
||||
context = _build_context_string(context_messages, context_cmds, exploitation)
|
||||
|
||||
# Call LLM
|
||||
try:
|
||||
llm = LLMManager()
|
||||
prompt = f"{context}\n\nUser: {user_message}"
|
||||
response = await llm.generate(prompt, system_prompt)
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=502, detail=f"LLM call failed: {exc}")
|
||||
|
||||
suggested_commands = _extract_suggested_commands(response)
|
||||
|
||||
# Record assistant response
|
||||
session["messages"].append({
|
||||
"role": "assistant",
|
||||
"content": response,
|
||||
"timestamp": _now_iso(),
|
||||
"suggested_commands": suggested_commands,
|
||||
})
|
||||
|
||||
return MessageResponse(
|
||||
role="assistant",
|
||||
response=response,
|
||||
timestamp=session["messages"][-1]["timestamp"],
|
||||
suggested_commands=suggested_commands,
|
||||
).model_dump()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Command execution
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.post("/sessions/{session_id}/execute")
|
||||
async def execute_command(session_id: str, req: ExecuteCommandRequest):
|
||||
"""Execute a command in the Docker sandbox (fallback: direct shell)."""
|
||||
session = _get_session(session_id)
|
||||
command = req.command.strip()
|
||||
if not command:
|
||||
raise HTTPException(status_code=400, detail="Command cannot be empty")
|
||||
|
||||
start = time.time()
|
||||
stdout = ""
|
||||
stderr = ""
|
||||
exit_code = -1
|
||||
execution_method = "direct"
|
||||
|
||||
# Use requested execution method
|
||||
use_sandbox = req.execution_method == "sandbox"
|
||||
|
||||
if use_sandbox:
|
||||
# Prefer session's own Kali container
|
||||
sandbox = session_sandboxes.get(session_id)
|
||||
if sandbox and sandbox.is_available:
|
||||
try:
|
||||
result = await sandbox.execute_raw(command)
|
||||
stdout = result.stdout
|
||||
stderr = result.stderr
|
||||
exit_code = result.exit_code
|
||||
execution_method = "kali-sandbox"
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Fallback to shared sandbox
|
||||
if execution_method == "direct":
|
||||
try:
|
||||
shared = await get_sandbox()
|
||||
if shared and shared.is_available:
|
||||
result = await shared.execute_raw(command)
|
||||
stdout = result.stdout
|
||||
stderr = result.stderr
|
||||
exit_code = result.exit_code
|
||||
execution_method = "sandbox"
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Fallback or direct execution requested
|
||||
if execution_method not in ("kali-sandbox", "sandbox"):
|
||||
try:
|
||||
proc = await asyncio.create_subprocess_shell(
|
||||
command,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
raw_stdout, raw_stderr = await asyncio.wait_for(
|
||||
proc.communicate(), timeout=120
|
||||
)
|
||||
stdout = raw_stdout.decode(errors="replace")
|
||||
stderr = raw_stderr.decode(errors="replace")
|
||||
exit_code = proc.returncode or 0
|
||||
execution_method = "direct"
|
||||
except asyncio.TimeoutError:
|
||||
stderr = "Command timed out after 120 seconds"
|
||||
exit_code = 124
|
||||
except Exception as exc:
|
||||
stderr = str(exc)
|
||||
exit_code = 1
|
||||
|
||||
duration = round(time.time() - start, 3)
|
||||
|
||||
cmd_record = {
|
||||
"command": command,
|
||||
"exit_code": exit_code,
|
||||
"stdout": stdout,
|
||||
"stderr": stderr,
|
||||
"duration": duration,
|
||||
"execution_method": execution_method,
|
||||
"timestamp": _now_iso(),
|
||||
}
|
||||
session["command_history"].append(cmd_record)
|
||||
|
||||
# Mirror into messages for AI context continuity
|
||||
output_preview = stdout[:2000] if stdout else stderr[:2000]
|
||||
session["messages"].append({
|
||||
"role": "tool",
|
||||
"content": f"$ {command}\n[exit {exit_code}] ({execution_method}, {duration}s)\n{output_preview}",
|
||||
"timestamp": cmd_record["timestamp"],
|
||||
"metadata": {"exit_code": exit_code, "execution_method": execution_method},
|
||||
})
|
||||
|
||||
return CommandResult(**cmd_record).model_dump()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Exploitation path
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.post("/sessions/{session_id}/exploitation-path")
|
||||
async def add_exploitation_step(session_id: str, req: ExploitationStepRequest):
|
||||
"""Add a manual step to the exploitation path timeline."""
|
||||
session = _get_session(session_id)
|
||||
|
||||
valid_types = {"recon", "exploit", "pivot", "escalate", "action"}
|
||||
if req.step_type not in valid_types:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"step_type must be one of {sorted(valid_types)}",
|
||||
)
|
||||
|
||||
step = {
|
||||
"description": req.description,
|
||||
"command": req.command or "",
|
||||
"result": req.result or "",
|
||||
"timestamp": _now_iso(),
|
||||
"step_type": req.step_type,
|
||||
}
|
||||
session["exploitation_path"].append(step)
|
||||
return step
|
||||
|
||||
|
||||
@router.get("/sessions/{session_id}/exploitation-path")
|
||||
async def get_exploitation_path(session_id: str):
|
||||
"""Return the full exploitation path timeline."""
|
||||
session = _get_session(session_id)
|
||||
return session["exploitation_path"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# VPN management
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.post("/sessions/{session_id}/vpn/upload")
|
||||
async def upload_vpn_config(
|
||||
session_id: str,
|
||||
ovpn_file: UploadFile = File(...),
|
||||
username: Optional[str] = Form(None),
|
||||
password: Optional[str] = Form(None),
|
||||
):
|
||||
"""Upload .ovpn config and optionally credentials into the session's container."""
|
||||
session = _get_session(session_id)
|
||||
sandbox = session_sandboxes.get(session_id)
|
||||
|
||||
if not sandbox or not sandbox.is_available:
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="No Kali container available for this session.",
|
||||
)
|
||||
|
||||
content = await ovpn_file.read()
|
||||
if len(content) > 1_000_000:
|
||||
raise HTTPException(status_code=400, detail="File too large (max 1MB)")
|
||||
if not (ovpn_file.filename or "").endswith((".ovpn", ".conf")):
|
||||
raise HTTPException(status_code=400, detail="File must be .ovpn or .conf")
|
||||
|
||||
# Upload config to container
|
||||
dest = "/etc/openvpn/client.ovpn"
|
||||
ok = await sandbox.upload_file(content, dest)
|
||||
if not ok:
|
||||
raise HTTPException(status_code=500, detail="Failed to upload config to container")
|
||||
|
||||
# Write auth file if credentials provided
|
||||
if username and password:
|
||||
auth_bytes = f"{username}\n{password}\n".encode()
|
||||
await sandbox.upload_file(auth_bytes, "/etc/openvpn/auth.txt")
|
||||
await sandbox._exec("chmod 600 /etc/openvpn/auth.txt", timeout=5)
|
||||
await sandbox._exec(
|
||||
"grep -q 'auth-user-pass' /etc/openvpn/client.ovpn || "
|
||||
"echo 'auth-user-pass /etc/openvpn/auth.txt' >> /etc/openvpn/client.ovpn",
|
||||
timeout=5,
|
||||
)
|
||||
await sandbox._exec(
|
||||
"sed -i 's|auth-user-pass$|auth-user-pass /etc/openvpn/auth.txt|' /etc/openvpn/client.ovpn",
|
||||
timeout=5,
|
||||
)
|
||||
|
||||
session["vpn_config_uploaded"] = True
|
||||
return {
|
||||
"status": "uploaded",
|
||||
"filename": ovpn_file.filename,
|
||||
"credentials_set": bool(username),
|
||||
}
|
||||
|
||||
|
||||
@router.post("/sessions/{session_id}/vpn/connect")
|
||||
async def connect_vpn(session_id: str):
|
||||
"""Start VPN connection using previously uploaded config."""
|
||||
session = _get_session(session_id)
|
||||
sandbox = session_sandboxes.get(session_id)
|
||||
|
||||
if not sandbox or not sandbox.is_available:
|
||||
raise HTTPException(status_code=503, detail="No Kali container for this session")
|
||||
|
||||
if not session.get("vpn_config_uploaded"):
|
||||
raise HTTPException(status_code=400, detail="No VPN config uploaded. Upload .ovpn first.")
|
||||
|
||||
# Create TUN device
|
||||
await sandbox._exec(
|
||||
"mkdir -p /dev/net && "
|
||||
"[ -c /dev/net/tun ] || mknod /dev/net/tun c 10 200; "
|
||||
"chmod 600 /dev/net/tun",
|
||||
timeout=5,
|
||||
)
|
||||
|
||||
# Kill any existing VPN
|
||||
await sandbox._exec("pkill -9 openvpn 2>/dev/null", timeout=5)
|
||||
|
||||
# Start OpenVPN
|
||||
result = await sandbox._exec(
|
||||
"openvpn --config /etc/openvpn/client.ovpn --daemon "
|
||||
"--log /var/log/openvpn.log --writepid /var/run/openvpn.pid",
|
||||
timeout=15,
|
||||
)
|
||||
if result.exit_code != 0:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"OpenVPN failed to start: {result.stderr or result.stdout}",
|
||||
)
|
||||
|
||||
# Wait for tunnel (max 20s)
|
||||
for _ in range(20):
|
||||
await asyncio.sleep(1)
|
||||
check = await sandbox._exec("ip addr show tun0 2>/dev/null", timeout=5)
|
||||
if check.exit_code == 0 and "inet " in check.stdout:
|
||||
match = re.search(r"inet\s+(\d+\.\d+\.\d+\.\d+)", check.stdout)
|
||||
ip = match.group(1) if match else None
|
||||
vpn = {"connected": True, "ip": ip}
|
||||
session["vpn_status"] = vpn
|
||||
return {"status": "connected", "ip": ip}
|
||||
|
||||
# Timeout
|
||||
log_result = await sandbox._exec("tail -30 /var/log/openvpn.log 2>/dev/null", timeout=5)
|
||||
raise HTTPException(
|
||||
status_code=504,
|
||||
detail=f"VPN connection timed out (20s). Log:\n{(log_result.stdout or '')[-500:]}",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/sessions/{session_id}/vpn/disconnect")
|
||||
async def disconnect_vpn(session_id: str):
|
||||
"""Kill VPN connection inside the container."""
|
||||
session = _get_session(session_id)
|
||||
sandbox = session_sandboxes.get(session_id)
|
||||
|
||||
if not sandbox or not sandbox.is_available:
|
||||
raise HTTPException(status_code=503, detail="No Kali container for this session")
|
||||
|
||||
await sandbox._exec(
|
||||
"kill $(cat /var/run/openvpn.pid 2>/dev/null) 2>/dev/null; "
|
||||
"pkill -9 openvpn 2>/dev/null",
|
||||
timeout=10,
|
||||
)
|
||||
session["vpn_status"] = {"connected": False, "ip": None}
|
||||
return {"status": "disconnected"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# VPN status
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.get("/sessions/{session_id}/vpn-status")
|
||||
async def get_vpn_status(session_id: str):
|
||||
"""Check VPN status inside the session's Kali container (fallback: host)."""
|
||||
session = _get_session(session_id)
|
||||
|
||||
sandbox = session_sandboxes.get(session_id)
|
||||
|
||||
# Check inside container if available
|
||||
if sandbox and sandbox.is_available:
|
||||
vpn_data = await sandbox.get_vpn_status()
|
||||
session["vpn_status"] = vpn_data
|
||||
return VPNStatus(
|
||||
connected=vpn_data["connected"],
|
||||
ip=vpn_data.get("ip"),
|
||||
interface=vpn_data.get("interface"),
|
||||
container_name=sandbox.container_name,
|
||||
).model_dump()
|
||||
|
||||
# Fallback: check on host (legacy behavior)
|
||||
connected = False
|
||||
ip_addr: Optional[str] = None
|
||||
|
||||
try:
|
||||
proc = await asyncio.create_subprocess_shell(
|
||||
"pgrep -a openvpn",
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
raw_stdout, _ = await asyncio.wait_for(proc.communicate(), timeout=5)
|
||||
if proc.returncode == 0 and raw_stdout.strip():
|
||||
connected = True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if connected:
|
||||
try:
|
||||
proc = await asyncio.create_subprocess_shell(
|
||||
"ip addr show tun0",
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
raw_stdout, _ = await asyncio.wait_for(proc.communicate(), timeout=5)
|
||||
if proc.returncode == 0:
|
||||
match = re.search(
|
||||
r"inet\s+(\d+\.\d+\.\d+\.\d+)", raw_stdout.decode(errors="replace")
|
||||
)
|
||||
if match:
|
||||
ip_addr = match.group(1)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
vpn = {"connected": connected, "ip": ip_addr}
|
||||
session["vpn_status"] = vpn
|
||||
return VPNStatus(**vpn).model_dump()
|
||||
876
backend/api/v1/vuln_lab.py
Executable file
876
backend/api/v1/vuln_lab.py
Executable file
@@ -0,0 +1,876 @@
|
||||
"""
|
||||
NeuroSploit v3 - Vulnerability Lab API Endpoints
|
||||
|
||||
Isolated vulnerability testing against labs, CTFs, and PortSwigger challenges.
|
||||
Test individual vuln types one at a time and track results.
|
||||
"""
|
||||
from typing import Optional, Dict, List
|
||||
from fastapi import APIRouter, HTTPException, BackgroundTasks
|
||||
from pydantic import BaseModel, Field
|
||||
from datetime import datetime
|
||||
from sqlalchemy import select, func, text
|
||||
|
||||
from backend.core.autonomous_agent import AutonomousAgent, OperationMode
|
||||
from backend.core.vuln_engine.registry import VulnerabilityRegistry
|
||||
from backend.db.database import async_session_factory
|
||||
from backend.models import Scan, Target, Vulnerability, Endpoint, Report, VulnLabChallenge
|
||||
|
||||
# Import agent.py's shared dicts so ScanDetailsPage can find our scans
|
||||
from backend.api.v1.agent import (
|
||||
agent_results, agent_instances, agent_to_scan, scan_to_agent
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# In-memory tracking for running lab tests
|
||||
lab_agents: Dict[str, AutonomousAgent] = {}
|
||||
lab_results: Dict[str, Dict] = {}
|
||||
|
||||
|
||||
# --- Request/Response Models ---
|
||||
|
||||
class VulnLabRunRequest(BaseModel):
|
||||
target_url: str = Field(..., description="Target URL to test (lab, CTF, etc.)")
|
||||
vuln_type: str = Field(..., description="Vulnerability type to test (e.g. xss_reflected)")
|
||||
challenge_name: Optional[str] = Field(None, description="Name of the lab/challenge")
|
||||
auth_type: Optional[str] = Field(None, description="Auth type: cookie, bearer, basic, header")
|
||||
auth_value: Optional[str] = Field(None, description="Auth credential value")
|
||||
custom_headers: Optional[Dict[str, str]] = Field(None, description="Custom HTTP headers")
|
||||
notes: Optional[str] = Field(None, description="Notes about this challenge")
|
||||
|
||||
|
||||
class VulnLabResponse(BaseModel):
|
||||
challenge_id: str
|
||||
agent_id: str
|
||||
status: str
|
||||
message: str
|
||||
|
||||
|
||||
class VulnTypeInfo(BaseModel):
|
||||
key: str
|
||||
title: str
|
||||
severity: str
|
||||
cwe_id: str
|
||||
category: str
|
||||
|
||||
|
||||
# --- Vuln type categories for the selector ---
|
||||
|
||||
VULN_CATEGORIES = {
|
||||
"injection": {
|
||||
"label": "Injection",
|
||||
"types": [
|
||||
"xss_reflected", "xss_stored", "xss_dom",
|
||||
"sqli_error", "sqli_union", "sqli_blind", "sqli_time",
|
||||
"command_injection", "ssti", "nosql_injection",
|
||||
]
|
||||
},
|
||||
"advanced_injection": {
|
||||
"label": "Advanced Injection",
|
||||
"types": [
|
||||
"ldap_injection", "xpath_injection", "graphql_injection",
|
||||
"crlf_injection", "header_injection", "email_injection",
|
||||
"el_injection", "log_injection", "html_injection",
|
||||
"csv_injection", "orm_injection",
|
||||
]
|
||||
},
|
||||
"file_access": {
|
||||
"label": "File Access",
|
||||
"types": [
|
||||
"lfi", "rfi", "path_traversal", "xxe", "file_upload",
|
||||
"arbitrary_file_read", "arbitrary_file_delete", "zip_slip",
|
||||
]
|
||||
},
|
||||
"request_forgery": {
|
||||
"label": "Request Forgery",
|
||||
"types": [
|
||||
"ssrf", "csrf", "graphql_introspection", "graphql_dos",
|
||||
]
|
||||
},
|
||||
"authentication": {
|
||||
"label": "Authentication",
|
||||
"types": [
|
||||
"auth_bypass", "jwt_manipulation", "session_fixation",
|
||||
"weak_password", "default_credentials", "two_factor_bypass",
|
||||
"oauth_misconfig",
|
||||
]
|
||||
},
|
||||
"authorization": {
|
||||
"label": "Authorization",
|
||||
"types": [
|
||||
"idor", "bola", "privilege_escalation",
|
||||
"bfla", "mass_assignment", "forced_browsing",
|
||||
]
|
||||
},
|
||||
"client_side": {
|
||||
"label": "Client-Side",
|
||||
"types": [
|
||||
"cors_misconfiguration", "clickjacking", "open_redirect",
|
||||
"dom_clobbering", "postmessage_vuln", "websocket_hijack",
|
||||
"prototype_pollution", "css_injection", "tabnabbing",
|
||||
]
|
||||
},
|
||||
"infrastructure": {
|
||||
"label": "Infrastructure",
|
||||
"types": [
|
||||
"security_headers", "ssl_issues", "http_methods",
|
||||
"directory_listing", "debug_mode", "exposed_admin_panel",
|
||||
"exposed_api_docs", "insecure_cookie_flags",
|
||||
]
|
||||
},
|
||||
"logic": {
|
||||
"label": "Business Logic",
|
||||
"types": [
|
||||
"race_condition", "business_logic", "rate_limit_bypass",
|
||||
"parameter_pollution", "type_juggling", "timing_attack",
|
||||
"host_header_injection", "http_smuggling", "cache_poisoning",
|
||||
]
|
||||
},
|
||||
"data_exposure": {
|
||||
"label": "Data Exposure",
|
||||
"types": [
|
||||
"sensitive_data_exposure", "information_disclosure",
|
||||
"api_key_exposure", "source_code_disclosure",
|
||||
"backup_file_exposure", "version_disclosure",
|
||||
]
|
||||
},
|
||||
"cloud_supply": {
|
||||
"label": "Cloud & Supply Chain",
|
||||
"types": [
|
||||
"s3_bucket_misconfig", "cloud_metadata_exposure",
|
||||
"subdomain_takeover", "vulnerable_dependency",
|
||||
"container_escape", "serverless_misconfiguration",
|
||||
]
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _get_vuln_category(vuln_type: str) -> str:
|
||||
"""Get category for a vuln type"""
|
||||
for cat_key, cat_info in VULN_CATEGORIES.items():
|
||||
if vuln_type in cat_info["types"]:
|
||||
return cat_key
|
||||
return "other"
|
||||
|
||||
|
||||
# --- Endpoints ---
|
||||
|
||||
@router.get("/types")
|
||||
async def list_vuln_types():
|
||||
"""List all available vulnerability types grouped by category"""
|
||||
registry = VulnerabilityRegistry()
|
||||
result = {}
|
||||
|
||||
for cat_key, cat_info in VULN_CATEGORIES.items():
|
||||
types_list = []
|
||||
for vtype in cat_info["types"]:
|
||||
info = registry.VULNERABILITY_INFO.get(vtype, {})
|
||||
types_list.append({
|
||||
"key": vtype,
|
||||
"title": info.get("title", vtype.replace("_", " ").title()),
|
||||
"severity": info.get("severity", "medium"),
|
||||
"cwe_id": info.get("cwe_id", ""),
|
||||
"description": info.get("description", "")[:120] if info.get("description") else "",
|
||||
})
|
||||
result[cat_key] = {
|
||||
"label": cat_info["label"],
|
||||
"types": types_list,
|
||||
"count": len(types_list),
|
||||
}
|
||||
|
||||
return {"categories": result, "total_types": sum(len(c["types"]) for c in VULN_CATEGORIES.values())}
|
||||
|
||||
|
||||
@router.post("/run", response_model=VulnLabResponse)
|
||||
async def run_vuln_lab(request: VulnLabRunRequest, background_tasks: BackgroundTasks):
|
||||
"""Launch an isolated vulnerability test for a specific vuln type"""
|
||||
import uuid
|
||||
|
||||
# Validate vuln type exists
|
||||
registry = VulnerabilityRegistry()
|
||||
if request.vuln_type not in registry.VULNERABILITY_INFO:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Unknown vulnerability type: {request.vuln_type}. Use GET /vuln-lab/types for available types."
|
||||
)
|
||||
|
||||
challenge_id = str(uuid.uuid4())
|
||||
agent_id = str(uuid.uuid4())[:8]
|
||||
category = _get_vuln_category(request.vuln_type)
|
||||
|
||||
# Build auth headers
|
||||
auth_headers = {}
|
||||
if request.auth_type and request.auth_value:
|
||||
if request.auth_type == "cookie":
|
||||
auth_headers["Cookie"] = request.auth_value
|
||||
elif request.auth_type == "bearer":
|
||||
auth_headers["Authorization"] = f"Bearer {request.auth_value}"
|
||||
elif request.auth_type == "basic":
|
||||
import base64
|
||||
auth_headers["Authorization"] = f"Basic {base64.b64encode(request.auth_value.encode()).decode()}"
|
||||
elif request.auth_type == "header":
|
||||
if ":" in request.auth_value:
|
||||
name, value = request.auth_value.split(":", 1)
|
||||
auth_headers[name.strip()] = value.strip()
|
||||
|
||||
if request.custom_headers:
|
||||
auth_headers.update(request.custom_headers)
|
||||
|
||||
# Create DB record
|
||||
async with async_session_factory() as db:
|
||||
challenge = VulnLabChallenge(
|
||||
id=challenge_id,
|
||||
target_url=request.target_url,
|
||||
challenge_name=request.challenge_name,
|
||||
vuln_type=request.vuln_type,
|
||||
vuln_category=category,
|
||||
auth_type=request.auth_type,
|
||||
auth_value=request.auth_value,
|
||||
status="running",
|
||||
agent_id=agent_id,
|
||||
started_at=datetime.utcnow(),
|
||||
notes=request.notes,
|
||||
)
|
||||
db.add(challenge)
|
||||
await db.commit()
|
||||
|
||||
# Init in-memory tracking (both local and in agent.py's shared dicts)
|
||||
vuln_info = registry.VULNERABILITY_INFO[request.vuln_type]
|
||||
lab_results[challenge_id] = {
|
||||
"status": "running",
|
||||
"agent_id": agent_id,
|
||||
"vuln_type": request.vuln_type,
|
||||
"target": request.target_url,
|
||||
"progress": 0,
|
||||
"phase": "initializing",
|
||||
"findings": [],
|
||||
"logs": [],
|
||||
}
|
||||
|
||||
# Also register in agent.py's shared results dict so /agent/status works
|
||||
agent_results[agent_id] = {
|
||||
"status": "running",
|
||||
"mode": "full_auto",
|
||||
"started_at": datetime.utcnow().isoformat(),
|
||||
"target": request.target_url,
|
||||
"task": f"VulnLab: {vuln_info.get('title', request.vuln_type)}",
|
||||
"logs": [],
|
||||
"findings": [],
|
||||
"report": None,
|
||||
"progress": 0,
|
||||
"phase": "initializing",
|
||||
}
|
||||
|
||||
# Launch agent in background
|
||||
background_tasks.add_task(
|
||||
_run_lab_test,
|
||||
challenge_id,
|
||||
agent_id,
|
||||
request.target_url,
|
||||
request.vuln_type,
|
||||
vuln_info.get("title", request.vuln_type),
|
||||
auth_headers,
|
||||
request.challenge_name,
|
||||
request.notes,
|
||||
)
|
||||
|
||||
return VulnLabResponse(
|
||||
challenge_id=challenge_id,
|
||||
agent_id=agent_id,
|
||||
status="running",
|
||||
message=f"Testing {vuln_info.get('title', request.vuln_type)} against {request.target_url}"
|
||||
)
|
||||
|
||||
|
||||
async def _run_lab_test(
|
||||
challenge_id: str,
|
||||
agent_id: str,
|
||||
target: str,
|
||||
vuln_type: str,
|
||||
vuln_title: str,
|
||||
auth_headers: Dict,
|
||||
challenge_name: Optional[str] = None,
|
||||
notes: Optional[str] = None,
|
||||
):
|
||||
"""Background task: run the agent focused on a single vuln type"""
|
||||
import asyncio
|
||||
|
||||
logs = []
|
||||
findings_list = []
|
||||
scan_id = None
|
||||
|
||||
async def log_callback(level: str, message: str):
|
||||
source = "llm" if any(tag in message for tag in ["[AI]", "[LLM]", "[USER PROMPT]", "[AI RESPONSE]"]) else "script"
|
||||
entry = {"level": level, "message": message, "time": datetime.utcnow().isoformat(), "source": source}
|
||||
logs.append(entry)
|
||||
# Update local tracking
|
||||
if challenge_id in lab_results:
|
||||
lab_results[challenge_id]["logs"] = logs
|
||||
# Also update agent.py's shared dict so /agent/logs works
|
||||
if agent_id in agent_results:
|
||||
agent_results[agent_id]["logs"] = logs
|
||||
|
||||
async def progress_callback(progress: int, phase: str):
|
||||
if challenge_id in lab_results:
|
||||
lab_results[challenge_id]["progress"] = progress
|
||||
lab_results[challenge_id]["phase"] = phase
|
||||
if agent_id in agent_results:
|
||||
agent_results[agent_id]["progress"] = progress
|
||||
agent_results[agent_id]["phase"] = phase
|
||||
|
||||
async def finding_callback(finding: Dict):
|
||||
findings_list.append(finding)
|
||||
if challenge_id in lab_results:
|
||||
lab_results[challenge_id]["findings"] = findings_list
|
||||
if agent_id in agent_results:
|
||||
agent_results[agent_id]["findings"] = findings_list
|
||||
agent_results[agent_id]["findings_count"] = len(findings_list)
|
||||
|
||||
try:
|
||||
async with async_session_factory() as db:
|
||||
# Create a scan record linked to this challenge
|
||||
scan = Scan(
|
||||
name=f"VulnLab: {vuln_title} - {target[:50]}",
|
||||
status="running",
|
||||
scan_type="full_auto",
|
||||
recon_enabled=True,
|
||||
progress=0,
|
||||
current_phase="initializing",
|
||||
custom_prompt=f"Focus ONLY on testing for {vuln_title} ({vuln_type}). "
|
||||
f"Do NOT test other vulnerability types. "
|
||||
f"Test thoroughly with multiple payloads and techniques for this specific vulnerability.",
|
||||
)
|
||||
db.add(scan)
|
||||
await db.commit()
|
||||
await db.refresh(scan)
|
||||
scan_id = scan.id
|
||||
|
||||
# Create target record
|
||||
target_record = Target(scan_id=scan_id, url=target, status="pending")
|
||||
db.add(target_record)
|
||||
await db.commit()
|
||||
|
||||
# Update challenge with scan_id
|
||||
result = await db.execute(
|
||||
select(VulnLabChallenge).where(VulnLabChallenge.id == challenge_id)
|
||||
)
|
||||
challenge = result.scalar_one_or_none()
|
||||
if challenge:
|
||||
challenge.scan_id = scan_id
|
||||
await db.commit()
|
||||
|
||||
if challenge_id in lab_results:
|
||||
lab_results[challenge_id]["scan_id"] = scan_id
|
||||
|
||||
# Register in agent.py's shared mappings so ScanDetailsPage works
|
||||
agent_to_scan[agent_id] = scan_id
|
||||
scan_to_agent[scan_id] = agent_id
|
||||
if agent_id in agent_results:
|
||||
agent_results[agent_id]["scan_id"] = scan_id
|
||||
|
||||
# Build focused prompt for isolated testing
|
||||
focused_prompt = (
|
||||
f"You are testing specifically for {vuln_title} ({vuln_type}). "
|
||||
f"Focus ALL your efforts on detecting and exploiting this single vulnerability type. "
|
||||
f"Do NOT scan for other vulnerability types. "
|
||||
f"Use all relevant payloads and techniques for {vuln_type}. "
|
||||
f"Be thorough: try multiple injection points, encoding bypasses, and edge cases. "
|
||||
f"This is a lab/CTF challenge - the vulnerability is expected to exist."
|
||||
)
|
||||
if challenge_name:
|
||||
focused_prompt += (
|
||||
f"\n\nCHALLENGE HINT: This is PortSwigger lab '{challenge_name}'. "
|
||||
f"Use this name to understand what specific technique or bypass is needed. "
|
||||
f"For example, 'angle brackets HTML-encoded' means attribute-based XSS, "
|
||||
f"'most tags and attributes blocked' means fuzz for allowed tags/events."
|
||||
)
|
||||
if notes:
|
||||
focused_prompt += f"\n\nUSER NOTES: {notes}"
|
||||
|
||||
lab_ctx = {
|
||||
"challenge_name": challenge_name,
|
||||
"notes": notes,
|
||||
"vuln_type": vuln_type,
|
||||
"is_lab": True,
|
||||
}
|
||||
|
||||
async with AutonomousAgent(
|
||||
target=target,
|
||||
mode=OperationMode.FULL_AUTO,
|
||||
log_callback=log_callback,
|
||||
progress_callback=progress_callback,
|
||||
auth_headers=auth_headers,
|
||||
custom_prompt=focused_prompt,
|
||||
finding_callback=finding_callback,
|
||||
lab_context=lab_ctx,
|
||||
) as agent:
|
||||
lab_agents[challenge_id] = agent
|
||||
# Also register in agent.py's shared instances so stop works
|
||||
agent_instances[agent_id] = agent
|
||||
|
||||
report = await agent.run()
|
||||
|
||||
lab_agents.pop(challenge_id, None)
|
||||
agent_instances.pop(agent_id, None)
|
||||
|
||||
# Use findings from report OR from real-time callbacks (fallback)
|
||||
report_findings = report.get("findings", [])
|
||||
# If report findings are empty but we got findings via callback, use those
|
||||
findings = report_findings if report_findings else findings_list
|
||||
# Also merge: if findings_list has entries not in report_findings, add them
|
||||
if not findings and findings_list:
|
||||
findings = findings_list
|
||||
|
||||
severity_counts = {"critical": 0, "high": 0, "medium": 0, "low": 0, "info": 0}
|
||||
findings_detail = []
|
||||
|
||||
for finding in findings:
|
||||
severity = finding.get("severity", "medium").lower()
|
||||
if severity in severity_counts:
|
||||
severity_counts[severity] += 1
|
||||
|
||||
findings_detail.append({
|
||||
"title": finding.get("title", ""),
|
||||
"vulnerability_type": finding.get("vulnerability_type", ""),
|
||||
"severity": severity,
|
||||
"affected_endpoint": finding.get("affected_endpoint", ""),
|
||||
"evidence": (finding.get("evidence", "") or "")[:500],
|
||||
"payload": (finding.get("payload", "") or "")[:200],
|
||||
})
|
||||
|
||||
# Save to vulnerabilities table
|
||||
vuln = Vulnerability(
|
||||
scan_id=scan_id,
|
||||
title=finding.get("title", finding.get("type", "Unknown")),
|
||||
vulnerability_type=finding.get("vulnerability_type", finding.get("type", "unknown")),
|
||||
severity=severity,
|
||||
cvss_score=finding.get("cvss_score"),
|
||||
cvss_vector=finding.get("cvss_vector"),
|
||||
cwe_id=finding.get("cwe_id"),
|
||||
description=finding.get("description", finding.get("evidence", "")),
|
||||
affected_endpoint=finding.get("affected_endpoint", finding.get("url", target)),
|
||||
poc_payload=finding.get("payload", finding.get("poc_payload", finding.get("poc_code", ""))),
|
||||
poc_parameter=finding.get("parameter", finding.get("poc_parameter", "")),
|
||||
poc_evidence=finding.get("evidence", finding.get("poc_evidence", "")),
|
||||
poc_request=str(finding.get("request", finding.get("poc_request", "")))[:5000],
|
||||
poc_response=str(finding.get("response", finding.get("poc_response", "")))[:5000],
|
||||
impact=finding.get("impact", ""),
|
||||
remediation=finding.get("remediation", ""),
|
||||
references=finding.get("references", []),
|
||||
ai_analysis=finding.get("ai_analysis", ""),
|
||||
screenshots=finding.get("screenshots", []),
|
||||
url=finding.get("url", finding.get("affected_endpoint", "")),
|
||||
parameter=finding.get("parameter", finding.get("poc_parameter", "")),
|
||||
)
|
||||
db.add(vuln)
|
||||
|
||||
# Save discovered endpoints from recon data
|
||||
endpoints_count = 0
|
||||
for ep in report.get("recon", {}).get("endpoints", []):
|
||||
endpoints_count += 1
|
||||
if isinstance(ep, str):
|
||||
endpoint = Endpoint(
|
||||
scan_id=scan_id,
|
||||
target_id=target_record.id,
|
||||
url=ep,
|
||||
method="GET",
|
||||
path=ep.split("?")[0].split("/")[-1] or "/"
|
||||
)
|
||||
else:
|
||||
endpoint = Endpoint(
|
||||
scan_id=scan_id,
|
||||
target_id=target_record.id,
|
||||
url=ep.get("url", ""),
|
||||
method=ep.get("method", "GET"),
|
||||
path=ep.get("path", "/")
|
||||
)
|
||||
db.add(endpoint)
|
||||
|
||||
# Determine result - more flexible matching
|
||||
# Check if any finding matches the target vuln type
|
||||
target_type_findings = [
|
||||
f for f in findings
|
||||
if _vuln_type_matches(vuln_type, f.get("vulnerability_type", ""))
|
||||
]
|
||||
# If the agent found ANY vulnerability, it detected something
|
||||
# (since we told it to focus on one type, any finding is relevant)
|
||||
if target_type_findings:
|
||||
result_status = "detected"
|
||||
elif len(findings) > 0:
|
||||
# Found other vulns but not the exact type
|
||||
result_status = "detected"
|
||||
else:
|
||||
result_status = "not_detected"
|
||||
|
||||
# Update scan
|
||||
scan.status = "completed"
|
||||
scan.completed_at = datetime.utcnow()
|
||||
scan.progress = 100
|
||||
scan.current_phase = "completed"
|
||||
scan.total_vulnerabilities = len(findings)
|
||||
scan.total_endpoints = endpoints_count
|
||||
scan.critical_count = severity_counts["critical"]
|
||||
scan.high_count = severity_counts["high"]
|
||||
scan.medium_count = severity_counts["medium"]
|
||||
scan.low_count = severity_counts["low"]
|
||||
scan.info_count = severity_counts["info"]
|
||||
|
||||
# Auto-generate report
|
||||
exec_summary = report.get("executive_summary", f"VulnLab test for {vuln_title} on {target}")
|
||||
report_record = Report(
|
||||
scan_id=scan_id,
|
||||
title=f"VulnLab: {vuln_title} - {target[:50]}",
|
||||
format="json",
|
||||
executive_summary=exec_summary[:1000] if exec_summary else None,
|
||||
)
|
||||
db.add(report_record)
|
||||
|
||||
# Persist logs (keep last 500 entries to avoid huge DB rows)
|
||||
persisted_logs = logs[-500:] if len(logs) > 500 else logs
|
||||
|
||||
# Update challenge record
|
||||
result_q = await db.execute(
|
||||
select(VulnLabChallenge).where(VulnLabChallenge.id == challenge_id)
|
||||
)
|
||||
challenge = result_q.scalar_one_or_none()
|
||||
if challenge:
|
||||
challenge.status = "completed"
|
||||
challenge.result = result_status
|
||||
challenge.completed_at = datetime.utcnow()
|
||||
challenge.duration = int((datetime.utcnow() - challenge.started_at).total_seconds()) if challenge.started_at else 0
|
||||
challenge.findings_count = len(findings)
|
||||
challenge.critical_count = severity_counts["critical"]
|
||||
challenge.high_count = severity_counts["high"]
|
||||
challenge.medium_count = severity_counts["medium"]
|
||||
challenge.low_count = severity_counts["low"]
|
||||
challenge.info_count = severity_counts["info"]
|
||||
challenge.findings_detail = findings_detail
|
||||
challenge.logs = persisted_logs
|
||||
challenge.endpoints_count = endpoints_count
|
||||
|
||||
await db.commit()
|
||||
|
||||
# Update in-memory results
|
||||
if challenge_id in lab_results:
|
||||
lab_results[challenge_id]["status"] = "completed"
|
||||
lab_results[challenge_id]["result"] = result_status
|
||||
lab_results[challenge_id]["findings"] = findings
|
||||
lab_results[challenge_id]["progress"] = 100
|
||||
lab_results[challenge_id]["phase"] = "completed"
|
||||
|
||||
if agent_id in agent_results:
|
||||
agent_results[agent_id]["status"] = "completed"
|
||||
agent_results[agent_id]["completed_at"] = datetime.utcnow().isoformat()
|
||||
agent_results[agent_id]["report"] = report
|
||||
agent_results[agent_id]["findings"] = findings
|
||||
agent_results[agent_id]["progress"] = 100
|
||||
agent_results[agent_id]["phase"] = "completed"
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
error_tb = traceback.format_exc()
|
||||
print(f"VulnLab error: {error_tb}")
|
||||
|
||||
if challenge_id in lab_results:
|
||||
lab_results[challenge_id]["status"] = "error"
|
||||
lab_results[challenge_id]["error"] = str(e)
|
||||
|
||||
if agent_id in agent_results:
|
||||
agent_results[agent_id]["status"] = "error"
|
||||
agent_results[agent_id]["error"] = str(e)
|
||||
|
||||
# Persist logs even on error
|
||||
persisted_logs = logs[-500:] if len(logs) > 500 else logs
|
||||
|
||||
# Update DB records
|
||||
try:
|
||||
async with async_session_factory() as db:
|
||||
result = await db.execute(
|
||||
select(VulnLabChallenge).where(VulnLabChallenge.id == challenge_id)
|
||||
)
|
||||
challenge = result.scalar_one_or_none()
|
||||
if challenge:
|
||||
challenge.status = "failed"
|
||||
challenge.result = "error"
|
||||
challenge.completed_at = datetime.utcnow()
|
||||
challenge.notes = (challenge.notes or "") + f"\nError: {str(e)}"
|
||||
challenge.logs = persisted_logs
|
||||
await db.commit()
|
||||
|
||||
if scan_id:
|
||||
result = await db.execute(select(Scan).where(Scan.id == scan_id))
|
||||
scan = result.scalar_one_or_none()
|
||||
if scan:
|
||||
scan.status = "failed"
|
||||
scan.error_message = str(e)
|
||||
scan.completed_at = datetime.utcnow()
|
||||
await db.commit()
|
||||
except:
|
||||
pass
|
||||
finally:
|
||||
lab_agents.pop(challenge_id, None)
|
||||
agent_instances.pop(agent_id, None)
|
||||
|
||||
|
||||
def _vuln_type_matches(target_type: str, found_type: str) -> bool:
|
||||
"""Check if a found vuln type matches the target type (flexible matching)"""
|
||||
if not found_type:
|
||||
return False
|
||||
target = target_type.lower().replace("_", " ").replace("-", " ")
|
||||
found = found_type.lower().replace("_", " ").replace("-", " ")
|
||||
# Exact match
|
||||
if target == found:
|
||||
return True
|
||||
# Target is substring of found or vice versa
|
||||
if target in found or found in target:
|
||||
return True
|
||||
# Key word matching for common patterns
|
||||
target_words = set(target.split())
|
||||
found_words = set(found.split())
|
||||
# If they share major keywords (xss, sqli, ssrf, etc.)
|
||||
major_keywords = {"xss", "sqli", "sql", "injection", "ssrf", "csrf", "lfi", "rfi",
|
||||
"xxe", "ssti", "idor", "cors", "jwt", "redirect", "traversal"}
|
||||
shared = target_words & found_words & major_keywords
|
||||
if shared:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@router.get("/challenges")
|
||||
async def list_challenges(
|
||||
vuln_type: Optional[str] = None,
|
||||
vuln_category: Optional[str] = None,
|
||||
status: Optional[str] = None,
|
||||
result: Optional[str] = None,
|
||||
limit: int = 50,
|
||||
):
|
||||
"""List all vulnerability lab challenges with optional filtering"""
|
||||
async with async_session_factory() as db:
|
||||
query = select(VulnLabChallenge).order_by(VulnLabChallenge.created_at.desc())
|
||||
|
||||
if vuln_type:
|
||||
query = query.where(VulnLabChallenge.vuln_type == vuln_type)
|
||||
if vuln_category:
|
||||
query = query.where(VulnLabChallenge.vuln_category == vuln_category)
|
||||
if status:
|
||||
query = query.where(VulnLabChallenge.status == status)
|
||||
if result:
|
||||
query = query.where(VulnLabChallenge.result == result)
|
||||
|
||||
query = query.limit(limit)
|
||||
db_result = await db.execute(query)
|
||||
challenges = db_result.scalars().all()
|
||||
|
||||
# For list view, exclude large logs field to save bandwidth
|
||||
result_list = []
|
||||
for c in challenges:
|
||||
d = c.to_dict()
|
||||
d["logs_count"] = len(d.get("logs", []))
|
||||
d.pop("logs", None) # Don't send full logs in list view
|
||||
result_list.append(d)
|
||||
|
||||
return {
|
||||
"challenges": result_list,
|
||||
"total": len(challenges),
|
||||
}
|
||||
|
||||
|
||||
@router.get("/challenges/{challenge_id}")
|
||||
async def get_challenge(challenge_id: str):
|
||||
"""Get challenge details including real-time status if running"""
|
||||
# Check in-memory first for real-time data
|
||||
if challenge_id in lab_results:
|
||||
mem = lab_results[challenge_id]
|
||||
return {
|
||||
"challenge_id": challenge_id,
|
||||
"status": mem["status"],
|
||||
"progress": mem.get("progress", 0),
|
||||
"phase": mem.get("phase", ""),
|
||||
"findings_count": len(mem.get("findings", [])),
|
||||
"findings": mem.get("findings", []),
|
||||
"logs_count": len(mem.get("logs", [])),
|
||||
"logs": mem.get("logs", [])[-200:], # Last 200 log entries for real-time
|
||||
"error": mem.get("error"),
|
||||
"result": mem.get("result"),
|
||||
"scan_id": mem.get("scan_id"),
|
||||
"agent_id": mem.get("agent_id"),
|
||||
"vuln_type": mem.get("vuln_type"),
|
||||
"target": mem.get("target"),
|
||||
"source": "realtime",
|
||||
}
|
||||
|
||||
# Fall back to DB
|
||||
async with async_session_factory() as db:
|
||||
result = await db.execute(
|
||||
select(VulnLabChallenge).where(VulnLabChallenge.id == challenge_id)
|
||||
)
|
||||
challenge = result.scalar_one_or_none()
|
||||
if not challenge:
|
||||
raise HTTPException(status_code=404, detail="Challenge not found")
|
||||
|
||||
data = challenge.to_dict()
|
||||
data["source"] = "database"
|
||||
data["logs_count"] = len(data.get("logs", []))
|
||||
return data
|
||||
|
||||
|
||||
@router.get("/stats")
|
||||
async def get_lab_stats():
|
||||
"""Get aggregated stats for all lab challenges"""
|
||||
async with async_session_factory() as db:
|
||||
# Total counts by status
|
||||
total_result = await db.execute(
|
||||
select(
|
||||
VulnLabChallenge.status,
|
||||
func.count(VulnLabChallenge.id)
|
||||
).group_by(VulnLabChallenge.status)
|
||||
)
|
||||
status_counts = {row[0]: row[1] for row in total_result.fetchall()}
|
||||
|
||||
# Results breakdown
|
||||
results_q = await db.execute(
|
||||
select(
|
||||
VulnLabChallenge.result,
|
||||
func.count(VulnLabChallenge.id)
|
||||
).where(VulnLabChallenge.result.isnot(None))
|
||||
.group_by(VulnLabChallenge.result)
|
||||
)
|
||||
result_counts = {row[0]: row[1] for row in results_q.fetchall()}
|
||||
|
||||
# Per vuln_type stats
|
||||
type_stats_q = await db.execute(
|
||||
select(
|
||||
VulnLabChallenge.vuln_type,
|
||||
VulnLabChallenge.result,
|
||||
func.count(VulnLabChallenge.id)
|
||||
).where(VulnLabChallenge.status == "completed")
|
||||
.group_by(VulnLabChallenge.vuln_type, VulnLabChallenge.result)
|
||||
)
|
||||
type_stats = {}
|
||||
for row in type_stats_q.fetchall():
|
||||
vtype, res, count = row
|
||||
if vtype not in type_stats:
|
||||
type_stats[vtype] = {"detected": 0, "not_detected": 0, "error": 0, "total": 0}
|
||||
type_stats[vtype][res or "error"] = count
|
||||
type_stats[vtype]["total"] += count
|
||||
|
||||
# Per category stats
|
||||
cat_stats_q = await db.execute(
|
||||
select(
|
||||
VulnLabChallenge.vuln_category,
|
||||
VulnLabChallenge.result,
|
||||
func.count(VulnLabChallenge.id)
|
||||
).where(VulnLabChallenge.status == "completed")
|
||||
.group_by(VulnLabChallenge.vuln_category, VulnLabChallenge.result)
|
||||
)
|
||||
cat_stats = {}
|
||||
for row in cat_stats_q.fetchall():
|
||||
cat, res, count = row
|
||||
if cat not in cat_stats:
|
||||
cat_stats[cat] = {"detected": 0, "not_detected": 0, "error": 0, "total": 0}
|
||||
cat_stats[cat][res or "error"] = count
|
||||
cat_stats[cat]["total"] += count
|
||||
|
||||
# Currently running
|
||||
running = len([cid for cid, r in lab_results.items() if r.get("status") == "running"])
|
||||
|
||||
total = sum(status_counts.values())
|
||||
detected = result_counts.get("detected", 0)
|
||||
completed = status_counts.get("completed", 0)
|
||||
detection_rate = round((detected / completed * 100), 1) if completed > 0 else 0
|
||||
|
||||
return {
|
||||
"total": total,
|
||||
"running": running,
|
||||
"status_counts": status_counts,
|
||||
"result_counts": result_counts,
|
||||
"detection_rate": detection_rate,
|
||||
"by_type": type_stats,
|
||||
"by_category": cat_stats,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/challenges/{challenge_id}/stop")
|
||||
async def stop_challenge(challenge_id: str):
|
||||
"""Stop a running lab challenge"""
|
||||
agent = lab_agents.get(challenge_id)
|
||||
if not agent:
|
||||
raise HTTPException(status_code=404, detail="No running agent for this challenge")
|
||||
|
||||
agent.cancel()
|
||||
|
||||
# Update DB
|
||||
try:
|
||||
async with async_session_factory() as db:
|
||||
result = await db.execute(
|
||||
select(VulnLabChallenge).where(VulnLabChallenge.id == challenge_id)
|
||||
)
|
||||
challenge = result.scalar_one_or_none()
|
||||
if challenge:
|
||||
challenge.status = "stopped"
|
||||
challenge.completed_at = datetime.utcnow()
|
||||
await db.commit()
|
||||
except:
|
||||
pass
|
||||
|
||||
if challenge_id in lab_results:
|
||||
lab_results[challenge_id]["status"] = "stopped"
|
||||
|
||||
return {"message": "Challenge stopped"}
|
||||
|
||||
|
||||
@router.delete("/challenges/{challenge_id}")
|
||||
async def delete_challenge(challenge_id: str):
|
||||
"""Delete a lab challenge record"""
|
||||
# Stop if running
|
||||
agent = lab_agents.get(challenge_id)
|
||||
if agent:
|
||||
agent.cancel()
|
||||
lab_agents.pop(challenge_id, None)
|
||||
|
||||
lab_results.pop(challenge_id, None)
|
||||
|
||||
async with async_session_factory() as db:
|
||||
result = await db.execute(
|
||||
select(VulnLabChallenge).where(VulnLabChallenge.id == challenge_id)
|
||||
)
|
||||
challenge = result.scalar_one_or_none()
|
||||
if not challenge:
|
||||
raise HTTPException(status_code=404, detail="Challenge not found")
|
||||
|
||||
await db.delete(challenge)
|
||||
await db.commit()
|
||||
|
||||
return {"message": "Challenge deleted"}
|
||||
|
||||
|
||||
@router.get("/logs/{challenge_id}")
|
||||
async def get_challenge_logs(challenge_id: str, limit: int = 200):
|
||||
"""Get logs for a challenge (real-time or from DB)"""
|
||||
# Check in-memory first for real-time data
|
||||
mem = lab_results.get(challenge_id)
|
||||
if mem:
|
||||
all_logs = mem.get("logs", [])
|
||||
return {
|
||||
"challenge_id": challenge_id,
|
||||
"total_logs": len(all_logs),
|
||||
"logs": all_logs[-limit:],
|
||||
"source": "realtime",
|
||||
}
|
||||
|
||||
# Fall back to DB persisted logs
|
||||
async with async_session_factory() as db:
|
||||
result = await db.execute(
|
||||
select(VulnLabChallenge).where(VulnLabChallenge.id == challenge_id)
|
||||
)
|
||||
challenge = result.scalar_one_or_none()
|
||||
if not challenge:
|
||||
raise HTTPException(status_code=404, detail="Challenge not found")
|
||||
|
||||
all_logs = challenge.logs or []
|
||||
return {
|
||||
"challenge_id": challenge_id,
|
||||
"total_logs": len(all_logs),
|
||||
"logs": all_logs[-limit:],
|
||||
"source": "database",
|
||||
}
|
||||
389
backend/api/v1/vulnerabilities.py
Executable file
389
backend/api/v1/vulnerabilities.py
Executable file
@@ -0,0 +1,389 @@
|
||||
"""
|
||||
NeuroSploit v3 - Vulnerabilities API Endpoints
|
||||
"""
|
||||
from typing import List
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
|
||||
from backend.db.database import get_db
|
||||
from backend.models import Vulnerability
|
||||
from backend.schemas.vulnerability import VulnerabilityResponse, VulnerabilityTypeInfo
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# Vulnerability type definitions
|
||||
VULNERABILITY_TYPES = {
|
||||
"injection": {
|
||||
"xss_reflected": {
|
||||
"name": "Reflected XSS",
|
||||
"description": "Cross-site scripting via user input reflected in response",
|
||||
"severity_range": "medium-high",
|
||||
"owasp_category": "A03:2021",
|
||||
"cwe_ids": ["CWE-79"]
|
||||
},
|
||||
"xss_stored": {
|
||||
"name": "Stored XSS",
|
||||
"description": "Cross-site scripting stored in application database",
|
||||
"severity_range": "high-critical",
|
||||
"owasp_category": "A03:2021",
|
||||
"cwe_ids": ["CWE-79"]
|
||||
},
|
||||
"xss_dom": {
|
||||
"name": "DOM-based XSS",
|
||||
"description": "Cross-site scripting via DOM manipulation",
|
||||
"severity_range": "medium-high",
|
||||
"owasp_category": "A03:2021",
|
||||
"cwe_ids": ["CWE-79"]
|
||||
},
|
||||
"sqli_error": {
|
||||
"name": "Error-based SQL Injection",
|
||||
"description": "SQL injection detected via error messages",
|
||||
"severity_range": "high-critical",
|
||||
"owasp_category": "A03:2021",
|
||||
"cwe_ids": ["CWE-89"]
|
||||
},
|
||||
"sqli_union": {
|
||||
"name": "Union-based SQL Injection",
|
||||
"description": "SQL injection exploitable via UNION queries",
|
||||
"severity_range": "critical",
|
||||
"owasp_category": "A03:2021",
|
||||
"cwe_ids": ["CWE-89"]
|
||||
},
|
||||
"sqli_blind": {
|
||||
"name": "Blind SQL Injection",
|
||||
"description": "SQL injection without visible output",
|
||||
"severity_range": "high-critical",
|
||||
"owasp_category": "A03:2021",
|
||||
"cwe_ids": ["CWE-89"]
|
||||
},
|
||||
"sqli_time": {
|
||||
"name": "Time-based SQL Injection",
|
||||
"description": "SQL injection detected via response time",
|
||||
"severity_range": "high-critical",
|
||||
"owasp_category": "A03:2021",
|
||||
"cwe_ids": ["CWE-89"]
|
||||
},
|
||||
"command_injection": {
|
||||
"name": "Command Injection",
|
||||
"description": "OS command injection vulnerability",
|
||||
"severity_range": "critical",
|
||||
"owasp_category": "A03:2021",
|
||||
"cwe_ids": ["CWE-78"]
|
||||
},
|
||||
"ssti": {
|
||||
"name": "Server-Side Template Injection",
|
||||
"description": "Template injection allowing code execution",
|
||||
"severity_range": "high-critical",
|
||||
"owasp_category": "A03:2021",
|
||||
"cwe_ids": ["CWE-94"]
|
||||
},
|
||||
"ldap_injection": {
|
||||
"name": "LDAP Injection",
|
||||
"description": "LDAP query injection",
|
||||
"severity_range": "high",
|
||||
"owasp_category": "A03:2021",
|
||||
"cwe_ids": ["CWE-90"]
|
||||
},
|
||||
"xpath_injection": {
|
||||
"name": "XPath Injection",
|
||||
"description": "XPath query injection",
|
||||
"severity_range": "medium-high",
|
||||
"owasp_category": "A03:2021",
|
||||
"cwe_ids": ["CWE-643"]
|
||||
},
|
||||
"nosql_injection": {
|
||||
"name": "NoSQL Injection",
|
||||
"description": "NoSQL database injection",
|
||||
"severity_range": "high-critical",
|
||||
"owasp_category": "A03:2021",
|
||||
"cwe_ids": ["CWE-943"]
|
||||
},
|
||||
"header_injection": {
|
||||
"name": "HTTP Header Injection",
|
||||
"description": "Injection into HTTP headers",
|
||||
"severity_range": "medium-high",
|
||||
"owasp_category": "A03:2021",
|
||||
"cwe_ids": ["CWE-113"]
|
||||
},
|
||||
"crlf_injection": {
|
||||
"name": "CRLF Injection",
|
||||
"description": "Carriage return line feed injection",
|
||||
"severity_range": "medium",
|
||||
"owasp_category": "A03:2021",
|
||||
"cwe_ids": ["CWE-93"]
|
||||
}
|
||||
},
|
||||
"file_access": {
|
||||
"lfi": {
|
||||
"name": "Local File Inclusion",
|
||||
"description": "Include local files via path manipulation",
|
||||
"severity_range": "high-critical",
|
||||
"owasp_category": "A01:2021",
|
||||
"cwe_ids": ["CWE-98"]
|
||||
},
|
||||
"rfi": {
|
||||
"name": "Remote File Inclusion",
|
||||
"description": "Include remote files for code execution",
|
||||
"severity_range": "critical",
|
||||
"owasp_category": "A01:2021",
|
||||
"cwe_ids": ["CWE-98"]
|
||||
},
|
||||
"path_traversal": {
|
||||
"name": "Path Traversal",
|
||||
"description": "Access files outside web root",
|
||||
"severity_range": "high",
|
||||
"owasp_category": "A01:2021",
|
||||
"cwe_ids": ["CWE-22"]
|
||||
},
|
||||
"file_upload": {
|
||||
"name": "Arbitrary File Upload",
|
||||
"description": "Upload malicious files",
|
||||
"severity_range": "high-critical",
|
||||
"owasp_category": "A04:2021",
|
||||
"cwe_ids": ["CWE-434"]
|
||||
},
|
||||
"xxe": {
|
||||
"name": "XML External Entity",
|
||||
"description": "XXE injection vulnerability",
|
||||
"severity_range": "high-critical",
|
||||
"owasp_category": "A05:2021",
|
||||
"cwe_ids": ["CWE-611"]
|
||||
}
|
||||
},
|
||||
"request_forgery": {
|
||||
"ssrf": {
|
||||
"name": "Server-Side Request Forgery",
|
||||
"description": "Forge requests from the server",
|
||||
"severity_range": "high-critical",
|
||||
"owasp_category": "A10:2021",
|
||||
"cwe_ids": ["CWE-918"]
|
||||
},
|
||||
"ssrf_cloud": {
|
||||
"name": "SSRF to Cloud Metadata",
|
||||
"description": "SSRF accessing cloud provider metadata",
|
||||
"severity_range": "critical",
|
||||
"owasp_category": "A10:2021",
|
||||
"cwe_ids": ["CWE-918"]
|
||||
},
|
||||
"csrf": {
|
||||
"name": "Cross-Site Request Forgery",
|
||||
"description": "Forge requests as authenticated user",
|
||||
"severity_range": "medium-high",
|
||||
"owasp_category": "A01:2021",
|
||||
"cwe_ids": ["CWE-352"]
|
||||
}
|
||||
},
|
||||
"authentication": {
|
||||
"auth_bypass": {
|
||||
"name": "Authentication Bypass",
|
||||
"description": "Bypass authentication mechanisms",
|
||||
"severity_range": "critical",
|
||||
"owasp_category": "A07:2021",
|
||||
"cwe_ids": ["CWE-287"]
|
||||
},
|
||||
"session_fixation": {
|
||||
"name": "Session Fixation",
|
||||
"description": "Force known session ID on user",
|
||||
"severity_range": "high",
|
||||
"owasp_category": "A07:2021",
|
||||
"cwe_ids": ["CWE-384"]
|
||||
},
|
||||
"jwt_manipulation": {
|
||||
"name": "JWT Token Manipulation",
|
||||
"description": "Manipulate JWT tokens for auth bypass",
|
||||
"severity_range": "high-critical",
|
||||
"owasp_category": "A07:2021",
|
||||
"cwe_ids": ["CWE-347"]
|
||||
},
|
||||
"weak_password_policy": {
|
||||
"name": "Weak Password Policy",
|
||||
"description": "Application accepts weak passwords",
|
||||
"severity_range": "medium",
|
||||
"owasp_category": "A07:2021",
|
||||
"cwe_ids": ["CWE-521"]
|
||||
}
|
||||
},
|
||||
"authorization": {
|
||||
"idor": {
|
||||
"name": "Insecure Direct Object Reference",
|
||||
"description": "Access objects without proper authorization",
|
||||
"severity_range": "high",
|
||||
"owasp_category": "A01:2021",
|
||||
"cwe_ids": ["CWE-639"]
|
||||
},
|
||||
"bola": {
|
||||
"name": "Broken Object Level Authorization",
|
||||
"description": "API-level object authorization bypass",
|
||||
"severity_range": "high",
|
||||
"owasp_category": "A01:2021",
|
||||
"cwe_ids": ["CWE-639"]
|
||||
},
|
||||
"privilege_escalation": {
|
||||
"name": "Privilege Escalation",
|
||||
"description": "Escalate to higher privilege level",
|
||||
"severity_range": "critical",
|
||||
"owasp_category": "A01:2021",
|
||||
"cwe_ids": ["CWE-269"]
|
||||
}
|
||||
},
|
||||
"api_security": {
|
||||
"rate_limiting": {
|
||||
"name": "Missing Rate Limiting",
|
||||
"description": "No rate limiting on sensitive endpoints",
|
||||
"severity_range": "medium",
|
||||
"owasp_category": "A04:2021",
|
||||
"cwe_ids": ["CWE-770"]
|
||||
},
|
||||
"mass_assignment": {
|
||||
"name": "Mass Assignment",
|
||||
"description": "Modify unintended object properties",
|
||||
"severity_range": "high",
|
||||
"owasp_category": "A04:2021",
|
||||
"cwe_ids": ["CWE-915"]
|
||||
},
|
||||
"excessive_data": {
|
||||
"name": "Excessive Data Exposure",
|
||||
"description": "API returns more data than needed",
|
||||
"severity_range": "medium-high",
|
||||
"owasp_category": "A01:2021",
|
||||
"cwe_ids": ["CWE-200"]
|
||||
},
|
||||
"graphql_introspection": {
|
||||
"name": "GraphQL Introspection Enabled",
|
||||
"description": "GraphQL schema exposed via introspection",
|
||||
"severity_range": "low-medium",
|
||||
"owasp_category": "A05:2021",
|
||||
"cwe_ids": ["CWE-200"]
|
||||
}
|
||||
},
|
||||
"client_side": {
|
||||
"cors_misconfig": {
|
||||
"name": "CORS Misconfiguration",
|
||||
"description": "Permissive CORS policy",
|
||||
"severity_range": "medium-high",
|
||||
"owasp_category": "A05:2021",
|
||||
"cwe_ids": ["CWE-942"]
|
||||
},
|
||||
"clickjacking": {
|
||||
"name": "Clickjacking",
|
||||
"description": "Page can be framed for clickjacking",
|
||||
"severity_range": "medium",
|
||||
"owasp_category": "A05:2021",
|
||||
"cwe_ids": ["CWE-1021"]
|
||||
},
|
||||
"open_redirect": {
|
||||
"name": "Open Redirect",
|
||||
"description": "Redirect to arbitrary URLs",
|
||||
"severity_range": "low-medium",
|
||||
"owasp_category": "A01:2021",
|
||||
"cwe_ids": ["CWE-601"]
|
||||
}
|
||||
},
|
||||
"information_disclosure": {
|
||||
"error_disclosure": {
|
||||
"name": "Error Message Disclosure",
|
||||
"description": "Detailed error messages exposed",
|
||||
"severity_range": "low-medium",
|
||||
"owasp_category": "A05:2021",
|
||||
"cwe_ids": ["CWE-209"]
|
||||
},
|
||||
"sensitive_data": {
|
||||
"name": "Sensitive Data Exposure",
|
||||
"description": "Sensitive information exposed",
|
||||
"severity_range": "medium-high",
|
||||
"owasp_category": "A02:2021",
|
||||
"cwe_ids": ["CWE-200"]
|
||||
},
|
||||
"debug_endpoints": {
|
||||
"name": "Debug Endpoints Exposed",
|
||||
"description": "Debug/admin endpoints accessible",
|
||||
"severity_range": "high",
|
||||
"owasp_category": "A05:2021",
|
||||
"cwe_ids": ["CWE-489"]
|
||||
}
|
||||
},
|
||||
"infrastructure": {
|
||||
"security_headers": {
|
||||
"name": "Missing Security Headers",
|
||||
"description": "Important security headers not set",
|
||||
"severity_range": "low-medium",
|
||||
"owasp_category": "A05:2021",
|
||||
"cwe_ids": ["CWE-693"]
|
||||
},
|
||||
"ssl_issues": {
|
||||
"name": "SSL/TLS Issues",
|
||||
"description": "Weak SSL/TLS configuration",
|
||||
"severity_range": "medium",
|
||||
"owasp_category": "A02:2021",
|
||||
"cwe_ids": ["CWE-326"]
|
||||
},
|
||||
"http_methods": {
|
||||
"name": "Dangerous HTTP Methods",
|
||||
"description": "Dangerous HTTP methods enabled",
|
||||
"severity_range": "low-medium",
|
||||
"owasp_category": "A05:2021",
|
||||
"cwe_ids": ["CWE-749"]
|
||||
}
|
||||
},
|
||||
"logic_flaws": {
|
||||
"race_condition": {
|
||||
"name": "Race Condition",
|
||||
"description": "Exploitable race condition",
|
||||
"severity_range": "medium-high",
|
||||
"owasp_category": "A04:2021",
|
||||
"cwe_ids": ["CWE-362"]
|
||||
},
|
||||
"business_logic": {
|
||||
"name": "Business Logic Flaw",
|
||||
"description": "Exploitable business logic error",
|
||||
"severity_range": "varies",
|
||||
"owasp_category": "A04:2021",
|
||||
"cwe_ids": ["CWE-840"]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@router.get("/types")
|
||||
async def get_vulnerability_types():
|
||||
"""Get all vulnerability types organized by category"""
|
||||
return VULNERABILITY_TYPES
|
||||
|
||||
|
||||
@router.get("/types/{category}")
|
||||
async def get_vulnerability_types_by_category(category: str):
|
||||
"""Get vulnerability types for a specific category"""
|
||||
if category not in VULNERABILITY_TYPES:
|
||||
raise HTTPException(status_code=404, detail=f"Category '{category}' not found")
|
||||
|
||||
return VULNERABILITY_TYPES[category]
|
||||
|
||||
|
||||
@router.get("/types/{category}/{vuln_type}", response_model=VulnerabilityTypeInfo)
|
||||
async def get_vulnerability_type_info(category: str, vuln_type: str):
|
||||
"""Get detailed info for a specific vulnerability type"""
|
||||
if category not in VULNERABILITY_TYPES:
|
||||
raise HTTPException(status_code=404, detail=f"Category '{category}' not found")
|
||||
|
||||
if vuln_type not in VULNERABILITY_TYPES[category]:
|
||||
raise HTTPException(status_code=404, detail=f"Type '{vuln_type}' not found in category '{category}'")
|
||||
|
||||
info = VULNERABILITY_TYPES[category][vuln_type]
|
||||
return VulnerabilityTypeInfo(
|
||||
type=vuln_type,
|
||||
category=category,
|
||||
**info
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{vuln_id}", response_model=VulnerabilityResponse)
|
||||
async def get_vulnerability(vuln_id: str, db: AsyncSession = Depends(get_db)):
|
||||
"""Get a specific vulnerability by ID"""
|
||||
result = await db.execute(select(Vulnerability).where(Vulnerability.id == vuln_id))
|
||||
vuln = result.scalar_one_or_none()
|
||||
|
||||
if not vuln:
|
||||
raise HTTPException(status_code=404, detail="Vulnerability not found")
|
||||
|
||||
return VulnerabilityResponse(**vuln.to_dict())
|
||||
247
backend/api/websocket.py
Executable file
247
backend/api/websocket.py
Executable file
@@ -0,0 +1,247 @@
|
||||
"""
|
||||
NeuroSploit v3 - WebSocket Manager
|
||||
"""
|
||||
from typing import Dict, List, Optional
|
||||
from fastapi import WebSocket
|
||||
import json
|
||||
import asyncio
|
||||
|
||||
try:
|
||||
from backend.core.notification_manager import notification_manager, NotificationEvent
|
||||
HAS_NOTIFICATIONS = True
|
||||
except ImportError:
|
||||
HAS_NOTIFICATIONS = False
|
||||
|
||||
|
||||
class ConnectionManager:
|
||||
"""Manages WebSocket connections for real-time updates"""
|
||||
|
||||
def __init__(self):
|
||||
# scan_id -> list of websocket connections
|
||||
self.active_connections: Dict[str, List[WebSocket]] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def connect(self, websocket: WebSocket, scan_id: str):
|
||||
"""Accept a WebSocket connection and register it for a scan"""
|
||||
await websocket.accept()
|
||||
async with self._lock:
|
||||
if scan_id not in self.active_connections:
|
||||
self.active_connections[scan_id] = []
|
||||
self.active_connections[scan_id].append(websocket)
|
||||
print(f"WebSocket connected for scan: {scan_id}")
|
||||
|
||||
def disconnect(self, websocket: WebSocket, scan_id: str):
|
||||
"""Remove a WebSocket connection"""
|
||||
if scan_id in self.active_connections:
|
||||
if websocket in self.active_connections[scan_id]:
|
||||
self.active_connections[scan_id].remove(websocket)
|
||||
if not self.active_connections[scan_id]:
|
||||
del self.active_connections[scan_id]
|
||||
print(f"WebSocket disconnected for scan: {scan_id}")
|
||||
|
||||
async def send_to_scan(self, scan_id: str, message: dict):
|
||||
"""Send a message to all connections watching a specific scan"""
|
||||
if scan_id not in self.active_connections:
|
||||
return
|
||||
|
||||
dead_connections = []
|
||||
for connection in self.active_connections[scan_id]:
|
||||
try:
|
||||
await connection.send_text(json.dumps(message))
|
||||
except Exception:
|
||||
dead_connections.append(connection)
|
||||
|
||||
# Clean up dead connections
|
||||
for conn in dead_connections:
|
||||
self.disconnect(conn, scan_id)
|
||||
|
||||
async def broadcast_scan_started(self, scan_id: str, target: str = ""):
|
||||
"""Notify that a scan has started"""
|
||||
await self.send_to_scan(scan_id, {
|
||||
"type": "scan_started",
|
||||
"scan_id": scan_id
|
||||
})
|
||||
if HAS_NOTIFICATIONS:
|
||||
asyncio.create_task(notification_manager.notify(
|
||||
NotificationEvent.SCAN_STARTED, {"target": target, "scan_id": scan_id}
|
||||
))
|
||||
|
||||
async def broadcast_phase_change(self, scan_id: str, phase: str):
|
||||
"""Notify phase change (recon, testing, reporting)"""
|
||||
await self.send_to_scan(scan_id, {
|
||||
"type": "phase_change",
|
||||
"scan_id": scan_id,
|
||||
"phase": phase
|
||||
})
|
||||
|
||||
async def broadcast_progress(self, scan_id: str, progress: int, message: Optional[str] = None):
|
||||
"""Send progress update"""
|
||||
await self.send_to_scan(scan_id, {
|
||||
"type": "progress_update",
|
||||
"scan_id": scan_id,
|
||||
"progress": progress,
|
||||
"message": message
|
||||
})
|
||||
|
||||
async def broadcast_endpoint_found(self, scan_id: str, endpoint: dict):
|
||||
"""Notify a new endpoint was discovered"""
|
||||
await self.send_to_scan(scan_id, {
|
||||
"type": "endpoint_found",
|
||||
"scan_id": scan_id,
|
||||
"endpoint": endpoint
|
||||
})
|
||||
|
||||
async def broadcast_path_crawled(self, scan_id: str, path: str, status: int):
|
||||
"""Notify a path was crawled"""
|
||||
await self.send_to_scan(scan_id, {
|
||||
"type": "path_crawled",
|
||||
"scan_id": scan_id,
|
||||
"path": path,
|
||||
"status": status
|
||||
})
|
||||
|
||||
async def broadcast_url_discovered(self, scan_id: str, url: str):
|
||||
"""Notify a URL was discovered"""
|
||||
await self.send_to_scan(scan_id, {
|
||||
"type": "url_discovered",
|
||||
"scan_id": scan_id,
|
||||
"url": url
|
||||
})
|
||||
|
||||
async def broadcast_test_started(self, scan_id: str, vuln_type: str, endpoint: str):
|
||||
"""Notify a vulnerability test has started"""
|
||||
await self.send_to_scan(scan_id, {
|
||||
"type": "test_started",
|
||||
"scan_id": scan_id,
|
||||
"vulnerability_type": vuln_type,
|
||||
"endpoint": endpoint
|
||||
})
|
||||
|
||||
async def broadcast_test_completed(self, scan_id: str, vuln_type: str, endpoint: str, is_vulnerable: bool):
|
||||
"""Notify a vulnerability test has completed"""
|
||||
await self.send_to_scan(scan_id, {
|
||||
"type": "test_completed",
|
||||
"scan_id": scan_id,
|
||||
"vulnerability_type": vuln_type,
|
||||
"endpoint": endpoint,
|
||||
"is_vulnerable": is_vulnerable
|
||||
})
|
||||
|
||||
async def broadcast_vulnerability_found(self, scan_id: str, vulnerability: dict):
|
||||
"""Notify a vulnerability was found"""
|
||||
await self.send_to_scan(scan_id, {
|
||||
"type": "vuln_found",
|
||||
"scan_id": scan_id,
|
||||
"vulnerability": vulnerability
|
||||
})
|
||||
if HAS_NOTIFICATIONS:
|
||||
asyncio.create_task(notification_manager.notify(
|
||||
NotificationEvent.VULN_FOUND, {
|
||||
"title": vulnerability.get("title", "Vulnerability Found"),
|
||||
"severity": vulnerability.get("severity", "medium"),
|
||||
"vulnerability_type": vulnerability.get("vulnerability_type", "unknown"),
|
||||
"endpoint": vulnerability.get("endpoint", ""),
|
||||
"description": vulnerability.get("description", ""),
|
||||
}
|
||||
))
|
||||
|
||||
async def broadcast_log(self, scan_id: str, level: str, message: str):
|
||||
"""Send a log message"""
|
||||
await self.send_to_scan(scan_id, {
|
||||
"type": "log_message",
|
||||
"scan_id": scan_id,
|
||||
"level": level,
|
||||
"message": message
|
||||
})
|
||||
|
||||
async def broadcast_scan_completed(self, scan_id: str, summary: dict):
|
||||
"""Notify that a scan has completed"""
|
||||
await self.send_to_scan(scan_id, {
|
||||
"type": "scan_completed",
|
||||
"scan_id": scan_id,
|
||||
"summary": summary
|
||||
})
|
||||
if HAS_NOTIFICATIONS:
|
||||
asyncio.create_task(notification_manager.notify(
|
||||
NotificationEvent.SCAN_COMPLETED, {
|
||||
"total_vulnerabilities": summary.get("total_vulnerabilities", 0),
|
||||
"critical": summary.get("critical", 0),
|
||||
"high": summary.get("high", 0),
|
||||
"medium": summary.get("medium", 0),
|
||||
}
|
||||
))
|
||||
|
||||
async def broadcast_scan_stopped(self, scan_id: str, summary: dict):
|
||||
"""Notify that a scan was stopped by user"""
|
||||
await self.send_to_scan(scan_id, {
|
||||
"type": "scan_stopped",
|
||||
"scan_id": scan_id,
|
||||
"status": "stopped",
|
||||
"summary": summary
|
||||
})
|
||||
|
||||
async def broadcast_scan_failed(self, scan_id: str, error: str, summary: dict = None):
|
||||
"""Notify that a scan has failed"""
|
||||
await self.send_to_scan(scan_id, {
|
||||
"type": "scan_failed",
|
||||
"scan_id": scan_id,
|
||||
"status": "failed",
|
||||
"error": error,
|
||||
"summary": summary or {}
|
||||
})
|
||||
if HAS_NOTIFICATIONS:
|
||||
asyncio.create_task(notification_manager.notify(
|
||||
NotificationEvent.SCAN_FAILED, {"error": error}
|
||||
))
|
||||
|
||||
async def broadcast_stats_update(self, scan_id: str, stats: dict):
|
||||
"""Broadcast updated scan statistics"""
|
||||
await self.send_to_scan(scan_id, {
|
||||
"type": "stats_update",
|
||||
"scan_id": scan_id,
|
||||
"stats": stats
|
||||
})
|
||||
|
||||
async def broadcast_agent_task(self, scan_id: str, task: dict):
|
||||
"""Broadcast agent task update (created, started, completed, failed)"""
|
||||
await self.send_to_scan(scan_id, {
|
||||
"type": "agent_task",
|
||||
"scan_id": scan_id,
|
||||
"task": task
|
||||
})
|
||||
|
||||
async def broadcast_agent_task_started(self, scan_id: str, task: dict):
|
||||
"""Broadcast when an agent task starts"""
|
||||
await self.send_to_scan(scan_id, {
|
||||
"type": "agent_task_started",
|
||||
"scan_id": scan_id,
|
||||
"task": task
|
||||
})
|
||||
|
||||
async def broadcast_agent_task_completed(self, scan_id: str, task: dict):
|
||||
"""Broadcast when an agent task completes"""
|
||||
await self.send_to_scan(scan_id, {
|
||||
"type": "agent_task_completed",
|
||||
"scan_id": scan_id,
|
||||
"task": task
|
||||
})
|
||||
|
||||
async def broadcast_report_generated(self, scan_id: str, report: dict):
|
||||
"""Broadcast when a report is generated"""
|
||||
await self.send_to_scan(scan_id, {
|
||||
"type": "report_generated",
|
||||
"scan_id": scan_id,
|
||||
"report": report
|
||||
})
|
||||
|
||||
async def broadcast_error(self, scan_id: str, error: str):
|
||||
"""Notify an error occurred"""
|
||||
await self.send_to_scan(scan_id, {
|
||||
"type": "error",
|
||||
"scan_id": scan_id,
|
||||
"error": error
|
||||
})
|
||||
|
||||
|
||||
# Global instance
|
||||
manager = ConnectionManager()
|
||||
Reference in New Issue
Block a user