mirror of
https://github.com/CyberSecurityUP/NeuroSploit.git
synced 2026-05-21 05:56:51 +02:00
Add files via upload
This commit is contained in:
@@ -0,0 +1,176 @@
|
||||
"""
|
||||
NeuroSploit v3 - Agent Tasks API Endpoints
|
||||
"""
|
||||
from typing import Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, func
|
||||
|
||||
from backend.db.database import get_db
|
||||
from backend.models import AgentTask, Scan
|
||||
from backend.schemas.agent_task import (
|
||||
AgentTaskResponse,
|
||||
AgentTaskListResponse,
|
||||
AgentTaskSummary
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("", response_model=AgentTaskListResponse)
|
||||
async def list_agent_tasks(
|
||||
scan_id: str,
|
||||
status: Optional[str] = None,
|
||||
task_type: Optional[str] = None,
|
||||
page: int = 1,
|
||||
per_page: int = 50,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""List all agent tasks for a scan"""
|
||||
# Verify scan exists
|
||||
scan_result = await db.execute(select(Scan).where(Scan.id == scan_id))
|
||||
scan = scan_result.scalar_one_or_none()
|
||||
if not scan:
|
||||
raise HTTPException(status_code=404, detail="Scan not found")
|
||||
|
||||
# Build query
|
||||
query = select(AgentTask).where(AgentTask.scan_id == scan_id)
|
||||
|
||||
if status:
|
||||
query = query.where(AgentTask.status == status)
|
||||
if task_type:
|
||||
query = query.where(AgentTask.task_type == task_type)
|
||||
|
||||
query = query.order_by(AgentTask.created_at.desc())
|
||||
|
||||
# Get total count
|
||||
count_query = select(func.count()).select_from(AgentTask).where(AgentTask.scan_id == scan_id)
|
||||
if status:
|
||||
count_query = count_query.where(AgentTask.status == status)
|
||||
if task_type:
|
||||
count_query = count_query.where(AgentTask.task_type == task_type)
|
||||
total_result = await db.execute(count_query)
|
||||
total = total_result.scalar() or 0
|
||||
|
||||
# Apply pagination
|
||||
query = query.offset((page - 1) * per_page).limit(per_page)
|
||||
result = await db.execute(query)
|
||||
tasks = result.scalars().all()
|
||||
|
||||
return AgentTaskListResponse(
|
||||
tasks=[AgentTaskResponse(**t.to_dict()) for t in tasks],
|
||||
total=total,
|
||||
scan_id=scan_id
|
||||
)
|
||||
|
||||
|
||||
@router.get("/summary", response_model=AgentTaskSummary)
|
||||
async def get_agent_tasks_summary(
|
||||
scan_id: str,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Get summary statistics for agent tasks in a scan"""
|
||||
# Verify scan exists
|
||||
scan_result = await db.execute(select(Scan).where(Scan.id == scan_id))
|
||||
scan = scan_result.scalar_one_or_none()
|
||||
if not scan:
|
||||
raise HTTPException(status_code=404, detail="Scan not found")
|
||||
|
||||
# Total count
|
||||
total_result = await db.execute(
|
||||
select(func.count()).select_from(AgentTask).where(AgentTask.scan_id == scan_id)
|
||||
)
|
||||
total = total_result.scalar() or 0
|
||||
|
||||
# Count by status
|
||||
status_counts = {}
|
||||
for status in ["pending", "running", "completed", "failed"]:
|
||||
count_result = await db.execute(
|
||||
select(func.count()).select_from(AgentTask)
|
||||
.where(AgentTask.scan_id == scan_id)
|
||||
.where(AgentTask.status == status)
|
||||
)
|
||||
status_counts[status] = count_result.scalar() or 0
|
||||
|
||||
# Count by task type
|
||||
type_query = select(
|
||||
AgentTask.task_type,
|
||||
func.count(AgentTask.id).label("count")
|
||||
).where(AgentTask.scan_id == scan_id).group_by(AgentTask.task_type)
|
||||
type_result = await db.execute(type_query)
|
||||
by_type = {row[0]: row[1] for row in type_result.all()}
|
||||
|
||||
# Count by tool
|
||||
tool_query = select(
|
||||
AgentTask.tool_name,
|
||||
func.count(AgentTask.id).label("count")
|
||||
).where(AgentTask.scan_id == scan_id).where(AgentTask.tool_name.isnot(None)).group_by(AgentTask.tool_name)
|
||||
tool_result = await db.execute(tool_query)
|
||||
by_tool = {row[0]: row[1] for row in tool_result.all()}
|
||||
|
||||
return AgentTaskSummary(
|
||||
total=total,
|
||||
pending=status_counts.get("pending", 0),
|
||||
running=status_counts.get("running", 0),
|
||||
completed=status_counts.get("completed", 0),
|
||||
failed=status_counts.get("failed", 0),
|
||||
by_type=by_type,
|
||||
by_tool=by_tool
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{task_id}", response_model=AgentTaskResponse)
|
||||
async def get_agent_task(
|
||||
task_id: str,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Get a specific agent task by ID"""
|
||||
result = await db.execute(select(AgentTask).where(AgentTask.id == task_id))
|
||||
task = result.scalar_one_or_none()
|
||||
|
||||
if not task:
|
||||
raise HTTPException(status_code=404, detail="Agent task not found")
|
||||
|
||||
return AgentTaskResponse(**task.to_dict())
|
||||
|
||||
|
||||
@router.get("/scan/{scan_id}/timeline")
|
||||
async def get_agent_tasks_timeline(
|
||||
scan_id: str,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Get agent tasks as a timeline for visualization"""
|
||||
# Verify scan exists
|
||||
scan_result = await db.execute(select(Scan).where(Scan.id == scan_id))
|
||||
scan = scan_result.scalar_one_or_none()
|
||||
if not scan:
|
||||
raise HTTPException(status_code=404, detail="Scan not found")
|
||||
|
||||
# Get all tasks ordered by creation time
|
||||
query = select(AgentTask).where(AgentTask.scan_id == scan_id).order_by(AgentTask.created_at.asc())
|
||||
result = await db.execute(query)
|
||||
tasks = result.scalars().all()
|
||||
|
||||
timeline = []
|
||||
for task in tasks:
|
||||
timeline_item = {
|
||||
"id": task.id,
|
||||
"task_name": task.task_name,
|
||||
"task_type": task.task_type,
|
||||
"tool_name": task.tool_name,
|
||||
"status": task.status,
|
||||
"started_at": task.started_at.isoformat() if task.started_at else None,
|
||||
"completed_at": task.completed_at.isoformat() if task.completed_at else None,
|
||||
"duration_ms": task.duration_ms,
|
||||
"items_processed": task.items_processed,
|
||||
"items_found": task.items_found,
|
||||
"result_summary": task.result_summary,
|
||||
"error_message": task.error_message
|
||||
}
|
||||
timeline.append(timeline_item)
|
||||
|
||||
return {
|
||||
"scan_id": scan_id,
|
||||
"timeline": timeline,
|
||||
"total": len(timeline)
|
||||
}
|
||||
+125
-3
@@ -8,7 +8,7 @@ from sqlalchemy import select, func
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from backend.db.database import get_db
|
||||
from backend.models import Scan, Vulnerability, Endpoint
|
||||
from backend.models import Scan, Vulnerability, Endpoint, AgentTask, Report
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -20,18 +20,32 @@ async def get_dashboard_stats(db: AsyncSession = Depends(get_db)):
|
||||
total_scans_result = await db.execute(select(func.count()).select_from(Scan))
|
||||
total_scans = total_scans_result.scalar() or 0
|
||||
|
||||
# Running scans
|
||||
# Scans by status
|
||||
running_result = await db.execute(
|
||||
select(func.count()).select_from(Scan).where(Scan.status == "running")
|
||||
)
|
||||
running_scans = running_result.scalar() or 0
|
||||
|
||||
# Completed scans
|
||||
completed_result = await db.execute(
|
||||
select(func.count()).select_from(Scan).where(Scan.status == "completed")
|
||||
)
|
||||
completed_scans = completed_result.scalar() or 0
|
||||
|
||||
stopped_result = await db.execute(
|
||||
select(func.count()).select_from(Scan).where(Scan.status == "stopped")
|
||||
)
|
||||
stopped_scans = stopped_result.scalar() or 0
|
||||
|
||||
failed_result = await db.execute(
|
||||
select(func.count()).select_from(Scan).where(Scan.status == "failed")
|
||||
)
|
||||
failed_scans = failed_result.scalar() or 0
|
||||
|
||||
pending_result = await db.execute(
|
||||
select(func.count()).select_from(Scan).where(Scan.status == "pending")
|
||||
)
|
||||
pending_scans = pending_result.scalar() or 0
|
||||
|
||||
# Total vulnerabilities by severity
|
||||
vuln_counts = {}
|
||||
for severity in ["critical", "high", "medium", "low", "info"]:
|
||||
@@ -63,6 +77,9 @@ async def get_dashboard_stats(db: AsyncSession = Depends(get_db)):
|
||||
"total": total_scans,
|
||||
"running": running_scans,
|
||||
"completed": completed_scans,
|
||||
"stopped": stopped_scans,
|
||||
"failed": failed_scans,
|
||||
"pending": pending_scans,
|
||||
"recent": recent_scans
|
||||
},
|
||||
"vulnerabilities": {
|
||||
@@ -175,3 +192,108 @@ async def get_scan_history(
|
||||
history[date_str]["high"] += scan.high_count
|
||||
|
||||
return {"history": list(history.values())}
|
||||
|
||||
|
||||
@router.get("/agent-tasks")
|
||||
async def get_recent_agent_tasks(
|
||||
limit: int = 20,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Get recent agent tasks across all scans"""
|
||||
query = (
|
||||
select(AgentTask)
|
||||
.order_by(AgentTask.created_at.desc())
|
||||
.limit(limit)
|
||||
)
|
||||
result = await db.execute(query)
|
||||
tasks = result.scalars().all()
|
||||
|
||||
return {
|
||||
"agent_tasks": [t.to_dict() for t in tasks],
|
||||
"total": len(tasks)
|
||||
}
|
||||
|
||||
|
||||
@router.get("/activity-feed")
|
||||
async def get_activity_feed(
|
||||
limit: int = 30,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Get unified activity feed with all recent events"""
|
||||
activities = []
|
||||
|
||||
# Get recent scans
|
||||
scans_result = await db.execute(
|
||||
select(Scan).order_by(Scan.created_at.desc()).limit(limit // 3)
|
||||
)
|
||||
for scan in scans_result.scalars().all():
|
||||
activities.append({
|
||||
"type": "scan",
|
||||
"action": f"Scan {scan.status}",
|
||||
"title": scan.name or "Unnamed Scan",
|
||||
"description": f"{scan.total_vulnerabilities} vulnerabilities found",
|
||||
"status": scan.status,
|
||||
"severity": None,
|
||||
"timestamp": scan.created_at.isoformat(),
|
||||
"scan_id": scan.id,
|
||||
"link": f"/scan/{scan.id}"
|
||||
})
|
||||
|
||||
# Get recent vulnerabilities
|
||||
vulns_result = await db.execute(
|
||||
select(Vulnerability).order_by(Vulnerability.created_at.desc()).limit(limit // 3)
|
||||
)
|
||||
for vuln in vulns_result.scalars().all():
|
||||
activities.append({
|
||||
"type": "vulnerability",
|
||||
"action": "Vulnerability found",
|
||||
"title": vuln.title,
|
||||
"description": vuln.affected_endpoint or "",
|
||||
"status": None,
|
||||
"severity": vuln.severity,
|
||||
"timestamp": vuln.created_at.isoformat(),
|
||||
"scan_id": vuln.scan_id,
|
||||
"link": f"/scan/{vuln.scan_id}"
|
||||
})
|
||||
|
||||
# Get recent agent tasks
|
||||
tasks_result = await db.execute(
|
||||
select(AgentTask).order_by(AgentTask.created_at.desc()).limit(limit // 3)
|
||||
)
|
||||
for task in tasks_result.scalars().all():
|
||||
activities.append({
|
||||
"type": "agent_task",
|
||||
"action": f"Task {task.status}",
|
||||
"title": task.task_name,
|
||||
"description": task.result_summary or task.description or "",
|
||||
"status": task.status,
|
||||
"severity": None,
|
||||
"timestamp": task.created_at.isoformat(),
|
||||
"scan_id": task.scan_id,
|
||||
"link": f"/scan/{task.scan_id}"
|
||||
})
|
||||
|
||||
# Get recent reports
|
||||
reports_result = await db.execute(
|
||||
select(Report).order_by(Report.generated_at.desc()).limit(limit // 4)
|
||||
)
|
||||
for report in reports_result.scalars().all():
|
||||
activities.append({
|
||||
"type": "report",
|
||||
"action": "Report generated" if report.auto_generated else "Report created",
|
||||
"title": report.title or "Report",
|
||||
"description": f"{report.format.upper()} format",
|
||||
"status": "auto" if report.auto_generated else "manual",
|
||||
"severity": None,
|
||||
"timestamp": report.generated_at.isoformat(),
|
||||
"scan_id": report.scan_id,
|
||||
"link": f"/reports"
|
||||
})
|
||||
|
||||
# Sort all activities by timestamp (newest first)
|
||||
activities.sort(key=lambda x: x["timestamp"], reverse=True)
|
||||
|
||||
return {
|
||||
"activities": activities[:limit],
|
||||
"total": len(activities)
|
||||
}
|
||||
|
||||
@@ -20,14 +20,22 @@ router = APIRouter()
|
||||
@router.get("", response_model=ReportListResponse)
|
||||
async def list_reports(
|
||||
scan_id: Optional[str] = None,
|
||||
auto_generated: Optional[bool] = None,
|
||||
is_partial: Optional[bool] = None,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""List all reports"""
|
||||
"""List all reports with optional filtering"""
|
||||
query = select(Report).order_by(Report.generated_at.desc())
|
||||
|
||||
if scan_id:
|
||||
query = query.where(Report.scan_id == scan_id)
|
||||
|
||||
if auto_generated is not None:
|
||||
query = query.where(Report.auto_generated == auto_generated)
|
||||
|
||||
if is_partial is not None:
|
||||
query = query.where(Report.is_partial == is_partial)
|
||||
|
||||
result = await db.execute(query)
|
||||
reports = result.scalars().all()
|
||||
|
||||
|
||||
+69
-2
@@ -175,7 +175,9 @@ async def start_scan(
|
||||
|
||||
@router.post("/{scan_id}/stop")
|
||||
async def stop_scan(scan_id: str, db: AsyncSession = Depends(get_db)):
|
||||
"""Stop a running scan"""
|
||||
"""Stop a running scan and save partial results"""
|
||||
from backend.api.websocket import manager as ws_manager
|
||||
|
||||
result = await db.execute(select(Scan).where(Scan.id == scan_id))
|
||||
scan = result.scalar_one_or_none()
|
||||
|
||||
@@ -185,11 +187,76 @@ async def stop_scan(scan_id: str, db: AsyncSession = Depends(get_db)):
|
||||
if scan.status != "running":
|
||||
raise HTTPException(status_code=400, detail="Scan is not running")
|
||||
|
||||
# Update scan status
|
||||
scan.status = "stopped"
|
||||
scan.completed_at = datetime.utcnow()
|
||||
scan.current_phase = "stopped"
|
||||
|
||||
# Calculate duration
|
||||
if scan.started_at:
|
||||
duration = (scan.completed_at - scan.started_at).total_seconds()
|
||||
scan.duration = int(duration)
|
||||
|
||||
# Compute final vulnerability statistics from database
|
||||
for severity in ["critical", "high", "medium", "low", "info"]:
|
||||
count_result = await db.execute(
|
||||
select(func.count()).select_from(Vulnerability)
|
||||
.where(Vulnerability.scan_id == scan_id)
|
||||
.where(Vulnerability.severity == severity)
|
||||
)
|
||||
setattr(scan, f"{severity}_count", count_result.scalar() or 0)
|
||||
|
||||
# Get total vulnerability count
|
||||
total_vuln_result = await db.execute(
|
||||
select(func.count()).select_from(Vulnerability)
|
||||
.where(Vulnerability.scan_id == scan_id)
|
||||
)
|
||||
scan.total_vulnerabilities = total_vuln_result.scalar() or 0
|
||||
|
||||
# Get total endpoint count
|
||||
total_endpoint_result = await db.execute(
|
||||
select(func.count()).select_from(Endpoint)
|
||||
.where(Endpoint.scan_id == scan_id)
|
||||
)
|
||||
scan.total_endpoints = total_endpoint_result.scalar() or 0
|
||||
|
||||
await db.commit()
|
||||
|
||||
return {"message": "Scan stopped", "scan_id": scan_id}
|
||||
# Build summary for WebSocket broadcast
|
||||
summary = {
|
||||
"total_endpoints": scan.total_endpoints,
|
||||
"total_vulnerabilities": scan.total_vulnerabilities,
|
||||
"critical": scan.critical_count,
|
||||
"high": scan.high_count,
|
||||
"medium": scan.medium_count,
|
||||
"low": scan.low_count,
|
||||
"info": scan.info_count,
|
||||
"duration": scan.duration,
|
||||
"progress": scan.progress
|
||||
}
|
||||
|
||||
# Broadcast stop event via WebSocket
|
||||
await ws_manager.broadcast_scan_stopped(scan_id, summary)
|
||||
await ws_manager.broadcast_log(scan_id, "warning", "Scan stopped by user")
|
||||
await ws_manager.broadcast_log(scan_id, "info", f"Partial results: {scan.total_vulnerabilities} vulnerabilities found")
|
||||
|
||||
# Auto-generate partial report
|
||||
report_data = None
|
||||
try:
|
||||
from backend.services.report_service import auto_generate_report
|
||||
await ws_manager.broadcast_log(scan_id, "info", "Generating partial report...")
|
||||
report = await auto_generate_report(db, scan_id, is_partial=True)
|
||||
report_data = report.to_dict()
|
||||
await ws_manager.broadcast_log(scan_id, "info", f"Partial report generated: {report.title}")
|
||||
except Exception as report_error:
|
||||
await ws_manager.broadcast_log(scan_id, "warning", f"Failed to generate partial report: {str(report_error)}")
|
||||
|
||||
return {
|
||||
"message": "Scan stopped",
|
||||
"scan_id": scan_id,
|
||||
"summary": summary,
|
||||
"report": report_data
|
||||
}
|
||||
|
||||
|
||||
@router.get("/{scan_id}/status", response_model=ScanProgress)
|
||||
|
||||
@@ -142,6 +142,65 @@ class ConnectionManager:
|
||||
"summary": summary
|
||||
})
|
||||
|
||||
async def broadcast_scan_stopped(self, scan_id: str, summary: dict):
|
||||
"""Notify that a scan was stopped by user"""
|
||||
await self.send_to_scan(scan_id, {
|
||||
"type": "scan_stopped",
|
||||
"scan_id": scan_id,
|
||||
"status": "stopped",
|
||||
"summary": summary
|
||||
})
|
||||
|
||||
async def broadcast_scan_failed(self, scan_id: str, error: str, summary: dict = None):
|
||||
"""Notify that a scan has failed"""
|
||||
await self.send_to_scan(scan_id, {
|
||||
"type": "scan_failed",
|
||||
"scan_id": scan_id,
|
||||
"status": "failed",
|
||||
"error": error,
|
||||
"summary": summary or {}
|
||||
})
|
||||
|
||||
async def broadcast_stats_update(self, scan_id: str, stats: dict):
|
||||
"""Broadcast updated scan statistics"""
|
||||
await self.send_to_scan(scan_id, {
|
||||
"type": "stats_update",
|
||||
"scan_id": scan_id,
|
||||
"stats": stats
|
||||
})
|
||||
|
||||
async def broadcast_agent_task(self, scan_id: str, task: dict):
|
||||
"""Broadcast agent task update (created, started, completed, failed)"""
|
||||
await self.send_to_scan(scan_id, {
|
||||
"type": "agent_task",
|
||||
"scan_id": scan_id,
|
||||
"task": task
|
||||
})
|
||||
|
||||
async def broadcast_agent_task_started(self, scan_id: str, task: dict):
|
||||
"""Broadcast when an agent task starts"""
|
||||
await self.send_to_scan(scan_id, {
|
||||
"type": "agent_task_started",
|
||||
"scan_id": scan_id,
|
||||
"task": task
|
||||
})
|
||||
|
||||
async def broadcast_agent_task_completed(self, scan_id: str, task: dict):
|
||||
"""Broadcast when an agent task completes"""
|
||||
await self.send_to_scan(scan_id, {
|
||||
"type": "agent_task_completed",
|
||||
"scan_id": scan_id,
|
||||
"task": task
|
||||
})
|
||||
|
||||
async def broadcast_report_generated(self, scan_id: str, report: dict):
|
||||
"""Broadcast when a report is generated"""
|
||||
await self.send_to_scan(scan_id, {
|
||||
"type": "report_generated",
|
||||
"scan_id": scan_id,
|
||||
"report": report
|
||||
})
|
||||
|
||||
async def broadcast_error(self, scan_id: str, error: str):
|
||||
"""Notify an error occurred"""
|
||||
await self.send_to_scan(scan_id, {
|
||||
|
||||
Reference in New Issue
Block a user