Add files via upload

This commit is contained in:
Joas A Santos
2026-01-23 15:46:05 -03:00
committed by GitHub
parent 2a5e9b139a
commit f9e4ec16ec
19 changed files with 1398 additions and 158 deletions
+176
View File
@@ -0,0 +1,176 @@
"""
NeuroSploit v3 - Agent Tasks API Endpoints
"""
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, func
from backend.db.database import get_db
from backend.models import AgentTask, Scan
from backend.schemas.agent_task import (
AgentTaskResponse,
AgentTaskListResponse,
AgentTaskSummary
)
router = APIRouter()
@router.get("", response_model=AgentTaskListResponse)
async def list_agent_tasks(
scan_id: str,
status: Optional[str] = None,
task_type: Optional[str] = None,
page: int = 1,
per_page: int = 50,
db: AsyncSession = Depends(get_db)
):
"""List all agent tasks for a scan"""
# Verify scan exists
scan_result = await db.execute(select(Scan).where(Scan.id == scan_id))
scan = scan_result.scalar_one_or_none()
if not scan:
raise HTTPException(status_code=404, detail="Scan not found")
# Build query
query = select(AgentTask).where(AgentTask.scan_id == scan_id)
if status:
query = query.where(AgentTask.status == status)
if task_type:
query = query.where(AgentTask.task_type == task_type)
query = query.order_by(AgentTask.created_at.desc())
# Get total count
count_query = select(func.count()).select_from(AgentTask).where(AgentTask.scan_id == scan_id)
if status:
count_query = count_query.where(AgentTask.status == status)
if task_type:
count_query = count_query.where(AgentTask.task_type == task_type)
total_result = await db.execute(count_query)
total = total_result.scalar() or 0
# Apply pagination
query = query.offset((page - 1) * per_page).limit(per_page)
result = await db.execute(query)
tasks = result.scalars().all()
return AgentTaskListResponse(
tasks=[AgentTaskResponse(**t.to_dict()) for t in tasks],
total=total,
scan_id=scan_id
)
@router.get("/summary", response_model=AgentTaskSummary)
async def get_agent_tasks_summary(
scan_id: str,
db: AsyncSession = Depends(get_db)
):
"""Get summary statistics for agent tasks in a scan"""
# Verify scan exists
scan_result = await db.execute(select(Scan).where(Scan.id == scan_id))
scan = scan_result.scalar_one_or_none()
if not scan:
raise HTTPException(status_code=404, detail="Scan not found")
# Total count
total_result = await db.execute(
select(func.count()).select_from(AgentTask).where(AgentTask.scan_id == scan_id)
)
total = total_result.scalar() or 0
# Count by status
status_counts = {}
for status in ["pending", "running", "completed", "failed"]:
count_result = await db.execute(
select(func.count()).select_from(AgentTask)
.where(AgentTask.scan_id == scan_id)
.where(AgentTask.status == status)
)
status_counts[status] = count_result.scalar() or 0
# Count by task type
type_query = select(
AgentTask.task_type,
func.count(AgentTask.id).label("count")
).where(AgentTask.scan_id == scan_id).group_by(AgentTask.task_type)
type_result = await db.execute(type_query)
by_type = {row[0]: row[1] for row in type_result.all()}
# Count by tool
tool_query = select(
AgentTask.tool_name,
func.count(AgentTask.id).label("count")
).where(AgentTask.scan_id == scan_id).where(AgentTask.tool_name.isnot(None)).group_by(AgentTask.tool_name)
tool_result = await db.execute(tool_query)
by_tool = {row[0]: row[1] for row in tool_result.all()}
return AgentTaskSummary(
total=total,
pending=status_counts.get("pending", 0),
running=status_counts.get("running", 0),
completed=status_counts.get("completed", 0),
failed=status_counts.get("failed", 0),
by_type=by_type,
by_tool=by_tool
)
@router.get("/{task_id}", response_model=AgentTaskResponse)
async def get_agent_task(
task_id: str,
db: AsyncSession = Depends(get_db)
):
"""Get a specific agent task by ID"""
result = await db.execute(select(AgentTask).where(AgentTask.id == task_id))
task = result.scalar_one_or_none()
if not task:
raise HTTPException(status_code=404, detail="Agent task not found")
return AgentTaskResponse(**task.to_dict())
@router.get("/scan/{scan_id}/timeline")
async def get_agent_tasks_timeline(
scan_id: str,
db: AsyncSession = Depends(get_db)
):
"""Get agent tasks as a timeline for visualization"""
# Verify scan exists
scan_result = await db.execute(select(Scan).where(Scan.id == scan_id))
scan = scan_result.scalar_one_or_none()
if not scan:
raise HTTPException(status_code=404, detail="Scan not found")
# Get all tasks ordered by creation time
query = select(AgentTask).where(AgentTask.scan_id == scan_id).order_by(AgentTask.created_at.asc())
result = await db.execute(query)
tasks = result.scalars().all()
timeline = []
for task in tasks:
timeline_item = {
"id": task.id,
"task_name": task.task_name,
"task_type": task.task_type,
"tool_name": task.tool_name,
"status": task.status,
"started_at": task.started_at.isoformat() if task.started_at else None,
"completed_at": task.completed_at.isoformat() if task.completed_at else None,
"duration_ms": task.duration_ms,
"items_processed": task.items_processed,
"items_found": task.items_found,
"result_summary": task.result_summary,
"error_message": task.error_message
}
timeline.append(timeline_item)
return {
"scan_id": scan_id,
"timeline": timeline,
"total": len(timeline)
}
+125 -3
View File
@@ -8,7 +8,7 @@ from sqlalchemy import select, func
from datetime import datetime, timedelta
from backend.db.database import get_db
from backend.models import Scan, Vulnerability, Endpoint
from backend.models import Scan, Vulnerability, Endpoint, AgentTask, Report
router = APIRouter()
@@ -20,18 +20,32 @@ async def get_dashboard_stats(db: AsyncSession = Depends(get_db)):
total_scans_result = await db.execute(select(func.count()).select_from(Scan))
total_scans = total_scans_result.scalar() or 0
# Running scans
# Scans by status
running_result = await db.execute(
select(func.count()).select_from(Scan).where(Scan.status == "running")
)
running_scans = running_result.scalar() or 0
# Completed scans
completed_result = await db.execute(
select(func.count()).select_from(Scan).where(Scan.status == "completed")
)
completed_scans = completed_result.scalar() or 0
stopped_result = await db.execute(
select(func.count()).select_from(Scan).where(Scan.status == "stopped")
)
stopped_scans = stopped_result.scalar() or 0
failed_result = await db.execute(
select(func.count()).select_from(Scan).where(Scan.status == "failed")
)
failed_scans = failed_result.scalar() or 0
pending_result = await db.execute(
select(func.count()).select_from(Scan).where(Scan.status == "pending")
)
pending_scans = pending_result.scalar() or 0
# Total vulnerabilities by severity
vuln_counts = {}
for severity in ["critical", "high", "medium", "low", "info"]:
@@ -63,6 +77,9 @@ async def get_dashboard_stats(db: AsyncSession = Depends(get_db)):
"total": total_scans,
"running": running_scans,
"completed": completed_scans,
"stopped": stopped_scans,
"failed": failed_scans,
"pending": pending_scans,
"recent": recent_scans
},
"vulnerabilities": {
@@ -175,3 +192,108 @@ async def get_scan_history(
history[date_str]["high"] += scan.high_count
return {"history": list(history.values())}
@router.get("/agent-tasks")
async def get_recent_agent_tasks(
limit: int = 20,
db: AsyncSession = Depends(get_db)
):
"""Get recent agent tasks across all scans"""
query = (
select(AgentTask)
.order_by(AgentTask.created_at.desc())
.limit(limit)
)
result = await db.execute(query)
tasks = result.scalars().all()
return {
"agent_tasks": [t.to_dict() for t in tasks],
"total": len(tasks)
}
@router.get("/activity-feed")
async def get_activity_feed(
limit: int = 30,
db: AsyncSession = Depends(get_db)
):
"""Get unified activity feed with all recent events"""
activities = []
# Get recent scans
scans_result = await db.execute(
select(Scan).order_by(Scan.created_at.desc()).limit(limit // 3)
)
for scan in scans_result.scalars().all():
activities.append({
"type": "scan",
"action": f"Scan {scan.status}",
"title": scan.name or "Unnamed Scan",
"description": f"{scan.total_vulnerabilities} vulnerabilities found",
"status": scan.status,
"severity": None,
"timestamp": scan.created_at.isoformat(),
"scan_id": scan.id,
"link": f"/scan/{scan.id}"
})
# Get recent vulnerabilities
vulns_result = await db.execute(
select(Vulnerability).order_by(Vulnerability.created_at.desc()).limit(limit // 3)
)
for vuln in vulns_result.scalars().all():
activities.append({
"type": "vulnerability",
"action": "Vulnerability found",
"title": vuln.title,
"description": vuln.affected_endpoint or "",
"status": None,
"severity": vuln.severity,
"timestamp": vuln.created_at.isoformat(),
"scan_id": vuln.scan_id,
"link": f"/scan/{vuln.scan_id}"
})
# Get recent agent tasks
tasks_result = await db.execute(
select(AgentTask).order_by(AgentTask.created_at.desc()).limit(limit // 3)
)
for task in tasks_result.scalars().all():
activities.append({
"type": "agent_task",
"action": f"Task {task.status}",
"title": task.task_name,
"description": task.result_summary or task.description or "",
"status": task.status,
"severity": None,
"timestamp": task.created_at.isoformat(),
"scan_id": task.scan_id,
"link": f"/scan/{task.scan_id}"
})
# Get recent reports
reports_result = await db.execute(
select(Report).order_by(Report.generated_at.desc()).limit(limit // 4)
)
for report in reports_result.scalars().all():
activities.append({
"type": "report",
"action": "Report generated" if report.auto_generated else "Report created",
"title": report.title or "Report",
"description": f"{report.format.upper()} format",
"status": "auto" if report.auto_generated else "manual",
"severity": None,
"timestamp": report.generated_at.isoformat(),
"scan_id": report.scan_id,
"link": f"/reports"
})
# Sort all activities by timestamp (newest first)
activities.sort(key=lambda x: x["timestamp"], reverse=True)
return {
"activities": activities[:limit],
"total": len(activities)
}
+9 -1
View File
@@ -20,14 +20,22 @@ router = APIRouter()
@router.get("", response_model=ReportListResponse)
async def list_reports(
scan_id: Optional[str] = None,
auto_generated: Optional[bool] = None,
is_partial: Optional[bool] = None,
db: AsyncSession = Depends(get_db)
):
"""List all reports"""
"""List all reports with optional filtering"""
query = select(Report).order_by(Report.generated_at.desc())
if scan_id:
query = query.where(Report.scan_id == scan_id)
if auto_generated is not None:
query = query.where(Report.auto_generated == auto_generated)
if is_partial is not None:
query = query.where(Report.is_partial == is_partial)
result = await db.execute(query)
reports = result.scalars().all()
+69 -2
View File
@@ -175,7 +175,9 @@ async def start_scan(
@router.post("/{scan_id}/stop")
async def stop_scan(scan_id: str, db: AsyncSession = Depends(get_db)):
"""Stop a running scan"""
"""Stop a running scan and save partial results"""
from backend.api.websocket import manager as ws_manager
result = await db.execute(select(Scan).where(Scan.id == scan_id))
scan = result.scalar_one_or_none()
@@ -185,11 +187,76 @@ async def stop_scan(scan_id: str, db: AsyncSession = Depends(get_db)):
if scan.status != "running":
raise HTTPException(status_code=400, detail="Scan is not running")
# Update scan status
scan.status = "stopped"
scan.completed_at = datetime.utcnow()
scan.current_phase = "stopped"
# Calculate duration
if scan.started_at:
duration = (scan.completed_at - scan.started_at).total_seconds()
scan.duration = int(duration)
# Compute final vulnerability statistics from database
for severity in ["critical", "high", "medium", "low", "info"]:
count_result = await db.execute(
select(func.count()).select_from(Vulnerability)
.where(Vulnerability.scan_id == scan_id)
.where(Vulnerability.severity == severity)
)
setattr(scan, f"{severity}_count", count_result.scalar() or 0)
# Get total vulnerability count
total_vuln_result = await db.execute(
select(func.count()).select_from(Vulnerability)
.where(Vulnerability.scan_id == scan_id)
)
scan.total_vulnerabilities = total_vuln_result.scalar() or 0
# Get total endpoint count
total_endpoint_result = await db.execute(
select(func.count()).select_from(Endpoint)
.where(Endpoint.scan_id == scan_id)
)
scan.total_endpoints = total_endpoint_result.scalar() or 0
await db.commit()
return {"message": "Scan stopped", "scan_id": scan_id}
# Build summary for WebSocket broadcast
summary = {
"total_endpoints": scan.total_endpoints,
"total_vulnerabilities": scan.total_vulnerabilities,
"critical": scan.critical_count,
"high": scan.high_count,
"medium": scan.medium_count,
"low": scan.low_count,
"info": scan.info_count,
"duration": scan.duration,
"progress": scan.progress
}
# Broadcast stop event via WebSocket
await ws_manager.broadcast_scan_stopped(scan_id, summary)
await ws_manager.broadcast_log(scan_id, "warning", "Scan stopped by user")
await ws_manager.broadcast_log(scan_id, "info", f"Partial results: {scan.total_vulnerabilities} vulnerabilities found")
# Auto-generate partial report
report_data = None
try:
from backend.services.report_service import auto_generate_report
await ws_manager.broadcast_log(scan_id, "info", "Generating partial report...")
report = await auto_generate_report(db, scan_id, is_partial=True)
report_data = report.to_dict()
await ws_manager.broadcast_log(scan_id, "info", f"Partial report generated: {report.title}")
except Exception as report_error:
await ws_manager.broadcast_log(scan_id, "warning", f"Failed to generate partial report: {str(report_error)}")
return {
"message": "Scan stopped",
"scan_id": scan_id,
"summary": summary,
"report": report_data
}
@router.get("/{scan_id}/status", response_model=ScanProgress)
+59
View File
@@ -142,6 +142,65 @@ class ConnectionManager:
"summary": summary
})
async def broadcast_scan_stopped(self, scan_id: str, summary: dict):
"""Notify that a scan was stopped by user"""
await self.send_to_scan(scan_id, {
"type": "scan_stopped",
"scan_id": scan_id,
"status": "stopped",
"summary": summary
})
async def broadcast_scan_failed(self, scan_id: str, error: str, summary: dict = None):
"""Notify that a scan has failed"""
await self.send_to_scan(scan_id, {
"type": "scan_failed",
"scan_id": scan_id,
"status": "failed",
"error": error,
"summary": summary or {}
})
async def broadcast_stats_update(self, scan_id: str, stats: dict):
"""Broadcast updated scan statistics"""
await self.send_to_scan(scan_id, {
"type": "stats_update",
"scan_id": scan_id,
"stats": stats
})
async def broadcast_agent_task(self, scan_id: str, task: dict):
"""Broadcast agent task update (created, started, completed, failed)"""
await self.send_to_scan(scan_id, {
"type": "agent_task",
"scan_id": scan_id,
"task": task
})
async def broadcast_agent_task_started(self, scan_id: str, task: dict):
"""Broadcast when an agent task starts"""
await self.send_to_scan(scan_id, {
"type": "agent_task_started",
"scan_id": scan_id,
"task": task
})
async def broadcast_agent_task_completed(self, scan_id: str, task: dict):
"""Broadcast when an agent task completes"""
await self.send_to_scan(scan_id, {
"type": "agent_task_completed",
"scan_id": scan_id,
"task": task
})
async def broadcast_report_generated(self, scan_id: str, report: dict):
"""Broadcast when a report is generated"""
await self.send_to_scan(scan_id, {
"type": "report_generated",
"scan_id": scan_id,
"report": report
})
async def broadcast_error(self, scan_id: str, error: str):
"""Notify an error occurred"""
await self.send_to_scan(scan_id, {