mirror of
https://github.com/CyberSecurityUP/NeuroSploit.git
synced 2026-02-12 14:02:45 +00:00
Add files via upload
This commit is contained in:
176
backend/api/v1/agent_tasks.py
Normal file
176
backend/api/v1/agent_tasks.py
Normal file
@@ -0,0 +1,176 @@
|
||||
"""
|
||||
NeuroSploit v3 - Agent Tasks API Endpoints
|
||||
"""
|
||||
from typing import Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, func
|
||||
|
||||
from backend.db.database import get_db
|
||||
from backend.models import AgentTask, Scan
|
||||
from backend.schemas.agent_task import (
|
||||
AgentTaskResponse,
|
||||
AgentTaskListResponse,
|
||||
AgentTaskSummary
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("", response_model=AgentTaskListResponse)
|
||||
async def list_agent_tasks(
|
||||
scan_id: str,
|
||||
status: Optional[str] = None,
|
||||
task_type: Optional[str] = None,
|
||||
page: int = 1,
|
||||
per_page: int = 50,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""List all agent tasks for a scan"""
|
||||
# Verify scan exists
|
||||
scan_result = await db.execute(select(Scan).where(Scan.id == scan_id))
|
||||
scan = scan_result.scalar_one_or_none()
|
||||
if not scan:
|
||||
raise HTTPException(status_code=404, detail="Scan not found")
|
||||
|
||||
# Build query
|
||||
query = select(AgentTask).where(AgentTask.scan_id == scan_id)
|
||||
|
||||
if status:
|
||||
query = query.where(AgentTask.status == status)
|
||||
if task_type:
|
||||
query = query.where(AgentTask.task_type == task_type)
|
||||
|
||||
query = query.order_by(AgentTask.created_at.desc())
|
||||
|
||||
# Get total count
|
||||
count_query = select(func.count()).select_from(AgentTask).where(AgentTask.scan_id == scan_id)
|
||||
if status:
|
||||
count_query = count_query.where(AgentTask.status == status)
|
||||
if task_type:
|
||||
count_query = count_query.where(AgentTask.task_type == task_type)
|
||||
total_result = await db.execute(count_query)
|
||||
total = total_result.scalar() or 0
|
||||
|
||||
# Apply pagination
|
||||
query = query.offset((page - 1) * per_page).limit(per_page)
|
||||
result = await db.execute(query)
|
||||
tasks = result.scalars().all()
|
||||
|
||||
return AgentTaskListResponse(
|
||||
tasks=[AgentTaskResponse(**t.to_dict()) for t in tasks],
|
||||
total=total,
|
||||
scan_id=scan_id
|
||||
)
|
||||
|
||||
|
||||
@router.get("/summary", response_model=AgentTaskSummary)
|
||||
async def get_agent_tasks_summary(
|
||||
scan_id: str,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Get summary statistics for agent tasks in a scan"""
|
||||
# Verify scan exists
|
||||
scan_result = await db.execute(select(Scan).where(Scan.id == scan_id))
|
||||
scan = scan_result.scalar_one_or_none()
|
||||
if not scan:
|
||||
raise HTTPException(status_code=404, detail="Scan not found")
|
||||
|
||||
# Total count
|
||||
total_result = await db.execute(
|
||||
select(func.count()).select_from(AgentTask).where(AgentTask.scan_id == scan_id)
|
||||
)
|
||||
total = total_result.scalar() or 0
|
||||
|
||||
# Count by status
|
||||
status_counts = {}
|
||||
for status in ["pending", "running", "completed", "failed"]:
|
||||
count_result = await db.execute(
|
||||
select(func.count()).select_from(AgentTask)
|
||||
.where(AgentTask.scan_id == scan_id)
|
||||
.where(AgentTask.status == status)
|
||||
)
|
||||
status_counts[status] = count_result.scalar() or 0
|
||||
|
||||
# Count by task type
|
||||
type_query = select(
|
||||
AgentTask.task_type,
|
||||
func.count(AgentTask.id).label("count")
|
||||
).where(AgentTask.scan_id == scan_id).group_by(AgentTask.task_type)
|
||||
type_result = await db.execute(type_query)
|
||||
by_type = {row[0]: row[1] for row in type_result.all()}
|
||||
|
||||
# Count by tool
|
||||
tool_query = select(
|
||||
AgentTask.tool_name,
|
||||
func.count(AgentTask.id).label("count")
|
||||
).where(AgentTask.scan_id == scan_id).where(AgentTask.tool_name.isnot(None)).group_by(AgentTask.tool_name)
|
||||
tool_result = await db.execute(tool_query)
|
||||
by_tool = {row[0]: row[1] for row in tool_result.all()}
|
||||
|
||||
return AgentTaskSummary(
|
||||
total=total,
|
||||
pending=status_counts.get("pending", 0),
|
||||
running=status_counts.get("running", 0),
|
||||
completed=status_counts.get("completed", 0),
|
||||
failed=status_counts.get("failed", 0),
|
||||
by_type=by_type,
|
||||
by_tool=by_tool
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{task_id}", response_model=AgentTaskResponse)
|
||||
async def get_agent_task(
|
||||
task_id: str,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Get a specific agent task by ID"""
|
||||
result = await db.execute(select(AgentTask).where(AgentTask.id == task_id))
|
||||
task = result.scalar_one_or_none()
|
||||
|
||||
if not task:
|
||||
raise HTTPException(status_code=404, detail="Agent task not found")
|
||||
|
||||
return AgentTaskResponse(**task.to_dict())
|
||||
|
||||
|
||||
@router.get("/scan/{scan_id}/timeline")
|
||||
async def get_agent_tasks_timeline(
|
||||
scan_id: str,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Get agent tasks as a timeline for visualization"""
|
||||
# Verify scan exists
|
||||
scan_result = await db.execute(select(Scan).where(Scan.id == scan_id))
|
||||
scan = scan_result.scalar_one_or_none()
|
||||
if not scan:
|
||||
raise HTTPException(status_code=404, detail="Scan not found")
|
||||
|
||||
# Get all tasks ordered by creation time
|
||||
query = select(AgentTask).where(AgentTask.scan_id == scan_id).order_by(AgentTask.created_at.asc())
|
||||
result = await db.execute(query)
|
||||
tasks = result.scalars().all()
|
||||
|
||||
timeline = []
|
||||
for task in tasks:
|
||||
timeline_item = {
|
||||
"id": task.id,
|
||||
"task_name": task.task_name,
|
||||
"task_type": task.task_type,
|
||||
"tool_name": task.tool_name,
|
||||
"status": task.status,
|
||||
"started_at": task.started_at.isoformat() if task.started_at else None,
|
||||
"completed_at": task.completed_at.isoformat() if task.completed_at else None,
|
||||
"duration_ms": task.duration_ms,
|
||||
"items_processed": task.items_processed,
|
||||
"items_found": task.items_found,
|
||||
"result_summary": task.result_summary,
|
||||
"error_message": task.error_message
|
||||
}
|
||||
timeline.append(timeline_item)
|
||||
|
||||
return {
|
||||
"scan_id": scan_id,
|
||||
"timeline": timeline,
|
||||
"total": len(timeline)
|
||||
}
|
||||
@@ -8,7 +8,7 @@ from sqlalchemy import select, func
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from backend.db.database import get_db
|
||||
from backend.models import Scan, Vulnerability, Endpoint
|
||||
from backend.models import Scan, Vulnerability, Endpoint, AgentTask, Report
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -20,18 +20,32 @@ async def get_dashboard_stats(db: AsyncSession = Depends(get_db)):
|
||||
total_scans_result = await db.execute(select(func.count()).select_from(Scan))
|
||||
total_scans = total_scans_result.scalar() or 0
|
||||
|
||||
# Running scans
|
||||
# Scans by status
|
||||
running_result = await db.execute(
|
||||
select(func.count()).select_from(Scan).where(Scan.status == "running")
|
||||
)
|
||||
running_scans = running_result.scalar() or 0
|
||||
|
||||
# Completed scans
|
||||
completed_result = await db.execute(
|
||||
select(func.count()).select_from(Scan).where(Scan.status == "completed")
|
||||
)
|
||||
completed_scans = completed_result.scalar() or 0
|
||||
|
||||
stopped_result = await db.execute(
|
||||
select(func.count()).select_from(Scan).where(Scan.status == "stopped")
|
||||
)
|
||||
stopped_scans = stopped_result.scalar() or 0
|
||||
|
||||
failed_result = await db.execute(
|
||||
select(func.count()).select_from(Scan).where(Scan.status == "failed")
|
||||
)
|
||||
failed_scans = failed_result.scalar() or 0
|
||||
|
||||
pending_result = await db.execute(
|
||||
select(func.count()).select_from(Scan).where(Scan.status == "pending")
|
||||
)
|
||||
pending_scans = pending_result.scalar() or 0
|
||||
|
||||
# Total vulnerabilities by severity
|
||||
vuln_counts = {}
|
||||
for severity in ["critical", "high", "medium", "low", "info"]:
|
||||
@@ -63,6 +77,9 @@ async def get_dashboard_stats(db: AsyncSession = Depends(get_db)):
|
||||
"total": total_scans,
|
||||
"running": running_scans,
|
||||
"completed": completed_scans,
|
||||
"stopped": stopped_scans,
|
||||
"failed": failed_scans,
|
||||
"pending": pending_scans,
|
||||
"recent": recent_scans
|
||||
},
|
||||
"vulnerabilities": {
|
||||
@@ -175,3 +192,108 @@ async def get_scan_history(
|
||||
history[date_str]["high"] += scan.high_count
|
||||
|
||||
return {"history": list(history.values())}
|
||||
|
||||
|
||||
@router.get("/agent-tasks")
|
||||
async def get_recent_agent_tasks(
|
||||
limit: int = 20,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Get recent agent tasks across all scans"""
|
||||
query = (
|
||||
select(AgentTask)
|
||||
.order_by(AgentTask.created_at.desc())
|
||||
.limit(limit)
|
||||
)
|
||||
result = await db.execute(query)
|
||||
tasks = result.scalars().all()
|
||||
|
||||
return {
|
||||
"agent_tasks": [t.to_dict() for t in tasks],
|
||||
"total": len(tasks)
|
||||
}
|
||||
|
||||
|
||||
@router.get("/activity-feed")
|
||||
async def get_activity_feed(
|
||||
limit: int = 30,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Get unified activity feed with all recent events"""
|
||||
activities = []
|
||||
|
||||
# Get recent scans
|
||||
scans_result = await db.execute(
|
||||
select(Scan).order_by(Scan.created_at.desc()).limit(limit // 3)
|
||||
)
|
||||
for scan in scans_result.scalars().all():
|
||||
activities.append({
|
||||
"type": "scan",
|
||||
"action": f"Scan {scan.status}",
|
||||
"title": scan.name or "Unnamed Scan",
|
||||
"description": f"{scan.total_vulnerabilities} vulnerabilities found",
|
||||
"status": scan.status,
|
||||
"severity": None,
|
||||
"timestamp": scan.created_at.isoformat(),
|
||||
"scan_id": scan.id,
|
||||
"link": f"/scan/{scan.id}"
|
||||
})
|
||||
|
||||
# Get recent vulnerabilities
|
||||
vulns_result = await db.execute(
|
||||
select(Vulnerability).order_by(Vulnerability.created_at.desc()).limit(limit // 3)
|
||||
)
|
||||
for vuln in vulns_result.scalars().all():
|
||||
activities.append({
|
||||
"type": "vulnerability",
|
||||
"action": "Vulnerability found",
|
||||
"title": vuln.title,
|
||||
"description": vuln.affected_endpoint or "",
|
||||
"status": None,
|
||||
"severity": vuln.severity,
|
||||
"timestamp": vuln.created_at.isoformat(),
|
||||
"scan_id": vuln.scan_id,
|
||||
"link": f"/scan/{vuln.scan_id}"
|
||||
})
|
||||
|
||||
# Get recent agent tasks
|
||||
tasks_result = await db.execute(
|
||||
select(AgentTask).order_by(AgentTask.created_at.desc()).limit(limit // 3)
|
||||
)
|
||||
for task in tasks_result.scalars().all():
|
||||
activities.append({
|
||||
"type": "agent_task",
|
||||
"action": f"Task {task.status}",
|
||||
"title": task.task_name,
|
||||
"description": task.result_summary or task.description or "",
|
||||
"status": task.status,
|
||||
"severity": None,
|
||||
"timestamp": task.created_at.isoformat(),
|
||||
"scan_id": task.scan_id,
|
||||
"link": f"/scan/{task.scan_id}"
|
||||
})
|
||||
|
||||
# Get recent reports
|
||||
reports_result = await db.execute(
|
||||
select(Report).order_by(Report.generated_at.desc()).limit(limit // 4)
|
||||
)
|
||||
for report in reports_result.scalars().all():
|
||||
activities.append({
|
||||
"type": "report",
|
||||
"action": "Report generated" if report.auto_generated else "Report created",
|
||||
"title": report.title or "Report",
|
||||
"description": f"{report.format.upper()} format",
|
||||
"status": "auto" if report.auto_generated else "manual",
|
||||
"severity": None,
|
||||
"timestamp": report.generated_at.isoformat(),
|
||||
"scan_id": report.scan_id,
|
||||
"link": f"/reports"
|
||||
})
|
||||
|
||||
# Sort all activities by timestamp (newest first)
|
||||
activities.sort(key=lambda x: x["timestamp"], reverse=True)
|
||||
|
||||
return {
|
||||
"activities": activities[:limit],
|
||||
"total": len(activities)
|
||||
}
|
||||
|
||||
@@ -20,14 +20,22 @@ router = APIRouter()
|
||||
@router.get("", response_model=ReportListResponse)
|
||||
async def list_reports(
|
||||
scan_id: Optional[str] = None,
|
||||
auto_generated: Optional[bool] = None,
|
||||
is_partial: Optional[bool] = None,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""List all reports"""
|
||||
"""List all reports with optional filtering"""
|
||||
query = select(Report).order_by(Report.generated_at.desc())
|
||||
|
||||
if scan_id:
|
||||
query = query.where(Report.scan_id == scan_id)
|
||||
|
||||
if auto_generated is not None:
|
||||
query = query.where(Report.auto_generated == auto_generated)
|
||||
|
||||
if is_partial is not None:
|
||||
query = query.where(Report.is_partial == is_partial)
|
||||
|
||||
result = await db.execute(query)
|
||||
reports = result.scalars().all()
|
||||
|
||||
|
||||
@@ -175,7 +175,9 @@ async def start_scan(
|
||||
|
||||
@router.post("/{scan_id}/stop")
|
||||
async def stop_scan(scan_id: str, db: AsyncSession = Depends(get_db)):
|
||||
"""Stop a running scan"""
|
||||
"""Stop a running scan and save partial results"""
|
||||
from backend.api.websocket import manager as ws_manager
|
||||
|
||||
result = await db.execute(select(Scan).where(Scan.id == scan_id))
|
||||
scan = result.scalar_one_or_none()
|
||||
|
||||
@@ -185,11 +187,76 @@ async def stop_scan(scan_id: str, db: AsyncSession = Depends(get_db)):
|
||||
if scan.status != "running":
|
||||
raise HTTPException(status_code=400, detail="Scan is not running")
|
||||
|
||||
# Update scan status
|
||||
scan.status = "stopped"
|
||||
scan.completed_at = datetime.utcnow()
|
||||
scan.current_phase = "stopped"
|
||||
|
||||
# Calculate duration
|
||||
if scan.started_at:
|
||||
duration = (scan.completed_at - scan.started_at).total_seconds()
|
||||
scan.duration = int(duration)
|
||||
|
||||
# Compute final vulnerability statistics from database
|
||||
for severity in ["critical", "high", "medium", "low", "info"]:
|
||||
count_result = await db.execute(
|
||||
select(func.count()).select_from(Vulnerability)
|
||||
.where(Vulnerability.scan_id == scan_id)
|
||||
.where(Vulnerability.severity == severity)
|
||||
)
|
||||
setattr(scan, f"{severity}_count", count_result.scalar() or 0)
|
||||
|
||||
# Get total vulnerability count
|
||||
total_vuln_result = await db.execute(
|
||||
select(func.count()).select_from(Vulnerability)
|
||||
.where(Vulnerability.scan_id == scan_id)
|
||||
)
|
||||
scan.total_vulnerabilities = total_vuln_result.scalar() or 0
|
||||
|
||||
# Get total endpoint count
|
||||
total_endpoint_result = await db.execute(
|
||||
select(func.count()).select_from(Endpoint)
|
||||
.where(Endpoint.scan_id == scan_id)
|
||||
)
|
||||
scan.total_endpoints = total_endpoint_result.scalar() or 0
|
||||
|
||||
await db.commit()
|
||||
|
||||
return {"message": "Scan stopped", "scan_id": scan_id}
|
||||
# Build summary for WebSocket broadcast
|
||||
summary = {
|
||||
"total_endpoints": scan.total_endpoints,
|
||||
"total_vulnerabilities": scan.total_vulnerabilities,
|
||||
"critical": scan.critical_count,
|
||||
"high": scan.high_count,
|
||||
"medium": scan.medium_count,
|
||||
"low": scan.low_count,
|
||||
"info": scan.info_count,
|
||||
"duration": scan.duration,
|
||||
"progress": scan.progress
|
||||
}
|
||||
|
||||
# Broadcast stop event via WebSocket
|
||||
await ws_manager.broadcast_scan_stopped(scan_id, summary)
|
||||
await ws_manager.broadcast_log(scan_id, "warning", "Scan stopped by user")
|
||||
await ws_manager.broadcast_log(scan_id, "info", f"Partial results: {scan.total_vulnerabilities} vulnerabilities found")
|
||||
|
||||
# Auto-generate partial report
|
||||
report_data = None
|
||||
try:
|
||||
from backend.services.report_service import auto_generate_report
|
||||
await ws_manager.broadcast_log(scan_id, "info", "Generating partial report...")
|
||||
report = await auto_generate_report(db, scan_id, is_partial=True)
|
||||
report_data = report.to_dict()
|
||||
await ws_manager.broadcast_log(scan_id, "info", f"Partial report generated: {report.title}")
|
||||
except Exception as report_error:
|
||||
await ws_manager.broadcast_log(scan_id, "warning", f"Failed to generate partial report: {str(report_error)}")
|
||||
|
||||
return {
|
||||
"message": "Scan stopped",
|
||||
"scan_id": scan_id,
|
||||
"summary": summary,
|
||||
"report": report_data
|
||||
}
|
||||
|
||||
|
||||
@router.get("/{scan_id}/status", response_model=ScanProgress)
|
||||
|
||||
@@ -142,6 +142,65 @@ class ConnectionManager:
|
||||
"summary": summary
|
||||
})
|
||||
|
||||
async def broadcast_scan_stopped(self, scan_id: str, summary: dict):
|
||||
"""Notify that a scan was stopped by user"""
|
||||
await self.send_to_scan(scan_id, {
|
||||
"type": "scan_stopped",
|
||||
"scan_id": scan_id,
|
||||
"status": "stopped",
|
||||
"summary": summary
|
||||
})
|
||||
|
||||
async def broadcast_scan_failed(self, scan_id: str, error: str, summary: dict = None):
|
||||
"""Notify that a scan has failed"""
|
||||
await self.send_to_scan(scan_id, {
|
||||
"type": "scan_failed",
|
||||
"scan_id": scan_id,
|
||||
"status": "failed",
|
||||
"error": error,
|
||||
"summary": summary or {}
|
||||
})
|
||||
|
||||
async def broadcast_stats_update(self, scan_id: str, stats: dict):
|
||||
"""Broadcast updated scan statistics"""
|
||||
await self.send_to_scan(scan_id, {
|
||||
"type": "stats_update",
|
||||
"scan_id": scan_id,
|
||||
"stats": stats
|
||||
})
|
||||
|
||||
async def broadcast_agent_task(self, scan_id: str, task: dict):
|
||||
"""Broadcast agent task update (created, started, completed, failed)"""
|
||||
await self.send_to_scan(scan_id, {
|
||||
"type": "agent_task",
|
||||
"scan_id": scan_id,
|
||||
"task": task
|
||||
})
|
||||
|
||||
async def broadcast_agent_task_started(self, scan_id: str, task: dict):
|
||||
"""Broadcast when an agent task starts"""
|
||||
await self.send_to_scan(scan_id, {
|
||||
"type": "agent_task_started",
|
||||
"scan_id": scan_id,
|
||||
"task": task
|
||||
})
|
||||
|
||||
async def broadcast_agent_task_completed(self, scan_id: str, task: dict):
|
||||
"""Broadcast when an agent task completes"""
|
||||
await self.send_to_scan(scan_id, {
|
||||
"type": "agent_task_completed",
|
||||
"scan_id": scan_id,
|
||||
"task": task
|
||||
})
|
||||
|
||||
async def broadcast_report_generated(self, scan_id: str, report: dict):
|
||||
"""Broadcast when a report is generated"""
|
||||
await self.send_to_scan(scan_id, {
|
||||
"type": "report_generated",
|
||||
"scan_id": scan_id,
|
||||
"report": report
|
||||
})
|
||||
|
||||
async def broadcast_error(self, scan_id: str, error: str):
|
||||
"""Notify an error occurred"""
|
||||
await self.send_to_scan(scan_id, {
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
"""
|
||||
NeuroSploit v3 - Database Configuration
|
||||
"""
|
||||
import logging
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
from sqlalchemy import text
|
||||
from backend.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
"""Base class for all models"""
|
||||
@@ -42,10 +46,114 @@ async def get_db() -> AsyncSession:
|
||||
await session.close()
|
||||
|
||||
|
||||
async def _run_migrations(conn):
|
||||
"""Run schema migrations to add missing columns"""
|
||||
try:
|
||||
# Check and add duration column to scans table
|
||||
result = await conn.execute(text("PRAGMA table_info(scans)"))
|
||||
columns = [row[1] for row in result.fetchall()]
|
||||
|
||||
if "duration" not in columns:
|
||||
logger.info("Adding 'duration' column to scans table...")
|
||||
await conn.execute(text("ALTER TABLE scans ADD COLUMN duration INTEGER"))
|
||||
|
||||
# Check and add columns to reports table
|
||||
result = await conn.execute(text("PRAGMA table_info(reports)"))
|
||||
columns = [row[1] for row in result.fetchall()]
|
||||
|
||||
if columns: # Table exists
|
||||
if "auto_generated" not in columns:
|
||||
logger.info("Adding 'auto_generated' column to reports table...")
|
||||
await conn.execute(text("ALTER TABLE reports ADD COLUMN auto_generated BOOLEAN DEFAULT 0"))
|
||||
|
||||
if "is_partial" not in columns:
|
||||
logger.info("Adding 'is_partial' column to reports table...")
|
||||
await conn.execute(text("ALTER TABLE reports ADD COLUMN is_partial BOOLEAN DEFAULT 0"))
|
||||
|
||||
# Check and add columns to vulnerabilities table
|
||||
result = await conn.execute(text("PRAGMA table_info(vulnerabilities)"))
|
||||
columns = [row[1] for row in result.fetchall()]
|
||||
|
||||
if columns: # Table exists
|
||||
if "test_id" not in columns:
|
||||
logger.info("Adding 'test_id' column to vulnerabilities table...")
|
||||
await conn.execute(text("ALTER TABLE vulnerabilities ADD COLUMN test_id VARCHAR(36)"))
|
||||
|
||||
if "poc_parameter" not in columns:
|
||||
logger.info("Adding 'poc_parameter' column to vulnerabilities table...")
|
||||
await conn.execute(text("ALTER TABLE vulnerabilities ADD COLUMN poc_parameter VARCHAR(500)"))
|
||||
|
||||
if "poc_evidence" not in columns:
|
||||
logger.info("Adding 'poc_evidence' column to vulnerabilities table...")
|
||||
await conn.execute(text("ALTER TABLE vulnerabilities ADD COLUMN poc_evidence TEXT"))
|
||||
|
||||
# Check if agent_tasks table exists
|
||||
result = await conn.execute(
|
||||
text("SELECT name FROM sqlite_master WHERE type='table' AND name='agent_tasks'")
|
||||
)
|
||||
if not result.fetchone():
|
||||
logger.info("Creating 'agent_tasks' table...")
|
||||
await conn.execute(text("""
|
||||
CREATE TABLE agent_tasks (
|
||||
id VARCHAR(36) PRIMARY KEY,
|
||||
scan_id VARCHAR(36) NOT NULL,
|
||||
task_type VARCHAR(50) NOT NULL,
|
||||
task_name VARCHAR(255) NOT NULL,
|
||||
description TEXT,
|
||||
tool_name VARCHAR(100),
|
||||
tool_category VARCHAR(100),
|
||||
status VARCHAR(20) DEFAULT 'pending',
|
||||
started_at DATETIME,
|
||||
completed_at DATETIME,
|
||||
duration_ms INTEGER,
|
||||
items_processed INTEGER DEFAULT 0,
|
||||
items_found INTEGER DEFAULT 0,
|
||||
result_summary TEXT,
|
||||
error_message TEXT,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
FOREIGN KEY (scan_id) REFERENCES scans(id) ON DELETE CASCADE
|
||||
)
|
||||
"""))
|
||||
await conn.execute(text("CREATE INDEX IF NOT EXISTS idx_agent_tasks_scan_id ON agent_tasks(scan_id)"))
|
||||
await conn.execute(text("CREATE INDEX IF NOT EXISTS idx_agent_tasks_status ON agent_tasks(status)"))
|
||||
|
||||
# Check if vulnerability_tests table exists
|
||||
result = await conn.execute(
|
||||
text("SELECT name FROM sqlite_master WHERE type='table' AND name='vulnerability_tests'")
|
||||
)
|
||||
if not result.fetchone():
|
||||
logger.info("Creating 'vulnerability_tests' table...")
|
||||
await conn.execute(text("""
|
||||
CREATE TABLE vulnerability_tests (
|
||||
id VARCHAR(36) PRIMARY KEY,
|
||||
scan_id VARCHAR(36) NOT NULL,
|
||||
endpoint_id VARCHAR(36),
|
||||
vulnerability_type VARCHAR(100) NOT NULL,
|
||||
payload TEXT,
|
||||
request_data JSON DEFAULT '{}',
|
||||
response_data JSON DEFAULT '{}',
|
||||
is_vulnerable BOOLEAN DEFAULT 0,
|
||||
confidence FLOAT,
|
||||
evidence TEXT,
|
||||
tested_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
FOREIGN KEY (scan_id) REFERENCES scans(id) ON DELETE CASCADE,
|
||||
FOREIGN KEY (endpoint_id) REFERENCES endpoints(id) ON DELETE SET NULL
|
||||
)
|
||||
"""))
|
||||
await conn.execute(text("CREATE INDEX IF NOT EXISTS idx_vulnerability_tests_scan_id ON vulnerability_tests(scan_id)"))
|
||||
|
||||
logger.info("Database migrations completed")
|
||||
except Exception as e:
|
||||
logger.warning(f"Migration check failed (may be normal on first run): {e}")
|
||||
|
||||
|
||||
async def init_db():
|
||||
"""Initialize database tables"""
|
||||
"""Initialize database tables and run migrations"""
|
||||
async with engine.begin() as conn:
|
||||
# Create all tables from models
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
# Run migrations to add any missing columns
|
||||
await _run_migrations(conn)
|
||||
|
||||
|
||||
async def close_db():
|
||||
|
||||
@@ -11,7 +11,7 @@ from pathlib import Path
|
||||
|
||||
from backend.config import settings
|
||||
from backend.db.database import init_db, close_db
|
||||
from backend.api.v1 import scans, targets, prompts, reports, dashboard, vulnerabilities, settings as settings_router, agent
|
||||
from backend.api.v1 import scans, targets, prompts, reports, dashboard, vulnerabilities, settings as settings_router, agent, agent_tasks
|
||||
from backend.api.websocket import manager as ws_manager
|
||||
|
||||
|
||||
@@ -59,6 +59,7 @@ app.include_router(dashboard.router, prefix="/api/v1/dashboard", tags=["Dashboar
|
||||
app.include_router(vulnerabilities.router, prefix="/api/v1/vulnerabilities", tags=["Vulnerabilities"])
|
||||
app.include_router(settings_router.router, prefix="/api/v1/settings", tags=["Settings"])
|
||||
app.include_router(agent.router, prefix="/api/v1/agent", tags=["AI Agent"])
|
||||
app.include_router(agent_tasks.router, prefix="/api/v1/agent-tasks", tags=["Agent Tasks"])
|
||||
|
||||
|
||||
@app.get("/api/health")
|
||||
|
||||
36
backend/migrations/001_add_dashboard_integration.sql
Normal file
36
backend/migrations/001_add_dashboard_integration.sql
Normal file
@@ -0,0 +1,36 @@
|
||||
-- Migration: Add Dashboard Integration Columns
|
||||
-- Date: 2026-01-23
|
||||
-- Description: Adds duration column to scans, auto_generated/is_partial to reports, and creates agent_tasks table
|
||||
|
||||
-- Add duration column to scans table
|
||||
ALTER TABLE scans ADD COLUMN duration INTEGER;
|
||||
|
||||
-- Add auto_generated and is_partial columns to reports table
|
||||
ALTER TABLE reports ADD COLUMN auto_generated BOOLEAN DEFAULT 0;
|
||||
ALTER TABLE reports ADD COLUMN is_partial BOOLEAN DEFAULT 0;
|
||||
|
||||
-- Create agent_tasks table
|
||||
CREATE TABLE IF NOT EXISTS agent_tasks (
|
||||
id VARCHAR(36) PRIMARY KEY,
|
||||
scan_id VARCHAR(36) NOT NULL,
|
||||
task_type VARCHAR(50) NOT NULL,
|
||||
task_name VARCHAR(255) NOT NULL,
|
||||
description TEXT,
|
||||
tool_name VARCHAR(100),
|
||||
tool_category VARCHAR(100),
|
||||
status VARCHAR(20) DEFAULT 'pending',
|
||||
started_at DATETIME,
|
||||
completed_at DATETIME,
|
||||
duration_ms INTEGER,
|
||||
items_processed INTEGER DEFAULT 0,
|
||||
items_found INTEGER DEFAULT 0,
|
||||
result_summary TEXT,
|
||||
error_message TEXT,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
FOREIGN KEY (scan_id) REFERENCES scans(id) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
-- Create indexes for performance
|
||||
CREATE INDEX IF NOT EXISTS idx_agent_tasks_scan_id ON agent_tasks(scan_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_agent_tasks_status ON agent_tasks(status);
|
||||
CREATE INDEX IF NOT EXISTS idx_agent_tasks_task_type ON agent_tasks(task_type);
|
||||
4
backend/migrations/__init__.py
Normal file
4
backend/migrations/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
"""Database migrations for NeuroSploit v3"""
|
||||
from backend.migrations.run_migrations import run_migration, get_db_path
|
||||
|
||||
__all__ = ["run_migration", "get_db_path"]
|
||||
137
backend/migrations/run_migrations.py
Normal file
137
backend/migrations/run_migrations.py
Normal file
@@ -0,0 +1,137 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Run database migrations for NeuroSploit v3
|
||||
|
||||
Usage:
|
||||
python -m backend.migrations.run_migrations
|
||||
|
||||
Or from backend directory:
|
||||
python migrations/run_migrations.py
|
||||
"""
|
||||
import sqlite3
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def get_db_path():
|
||||
"""Get the database file path"""
|
||||
# Try common locations
|
||||
possible_paths = [
|
||||
Path("./data/neurosploit.db"),
|
||||
Path("../data/neurosploit.db"),
|
||||
Path("/opt/NeuroSploitv2/data/neurosploit.db"),
|
||||
Path("/opt/NeuroSploitv2/backend/data/neurosploit.db"),
|
||||
]
|
||||
|
||||
for path in possible_paths:
|
||||
if path.exists():
|
||||
return str(path.resolve())
|
||||
|
||||
# Default path
|
||||
return "./data/neurosploit.db"
|
||||
|
||||
|
||||
def column_exists(cursor, table_name, column_name):
|
||||
"""Check if a column exists in a table"""
|
||||
cursor.execute(f"PRAGMA table_info({table_name})")
|
||||
columns = [row[1] for row in cursor.fetchall()]
|
||||
return column_name in columns
|
||||
|
||||
|
||||
def table_exists(cursor, table_name):
|
||||
"""Check if a table exists"""
|
||||
cursor.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name=?",
|
||||
(table_name,)
|
||||
)
|
||||
return cursor.fetchone() is not None
|
||||
|
||||
|
||||
def run_migration(db_path: str):
|
||||
"""Run the database migration"""
|
||||
print(f"Running migration on database: {db_path}")
|
||||
|
||||
if not os.path.exists(db_path):
|
||||
print(f"Database file not found at {db_path}")
|
||||
print("Creating data directory and database will be created on first run")
|
||||
os.makedirs(os.path.dirname(db_path), exist_ok=True)
|
||||
return
|
||||
|
||||
conn = sqlite3.connect(db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
try:
|
||||
# Migration 1: Add duration column to scans table
|
||||
if not column_exists(cursor, "scans", "duration"):
|
||||
print("Adding 'duration' column to scans table...")
|
||||
cursor.execute("ALTER TABLE scans ADD COLUMN duration INTEGER")
|
||||
print(" Done!")
|
||||
else:
|
||||
print("Column 'duration' already exists in scans table")
|
||||
|
||||
# Migration 2: Add auto_generated column to reports table
|
||||
if table_exists(cursor, "reports"):
|
||||
if not column_exists(cursor, "reports", "auto_generated"):
|
||||
print("Adding 'auto_generated' column to reports table...")
|
||||
cursor.execute("ALTER TABLE reports ADD COLUMN auto_generated BOOLEAN DEFAULT 0")
|
||||
print(" Done!")
|
||||
else:
|
||||
print("Column 'auto_generated' already exists in reports table")
|
||||
|
||||
# Migration 3: Add is_partial column to reports table
|
||||
if not column_exists(cursor, "reports", "is_partial"):
|
||||
print("Adding 'is_partial' column to reports table...")
|
||||
cursor.execute("ALTER TABLE reports ADD COLUMN is_partial BOOLEAN DEFAULT 0")
|
||||
print(" Done!")
|
||||
else:
|
||||
print("Column 'is_partial' already exists in reports table")
|
||||
else:
|
||||
print("Reports table does not exist yet, will be created on first run")
|
||||
|
||||
# Migration 4: Create agent_tasks table
|
||||
if not table_exists(cursor, "agent_tasks"):
|
||||
print("Creating 'agent_tasks' table...")
|
||||
cursor.execute("""
|
||||
CREATE TABLE agent_tasks (
|
||||
id VARCHAR(36) PRIMARY KEY,
|
||||
scan_id VARCHAR(36) NOT NULL,
|
||||
task_type VARCHAR(50) NOT NULL,
|
||||
task_name VARCHAR(255) NOT NULL,
|
||||
description TEXT,
|
||||
tool_name VARCHAR(100),
|
||||
tool_category VARCHAR(100),
|
||||
status VARCHAR(20) DEFAULT 'pending',
|
||||
started_at DATETIME,
|
||||
completed_at DATETIME,
|
||||
duration_ms INTEGER,
|
||||
items_processed INTEGER DEFAULT 0,
|
||||
items_found INTEGER DEFAULT 0,
|
||||
result_summary TEXT,
|
||||
error_message TEXT,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
FOREIGN KEY (scan_id) REFERENCES scans(id) ON DELETE CASCADE
|
||||
)
|
||||
""")
|
||||
|
||||
# Create indexes
|
||||
cursor.execute("CREATE INDEX idx_agent_tasks_scan_id ON agent_tasks(scan_id)")
|
||||
cursor.execute("CREATE INDEX idx_agent_tasks_status ON agent_tasks(status)")
|
||||
cursor.execute("CREATE INDEX idx_agent_tasks_task_type ON agent_tasks(task_type)")
|
||||
print(" Done!")
|
||||
else:
|
||||
print("Table 'agent_tasks' already exists")
|
||||
|
||||
conn.commit()
|
||||
print("\nMigration completed successfully!")
|
||||
|
||||
except Exception as e:
|
||||
conn.rollback()
|
||||
print(f"\nMigration failed: {e}")
|
||||
raise
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
db_path = get_db_path()
|
||||
run_migration(db_path)
|
||||
@@ -4,6 +4,7 @@ from backend.models.prompt import Prompt
|
||||
from backend.models.endpoint import Endpoint
|
||||
from backend.models.vulnerability import Vulnerability, VulnerabilityTest
|
||||
from backend.models.report import Report
|
||||
from backend.models.agent_task import AgentTask
|
||||
|
||||
__all__ = [
|
||||
"Scan",
|
||||
@@ -12,5 +13,6 @@ __all__ = [
|
||||
"Endpoint",
|
||||
"Vulnerability",
|
||||
"VulnerabilityTest",
|
||||
"Report"
|
||||
"Report",
|
||||
"AgentTask"
|
||||
]
|
||||
|
||||
94
backend/models/agent_task.py
Normal file
94
backend/models/agent_task.py
Normal file
@@ -0,0 +1,94 @@
|
||||
"""
|
||||
NeuroSploit v3 - Agent Task Model
|
||||
|
||||
Tracks all agent activities during scans for dashboard visibility.
|
||||
"""
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from sqlalchemy import String, Integer, DateTime, Text, ForeignKey
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
from backend.db.database import Base
|
||||
import uuid
|
||||
|
||||
|
||||
class AgentTask(Base):
|
||||
"""Agent task record for tracking scan activities"""
|
||||
__tablename__ = "agent_tasks"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
scan_id: Mapped[str] = mapped_column(String(36), ForeignKey("scans.id", ondelete="CASCADE"))
|
||||
|
||||
# Task identification
|
||||
task_type: Mapped[str] = mapped_column(String(50)) # recon, analysis, testing, reporting
|
||||
task_name: Mapped[str] = mapped_column(String(255)) # Human-readable name
|
||||
description: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
|
||||
# Tool information
|
||||
tool_name: Mapped[Optional[str]] = mapped_column(String(100), nullable=True) # nmap, nuclei, claude, httpx, etc.
|
||||
tool_category: Mapped[Optional[str]] = mapped_column(String(50), nullable=True) # scanner, analyzer, ai, crawler
|
||||
|
||||
# Status tracking
|
||||
status: Mapped[str] = mapped_column(String(20), default="pending") # pending, running, completed, failed, cancelled
|
||||
|
||||
# Timing
|
||||
started_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
|
||||
completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
|
||||
duration_ms: Mapped[Optional[int]] = mapped_column(Integer, nullable=True) # Duration in milliseconds
|
||||
|
||||
# Results
|
||||
items_processed: Mapped[int] = mapped_column(Integer, default=0) # URLs tested, hosts scanned, etc.
|
||||
items_found: Mapped[int] = mapped_column(Integer, default=0) # Endpoints found, vulns found, etc.
|
||||
result_summary: Mapped[Optional[str]] = mapped_column(Text, nullable=True) # Brief summary of results
|
||||
|
||||
# Error handling
|
||||
error_message: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
|
||||
# Metadata
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
|
||||
|
||||
# Relationships
|
||||
scan: Mapped["Scan"] = relationship("Scan", back_populates="agent_tasks")
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert to dictionary"""
|
||||
return {
|
||||
"id": self.id,
|
||||
"scan_id": self.scan_id,
|
||||
"task_type": self.task_type,
|
||||
"task_name": self.task_name,
|
||||
"description": self.description,
|
||||
"tool_name": self.tool_name,
|
||||
"tool_category": self.tool_category,
|
||||
"status": self.status,
|
||||
"started_at": self.started_at.isoformat() if self.started_at else None,
|
||||
"completed_at": self.completed_at.isoformat() if self.completed_at else None,
|
||||
"duration_ms": self.duration_ms,
|
||||
"items_processed": self.items_processed,
|
||||
"items_found": self.items_found,
|
||||
"result_summary": self.result_summary,
|
||||
"error_message": self.error_message,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None
|
||||
}
|
||||
|
||||
def start(self):
|
||||
"""Mark task as started"""
|
||||
self.status = "running"
|
||||
self.started_at = datetime.utcnow()
|
||||
|
||||
def complete(self, items_processed: int = 0, items_found: int = 0, summary: str = None):
|
||||
"""Mark task as completed"""
|
||||
self.status = "completed"
|
||||
self.completed_at = datetime.utcnow()
|
||||
self.items_processed = items_processed
|
||||
self.items_found = items_found
|
||||
self.result_summary = summary
|
||||
if self.started_at:
|
||||
self.duration_ms = int((self.completed_at - self.started_at).total_seconds() * 1000)
|
||||
|
||||
def fail(self, error: str):
|
||||
"""Mark task as failed"""
|
||||
self.status = "failed"
|
||||
self.completed_at = datetime.utcnow()
|
||||
self.error_message = error
|
||||
if self.started_at:
|
||||
self.duration_ms = int((self.completed_at - self.started_at).total_seconds() * 1000)
|
||||
@@ -3,7 +3,7 @@ NeuroSploit v3 - Report Model
|
||||
"""
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from sqlalchemy import String, DateTime, Text, ForeignKey
|
||||
from sqlalchemy import String, DateTime, Text, ForeignKey, Boolean
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
from backend.db.database import Base
|
||||
import uuid
|
||||
@@ -24,6 +24,10 @@ class Report(Base):
|
||||
# Content
|
||||
executive_summary: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
|
||||
# Auto-generation flags
|
||||
auto_generated: Mapped[bool] = mapped_column(Boolean, default=False) # True if auto-generated on scan completion/stop
|
||||
is_partial: Mapped[bool] = mapped_column(Boolean, default=False) # True if generated from stopped/incomplete scan
|
||||
|
||||
# Timestamps
|
||||
generated_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
|
||||
|
||||
@@ -39,5 +43,7 @@ class Report(Base):
|
||||
"format": self.format,
|
||||
"file_path": self.file_path,
|
||||
"executive_summary": self.executive_summary,
|
||||
"auto_generated": self.auto_generated,
|
||||
"is_partial": self.is_partial,
|
||||
"generated_at": self.generated_at.isoformat() if self.generated_at else None
|
||||
}
|
||||
|
||||
@@ -39,6 +39,7 @@ class Scan(Base):
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
|
||||
started_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
|
||||
completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
|
||||
duration: Mapped[Optional[int]] = mapped_column(Integer, nullable=True) # Duration in seconds
|
||||
|
||||
# Error handling
|
||||
error_message: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
@@ -57,6 +58,7 @@ class Scan(Base):
|
||||
endpoints: Mapped[List["Endpoint"]] = relationship("Endpoint", back_populates="scan", cascade="all, delete-orphan")
|
||||
vulnerabilities: Mapped[List["Vulnerability"]] = relationship("Vulnerability", back_populates="scan", cascade="all, delete-orphan")
|
||||
reports: Mapped[List["Report"]] = relationship("Report", back_populates="scan", cascade="all, delete-orphan")
|
||||
agent_tasks: Mapped[List["AgentTask"]] = relationship("AgentTask", back_populates="scan", cascade="all, delete-orphan")
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert to dictionary"""
|
||||
@@ -77,6 +79,7 @@ class Scan(Base):
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
"started_at": self.started_at.isoformat() if self.started_at else None,
|
||||
"completed_at": self.completed_at.isoformat() if self.completed_at else None,
|
||||
"duration": self.duration,
|
||||
"error_message": self.error_message,
|
||||
"total_endpoints": self.total_endpoints,
|
||||
"total_vulnerabilities": self.total_vulnerabilities,
|
||||
|
||||
@@ -27,11 +27,19 @@ from backend.schemas.report import (
|
||||
ReportResponse,
|
||||
ReportGenerate
|
||||
)
|
||||
from backend.schemas.agent_task import (
|
||||
AgentTaskCreate,
|
||||
AgentTaskUpdate,
|
||||
AgentTaskResponse,
|
||||
AgentTaskListResponse,
|
||||
AgentTaskSummary
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ScanCreate", "ScanUpdate", "ScanResponse", "ScanListResponse", "ScanProgress",
|
||||
"TargetCreate", "TargetResponse", "TargetBulkCreate", "TargetValidation",
|
||||
"PromptCreate", "PromptUpdate", "PromptResponse", "PromptParse", "PromptParseResult",
|
||||
"VulnerabilityResponse", "VulnerabilityTestResponse", "VulnerabilityTypeInfo",
|
||||
"ReportResponse", "ReportGenerate"
|
||||
"ReportResponse", "ReportGenerate",
|
||||
"AgentTaskCreate", "AgentTaskUpdate", "AgentTaskResponse", "AgentTaskListResponse", "AgentTaskSummary"
|
||||
]
|
||||
|
||||
66
backend/schemas/agent_task.py
Normal file
66
backend/schemas/agent_task.py
Normal file
@@ -0,0 +1,66 @@
|
||||
"""
|
||||
NeuroSploit v3 - Agent Task Schemas
|
||||
"""
|
||||
from datetime import datetime
|
||||
from typing import Optional, List
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class AgentTaskCreate(BaseModel):
|
||||
"""Schema for creating an agent task"""
|
||||
scan_id: str = Field(..., description="Scan ID this task belongs to")
|
||||
task_type: str = Field(..., description="Task type: recon, analysis, testing, reporting")
|
||||
task_name: str = Field(..., description="Human-readable task name")
|
||||
description: Optional[str] = Field(None, description="Task description")
|
||||
tool_name: Optional[str] = Field(None, description="Tool being used")
|
||||
tool_category: Optional[str] = Field(None, description="Tool category")
|
||||
|
||||
|
||||
class AgentTaskUpdate(BaseModel):
|
||||
"""Schema for updating an agent task"""
|
||||
status: Optional[str] = Field(None, description="Task status")
|
||||
items_processed: Optional[int] = Field(None, description="Items processed")
|
||||
items_found: Optional[int] = Field(None, description="Items found")
|
||||
result_summary: Optional[str] = Field(None, description="Result summary")
|
||||
error_message: Optional[str] = Field(None, description="Error message if failed")
|
||||
|
||||
|
||||
class AgentTaskResponse(BaseModel):
|
||||
"""Schema for agent task response"""
|
||||
id: str
|
||||
scan_id: str
|
||||
task_type: str
|
||||
task_name: str
|
||||
description: Optional[str]
|
||||
tool_name: Optional[str]
|
||||
tool_category: Optional[str]
|
||||
status: str
|
||||
started_at: Optional[datetime]
|
||||
completed_at: Optional[datetime]
|
||||
duration_ms: Optional[int]
|
||||
items_processed: int
|
||||
items_found: int
|
||||
result_summary: Optional[str]
|
||||
error_message: Optional[str]
|
||||
created_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class AgentTaskListResponse(BaseModel):
|
||||
"""Schema for list of agent tasks"""
|
||||
tasks: List[AgentTaskResponse]
|
||||
total: int
|
||||
scan_id: str
|
||||
|
||||
|
||||
class AgentTaskSummary(BaseModel):
|
||||
"""Schema for agent task summary statistics"""
|
||||
total: int
|
||||
pending: int
|
||||
running: int
|
||||
completed: int
|
||||
failed: int
|
||||
by_type: dict # recon, analysis, testing, reporting counts
|
||||
by_tool: dict # tool name -> count
|
||||
@@ -24,6 +24,8 @@ class ReportResponse(BaseModel):
|
||||
format: str
|
||||
file_path: Optional[str]
|
||||
executive_summary: Optional[str]
|
||||
auto_generated: bool = False
|
||||
is_partial: bool = False
|
||||
generated_at: datetime
|
||||
|
||||
class Config:
|
||||
|
||||
105
backend/services/report_service.py
Normal file
105
backend/services/report_service.py
Normal file
@@ -0,0 +1,105 @@
|
||||
"""
|
||||
NeuroSploit v3 - Report Service
|
||||
|
||||
Handles automatic report generation on scan completion/stop.
|
||||
"""
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
|
||||
from backend.models import Scan, Report, Vulnerability
|
||||
from backend.core.report_engine.generator import ReportGenerator
|
||||
from backend.api.websocket import manager as ws_manager
|
||||
|
||||
|
||||
class ReportService:
|
||||
"""Service for automatic report generation"""
|
||||
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
self.generator = ReportGenerator()
|
||||
|
||||
async def auto_generate_report(
|
||||
self,
|
||||
scan_id: str,
|
||||
is_partial: bool = False,
|
||||
format: str = "html"
|
||||
) -> Report:
|
||||
"""
|
||||
Automatically generate a report for a scan.
|
||||
|
||||
Args:
|
||||
scan_id: The scan ID to generate report for
|
||||
is_partial: True if scan was stopped/incomplete
|
||||
format: Report format (html, pdf, json)
|
||||
|
||||
Returns:
|
||||
The generated Report model instance
|
||||
"""
|
||||
# Get scan
|
||||
scan_result = await self.db.execute(select(Scan).where(Scan.id == scan_id))
|
||||
scan = scan_result.scalar_one_or_none()
|
||||
|
||||
if not scan:
|
||||
raise ValueError(f"Scan {scan_id} not found")
|
||||
|
||||
# Get vulnerabilities
|
||||
vulns_result = await self.db.execute(
|
||||
select(Vulnerability).where(Vulnerability.scan_id == scan_id)
|
||||
)
|
||||
vulnerabilities = vulns_result.scalars().all()
|
||||
|
||||
# Generate title
|
||||
if is_partial:
|
||||
title = f"Partial Report - {scan.name or 'Unnamed Scan'}"
|
||||
else:
|
||||
title = f"Security Assessment Report - {scan.name or 'Unnamed Scan'}"
|
||||
|
||||
# Generate report
|
||||
try:
|
||||
report_path, executive_summary = await self.generator.generate(
|
||||
scan=scan,
|
||||
vulnerabilities=vulnerabilities,
|
||||
format=format,
|
||||
title=title,
|
||||
include_executive_summary=True,
|
||||
include_poc=True,
|
||||
include_remediation=True
|
||||
)
|
||||
|
||||
# Create report record
|
||||
report = Report(
|
||||
scan_id=scan_id,
|
||||
title=title,
|
||||
format=format,
|
||||
file_path=str(report_path) if report_path else None,
|
||||
executive_summary=executive_summary,
|
||||
auto_generated=True,
|
||||
is_partial=is_partial
|
||||
)
|
||||
self.db.add(report)
|
||||
await self.db.commit()
|
||||
await self.db.refresh(report)
|
||||
|
||||
# Broadcast report generated event
|
||||
await ws_manager.broadcast_report_generated(scan_id, report.to_dict())
|
||||
await ws_manager.broadcast_log(
|
||||
scan_id,
|
||||
"info",
|
||||
f"Report auto-generated: {title}"
|
||||
)
|
||||
|
||||
return report
|
||||
|
||||
except Exception as e:
|
||||
await ws_manager.broadcast_log(
|
||||
scan_id,
|
||||
"error",
|
||||
f"Failed to auto-generate report: {str(e)}"
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
async def auto_generate_report(db: AsyncSession, scan_id: str, is_partial: bool = False) -> Report:
|
||||
"""Helper function to auto-generate a report"""
|
||||
service = ReportService(db)
|
||||
return await service.auto_generate_report(scan_id, is_partial)
|
||||
@@ -19,7 +19,7 @@ from typing import Optional, List, Dict, Any
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
|
||||
from backend.models import Scan, Target, Endpoint, Vulnerability, VulnerabilityTest
|
||||
from backend.models import Scan, Target, Endpoint, Vulnerability, VulnerabilityTest, AgentTask
|
||||
from backend.api.websocket import manager as ws_manager
|
||||
from backend.api.v1.prompts import PRESET_PROMPTS
|
||||
from backend.db.database import async_session_factory
|
||||
@@ -72,6 +72,55 @@ class ScanService:
|
||||
self.payload_generator = PayloadGenerator()
|
||||
self._stop_requested = False
|
||||
|
||||
async def _create_agent_task(
|
||||
self,
|
||||
scan_id: str,
|
||||
task_type: str,
|
||||
task_name: str,
|
||||
description: str = None,
|
||||
tool_name: str = None,
|
||||
tool_category: str = None
|
||||
) -> AgentTask:
|
||||
"""Create and start a new agent task"""
|
||||
task = AgentTask(
|
||||
scan_id=scan_id,
|
||||
task_type=task_type,
|
||||
task_name=task_name,
|
||||
description=description,
|
||||
tool_name=tool_name,
|
||||
tool_category=tool_category
|
||||
)
|
||||
task.start()
|
||||
self.db.add(task)
|
||||
await self.db.flush()
|
||||
|
||||
# Broadcast task started
|
||||
await ws_manager.broadcast_agent_task_started(scan_id, task.to_dict())
|
||||
|
||||
return task
|
||||
|
||||
async def _complete_agent_task(
|
||||
self,
|
||||
task: AgentTask,
|
||||
items_processed: int = 0,
|
||||
items_found: int = 0,
|
||||
summary: str = None
|
||||
):
|
||||
"""Mark an agent task as completed"""
|
||||
task.complete(items_processed, items_found, summary)
|
||||
await self.db.commit()
|
||||
|
||||
# Broadcast task completed
|
||||
await ws_manager.broadcast_agent_task_completed(task.scan_id, task.to_dict())
|
||||
|
||||
async def _fail_agent_task(self, task: AgentTask, error: str):
|
||||
"""Mark an agent task as failed"""
|
||||
task.fail(error)
|
||||
await self.db.commit()
|
||||
|
||||
# Broadcast task update
|
||||
await ws_manager.broadcast_agent_task(task.scan_id, task.to_dict())
|
||||
|
||||
async def execute_scan(self, scan_id: str):
|
||||
"""Execute a complete scan with real recon, autonomous discovery, and AI analysis"""
|
||||
try:
|
||||
@@ -112,13 +161,29 @@ class ScanService:
|
||||
|
||||
await ws_manager.broadcast_log(scan_id, "info", f"Targets: {', '.join([t.url for t in targets])}")
|
||||
|
||||
# Check available tools
|
||||
# Check available tools - Create task for initialization
|
||||
init_task = await self._create_agent_task(
|
||||
scan_id=scan_id,
|
||||
task_type="recon",
|
||||
task_name="Initialize Security Tools",
|
||||
description="Checking available security tools and dependencies",
|
||||
tool_name="system",
|
||||
tool_category="setup"
|
||||
)
|
||||
|
||||
await ws_manager.broadcast_log(scan_id, "info", "")
|
||||
await ws_manager.broadcast_log(scan_id, "info", "Checking installed security tools...")
|
||||
tools_status = await check_tools_installed()
|
||||
installed_tools = [t for t, installed in tools_status.items() if installed]
|
||||
await ws_manager.broadcast_log(scan_id, "info", f"Available: {', '.join(installed_tools[:15])}...")
|
||||
|
||||
await self._complete_agent_task(
|
||||
init_task,
|
||||
items_processed=len(tools_status),
|
||||
items_found=len(installed_tools),
|
||||
summary=f"Found {len(installed_tools)} available security tools"
|
||||
)
|
||||
|
||||
# Get prompt content
|
||||
prompt_content = await self._get_prompt_content(scan)
|
||||
await ws_manager.broadcast_log(scan_id, "info", "")
|
||||
@@ -141,24 +206,46 @@ class ScanService:
|
||||
depth = "medium" if scan.scan_type == "full" else "quick"
|
||||
|
||||
for target in targets:
|
||||
await ws_manager.broadcast_log(scan_id, "info", f"Target: {target.url}")
|
||||
target_recon = await recon_integration.run_full_recon(target.url, depth=depth)
|
||||
recon_data = self._merge_recon_data(recon_data, target_recon)
|
||||
# Create recon task for each target
|
||||
recon_task = await self._create_agent_task(
|
||||
scan_id=scan_id,
|
||||
task_type="recon",
|
||||
task_name=f"Reconnaissance: {target.hostname or target.url[:30]}",
|
||||
description=f"Running {depth} reconnaissance on {target.url}",
|
||||
tool_name="recon_integration",
|
||||
tool_category="scanner"
|
||||
)
|
||||
|
||||
# Save discovered endpoints to database
|
||||
for endpoint_data in target_recon.get("endpoints", []):
|
||||
if isinstance(endpoint_data, dict):
|
||||
endpoint = Endpoint(
|
||||
scan_id=scan_id,
|
||||
target_id=target.id,
|
||||
url=endpoint_data.get("url", ""),
|
||||
method="GET",
|
||||
path=endpoint_data.get("path", "/"),
|
||||
response_status=endpoint_data.get("status"),
|
||||
content_type=endpoint_data.get("content_type", "")
|
||||
)
|
||||
self.db.add(endpoint)
|
||||
scan.total_endpoints += 1
|
||||
try:
|
||||
await ws_manager.broadcast_log(scan_id, "info", f"Target: {target.url}")
|
||||
target_recon = await recon_integration.run_full_recon(target.url, depth=depth)
|
||||
recon_data = self._merge_recon_data(recon_data, target_recon)
|
||||
|
||||
endpoints_found = 0
|
||||
# Save discovered endpoints to database
|
||||
for endpoint_data in target_recon.get("endpoints", []):
|
||||
if isinstance(endpoint_data, dict):
|
||||
endpoint = Endpoint(
|
||||
scan_id=scan_id,
|
||||
target_id=target.id,
|
||||
url=endpoint_data.get("url", ""),
|
||||
method="GET",
|
||||
path=endpoint_data.get("path", "/"),
|
||||
response_status=endpoint_data.get("status"),
|
||||
content_type=endpoint_data.get("content_type", "")
|
||||
)
|
||||
self.db.add(endpoint)
|
||||
scan.total_endpoints += 1
|
||||
endpoints_found += 1
|
||||
|
||||
await self._complete_agent_task(
|
||||
recon_task,
|
||||
items_processed=1,
|
||||
items_found=endpoints_found,
|
||||
summary=f"Found {endpoints_found} endpoints, {len(target_recon.get('urls', []))} URLs"
|
||||
)
|
||||
except Exception as e:
|
||||
await self._fail_agent_task(recon_task, str(e))
|
||||
|
||||
await self.db.commit()
|
||||
recon_endpoints = scan.total_endpoints
|
||||
@@ -181,60 +268,90 @@ class ScanService:
|
||||
await ws_manager.broadcast_log(scan_id, level, message)
|
||||
|
||||
for target in targets:
|
||||
async with AutonomousScanner(
|
||||
# Create autonomous discovery task
|
||||
discovery_task = await self._create_agent_task(
|
||||
scan_id=scan_id,
|
||||
log_callback=scanner_log,
|
||||
timeout=15,
|
||||
max_depth=3
|
||||
) as scanner:
|
||||
autonomous_results = await scanner.run_autonomous_scan(
|
||||
target_url=target.url,
|
||||
recon_data=recon_data
|
||||
)
|
||||
task_type="recon",
|
||||
task_name=f"Autonomous Discovery: {target.hostname or target.url[:30]}",
|
||||
description="AI-powered endpoint discovery and vulnerability scanning",
|
||||
tool_name="autonomous_scanner",
|
||||
tool_category="ai"
|
||||
)
|
||||
|
||||
# Merge autonomous results
|
||||
for ep in autonomous_results.get("endpoints", []):
|
||||
if isinstance(ep, dict):
|
||||
endpoint = Endpoint(
|
||||
scan_id=scan_id,
|
||||
target_id=target.id,
|
||||
url=ep.get("url", ""),
|
||||
method=ep.get("method", "GET"),
|
||||
path=ep.get("url", "").split("?")[0].split("/")[-1] or "/"
|
||||
)
|
||||
self.db.add(endpoint)
|
||||
scan.total_endpoints += 1
|
||||
try:
|
||||
endpoints_discovered = 0
|
||||
vulns_discovered = 0
|
||||
|
||||
# Add URLs to recon data
|
||||
recon_data["urls"] = recon_data.get("urls", []) + [
|
||||
ep.get("url") for ep in autonomous_results.get("endpoints", [])
|
||||
if isinstance(ep, dict)
|
||||
]
|
||||
recon_data["directories"] = autonomous_results.get("directories_found", [])
|
||||
recon_data["parameters"] = autonomous_results.get("parameters_found", [])
|
||||
|
||||
# Save autonomous vulnerabilities directly
|
||||
for vuln in autonomous_results.get("vulnerabilities", []):
|
||||
db_vuln = Vulnerability(
|
||||
scan_id=scan_id,
|
||||
title=f"{vuln['type'].replace('_', ' ').title()} on {vuln['endpoint'][:50]}",
|
||||
vulnerability_type=vuln["type"],
|
||||
severity=self._confidence_to_severity(vuln["confidence"]),
|
||||
description=vuln["evidence"],
|
||||
affected_endpoint=vuln["endpoint"],
|
||||
poc_payload=vuln["payload"],
|
||||
poc_request=str(vuln.get("request", {}))[:5000],
|
||||
poc_response=str(vuln.get("response", {}))[:5000]
|
||||
async with AutonomousScanner(
|
||||
scan_id=scan_id,
|
||||
log_callback=scanner_log,
|
||||
timeout=15,
|
||||
max_depth=3
|
||||
) as scanner:
|
||||
autonomous_results = await scanner.run_autonomous_scan(
|
||||
target_url=target.url,
|
||||
recon_data=recon_data
|
||||
)
|
||||
self.db.add(db_vuln)
|
||||
|
||||
await ws_manager.broadcast_vulnerability_found(scan_id, {
|
||||
"id": db_vuln.id,
|
||||
"title": db_vuln.title,
|
||||
"severity": db_vuln.severity,
|
||||
"type": vuln["type"],
|
||||
"endpoint": vuln["endpoint"]
|
||||
})
|
||||
# Merge autonomous results
|
||||
for ep in autonomous_results.get("endpoints", []):
|
||||
if isinstance(ep, dict):
|
||||
endpoint = Endpoint(
|
||||
scan_id=scan_id,
|
||||
target_id=target.id,
|
||||
url=ep.get("url", ""),
|
||||
method=ep.get("method", "GET"),
|
||||
path=ep.get("url", "").split("?")[0].split("/")[-1] or "/"
|
||||
)
|
||||
self.db.add(endpoint)
|
||||
scan.total_endpoints += 1
|
||||
endpoints_discovered += 1
|
||||
|
||||
# Add URLs to recon data
|
||||
recon_data["urls"] = recon_data.get("urls", []) + [
|
||||
ep.get("url") for ep in autonomous_results.get("endpoints", [])
|
||||
if isinstance(ep, dict)
|
||||
]
|
||||
recon_data["directories"] = autonomous_results.get("directories_found", [])
|
||||
recon_data["parameters"] = autonomous_results.get("parameters_found", [])
|
||||
|
||||
# Save autonomous vulnerabilities directly
|
||||
for vuln in autonomous_results.get("vulnerabilities", []):
|
||||
vuln_severity = self._confidence_to_severity(vuln["confidence"])
|
||||
db_vuln = Vulnerability(
|
||||
scan_id=scan_id,
|
||||
title=f"{vuln['type'].replace('_', ' ').title()} on {vuln['endpoint'][:50]}",
|
||||
vulnerability_type=vuln["type"],
|
||||
severity=vuln_severity,
|
||||
description=vuln["evidence"],
|
||||
affected_endpoint=vuln["endpoint"],
|
||||
poc_payload=vuln["payload"],
|
||||
poc_request=str(vuln.get("request", {}))[:5000],
|
||||
poc_response=str(vuln.get("response", {}))[:5000]
|
||||
)
|
||||
self.db.add(db_vuln)
|
||||
await self.db.flush() # Ensure ID is assigned
|
||||
vulns_discovered += 1
|
||||
|
||||
# Increment vulnerability count
|
||||
await self._increment_vulnerability_count(scan, vuln_severity)
|
||||
|
||||
await ws_manager.broadcast_vulnerability_found(scan_id, {
|
||||
"id": db_vuln.id,
|
||||
"title": db_vuln.title,
|
||||
"severity": db_vuln.severity,
|
||||
"type": vuln["type"],
|
||||
"endpoint": vuln["endpoint"]
|
||||
})
|
||||
|
||||
await self._complete_agent_task(
|
||||
discovery_task,
|
||||
items_processed=endpoints_discovered,
|
||||
items_found=vulns_discovered,
|
||||
summary=f"Discovered {endpoints_discovered} endpoints, {vulns_discovered} vulnerabilities"
|
||||
)
|
||||
except Exception as e:
|
||||
await self._fail_agent_task(discovery_task, str(e))
|
||||
|
||||
await self.db.commit()
|
||||
await ws_manager.broadcast_log(scan_id, "info", f"Autonomous discovery complete. Total endpoints: {scan.total_endpoints}")
|
||||
@@ -249,27 +366,48 @@ class ScanService:
|
||||
await ws_manager.broadcast_log(scan_id, "info", "PHASE 2: AI ANALYSIS")
|
||||
await ws_manager.broadcast_log(scan_id, "info", "=" * 40)
|
||||
|
||||
# Enhance prompt with authorization
|
||||
enhanced_prompt = f"{GLOBAL_AUTHORIZATION}\n\nUSER REQUEST:\n{prompt_content}"
|
||||
|
||||
# Get AI-generated testing plan
|
||||
await ws_manager.broadcast_log(scan_id, "info", "AI processing prompt and determining attack strategy...")
|
||||
|
||||
testing_plan = await self.ai_processor.process_prompt(
|
||||
prompt=enhanced_prompt,
|
||||
recon_data=recon_data,
|
||||
target_info={"targets": [t.url for t in targets]}
|
||||
# Create AI analysis task
|
||||
analysis_task = await self._create_agent_task(
|
||||
scan_id=scan_id,
|
||||
task_type="analysis",
|
||||
task_name="AI Strategy Analysis",
|
||||
description="Analyzing prompt and recon data to determine testing strategy",
|
||||
tool_name="ai_prompt_processor",
|
||||
tool_category="ai"
|
||||
)
|
||||
|
||||
await ws_manager.broadcast_log(scan_id, "info", "")
|
||||
await ws_manager.broadcast_log(scan_id, "info", "AI TESTING PLAN:")
|
||||
await ws_manager.broadcast_log(scan_id, "info", f" Vulnerability Types: {', '.join(testing_plan.vulnerability_types[:10])}")
|
||||
if len(testing_plan.vulnerability_types) > 10:
|
||||
await ws_manager.broadcast_log(scan_id, "info", f" ... and {len(testing_plan.vulnerability_types) - 10} more types")
|
||||
await ws_manager.broadcast_log(scan_id, "info", f" Testing Focus: {', '.join(testing_plan.testing_focus[:5])}")
|
||||
await ws_manager.broadcast_log(scan_id, "info", f" Depth: {testing_plan.testing_depth}")
|
||||
await ws_manager.broadcast_log(scan_id, "info", "")
|
||||
await ws_manager.broadcast_log(scan_id, "info", f"AI Reasoning: {testing_plan.ai_reasoning[:300]}...")
|
||||
try:
|
||||
# Enhance prompt with authorization
|
||||
enhanced_prompt = f"{GLOBAL_AUTHORIZATION}\n\nUSER REQUEST:\n{prompt_content}"
|
||||
|
||||
# Get AI-generated testing plan
|
||||
await ws_manager.broadcast_log(scan_id, "info", "AI processing prompt and determining attack strategy...")
|
||||
|
||||
testing_plan = await self.ai_processor.process_prompt(
|
||||
prompt=enhanced_prompt,
|
||||
recon_data=recon_data,
|
||||
target_info={"targets": [t.url for t in targets]}
|
||||
)
|
||||
|
||||
await ws_manager.broadcast_log(scan_id, "info", "")
|
||||
await ws_manager.broadcast_log(scan_id, "info", "AI TESTING PLAN:")
|
||||
await ws_manager.broadcast_log(scan_id, "info", f" Vulnerability Types: {', '.join(testing_plan.vulnerability_types[:10])}")
|
||||
if len(testing_plan.vulnerability_types) > 10:
|
||||
await ws_manager.broadcast_log(scan_id, "info", f" ... and {len(testing_plan.vulnerability_types) - 10} more types")
|
||||
await ws_manager.broadcast_log(scan_id, "info", f" Testing Focus: {', '.join(testing_plan.testing_focus[:5])}")
|
||||
await ws_manager.broadcast_log(scan_id, "info", f" Depth: {testing_plan.testing_depth}")
|
||||
await ws_manager.broadcast_log(scan_id, "info", "")
|
||||
await ws_manager.broadcast_log(scan_id, "info", f"AI Reasoning: {testing_plan.ai_reasoning[:300]}...")
|
||||
|
||||
await self._complete_agent_task(
|
||||
analysis_task,
|
||||
items_processed=1,
|
||||
items_found=len(testing_plan.vulnerability_types),
|
||||
summary=f"Generated testing plan with {len(testing_plan.vulnerability_types)} vulnerability types"
|
||||
)
|
||||
except Exception as e:
|
||||
await self._fail_agent_task(analysis_task, str(e))
|
||||
raise
|
||||
|
||||
await ws_manager.broadcast_progress(scan_id, 45, f"Testing {len(testing_plan.vulnerability_types)} vuln types")
|
||||
|
||||
@@ -286,48 +424,78 @@ class ScanService:
|
||||
for target in targets:
|
||||
await ws_manager.broadcast_log(scan_id, "info", f"Deploying AI Agent on: {target.url}")
|
||||
|
||||
# Create log callback for the agent
|
||||
async def agent_log(level: str, message: str):
|
||||
await ws_manager.broadcast_log(scan_id, level, message)
|
||||
# Create AI pentest agent task
|
||||
agent_task = await self._create_agent_task(
|
||||
scan_id=scan_id,
|
||||
task_type="testing",
|
||||
task_name=f"AI Pentest Agent: {target.hostname or target.url[:30]}",
|
||||
description=f"AI-powered penetration testing on {target.url}",
|
||||
tool_name="ai_pentest_agent",
|
||||
tool_category="ai"
|
||||
)
|
||||
|
||||
# Build auth headers
|
||||
auth_headers = self._build_auth_headers(scan)
|
||||
try:
|
||||
# Create log callback for the agent
|
||||
async def agent_log(level: str, message: str):
|
||||
await ws_manager.broadcast_log(scan_id, level, message)
|
||||
|
||||
async with AIPentestAgent(
|
||||
target=target.url,
|
||||
log_callback=agent_log,
|
||||
auth_headers=auth_headers,
|
||||
max_depth=5
|
||||
) as agent:
|
||||
agent_report = await agent.run()
|
||||
# Build auth headers
|
||||
auth_headers = self._build_auth_headers(scan)
|
||||
|
||||
# Save agent findings as vulnerabilities
|
||||
for finding in agent_report.get("findings", []):
|
||||
vuln = Vulnerability(
|
||||
scan_id=scan_id,
|
||||
title=f"{finding['type'].upper()} - {finding['endpoint'][:50]}",
|
||||
vulnerability_type=finding["type"],
|
||||
severity=finding["severity"],
|
||||
description=finding["evidence"],
|
||||
affected_endpoint=finding["endpoint"],
|
||||
poc_payload=finding["payload"],
|
||||
poc_request=finding.get("raw_request", "")[:5000],
|
||||
poc_response=finding.get("raw_response", "")[:5000],
|
||||
remediation=finding.get("impact", ""),
|
||||
ai_analysis="\n".join(finding.get("exploitation_steps", []))
|
||||
)
|
||||
self.db.add(vuln)
|
||||
findings_count = 0
|
||||
endpoints_tested = 0
|
||||
|
||||
await ws_manager.broadcast_vulnerability_found(scan_id, {
|
||||
"id": vuln.id,
|
||||
"title": vuln.title,
|
||||
"severity": vuln.severity,
|
||||
"type": finding["type"],
|
||||
"endpoint": finding["endpoint"]
|
||||
})
|
||||
async with AIPentestAgent(
|
||||
target=target.url,
|
||||
log_callback=agent_log,
|
||||
auth_headers=auth_headers,
|
||||
max_depth=5
|
||||
) as agent:
|
||||
agent_report = await agent.run()
|
||||
|
||||
# Update endpoint count
|
||||
scan.total_endpoints += agent_report.get("summary", {}).get("total_endpoints", 0)
|
||||
# Save agent findings as vulnerabilities
|
||||
for finding in agent_report.get("findings", []):
|
||||
finding_severity = finding["severity"]
|
||||
vuln = Vulnerability(
|
||||
scan_id=scan_id,
|
||||
title=f"{finding['type'].upper()} - {finding['endpoint'][:50]}",
|
||||
vulnerability_type=finding["type"],
|
||||
severity=finding_severity,
|
||||
description=finding["evidence"],
|
||||
affected_endpoint=finding["endpoint"],
|
||||
poc_payload=finding["payload"],
|
||||
poc_request=finding.get("raw_request", "")[:5000],
|
||||
poc_response=finding.get("raw_response", "")[:5000],
|
||||
remediation=finding.get("impact", ""),
|
||||
ai_analysis="\n".join(finding.get("exploitation_steps", []))
|
||||
)
|
||||
self.db.add(vuln)
|
||||
await self.db.flush() # Ensure ID is assigned
|
||||
findings_count += 1
|
||||
|
||||
# Increment vulnerability count
|
||||
await self._increment_vulnerability_count(scan, finding_severity)
|
||||
|
||||
await ws_manager.broadcast_vulnerability_found(scan_id, {
|
||||
"id": vuln.id,
|
||||
"title": vuln.title,
|
||||
"severity": vuln.severity,
|
||||
"type": finding["type"],
|
||||
"endpoint": finding["endpoint"]
|
||||
})
|
||||
|
||||
# Update endpoint count
|
||||
endpoints_tested = agent_report.get("summary", {}).get("total_endpoints", 0)
|
||||
scan.total_endpoints += endpoints_tested
|
||||
|
||||
await self._complete_agent_task(
|
||||
agent_task,
|
||||
items_processed=endpoints_tested,
|
||||
items_found=findings_count,
|
||||
summary=f"Tested {endpoints_tested} endpoints, found {findings_count} vulnerabilities"
|
||||
)
|
||||
except Exception as e:
|
||||
await self._fail_agent_task(agent_task, str(e))
|
||||
|
||||
await self.db.commit()
|
||||
|
||||
@@ -377,38 +545,70 @@ class ScanService:
|
||||
await ws_manager.broadcast_log(scan_id, "info", f"Testing {len(endpoints)} endpoints for {len(testing_plan.vulnerability_types)} vuln types")
|
||||
await ws_manager.broadcast_log(scan_id, "info", "")
|
||||
|
||||
# Test endpoints with AI-determined vulnerabilities
|
||||
total_endpoints = len(endpoints)
|
||||
async with DynamicVulnerabilityEngine() as engine:
|
||||
for i, endpoint in enumerate(endpoints):
|
||||
if self._stop_requested:
|
||||
break
|
||||
# Create vulnerability testing task
|
||||
vuln_testing_task = await self._create_agent_task(
|
||||
scan_id=scan_id,
|
||||
task_type="testing",
|
||||
task_name="Vulnerability Testing",
|
||||
description=f"Testing {len(endpoints)} endpoints for {len(testing_plan.vulnerability_types)} vulnerability types",
|
||||
tool_name="dynamic_vuln_engine",
|
||||
tool_category="scanner"
|
||||
)
|
||||
|
||||
progress = 45 + int((i / total_endpoints) * 45)
|
||||
await ws_manager.broadcast_progress(
|
||||
scan_id, progress,
|
||||
f"Testing {i+1}/{total_endpoints}: {endpoint.path or endpoint.url[:50]}"
|
||||
)
|
||||
try:
|
||||
# Test endpoints with AI-determined vulnerabilities
|
||||
total_endpoints = len(endpoints)
|
||||
endpoints_tested = 0
|
||||
vulns_before = scan.total_vulnerabilities
|
||||
|
||||
# Log what we're testing
|
||||
await ws_manager.broadcast_log(scan_id, "debug", f"[{i+1}/{total_endpoints}] Testing: {endpoint.url[:80]}")
|
||||
async with DynamicVulnerabilityEngine() as engine:
|
||||
for i, endpoint in enumerate(endpoints):
|
||||
if self._stop_requested:
|
||||
break
|
||||
|
||||
await self._test_endpoint_with_ai(
|
||||
scan=scan,
|
||||
endpoint=endpoint,
|
||||
testing_plan=testing_plan,
|
||||
engine=engine,
|
||||
recon_data=recon_data
|
||||
)
|
||||
progress = 45 + int((i / total_endpoints) * 45)
|
||||
await ws_manager.broadcast_progress(
|
||||
scan_id, progress,
|
||||
f"Testing {i+1}/{total_endpoints}: {endpoint.path or endpoint.url[:50]}"
|
||||
)
|
||||
|
||||
# Update counts
|
||||
await self._update_vulnerability_counts(scan)
|
||||
# Log what we're testing
|
||||
await ws_manager.broadcast_log(scan_id, "debug", f"[{i+1}/{total_endpoints}] Testing: {endpoint.url[:80]}")
|
||||
|
||||
await self._test_endpoint_with_ai(
|
||||
scan=scan,
|
||||
endpoint=endpoint,
|
||||
testing_plan=testing_plan,
|
||||
engine=engine,
|
||||
recon_data=recon_data
|
||||
)
|
||||
endpoints_tested += 1
|
||||
|
||||
# Update final counts
|
||||
await self._update_vulnerability_counts(scan)
|
||||
|
||||
vulns_found = scan.total_vulnerabilities - vulns_before
|
||||
await self._complete_agent_task(
|
||||
vuln_testing_task,
|
||||
items_processed=endpoints_tested,
|
||||
items_found=vulns_found,
|
||||
summary=f"Tested {endpoints_tested} endpoints, found {vulns_found} vulnerabilities"
|
||||
)
|
||||
except Exception as e:
|
||||
await self._fail_agent_task(vuln_testing_task, str(e))
|
||||
raise
|
||||
|
||||
# Phase 4: Complete
|
||||
scan.status = "completed"
|
||||
scan.completed_at = datetime.utcnow()
|
||||
scan.progress = 100
|
||||
scan.current_phase = "completed"
|
||||
|
||||
# Calculate duration
|
||||
if scan.started_at:
|
||||
duration = (scan.completed_at - scan.started_at).total_seconds()
|
||||
scan.duration = int(duration)
|
||||
|
||||
await self.db.commit()
|
||||
|
||||
await ws_manager.broadcast_log(scan_id, "info", "")
|
||||
@@ -432,6 +632,16 @@ class ScanService:
|
||||
"low": scan.low_count
|
||||
})
|
||||
|
||||
# Auto-generate report on completion
|
||||
try:
|
||||
from backend.services.report_service import auto_generate_report
|
||||
await ws_manager.broadcast_log(scan_id, "info", "")
|
||||
await ws_manager.broadcast_log(scan_id, "info", "Generating security assessment report...")
|
||||
report = await auto_generate_report(self.db, scan_id, is_partial=False)
|
||||
await ws_manager.broadcast_log(scan_id, "info", f"Report generated: {report.title}")
|
||||
except Exception as report_error:
|
||||
await ws_manager.broadcast_log(scan_id, "warning", f"Failed to auto-generate report: {str(report_error)}")
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
error_msg = f"Scan error: {str(e)}"
|
||||
@@ -559,11 +769,12 @@ Be thorough and test all discovered endpoints aggressively.
|
||||
|
||||
if confidence >= 0.5: # Lower threshold to catch more
|
||||
# Create vulnerability record
|
||||
vuln_severity = ai_analysis.get("severity", self._confidence_to_severity(confidence))
|
||||
vuln = Vulnerability(
|
||||
scan_id=scan.id,
|
||||
title=f"{vuln_type.replace('_', ' ').title()} on {endpoint.path or endpoint.url}",
|
||||
vulnerability_type=vuln_type,
|
||||
severity=ai_analysis.get("severity", self._confidence_to_severity(confidence)),
|
||||
severity=vuln_severity,
|
||||
description=ai_analysis.get("evidence", result.get("evidence", "")),
|
||||
affected_endpoint=endpoint.url,
|
||||
poc_payload=payload,
|
||||
@@ -573,6 +784,10 @@ Be thorough and test all discovered endpoints aggressively.
|
||||
ai_analysis=ai_analysis.get("exploitation_path", "")
|
||||
)
|
||||
self.db.add(vuln)
|
||||
await self.db.flush() # Ensure ID is assigned
|
||||
|
||||
# Increment vulnerability count
|
||||
await self._increment_vulnerability_count(scan, vuln_severity)
|
||||
|
||||
await ws_manager.broadcast_vulnerability_found(scan.id, {
|
||||
"id": vuln.id,
|
||||
@@ -761,3 +976,24 @@ Be thorough and test all discovered endpoints aggressively.
|
||||
scan.total_endpoints = result.scalar() or 0
|
||||
|
||||
await self.db.commit()
|
||||
|
||||
async def _increment_vulnerability_count(self, scan: Scan, severity: str):
|
||||
"""Increment vulnerability count for a severity level and broadcast update"""
|
||||
# Increment the appropriate counter
|
||||
severity_lower = severity.lower()
|
||||
if severity_lower in ["critical", "high", "medium", "low", "info"]:
|
||||
current = getattr(scan, f"{severity_lower}_count", 0)
|
||||
setattr(scan, f"{severity_lower}_count", current + 1)
|
||||
scan.total_vulnerabilities = (scan.total_vulnerabilities or 0) + 1
|
||||
await self.db.commit()
|
||||
|
||||
# Broadcast stats update
|
||||
await ws_manager.broadcast_stats_update(scan.id, {
|
||||
"total_vulnerabilities": scan.total_vulnerabilities,
|
||||
"critical": scan.critical_count,
|
||||
"high": scan.high_count,
|
||||
"medium": scan.medium_count,
|
||||
"low": scan.low_count,
|
||||
"info": scan.info_count,
|
||||
"total_endpoints": scan.total_endpoints
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user