Add files via upload

This commit is contained in:
Joas A Santos
2026-01-23 15:46:05 -03:00
committed by GitHub
parent 2a5e9b139a
commit f9e4ec16ec
19 changed files with 1398 additions and 158 deletions

View File

@@ -0,0 +1,176 @@
"""
NeuroSploit v3 - Agent Tasks API Endpoints
"""
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, func
from backend.db.database import get_db
from backend.models import AgentTask, Scan
from backend.schemas.agent_task import (
AgentTaskResponse,
AgentTaskListResponse,
AgentTaskSummary
)
router = APIRouter()
@router.get("", response_model=AgentTaskListResponse)
async def list_agent_tasks(
scan_id: str,
status: Optional[str] = None,
task_type: Optional[str] = None,
page: int = 1,
per_page: int = 50,
db: AsyncSession = Depends(get_db)
):
"""List all agent tasks for a scan"""
# Verify scan exists
scan_result = await db.execute(select(Scan).where(Scan.id == scan_id))
scan = scan_result.scalar_one_or_none()
if not scan:
raise HTTPException(status_code=404, detail="Scan not found")
# Build query
query = select(AgentTask).where(AgentTask.scan_id == scan_id)
if status:
query = query.where(AgentTask.status == status)
if task_type:
query = query.where(AgentTask.task_type == task_type)
query = query.order_by(AgentTask.created_at.desc())
# Get total count
count_query = select(func.count()).select_from(AgentTask).where(AgentTask.scan_id == scan_id)
if status:
count_query = count_query.where(AgentTask.status == status)
if task_type:
count_query = count_query.where(AgentTask.task_type == task_type)
total_result = await db.execute(count_query)
total = total_result.scalar() or 0
# Apply pagination
query = query.offset((page - 1) * per_page).limit(per_page)
result = await db.execute(query)
tasks = result.scalars().all()
return AgentTaskListResponse(
tasks=[AgentTaskResponse(**t.to_dict()) for t in tasks],
total=total,
scan_id=scan_id
)
@router.get("/summary", response_model=AgentTaskSummary)
async def get_agent_tasks_summary(
scan_id: str,
db: AsyncSession = Depends(get_db)
):
"""Get summary statistics for agent tasks in a scan"""
# Verify scan exists
scan_result = await db.execute(select(Scan).where(Scan.id == scan_id))
scan = scan_result.scalar_one_or_none()
if not scan:
raise HTTPException(status_code=404, detail="Scan not found")
# Total count
total_result = await db.execute(
select(func.count()).select_from(AgentTask).where(AgentTask.scan_id == scan_id)
)
total = total_result.scalar() or 0
# Count by status
status_counts = {}
for status in ["pending", "running", "completed", "failed"]:
count_result = await db.execute(
select(func.count()).select_from(AgentTask)
.where(AgentTask.scan_id == scan_id)
.where(AgentTask.status == status)
)
status_counts[status] = count_result.scalar() or 0
# Count by task type
type_query = select(
AgentTask.task_type,
func.count(AgentTask.id).label("count")
).where(AgentTask.scan_id == scan_id).group_by(AgentTask.task_type)
type_result = await db.execute(type_query)
by_type = {row[0]: row[1] for row in type_result.all()}
# Count by tool
tool_query = select(
AgentTask.tool_name,
func.count(AgentTask.id).label("count")
).where(AgentTask.scan_id == scan_id).where(AgentTask.tool_name.isnot(None)).group_by(AgentTask.tool_name)
tool_result = await db.execute(tool_query)
by_tool = {row[0]: row[1] for row in tool_result.all()}
return AgentTaskSummary(
total=total,
pending=status_counts.get("pending", 0),
running=status_counts.get("running", 0),
completed=status_counts.get("completed", 0),
failed=status_counts.get("failed", 0),
by_type=by_type,
by_tool=by_tool
)
@router.get("/{task_id}", response_model=AgentTaskResponse)
async def get_agent_task(
task_id: str,
db: AsyncSession = Depends(get_db)
):
"""Get a specific agent task by ID"""
result = await db.execute(select(AgentTask).where(AgentTask.id == task_id))
task = result.scalar_one_or_none()
if not task:
raise HTTPException(status_code=404, detail="Agent task not found")
return AgentTaskResponse(**task.to_dict())
@router.get("/scan/{scan_id}/timeline")
async def get_agent_tasks_timeline(
scan_id: str,
db: AsyncSession = Depends(get_db)
):
"""Get agent tasks as a timeline for visualization"""
# Verify scan exists
scan_result = await db.execute(select(Scan).where(Scan.id == scan_id))
scan = scan_result.scalar_one_or_none()
if not scan:
raise HTTPException(status_code=404, detail="Scan not found")
# Get all tasks ordered by creation time
query = select(AgentTask).where(AgentTask.scan_id == scan_id).order_by(AgentTask.created_at.asc())
result = await db.execute(query)
tasks = result.scalars().all()
timeline = []
for task in tasks:
timeline_item = {
"id": task.id,
"task_name": task.task_name,
"task_type": task.task_type,
"tool_name": task.tool_name,
"status": task.status,
"started_at": task.started_at.isoformat() if task.started_at else None,
"completed_at": task.completed_at.isoformat() if task.completed_at else None,
"duration_ms": task.duration_ms,
"items_processed": task.items_processed,
"items_found": task.items_found,
"result_summary": task.result_summary,
"error_message": task.error_message
}
timeline.append(timeline_item)
return {
"scan_id": scan_id,
"timeline": timeline,
"total": len(timeline)
}

View File

@@ -8,7 +8,7 @@ from sqlalchemy import select, func
from datetime import datetime, timedelta
from backend.db.database import get_db
from backend.models import Scan, Vulnerability, Endpoint
from backend.models import Scan, Vulnerability, Endpoint, AgentTask, Report
router = APIRouter()
@@ -20,18 +20,32 @@ async def get_dashboard_stats(db: AsyncSession = Depends(get_db)):
total_scans_result = await db.execute(select(func.count()).select_from(Scan))
total_scans = total_scans_result.scalar() or 0
# Running scans
# Scans by status
running_result = await db.execute(
select(func.count()).select_from(Scan).where(Scan.status == "running")
)
running_scans = running_result.scalar() or 0
# Completed scans
completed_result = await db.execute(
select(func.count()).select_from(Scan).where(Scan.status == "completed")
)
completed_scans = completed_result.scalar() or 0
stopped_result = await db.execute(
select(func.count()).select_from(Scan).where(Scan.status == "stopped")
)
stopped_scans = stopped_result.scalar() or 0
failed_result = await db.execute(
select(func.count()).select_from(Scan).where(Scan.status == "failed")
)
failed_scans = failed_result.scalar() or 0
pending_result = await db.execute(
select(func.count()).select_from(Scan).where(Scan.status == "pending")
)
pending_scans = pending_result.scalar() or 0
# Total vulnerabilities by severity
vuln_counts = {}
for severity in ["critical", "high", "medium", "low", "info"]:
@@ -63,6 +77,9 @@ async def get_dashboard_stats(db: AsyncSession = Depends(get_db)):
"total": total_scans,
"running": running_scans,
"completed": completed_scans,
"stopped": stopped_scans,
"failed": failed_scans,
"pending": pending_scans,
"recent": recent_scans
},
"vulnerabilities": {
@@ -175,3 +192,108 @@ async def get_scan_history(
history[date_str]["high"] += scan.high_count
return {"history": list(history.values())}
@router.get("/agent-tasks")
async def get_recent_agent_tasks(
limit: int = 20,
db: AsyncSession = Depends(get_db)
):
"""Get recent agent tasks across all scans"""
query = (
select(AgentTask)
.order_by(AgentTask.created_at.desc())
.limit(limit)
)
result = await db.execute(query)
tasks = result.scalars().all()
return {
"agent_tasks": [t.to_dict() for t in tasks],
"total": len(tasks)
}
@router.get("/activity-feed")
async def get_activity_feed(
limit: int = 30,
db: AsyncSession = Depends(get_db)
):
"""Get unified activity feed with all recent events"""
activities = []
# Get recent scans
scans_result = await db.execute(
select(Scan).order_by(Scan.created_at.desc()).limit(limit // 3)
)
for scan in scans_result.scalars().all():
activities.append({
"type": "scan",
"action": f"Scan {scan.status}",
"title": scan.name or "Unnamed Scan",
"description": f"{scan.total_vulnerabilities} vulnerabilities found",
"status": scan.status,
"severity": None,
"timestamp": scan.created_at.isoformat(),
"scan_id": scan.id,
"link": f"/scan/{scan.id}"
})
# Get recent vulnerabilities
vulns_result = await db.execute(
select(Vulnerability).order_by(Vulnerability.created_at.desc()).limit(limit // 3)
)
for vuln in vulns_result.scalars().all():
activities.append({
"type": "vulnerability",
"action": "Vulnerability found",
"title": vuln.title,
"description": vuln.affected_endpoint or "",
"status": None,
"severity": vuln.severity,
"timestamp": vuln.created_at.isoformat(),
"scan_id": vuln.scan_id,
"link": f"/scan/{vuln.scan_id}"
})
# Get recent agent tasks
tasks_result = await db.execute(
select(AgentTask).order_by(AgentTask.created_at.desc()).limit(limit // 3)
)
for task in tasks_result.scalars().all():
activities.append({
"type": "agent_task",
"action": f"Task {task.status}",
"title": task.task_name,
"description": task.result_summary or task.description or "",
"status": task.status,
"severity": None,
"timestamp": task.created_at.isoformat(),
"scan_id": task.scan_id,
"link": f"/scan/{task.scan_id}"
})
# Get recent reports
reports_result = await db.execute(
select(Report).order_by(Report.generated_at.desc()).limit(limit // 4)
)
for report in reports_result.scalars().all():
activities.append({
"type": "report",
"action": "Report generated" if report.auto_generated else "Report created",
"title": report.title or "Report",
"description": f"{report.format.upper()} format",
"status": "auto" if report.auto_generated else "manual",
"severity": None,
"timestamp": report.generated_at.isoformat(),
"scan_id": report.scan_id,
"link": f"/reports"
})
# Sort all activities by timestamp (newest first)
activities.sort(key=lambda x: x["timestamp"], reverse=True)
return {
"activities": activities[:limit],
"total": len(activities)
}

View File

@@ -20,14 +20,22 @@ router = APIRouter()
@router.get("", response_model=ReportListResponse)
async def list_reports(
scan_id: Optional[str] = None,
auto_generated: Optional[bool] = None,
is_partial: Optional[bool] = None,
db: AsyncSession = Depends(get_db)
):
"""List all reports"""
"""List all reports with optional filtering"""
query = select(Report).order_by(Report.generated_at.desc())
if scan_id:
query = query.where(Report.scan_id == scan_id)
if auto_generated is not None:
query = query.where(Report.auto_generated == auto_generated)
if is_partial is not None:
query = query.where(Report.is_partial == is_partial)
result = await db.execute(query)
reports = result.scalars().all()

View File

@@ -175,7 +175,9 @@ async def start_scan(
@router.post("/{scan_id}/stop")
async def stop_scan(scan_id: str, db: AsyncSession = Depends(get_db)):
"""Stop a running scan"""
"""Stop a running scan and save partial results"""
from backend.api.websocket import manager as ws_manager
result = await db.execute(select(Scan).where(Scan.id == scan_id))
scan = result.scalar_one_or_none()
@@ -185,11 +187,76 @@ async def stop_scan(scan_id: str, db: AsyncSession = Depends(get_db)):
if scan.status != "running":
raise HTTPException(status_code=400, detail="Scan is not running")
# Update scan status
scan.status = "stopped"
scan.completed_at = datetime.utcnow()
scan.current_phase = "stopped"
# Calculate duration
if scan.started_at:
duration = (scan.completed_at - scan.started_at).total_seconds()
scan.duration = int(duration)
# Compute final vulnerability statistics from database
for severity in ["critical", "high", "medium", "low", "info"]:
count_result = await db.execute(
select(func.count()).select_from(Vulnerability)
.where(Vulnerability.scan_id == scan_id)
.where(Vulnerability.severity == severity)
)
setattr(scan, f"{severity}_count", count_result.scalar() or 0)
# Get total vulnerability count
total_vuln_result = await db.execute(
select(func.count()).select_from(Vulnerability)
.where(Vulnerability.scan_id == scan_id)
)
scan.total_vulnerabilities = total_vuln_result.scalar() or 0
# Get total endpoint count
total_endpoint_result = await db.execute(
select(func.count()).select_from(Endpoint)
.where(Endpoint.scan_id == scan_id)
)
scan.total_endpoints = total_endpoint_result.scalar() or 0
await db.commit()
return {"message": "Scan stopped", "scan_id": scan_id}
# Build summary for WebSocket broadcast
summary = {
"total_endpoints": scan.total_endpoints,
"total_vulnerabilities": scan.total_vulnerabilities,
"critical": scan.critical_count,
"high": scan.high_count,
"medium": scan.medium_count,
"low": scan.low_count,
"info": scan.info_count,
"duration": scan.duration,
"progress": scan.progress
}
# Broadcast stop event via WebSocket
await ws_manager.broadcast_scan_stopped(scan_id, summary)
await ws_manager.broadcast_log(scan_id, "warning", "Scan stopped by user")
await ws_manager.broadcast_log(scan_id, "info", f"Partial results: {scan.total_vulnerabilities} vulnerabilities found")
# Auto-generate partial report
report_data = None
try:
from backend.services.report_service import auto_generate_report
await ws_manager.broadcast_log(scan_id, "info", "Generating partial report...")
report = await auto_generate_report(db, scan_id, is_partial=True)
report_data = report.to_dict()
await ws_manager.broadcast_log(scan_id, "info", f"Partial report generated: {report.title}")
except Exception as report_error:
await ws_manager.broadcast_log(scan_id, "warning", f"Failed to generate partial report: {str(report_error)}")
return {
"message": "Scan stopped",
"scan_id": scan_id,
"summary": summary,
"report": report_data
}
@router.get("/{scan_id}/status", response_model=ScanProgress)

View File

@@ -142,6 +142,65 @@ class ConnectionManager:
"summary": summary
})
async def broadcast_scan_stopped(self, scan_id: str, summary: dict):
"""Notify that a scan was stopped by user"""
await self.send_to_scan(scan_id, {
"type": "scan_stopped",
"scan_id": scan_id,
"status": "stopped",
"summary": summary
})
async def broadcast_scan_failed(self, scan_id: str, error: str, summary: dict = None):
"""Notify that a scan has failed"""
await self.send_to_scan(scan_id, {
"type": "scan_failed",
"scan_id": scan_id,
"status": "failed",
"error": error,
"summary": summary or {}
})
async def broadcast_stats_update(self, scan_id: str, stats: dict):
"""Broadcast updated scan statistics"""
await self.send_to_scan(scan_id, {
"type": "stats_update",
"scan_id": scan_id,
"stats": stats
})
async def broadcast_agent_task(self, scan_id: str, task: dict):
"""Broadcast agent task update (created, started, completed, failed)"""
await self.send_to_scan(scan_id, {
"type": "agent_task",
"scan_id": scan_id,
"task": task
})
async def broadcast_agent_task_started(self, scan_id: str, task: dict):
"""Broadcast when an agent task starts"""
await self.send_to_scan(scan_id, {
"type": "agent_task_started",
"scan_id": scan_id,
"task": task
})
async def broadcast_agent_task_completed(self, scan_id: str, task: dict):
"""Broadcast when an agent task completes"""
await self.send_to_scan(scan_id, {
"type": "agent_task_completed",
"scan_id": scan_id,
"task": task
})
async def broadcast_report_generated(self, scan_id: str, report: dict):
"""Broadcast when a report is generated"""
await self.send_to_scan(scan_id, {
"type": "report_generated",
"scan_id": scan_id,
"report": report
})
async def broadcast_error(self, scan_id: str, error: str):
"""Notify an error occurred"""
await self.send_to_scan(scan_id, {

View File

@@ -1,10 +1,14 @@
"""
NeuroSploit v3 - Database Configuration
"""
import logging
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
from sqlalchemy.orm import DeclarativeBase
from sqlalchemy import text
from backend.config import settings
logger = logging.getLogger(__name__)
class Base(DeclarativeBase):
"""Base class for all models"""
@@ -42,10 +46,114 @@ async def get_db() -> AsyncSession:
await session.close()
async def _run_migrations(conn):
"""Run schema migrations to add missing columns"""
try:
# Check and add duration column to scans table
result = await conn.execute(text("PRAGMA table_info(scans)"))
columns = [row[1] for row in result.fetchall()]
if "duration" not in columns:
logger.info("Adding 'duration' column to scans table...")
await conn.execute(text("ALTER TABLE scans ADD COLUMN duration INTEGER"))
# Check and add columns to reports table
result = await conn.execute(text("PRAGMA table_info(reports)"))
columns = [row[1] for row in result.fetchall()]
if columns: # Table exists
if "auto_generated" not in columns:
logger.info("Adding 'auto_generated' column to reports table...")
await conn.execute(text("ALTER TABLE reports ADD COLUMN auto_generated BOOLEAN DEFAULT 0"))
if "is_partial" not in columns:
logger.info("Adding 'is_partial' column to reports table...")
await conn.execute(text("ALTER TABLE reports ADD COLUMN is_partial BOOLEAN DEFAULT 0"))
# Check and add columns to vulnerabilities table
result = await conn.execute(text("PRAGMA table_info(vulnerabilities)"))
columns = [row[1] for row in result.fetchall()]
if columns: # Table exists
if "test_id" not in columns:
logger.info("Adding 'test_id' column to vulnerabilities table...")
await conn.execute(text("ALTER TABLE vulnerabilities ADD COLUMN test_id VARCHAR(36)"))
if "poc_parameter" not in columns:
logger.info("Adding 'poc_parameter' column to vulnerabilities table...")
await conn.execute(text("ALTER TABLE vulnerabilities ADD COLUMN poc_parameter VARCHAR(500)"))
if "poc_evidence" not in columns:
logger.info("Adding 'poc_evidence' column to vulnerabilities table...")
await conn.execute(text("ALTER TABLE vulnerabilities ADD COLUMN poc_evidence TEXT"))
# Check if agent_tasks table exists
result = await conn.execute(
text("SELECT name FROM sqlite_master WHERE type='table' AND name='agent_tasks'")
)
if not result.fetchone():
logger.info("Creating 'agent_tasks' table...")
await conn.execute(text("""
CREATE TABLE agent_tasks (
id VARCHAR(36) PRIMARY KEY,
scan_id VARCHAR(36) NOT NULL,
task_type VARCHAR(50) NOT NULL,
task_name VARCHAR(255) NOT NULL,
description TEXT,
tool_name VARCHAR(100),
tool_category VARCHAR(100),
status VARCHAR(20) DEFAULT 'pending',
started_at DATETIME,
completed_at DATETIME,
duration_ms INTEGER,
items_processed INTEGER DEFAULT 0,
items_found INTEGER DEFAULT 0,
result_summary TEXT,
error_message TEXT,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (scan_id) REFERENCES scans(id) ON DELETE CASCADE
)
"""))
await conn.execute(text("CREATE INDEX IF NOT EXISTS idx_agent_tasks_scan_id ON agent_tasks(scan_id)"))
await conn.execute(text("CREATE INDEX IF NOT EXISTS idx_agent_tasks_status ON agent_tasks(status)"))
# Check if vulnerability_tests table exists
result = await conn.execute(
text("SELECT name FROM sqlite_master WHERE type='table' AND name='vulnerability_tests'")
)
if not result.fetchone():
logger.info("Creating 'vulnerability_tests' table...")
await conn.execute(text("""
CREATE TABLE vulnerability_tests (
id VARCHAR(36) PRIMARY KEY,
scan_id VARCHAR(36) NOT NULL,
endpoint_id VARCHAR(36),
vulnerability_type VARCHAR(100) NOT NULL,
payload TEXT,
request_data JSON DEFAULT '{}',
response_data JSON DEFAULT '{}',
is_vulnerable BOOLEAN DEFAULT 0,
confidence FLOAT,
evidence TEXT,
tested_at DATETIME DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (scan_id) REFERENCES scans(id) ON DELETE CASCADE,
FOREIGN KEY (endpoint_id) REFERENCES endpoints(id) ON DELETE SET NULL
)
"""))
await conn.execute(text("CREATE INDEX IF NOT EXISTS idx_vulnerability_tests_scan_id ON vulnerability_tests(scan_id)"))
logger.info("Database migrations completed")
except Exception as e:
logger.warning(f"Migration check failed (may be normal on first run): {e}")
async def init_db():
"""Initialize database tables"""
"""Initialize database tables and run migrations"""
async with engine.begin() as conn:
# Create all tables from models
await conn.run_sync(Base.metadata.create_all)
# Run migrations to add any missing columns
await _run_migrations(conn)
async def close_db():

View File

@@ -11,7 +11,7 @@ from pathlib import Path
from backend.config import settings
from backend.db.database import init_db, close_db
from backend.api.v1 import scans, targets, prompts, reports, dashboard, vulnerabilities, settings as settings_router, agent
from backend.api.v1 import scans, targets, prompts, reports, dashboard, vulnerabilities, settings as settings_router, agent, agent_tasks
from backend.api.websocket import manager as ws_manager
@@ -59,6 +59,7 @@ app.include_router(dashboard.router, prefix="/api/v1/dashboard", tags=["Dashboar
app.include_router(vulnerabilities.router, prefix="/api/v1/vulnerabilities", tags=["Vulnerabilities"])
app.include_router(settings_router.router, prefix="/api/v1/settings", tags=["Settings"])
app.include_router(agent.router, prefix="/api/v1/agent", tags=["AI Agent"])
app.include_router(agent_tasks.router, prefix="/api/v1/agent-tasks", tags=["Agent Tasks"])
@app.get("/api/health")

View File

@@ -0,0 +1,36 @@
-- Migration: Add Dashboard Integration Columns
-- Date: 2026-01-23
-- Description: Adds duration column to scans, auto_generated/is_partial to reports, and creates agent_tasks table
-- Add duration column to scans table
ALTER TABLE scans ADD COLUMN duration INTEGER;
-- Add auto_generated and is_partial columns to reports table
ALTER TABLE reports ADD COLUMN auto_generated BOOLEAN DEFAULT 0;
ALTER TABLE reports ADD COLUMN is_partial BOOLEAN DEFAULT 0;
-- Create agent_tasks table
CREATE TABLE IF NOT EXISTS agent_tasks (
id VARCHAR(36) PRIMARY KEY,
scan_id VARCHAR(36) NOT NULL,
task_type VARCHAR(50) NOT NULL,
task_name VARCHAR(255) NOT NULL,
description TEXT,
tool_name VARCHAR(100),
tool_category VARCHAR(100),
status VARCHAR(20) DEFAULT 'pending',
started_at DATETIME,
completed_at DATETIME,
duration_ms INTEGER,
items_processed INTEGER DEFAULT 0,
items_found INTEGER DEFAULT 0,
result_summary TEXT,
error_message TEXT,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (scan_id) REFERENCES scans(id) ON DELETE CASCADE
);
-- Create indexes for performance
CREATE INDEX IF NOT EXISTS idx_agent_tasks_scan_id ON agent_tasks(scan_id);
CREATE INDEX IF NOT EXISTS idx_agent_tasks_status ON agent_tasks(status);
CREATE INDEX IF NOT EXISTS idx_agent_tasks_task_type ON agent_tasks(task_type);

View File

@@ -0,0 +1,4 @@
"""Database migrations for NeuroSploit v3"""
from backend.migrations.run_migrations import run_migration, get_db_path
__all__ = ["run_migration", "get_db_path"]

View File

@@ -0,0 +1,137 @@
#!/usr/bin/env python3
"""
Run database migrations for NeuroSploit v3
Usage:
python -m backend.migrations.run_migrations
Or from backend directory:
python migrations/run_migrations.py
"""
import sqlite3
import os
from pathlib import Path
def get_db_path():
"""Get the database file path"""
# Try common locations
possible_paths = [
Path("./data/neurosploit.db"),
Path("../data/neurosploit.db"),
Path("/opt/NeuroSploitv2/data/neurosploit.db"),
Path("/opt/NeuroSploitv2/backend/data/neurosploit.db"),
]
for path in possible_paths:
if path.exists():
return str(path.resolve())
# Default path
return "./data/neurosploit.db"
def column_exists(cursor, table_name, column_name):
"""Check if a column exists in a table"""
cursor.execute(f"PRAGMA table_info({table_name})")
columns = [row[1] for row in cursor.fetchall()]
return column_name in columns
def table_exists(cursor, table_name):
"""Check if a table exists"""
cursor.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name=?",
(table_name,)
)
return cursor.fetchone() is not None
def run_migration(db_path: str):
"""Run the database migration"""
print(f"Running migration on database: {db_path}")
if not os.path.exists(db_path):
print(f"Database file not found at {db_path}")
print("Creating data directory and database will be created on first run")
os.makedirs(os.path.dirname(db_path), exist_ok=True)
return
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
try:
# Migration 1: Add duration column to scans table
if not column_exists(cursor, "scans", "duration"):
print("Adding 'duration' column to scans table...")
cursor.execute("ALTER TABLE scans ADD COLUMN duration INTEGER")
print(" Done!")
else:
print("Column 'duration' already exists in scans table")
# Migration 2: Add auto_generated column to reports table
if table_exists(cursor, "reports"):
if not column_exists(cursor, "reports", "auto_generated"):
print("Adding 'auto_generated' column to reports table...")
cursor.execute("ALTER TABLE reports ADD COLUMN auto_generated BOOLEAN DEFAULT 0")
print(" Done!")
else:
print("Column 'auto_generated' already exists in reports table")
# Migration 3: Add is_partial column to reports table
if not column_exists(cursor, "reports", "is_partial"):
print("Adding 'is_partial' column to reports table...")
cursor.execute("ALTER TABLE reports ADD COLUMN is_partial BOOLEAN DEFAULT 0")
print(" Done!")
else:
print("Column 'is_partial' already exists in reports table")
else:
print("Reports table does not exist yet, will be created on first run")
# Migration 4: Create agent_tasks table
if not table_exists(cursor, "agent_tasks"):
print("Creating 'agent_tasks' table...")
cursor.execute("""
CREATE TABLE agent_tasks (
id VARCHAR(36) PRIMARY KEY,
scan_id VARCHAR(36) NOT NULL,
task_type VARCHAR(50) NOT NULL,
task_name VARCHAR(255) NOT NULL,
description TEXT,
tool_name VARCHAR(100),
tool_category VARCHAR(100),
status VARCHAR(20) DEFAULT 'pending',
started_at DATETIME,
completed_at DATETIME,
duration_ms INTEGER,
items_processed INTEGER DEFAULT 0,
items_found INTEGER DEFAULT 0,
result_summary TEXT,
error_message TEXT,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (scan_id) REFERENCES scans(id) ON DELETE CASCADE
)
""")
# Create indexes
cursor.execute("CREATE INDEX idx_agent_tasks_scan_id ON agent_tasks(scan_id)")
cursor.execute("CREATE INDEX idx_agent_tasks_status ON agent_tasks(status)")
cursor.execute("CREATE INDEX idx_agent_tasks_task_type ON agent_tasks(task_type)")
print(" Done!")
else:
print("Table 'agent_tasks' already exists")
conn.commit()
print("\nMigration completed successfully!")
except Exception as e:
conn.rollback()
print(f"\nMigration failed: {e}")
raise
finally:
conn.close()
if __name__ == "__main__":
db_path = get_db_path()
run_migration(db_path)

View File

@@ -4,6 +4,7 @@ from backend.models.prompt import Prompt
from backend.models.endpoint import Endpoint
from backend.models.vulnerability import Vulnerability, VulnerabilityTest
from backend.models.report import Report
from backend.models.agent_task import AgentTask
__all__ = [
"Scan",
@@ -12,5 +13,6 @@ __all__ = [
"Endpoint",
"Vulnerability",
"VulnerabilityTest",
"Report"
"Report",
"AgentTask"
]

View File

@@ -0,0 +1,94 @@
"""
NeuroSploit v3 - Agent Task Model
Tracks all agent activities during scans for dashboard visibility.
"""
from datetime import datetime
from typing import Optional
from sqlalchemy import String, Integer, DateTime, Text, ForeignKey
from sqlalchemy.orm import Mapped, mapped_column, relationship
from backend.db.database import Base
import uuid
class AgentTask(Base):
"""Agent task record for tracking scan activities"""
__tablename__ = "agent_tasks"
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
scan_id: Mapped[str] = mapped_column(String(36), ForeignKey("scans.id", ondelete="CASCADE"))
# Task identification
task_type: Mapped[str] = mapped_column(String(50)) # recon, analysis, testing, reporting
task_name: Mapped[str] = mapped_column(String(255)) # Human-readable name
description: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
# Tool information
tool_name: Mapped[Optional[str]] = mapped_column(String(100), nullable=True) # nmap, nuclei, claude, httpx, etc.
tool_category: Mapped[Optional[str]] = mapped_column(String(50), nullable=True) # scanner, analyzer, ai, crawler
# Status tracking
status: Mapped[str] = mapped_column(String(20), default="pending") # pending, running, completed, failed, cancelled
# Timing
started_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
duration_ms: Mapped[Optional[int]] = mapped_column(Integer, nullable=True) # Duration in milliseconds
# Results
items_processed: Mapped[int] = mapped_column(Integer, default=0) # URLs tested, hosts scanned, etc.
items_found: Mapped[int] = mapped_column(Integer, default=0) # Endpoints found, vulns found, etc.
result_summary: Mapped[Optional[str]] = mapped_column(Text, nullable=True) # Brief summary of results
# Error handling
error_message: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
# Metadata
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
# Relationships
scan: Mapped["Scan"] = relationship("Scan", back_populates="agent_tasks")
def to_dict(self) -> dict:
"""Convert to dictionary"""
return {
"id": self.id,
"scan_id": self.scan_id,
"task_type": self.task_type,
"task_name": self.task_name,
"description": self.description,
"tool_name": self.tool_name,
"tool_category": self.tool_category,
"status": self.status,
"started_at": self.started_at.isoformat() if self.started_at else None,
"completed_at": self.completed_at.isoformat() if self.completed_at else None,
"duration_ms": self.duration_ms,
"items_processed": self.items_processed,
"items_found": self.items_found,
"result_summary": self.result_summary,
"error_message": self.error_message,
"created_at": self.created_at.isoformat() if self.created_at else None
}
def start(self):
"""Mark task as started"""
self.status = "running"
self.started_at = datetime.utcnow()
def complete(self, items_processed: int = 0, items_found: int = 0, summary: str = None):
"""Mark task as completed"""
self.status = "completed"
self.completed_at = datetime.utcnow()
self.items_processed = items_processed
self.items_found = items_found
self.result_summary = summary
if self.started_at:
self.duration_ms = int((self.completed_at - self.started_at).total_seconds() * 1000)
def fail(self, error: str):
"""Mark task as failed"""
self.status = "failed"
self.completed_at = datetime.utcnow()
self.error_message = error
if self.started_at:
self.duration_ms = int((self.completed_at - self.started_at).total_seconds() * 1000)

View File

@@ -3,7 +3,7 @@ NeuroSploit v3 - Report Model
"""
from datetime import datetime
from typing import Optional
from sqlalchemy import String, DateTime, Text, ForeignKey
from sqlalchemy import String, DateTime, Text, ForeignKey, Boolean
from sqlalchemy.orm import Mapped, mapped_column, relationship
from backend.db.database import Base
import uuid
@@ -24,6 +24,10 @@ class Report(Base):
# Content
executive_summary: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
# Auto-generation flags
auto_generated: Mapped[bool] = mapped_column(Boolean, default=False) # True if auto-generated on scan completion/stop
is_partial: Mapped[bool] = mapped_column(Boolean, default=False) # True if generated from stopped/incomplete scan
# Timestamps
generated_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
@@ -39,5 +43,7 @@ class Report(Base):
"format": self.format,
"file_path": self.file_path,
"executive_summary": self.executive_summary,
"auto_generated": self.auto_generated,
"is_partial": self.is_partial,
"generated_at": self.generated_at.isoformat() if self.generated_at else None
}

View File

@@ -39,6 +39,7 @@ class Scan(Base):
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
started_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
duration: Mapped[Optional[int]] = mapped_column(Integer, nullable=True) # Duration in seconds
# Error handling
error_message: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
@@ -57,6 +58,7 @@ class Scan(Base):
endpoints: Mapped[List["Endpoint"]] = relationship("Endpoint", back_populates="scan", cascade="all, delete-orphan")
vulnerabilities: Mapped[List["Vulnerability"]] = relationship("Vulnerability", back_populates="scan", cascade="all, delete-orphan")
reports: Mapped[List["Report"]] = relationship("Report", back_populates="scan", cascade="all, delete-orphan")
agent_tasks: Mapped[List["AgentTask"]] = relationship("AgentTask", back_populates="scan", cascade="all, delete-orphan")
def to_dict(self) -> dict:
"""Convert to dictionary"""
@@ -77,6 +79,7 @@ class Scan(Base):
"created_at": self.created_at.isoformat() if self.created_at else None,
"started_at": self.started_at.isoformat() if self.started_at else None,
"completed_at": self.completed_at.isoformat() if self.completed_at else None,
"duration": self.duration,
"error_message": self.error_message,
"total_endpoints": self.total_endpoints,
"total_vulnerabilities": self.total_vulnerabilities,

View File

@@ -27,11 +27,19 @@ from backend.schemas.report import (
ReportResponse,
ReportGenerate
)
from backend.schemas.agent_task import (
AgentTaskCreate,
AgentTaskUpdate,
AgentTaskResponse,
AgentTaskListResponse,
AgentTaskSummary
)
__all__ = [
"ScanCreate", "ScanUpdate", "ScanResponse", "ScanListResponse", "ScanProgress",
"TargetCreate", "TargetResponse", "TargetBulkCreate", "TargetValidation",
"PromptCreate", "PromptUpdate", "PromptResponse", "PromptParse", "PromptParseResult",
"VulnerabilityResponse", "VulnerabilityTestResponse", "VulnerabilityTypeInfo",
"ReportResponse", "ReportGenerate"
"ReportResponse", "ReportGenerate",
"AgentTaskCreate", "AgentTaskUpdate", "AgentTaskResponse", "AgentTaskListResponse", "AgentTaskSummary"
]

View File

@@ -0,0 +1,66 @@
"""
NeuroSploit v3 - Agent Task Schemas
"""
from datetime import datetime
from typing import Optional, List
from pydantic import BaseModel, Field
class AgentTaskCreate(BaseModel):
"""Schema for creating an agent task"""
scan_id: str = Field(..., description="Scan ID this task belongs to")
task_type: str = Field(..., description="Task type: recon, analysis, testing, reporting")
task_name: str = Field(..., description="Human-readable task name")
description: Optional[str] = Field(None, description="Task description")
tool_name: Optional[str] = Field(None, description="Tool being used")
tool_category: Optional[str] = Field(None, description="Tool category")
class AgentTaskUpdate(BaseModel):
"""Schema for updating an agent task"""
status: Optional[str] = Field(None, description="Task status")
items_processed: Optional[int] = Field(None, description="Items processed")
items_found: Optional[int] = Field(None, description="Items found")
result_summary: Optional[str] = Field(None, description="Result summary")
error_message: Optional[str] = Field(None, description="Error message if failed")
class AgentTaskResponse(BaseModel):
"""Schema for agent task response"""
id: str
scan_id: str
task_type: str
task_name: str
description: Optional[str]
tool_name: Optional[str]
tool_category: Optional[str]
status: str
started_at: Optional[datetime]
completed_at: Optional[datetime]
duration_ms: Optional[int]
items_processed: int
items_found: int
result_summary: Optional[str]
error_message: Optional[str]
created_at: datetime
class Config:
from_attributes = True
class AgentTaskListResponse(BaseModel):
"""Schema for list of agent tasks"""
tasks: List[AgentTaskResponse]
total: int
scan_id: str
class AgentTaskSummary(BaseModel):
"""Schema for agent task summary statistics"""
total: int
pending: int
running: int
completed: int
failed: int
by_type: dict # recon, analysis, testing, reporting counts
by_tool: dict # tool name -> count

View File

@@ -24,6 +24,8 @@ class ReportResponse(BaseModel):
format: str
file_path: Optional[str]
executive_summary: Optional[str]
auto_generated: bool = False
is_partial: bool = False
generated_at: datetime
class Config:

View File

@@ -0,0 +1,105 @@
"""
NeuroSploit v3 - Report Service
Handles automatic report generation on scan completion/stop.
"""
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from backend.models import Scan, Report, Vulnerability
from backend.core.report_engine.generator import ReportGenerator
from backend.api.websocket import manager as ws_manager
class ReportService:
"""Service for automatic report generation"""
def __init__(self, db: AsyncSession):
self.db = db
self.generator = ReportGenerator()
async def auto_generate_report(
self,
scan_id: str,
is_partial: bool = False,
format: str = "html"
) -> Report:
"""
Automatically generate a report for a scan.
Args:
scan_id: The scan ID to generate report for
is_partial: True if scan was stopped/incomplete
format: Report format (html, pdf, json)
Returns:
The generated Report model instance
"""
# Get scan
scan_result = await self.db.execute(select(Scan).where(Scan.id == scan_id))
scan = scan_result.scalar_one_or_none()
if not scan:
raise ValueError(f"Scan {scan_id} not found")
# Get vulnerabilities
vulns_result = await self.db.execute(
select(Vulnerability).where(Vulnerability.scan_id == scan_id)
)
vulnerabilities = vulns_result.scalars().all()
# Generate title
if is_partial:
title = f"Partial Report - {scan.name or 'Unnamed Scan'}"
else:
title = f"Security Assessment Report - {scan.name or 'Unnamed Scan'}"
# Generate report
try:
report_path, executive_summary = await self.generator.generate(
scan=scan,
vulnerabilities=vulnerabilities,
format=format,
title=title,
include_executive_summary=True,
include_poc=True,
include_remediation=True
)
# Create report record
report = Report(
scan_id=scan_id,
title=title,
format=format,
file_path=str(report_path) if report_path else None,
executive_summary=executive_summary,
auto_generated=True,
is_partial=is_partial
)
self.db.add(report)
await self.db.commit()
await self.db.refresh(report)
# Broadcast report generated event
await ws_manager.broadcast_report_generated(scan_id, report.to_dict())
await ws_manager.broadcast_log(
scan_id,
"info",
f"Report auto-generated: {title}"
)
return report
except Exception as e:
await ws_manager.broadcast_log(
scan_id,
"error",
f"Failed to auto-generate report: {str(e)}"
)
raise
async def auto_generate_report(db: AsyncSession, scan_id: str, is_partial: bool = False) -> Report:
"""Helper function to auto-generate a report"""
service = ReportService(db)
return await service.auto_generate_report(scan_id, is_partial)

View File

@@ -19,7 +19,7 @@ from typing import Optional, List, Dict, Any
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from backend.models import Scan, Target, Endpoint, Vulnerability, VulnerabilityTest
from backend.models import Scan, Target, Endpoint, Vulnerability, VulnerabilityTest, AgentTask
from backend.api.websocket import manager as ws_manager
from backend.api.v1.prompts import PRESET_PROMPTS
from backend.db.database import async_session_factory
@@ -72,6 +72,55 @@ class ScanService:
self.payload_generator = PayloadGenerator()
self._stop_requested = False
async def _create_agent_task(
self,
scan_id: str,
task_type: str,
task_name: str,
description: str = None,
tool_name: str = None,
tool_category: str = None
) -> AgentTask:
"""Create and start a new agent task"""
task = AgentTask(
scan_id=scan_id,
task_type=task_type,
task_name=task_name,
description=description,
tool_name=tool_name,
tool_category=tool_category
)
task.start()
self.db.add(task)
await self.db.flush()
# Broadcast task started
await ws_manager.broadcast_agent_task_started(scan_id, task.to_dict())
return task
async def _complete_agent_task(
self,
task: AgentTask,
items_processed: int = 0,
items_found: int = 0,
summary: str = None
):
"""Mark an agent task as completed"""
task.complete(items_processed, items_found, summary)
await self.db.commit()
# Broadcast task completed
await ws_manager.broadcast_agent_task_completed(task.scan_id, task.to_dict())
async def _fail_agent_task(self, task: AgentTask, error: str):
"""Mark an agent task as failed"""
task.fail(error)
await self.db.commit()
# Broadcast task update
await ws_manager.broadcast_agent_task(task.scan_id, task.to_dict())
async def execute_scan(self, scan_id: str):
"""Execute a complete scan with real recon, autonomous discovery, and AI analysis"""
try:
@@ -112,13 +161,29 @@ class ScanService:
await ws_manager.broadcast_log(scan_id, "info", f"Targets: {', '.join([t.url for t in targets])}")
# Check available tools
# Check available tools - Create task for initialization
init_task = await self._create_agent_task(
scan_id=scan_id,
task_type="recon",
task_name="Initialize Security Tools",
description="Checking available security tools and dependencies",
tool_name="system",
tool_category="setup"
)
await ws_manager.broadcast_log(scan_id, "info", "")
await ws_manager.broadcast_log(scan_id, "info", "Checking installed security tools...")
tools_status = await check_tools_installed()
installed_tools = [t for t, installed in tools_status.items() if installed]
await ws_manager.broadcast_log(scan_id, "info", f"Available: {', '.join(installed_tools[:15])}...")
await self._complete_agent_task(
init_task,
items_processed=len(tools_status),
items_found=len(installed_tools),
summary=f"Found {len(installed_tools)} available security tools"
)
# Get prompt content
prompt_content = await self._get_prompt_content(scan)
await ws_manager.broadcast_log(scan_id, "info", "")
@@ -141,24 +206,46 @@ class ScanService:
depth = "medium" if scan.scan_type == "full" else "quick"
for target in targets:
await ws_manager.broadcast_log(scan_id, "info", f"Target: {target.url}")
target_recon = await recon_integration.run_full_recon(target.url, depth=depth)
recon_data = self._merge_recon_data(recon_data, target_recon)
# Create recon task for each target
recon_task = await self._create_agent_task(
scan_id=scan_id,
task_type="recon",
task_name=f"Reconnaissance: {target.hostname or target.url[:30]}",
description=f"Running {depth} reconnaissance on {target.url}",
tool_name="recon_integration",
tool_category="scanner"
)
# Save discovered endpoints to database
for endpoint_data in target_recon.get("endpoints", []):
if isinstance(endpoint_data, dict):
endpoint = Endpoint(
scan_id=scan_id,
target_id=target.id,
url=endpoint_data.get("url", ""),
method="GET",
path=endpoint_data.get("path", "/"),
response_status=endpoint_data.get("status"),
content_type=endpoint_data.get("content_type", "")
)
self.db.add(endpoint)
scan.total_endpoints += 1
try:
await ws_manager.broadcast_log(scan_id, "info", f"Target: {target.url}")
target_recon = await recon_integration.run_full_recon(target.url, depth=depth)
recon_data = self._merge_recon_data(recon_data, target_recon)
endpoints_found = 0
# Save discovered endpoints to database
for endpoint_data in target_recon.get("endpoints", []):
if isinstance(endpoint_data, dict):
endpoint = Endpoint(
scan_id=scan_id,
target_id=target.id,
url=endpoint_data.get("url", ""),
method="GET",
path=endpoint_data.get("path", "/"),
response_status=endpoint_data.get("status"),
content_type=endpoint_data.get("content_type", "")
)
self.db.add(endpoint)
scan.total_endpoints += 1
endpoints_found += 1
await self._complete_agent_task(
recon_task,
items_processed=1,
items_found=endpoints_found,
summary=f"Found {endpoints_found} endpoints, {len(target_recon.get('urls', []))} URLs"
)
except Exception as e:
await self._fail_agent_task(recon_task, str(e))
await self.db.commit()
recon_endpoints = scan.total_endpoints
@@ -181,60 +268,90 @@ class ScanService:
await ws_manager.broadcast_log(scan_id, level, message)
for target in targets:
async with AutonomousScanner(
# Create autonomous discovery task
discovery_task = await self._create_agent_task(
scan_id=scan_id,
log_callback=scanner_log,
timeout=15,
max_depth=3
) as scanner:
autonomous_results = await scanner.run_autonomous_scan(
target_url=target.url,
recon_data=recon_data
)
task_type="recon",
task_name=f"Autonomous Discovery: {target.hostname or target.url[:30]}",
description="AI-powered endpoint discovery and vulnerability scanning",
tool_name="autonomous_scanner",
tool_category="ai"
)
# Merge autonomous results
for ep in autonomous_results.get("endpoints", []):
if isinstance(ep, dict):
endpoint = Endpoint(
scan_id=scan_id,
target_id=target.id,
url=ep.get("url", ""),
method=ep.get("method", "GET"),
path=ep.get("url", "").split("?")[0].split("/")[-1] or "/"
)
self.db.add(endpoint)
scan.total_endpoints += 1
try:
endpoints_discovered = 0
vulns_discovered = 0
# Add URLs to recon data
recon_data["urls"] = recon_data.get("urls", []) + [
ep.get("url") for ep in autonomous_results.get("endpoints", [])
if isinstance(ep, dict)
]
recon_data["directories"] = autonomous_results.get("directories_found", [])
recon_data["parameters"] = autonomous_results.get("parameters_found", [])
# Save autonomous vulnerabilities directly
for vuln in autonomous_results.get("vulnerabilities", []):
db_vuln = Vulnerability(
scan_id=scan_id,
title=f"{vuln['type'].replace('_', ' ').title()} on {vuln['endpoint'][:50]}",
vulnerability_type=vuln["type"],
severity=self._confidence_to_severity(vuln["confidence"]),
description=vuln["evidence"],
affected_endpoint=vuln["endpoint"],
poc_payload=vuln["payload"],
poc_request=str(vuln.get("request", {}))[:5000],
poc_response=str(vuln.get("response", {}))[:5000]
async with AutonomousScanner(
scan_id=scan_id,
log_callback=scanner_log,
timeout=15,
max_depth=3
) as scanner:
autonomous_results = await scanner.run_autonomous_scan(
target_url=target.url,
recon_data=recon_data
)
self.db.add(db_vuln)
await ws_manager.broadcast_vulnerability_found(scan_id, {
"id": db_vuln.id,
"title": db_vuln.title,
"severity": db_vuln.severity,
"type": vuln["type"],
"endpoint": vuln["endpoint"]
})
# Merge autonomous results
for ep in autonomous_results.get("endpoints", []):
if isinstance(ep, dict):
endpoint = Endpoint(
scan_id=scan_id,
target_id=target.id,
url=ep.get("url", ""),
method=ep.get("method", "GET"),
path=ep.get("url", "").split("?")[0].split("/")[-1] or "/"
)
self.db.add(endpoint)
scan.total_endpoints += 1
endpoints_discovered += 1
# Add URLs to recon data
recon_data["urls"] = recon_data.get("urls", []) + [
ep.get("url") for ep in autonomous_results.get("endpoints", [])
if isinstance(ep, dict)
]
recon_data["directories"] = autonomous_results.get("directories_found", [])
recon_data["parameters"] = autonomous_results.get("parameters_found", [])
# Save autonomous vulnerabilities directly
for vuln in autonomous_results.get("vulnerabilities", []):
vuln_severity = self._confidence_to_severity(vuln["confidence"])
db_vuln = Vulnerability(
scan_id=scan_id,
title=f"{vuln['type'].replace('_', ' ').title()} on {vuln['endpoint'][:50]}",
vulnerability_type=vuln["type"],
severity=vuln_severity,
description=vuln["evidence"],
affected_endpoint=vuln["endpoint"],
poc_payload=vuln["payload"],
poc_request=str(vuln.get("request", {}))[:5000],
poc_response=str(vuln.get("response", {}))[:5000]
)
self.db.add(db_vuln)
await self.db.flush() # Ensure ID is assigned
vulns_discovered += 1
# Increment vulnerability count
await self._increment_vulnerability_count(scan, vuln_severity)
await ws_manager.broadcast_vulnerability_found(scan_id, {
"id": db_vuln.id,
"title": db_vuln.title,
"severity": db_vuln.severity,
"type": vuln["type"],
"endpoint": vuln["endpoint"]
})
await self._complete_agent_task(
discovery_task,
items_processed=endpoints_discovered,
items_found=vulns_discovered,
summary=f"Discovered {endpoints_discovered} endpoints, {vulns_discovered} vulnerabilities"
)
except Exception as e:
await self._fail_agent_task(discovery_task, str(e))
await self.db.commit()
await ws_manager.broadcast_log(scan_id, "info", f"Autonomous discovery complete. Total endpoints: {scan.total_endpoints}")
@@ -249,27 +366,48 @@ class ScanService:
await ws_manager.broadcast_log(scan_id, "info", "PHASE 2: AI ANALYSIS")
await ws_manager.broadcast_log(scan_id, "info", "=" * 40)
# Enhance prompt with authorization
enhanced_prompt = f"{GLOBAL_AUTHORIZATION}\n\nUSER REQUEST:\n{prompt_content}"
# Get AI-generated testing plan
await ws_manager.broadcast_log(scan_id, "info", "AI processing prompt and determining attack strategy...")
testing_plan = await self.ai_processor.process_prompt(
prompt=enhanced_prompt,
recon_data=recon_data,
target_info={"targets": [t.url for t in targets]}
# Create AI analysis task
analysis_task = await self._create_agent_task(
scan_id=scan_id,
task_type="analysis",
task_name="AI Strategy Analysis",
description="Analyzing prompt and recon data to determine testing strategy",
tool_name="ai_prompt_processor",
tool_category="ai"
)
await ws_manager.broadcast_log(scan_id, "info", "")
await ws_manager.broadcast_log(scan_id, "info", "AI TESTING PLAN:")
await ws_manager.broadcast_log(scan_id, "info", f" Vulnerability Types: {', '.join(testing_plan.vulnerability_types[:10])}")
if len(testing_plan.vulnerability_types) > 10:
await ws_manager.broadcast_log(scan_id, "info", f" ... and {len(testing_plan.vulnerability_types) - 10} more types")
await ws_manager.broadcast_log(scan_id, "info", f" Testing Focus: {', '.join(testing_plan.testing_focus[:5])}")
await ws_manager.broadcast_log(scan_id, "info", f" Depth: {testing_plan.testing_depth}")
await ws_manager.broadcast_log(scan_id, "info", "")
await ws_manager.broadcast_log(scan_id, "info", f"AI Reasoning: {testing_plan.ai_reasoning[:300]}...")
try:
# Enhance prompt with authorization
enhanced_prompt = f"{GLOBAL_AUTHORIZATION}\n\nUSER REQUEST:\n{prompt_content}"
# Get AI-generated testing plan
await ws_manager.broadcast_log(scan_id, "info", "AI processing prompt and determining attack strategy...")
testing_plan = await self.ai_processor.process_prompt(
prompt=enhanced_prompt,
recon_data=recon_data,
target_info={"targets": [t.url for t in targets]}
)
await ws_manager.broadcast_log(scan_id, "info", "")
await ws_manager.broadcast_log(scan_id, "info", "AI TESTING PLAN:")
await ws_manager.broadcast_log(scan_id, "info", f" Vulnerability Types: {', '.join(testing_plan.vulnerability_types[:10])}")
if len(testing_plan.vulnerability_types) > 10:
await ws_manager.broadcast_log(scan_id, "info", f" ... and {len(testing_plan.vulnerability_types) - 10} more types")
await ws_manager.broadcast_log(scan_id, "info", f" Testing Focus: {', '.join(testing_plan.testing_focus[:5])}")
await ws_manager.broadcast_log(scan_id, "info", f" Depth: {testing_plan.testing_depth}")
await ws_manager.broadcast_log(scan_id, "info", "")
await ws_manager.broadcast_log(scan_id, "info", f"AI Reasoning: {testing_plan.ai_reasoning[:300]}...")
await self._complete_agent_task(
analysis_task,
items_processed=1,
items_found=len(testing_plan.vulnerability_types),
summary=f"Generated testing plan with {len(testing_plan.vulnerability_types)} vulnerability types"
)
except Exception as e:
await self._fail_agent_task(analysis_task, str(e))
raise
await ws_manager.broadcast_progress(scan_id, 45, f"Testing {len(testing_plan.vulnerability_types)} vuln types")
@@ -286,48 +424,78 @@ class ScanService:
for target in targets:
await ws_manager.broadcast_log(scan_id, "info", f"Deploying AI Agent on: {target.url}")
# Create log callback for the agent
async def agent_log(level: str, message: str):
await ws_manager.broadcast_log(scan_id, level, message)
# Create AI pentest agent task
agent_task = await self._create_agent_task(
scan_id=scan_id,
task_type="testing",
task_name=f"AI Pentest Agent: {target.hostname or target.url[:30]}",
description=f"AI-powered penetration testing on {target.url}",
tool_name="ai_pentest_agent",
tool_category="ai"
)
# Build auth headers
auth_headers = self._build_auth_headers(scan)
try:
# Create log callback for the agent
async def agent_log(level: str, message: str):
await ws_manager.broadcast_log(scan_id, level, message)
async with AIPentestAgent(
target=target.url,
log_callback=agent_log,
auth_headers=auth_headers,
max_depth=5
) as agent:
agent_report = await agent.run()
# Build auth headers
auth_headers = self._build_auth_headers(scan)
# Save agent findings as vulnerabilities
for finding in agent_report.get("findings", []):
vuln = Vulnerability(
scan_id=scan_id,
title=f"{finding['type'].upper()} - {finding['endpoint'][:50]}",
vulnerability_type=finding["type"],
severity=finding["severity"],
description=finding["evidence"],
affected_endpoint=finding["endpoint"],
poc_payload=finding["payload"],
poc_request=finding.get("raw_request", "")[:5000],
poc_response=finding.get("raw_response", "")[:5000],
remediation=finding.get("impact", ""),
ai_analysis="\n".join(finding.get("exploitation_steps", []))
)
self.db.add(vuln)
findings_count = 0
endpoints_tested = 0
await ws_manager.broadcast_vulnerability_found(scan_id, {
"id": vuln.id,
"title": vuln.title,
"severity": vuln.severity,
"type": finding["type"],
"endpoint": finding["endpoint"]
})
async with AIPentestAgent(
target=target.url,
log_callback=agent_log,
auth_headers=auth_headers,
max_depth=5
) as agent:
agent_report = await agent.run()
# Update endpoint count
scan.total_endpoints += agent_report.get("summary", {}).get("total_endpoints", 0)
# Save agent findings as vulnerabilities
for finding in agent_report.get("findings", []):
finding_severity = finding["severity"]
vuln = Vulnerability(
scan_id=scan_id,
title=f"{finding['type'].upper()} - {finding['endpoint'][:50]}",
vulnerability_type=finding["type"],
severity=finding_severity,
description=finding["evidence"],
affected_endpoint=finding["endpoint"],
poc_payload=finding["payload"],
poc_request=finding.get("raw_request", "")[:5000],
poc_response=finding.get("raw_response", "")[:5000],
remediation=finding.get("impact", ""),
ai_analysis="\n".join(finding.get("exploitation_steps", []))
)
self.db.add(vuln)
await self.db.flush() # Ensure ID is assigned
findings_count += 1
# Increment vulnerability count
await self._increment_vulnerability_count(scan, finding_severity)
await ws_manager.broadcast_vulnerability_found(scan_id, {
"id": vuln.id,
"title": vuln.title,
"severity": vuln.severity,
"type": finding["type"],
"endpoint": finding["endpoint"]
})
# Update endpoint count
endpoints_tested = agent_report.get("summary", {}).get("total_endpoints", 0)
scan.total_endpoints += endpoints_tested
await self._complete_agent_task(
agent_task,
items_processed=endpoints_tested,
items_found=findings_count,
summary=f"Tested {endpoints_tested} endpoints, found {findings_count} vulnerabilities"
)
except Exception as e:
await self._fail_agent_task(agent_task, str(e))
await self.db.commit()
@@ -377,38 +545,70 @@ class ScanService:
await ws_manager.broadcast_log(scan_id, "info", f"Testing {len(endpoints)} endpoints for {len(testing_plan.vulnerability_types)} vuln types")
await ws_manager.broadcast_log(scan_id, "info", "")
# Test endpoints with AI-determined vulnerabilities
total_endpoints = len(endpoints)
async with DynamicVulnerabilityEngine() as engine:
for i, endpoint in enumerate(endpoints):
if self._stop_requested:
break
# Create vulnerability testing task
vuln_testing_task = await self._create_agent_task(
scan_id=scan_id,
task_type="testing",
task_name="Vulnerability Testing",
description=f"Testing {len(endpoints)} endpoints for {len(testing_plan.vulnerability_types)} vulnerability types",
tool_name="dynamic_vuln_engine",
tool_category="scanner"
)
progress = 45 + int((i / total_endpoints) * 45)
await ws_manager.broadcast_progress(
scan_id, progress,
f"Testing {i+1}/{total_endpoints}: {endpoint.path or endpoint.url[:50]}"
)
try:
# Test endpoints with AI-determined vulnerabilities
total_endpoints = len(endpoints)
endpoints_tested = 0
vulns_before = scan.total_vulnerabilities
# Log what we're testing
await ws_manager.broadcast_log(scan_id, "debug", f"[{i+1}/{total_endpoints}] Testing: {endpoint.url[:80]}")
async with DynamicVulnerabilityEngine() as engine:
for i, endpoint in enumerate(endpoints):
if self._stop_requested:
break
await self._test_endpoint_with_ai(
scan=scan,
endpoint=endpoint,
testing_plan=testing_plan,
engine=engine,
recon_data=recon_data
)
progress = 45 + int((i / total_endpoints) * 45)
await ws_manager.broadcast_progress(
scan_id, progress,
f"Testing {i+1}/{total_endpoints}: {endpoint.path or endpoint.url[:50]}"
)
# Update counts
await self._update_vulnerability_counts(scan)
# Log what we're testing
await ws_manager.broadcast_log(scan_id, "debug", f"[{i+1}/{total_endpoints}] Testing: {endpoint.url[:80]}")
await self._test_endpoint_with_ai(
scan=scan,
endpoint=endpoint,
testing_plan=testing_plan,
engine=engine,
recon_data=recon_data
)
endpoints_tested += 1
# Update final counts
await self._update_vulnerability_counts(scan)
vulns_found = scan.total_vulnerabilities - vulns_before
await self._complete_agent_task(
vuln_testing_task,
items_processed=endpoints_tested,
items_found=vulns_found,
summary=f"Tested {endpoints_tested} endpoints, found {vulns_found} vulnerabilities"
)
except Exception as e:
await self._fail_agent_task(vuln_testing_task, str(e))
raise
# Phase 4: Complete
scan.status = "completed"
scan.completed_at = datetime.utcnow()
scan.progress = 100
scan.current_phase = "completed"
# Calculate duration
if scan.started_at:
duration = (scan.completed_at - scan.started_at).total_seconds()
scan.duration = int(duration)
await self.db.commit()
await ws_manager.broadcast_log(scan_id, "info", "")
@@ -432,6 +632,16 @@ class ScanService:
"low": scan.low_count
})
# Auto-generate report on completion
try:
from backend.services.report_service import auto_generate_report
await ws_manager.broadcast_log(scan_id, "info", "")
await ws_manager.broadcast_log(scan_id, "info", "Generating security assessment report...")
report = await auto_generate_report(self.db, scan_id, is_partial=False)
await ws_manager.broadcast_log(scan_id, "info", f"Report generated: {report.title}")
except Exception as report_error:
await ws_manager.broadcast_log(scan_id, "warning", f"Failed to auto-generate report: {str(report_error)}")
except Exception as e:
import traceback
error_msg = f"Scan error: {str(e)}"
@@ -559,11 +769,12 @@ Be thorough and test all discovered endpoints aggressively.
if confidence >= 0.5: # Lower threshold to catch more
# Create vulnerability record
vuln_severity = ai_analysis.get("severity", self._confidence_to_severity(confidence))
vuln = Vulnerability(
scan_id=scan.id,
title=f"{vuln_type.replace('_', ' ').title()} on {endpoint.path or endpoint.url}",
vulnerability_type=vuln_type,
severity=ai_analysis.get("severity", self._confidence_to_severity(confidence)),
severity=vuln_severity,
description=ai_analysis.get("evidence", result.get("evidence", "")),
affected_endpoint=endpoint.url,
poc_payload=payload,
@@ -573,6 +784,10 @@ Be thorough and test all discovered endpoints aggressively.
ai_analysis=ai_analysis.get("exploitation_path", "")
)
self.db.add(vuln)
await self.db.flush() # Ensure ID is assigned
# Increment vulnerability count
await self._increment_vulnerability_count(scan, vuln_severity)
await ws_manager.broadcast_vulnerability_found(scan.id, {
"id": vuln.id,
@@ -761,3 +976,24 @@ Be thorough and test all discovered endpoints aggressively.
scan.total_endpoints = result.scalar() or 0
await self.db.commit()
async def _increment_vulnerability_count(self, scan: Scan, severity: str):
"""Increment vulnerability count for a severity level and broadcast update"""
# Increment the appropriate counter
severity_lower = severity.lower()
if severity_lower in ["critical", "high", "medium", "low", "info"]:
current = getattr(scan, f"{severity_lower}_count", 0)
setattr(scan, f"{severity_lower}_count", current + 1)
scan.total_vulnerabilities = (scan.total_vulnerabilities or 0) + 1
await self.db.commit()
# Broadcast stats update
await ws_manager.broadcast_stats_update(scan.id, {
"total_vulnerabilities": scan.total_vulnerabilities,
"critical": scan.critical_count,
"high": scan.high_count,
"medium": scan.medium_count,
"low": scan.low_count,
"info": scan.info_count,
"total_endpoints": scan.total_endpoints
})