mirror of
https://github.com/FuzzingLabs/fuzzforge_ai.git
synced 2026-05-21 19:36:53 +02:00
395 lines
16 KiB
Python
395 lines
16 KiB
Python
"""
|
|
Generic Prefect Statistics Monitor Service
|
|
|
|
This service monitors ALL workflows for structured live data logging and
|
|
updates the appropriate statistics APIs. Works with any workflow that follows
|
|
the standard LIVE_STATS logging pattern.
|
|
"""
|
|
# Copyright (c) 2025 FuzzingLabs
|
|
#
|
|
# Licensed under the Business Source License 1.1 (BSL). See the LICENSE file
|
|
# at the root of this repository for details.
|
|
#
|
|
# After the Change Date (four years from publication), this version of the
|
|
# Licensed Work will be made available under the Apache License, Version 2.0.
|
|
# See the LICENSE-APACHE file or http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Additional attribution and requirements are provided in the NOTICE file.
|
|
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
from datetime import datetime, timedelta, timezone
|
|
from typing import Dict, Any, Optional
|
|
from prefect.client.orchestration import get_client
|
|
from prefect.client.schemas.objects import FlowRun, TaskRun
|
|
from src.models.findings import FuzzingStats
|
|
from src.api.fuzzing import fuzzing_stats, initialize_fuzzing_tracking, active_connections
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class PrefectStatsMonitor:
|
|
"""Monitors Prefect flows and tasks for live statistics from any workflow"""
|
|
|
|
def __init__(self):
|
|
self.monitoring = False
|
|
self.monitor_task = None
|
|
self.monitored_runs = set()
|
|
self.last_log_ts: Dict[str, datetime] = {}
|
|
self._client = None
|
|
self._client_refresh_time = None
|
|
self._client_refresh_interval = 300 # Refresh connection every 5 minutes
|
|
|
|
async def start_monitoring(self):
|
|
"""Start the Prefect statistics monitoring service"""
|
|
if self.monitoring:
|
|
logger.warning("Prefect stats monitor already running")
|
|
return
|
|
|
|
self.monitoring = True
|
|
self.monitor_task = asyncio.create_task(self._monitor_flows())
|
|
logger.info("Started Prefect statistics monitor")
|
|
|
|
async def stop_monitoring(self):
|
|
"""Stop the monitoring service"""
|
|
self.monitoring = False
|
|
if self.monitor_task:
|
|
self.monitor_task.cancel()
|
|
try:
|
|
await self.monitor_task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
logger.info("Stopped Prefect statistics monitor")
|
|
|
|
async def _get_or_refresh_client(self):
|
|
"""Get or refresh Prefect client with connection pooling."""
|
|
now = datetime.now(timezone.utc)
|
|
|
|
if (self._client is None or
|
|
self._client_refresh_time is None or
|
|
(now - self._client_refresh_time).total_seconds() > self._client_refresh_interval):
|
|
|
|
if self._client:
|
|
try:
|
|
await self._client.aclose()
|
|
except Exception:
|
|
pass
|
|
|
|
self._client = get_client()
|
|
self._client_refresh_time = now
|
|
await self._client.__aenter__()
|
|
|
|
return self._client
|
|
|
|
async def _monitor_flows(self):
|
|
"""Main monitoring loop that watches Prefect flows"""
|
|
try:
|
|
while self.monitoring:
|
|
try:
|
|
# Use connection pooling for better performance
|
|
client = await self._get_or_refresh_client()
|
|
|
|
# Get recent flow runs (limit to reduce load)
|
|
flow_runs = await client.read_flow_runs(
|
|
limit=50,
|
|
sort="START_TIME_DESC",
|
|
)
|
|
|
|
# Only consider runs from the last 15 minutes
|
|
recent_cutoff = datetime.now(timezone.utc) - timedelta(minutes=15)
|
|
for flow_run in flow_runs:
|
|
created = getattr(flow_run, "created", None)
|
|
if created is None:
|
|
continue
|
|
try:
|
|
# Ensure timezone-aware comparison
|
|
if created.tzinfo is None:
|
|
created = created.replace(tzinfo=timezone.utc)
|
|
if created >= recent_cutoff:
|
|
await self._monitor_flow_run(client, flow_run)
|
|
except Exception:
|
|
# If comparison fails, attempt monitoring anyway
|
|
await self._monitor_flow_run(client, flow_run)
|
|
|
|
await asyncio.sleep(5) # Check every 5 seconds
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in Prefect monitoring: {e}")
|
|
await asyncio.sleep(10)
|
|
|
|
except asyncio.CancelledError:
|
|
logger.info("Prefect monitoring cancelled")
|
|
except Exception as e:
|
|
logger.error(f"Fatal error in Prefect monitoring: {e}")
|
|
finally:
|
|
# Clean up client on exit
|
|
if self._client:
|
|
try:
|
|
await self._client.__aexit__(None, None, None)
|
|
except Exception:
|
|
pass
|
|
self._client = None
|
|
|
|
async def _monitor_flow_run(self, client, flow_run: FlowRun):
|
|
"""Monitor a specific flow run for statistics"""
|
|
run_id = str(flow_run.id)
|
|
workflow_name = flow_run.name or "unknown"
|
|
|
|
try:
|
|
# Initialize tracking if not exists - only for workflows that might have live stats
|
|
if run_id not in fuzzing_stats:
|
|
initialize_fuzzing_tracking(run_id, workflow_name)
|
|
self.monitored_runs.add(run_id)
|
|
|
|
# Skip corrupted entries (should not happen after startup cleanup, but defensive)
|
|
elif not isinstance(fuzzing_stats[run_id], FuzzingStats):
|
|
logger.warning(f"Skipping corrupted stats entry for {run_id}, reinitializing")
|
|
initialize_fuzzing_tracking(run_id, workflow_name)
|
|
self.monitored_runs.add(run_id)
|
|
|
|
# Get task runs for this flow
|
|
task_runs = await client.read_task_runs(
|
|
flow_run_filter={"id": {"any_": [flow_run.id]}},
|
|
limit=25,
|
|
)
|
|
|
|
# Check all tasks for live statistics logging
|
|
for task_run in task_runs:
|
|
await self._extract_stats_from_task(client, run_id, task_run, workflow_name)
|
|
|
|
# Also scan flow-level logs as a fallback
|
|
await self._extract_stats_from_flow_logs(client, run_id, flow_run, workflow_name)
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Error monitoring flow run {run_id}: {e}")
|
|
|
|
async def _extract_stats_from_task(self, client, run_id: str, task_run: TaskRun, workflow_name: str):
|
|
"""Extract statistics from any task that logs live stats"""
|
|
try:
|
|
# Get task run logs
|
|
logs = await client.read_logs(
|
|
log_filter={
|
|
"task_run_id": {"any_": [task_run.id]}
|
|
},
|
|
limit=100,
|
|
sort="TIMESTAMP_ASC"
|
|
)
|
|
|
|
# Parse logs for LIVE_STATS entries (generic pattern for any workflow)
|
|
latest_stats = None
|
|
for log in logs:
|
|
# Prefer structured extra field if present
|
|
extra_data = getattr(log, "extra", None) or getattr(log, "extra_fields", None) or None
|
|
if isinstance(extra_data, dict):
|
|
stat_type = extra_data.get("stats_type")
|
|
if stat_type in ["fuzzing_live_update", "scan_progress", "analysis_update", "live_stats"]:
|
|
latest_stats = extra_data
|
|
continue
|
|
|
|
# Fallback to parsing from message text
|
|
if ("FUZZ_STATS" in log.message or "LIVE_STATS" in log.message):
|
|
stats = self._parse_stats_from_log(log.message)
|
|
if stats:
|
|
latest_stats = stats
|
|
|
|
# Update statistics if we found any
|
|
if latest_stats:
|
|
# Calculate elapsed time from task start
|
|
elapsed_time = 0
|
|
if task_run.start_time:
|
|
# Ensure timezone-aware arithmetic
|
|
now = datetime.now(timezone.utc)
|
|
try:
|
|
elapsed_time = int((now - task_run.start_time).total_seconds())
|
|
except Exception:
|
|
# Fallback to naive UTC if types mismatch
|
|
elapsed_time = int((datetime.utcnow() - task_run.start_time.replace(tzinfo=None)).total_seconds())
|
|
|
|
updated_stats = FuzzingStats(
|
|
run_id=run_id,
|
|
workflow=workflow_name,
|
|
executions=latest_stats.get("executions", 0),
|
|
executions_per_sec=latest_stats.get("executions_per_sec", 0.0),
|
|
crashes=latest_stats.get("crashes", 0),
|
|
unique_crashes=latest_stats.get("unique_crashes", 0),
|
|
corpus_size=latest_stats.get("corpus_size", 0),
|
|
elapsed_time=elapsed_time
|
|
)
|
|
|
|
# Update the global stats
|
|
previous = fuzzing_stats.get(run_id)
|
|
fuzzing_stats[run_id] = updated_stats
|
|
|
|
# Broadcast to any active WebSocket clients for this run
|
|
if active_connections.get(run_id):
|
|
# Handle both Pydantic objects and plain dicts
|
|
if isinstance(updated_stats, dict):
|
|
stats_data = updated_stats
|
|
elif hasattr(updated_stats, 'model_dump'):
|
|
stats_data = updated_stats.model_dump()
|
|
elif hasattr(updated_stats, 'dict'):
|
|
stats_data = updated_stats.dict()
|
|
else:
|
|
stats_data = updated_stats.__dict__
|
|
|
|
message = {
|
|
"type": "stats_update",
|
|
"data": stats_data,
|
|
}
|
|
disconnected = []
|
|
for ws in active_connections[run_id]:
|
|
try:
|
|
await ws.send_text(json.dumps(message))
|
|
except Exception:
|
|
disconnected.append(ws)
|
|
# Clean up disconnected sockets
|
|
for ws in disconnected:
|
|
try:
|
|
active_connections[run_id].remove(ws)
|
|
except ValueError:
|
|
pass
|
|
|
|
logger.debug(f"Updated Prefect stats for {run_id}: {updated_stats.executions} execs")
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Error extracting stats from task {task_run.id}: {e}")
|
|
|
|
async def _extract_stats_from_flow_logs(self, client, run_id: str, flow_run: FlowRun, workflow_name: str):
|
|
"""Extract statistics by scanning flow-level logs for LIVE/FUZZ stats"""
|
|
try:
|
|
logs = await client.read_logs(
|
|
log_filter={
|
|
"flow_run_id": {"any_": [flow_run.id]}
|
|
},
|
|
limit=200,
|
|
sort="TIMESTAMP_ASC"
|
|
)
|
|
|
|
latest_stats = None
|
|
last_seen = self.last_log_ts.get(run_id)
|
|
max_ts = last_seen
|
|
|
|
for log in logs:
|
|
# Skip logs we've already processed
|
|
ts = getattr(log, "timestamp", None)
|
|
if last_seen and ts and ts <= last_seen:
|
|
continue
|
|
if ts and (max_ts is None or ts > max_ts):
|
|
max_ts = ts
|
|
|
|
# Prefer structured extra field if available
|
|
extra_data = getattr(log, "extra", None) or getattr(log, "extra_fields", None) or None
|
|
if isinstance(extra_data, dict):
|
|
stat_type = extra_data.get("stats_type")
|
|
if stat_type in ["fuzzing_live_update", "scan_progress", "analysis_update", "live_stats"]:
|
|
latest_stats = extra_data
|
|
continue
|
|
|
|
# Fallback to message parse
|
|
if ("FUZZ_STATS" in log.message or "LIVE_STATS" in log.message):
|
|
stats = self._parse_stats_from_log(log.message)
|
|
if stats:
|
|
latest_stats = stats
|
|
|
|
if max_ts:
|
|
self.last_log_ts[run_id] = max_ts
|
|
|
|
if latest_stats:
|
|
# Use flow_run timestamps for elapsed time if available
|
|
elapsed_time = 0
|
|
start_time = getattr(flow_run, "start_time", None) or getattr(flow_run, "start_time", None)
|
|
if start_time:
|
|
now = datetime.now(timezone.utc)
|
|
try:
|
|
if start_time.tzinfo is None:
|
|
start_time = start_time.replace(tzinfo=timezone.utc)
|
|
elapsed_time = int((now - start_time).total_seconds())
|
|
except Exception:
|
|
elapsed_time = int((datetime.utcnow() - start_time.replace(tzinfo=None)).total_seconds())
|
|
|
|
updated_stats = FuzzingStats(
|
|
run_id=run_id,
|
|
workflow=workflow_name,
|
|
executions=latest_stats.get("executions", 0),
|
|
executions_per_sec=latest_stats.get("executions_per_sec", 0.0),
|
|
crashes=latest_stats.get("crashes", 0),
|
|
unique_crashes=latest_stats.get("unique_crashes", 0),
|
|
corpus_size=latest_stats.get("corpus_size", 0),
|
|
elapsed_time=elapsed_time
|
|
)
|
|
|
|
fuzzing_stats[run_id] = updated_stats
|
|
|
|
# Broadcast if listeners exist
|
|
if active_connections.get(run_id):
|
|
# Handle both Pydantic objects and plain dicts
|
|
if isinstance(updated_stats, dict):
|
|
stats_data = updated_stats
|
|
elif hasattr(updated_stats, 'model_dump'):
|
|
stats_data = updated_stats.model_dump()
|
|
elif hasattr(updated_stats, 'dict'):
|
|
stats_data = updated_stats.dict()
|
|
else:
|
|
stats_data = updated_stats.__dict__
|
|
|
|
message = {
|
|
"type": "stats_update",
|
|
"data": stats_data,
|
|
}
|
|
disconnected = []
|
|
for ws in active_connections[run_id]:
|
|
try:
|
|
await ws.send_text(json.dumps(message))
|
|
except Exception:
|
|
disconnected.append(ws)
|
|
for ws in disconnected:
|
|
try:
|
|
active_connections[run_id].remove(ws)
|
|
except ValueError:
|
|
pass
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Error extracting stats from flow logs {run_id}: {e}")
|
|
|
|
def _parse_stats_from_log(self, log_message: str) -> Optional[Dict[str, Any]]:
|
|
"""Parse statistics from a log message"""
|
|
try:
|
|
import re
|
|
|
|
# Prefer explicit JSON after marker tokens
|
|
m = re.search(r'(?:FUZZ_STATS|LIVE_STATS)\s+(\{.*\})', log_message)
|
|
if m:
|
|
try:
|
|
return json.loads(m.group(1))
|
|
except Exception:
|
|
pass
|
|
|
|
# Fallback: Extract the extra= dict and coerce to JSON
|
|
stats_match = re.search(r'extra=({.*?})', log_message)
|
|
if not stats_match:
|
|
return None
|
|
|
|
extra_str = stats_match.group(1)
|
|
extra_str = extra_str.replace("'", '"')
|
|
extra_str = extra_str.replace('None', 'null')
|
|
extra_str = extra_str.replace('True', 'true')
|
|
extra_str = extra_str.replace('False', 'false')
|
|
|
|
stats_data = json.loads(extra_str)
|
|
|
|
# Support multiple stat types for different workflows
|
|
stat_type = stats_data.get("stats_type")
|
|
if stat_type in ["fuzzing_live_update", "scan_progress", "analysis_update", "live_stats"]:
|
|
return stats_data
|
|
|
|
except Exception as e:
|
|
logger.debug(f"Error parsing log stats: {e}")
|
|
|
|
return None
|
|
|
|
|
|
# Global instance
|
|
prefect_stats_monitor = PrefectStatsMonitor()
|