Files
fuzzforge_ai/backend/src/services/prefect_stats_monitor.py
T
Tanguy Duhamel 323a434c73 Initial commit
2025-09-29 21:26:41 +02:00

395 lines
16 KiB
Python

"""
Generic Prefect Statistics Monitor Service
This service monitors ALL workflows for structured live data logging and
updates the appropriate statistics APIs. Works with any workflow that follows
the standard LIVE_STATS logging pattern.
"""
# Copyright (c) 2025 FuzzingLabs
#
# Licensed under the Business Source License 1.1 (BSL). See the LICENSE file
# at the root of this repository for details.
#
# After the Change Date (four years from publication), this version of the
# Licensed Work will be made available under the Apache License, Version 2.0.
# See the LICENSE-APACHE file or http://www.apache.org/licenses/LICENSE-2.0
#
# Additional attribution and requirements are provided in the NOTICE file.
import asyncio
import json
import logging
from datetime import datetime, timedelta, timezone
from typing import Dict, Any, Optional
from prefect.client.orchestration import get_client
from prefect.client.schemas.objects import FlowRun, TaskRun
from src.models.findings import FuzzingStats
from src.api.fuzzing import fuzzing_stats, initialize_fuzzing_tracking, active_connections
logger = logging.getLogger(__name__)
class PrefectStatsMonitor:
"""Monitors Prefect flows and tasks for live statistics from any workflow"""
def __init__(self):
self.monitoring = False
self.monitor_task = None
self.monitored_runs = set()
self.last_log_ts: Dict[str, datetime] = {}
self._client = None
self._client_refresh_time = None
self._client_refresh_interval = 300 # Refresh connection every 5 minutes
async def start_monitoring(self):
"""Start the Prefect statistics monitoring service"""
if self.monitoring:
logger.warning("Prefect stats monitor already running")
return
self.monitoring = True
self.monitor_task = asyncio.create_task(self._monitor_flows())
logger.info("Started Prefect statistics monitor")
async def stop_monitoring(self):
"""Stop the monitoring service"""
self.monitoring = False
if self.monitor_task:
self.monitor_task.cancel()
try:
await self.monitor_task
except asyncio.CancelledError:
pass
logger.info("Stopped Prefect statistics monitor")
async def _get_or_refresh_client(self):
"""Get or refresh Prefect client with connection pooling."""
now = datetime.now(timezone.utc)
if (self._client is None or
self._client_refresh_time is None or
(now - self._client_refresh_time).total_seconds() > self._client_refresh_interval):
if self._client:
try:
await self._client.aclose()
except Exception:
pass
self._client = get_client()
self._client_refresh_time = now
await self._client.__aenter__()
return self._client
async def _monitor_flows(self):
"""Main monitoring loop that watches Prefect flows"""
try:
while self.monitoring:
try:
# Use connection pooling for better performance
client = await self._get_or_refresh_client()
# Get recent flow runs (limit to reduce load)
flow_runs = await client.read_flow_runs(
limit=50,
sort="START_TIME_DESC",
)
# Only consider runs from the last 15 minutes
recent_cutoff = datetime.now(timezone.utc) - timedelta(minutes=15)
for flow_run in flow_runs:
created = getattr(flow_run, "created", None)
if created is None:
continue
try:
# Ensure timezone-aware comparison
if created.tzinfo is None:
created = created.replace(tzinfo=timezone.utc)
if created >= recent_cutoff:
await self._monitor_flow_run(client, flow_run)
except Exception:
# If comparison fails, attempt monitoring anyway
await self._monitor_flow_run(client, flow_run)
await asyncio.sleep(5) # Check every 5 seconds
except Exception as e:
logger.error(f"Error in Prefect monitoring: {e}")
await asyncio.sleep(10)
except asyncio.CancelledError:
logger.info("Prefect monitoring cancelled")
except Exception as e:
logger.error(f"Fatal error in Prefect monitoring: {e}")
finally:
# Clean up client on exit
if self._client:
try:
await self._client.__aexit__(None, None, None)
except Exception:
pass
self._client = None
async def _monitor_flow_run(self, client, flow_run: FlowRun):
"""Monitor a specific flow run for statistics"""
run_id = str(flow_run.id)
workflow_name = flow_run.name or "unknown"
try:
# Initialize tracking if not exists - only for workflows that might have live stats
if run_id not in fuzzing_stats:
initialize_fuzzing_tracking(run_id, workflow_name)
self.monitored_runs.add(run_id)
# Skip corrupted entries (should not happen after startup cleanup, but defensive)
elif not isinstance(fuzzing_stats[run_id], FuzzingStats):
logger.warning(f"Skipping corrupted stats entry for {run_id}, reinitializing")
initialize_fuzzing_tracking(run_id, workflow_name)
self.monitored_runs.add(run_id)
# Get task runs for this flow
task_runs = await client.read_task_runs(
flow_run_filter={"id": {"any_": [flow_run.id]}},
limit=25,
)
# Check all tasks for live statistics logging
for task_run in task_runs:
await self._extract_stats_from_task(client, run_id, task_run, workflow_name)
# Also scan flow-level logs as a fallback
await self._extract_stats_from_flow_logs(client, run_id, flow_run, workflow_name)
except Exception as e:
logger.warning(f"Error monitoring flow run {run_id}: {e}")
async def _extract_stats_from_task(self, client, run_id: str, task_run: TaskRun, workflow_name: str):
"""Extract statistics from any task that logs live stats"""
try:
# Get task run logs
logs = await client.read_logs(
log_filter={
"task_run_id": {"any_": [task_run.id]}
},
limit=100,
sort="TIMESTAMP_ASC"
)
# Parse logs for LIVE_STATS entries (generic pattern for any workflow)
latest_stats = None
for log in logs:
# Prefer structured extra field if present
extra_data = getattr(log, "extra", None) or getattr(log, "extra_fields", None) or None
if isinstance(extra_data, dict):
stat_type = extra_data.get("stats_type")
if stat_type in ["fuzzing_live_update", "scan_progress", "analysis_update", "live_stats"]:
latest_stats = extra_data
continue
# Fallback to parsing from message text
if ("FUZZ_STATS" in log.message or "LIVE_STATS" in log.message):
stats = self._parse_stats_from_log(log.message)
if stats:
latest_stats = stats
# Update statistics if we found any
if latest_stats:
# Calculate elapsed time from task start
elapsed_time = 0
if task_run.start_time:
# Ensure timezone-aware arithmetic
now = datetime.now(timezone.utc)
try:
elapsed_time = int((now - task_run.start_time).total_seconds())
except Exception:
# Fallback to naive UTC if types mismatch
elapsed_time = int((datetime.utcnow() - task_run.start_time.replace(tzinfo=None)).total_seconds())
updated_stats = FuzzingStats(
run_id=run_id,
workflow=workflow_name,
executions=latest_stats.get("executions", 0),
executions_per_sec=latest_stats.get("executions_per_sec", 0.0),
crashes=latest_stats.get("crashes", 0),
unique_crashes=latest_stats.get("unique_crashes", 0),
corpus_size=latest_stats.get("corpus_size", 0),
elapsed_time=elapsed_time
)
# Update the global stats
previous = fuzzing_stats.get(run_id)
fuzzing_stats[run_id] = updated_stats
# Broadcast to any active WebSocket clients for this run
if active_connections.get(run_id):
# Handle both Pydantic objects and plain dicts
if isinstance(updated_stats, dict):
stats_data = updated_stats
elif hasattr(updated_stats, 'model_dump'):
stats_data = updated_stats.model_dump()
elif hasattr(updated_stats, 'dict'):
stats_data = updated_stats.dict()
else:
stats_data = updated_stats.__dict__
message = {
"type": "stats_update",
"data": stats_data,
}
disconnected = []
for ws in active_connections[run_id]:
try:
await ws.send_text(json.dumps(message))
except Exception:
disconnected.append(ws)
# Clean up disconnected sockets
for ws in disconnected:
try:
active_connections[run_id].remove(ws)
except ValueError:
pass
logger.debug(f"Updated Prefect stats for {run_id}: {updated_stats.executions} execs")
except Exception as e:
logger.warning(f"Error extracting stats from task {task_run.id}: {e}")
async def _extract_stats_from_flow_logs(self, client, run_id: str, flow_run: FlowRun, workflow_name: str):
"""Extract statistics by scanning flow-level logs for LIVE/FUZZ stats"""
try:
logs = await client.read_logs(
log_filter={
"flow_run_id": {"any_": [flow_run.id]}
},
limit=200,
sort="TIMESTAMP_ASC"
)
latest_stats = None
last_seen = self.last_log_ts.get(run_id)
max_ts = last_seen
for log in logs:
# Skip logs we've already processed
ts = getattr(log, "timestamp", None)
if last_seen and ts and ts <= last_seen:
continue
if ts and (max_ts is None or ts > max_ts):
max_ts = ts
# Prefer structured extra field if available
extra_data = getattr(log, "extra", None) or getattr(log, "extra_fields", None) or None
if isinstance(extra_data, dict):
stat_type = extra_data.get("stats_type")
if stat_type in ["fuzzing_live_update", "scan_progress", "analysis_update", "live_stats"]:
latest_stats = extra_data
continue
# Fallback to message parse
if ("FUZZ_STATS" in log.message or "LIVE_STATS" in log.message):
stats = self._parse_stats_from_log(log.message)
if stats:
latest_stats = stats
if max_ts:
self.last_log_ts[run_id] = max_ts
if latest_stats:
# Use flow_run timestamps for elapsed time if available
elapsed_time = 0
start_time = getattr(flow_run, "start_time", None) or getattr(flow_run, "start_time", None)
if start_time:
now = datetime.now(timezone.utc)
try:
if start_time.tzinfo is None:
start_time = start_time.replace(tzinfo=timezone.utc)
elapsed_time = int((now - start_time).total_seconds())
except Exception:
elapsed_time = int((datetime.utcnow() - start_time.replace(tzinfo=None)).total_seconds())
updated_stats = FuzzingStats(
run_id=run_id,
workflow=workflow_name,
executions=latest_stats.get("executions", 0),
executions_per_sec=latest_stats.get("executions_per_sec", 0.0),
crashes=latest_stats.get("crashes", 0),
unique_crashes=latest_stats.get("unique_crashes", 0),
corpus_size=latest_stats.get("corpus_size", 0),
elapsed_time=elapsed_time
)
fuzzing_stats[run_id] = updated_stats
# Broadcast if listeners exist
if active_connections.get(run_id):
# Handle both Pydantic objects and plain dicts
if isinstance(updated_stats, dict):
stats_data = updated_stats
elif hasattr(updated_stats, 'model_dump'):
stats_data = updated_stats.model_dump()
elif hasattr(updated_stats, 'dict'):
stats_data = updated_stats.dict()
else:
stats_data = updated_stats.__dict__
message = {
"type": "stats_update",
"data": stats_data,
}
disconnected = []
for ws in active_connections[run_id]:
try:
await ws.send_text(json.dumps(message))
except Exception:
disconnected.append(ws)
for ws in disconnected:
try:
active_connections[run_id].remove(ws)
except ValueError:
pass
except Exception as e:
logger.warning(f"Error extracting stats from flow logs {run_id}: {e}")
def _parse_stats_from_log(self, log_message: str) -> Optional[Dict[str, Any]]:
"""Parse statistics from a log message"""
try:
import re
# Prefer explicit JSON after marker tokens
m = re.search(r'(?:FUZZ_STATS|LIVE_STATS)\s+(\{.*\})', log_message)
if m:
try:
return json.loads(m.group(1))
except Exception:
pass
# Fallback: Extract the extra= dict and coerce to JSON
stats_match = re.search(r'extra=({.*?})', log_message)
if not stats_match:
return None
extra_str = stats_match.group(1)
extra_str = extra_str.replace("'", '"')
extra_str = extra_str.replace('None', 'null')
extra_str = extra_str.replace('True', 'true')
extra_str = extra_str.replace('False', 'false')
stats_data = json.loads(extra_str)
# Support multiple stat types for different workflows
stat_type = stats_data.get("stats_type")
if stat_type in ["fuzzing_live_update", "scan_progress", "analysis_update", "live_stats"]:
return stats_data
except Exception as e:
logger.debug(f"Error parsing log stats: {e}")
return None
# Global instance
prefect_stats_monitor = PrefectStatsMonitor()