mirror of
https://github.com/FuzzingLabs/fuzzforge_ai.git
synced 2026-02-13 07:12:50 +00:00
Compare commits
4 Commits
dev
...
chore-code
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a271a6bef7 | ||
|
|
40bbb18795 | ||
|
|
a810e29f76 | ||
|
|
1dc0d967b3 |
6
.github/workflows/test.yml
vendored
6
.github/workflows/test.yml
vendored
@@ -110,7 +110,7 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install ruff mypy
|
||||
pip install ruff mypy bandit
|
||||
|
||||
- name: Run ruff
|
||||
run: ruff check backend/src backend/toolbox backend/tests backend/benchmarks --output-format=github
|
||||
@@ -119,6 +119,10 @@ jobs:
|
||||
run: mypy backend/src backend/toolbox || true
|
||||
continue-on-error: true
|
||||
|
||||
- name: Run bandit (continue on error)
|
||||
run: bandit --recursive backend/src || true
|
||||
continue-on-error: true
|
||||
|
||||
unit-tests:
|
||||
name: Unit Tests
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
19
backend/Makefile
Normal file
19
backend/Makefile
Normal file
@@ -0,0 +1,19 @@
|
||||
SOURCES=./src
|
||||
TESTS=./tests
|
||||
|
||||
.PHONY: bandit format mypy pytest ruff
|
||||
|
||||
bandit:
|
||||
uv run bandit --recursive $(SOURCES)
|
||||
|
||||
format:
|
||||
uv run ruff format $(SOURCES) $(TESTS)
|
||||
|
||||
mypy:
|
||||
uv run mypy $(SOURCES) $(TESTS)
|
||||
|
||||
pytest:
|
||||
PYTHONPATH=./toolbox uv run pytest $(TESTS)
|
||||
|
||||
ruff:
|
||||
uv run ruff check --fix $(SOURCES) $(TESTS)
|
||||
@@ -6,28 +6,31 @@ authors = []
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.11"
|
||||
dependencies = [
|
||||
"fastapi>=0.116.1",
|
||||
"temporalio>=1.6.0",
|
||||
"boto3>=1.34.0",
|
||||
"pydantic>=2.0.0",
|
||||
"pyyaml>=6.0",
|
||||
"docker>=7.0.0",
|
||||
"aiofiles>=23.0.0",
|
||||
"uvicorn>=0.30.0",
|
||||
"aiohttp>=3.12.15",
|
||||
"fastmcp",
|
||||
"aiofiles==25.1.0",
|
||||
"aiohttp==3.13.2",
|
||||
"boto3==1.40.68",
|
||||
"docker==7.1.0",
|
||||
"fastapi==0.121.0",
|
||||
"fastmcp==2.13.0.2",
|
||||
"pydantic==2.12.4",
|
||||
"pyyaml==6.0.3",
|
||||
"temporalio==1.18.2",
|
||||
"uvicorn==0.38.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"pytest>=8.0.0",
|
||||
"pytest-asyncio>=0.23.0",
|
||||
"pytest-benchmark>=4.0.0",
|
||||
"pytest-cov>=5.0.0",
|
||||
"pytest-xdist>=3.5.0",
|
||||
"pytest-mock>=3.12.0",
|
||||
"httpx>=0.27.0",
|
||||
"ruff>=0.1.0",
|
||||
lints = [
|
||||
"bandit==1.8.6",
|
||||
"mypy==1.18.2",
|
||||
"ruff==0.14.4",
|
||||
]
|
||||
tests = [
|
||||
"pytest==8.4.2",
|
||||
"pytest-asyncio==1.2.0",
|
||||
"pytest-benchmark==5.2.1",
|
||||
"pytest-cov==7.0.0",
|
||||
"pytest-mock==3.15.1",
|
||||
"pytest-xdist==3.8.0",
|
||||
]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
|
||||
11
backend/ruff.toml
Normal file
11
backend/ruff.toml
Normal file
@@ -0,0 +1,11 @@
|
||||
line-length = 120
|
||||
|
||||
[lint]
|
||||
select = [ "ALL" ]
|
||||
ignore = []
|
||||
|
||||
[lint.per-file-ignores]
|
||||
"tests/*" = [
|
||||
"PLR2004", # allowing comparisons using unamed numerical constants in tests
|
||||
"S101", # allowing 'assert' statements in tests
|
||||
]
|
||||
@@ -1,6 +1,4 @@
|
||||
"""
|
||||
API endpoints for fuzzing workflow management and real-time monitoring
|
||||
"""
|
||||
"""API endpoints for fuzzing workflow management and real-time monitoring."""
|
||||
|
||||
# Copyright (c) 2025 FuzzingLabs
|
||||
#
|
||||
@@ -13,32 +11,29 @@ API endpoints for fuzzing workflow management and real-time monitoring
|
||||
#
|
||||
# Additional attribution and requirements are provided in the NOTICE file.
|
||||
|
||||
import logging
|
||||
from typing import List, Dict
|
||||
from fastapi import APIRouter, HTTPException, WebSocket, WebSocketDisconnect
|
||||
from fastapi.responses import StreamingResponse
|
||||
import asyncio
|
||||
import contextlib
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from src.models.findings import (
|
||||
FuzzingStats,
|
||||
CrashReport
|
||||
)
|
||||
from fastapi import APIRouter, HTTPException, WebSocket, WebSocketDisconnect
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from src.models.findings import CrashReport, FuzzingStats
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/fuzzing", tags=["fuzzing"])
|
||||
|
||||
# In-memory storage for real-time stats (in production, use Redis or similar)
|
||||
fuzzing_stats: Dict[str, FuzzingStats] = {}
|
||||
crash_reports: Dict[str, List[CrashReport]] = {}
|
||||
active_connections: Dict[str, List[WebSocket]] = {}
|
||||
fuzzing_stats: dict[str, FuzzingStats] = {}
|
||||
crash_reports: dict[str, list[CrashReport]] = {}
|
||||
active_connections: dict[str, list[WebSocket]] = {}
|
||||
|
||||
|
||||
def initialize_fuzzing_tracking(run_id: str, workflow_name: str):
|
||||
"""
|
||||
Initialize fuzzing tracking for a new run.
|
||||
def initialize_fuzzing_tracking(run_id: str, workflow_name: str) -> None:
|
||||
"""Initialize fuzzing tracking for a new run.
|
||||
|
||||
This function should be called when a workflow is submitted to enable
|
||||
real-time monitoring and stats collection.
|
||||
@@ -46,19 +41,19 @@ def initialize_fuzzing_tracking(run_id: str, workflow_name: str):
|
||||
Args:
|
||||
run_id: The run identifier
|
||||
workflow_name: Name of the workflow
|
||||
|
||||
"""
|
||||
fuzzing_stats[run_id] = FuzzingStats(
|
||||
run_id=run_id,
|
||||
workflow=workflow_name
|
||||
workflow=workflow_name,
|
||||
)
|
||||
crash_reports[run_id] = []
|
||||
active_connections[run_id] = []
|
||||
|
||||
|
||||
@router.get("/{run_id}/stats", response_model=FuzzingStats)
|
||||
@router.get("/{run_id}/stats")
|
||||
async def get_fuzzing_stats(run_id: str) -> FuzzingStats:
|
||||
"""
|
||||
Get current fuzzing statistics for a run.
|
||||
"""Get current fuzzing statistics for a run.
|
||||
|
||||
Args:
|
||||
run_id: The fuzzing run ID
|
||||
@@ -68,20 +63,20 @@ async def get_fuzzing_stats(run_id: str) -> FuzzingStats:
|
||||
|
||||
Raises:
|
||||
HTTPException: 404 if run not found
|
||||
|
||||
"""
|
||||
if run_id not in fuzzing_stats:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Fuzzing run not found: {run_id}"
|
||||
detail=f"Fuzzing run not found: {run_id}",
|
||||
)
|
||||
|
||||
return fuzzing_stats[run_id]
|
||||
|
||||
|
||||
@router.get("/{run_id}/crashes", response_model=List[CrashReport])
|
||||
async def get_crash_reports(run_id: str) -> List[CrashReport]:
|
||||
"""
|
||||
Get crash reports for a fuzzing run.
|
||||
@router.get("/{run_id}/crashes")
|
||||
async def get_crash_reports(run_id: str) -> list[CrashReport]:
|
||||
"""Get crash reports for a fuzzing run.
|
||||
|
||||
Args:
|
||||
run_id: The fuzzing run ID
|
||||
@@ -91,11 +86,12 @@ async def get_crash_reports(run_id: str) -> List[CrashReport]:
|
||||
|
||||
Raises:
|
||||
HTTPException: 404 if run not found
|
||||
|
||||
"""
|
||||
if run_id not in crash_reports:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Fuzzing run not found: {run_id}"
|
||||
detail=f"Fuzzing run not found: {run_id}",
|
||||
)
|
||||
|
||||
return crash_reports[run_id]
|
||||
@@ -103,8 +99,7 @@ async def get_crash_reports(run_id: str) -> List[CrashReport]:
|
||||
|
||||
@router.post("/{run_id}/stats")
|
||||
async def update_fuzzing_stats(run_id: str, stats: FuzzingStats):
|
||||
"""
|
||||
Update fuzzing statistics (called by fuzzing workflows).
|
||||
"""Update fuzzing statistics (called by fuzzing workflows).
|
||||
|
||||
Args:
|
||||
run_id: The fuzzing run ID
|
||||
@@ -112,18 +107,19 @@ async def update_fuzzing_stats(run_id: str, stats: FuzzingStats):
|
||||
|
||||
Raises:
|
||||
HTTPException: 404 if run not found
|
||||
|
||||
"""
|
||||
if run_id not in fuzzing_stats:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Fuzzing run not found: {run_id}"
|
||||
detail=f"Fuzzing run not found: {run_id}",
|
||||
)
|
||||
|
||||
# Update stats
|
||||
fuzzing_stats[run_id] = stats
|
||||
|
||||
# Debug: log reception for live instrumentation
|
||||
try:
|
||||
with contextlib.suppress(Exception):
|
||||
logger.info(
|
||||
"Received fuzzing stats update: run_id=%s exec=%s eps=%.2f crashes=%s corpus=%s coverage=%s elapsed=%ss",
|
||||
run_id,
|
||||
@@ -134,14 +130,12 @@ async def update_fuzzing_stats(run_id: str, stats: FuzzingStats):
|
||||
stats.coverage,
|
||||
stats.elapsed_time,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Notify connected WebSocket clients
|
||||
if run_id in active_connections:
|
||||
message = {
|
||||
"type": "stats_update",
|
||||
"data": stats.model_dump()
|
||||
"data": stats.model_dump(),
|
||||
}
|
||||
for websocket in active_connections[run_id][:]: # Copy to avoid modification during iteration
|
||||
try:
|
||||
@@ -153,12 +147,12 @@ async def update_fuzzing_stats(run_id: str, stats: FuzzingStats):
|
||||
|
||||
@router.post("/{run_id}/crash")
|
||||
async def report_crash(run_id: str, crash: CrashReport):
|
||||
"""
|
||||
Report a new crash (called by fuzzing workflows).
|
||||
"""Report a new crash (called by fuzzing workflows).
|
||||
|
||||
Args:
|
||||
run_id: The fuzzing run ID
|
||||
crash: Crash report details
|
||||
|
||||
"""
|
||||
if run_id not in crash_reports:
|
||||
crash_reports[run_id] = []
|
||||
@@ -175,7 +169,7 @@ async def report_crash(run_id: str, crash: CrashReport):
|
||||
if run_id in active_connections:
|
||||
message = {
|
||||
"type": "crash_report",
|
||||
"data": crash.model_dump()
|
||||
"data": crash.model_dump(),
|
||||
}
|
||||
for websocket in active_connections[run_id][:]:
|
||||
try:
|
||||
@@ -186,12 +180,12 @@ async def report_crash(run_id: str, crash: CrashReport):
|
||||
|
||||
@router.websocket("/{run_id}/live")
|
||||
async def websocket_endpoint(websocket: WebSocket, run_id: str):
|
||||
"""
|
||||
WebSocket endpoint for real-time fuzzing updates.
|
||||
"""WebSocket endpoint for real-time fuzzing updates.
|
||||
|
||||
Args:
|
||||
websocket: WebSocket connection
|
||||
run_id: The fuzzing run ID to monitor
|
||||
|
||||
"""
|
||||
await websocket.accept()
|
||||
|
||||
@@ -223,7 +217,7 @@ async def websocket_endpoint(websocket: WebSocket, run_id: str):
|
||||
# Echo back for ping-pong
|
||||
if data == "ping":
|
||||
await websocket.send_text("pong")
|
||||
except asyncio.TimeoutError:
|
||||
except TimeoutError:
|
||||
# Send periodic heartbeat
|
||||
await websocket.send_text(json.dumps({"type": "heartbeat"}))
|
||||
|
||||
@@ -231,31 +225,31 @@ async def websocket_endpoint(websocket: WebSocket, run_id: str):
|
||||
# Clean up connection
|
||||
if run_id in active_connections and websocket in active_connections[run_id]:
|
||||
active_connections[run_id].remove(websocket)
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket error for run {run_id}: {e}")
|
||||
except Exception:
|
||||
logger.exception("WebSocket error for run %s", run_id)
|
||||
if run_id in active_connections and websocket in active_connections[run_id]:
|
||||
active_connections[run_id].remove(websocket)
|
||||
|
||||
|
||||
@router.get("/{run_id}/stream")
|
||||
async def stream_fuzzing_updates(run_id: str):
|
||||
"""
|
||||
Server-Sent Events endpoint for real-time fuzzing updates.
|
||||
"""Server-Sent Events endpoint for real-time fuzzing updates.
|
||||
|
||||
Args:
|
||||
run_id: The fuzzing run ID to monitor
|
||||
|
||||
Returns:
|
||||
Streaming response with real-time updates
|
||||
|
||||
"""
|
||||
if run_id not in fuzzing_stats:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Fuzzing run not found: {run_id}"
|
||||
detail=f"Fuzzing run not found: {run_id}",
|
||||
)
|
||||
|
||||
async def event_stream():
|
||||
"""Generate server-sent events for fuzzing updates"""
|
||||
"""Generate server-sent events for fuzzing updates."""
|
||||
last_stats_time = datetime.utcnow()
|
||||
|
||||
while True:
|
||||
@@ -276,10 +270,7 @@ async def stream_fuzzing_updates(run_id: str):
|
||||
|
||||
# Send recent crashes
|
||||
if run_id in crash_reports:
|
||||
recent_crashes = [
|
||||
crash for crash in crash_reports[run_id]
|
||||
if crash.timestamp > last_stats_time
|
||||
]
|
||||
recent_crashes = [crash for crash in crash_reports[run_id] if crash.timestamp > last_stats_time]
|
||||
for crash in recent_crashes:
|
||||
event_data = f"data: {json.dumps({'type': 'crash', 'data': crash.model_dump()})}\n\n"
|
||||
yield event_data
|
||||
@@ -287,8 +278,8 @@ async def stream_fuzzing_updates(run_id: str):
|
||||
last_stats_time = datetime.utcnow()
|
||||
await asyncio.sleep(5) # Update every 5 seconds
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in event stream for run {run_id}: {e}")
|
||||
except Exception:
|
||||
logger.exception("Error in event stream for run %s", run_id)
|
||||
break
|
||||
|
||||
return StreamingResponse(
|
||||
@@ -297,17 +288,17 @@ async def stream_fuzzing_updates(run_id: str):
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/{run_id}")
|
||||
async def cleanup_fuzzing_run(run_id: str):
|
||||
"""
|
||||
Clean up fuzzing run data.
|
||||
async def cleanup_fuzzing_run(run_id: str) -> dict[str, str]:
|
||||
"""Clean up fuzzing run data.
|
||||
|
||||
Args:
|
||||
run_id: The fuzzing run ID to clean up
|
||||
|
||||
"""
|
||||
# Clean up tracking data
|
||||
fuzzing_stats.pop(run_id, None)
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
"""
|
||||
API endpoints for workflow run management and findings retrieval
|
||||
"""
|
||||
"""API endpoints for workflow run management and findings retrieval."""
|
||||
|
||||
# Copyright (c) 2025 FuzzingLabs
|
||||
#
|
||||
@@ -14,37 +12,36 @@ API endpoints for workflow run management and findings retrieval
|
||||
# Additional attribution and requirements are provided in the NOTICE file.
|
||||
|
||||
import logging
|
||||
from fastapi import APIRouter, HTTPException, Depends
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
|
||||
from src.main import temporal_mgr
|
||||
from src.models.findings import WorkflowFindings, WorkflowStatus
|
||||
from src.temporal import TemporalManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/runs", tags=["runs"])
|
||||
|
||||
|
||||
def get_temporal_manager():
|
||||
"""Dependency to get the Temporal manager instance"""
|
||||
from src.main import temporal_mgr
|
||||
def get_temporal_manager() -> TemporalManager:
|
||||
"""Dependency to get the Temporal manager instance."""
|
||||
return temporal_mgr
|
||||
|
||||
|
||||
@router.get("/{run_id}/status", response_model=WorkflowStatus)
|
||||
@router.get("/{run_id}/status")
|
||||
async def get_run_status(
|
||||
run_id: str,
|
||||
temporal_mgr=Depends(get_temporal_manager)
|
||||
temporal_mgr: Annotated[TemporalManager, Depends(get_temporal_manager)],
|
||||
) -> WorkflowStatus:
|
||||
"""
|
||||
Get the current status of a workflow run.
|
||||
"""Get the current status of a workflow run.
|
||||
|
||||
Args:
|
||||
run_id: The workflow run ID
|
||||
:param run_id: The workflow run ID
|
||||
:param temporal_mgr: The temporal manager instance.
|
||||
:return: Status information including state, timestamps, and completion flags
|
||||
:raises HTTPException: 404 if run not found
|
||||
|
||||
Returns:
|
||||
Status information including state, timestamps, and completion flags
|
||||
|
||||
Raises:
|
||||
HTTPException: 404 if run not found
|
||||
"""
|
||||
try:
|
||||
status = await temporal_mgr.get_workflow_status(run_id)
|
||||
@@ -56,7 +53,7 @@ async def get_run_status(
|
||||
is_running = workflow_status == "RUNNING"
|
||||
|
||||
# Extract workflow name from run_id (format: workflow_name-unique_id)
|
||||
workflow_name = run_id.rsplit('-', 1)[0] if '-' in run_id else "unknown"
|
||||
workflow_name = run_id.rsplit("-", 1)[0] if "-" in run_id else "unknown"
|
||||
|
||||
return WorkflowStatus(
|
||||
run_id=run_id,
|
||||
@@ -66,33 +63,29 @@ async def get_run_status(
|
||||
is_failed=is_failed,
|
||||
is_running=is_running,
|
||||
created_at=status.get("start_time"),
|
||||
updated_at=status.get("close_time") or status.get("execution_time")
|
||||
updated_at=status.get("close_time") or status.get("execution_time"),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get status for run {run_id}: {e}")
|
||||
logger.exception("Failed to get status for run %s", run_id)
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Run not found: {run_id}"
|
||||
)
|
||||
detail=f"Run not found: {run_id}",
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/{run_id}/findings", response_model=WorkflowFindings)
|
||||
@router.get("/{run_id}/findings")
|
||||
async def get_run_findings(
|
||||
run_id: str,
|
||||
temporal_mgr=Depends(get_temporal_manager)
|
||||
temporal_mgr: Annotated[TemporalManager, Depends(get_temporal_manager)],
|
||||
) -> WorkflowFindings:
|
||||
"""
|
||||
Get the findings from a completed workflow run.
|
||||
"""Get the findings from a completed workflow run.
|
||||
|
||||
Args:
|
||||
run_id: The workflow run ID
|
||||
:param run_id: The workflow run ID
|
||||
:param temporal_mgr: The temporal manager instance.
|
||||
:return: SARIF-formatted findings from the workflow execution
|
||||
:raises HTTPException: 404 if run not found, 400 if run not completed
|
||||
|
||||
Returns:
|
||||
SARIF-formatted findings from the workflow execution
|
||||
|
||||
Raises:
|
||||
HTTPException: 404 if run not found, 400 if run not completed
|
||||
"""
|
||||
try:
|
||||
# Get run status first
|
||||
@@ -103,80 +96,72 @@ async def get_run_findings(
|
||||
if workflow_status == "RUNNING":
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Run {run_id} is still running. Current status: {workflow_status}"
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Run {run_id} not completed. Status: {workflow_status}"
|
||||
detail=f"Run {run_id} is still running. Current status: {workflow_status}",
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Run {run_id} not completed. Status: {workflow_status}",
|
||||
)
|
||||
|
||||
if workflow_status == "FAILED":
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Run {run_id} failed. Status: {workflow_status}"
|
||||
detail=f"Run {run_id} failed. Status: {workflow_status}",
|
||||
)
|
||||
|
||||
# Get the workflow result
|
||||
result = await temporal_mgr.get_workflow_result(run_id)
|
||||
|
||||
# Extract SARIF from result (handle None for backwards compatibility)
|
||||
if isinstance(result, dict):
|
||||
sarif = result.get("sarif") or {}
|
||||
else:
|
||||
sarif = {}
|
||||
sarif = result.get("sarif", {}) if isinstance(result, dict) else {}
|
||||
|
||||
# Extract workflow name from run_id (format: workflow_name-unique_id)
|
||||
workflow_name = run_id.rsplit('-', 1)[0] if '-' in run_id else "unknown"
|
||||
workflow_name = run_id.rsplit("-", 1)[0] if "-" in run_id else "unknown"
|
||||
|
||||
# Metadata
|
||||
metadata = {
|
||||
"completion_time": status.get("close_time"),
|
||||
"workflow_version": "unknown"
|
||||
"workflow_version": "unknown",
|
||||
}
|
||||
|
||||
return WorkflowFindings(
|
||||
workflow=workflow_name,
|
||||
run_id=run_id,
|
||||
sarif=sarif,
|
||||
metadata=metadata
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get findings for run {run_id}: {e}")
|
||||
logger.exception("Failed to get findings for run %s", run_id)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to retrieve findings: {str(e)}"
|
||||
)
|
||||
detail=f"Failed to retrieve findings: {e!s}",
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/{workflow_name}/findings/{run_id}", response_model=WorkflowFindings)
|
||||
@router.get("/{workflow_name}/findings/{run_id}")
|
||||
async def get_workflow_findings(
|
||||
workflow_name: str,
|
||||
run_id: str,
|
||||
temporal_mgr=Depends(get_temporal_manager)
|
||||
temporal_mgr: Annotated[TemporalManager, Depends(get_temporal_manager)],
|
||||
) -> WorkflowFindings:
|
||||
"""
|
||||
Get findings for a specific workflow run.
|
||||
"""Get findings for a specific workflow run.
|
||||
|
||||
Alternative endpoint that includes workflow name in the path for clarity.
|
||||
|
||||
Args:
|
||||
workflow_name: Name of the workflow
|
||||
run_id: The workflow run ID
|
||||
:param workflow_name: Name of the workflow
|
||||
:param run_id: The workflow run ID
|
||||
:param temporal_mgr: The temporal manager instance.
|
||||
:return: SARIF-formatted findings from the workflow execution
|
||||
:raises HTTPException: 404 if workflow or run not found, 400 if run not completed
|
||||
|
||||
Returns:
|
||||
SARIF-formatted findings from the workflow execution
|
||||
|
||||
Raises:
|
||||
HTTPException: 404 if workflow or run not found, 400 if run not completed
|
||||
"""
|
||||
if workflow_name not in temporal_mgr.workflows:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Workflow not found: {workflow_name}"
|
||||
detail=f"Workflow not found: {workflow_name}",
|
||||
)
|
||||
|
||||
# Delegate to the main findings endpoint
|
||||
|
||||
@@ -9,14 +9,12 @@
|
||||
#
|
||||
# Additional attribution and requirements are provided in the NOTICE file.
|
||||
|
||||
"""
|
||||
System information endpoints for FuzzForge API.
|
||||
"""System information endpoints for FuzzForge API.
|
||||
|
||||
Provides system configuration and filesystem paths to CLI for worker management.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Dict
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
@@ -24,9 +22,8 @@ router = APIRouter(prefix="/system", tags=["system"])
|
||||
|
||||
|
||||
@router.get("/info")
|
||||
async def get_system_info() -> Dict[str, str]:
|
||||
"""
|
||||
Get system information including host filesystem paths.
|
||||
async def get_system_info() -> dict[str, str]:
|
||||
"""Get system information including host filesystem paths.
|
||||
|
||||
This endpoint exposes paths needed by the CLI to manage workers via docker-compose.
|
||||
The FUZZFORGE_HOST_ROOT environment variable is set by docker-compose and points
|
||||
@@ -37,6 +34,7 @@ async def get_system_info() -> Dict[str, str]:
|
||||
- host_root: Absolute path to FuzzForge root on host
|
||||
- docker_compose_path: Path to docker-compose.yml on host
|
||||
- workers_dir: Path to workers directory on host
|
||||
|
||||
"""
|
||||
host_root = os.getenv("FUZZFORGE_HOST_ROOT", "")
|
||||
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
"""
|
||||
API endpoints for workflow management with enhanced error handling
|
||||
"""
|
||||
"""API endpoints for workflow management with enhanced error handling."""
|
||||
|
||||
# Copyright (c) 2025 FuzzingLabs
|
||||
#
|
||||
@@ -13,20 +11,24 @@ API endpoints for workflow management with enhanced error handling
|
||||
#
|
||||
# Additional attribution and requirements are provided in the NOTICE file.
|
||||
|
||||
import json
|
||||
import logging
|
||||
import traceback
|
||||
import tempfile
|
||||
from typing import List, Dict, Any, Optional
|
||||
from fastapi import APIRouter, HTTPException, Depends, UploadFile, File, Form
|
||||
from pathlib import Path
|
||||
from typing import Annotated, Any
|
||||
|
||||
from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile
|
||||
|
||||
from src.api.fuzzing import initialize_fuzzing_tracking
|
||||
from src.main import temporal_mgr
|
||||
from src.models.findings import (
|
||||
WorkflowSubmission,
|
||||
WorkflowMetadata,
|
||||
RunSubmissionResponse,
|
||||
WorkflowListItem,
|
||||
RunSubmissionResponse
|
||||
WorkflowMetadata,
|
||||
WorkflowSubmission,
|
||||
)
|
||||
from src.temporal.discovery import WorkflowDiscovery
|
||||
from src.temporal.manager import TemporalManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -43,9 +45,8 @@ ALLOWED_CONTENT_TYPES = [
|
||||
router = APIRouter(prefix="/workflows", tags=["workflows"])
|
||||
|
||||
|
||||
def extract_defaults_from_json_schema(metadata: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Extract default parameter values from JSON Schema format.
|
||||
def extract_defaults_from_json_schema(metadata: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Extract default parameter values from JSON Schema format.
|
||||
|
||||
Converts from:
|
||||
parameters:
|
||||
@@ -61,6 +62,7 @@ def extract_defaults_from_json_schema(metadata: Dict[str, Any]) -> Dict[str, Any
|
||||
|
||||
Returns:
|
||||
Dictionary of parameter defaults
|
||||
|
||||
"""
|
||||
defaults = {}
|
||||
|
||||
@@ -82,19 +84,19 @@ def extract_defaults_from_json_schema(metadata: Dict[str, Any]) -> Dict[str, Any
|
||||
def create_structured_error_response(
|
||||
error_type: str,
|
||||
message: str,
|
||||
workflow_name: Optional[str] = None,
|
||||
run_id: Optional[str] = None,
|
||||
container_info: Optional[Dict[str, Any]] = None,
|
||||
deployment_info: Optional[Dict[str, Any]] = None,
|
||||
suggestions: Optional[List[str]] = None
|
||||
) -> Dict[str, Any]:
|
||||
workflow_name: str | None = None,
|
||||
run_id: str | None = None,
|
||||
container_info: dict[str, Any] | None = None,
|
||||
deployment_info: dict[str, Any] | None = None,
|
||||
suggestions: list[str] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Create a structured error response with rich context."""
|
||||
error_response = {
|
||||
"error": {
|
||||
"type": error_type,
|
||||
"message": message,
|
||||
"timestamp": __import__("datetime").datetime.utcnow().isoformat() + "Z"
|
||||
}
|
||||
"timestamp": __import__("datetime").datetime.utcnow().isoformat() + "Z",
|
||||
},
|
||||
}
|
||||
|
||||
if workflow_name:
|
||||
@@ -115,39 +117,38 @@ def create_structured_error_response(
|
||||
return error_response
|
||||
|
||||
|
||||
def get_temporal_manager():
|
||||
"""Dependency to get the Temporal manager instance"""
|
||||
from src.main import temporal_mgr
|
||||
def get_temporal_manager() -> TemporalManager:
|
||||
"""Dependency to get the Temporal manager instance."""
|
||||
return temporal_mgr
|
||||
|
||||
|
||||
@router.get("/", response_model=List[WorkflowListItem])
|
||||
@router.get("/")
|
||||
async def list_workflows(
|
||||
temporal_mgr=Depends(get_temporal_manager)
|
||||
) -> List[WorkflowListItem]:
|
||||
"""
|
||||
List all discovered workflows with their metadata.
|
||||
temporal_mgr: Annotated[TemporalManager, Depends(get_temporal_manager)],
|
||||
) -> list[WorkflowListItem]:
|
||||
"""List all discovered workflows with their metadata.
|
||||
|
||||
Returns a summary of each workflow including name, version, description,
|
||||
author, and tags.
|
||||
"""
|
||||
workflows = []
|
||||
for name, info in temporal_mgr.workflows.items():
|
||||
workflows.append(WorkflowListItem(
|
||||
name=name,
|
||||
version=info.metadata.get("version", "0.6.0"),
|
||||
description=info.metadata.get("description", ""),
|
||||
author=info.metadata.get("author"),
|
||||
tags=info.metadata.get("tags", [])
|
||||
))
|
||||
workflows.append(
|
||||
WorkflowListItem(
|
||||
name=name,
|
||||
version=info.metadata.get("version", "0.6.0"),
|
||||
description=info.metadata.get("description", ""),
|
||||
author=info.metadata.get("author"),
|
||||
tags=info.metadata.get("tags", []),
|
||||
),
|
||||
)
|
||||
|
||||
return workflows
|
||||
|
||||
|
||||
@router.get("/metadata/schema")
|
||||
async def get_metadata_schema() -> Dict[str, Any]:
|
||||
"""
|
||||
Get the JSON schema for workflow metadata files.
|
||||
async def get_metadata_schema() -> dict[str, Any]:
|
||||
"""Get the JSON schema for workflow metadata files.
|
||||
|
||||
This schema defines the structure and requirements for metadata.yaml files
|
||||
that must accompany each workflow.
|
||||
@@ -155,23 +156,19 @@ async def get_metadata_schema() -> Dict[str, Any]:
|
||||
return WorkflowDiscovery.get_metadata_schema()
|
||||
|
||||
|
||||
@router.get("/{workflow_name}/metadata", response_model=WorkflowMetadata)
|
||||
@router.get("/{workflow_name}/metadata")
|
||||
async def get_workflow_metadata(
|
||||
workflow_name: str,
|
||||
temporal_mgr=Depends(get_temporal_manager)
|
||||
temporal_mgr: Annotated[TemporalManager, Depends(get_temporal_manager)],
|
||||
) -> WorkflowMetadata:
|
||||
"""
|
||||
Get complete metadata for a specific workflow.
|
||||
"""Get complete metadata for a specific workflow.
|
||||
|
||||
Args:
|
||||
workflow_name: Name of the workflow
|
||||
|
||||
Returns:
|
||||
Complete metadata including parameters schema, supported volume modes,
|
||||
:param workflow_name: Name of the workflow
|
||||
:param temporal_mgr: The temporal manager instance.
|
||||
:return: Complete metadata including parameters schema, supported volume modes,
|
||||
required modules, and more.
|
||||
:raises HTTPException: 404 if workflow not found
|
||||
|
||||
Raises:
|
||||
HTTPException: 404 if workflow not found
|
||||
"""
|
||||
if workflow_name not in temporal_mgr.workflows:
|
||||
available_workflows = list(temporal_mgr.workflows.keys())
|
||||
@@ -182,12 +179,12 @@ async def get_workflow_metadata(
|
||||
suggestions=[
|
||||
f"Available workflows: {', '.join(available_workflows)}",
|
||||
"Use GET /workflows/ to see all available workflows",
|
||||
"Check workflow name spelling and case sensitivity"
|
||||
]
|
||||
"Check workflow name spelling and case sensitivity",
|
||||
],
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=error_response
|
||||
detail=error_response,
|
||||
)
|
||||
|
||||
info = temporal_mgr.workflows[workflow_name]
|
||||
@@ -201,28 +198,24 @@ async def get_workflow_metadata(
|
||||
tags=metadata.get("tags", []),
|
||||
parameters=metadata.get("parameters", {}),
|
||||
default_parameters=extract_defaults_from_json_schema(metadata),
|
||||
required_modules=metadata.get("required_modules", [])
|
||||
required_modules=metadata.get("required_modules", []),
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{workflow_name}/submit", response_model=RunSubmissionResponse)
|
||||
@router.post("/{workflow_name}/submit")
|
||||
async def submit_workflow(
|
||||
workflow_name: str,
|
||||
submission: WorkflowSubmission,
|
||||
temporal_mgr=Depends(get_temporal_manager)
|
||||
temporal_mgr: Annotated[TemporalManager, Depends(get_temporal_manager)],
|
||||
) -> RunSubmissionResponse:
|
||||
"""
|
||||
Submit a workflow for execution.
|
||||
"""Submit a workflow for execution.
|
||||
|
||||
Args:
|
||||
workflow_name: Name of the workflow to execute
|
||||
submission: Submission parameters including target path and parameters
|
||||
:param workflow_name: Name of the workflow to execute
|
||||
:param submission: Submission parameters including target path and parameters
|
||||
:param temporal_mgr: The temporal manager instance.
|
||||
:return: Run submission response with run_id and initial status
|
||||
:raises HTTPException: 404 if workflow not found, 400 for invalid parameters
|
||||
|
||||
Returns:
|
||||
Run submission response with run_id and initial status
|
||||
|
||||
Raises:
|
||||
HTTPException: 404 if workflow not found, 400 for invalid parameters
|
||||
"""
|
||||
if workflow_name not in temporal_mgr.workflows:
|
||||
available_workflows = list(temporal_mgr.workflows.keys())
|
||||
@@ -233,25 +226,26 @@ async def submit_workflow(
|
||||
suggestions=[
|
||||
f"Available workflows: {', '.join(available_workflows)}",
|
||||
"Use GET /workflows/ to see all available workflows",
|
||||
"Check workflow name spelling and case sensitivity"
|
||||
]
|
||||
"Check workflow name spelling and case sensitivity",
|
||||
],
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=error_response
|
||||
detail=error_response,
|
||||
)
|
||||
|
||||
try:
|
||||
# Upload target file to MinIO and get target_id
|
||||
target_path = Path(submission.target_path)
|
||||
if not target_path.exists():
|
||||
raise ValueError(f"Target path does not exist: {submission.target_path}")
|
||||
msg = f"Target path does not exist: {submission.target_path}"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Upload target (using anonymous user for now)
|
||||
target_id = await temporal_mgr.upload_target(
|
||||
file_path=target_path,
|
||||
user_id="api-user",
|
||||
metadata={"workflow": workflow_name}
|
||||
metadata={"workflow": workflow_name},
|
||||
)
|
||||
|
||||
# Merge default parameters with user parameters
|
||||
@@ -265,23 +259,22 @@ async def submit_workflow(
|
||||
handle = await temporal_mgr.run_workflow(
|
||||
workflow_name=workflow_name,
|
||||
target_id=target_id,
|
||||
workflow_params=workflow_params
|
||||
workflow_params=workflow_params,
|
||||
)
|
||||
|
||||
run_id = handle.id
|
||||
|
||||
# Initialize fuzzing tracking if this looks like a fuzzing workflow
|
||||
workflow_info = temporal_mgr.workflows.get(workflow_name, {})
|
||||
workflow_tags = workflow_info.metadata.get("tags", []) if hasattr(workflow_info, 'metadata') else []
|
||||
workflow_tags = workflow_info.metadata.get("tags", []) if hasattr(workflow_info, "metadata") else []
|
||||
if "fuzzing" in workflow_tags or "fuzz" in workflow_name.lower():
|
||||
from src.api.fuzzing import initialize_fuzzing_tracking
|
||||
initialize_fuzzing_tracking(run_id, workflow_name)
|
||||
|
||||
return RunSubmissionResponse(
|
||||
run_id=run_id,
|
||||
status="RUNNING",
|
||||
workflow=workflow_name,
|
||||
message=f"Workflow '{workflow_name}' submitted successfully"
|
||||
message=f"Workflow '{workflow_name}' submitted successfully",
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
@@ -293,14 +286,13 @@ async def submit_workflow(
|
||||
suggestions=[
|
||||
"Check parameter types and values",
|
||||
"Use GET /workflows/{workflow_name}/parameters for schema",
|
||||
"Ensure all required parameters are provided"
|
||||
]
|
||||
"Ensure all required parameters are provided",
|
||||
],
|
||||
)
|
||||
raise HTTPException(status_code=400, detail=error_response)
|
||||
raise HTTPException(status_code=400, detail=error_response) from e
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to submit workflow '{workflow_name}': {e}")
|
||||
logger.error(f"Traceback: {traceback.format_exc()}")
|
||||
logger.exception("Failed to submit workflow '%s'", workflow_name)
|
||||
|
||||
# Try to get more context about the error
|
||||
container_info = None
|
||||
@@ -313,47 +305,57 @@ async def submit_workflow(
|
||||
# Detect specific error patterns
|
||||
if "workflow" in error_message.lower() and "not found" in error_message.lower():
|
||||
error_type = "WorkflowError"
|
||||
suggestions.extend([
|
||||
"Check if Temporal server is running and accessible",
|
||||
"Verify workflow workers are running",
|
||||
"Check if workflow is registered with correct vertical",
|
||||
"Ensure Docker is running and has sufficient resources"
|
||||
])
|
||||
suggestions.extend(
|
||||
[
|
||||
"Check if Temporal server is running and accessible",
|
||||
"Verify workflow workers are running",
|
||||
"Check if workflow is registered with correct vertical",
|
||||
"Ensure Docker is running and has sufficient resources",
|
||||
],
|
||||
)
|
||||
|
||||
elif "volume" in error_message.lower() or "mount" in error_message.lower():
|
||||
error_type = "VolumeError"
|
||||
suggestions.extend([
|
||||
"Check if the target path exists and is accessible",
|
||||
"Verify file permissions (Docker needs read access)",
|
||||
"Ensure the path is not in use by another process",
|
||||
"Try using an absolute path instead of relative path"
|
||||
])
|
||||
suggestions.extend(
|
||||
[
|
||||
"Check if the target path exists and is accessible",
|
||||
"Verify file permissions (Docker needs read access)",
|
||||
"Ensure the path is not in use by another process",
|
||||
"Try using an absolute path instead of relative path",
|
||||
],
|
||||
)
|
||||
|
||||
elif "memory" in error_message.lower() or "resource" in error_message.lower():
|
||||
error_type = "ResourceError"
|
||||
suggestions.extend([
|
||||
"Check system memory and CPU availability",
|
||||
"Consider reducing resource limits or dataset size",
|
||||
"Monitor Docker resource usage",
|
||||
"Increase Docker memory limits if needed"
|
||||
])
|
||||
suggestions.extend(
|
||||
[
|
||||
"Check system memory and CPU availability",
|
||||
"Consider reducing resource limits or dataset size",
|
||||
"Monitor Docker resource usage",
|
||||
"Increase Docker memory limits if needed",
|
||||
],
|
||||
)
|
||||
|
||||
elif "image" in error_message.lower():
|
||||
error_type = "ImageError"
|
||||
suggestions.extend([
|
||||
"Check if the workflow image exists",
|
||||
"Verify Docker registry access",
|
||||
"Try rebuilding the workflow image",
|
||||
"Check network connectivity to registries"
|
||||
])
|
||||
suggestions.extend(
|
||||
[
|
||||
"Check if the workflow image exists",
|
||||
"Verify Docker registry access",
|
||||
"Try rebuilding the workflow image",
|
||||
"Check network connectivity to registries",
|
||||
],
|
||||
)
|
||||
|
||||
else:
|
||||
suggestions.extend([
|
||||
"Check FuzzForge backend logs for details",
|
||||
"Verify all services are running (docker-compose up -d)",
|
||||
"Try restarting the workflow deployment",
|
||||
"Contact support if the issue persists"
|
||||
])
|
||||
suggestions.extend(
|
||||
[
|
||||
"Check FuzzForge backend logs for details",
|
||||
"Verify all services are running (docker-compose up -d)",
|
||||
"Try restarting the workflow deployment",
|
||||
"Contact support if the issue persists",
|
||||
],
|
||||
)
|
||||
|
||||
error_response = create_structured_error_response(
|
||||
error_type=error_type,
|
||||
@@ -361,41 +363,35 @@ async def submit_workflow(
|
||||
workflow_name=workflow_name,
|
||||
container_info=container_info,
|
||||
deployment_info=deployment_info,
|
||||
suggestions=suggestions
|
||||
suggestions=suggestions,
|
||||
)
|
||||
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=error_response
|
||||
)
|
||||
detail=error_response,
|
||||
) from e
|
||||
|
||||
|
||||
@router.post("/{workflow_name}/upload-and-submit", response_model=RunSubmissionResponse)
|
||||
@router.post("/{workflow_name}/upload-and-submit")
|
||||
async def upload_and_submit_workflow(
|
||||
workflow_name: str,
|
||||
file: UploadFile = File(..., description="Target file or tarball to analyze"),
|
||||
parameters: Optional[str] = Form(None, description="JSON-encoded workflow parameters"),
|
||||
timeout: Optional[int] = Form(None, description="Timeout in seconds"),
|
||||
temporal_mgr=Depends(get_temporal_manager)
|
||||
temporal_mgr: Annotated[TemporalManager, Depends(get_temporal_manager)],
|
||||
file: Annotated[UploadFile, File(..., description="Target file or tarball to analyze")],
|
||||
parameters: Annotated[str, Form(None, description="JSON-encoded workflow parameters")],
|
||||
) -> RunSubmissionResponse:
|
||||
"""
|
||||
Upload a target file/tarball and submit workflow for execution.
|
||||
"""Upload a target file/tarball and submit workflow for execution.
|
||||
|
||||
This endpoint accepts multipart/form-data uploads and is the recommended
|
||||
way to submit workflows from remote CLI clients.
|
||||
|
||||
Args:
|
||||
workflow_name: Name of the workflow to execute
|
||||
file: Target file or tarball (compressed directory)
|
||||
parameters: JSON string of workflow parameters (optional)
|
||||
timeout: Execution timeout in seconds (optional)
|
||||
:param workflow_name: Name of the workflow to execute
|
||||
:param temporal_mgr: The temporal manager instance.
|
||||
:param file: Target file or tarball (compressed directory)
|
||||
:param parameters: JSON string of workflow parameters (optional)
|
||||
:returns: Run submission response with run_id and initial status
|
||||
:raises HTTPException: 404 if workflow not found, 400 for invalid parameters,
|
||||
413 if file too large
|
||||
|
||||
Returns:
|
||||
Run submission response with run_id and initial status
|
||||
|
||||
Raises:
|
||||
HTTPException: 404 if workflow not found, 400 for invalid parameters,
|
||||
413 if file too large
|
||||
"""
|
||||
if workflow_name not in temporal_mgr.workflows:
|
||||
available_workflows = list(temporal_mgr.workflows.keys())
|
||||
@@ -405,8 +401,8 @@ async def upload_and_submit_workflow(
|
||||
workflow_name=workflow_name,
|
||||
suggestions=[
|
||||
f"Available workflows: {', '.join(available_workflows)}",
|
||||
"Use GET /workflows/ to see all available workflows"
|
||||
]
|
||||
"Use GET /workflows/ to see all available workflows",
|
||||
],
|
||||
)
|
||||
raise HTTPException(status_code=404, detail=error_response)
|
||||
|
||||
@@ -420,10 +416,10 @@ async def upload_and_submit_workflow(
|
||||
# Create temporary file
|
||||
temp_fd, temp_file_path = tempfile.mkstemp(suffix=".tar.gz")
|
||||
|
||||
logger.info(f"Receiving file upload for workflow '{workflow_name}': {file.filename}")
|
||||
logger.info("Receiving file upload for workflow '%s': %s", workflow_name, file.filename)
|
||||
|
||||
# Stream file to disk
|
||||
with open(temp_fd, 'wb') as temp_file:
|
||||
with open(temp_fd, "wb") as temp_file:
|
||||
while True:
|
||||
chunk = await file.read(chunk_size)
|
||||
if not chunk:
|
||||
@@ -442,33 +438,33 @@ async def upload_and_submit_workflow(
|
||||
suggestions=[
|
||||
"Reduce the size of your target directory",
|
||||
"Exclude unnecessary files (build artifacts, dependencies, etc.)",
|
||||
"Consider splitting into smaller analysis targets"
|
||||
]
|
||||
)
|
||||
"Consider splitting into smaller analysis targets",
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
temp_file.write(chunk)
|
||||
|
||||
logger.info(f"Received file: {file_size / (1024**2):.2f} MB")
|
||||
logger.info("Received file: %s MB", f"{file_size / (1024**2):.2f}")
|
||||
|
||||
# Parse parameters
|
||||
workflow_params = {}
|
||||
if parameters:
|
||||
try:
|
||||
import json
|
||||
workflow_params = json.loads(parameters)
|
||||
if not isinstance(workflow_params, dict):
|
||||
raise ValueError("Parameters must be a JSON object")
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
msg = "Parameters must be a JSON object"
|
||||
raise TypeError(msg)
|
||||
except (json.JSONDecodeError, TypeError) as e:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=create_structured_error_response(
|
||||
error_type="InvalidParameters",
|
||||
message=f"Invalid parameters JSON: {e}",
|
||||
workflow_name=workflow_name,
|
||||
suggestions=["Ensure parameters is valid JSON object"]
|
||||
)
|
||||
)
|
||||
suggestions=["Ensure parameters is valid JSON object"],
|
||||
),
|
||||
) from e
|
||||
|
||||
# Upload to MinIO
|
||||
target_id = await temporal_mgr.upload_target(
|
||||
@@ -477,11 +473,11 @@ async def upload_and_submit_workflow(
|
||||
metadata={
|
||||
"workflow": workflow_name,
|
||||
"original_filename": file.filename,
|
||||
"upload_method": "multipart"
|
||||
}
|
||||
"upload_method": "multipart",
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(f"Uploaded to MinIO with target_id: {target_id}")
|
||||
logger.info("Uploaded to MinIO with target_id: %s", target_id)
|
||||
|
||||
# Merge default parameters with user parameters
|
||||
workflow_info = temporal_mgr.workflows.get(workflow_name)
|
||||
@@ -493,74 +489,68 @@ async def upload_and_submit_workflow(
|
||||
handle = await temporal_mgr.run_workflow(
|
||||
workflow_name=workflow_name,
|
||||
target_id=target_id,
|
||||
workflow_params=workflow_params
|
||||
workflow_params=workflow_params,
|
||||
)
|
||||
|
||||
run_id = handle.id
|
||||
|
||||
# Initialize fuzzing tracking if needed
|
||||
workflow_info = temporal_mgr.workflows.get(workflow_name, {})
|
||||
workflow_tags = workflow_info.metadata.get("tags", []) if hasattr(workflow_info, 'metadata') else []
|
||||
workflow_tags = workflow_info.metadata.get("tags", []) if hasattr(workflow_info, "metadata") else []
|
||||
if "fuzzing" in workflow_tags or "fuzz" in workflow_name.lower():
|
||||
from src.api.fuzzing import initialize_fuzzing_tracking
|
||||
initialize_fuzzing_tracking(run_id, workflow_name)
|
||||
|
||||
return RunSubmissionResponse(
|
||||
run_id=run_id,
|
||||
status="RUNNING",
|
||||
workflow=workflow_name,
|
||||
message=f"Workflow '{workflow_name}' submitted successfully with uploaded target"
|
||||
message=f"Workflow '{workflow_name}' submitted successfully with uploaded target",
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to upload and submit workflow '{workflow_name}': {e}")
|
||||
logger.error(f"Traceback: {traceback.format_exc()}")
|
||||
logger.exception("Failed to upload and submit workflow '%s'", workflow_name)
|
||||
|
||||
error_response = create_structured_error_response(
|
||||
error_type="WorkflowSubmissionError",
|
||||
message=f"Failed to process upload and submit workflow: {str(e)}",
|
||||
message=f"Failed to process upload and submit workflow: {e!s}",
|
||||
workflow_name=workflow_name,
|
||||
suggestions=[
|
||||
"Check if the uploaded file is a valid tarball",
|
||||
"Verify MinIO storage is accessible",
|
||||
"Check backend logs for detailed error information",
|
||||
"Ensure Temporal workers are running"
|
||||
]
|
||||
"Ensure Temporal workers are running",
|
||||
],
|
||||
)
|
||||
|
||||
raise HTTPException(status_code=500, detail=error_response)
|
||||
raise HTTPException(status_code=500, detail=error_response) from e
|
||||
|
||||
finally:
|
||||
# Cleanup temporary file
|
||||
if temp_file_path and Path(temp_file_path).exists():
|
||||
try:
|
||||
Path(temp_file_path).unlink()
|
||||
logger.debug(f"Cleaned up temp file: {temp_file_path}")
|
||||
logger.debug("Cleaned up temp file: %s", temp_file_path)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to cleanup temp file {temp_file_path}: {e}")
|
||||
logger.warning("Failed to cleanup temp file %s: %s", temp_file_path, e)
|
||||
|
||||
|
||||
@router.get("/{workflow_name}/worker-info")
|
||||
async def get_workflow_worker_info(
|
||||
workflow_name: str,
|
||||
temporal_mgr=Depends(get_temporal_manager)
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get worker information for a workflow.
|
||||
temporal_mgr: Annotated[TemporalManager, Depends(get_temporal_manager)],
|
||||
) -> dict[str, Any]:
|
||||
"""Get worker information for a workflow.
|
||||
|
||||
Returns details about which worker is required to execute this workflow,
|
||||
including container name, task queue, and vertical.
|
||||
|
||||
Args:
|
||||
workflow_name: Name of the workflow
|
||||
:param workflow_name: Name of the workflow
|
||||
:param temporal_mgr: The temporal manager instance.
|
||||
:return: Worker information including vertical, container name, and task queue
|
||||
:raises HTTPException: 404 if workflow not found
|
||||
|
||||
Returns:
|
||||
Worker information including vertical, container name, and task queue
|
||||
|
||||
Raises:
|
||||
HTTPException: 404 if workflow not found
|
||||
"""
|
||||
if workflow_name not in temporal_mgr.workflows:
|
||||
available_workflows = list(temporal_mgr.workflows.keys())
|
||||
@@ -570,12 +560,12 @@ async def get_workflow_worker_info(
|
||||
workflow_name=workflow_name,
|
||||
suggestions=[
|
||||
f"Available workflows: {', '.join(available_workflows)}",
|
||||
"Use GET /workflows/ to see all available workflows"
|
||||
]
|
||||
"Use GET /workflows/ to see all available workflows",
|
||||
],
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=error_response
|
||||
detail=error_response,
|
||||
)
|
||||
|
||||
info = temporal_mgr.workflows[workflow_name]
|
||||
@@ -591,12 +581,12 @@ async def get_workflow_worker_info(
|
||||
workflow_name=workflow_name,
|
||||
suggestions=[
|
||||
"Check workflow metadata.yaml for 'vertical' field",
|
||||
"Contact workflow author for support"
|
||||
]
|
||||
"Contact workflow author for support",
|
||||
],
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=error_response
|
||||
detail=error_response,
|
||||
)
|
||||
|
||||
return {
|
||||
@@ -604,26 +594,22 @@ async def get_workflow_worker_info(
|
||||
"vertical": vertical,
|
||||
"worker_service": f"worker-{vertical}",
|
||||
"task_queue": f"{vertical}-queue",
|
||||
"required": True
|
||||
"required": True,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/{workflow_name}/parameters")
|
||||
async def get_workflow_parameters(
|
||||
workflow_name: str,
|
||||
temporal_mgr=Depends(get_temporal_manager)
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get the parameters schema for a workflow.
|
||||
temporal_mgr: Annotated[TemporalManager, Depends(get_temporal_manager)],
|
||||
) -> dict[str, Any]:
|
||||
"""Get the parameters schema for a workflow.
|
||||
|
||||
Args:
|
||||
workflow_name: Name of the workflow
|
||||
:param workflow_name: Name of the workflow
|
||||
:param temporal_mgr: The temporal manager instance.
|
||||
:return: Parameters schema with types, descriptions, and defaults
|
||||
:raises HTTPException: 404 if workflow not found
|
||||
|
||||
Returns:
|
||||
Parameters schema with types, descriptions, and defaults
|
||||
|
||||
Raises:
|
||||
HTTPException: 404 if workflow not found
|
||||
"""
|
||||
if workflow_name not in temporal_mgr.workflows:
|
||||
available_workflows = list(temporal_mgr.workflows.keys())
|
||||
@@ -633,12 +619,12 @@ async def get_workflow_parameters(
|
||||
workflow_name=workflow_name,
|
||||
suggestions=[
|
||||
f"Available workflows: {', '.join(available_workflows)}",
|
||||
"Use GET /workflows/ to see all available workflows"
|
||||
]
|
||||
"Use GET /workflows/ to see all available workflows",
|
||||
],
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=error_response
|
||||
detail=error_response,
|
||||
)
|
||||
|
||||
info = temporal_mgr.workflows[workflow_name]
|
||||
@@ -648,10 +634,7 @@ async def get_workflow_parameters(
|
||||
parameters_schema = metadata.get("parameters", {})
|
||||
|
||||
# Extract the actual parameter definitions from JSON schema structure
|
||||
if "properties" in parameters_schema:
|
||||
param_definitions = parameters_schema["properties"]
|
||||
else:
|
||||
param_definitions = parameters_schema
|
||||
param_definitions = parameters_schema.get("properties", parameters_schema)
|
||||
|
||||
# Extract default values from JSON Schema
|
||||
default_params = extract_defaults_from_json_schema(metadata)
|
||||
@@ -661,7 +644,8 @@ async def get_workflow_parameters(
|
||||
"parameters": param_definitions,
|
||||
"default_parameters": default_params,
|
||||
"required_parameters": [
|
||||
name for name, schema in param_definitions.items()
|
||||
name
|
||||
for name, schema in param_definitions.items()
|
||||
if isinstance(schema, dict) and schema.get("required", False)
|
||||
]
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
"""
|
||||
Setup utilities for FuzzForge infrastructure
|
||||
"""
|
||||
"""Setup utilities for FuzzForge infrastructure."""
|
||||
|
||||
# Copyright (c) 2025 FuzzingLabs
|
||||
#
|
||||
@@ -18,9 +16,8 @@ import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def setup_result_storage():
|
||||
"""
|
||||
Setup result storage (MinIO).
|
||||
async def setup_result_storage() -> bool:
|
||||
"""Set up result storage (MinIO).
|
||||
|
||||
MinIO is used for both target upload and result storage.
|
||||
This is a placeholder for any MinIO-specific setup if needed.
|
||||
@@ -31,9 +28,8 @@ async def setup_result_storage():
|
||||
return True
|
||||
|
||||
|
||||
async def validate_infrastructure():
|
||||
"""
|
||||
Validate all required infrastructure components.
|
||||
async def validate_infrastructure() -> None:
|
||||
"""Validate all required infrastructure components.
|
||||
|
||||
This should be called during startup to ensure everything is ready.
|
||||
"""
|
||||
|
||||
@@ -13,20 +13,19 @@ import asyncio
|
||||
import logging
|
||||
import os
|
||||
from contextlib import AsyncExitStack, asynccontextmanager, suppress
|
||||
from typing import Any, Dict, Optional, List
|
||||
from typing import Any
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
from fastmcp import FastMCP
|
||||
from fastmcp.server.http import create_sse_app
|
||||
from starlette.applications import Starlette
|
||||
from starlette.routing import Mount
|
||||
|
||||
from fastmcp.server.http import create_sse_app
|
||||
|
||||
from src.temporal.manager import TemporalManager
|
||||
from src.api import fuzzing, runs, system, workflows
|
||||
from src.core.setup import setup_result_storage, validate_infrastructure
|
||||
from src.api import workflows, runs, fuzzing, system
|
||||
|
||||
from fastmcp import FastMCP
|
||||
from src.temporal.discovery import WorkflowDiscovery
|
||||
from src.temporal.manager import TemporalManager
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -38,12 +37,14 @@ class TemporalBootstrapState:
|
||||
"""Tracks Temporal initialization progress for API and MCP consumers."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize an instance of the class."""
|
||||
self.ready: bool = False
|
||||
self.status: str = "not_started"
|
||||
self.last_error: Optional[str] = None
|
||||
self.last_error: str | None = None
|
||||
self.task_running: bool = False
|
||||
|
||||
def as_dict(self) -> Dict[str, Any]:
|
||||
def as_dict(self) -> dict[str, Any]:
|
||||
"""Return the current state as a Python dictionnary."""
|
||||
return {
|
||||
"ready": self.ready,
|
||||
"status": self.status,
|
||||
@@ -61,7 +62,7 @@ STARTUP_RETRY_MAX_SECONDS = max(
|
||||
int(os.getenv("FUZZFORGE_STARTUP_RETRY_MAX_SECONDS", "60")),
|
||||
)
|
||||
|
||||
temporal_bootstrap_task: Optional[asyncio.Task] = None
|
||||
temporal_bootstrap_task: asyncio.Task | None = None
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# FastAPI application (REST API)
|
||||
@@ -79,17 +80,15 @@ app.include_router(fuzzing.router)
|
||||
app.include_router(system.router)
|
||||
|
||||
|
||||
def get_temporal_status() -> Dict[str, Any]:
|
||||
def get_temporal_status() -> dict[str, Any]:
|
||||
"""Return a snapshot of Temporal bootstrap state for diagnostics."""
|
||||
status = temporal_bootstrap_state.as_dict()
|
||||
status["workflows_loaded"] = len(temporal_mgr.workflows)
|
||||
status["bootstrap_task_running"] = (
|
||||
temporal_bootstrap_task is not None and not temporal_bootstrap_task.done()
|
||||
)
|
||||
status["bootstrap_task_running"] = temporal_bootstrap_task is not None and not temporal_bootstrap_task.done()
|
||||
return status
|
||||
|
||||
|
||||
def _temporal_not_ready_status() -> Optional[Dict[str, Any]]:
|
||||
def _temporal_not_ready_status() -> dict[str, Any] | None:
|
||||
"""Return status details if Temporal is not ready yet."""
|
||||
status = get_temporal_status()
|
||||
if status.get("ready"):
|
||||
@@ -98,7 +97,7 @@ def _temporal_not_ready_status() -> Optional[Dict[str, Any]]:
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root() -> Dict[str, Any]:
|
||||
async def root() -> dict[str, Any]:
|
||||
status = get_temporal_status()
|
||||
return {
|
||||
"name": "FuzzForge API",
|
||||
@@ -110,14 +109,14 @@ async def root() -> Dict[str, Any]:
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health() -> Dict[str, str]:
|
||||
async def health() -> dict[str, str]:
|
||||
status = get_temporal_status()
|
||||
health_status = "healthy" if status.get("ready") else "initializing"
|
||||
return {"status": health_status}
|
||||
|
||||
|
||||
# Map FastAPI OpenAPI operationIds to readable MCP tool names
|
||||
FASTAPI_MCP_NAME_OVERRIDES: Dict[str, str] = {
|
||||
FASTAPI_MCP_NAME_OVERRIDES: dict[str, str] = {
|
||||
"list_workflows_workflows__get": "api_list_workflows",
|
||||
"get_metadata_schema_workflows_metadata_schema_get": "api_get_metadata_schema",
|
||||
"get_workflow_metadata_workflows__workflow_name__metadata_get": "api_get_workflow_metadata",
|
||||
@@ -155,7 +154,6 @@ mcp = FastMCP(name="FuzzForge MCP")
|
||||
|
||||
async def _bootstrap_temporal_with_retries() -> None:
|
||||
"""Initialize Temporal infrastructure with exponential backoff retries."""
|
||||
|
||||
attempt = 0
|
||||
|
||||
while True:
|
||||
@@ -175,7 +173,6 @@ async def _bootstrap_temporal_with_retries() -> None:
|
||||
temporal_bootstrap_state.status = "ready"
|
||||
temporal_bootstrap_state.task_running = False
|
||||
logger.info("Temporal infrastructure ready")
|
||||
return
|
||||
|
||||
except asyncio.CancelledError:
|
||||
temporal_bootstrap_state.status = "cancelled"
|
||||
@@ -204,9 +201,11 @@ async def _bootstrap_temporal_with_retries() -> None:
|
||||
temporal_bootstrap_state.status = "cancelled"
|
||||
temporal_bootstrap_state.task_running = False
|
||||
raise
|
||||
else:
|
||||
return
|
||||
|
||||
|
||||
def _lookup_workflow(workflow_name: str):
|
||||
def _lookup_workflow(workflow_name: str) -> dict[str, Any]:
|
||||
info = temporal_mgr.workflows.get(workflow_name)
|
||||
if not info:
|
||||
return None
|
||||
@@ -222,12 +221,12 @@ def _lookup_workflow(workflow_name: str):
|
||||
"parameters": metadata.get("parameters", {}),
|
||||
"default_parameters": metadata.get("default_parameters", {}),
|
||||
"required_modules": metadata.get("required_modules", []),
|
||||
"default_target_path": default_target_path
|
||||
"default_target_path": default_target_path,
|
||||
}
|
||||
|
||||
|
||||
@mcp.tool
|
||||
async def list_workflows_mcp() -> Dict[str, Any]:
|
||||
async def list_workflows_mcp() -> dict[str, Any]:
|
||||
"""List all discovered workflows and their metadata summary."""
|
||||
not_ready = _temporal_not_ready_status()
|
||||
if not_ready:
|
||||
@@ -241,20 +240,21 @@ async def list_workflows_mcp() -> Dict[str, Any]:
|
||||
for name, info in temporal_mgr.workflows.items():
|
||||
metadata = info.metadata
|
||||
defaults = metadata.get("default_parameters", {})
|
||||
workflows_summary.append({
|
||||
"name": name,
|
||||
"version": metadata.get("version", "0.6.0"),
|
||||
"description": metadata.get("description", ""),
|
||||
"author": metadata.get("author"),
|
||||
"tags": metadata.get("tags", []),
|
||||
"default_target_path": metadata.get("default_target_path")
|
||||
or defaults.get("target_path")
|
||||
})
|
||||
workflows_summary.append(
|
||||
{
|
||||
"name": name,
|
||||
"version": metadata.get("version", "0.6.0"),
|
||||
"description": metadata.get("description", ""),
|
||||
"author": metadata.get("author"),
|
||||
"tags": metadata.get("tags", []),
|
||||
"default_target_path": metadata.get("default_target_path") or defaults.get("target_path"),
|
||||
},
|
||||
)
|
||||
return {"workflows": workflows_summary, "temporal": get_temporal_status()}
|
||||
|
||||
|
||||
@mcp.tool
|
||||
async def get_workflow_metadata_mcp(workflow_name: str) -> Dict[str, Any]:
|
||||
async def get_workflow_metadata_mcp(workflow_name: str) -> dict[str, Any]:
|
||||
"""Fetch detailed metadata for a workflow."""
|
||||
not_ready = _temporal_not_ready_status()
|
||||
if not_ready:
|
||||
@@ -270,7 +270,7 @@ async def get_workflow_metadata_mcp(workflow_name: str) -> Dict[str, Any]:
|
||||
|
||||
|
||||
@mcp.tool
|
||||
async def get_workflow_parameters_mcp(workflow_name: str) -> Dict[str, Any]:
|
||||
async def get_workflow_parameters_mcp(workflow_name: str) -> dict[str, Any]:
|
||||
"""Return the parameter schema and defaults for a workflow."""
|
||||
not_ready = _temporal_not_ready_status()
|
||||
if not_ready:
|
||||
@@ -289,9 +289,8 @@ async def get_workflow_parameters_mcp(workflow_name: str) -> Dict[str, Any]:
|
||||
|
||||
|
||||
@mcp.tool
|
||||
async def get_workflow_metadata_schema_mcp() -> Dict[str, Any]:
|
||||
async def get_workflow_metadata_schema_mcp() -> dict[str, Any]:
|
||||
"""Return the JSON schema describing workflow metadata files."""
|
||||
from src.temporal.discovery import WorkflowDiscovery
|
||||
return WorkflowDiscovery.get_metadata_schema()
|
||||
|
||||
|
||||
@@ -299,8 +298,8 @@ async def get_workflow_metadata_schema_mcp() -> Dict[str, Any]:
|
||||
async def submit_security_scan_mcp(
|
||||
workflow_name: str,
|
||||
target_id: str,
|
||||
parameters: Dict[str, Any] | None = None,
|
||||
) -> Dict[str, Any] | Dict[str, str]:
|
||||
parameters: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any] | dict[str, str]:
|
||||
"""Submit a Temporal workflow via MCP."""
|
||||
try:
|
||||
not_ready = _temporal_not_ready_status()
|
||||
@@ -318,7 +317,7 @@ async def submit_security_scan_mcp(
|
||||
defaults = metadata.get("default_parameters", {})
|
||||
|
||||
parameters = parameters or {}
|
||||
cleaned_parameters: Dict[str, Any] = {**defaults, **parameters}
|
||||
cleaned_parameters: dict[str, Any] = {**defaults, **parameters}
|
||||
|
||||
# Ensure *_config structures default to dicts
|
||||
for key, value in list(cleaned_parameters.items()):
|
||||
@@ -327,9 +326,7 @@ async def submit_security_scan_mcp(
|
||||
|
||||
# Some workflows expect configuration dictionaries even when omitted
|
||||
parameter_definitions = (
|
||||
metadata.get("parameters", {}).get("properties", {})
|
||||
if isinstance(metadata.get("parameters"), dict)
|
||||
else {}
|
||||
metadata.get("parameters", {}).get("properties", {}) if isinstance(metadata.get("parameters"), dict) else {}
|
||||
)
|
||||
for key, definition in parameter_definitions.items():
|
||||
if not isinstance(key, str) or not key.endswith("_config"):
|
||||
@@ -347,6 +344,10 @@ async def submit_security_scan_mcp(
|
||||
workflow_params=cleaned_parameters,
|
||||
)
|
||||
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
logger.exception("MCP submit failed")
|
||||
return {"error": f"Failed to submit workflow: {exc}"}
|
||||
else:
|
||||
return {
|
||||
"run_id": handle.id,
|
||||
"status": "RUNNING",
|
||||
@@ -356,13 +357,10 @@ async def submit_security_scan_mcp(
|
||||
"parameters": cleaned_parameters,
|
||||
"mcp_enabled": True,
|
||||
}
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
logger.exception("MCP submit failed")
|
||||
return {"error": f"Failed to submit workflow: {exc}"}
|
||||
|
||||
|
||||
@mcp.tool
|
||||
async def get_comprehensive_scan_summary(run_id: str) -> Dict[str, Any] | Dict[str, str]:
|
||||
async def get_comprehensive_scan_summary(run_id: str) -> dict[str, Any] | dict[str, str]:
|
||||
"""Return a summary for the given workflow run via MCP."""
|
||||
try:
|
||||
not_ready = _temporal_not_ready_status()
|
||||
@@ -385,7 +383,7 @@ async def get_comprehensive_scan_summary(run_id: str) -> Dict[str, Any] | Dict[s
|
||||
summary = result.get("summary", {})
|
||||
total_findings = summary.get("total_findings", 0)
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not retrieve result for {run_id}: {e}")
|
||||
logger.debug("Could not retrieve result for %s: %s", run_id, e)
|
||||
|
||||
return {
|
||||
"run_id": run_id,
|
||||
@@ -412,7 +410,7 @@ async def get_comprehensive_scan_summary(run_id: str) -> Dict[str, Any] | Dict[s
|
||||
|
||||
|
||||
@mcp.tool
|
||||
async def get_run_status_mcp(run_id: str) -> Dict[str, Any]:
|
||||
async def get_run_status_mcp(run_id: str) -> dict[str, Any]:
|
||||
"""Return current status information for a Temporal run."""
|
||||
try:
|
||||
not_ready = _temporal_not_ready_status()
|
||||
@@ -440,7 +438,7 @@ async def get_run_status_mcp(run_id: str) -> Dict[str, Any]:
|
||||
|
||||
|
||||
@mcp.tool
|
||||
async def get_run_findings_mcp(run_id: str) -> Dict[str, Any]:
|
||||
async def get_run_findings_mcp(run_id: str) -> dict[str, Any]:
|
||||
"""Return SARIF findings for a completed run."""
|
||||
try:
|
||||
not_ready = _temporal_not_ready_status()
|
||||
@@ -463,24 +461,24 @@ async def get_run_findings_mcp(run_id: str) -> Dict[str, Any]:
|
||||
|
||||
sarif = result.get("sarif", {}) if isinstance(result, dict) else {}
|
||||
|
||||
except Exception as exc:
|
||||
logger.exception("MCP findings failed")
|
||||
return {"error": f"Failed to retrieve findings: {exc}"}
|
||||
else:
|
||||
return {
|
||||
"workflow": "unknown",
|
||||
"run_id": run_id,
|
||||
"sarif": sarif,
|
||||
"metadata": metadata,
|
||||
}
|
||||
except Exception as exc:
|
||||
logger.exception("MCP findings failed")
|
||||
return {"error": f"Failed to retrieve findings: {exc}"}
|
||||
|
||||
|
||||
@mcp.tool
|
||||
async def list_recent_runs_mcp(
|
||||
limit: int = 10,
|
||||
workflow_name: str | None = None,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
"""List recent Temporal runs with optional workflow filter."""
|
||||
|
||||
not_ready = _temporal_not_ready_status()
|
||||
if not_ready:
|
||||
return {
|
||||
@@ -505,19 +503,21 @@ async def list_recent_runs_mcp(
|
||||
|
||||
workflows = await temporal_mgr.list_workflows(filter_query, limit_value)
|
||||
|
||||
results: List[Dict[str, Any]] = []
|
||||
results: list[dict[str, Any]] = []
|
||||
for wf in workflows:
|
||||
results.append({
|
||||
"run_id": wf["workflow_id"],
|
||||
"workflow": workflow_name or "unknown",
|
||||
"state": wf["status"],
|
||||
"state_type": wf["status"],
|
||||
"is_completed": wf["status"] in ["COMPLETED", "FAILED", "CANCELLED"],
|
||||
"is_running": wf["status"] == "RUNNING",
|
||||
"is_failed": wf["status"] == "FAILED",
|
||||
"created_at": wf.get("start_time"),
|
||||
"updated_at": wf.get("close_time"),
|
||||
})
|
||||
results.append(
|
||||
{
|
||||
"run_id": wf["workflow_id"],
|
||||
"workflow": workflow_name or "unknown",
|
||||
"state": wf["status"],
|
||||
"state_type": wf["status"],
|
||||
"is_completed": wf["status"] in ["COMPLETED", "FAILED", "CANCELLED"],
|
||||
"is_running": wf["status"] == "RUNNING",
|
||||
"is_failed": wf["status"] == "FAILED",
|
||||
"created_at": wf.get("start_time"),
|
||||
"updated_at": wf.get("close_time"),
|
||||
},
|
||||
)
|
||||
|
||||
return {"runs": results, "temporal": get_temporal_status()}
|
||||
|
||||
@@ -526,12 +526,12 @@ async def list_recent_runs_mcp(
|
||||
return {
|
||||
"runs": [],
|
||||
"temporal": get_temporal_status(),
|
||||
"error": str(exc)
|
||||
"error": str(exc),
|
||||
}
|
||||
|
||||
|
||||
@mcp.tool
|
||||
async def get_fuzzing_stats_mcp(run_id: str) -> Dict[str, Any]:
|
||||
async def get_fuzzing_stats_mcp(run_id: str) -> dict[str, Any]:
|
||||
"""Return fuzzing statistics for a run if available."""
|
||||
not_ready = _temporal_not_ready_status()
|
||||
if not_ready:
|
||||
@@ -555,7 +555,7 @@ async def get_fuzzing_stats_mcp(run_id: str) -> Dict[str, Any]:
|
||||
|
||||
|
||||
@mcp.tool
|
||||
async def get_fuzzing_crash_reports_mcp(run_id: str) -> Dict[str, Any]:
|
||||
async def get_fuzzing_crash_reports_mcp(run_id: str) -> dict[str, Any]:
|
||||
"""Return crash reports collected for a fuzzing run."""
|
||||
not_ready = _temporal_not_ready_status()
|
||||
if not_ready:
|
||||
@@ -571,11 +571,10 @@ async def get_fuzzing_crash_reports_mcp(run_id: str) -> Dict[str, Any]:
|
||||
|
||||
|
||||
@mcp.tool
|
||||
async def get_backend_status_mcp() -> Dict[str, Any]:
|
||||
async def get_backend_status_mcp() -> dict[str, Any]:
|
||||
"""Expose backend readiness, workflows, and registered MCP tools."""
|
||||
|
||||
status = get_temporal_status()
|
||||
response: Dict[str, Any] = {"temporal": status}
|
||||
response: dict[str, Any] = {"temporal": status}
|
||||
|
||||
if status.get("ready"):
|
||||
response["workflows"] = list(temporal_mgr.workflows.keys())
|
||||
@@ -591,7 +590,6 @@ async def get_backend_status_mcp() -> Dict[str, Any]:
|
||||
|
||||
def create_mcp_transport_app() -> Starlette:
|
||||
"""Build a Starlette app serving HTTP + SSE transports on one port."""
|
||||
|
||||
http_app = mcp.http_app(path="/", transport="streamable-http")
|
||||
sse_app = create_sse_app(
|
||||
server=mcp,
|
||||
@@ -609,10 +607,10 @@ def create_mcp_transport_app() -> Starlette:
|
||||
async def lifespan(app: Starlette): # pragma: no cover - integration wiring
|
||||
async with AsyncExitStack() as stack:
|
||||
await stack.enter_async_context(
|
||||
http_app.router.lifespan_context(http_app)
|
||||
http_app.router.lifespan_context(http_app),
|
||||
)
|
||||
await stack.enter_async_context(
|
||||
sse_app.router.lifespan_context(sse_app)
|
||||
sse_app.router.lifespan_context(sse_app),
|
||||
)
|
||||
yield
|
||||
|
||||
@@ -627,6 +625,7 @@ def create_mcp_transport_app() -> Starlette:
|
||||
# Combined lifespan: Temporal init + dedicated MCP transports
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def combined_lifespan(app: FastAPI):
|
||||
global temporal_bootstrap_task, _fastapi_mcp_imported
|
||||
@@ -675,13 +674,14 @@ async def combined_lifespan(app: FastAPI):
|
||||
if getattr(mcp_server, "started", False):
|
||||
return
|
||||
await asyncio.sleep(poll_interval)
|
||||
raise asyncio.TimeoutError
|
||||
raise TimeoutError
|
||||
|
||||
try:
|
||||
await _wait_for_uvicorn_startup()
|
||||
except asyncio.TimeoutError: # pragma: no cover - defensive logging
|
||||
except TimeoutError: # pragma: no cover - defensive logging
|
||||
if mcp_task.done():
|
||||
raise RuntimeError("MCP server failed to start") from mcp_task.exception()
|
||||
msg = "MCP server failed to start"
|
||||
raise RuntimeError(msg) from mcp_task.exception()
|
||||
logger.warning("Timed out waiting for MCP server startup; continuing anyway")
|
||||
|
||||
logger.info("MCP HTTP available at http://0.0.0.0:8010/mcp")
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
"""
|
||||
Models for workflow findings and submissions
|
||||
"""
|
||||
"""Models for workflow findings and submissions."""
|
||||
|
||||
# Copyright (c) 2025 FuzzingLabs
|
||||
#
|
||||
@@ -13,40 +11,43 @@ Models for workflow findings and submissions
|
||||
#
|
||||
# Additional attribution and requirements are provided in the NOTICE file.
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Dict, Any, Optional, List
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class WorkflowFindings(BaseModel):
|
||||
"""Findings from a workflow execution in SARIF format"""
|
||||
"""Findings from a workflow execution in SARIF format."""
|
||||
|
||||
workflow: str = Field(..., description="Workflow name")
|
||||
run_id: str = Field(..., description="Unique run identifier")
|
||||
sarif: Dict[str, Any] = Field(..., description="SARIF formatted findings")
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata")
|
||||
sarif: dict[str, Any] = Field(..., description="SARIF formatted findings")
|
||||
metadata: dict[str, Any] = Field(default_factory=dict, description="Additional metadata")
|
||||
|
||||
|
||||
class WorkflowSubmission(BaseModel):
|
||||
"""
|
||||
Submit a workflow with configurable settings.
|
||||
"""Submit a workflow with configurable settings.
|
||||
|
||||
Note: This model is deprecated in favor of the /upload-and-submit endpoint
|
||||
which handles file uploads directly.
|
||||
"""
|
||||
parameters: Dict[str, Any] = Field(
|
||||
|
||||
parameters: dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Workflow-specific parameters"
|
||||
description="Workflow-specific parameters",
|
||||
)
|
||||
timeout: Optional[int] = Field(
|
||||
timeout: int | None = Field(
|
||||
default=None, # Allow workflow-specific defaults
|
||||
description="Timeout in seconds (None for workflow default)",
|
||||
ge=1,
|
||||
le=604800 # Max 7 days to support fuzzing campaigns
|
||||
le=604800, # Max 7 days to support fuzzing campaigns
|
||||
)
|
||||
|
||||
|
||||
class WorkflowStatus(BaseModel):
|
||||
"""Status of a workflow run"""
|
||||
"""Status of a workflow run."""
|
||||
|
||||
run_id: str = Field(..., description="Unique run identifier")
|
||||
workflow: str = Field(..., description="Workflow name")
|
||||
status: str = Field(..., description="Current status")
|
||||
@@ -58,34 +59,37 @@ class WorkflowStatus(BaseModel):
|
||||
|
||||
|
||||
class WorkflowMetadata(BaseModel):
|
||||
"""Complete metadata for a workflow"""
|
||||
"""Complete metadata for a workflow."""
|
||||
|
||||
name: str = Field(..., description="Workflow name")
|
||||
version: str = Field(..., description="Semantic version")
|
||||
description: str = Field(..., description="Workflow description")
|
||||
author: Optional[str] = Field(None, description="Workflow author")
|
||||
tags: List[str] = Field(default_factory=list, description="Workflow tags")
|
||||
parameters: Dict[str, Any] = Field(..., description="Parameters schema")
|
||||
default_parameters: Dict[str, Any] = Field(
|
||||
author: str | None = Field(None, description="Workflow author")
|
||||
tags: list[str] = Field(default_factory=list, description="Workflow tags")
|
||||
parameters: dict[str, Any] = Field(..., description="Parameters schema")
|
||||
default_parameters: dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Default parameter values"
|
||||
description="Default parameter values",
|
||||
)
|
||||
required_modules: List[str] = Field(
|
||||
required_modules: list[str] = Field(
|
||||
default_factory=list,
|
||||
description="Required module names"
|
||||
description="Required module names",
|
||||
)
|
||||
|
||||
|
||||
class WorkflowListItem(BaseModel):
|
||||
"""Summary information for a workflow in list views"""
|
||||
"""Summary information for a workflow in list views."""
|
||||
|
||||
name: str = Field(..., description="Workflow name")
|
||||
version: str = Field(..., description="Semantic version")
|
||||
description: str = Field(..., description="Workflow description")
|
||||
author: Optional[str] = Field(None, description="Workflow author")
|
||||
tags: List[str] = Field(default_factory=list, description="Workflow tags")
|
||||
author: str | None = Field(None, description="Workflow author")
|
||||
tags: list[str] = Field(default_factory=list, description="Workflow tags")
|
||||
|
||||
|
||||
class RunSubmissionResponse(BaseModel):
|
||||
"""Response after submitting a workflow"""
|
||||
"""Response after submitting a workflow."""
|
||||
|
||||
run_id: str = Field(..., description="Unique run identifier")
|
||||
status: str = Field(..., description="Initial status")
|
||||
workflow: str = Field(..., description="Workflow name")
|
||||
@@ -93,28 +97,30 @@ class RunSubmissionResponse(BaseModel):
|
||||
|
||||
|
||||
class FuzzingStats(BaseModel):
|
||||
"""Real-time fuzzing statistics"""
|
||||
"""Real-time fuzzing statistics."""
|
||||
|
||||
run_id: str = Field(..., description="Unique run identifier")
|
||||
workflow: str = Field(..., description="Workflow name")
|
||||
executions: int = Field(default=0, description="Total executions")
|
||||
executions_per_sec: float = Field(default=0.0, description="Current execution rate")
|
||||
crashes: int = Field(default=0, description="Total crashes found")
|
||||
unique_crashes: int = Field(default=0, description="Unique crashes")
|
||||
coverage: Optional[float] = Field(None, description="Code coverage percentage")
|
||||
coverage: float | None = Field(None, description="Code coverage percentage")
|
||||
corpus_size: int = Field(default=0, description="Current corpus size")
|
||||
elapsed_time: int = Field(default=0, description="Elapsed time in seconds")
|
||||
last_crash_time: Optional[datetime] = Field(None, description="Time of last crash")
|
||||
last_crash_time: datetime | None = Field(None, description="Time of last crash")
|
||||
|
||||
|
||||
class CrashReport(BaseModel):
|
||||
"""Individual crash report from fuzzing"""
|
||||
"""Individual crash report from fuzzing."""
|
||||
|
||||
run_id: str = Field(..., description="Run identifier")
|
||||
crash_id: str = Field(..., description="Unique crash identifier")
|
||||
timestamp: datetime = Field(default_factory=datetime.utcnow)
|
||||
signal: Optional[str] = Field(None, description="Crash signal (SIGSEGV, etc.)")
|
||||
crash_type: Optional[str] = Field(None, description="Type of crash")
|
||||
stack_trace: Optional[str] = Field(None, description="Stack trace")
|
||||
input_file: Optional[str] = Field(None, description="Path to crashing input")
|
||||
reproducer: Optional[str] = Field(None, description="Minimized reproducer")
|
||||
signal: str | None = Field(None, description="Crash signal (SIGSEGV, etc.)")
|
||||
crash_type: str | None = Field(None, description="Type of crash")
|
||||
stack_trace: str | None = Field(None, description="Stack trace")
|
||||
input_file: str | None = Field(None, description="Path to crashing input")
|
||||
reproducer: str | None = Field(None, description="Minimized reproducer")
|
||||
severity: str = Field(default="medium", description="Crash severity")
|
||||
exploitability: Optional[str] = Field(None, description="Exploitability assessment")
|
||||
exploitability: str | None = Field(None, description="Exploitability assessment")
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
"""
|
||||
Storage abstraction layer for FuzzForge.
|
||||
"""Storage abstraction layer for FuzzForge.
|
||||
|
||||
Provides unified interface for storing and retrieving targets and results.
|
||||
"""
|
||||
@@ -7,4 +6,4 @@ Provides unified interface for storing and retrieving targets and results.
|
||||
from .base import StorageBackend
|
||||
from .s3_cached import S3CachedStorage
|
||||
|
||||
__all__ = ["StorageBackend", "S3CachedStorage"]
|
||||
__all__ = ["S3CachedStorage", "StorageBackend"]
|
||||
|
||||
@@ -1,17 +1,15 @@
|
||||
"""
|
||||
Base storage backend interface.
|
||||
"""Base storage backend interface.
|
||||
|
||||
All storage implementations must implement this interface.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, Any
|
||||
from typing import Any
|
||||
|
||||
|
||||
class StorageBackend(ABC):
|
||||
"""
|
||||
Abstract base class for storage backends.
|
||||
"""Abstract base class for storage backends.
|
||||
|
||||
Implementations handle storage and retrieval of:
|
||||
- Uploaded targets (code, binaries, etc.)
|
||||
@@ -24,10 +22,9 @@ class StorageBackend(ABC):
|
||||
self,
|
||||
file_path: Path,
|
||||
user_id: str,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Upload a target file to storage.
|
||||
"""Upload a target file to storage.
|
||||
|
||||
Args:
|
||||
file_path: Local path to file to upload
|
||||
@@ -40,13 +37,12 @@ class StorageBackend(ABC):
|
||||
Raises:
|
||||
FileNotFoundError: If file_path doesn't exist
|
||||
StorageError: If upload fails
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_target(self, target_id: str) -> Path:
|
||||
"""
|
||||
Get target file from storage.
|
||||
"""Get target file from storage.
|
||||
|
||||
Args:
|
||||
target_id: Unique identifier from upload_target()
|
||||
@@ -57,31 +53,29 @@ class StorageBackend(ABC):
|
||||
Raises:
|
||||
FileNotFoundError: If target doesn't exist
|
||||
StorageError: If download fails
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def delete_target(self, target_id: str) -> None:
|
||||
"""
|
||||
Delete target from storage.
|
||||
"""Delete target from storage.
|
||||
|
||||
Args:
|
||||
target_id: Unique identifier to delete
|
||||
|
||||
Raises:
|
||||
StorageError: If deletion fails (doesn't raise if not found)
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def upload_results(
|
||||
self,
|
||||
workflow_id: str,
|
||||
results: Dict[str, Any],
|
||||
results_format: str = "json"
|
||||
results: dict[str, Any],
|
||||
results_format: str = "json",
|
||||
) -> str:
|
||||
"""
|
||||
Upload workflow results to storage.
|
||||
"""Upload workflow results to storage.
|
||||
|
||||
Args:
|
||||
workflow_id: Workflow execution ID
|
||||
@@ -93,13 +87,12 @@ class StorageBackend(ABC):
|
||||
|
||||
Raises:
|
||||
StorageError: If upload fails
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_results(self, workflow_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Get workflow results from storage.
|
||||
async def get_results(self, workflow_id: str) -> dict[str, Any]:
|
||||
"""Get workflow results from storage.
|
||||
|
||||
Args:
|
||||
workflow_id: Workflow execution ID
|
||||
@@ -110,17 +103,16 @@ class StorageBackend(ABC):
|
||||
Raises:
|
||||
FileNotFoundError: If results don't exist
|
||||
StorageError: If download fails
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def list_targets(
|
||||
self,
|
||||
user_id: Optional[str] = None,
|
||||
limit: int = 100
|
||||
) -> list[Dict[str, Any]]:
|
||||
"""
|
||||
List uploaded targets.
|
||||
user_id: str | None = None,
|
||||
limit: int = 100,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""List uploaded targets.
|
||||
|
||||
Args:
|
||||
user_id: Filter by user ID (None = all users)
|
||||
@@ -131,23 +123,21 @@ class StorageBackend(ABC):
|
||||
|
||||
Raises:
|
||||
StorageError: If listing fails
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def cleanup_cache(self) -> int:
|
||||
"""
|
||||
Clean up local cache (LRU eviction).
|
||||
"""Clean up local cache (LRU eviction).
|
||||
|
||||
Returns:
|
||||
Number of files removed
|
||||
|
||||
Raises:
|
||||
StorageError: If cleanup fails
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class StorageError(Exception):
|
||||
"""Base exception for storage operations."""
|
||||
pass
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
"""
|
||||
S3-compatible storage backend with local caching.
|
||||
"""S3-compatible storage backend with local caching.
|
||||
|
||||
Works with MinIO (dev/prod) or AWS S3 (cloud).
|
||||
"""
|
||||
@@ -10,7 +9,7 @@ import os
|
||||
import shutil
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, Any
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
import boto3
|
||||
@@ -22,8 +21,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class S3CachedStorage(StorageBackend):
|
||||
"""
|
||||
S3-compatible storage with local caching.
|
||||
"""S3-compatible storage with local caching.
|
||||
|
||||
Features:
|
||||
- Upload targets to S3/MinIO
|
||||
@@ -34,17 +32,16 @@ class S3CachedStorage(StorageBackend):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
endpoint_url: Optional[str] = None,
|
||||
access_key: Optional[str] = None,
|
||||
secret_key: Optional[str] = None,
|
||||
endpoint_url: str | None = None,
|
||||
access_key: str | None = None,
|
||||
secret_key: str | None = None,
|
||||
bucket: str = "targets",
|
||||
region: str = "us-east-1",
|
||||
use_ssl: bool = False,
|
||||
cache_dir: Optional[Path] = None,
|
||||
cache_max_size_gb: int = 10
|
||||
):
|
||||
"""
|
||||
Initialize S3 storage backend.
|
||||
cache_dir: Path | None = None,
|
||||
cache_max_size_gb: int = 10,
|
||||
) -> None:
|
||||
"""Initialize S3 storage backend.
|
||||
|
||||
Args:
|
||||
endpoint_url: S3 endpoint (None = AWS S3, or MinIO URL)
|
||||
@@ -55,18 +52,19 @@ class S3CachedStorage(StorageBackend):
|
||||
use_ssl: Use HTTPS
|
||||
cache_dir: Local cache directory
|
||||
cache_max_size_gb: Maximum cache size in GB
|
||||
|
||||
"""
|
||||
# Use environment variables as defaults
|
||||
self.endpoint_url = endpoint_url or os.getenv('S3_ENDPOINT', 'http://minio:9000')
|
||||
self.access_key = access_key or os.getenv('S3_ACCESS_KEY', 'fuzzforge')
|
||||
self.secret_key = secret_key or os.getenv('S3_SECRET_KEY', 'fuzzforge123')
|
||||
self.bucket = bucket or os.getenv('S3_BUCKET', 'targets')
|
||||
self.region = region or os.getenv('S3_REGION', 'us-east-1')
|
||||
self.use_ssl = use_ssl or os.getenv('S3_USE_SSL', 'false').lower() == 'true'
|
||||
self.endpoint_url = endpoint_url or os.getenv("S3_ENDPOINT", "http://minio:9000")
|
||||
self.access_key = access_key or os.getenv("S3_ACCESS_KEY", "fuzzforge")
|
||||
self.secret_key = secret_key or os.getenv("S3_SECRET_KEY", "fuzzforge123")
|
||||
self.bucket = bucket or os.getenv("S3_BUCKET", "targets")
|
||||
self.region = region or os.getenv("S3_REGION", "us-east-1")
|
||||
self.use_ssl = use_ssl or os.getenv("S3_USE_SSL", "false").lower() == "true"
|
||||
|
||||
# Cache configuration
|
||||
self.cache_dir = cache_dir or Path(os.getenv('CACHE_DIR', '/tmp/fuzzforge-cache'))
|
||||
self.cache_max_size = cache_max_size_gb * (1024 ** 3) # Convert to bytes
|
||||
self.cache_dir = cache_dir or Path(os.getenv("CACHE_DIR", "/tmp/fuzzforge-cache"))
|
||||
self.cache_max_size = cache_max_size_gb * (1024**3) # Convert to bytes
|
||||
|
||||
# Ensure cache directory exists
|
||||
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
@@ -74,69 +72,75 @@ class S3CachedStorage(StorageBackend):
|
||||
# Initialize S3 client
|
||||
try:
|
||||
self.s3_client = boto3.client(
|
||||
's3',
|
||||
"s3",
|
||||
endpoint_url=self.endpoint_url,
|
||||
aws_access_key_id=self.access_key,
|
||||
aws_secret_access_key=self.secret_key,
|
||||
region_name=self.region,
|
||||
use_ssl=self.use_ssl
|
||||
use_ssl=self.use_ssl,
|
||||
)
|
||||
logger.info(f"Initialized S3 storage: {self.endpoint_url}/{self.bucket}")
|
||||
logger.info("Initialized S3 storage: %s/%s", self.endpoint_url, self.bucket)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize S3 client: {e}")
|
||||
raise StorageError(f"S3 initialization failed: {e}")
|
||||
logger.exception("Failed to initialize S3 client")
|
||||
msg = f"S3 initialization failed: {e}"
|
||||
raise StorageError(msg) from e
|
||||
|
||||
async def upload_target(
|
||||
self,
|
||||
file_path: Path,
|
||||
user_id: str,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> str:
|
||||
"""Upload target file to S3/MinIO."""
|
||||
if not file_path.exists():
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
msg = f"File not found: {file_path}"
|
||||
raise FileNotFoundError(msg)
|
||||
|
||||
# Generate unique target ID
|
||||
target_id = str(uuid4())
|
||||
|
||||
# Prepare metadata
|
||||
upload_metadata = {
|
||||
'user_id': user_id,
|
||||
'uploaded_at': datetime.now().isoformat(),
|
||||
'filename': file_path.name,
|
||||
'size': str(file_path.stat().st_size)
|
||||
"user_id": user_id,
|
||||
"uploaded_at": datetime.now().isoformat(),
|
||||
"filename": file_path.name,
|
||||
"size": str(file_path.stat().st_size),
|
||||
}
|
||||
if metadata:
|
||||
upload_metadata.update(metadata)
|
||||
|
||||
# Upload to S3
|
||||
s3_key = f'{target_id}/target'
|
||||
s3_key = f"{target_id}/target"
|
||||
try:
|
||||
logger.info(f"Uploading target to s3://{self.bucket}/{s3_key}")
|
||||
logger.info("Uploading target to s3://%s/%s", self.bucket, s3_key)
|
||||
|
||||
self.s3_client.upload_file(
|
||||
str(file_path),
|
||||
self.bucket,
|
||||
s3_key,
|
||||
ExtraArgs={
|
||||
'Metadata': upload_metadata
|
||||
}
|
||||
"Metadata": upload_metadata,
|
||||
},
|
||||
)
|
||||
|
||||
file_size_mb = file_path.stat().st_size / (1024 * 1024)
|
||||
logger.info(
|
||||
f"✓ Uploaded target {target_id} "
|
||||
f"({file_path.name}, {file_size_mb:.2f} MB)"
|
||||
"✓ Uploaded target %s (%s, %s MB)",
|
||||
target_id,
|
||||
file_path.name,
|
||||
f"{file_size_mb:.2f}",
|
||||
)
|
||||
|
||||
return target_id
|
||||
|
||||
except ClientError as e:
|
||||
logger.error(f"S3 upload failed: {e}", exc_info=True)
|
||||
raise StorageError(f"Failed to upload target: {e}")
|
||||
logger.exception("S3 upload failed")
|
||||
msg = f"Failed to upload target: {e}"
|
||||
raise StorageError(msg) from e
|
||||
except Exception as e:
|
||||
logger.error(f"Upload failed: {e}", exc_info=True)
|
||||
raise StorageError(f"Upload error: {e}")
|
||||
logger.exception("Upload failed")
|
||||
msg = f"Upload error: {e}"
|
||||
raise StorageError(msg) from e
|
||||
else:
|
||||
return target_id
|
||||
|
||||
async def get_target(self, target_id: str) -> Path:
|
||||
"""Get target from cache or download from S3/MinIO."""
|
||||
@@ -147,105 +151,110 @@ class S3CachedStorage(StorageBackend):
|
||||
if cached_file.exists():
|
||||
# Update access time for LRU
|
||||
cached_file.touch()
|
||||
logger.info(f"Cache HIT: {target_id}")
|
||||
logger.info("Cache HIT: %s", target_id)
|
||||
return cached_file
|
||||
|
||||
# Cache miss - download from S3
|
||||
logger.info(f"Cache MISS: {target_id}, downloading from S3...")
|
||||
logger.info("Cache MISS: %s, downloading from S3...", target_id)
|
||||
|
||||
try:
|
||||
# Create cache directory
|
||||
cache_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Download from S3
|
||||
s3_key = f'{target_id}/target'
|
||||
logger.info(f"Downloading s3://{self.bucket}/{s3_key}")
|
||||
s3_key = f"{target_id}/target"
|
||||
logger.info("Downloading s3://%s/%s", self.bucket, s3_key)
|
||||
|
||||
self.s3_client.download_file(
|
||||
self.bucket,
|
||||
s3_key,
|
||||
str(cached_file)
|
||||
str(cached_file),
|
||||
)
|
||||
|
||||
# Verify download
|
||||
if not cached_file.exists():
|
||||
raise StorageError(f"Downloaded file not found: {cached_file}")
|
||||
msg = f"Downloaded file not found: {cached_file}"
|
||||
raise StorageError(msg)
|
||||
|
||||
file_size_mb = cached_file.stat().st_size / (1024 * 1024)
|
||||
logger.info(f"✓ Downloaded target {target_id} ({file_size_mb:.2f} MB)")
|
||||
|
||||
return cached_file
|
||||
logger.info("✓ Downloaded target %s (%s MB)", target_id, f"{file_size_mb:.2f}")
|
||||
|
||||
except ClientError as e:
|
||||
error_code = e.response.get('Error', {}).get('Code')
|
||||
if error_code in ['404', 'NoSuchKey']:
|
||||
logger.error(f"Target not found: {target_id}")
|
||||
raise FileNotFoundError(f"Target {target_id} not found in storage")
|
||||
else:
|
||||
logger.error(f"S3 download failed: {e}", exc_info=True)
|
||||
raise StorageError(f"Download failed: {e}")
|
||||
error_code = e.response.get("Error", {}).get("Code")
|
||||
if error_code in ["404", "NoSuchKey"]:
|
||||
logger.exception("Target not found: %s", target_id)
|
||||
msg = f"Target {target_id} not found in storage"
|
||||
raise FileNotFoundError(msg) from e
|
||||
logger.exception("S3 download failed")
|
||||
msg = f"Download failed: {e}"
|
||||
raise StorageError(msg) from e
|
||||
except Exception as e:
|
||||
logger.error(f"Download error: {e}", exc_info=True)
|
||||
logger.exception("Download error")
|
||||
# Cleanup partial download
|
||||
if cache_path.exists():
|
||||
shutil.rmtree(cache_path, ignore_errors=True)
|
||||
raise StorageError(f"Download error: {e}")
|
||||
msg = f"Download error: {e}"
|
||||
raise StorageError(msg) from e
|
||||
else:
|
||||
return cached_file
|
||||
|
||||
async def delete_target(self, target_id: str) -> None:
|
||||
"""Delete target from S3/MinIO."""
|
||||
try:
|
||||
s3_key = f'{target_id}/target'
|
||||
logger.info(f"Deleting s3://{self.bucket}/{s3_key}")
|
||||
s3_key = f"{target_id}/target"
|
||||
logger.info("Deleting s3://%s/%s", self.bucket, s3_key)
|
||||
|
||||
self.s3_client.delete_object(
|
||||
Bucket=self.bucket,
|
||||
Key=s3_key
|
||||
Key=s3_key,
|
||||
)
|
||||
|
||||
# Also delete from cache if present
|
||||
cache_path = self.cache_dir / target_id
|
||||
if cache_path.exists():
|
||||
shutil.rmtree(cache_path, ignore_errors=True)
|
||||
logger.info(f"✓ Deleted target {target_id} from S3 and cache")
|
||||
logger.info("✓ Deleted target %s from S3 and cache", target_id)
|
||||
else:
|
||||
logger.info(f"✓ Deleted target {target_id} from S3")
|
||||
logger.info("✓ Deleted target %s from S3", target_id)
|
||||
|
||||
except ClientError as e:
|
||||
logger.error(f"S3 delete failed: {e}", exc_info=True)
|
||||
logger.exception("S3 delete failed")
|
||||
# Don't raise error if object doesn't exist
|
||||
if e.response.get('Error', {}).get('Code') not in ['404', 'NoSuchKey']:
|
||||
raise StorageError(f"Delete failed: {e}")
|
||||
if e.response.get("Error", {}).get("Code") not in ["404", "NoSuchKey"]:
|
||||
msg = f"Delete failed: {e}"
|
||||
raise StorageError(msg) from e
|
||||
except Exception as e:
|
||||
logger.error(f"Delete error: {e}", exc_info=True)
|
||||
raise StorageError(f"Delete error: {e}")
|
||||
logger.exception("Delete error")
|
||||
msg = f"Delete error: {e}"
|
||||
raise StorageError(msg) from e
|
||||
|
||||
async def upload_results(
|
||||
self,
|
||||
workflow_id: str,
|
||||
results: Dict[str, Any],
|
||||
results_format: str = "json"
|
||||
results: dict[str, Any],
|
||||
results_format: str = "json",
|
||||
) -> str:
|
||||
"""Upload workflow results to S3/MinIO."""
|
||||
try:
|
||||
# Prepare results content
|
||||
if results_format == "json":
|
||||
content = json.dumps(results, indent=2).encode('utf-8')
|
||||
content_type = 'application/json'
|
||||
file_ext = 'json'
|
||||
content = json.dumps(results, indent=2).encode("utf-8")
|
||||
content_type = "application/json"
|
||||
file_ext = "json"
|
||||
elif results_format == "sarif":
|
||||
content = json.dumps(results, indent=2).encode('utf-8')
|
||||
content_type = 'application/sarif+json'
|
||||
file_ext = 'sarif'
|
||||
content = json.dumps(results, indent=2).encode("utf-8")
|
||||
content_type = "application/sarif+json"
|
||||
file_ext = "sarif"
|
||||
else:
|
||||
content = json.dumps(results, indent=2).encode('utf-8')
|
||||
content_type = 'application/json'
|
||||
file_ext = 'json'
|
||||
content = json.dumps(results, indent=2).encode("utf-8")
|
||||
content_type = "application/json"
|
||||
file_ext = "json"
|
||||
|
||||
# Upload to results bucket
|
||||
results_bucket = 'results'
|
||||
s3_key = f'{workflow_id}/results.{file_ext}'
|
||||
results_bucket = "results"
|
||||
s3_key = f"{workflow_id}/results.{file_ext}"
|
||||
|
||||
logger.info(f"Uploading results to s3://{results_bucket}/{s3_key}")
|
||||
logger.info("Uploading results to s3://%s/%s", results_bucket, s3_key)
|
||||
|
||||
self.s3_client.put_object(
|
||||
Bucket=results_bucket,
|
||||
@@ -253,95 +262,103 @@ class S3CachedStorage(StorageBackend):
|
||||
Body=content,
|
||||
ContentType=content_type,
|
||||
Metadata={
|
||||
'workflow_id': workflow_id,
|
||||
'format': results_format,
|
||||
'uploaded_at': datetime.now().isoformat()
|
||||
}
|
||||
"workflow_id": workflow_id,
|
||||
"format": results_format,
|
||||
"uploaded_at": datetime.now().isoformat(),
|
||||
},
|
||||
)
|
||||
|
||||
# Construct URL
|
||||
results_url = f"{self.endpoint_url}/{results_bucket}/{s3_key}"
|
||||
logger.info(f"✓ Uploaded results: {results_url}")
|
||||
|
||||
return results_url
|
||||
logger.info("✓ Uploaded results: %s", results_url)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Results upload failed: {e}", exc_info=True)
|
||||
raise StorageError(f"Results upload failed: {e}")
|
||||
logger.exception("Results upload failed")
|
||||
msg = f"Results upload failed: {e}"
|
||||
raise StorageError(msg) from e
|
||||
else:
|
||||
return results_url
|
||||
|
||||
async def get_results(self, workflow_id: str) -> Dict[str, Any]:
|
||||
async def get_results(self, workflow_id: str) -> dict[str, Any]:
|
||||
"""Get workflow results from S3/MinIO."""
|
||||
try:
|
||||
results_bucket = 'results'
|
||||
s3_key = f'{workflow_id}/results.json'
|
||||
results_bucket = "results"
|
||||
s3_key = f"{workflow_id}/results.json"
|
||||
|
||||
logger.info(f"Downloading results from s3://{results_bucket}/{s3_key}")
|
||||
logger.info("Downloading results from s3://%s/%s", results_bucket, s3_key)
|
||||
|
||||
response = self.s3_client.get_object(
|
||||
Bucket=results_bucket,
|
||||
Key=s3_key
|
||||
Key=s3_key,
|
||||
)
|
||||
|
||||
content = response['Body'].read().decode('utf-8')
|
||||
content = response["Body"].read().decode("utf-8")
|
||||
results = json.loads(content)
|
||||
|
||||
logger.info(f"✓ Downloaded results for workflow {workflow_id}")
|
||||
return results
|
||||
logger.info("✓ Downloaded results for workflow %s", workflow_id)
|
||||
|
||||
except ClientError as e:
|
||||
error_code = e.response.get('Error', {}).get('Code')
|
||||
if error_code in ['404', 'NoSuchKey']:
|
||||
logger.error(f"Results not found: {workflow_id}")
|
||||
raise FileNotFoundError(f"Results for workflow {workflow_id} not found")
|
||||
else:
|
||||
logger.error(f"Results download failed: {e}", exc_info=True)
|
||||
raise StorageError(f"Results download failed: {e}")
|
||||
error_code = e.response.get("Error", {}).get("Code")
|
||||
if error_code in ["404", "NoSuchKey"]:
|
||||
logger.exception("Results not found: %s", workflow_id)
|
||||
msg = f"Results for workflow {workflow_id} not found"
|
||||
raise FileNotFoundError(msg) from e
|
||||
logger.exception("Results download failed")
|
||||
msg = f"Results download failed: {e}"
|
||||
raise StorageError(msg) from e
|
||||
except Exception as e:
|
||||
logger.error(f"Results download error: {e}", exc_info=True)
|
||||
raise StorageError(f"Results download error: {e}")
|
||||
logger.exception("Results download error")
|
||||
msg = f"Results download error: {e}"
|
||||
raise StorageError(msg) from e
|
||||
else:
|
||||
return results
|
||||
|
||||
async def list_targets(
|
||||
self,
|
||||
user_id: Optional[str] = None,
|
||||
limit: int = 100
|
||||
) -> list[Dict[str, Any]]:
|
||||
user_id: str | None = None,
|
||||
limit: int = 100,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""List uploaded targets."""
|
||||
try:
|
||||
targets = []
|
||||
paginator = self.s3_client.get_paginator('list_objects_v2')
|
||||
paginator = self.s3_client.get_paginator("list_objects_v2")
|
||||
|
||||
for page in paginator.paginate(Bucket=self.bucket, PaginationConfig={'MaxItems': limit}):
|
||||
for obj in page.get('Contents', []):
|
||||
for page in paginator.paginate(Bucket=self.bucket, PaginationConfig={"MaxItems": limit}):
|
||||
for obj in page.get("Contents", []):
|
||||
# Get object metadata
|
||||
try:
|
||||
metadata_response = self.s3_client.head_object(
|
||||
Bucket=self.bucket,
|
||||
Key=obj['Key']
|
||||
Key=obj["Key"],
|
||||
)
|
||||
metadata = metadata_response.get('Metadata', {})
|
||||
metadata = metadata_response.get("Metadata", {})
|
||||
|
||||
# Filter by user_id if specified
|
||||
if user_id and metadata.get('user_id') != user_id:
|
||||
if user_id and metadata.get("user_id") != user_id:
|
||||
continue
|
||||
|
||||
targets.append({
|
||||
'target_id': obj['Key'].split('/')[0],
|
||||
'key': obj['Key'],
|
||||
'size': obj['Size'],
|
||||
'last_modified': obj['LastModified'].isoformat(),
|
||||
'metadata': metadata
|
||||
})
|
||||
targets.append(
|
||||
{
|
||||
"target_id": obj["Key"].split("/")[0],
|
||||
"key": obj["Key"],
|
||||
"size": obj["Size"],
|
||||
"last_modified": obj["LastModified"].isoformat(),
|
||||
"metadata": metadata,
|
||||
},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get metadata for {obj['Key']}: {e}")
|
||||
logger.warning("Failed to get metadata for %s: %s", obj["Key"], e)
|
||||
continue
|
||||
|
||||
logger.info(f"Listed {len(targets)} targets (user_id={user_id})")
|
||||
return targets
|
||||
logger.info("Listed %s targets (user_id=%s)", len(targets), user_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"List targets failed: {e}", exc_info=True)
|
||||
raise StorageError(f"List targets failed: {e}")
|
||||
logger.exception("List targets failed")
|
||||
msg = f"List targets failed: {e}"
|
||||
raise StorageError(msg) from e
|
||||
else:
|
||||
return targets
|
||||
|
||||
async def cleanup_cache(self) -> int:
|
||||
"""Clean up local cache using LRU eviction."""
|
||||
@@ -350,30 +367,33 @@ class S3CachedStorage(StorageBackend):
|
||||
total_size = 0
|
||||
|
||||
# Gather all cached files with metadata
|
||||
for cache_file in self.cache_dir.rglob('*'):
|
||||
for cache_file in self.cache_dir.rglob("*"):
|
||||
if cache_file.is_file():
|
||||
try:
|
||||
stat = cache_file.stat()
|
||||
cache_files.append({
|
||||
'path': cache_file,
|
||||
'size': stat.st_size,
|
||||
'atime': stat.st_atime # Last access time
|
||||
})
|
||||
cache_files.append(
|
||||
{
|
||||
"path": cache_file,
|
||||
"size": stat.st_size,
|
||||
"atime": stat.st_atime, # Last access time
|
||||
},
|
||||
)
|
||||
total_size += stat.st_size
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to stat {cache_file}: {e}")
|
||||
logger.warning("Failed to stat %s: %s", cache_file, e)
|
||||
continue
|
||||
|
||||
# Check if cleanup is needed
|
||||
if total_size <= self.cache_max_size:
|
||||
logger.info(
|
||||
f"Cache size OK: {total_size / (1024**3):.2f} GB / "
|
||||
f"{self.cache_max_size / (1024**3):.2f} GB"
|
||||
"Cache size OK: %s GB / %s GB",
|
||||
f"{total_size / (1024**3):.2f}",
|
||||
f"{self.cache_max_size / (1024**3):.2f}",
|
||||
)
|
||||
return 0
|
||||
|
||||
# Sort by access time (oldest first)
|
||||
cache_files.sort(key=lambda x: x['atime'])
|
||||
cache_files.sort(key=lambda x: x["atime"])
|
||||
|
||||
# Remove files until under limit
|
||||
removed_count = 0
|
||||
@@ -382,42 +402,46 @@ class S3CachedStorage(StorageBackend):
|
||||
break
|
||||
|
||||
try:
|
||||
file_info['path'].unlink()
|
||||
total_size -= file_info['size']
|
||||
file_info["path"].unlink()
|
||||
total_size -= file_info["size"]
|
||||
removed_count += 1
|
||||
logger.debug(f"Evicted from cache: {file_info['path']}")
|
||||
logger.debug("Evicted from cache: %s", file_info["path"])
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete {file_info['path']}: {e}")
|
||||
logger.warning("Failed to delete %s: %s", file_info["path"], e)
|
||||
continue
|
||||
|
||||
logger.info(
|
||||
f"✓ Cache cleanup: removed {removed_count} files, "
|
||||
f"new size: {total_size / (1024**3):.2f} GB"
|
||||
"✓ Cache cleanup: removed %s files, new size: %s GB",
|
||||
removed_count,
|
||||
f"{total_size / (1024**3):.2f}",
|
||||
)
|
||||
return removed_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Cache cleanup failed: {e}", exc_info=True)
|
||||
raise StorageError(f"Cache cleanup failed: {e}")
|
||||
logger.exception("Cache cleanup failed")
|
||||
msg = f"Cache cleanup failed: {e}"
|
||||
raise StorageError(msg) from e
|
||||
|
||||
def get_cache_stats(self) -> Dict[str, Any]:
|
||||
else:
|
||||
return removed_count
|
||||
|
||||
def get_cache_stats(self) -> dict[str, Any]:
|
||||
"""Get cache statistics."""
|
||||
try:
|
||||
total_size = 0
|
||||
file_count = 0
|
||||
|
||||
for cache_file in self.cache_dir.rglob('*'):
|
||||
for cache_file in self.cache_dir.rglob("*"):
|
||||
if cache_file.is_file():
|
||||
total_size += cache_file.stat().st_size
|
||||
file_count += 1
|
||||
|
||||
return {
|
||||
'total_size_bytes': total_size,
|
||||
'total_size_gb': total_size / (1024 ** 3),
|
||||
'file_count': file_count,
|
||||
'max_size_gb': self.cache_max_size / (1024 ** 3),
|
||||
'usage_percent': (total_size / self.cache_max_size) * 100
|
||||
"total_size_bytes": total_size,
|
||||
"total_size_gb": total_size / (1024**3),
|
||||
"file_count": file_count,
|
||||
"max_size_gb": self.cache_max_size / (1024**3),
|
||||
"usage_percent": (total_size / self.cache_max_size) * 100,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get cache stats: {e}")
|
||||
return {'error': str(e)}
|
||||
logger.exception("Failed to get cache stats")
|
||||
return {"error": str(e)}
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
"""
|
||||
Temporal integration for FuzzForge.
|
||||
"""Temporal integration for FuzzForge.
|
||||
|
||||
Handles workflow execution, monitoring, and management.
|
||||
"""
|
||||
|
||||
from .manager import TemporalManager
|
||||
from .discovery import WorkflowDiscovery
|
||||
from .manager import TemporalManager
|
||||
|
||||
__all__ = ["TemporalManager", "WorkflowDiscovery"]
|
||||
|
||||
@@ -1,25 +1,26 @@
|
||||
"""
|
||||
Workflow Discovery for Temporal
|
||||
"""Workflow Discovery for Temporal.
|
||||
|
||||
Discovers workflows from the toolbox/workflows directory
|
||||
and provides metadata about available workflows.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import yaml
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowInfo(BaseModel):
|
||||
"""Information about a discovered workflow"""
|
||||
"""Information about a discovered workflow."""
|
||||
|
||||
name: str = Field(..., description="Workflow name")
|
||||
path: Path = Field(..., description="Path to workflow directory")
|
||||
workflow_file: Path = Field(..., description="Path to workflow.py file")
|
||||
metadata: Dict[str, Any] = Field(..., description="Workflow metadata from YAML")
|
||||
metadata: dict[str, Any] = Field(..., description="Workflow metadata from YAML")
|
||||
workflow_type: str = Field(..., description="Workflow class name")
|
||||
vertical: str = Field(..., description="Vertical (worker type) for this workflow")
|
||||
|
||||
@@ -27,8 +28,7 @@ class WorkflowInfo(BaseModel):
|
||||
|
||||
|
||||
class WorkflowDiscovery:
|
||||
"""
|
||||
Discovers workflows from the filesystem.
|
||||
"""Discovers workflows from the filesystem.
|
||||
|
||||
Scans toolbox/workflows/ for directories containing:
|
||||
- metadata.yaml (required)
|
||||
@@ -38,106 +38,109 @@ class WorkflowDiscovery:
|
||||
which determines which worker pool will execute it.
|
||||
"""
|
||||
|
||||
def __init__(self, workflows_dir: Path):
|
||||
"""
|
||||
Initialize workflow discovery.
|
||||
def __init__(self, workflows_dir: Path) -> None:
|
||||
"""Initialize workflow discovery.
|
||||
|
||||
Args:
|
||||
workflows_dir: Path to the workflows directory
|
||||
|
||||
"""
|
||||
self.workflows_dir = workflows_dir
|
||||
if not self.workflows_dir.exists():
|
||||
self.workflows_dir.mkdir(parents=True, exist_ok=True)
|
||||
logger.info(f"Created workflows directory: {self.workflows_dir}")
|
||||
logger.info("Created workflows directory: %s", self.workflows_dir)
|
||||
|
||||
async def discover_workflows(self) -> Dict[str, WorkflowInfo]:
|
||||
"""
|
||||
Discover workflows by scanning the workflows directory.
|
||||
async def discover_workflows(self) -> dict[str, WorkflowInfo]:
|
||||
"""Discover workflows by scanning the workflows directory.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping workflow names to their information
|
||||
|
||||
"""
|
||||
workflows = {}
|
||||
|
||||
logger.info(f"Scanning for workflows in: {self.workflows_dir}")
|
||||
logger.info("Scanning for workflows in: %s", self.workflows_dir)
|
||||
|
||||
for workflow_dir in self.workflows_dir.iterdir():
|
||||
if not workflow_dir.is_dir():
|
||||
continue
|
||||
|
||||
# Skip special directories
|
||||
if workflow_dir.name.startswith('.') or workflow_dir.name == '__pycache__':
|
||||
if workflow_dir.name.startswith(".") or workflow_dir.name == "__pycache__":
|
||||
continue
|
||||
|
||||
metadata_file = workflow_dir / "metadata.yaml"
|
||||
if not metadata_file.exists():
|
||||
logger.debug(f"No metadata.yaml in {workflow_dir.name}, skipping")
|
||||
logger.debug("No metadata.yaml in %s, skipping", workflow_dir.name)
|
||||
continue
|
||||
|
||||
workflow_file = workflow_dir / "workflow.py"
|
||||
if not workflow_file.exists():
|
||||
logger.warning(
|
||||
f"Workflow {workflow_dir.name} has metadata but no workflow.py, skipping"
|
||||
"Workflow %s has metadata but no workflow.py, skipping",
|
||||
workflow_dir.name,
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
# Parse metadata
|
||||
with open(metadata_file) as f:
|
||||
with metadata_file.open() as f:
|
||||
metadata = yaml.safe_load(f)
|
||||
|
||||
# Validate required fields
|
||||
if 'name' not in metadata:
|
||||
logger.warning(f"Workflow {workflow_dir.name} metadata missing 'name' field")
|
||||
metadata['name'] = workflow_dir.name
|
||||
if "name" not in metadata:
|
||||
logger.warning("Workflow %s metadata missing 'name' field", workflow_dir.name)
|
||||
metadata["name"] = workflow_dir.name
|
||||
|
||||
if 'vertical' not in metadata:
|
||||
if "vertical" not in metadata:
|
||||
logger.warning(
|
||||
f"Workflow {workflow_dir.name} metadata missing 'vertical' field"
|
||||
"Workflow %s metadata missing 'vertical' field",
|
||||
workflow_dir.name,
|
||||
)
|
||||
continue
|
||||
|
||||
# Infer workflow class name from metadata or use convention
|
||||
workflow_type = metadata.get('workflow_class')
|
||||
workflow_type = metadata.get("workflow_class")
|
||||
if not workflow_type:
|
||||
# Convention: convert snake_case to PascalCase + Workflow
|
||||
# e.g., rust_test -> RustTestWorkflow
|
||||
parts = workflow_dir.name.split('_')
|
||||
workflow_type = ''.join(part.capitalize() for part in parts) + 'Workflow'
|
||||
parts = workflow_dir.name.split("_")
|
||||
workflow_type = "".join(part.capitalize() for part in parts) + "Workflow"
|
||||
|
||||
# Create workflow info
|
||||
info = WorkflowInfo(
|
||||
name=metadata['name'],
|
||||
name=metadata["name"],
|
||||
path=workflow_dir,
|
||||
workflow_file=workflow_file,
|
||||
metadata=metadata,
|
||||
workflow_type=workflow_type,
|
||||
vertical=metadata['vertical']
|
||||
vertical=metadata["vertical"],
|
||||
)
|
||||
|
||||
workflows[info.name] = info
|
||||
logger.info(
|
||||
f"✓ Discovered workflow: {info.name} "
|
||||
f"(vertical: {info.vertical}, class: {info.workflow_type})"
|
||||
"✓ Discovered workflow: %s (vertical: %s, class: %s)",
|
||||
info.name,
|
||||
info.vertical,
|
||||
info.workflow_type,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error discovering workflow {workflow_dir.name}: {e}",
|
||||
exc_info=True
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Error discovering workflow %s",
|
||||
workflow_dir.name,
|
||||
)
|
||||
continue
|
||||
|
||||
logger.info(f"Discovered {len(workflows)} workflows")
|
||||
logger.info("Discovered %s workflows", len(workflows))
|
||||
return workflows
|
||||
|
||||
def get_workflows_by_vertical(
|
||||
self,
|
||||
workflows: Dict[str, WorkflowInfo],
|
||||
vertical: str
|
||||
) -> Dict[str, WorkflowInfo]:
|
||||
"""
|
||||
Filter workflows by vertical.
|
||||
workflows: dict[str, WorkflowInfo],
|
||||
vertical: str,
|
||||
) -> dict[str, WorkflowInfo]:
|
||||
"""Filter workflows by vertical.
|
||||
|
||||
Args:
|
||||
workflows: All discovered workflows
|
||||
@@ -145,32 +148,29 @@ class WorkflowDiscovery:
|
||||
|
||||
Returns:
|
||||
Filtered workflows dictionary
|
||||
"""
|
||||
return {
|
||||
name: info
|
||||
for name, info in workflows.items()
|
||||
if info.vertical == vertical
|
||||
}
|
||||
|
||||
def get_available_verticals(self, workflows: Dict[str, WorkflowInfo]) -> list[str]:
|
||||
"""
|
||||
Get list of all verticals from discovered workflows.
|
||||
return {name: info for name, info in workflows.items() if info.vertical == vertical}
|
||||
|
||||
def get_available_verticals(self, workflows: dict[str, WorkflowInfo]) -> list[str]:
|
||||
"""Get list of all verticals from discovered workflows.
|
||||
|
||||
Args:
|
||||
workflows: All discovered workflows
|
||||
|
||||
Returns:
|
||||
List of unique vertical names
|
||||
|
||||
"""
|
||||
return list(set(info.vertical for info in workflows.values()))
|
||||
return {info.vertical for info in workflows.values()}
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_schema() -> Dict[str, Any]:
|
||||
"""
|
||||
Get the JSON schema for workflow metadata.
|
||||
def get_metadata_schema() -> dict[str, Any]:
|
||||
"""Get the JSON schema for workflow metadata.
|
||||
|
||||
Returns:
|
||||
JSON schema dictionary
|
||||
|
||||
"""
|
||||
return {
|
||||
"type": "object",
|
||||
@@ -178,34 +178,34 @@ class WorkflowDiscovery:
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": "string",
|
||||
"description": "Workflow name"
|
||||
"description": "Workflow name",
|
||||
},
|
||||
"version": {
|
||||
"type": "string",
|
||||
"pattern": "^\\d+\\.\\d+\\.\\d+$",
|
||||
"description": "Semantic version (x.y.z)"
|
||||
"description": "Semantic version (x.y.z)",
|
||||
},
|
||||
"vertical": {
|
||||
"type": "string",
|
||||
"description": "Vertical worker type (rust, android, web, etc.)"
|
||||
"description": "Vertical worker type (rust, android, web, etc.)",
|
||||
},
|
||||
"description": {
|
||||
"type": "string",
|
||||
"description": "Workflow description"
|
||||
"description": "Workflow description",
|
||||
},
|
||||
"author": {
|
||||
"type": "string",
|
||||
"description": "Workflow author"
|
||||
"description": "Workflow author",
|
||||
},
|
||||
"category": {
|
||||
"type": "string",
|
||||
"enum": ["comprehensive", "specialized", "fuzzing", "focused"],
|
||||
"description": "Workflow category"
|
||||
"description": "Workflow category",
|
||||
},
|
||||
"tags": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Workflow tags for categorization"
|
||||
"description": "Workflow tags for categorization",
|
||||
},
|
||||
"requirements": {
|
||||
"type": "object",
|
||||
@@ -214,7 +214,7 @@ class WorkflowDiscovery:
|
||||
"tools": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Required security tools"
|
||||
"description": "Required security tools",
|
||||
},
|
||||
"resources": {
|
||||
"type": "object",
|
||||
@@ -223,35 +223,35 @@ class WorkflowDiscovery:
|
||||
"memory": {
|
||||
"type": "string",
|
||||
"pattern": "^\\d+[GMK]i$",
|
||||
"description": "Memory limit (e.g., 1Gi, 512Mi)"
|
||||
"description": "Memory limit (e.g., 1Gi, 512Mi)",
|
||||
},
|
||||
"cpu": {
|
||||
"type": "string",
|
||||
"pattern": "^\\d+m?$",
|
||||
"description": "CPU limit (e.g., 1000m, 2)"
|
||||
"description": "CPU limit (e.g., 1000m, 2)",
|
||||
},
|
||||
"timeout": {
|
||||
"type": "integer",
|
||||
"minimum": 60,
|
||||
"maximum": 7200,
|
||||
"description": "Workflow timeout in seconds"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"description": "Workflow timeout in seconds",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"description": "Workflow parameters schema"
|
||||
"description": "Workflow parameters schema",
|
||||
},
|
||||
"default_parameters": {
|
||||
"type": "object",
|
||||
"description": "Default parameter values"
|
||||
"description": "Default parameter values",
|
||||
},
|
||||
"required_modules": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Required module names"
|
||||
}
|
||||
}
|
||||
"description": "Required module names",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
"""
|
||||
Temporal Manager - Workflow execution and management
|
||||
"""Temporal Manager - Workflow execution and management.
|
||||
|
||||
Handles:
|
||||
- Workflow discovery from toolbox
|
||||
@@ -8,25 +7,26 @@ Handles:
|
||||
- Results retrieval
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, Any
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from temporalio.client import Client, WorkflowHandle
|
||||
from temporalio.common import RetryPolicy
|
||||
from datetime import timedelta
|
||||
|
||||
from src.storage import S3CachedStorage
|
||||
|
||||
from .discovery import WorkflowDiscovery, WorkflowInfo
|
||||
from src.storage import S3CachedStorage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TemporalManager:
|
||||
"""
|
||||
Manages Temporal workflow execution for FuzzForge.
|
||||
"""Manages Temporal workflow execution for FuzzForge.
|
||||
|
||||
This class:
|
||||
- Discovers available workflows from toolbox
|
||||
@@ -37,41 +37,42 @@ class TemporalManager:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
workflows_dir: Optional[Path] = None,
|
||||
temporal_address: Optional[str] = None,
|
||||
workflows_dir: Path | None = None,
|
||||
temporal_address: str | None = None,
|
||||
temporal_namespace: str = "default",
|
||||
storage: Optional[S3CachedStorage] = None
|
||||
):
|
||||
"""
|
||||
Initialize Temporal manager.
|
||||
storage: S3CachedStorage | None = None,
|
||||
) -> None:
|
||||
"""Initialize Temporal manager.
|
||||
|
||||
Args:
|
||||
workflows_dir: Path to workflows directory (default: toolbox/workflows)
|
||||
temporal_address: Temporal server address (default: from env or localhost:7233)
|
||||
temporal_namespace: Temporal namespace
|
||||
storage: Storage backend for file uploads (default: S3CachedStorage)
|
||||
|
||||
"""
|
||||
if workflows_dir is None:
|
||||
workflows_dir = Path("toolbox/workflows")
|
||||
|
||||
self.temporal_address = temporal_address or os.getenv(
|
||||
'TEMPORAL_ADDRESS',
|
||||
'localhost:7233'
|
||||
"TEMPORAL_ADDRESS",
|
||||
"localhost:7233",
|
||||
)
|
||||
self.temporal_namespace = temporal_namespace
|
||||
self.discovery = WorkflowDiscovery(workflows_dir)
|
||||
self.workflows: Dict[str, WorkflowInfo] = {}
|
||||
self.client: Optional[Client] = None
|
||||
self.workflows: dict[str, WorkflowInfo] = {}
|
||||
self.client: Client | None = None
|
||||
|
||||
# Initialize storage backend
|
||||
self.storage = storage or S3CachedStorage()
|
||||
|
||||
logger.info(
|
||||
f"TemporalManager initialized: {self.temporal_address} "
|
||||
f"(namespace: {self.temporal_namespace})"
|
||||
"TemporalManager initialized: %s (namespace: %s)",
|
||||
self.temporal_address,
|
||||
self.temporal_namespace,
|
||||
)
|
||||
|
||||
async def initialize(self):
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize the manager by discovering workflows and connecting to Temporal."""
|
||||
try:
|
||||
# Discover workflows
|
||||
@@ -81,45 +82,46 @@ class TemporalManager:
|
||||
logger.warning("No workflows discovered")
|
||||
else:
|
||||
logger.info(
|
||||
f"Discovered {len(self.workflows)} workflows: "
|
||||
f"{list(self.workflows.keys())}"
|
||||
"Discovered %s workflows: %s",
|
||||
len(self.workflows),
|
||||
list(self.workflows.keys()),
|
||||
)
|
||||
|
||||
# Connect to Temporal
|
||||
self.client = await Client.connect(
|
||||
self.temporal_address,
|
||||
namespace=self.temporal_namespace
|
||||
namespace=self.temporal_namespace,
|
||||
)
|
||||
logger.info(f"✓ Connected to Temporal: {self.temporal_address}")
|
||||
logger.info("✓ Connected to Temporal: %s", self.temporal_address)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize Temporal manager: {e}", exc_info=True)
|
||||
except Exception:
|
||||
logger.exception("Failed to initialize Temporal manager")
|
||||
raise
|
||||
|
||||
async def close(self):
|
||||
async def close(self) -> None:
|
||||
"""Close Temporal client connection."""
|
||||
if self.client:
|
||||
# Temporal client doesn't need explicit close in Python SDK
|
||||
pass
|
||||
|
||||
async def get_workflows(self) -> Dict[str, WorkflowInfo]:
|
||||
"""
|
||||
Get all discovered workflows.
|
||||
async def get_workflows(self) -> dict[str, WorkflowInfo]:
|
||||
"""Get all discovered workflows.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping workflow names to their info
|
||||
|
||||
"""
|
||||
return self.workflows
|
||||
|
||||
async def get_workflow(self, name: str) -> Optional[WorkflowInfo]:
|
||||
"""
|
||||
Get workflow info by name.
|
||||
async def get_workflow(self, name: str) -> WorkflowInfo | None:
|
||||
"""Get workflow info by name.
|
||||
|
||||
Args:
|
||||
name: Workflow name
|
||||
|
||||
Returns:
|
||||
WorkflowInfo or None if not found
|
||||
|
||||
"""
|
||||
return self.workflows.get(name)
|
||||
|
||||
@@ -127,10 +129,9 @@ class TemporalManager:
|
||||
self,
|
||||
file_path: Path,
|
||||
user_id: str,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Upload target file to storage.
|
||||
"""Upload target file to storage.
|
||||
|
||||
Args:
|
||||
file_path: Local path to file
|
||||
@@ -139,20 +140,20 @@ class TemporalManager:
|
||||
|
||||
Returns:
|
||||
Target ID for use in workflow execution
|
||||
|
||||
"""
|
||||
target_id = await self.storage.upload_target(file_path, user_id, metadata)
|
||||
logger.info(f"Uploaded target: {target_id}")
|
||||
logger.info("Uploaded target: %s", target_id)
|
||||
return target_id
|
||||
|
||||
async def run_workflow(
|
||||
self,
|
||||
workflow_name: str,
|
||||
target_id: str,
|
||||
workflow_params: Optional[Dict[str, Any]] = None,
|
||||
workflow_id: Optional[str] = None
|
||||
workflow_params: dict[str, Any] | None = None,
|
||||
workflow_id: str | None = None,
|
||||
) -> WorkflowHandle:
|
||||
"""
|
||||
Execute a workflow.
|
||||
"""Execute a workflow.
|
||||
|
||||
Args:
|
||||
workflow_name: Name of workflow to execute
|
||||
@@ -165,14 +166,17 @@ class TemporalManager:
|
||||
|
||||
Raises:
|
||||
ValueError: If workflow not found or client not initialized
|
||||
|
||||
"""
|
||||
if not self.client:
|
||||
raise ValueError("Temporal client not initialized. Call initialize() first.")
|
||||
msg = "Temporal client not initialized. Call initialize() first."
|
||||
raise ValueError(msg)
|
||||
|
||||
# Get workflow info
|
||||
workflow_info = self.workflows.get(workflow_name)
|
||||
if not workflow_info:
|
||||
raise ValueError(f"Workflow not found: {workflow_name}")
|
||||
msg = f"Workflow not found: {workflow_name}"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Generate workflow ID if not provided
|
||||
if not workflow_id:
|
||||
@@ -188,23 +192,23 @@ class TemporalManager:
|
||||
# Add parameters in order based on metadata schema
|
||||
# This ensures parameters match the workflow signature order
|
||||
# Apply defaults from metadata.yaml if parameter not provided
|
||||
if 'parameters' in workflow_info.metadata:
|
||||
param_schema = workflow_info.metadata['parameters'].get('properties', {})
|
||||
logger.debug(f"Found {len(param_schema)} parameters in schema")
|
||||
if "parameters" in workflow_info.metadata:
|
||||
param_schema = workflow_info.metadata["parameters"].get("properties", {})
|
||||
logger.debug("Found %s parameters in schema", len(param_schema))
|
||||
# Iterate parameters in schema order and add values
|
||||
for param_name in param_schema.keys():
|
||||
for param_name in param_schema:
|
||||
param_spec = param_schema[param_name]
|
||||
|
||||
# Use provided param, or fall back to default from metadata
|
||||
if workflow_params and param_name in workflow_params:
|
||||
param_value = workflow_params[param_name]
|
||||
logger.debug(f"Using provided value for {param_name}: {param_value}")
|
||||
elif 'default' in param_spec:
|
||||
param_value = param_spec['default']
|
||||
logger.debug(f"Using default for {param_name}: {param_value}")
|
||||
logger.debug("Using provided value for %s: %s", param_name, param_value)
|
||||
elif "default" in param_spec:
|
||||
param_value = param_spec["default"]
|
||||
logger.debug("Using default for %s: %s", param_name, param_value)
|
||||
else:
|
||||
param_value = None
|
||||
logger.debug(f"No value or default for {param_name}, using None")
|
||||
logger.debug("No value or default for {param_name}, using None")
|
||||
|
||||
workflow_args.append(param_value)
|
||||
else:
|
||||
@@ -215,11 +219,14 @@ class TemporalManager:
|
||||
task_queue = f"{vertical}-queue"
|
||||
|
||||
logger.info(
|
||||
f"Starting workflow: {workflow_name} "
|
||||
f"(id={workflow_id}, queue={task_queue}, target={target_id})"
|
||||
"Starting workflow: %s (id=%s, queue=%s, target=%s)",
|
||||
workflow_name,
|
||||
workflow_id,
|
||||
task_queue,
|
||||
target_id,
|
||||
)
|
||||
logger.info(f"DEBUG: workflow_args = {workflow_args}")
|
||||
logger.info(f"DEBUG: workflow_params received = {workflow_params}")
|
||||
logger.info("DEBUG: workflow_args = %s", workflow_args)
|
||||
logger.infof("DEBUG: workflow_params received = %s", workflow_params)
|
||||
|
||||
try:
|
||||
# Start workflow execution with positional arguments
|
||||
@@ -231,20 +238,20 @@ class TemporalManager:
|
||||
retry_policy=RetryPolicy(
|
||||
initial_interval=timedelta(seconds=1),
|
||||
maximum_interval=timedelta(minutes=1),
|
||||
maximum_attempts=3
|
||||
)
|
||||
maximum_attempts=3,
|
||||
),
|
||||
)
|
||||
|
||||
logger.info(f"✓ Workflow started: {workflow_id}")
|
||||
logger.info("✓ Workflow started: %s", workflow_id)
|
||||
|
||||
except Exception:
|
||||
logger.exception("Failed to start workflow %s", workflow_name)
|
||||
raise
|
||||
else:
|
||||
return handle
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start workflow {workflow_name}: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_workflow_status(self, workflow_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Get workflow execution status.
|
||||
async def get_workflow_status(self, workflow_id: str) -> dict[str, Any]:
|
||||
"""Get workflow execution status.
|
||||
|
||||
Args:
|
||||
workflow_id: Workflow execution ID
|
||||
@@ -254,9 +261,11 @@ class TemporalManager:
|
||||
|
||||
Raises:
|
||||
ValueError: If client not initialized or workflow not found
|
||||
|
||||
"""
|
||||
if not self.client:
|
||||
raise ValueError("Temporal client not initialized")
|
||||
msg = "Temporal client not initialized"
|
||||
raise ValueError(msg)
|
||||
|
||||
try:
|
||||
# Get workflow handle
|
||||
@@ -274,20 +283,20 @@ class TemporalManager:
|
||||
"task_queue": description.task_queue,
|
||||
}
|
||||
|
||||
logger.info(f"Workflow {workflow_id} status: {status['status']}")
|
||||
return status
|
||||
logger.info("Workflow %s status: %s", workflow_id, status["status"])
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get workflow status: {e}", exc_info=True)
|
||||
except Exception:
|
||||
logger.exception("Failed to get workflow status")
|
||||
raise
|
||||
else:
|
||||
return status
|
||||
|
||||
async def get_workflow_result(
|
||||
self,
|
||||
workflow_id: str,
|
||||
timeout: Optional[timedelta] = None
|
||||
timeout: timedelta | None = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Get workflow execution result (blocking).
|
||||
"""Get workflow execution result (blocking).
|
||||
|
||||
Args:
|
||||
workflow_id: Workflow execution ID
|
||||
@@ -299,60 +308,62 @@ class TemporalManager:
|
||||
Raises:
|
||||
ValueError: If client not initialized
|
||||
TimeoutError: If timeout exceeded
|
||||
|
||||
"""
|
||||
if not self.client:
|
||||
raise ValueError("Temporal client not initialized")
|
||||
msg = "Temporal client not initialized"
|
||||
raise ValueError(msg)
|
||||
|
||||
try:
|
||||
handle = self.client.get_workflow_handle(workflow_id)
|
||||
|
||||
logger.info(f"Waiting for workflow result: {workflow_id}")
|
||||
logger.info("Waiting for workflow result: %s", workflow_id)
|
||||
|
||||
# Wait for workflow to complete and get result
|
||||
if timeout:
|
||||
# Use asyncio timeout if provided
|
||||
import asyncio
|
||||
result = await asyncio.wait_for(handle.result(), timeout=timeout.total_seconds())
|
||||
else:
|
||||
result = await handle.result()
|
||||
|
||||
logger.info(f"✓ Workflow {workflow_id} completed")
|
||||
logger.info("✓ Workflow %s completed", workflow_id)
|
||||
|
||||
except Exception:
|
||||
logger.exception("Failed to get workflow result")
|
||||
raise
|
||||
else:
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get workflow result: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def cancel_workflow(self, workflow_id: str) -> None:
|
||||
"""
|
||||
Cancel a running workflow.
|
||||
"""Cancel a running workflow.
|
||||
|
||||
Args:
|
||||
workflow_id: Workflow execution ID
|
||||
|
||||
Raises:
|
||||
ValueError: If client not initialized
|
||||
|
||||
"""
|
||||
if not self.client:
|
||||
raise ValueError("Temporal client not initialized")
|
||||
msg = "Temporal client not initialized"
|
||||
raise ValueError(msg)
|
||||
|
||||
try:
|
||||
handle = self.client.get_workflow_handle(workflow_id)
|
||||
await handle.cancel()
|
||||
|
||||
logger.info(f"✓ Workflow cancelled: {workflow_id}")
|
||||
logger.info("✓ Workflow cancelled: %s", workflow_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to cancel workflow: {e}", exc_info=True)
|
||||
except Exception:
|
||||
logger.exception("Failed to cancel workflow: %s")
|
||||
raise
|
||||
|
||||
async def list_workflows(
|
||||
self,
|
||||
filter_query: Optional[str] = None,
|
||||
limit: int = 100
|
||||
) -> list[Dict[str, Any]]:
|
||||
"""
|
||||
List workflow executions.
|
||||
filter_query: str | None = None,
|
||||
limit: int = 100,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""List workflow executions.
|
||||
|
||||
Args:
|
||||
filter_query: Optional Temporal list filter query
|
||||
@@ -363,30 +374,36 @@ class TemporalManager:
|
||||
|
||||
Raises:
|
||||
ValueError: If client not initialized
|
||||
|
||||
"""
|
||||
if not self.client:
|
||||
raise ValueError("Temporal client not initialized")
|
||||
msg = "Temporal client not initialized"
|
||||
raise ValueError(msg)
|
||||
|
||||
try:
|
||||
workflows = []
|
||||
|
||||
# Use Temporal's list API
|
||||
async for workflow in self.client.list_workflows(filter_query):
|
||||
workflows.append({
|
||||
"workflow_id": workflow.id,
|
||||
"workflow_type": workflow.workflow_type,
|
||||
"status": workflow.status.name,
|
||||
"start_time": workflow.start_time.isoformat() if workflow.start_time else None,
|
||||
"close_time": workflow.close_time.isoformat() if workflow.close_time else None,
|
||||
"task_queue": workflow.task_queue,
|
||||
})
|
||||
workflows.append(
|
||||
{
|
||||
"workflow_id": workflow.id,
|
||||
"workflow_type": workflow.workflow_type,
|
||||
"status": workflow.status.name,
|
||||
"start_time": workflow.start_time.isoformat() if workflow.start_time else None,
|
||||
"close_time": workflow.close_time.isoformat() if workflow.close_time else None,
|
||||
"task_queue": workflow.task_queue,
|
||||
},
|
||||
)
|
||||
|
||||
if len(workflows) >= limit:
|
||||
break
|
||||
|
||||
logger.info(f"Listed {len(workflows)} workflows")
|
||||
logger.info("Listed %s workflows", len(workflows))
|
||||
return workflows
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list workflows: {e}", exc_info=True)
|
||||
except Exception:
|
||||
logger.exception("Failed to list workflows")
|
||||
raise
|
||||
else:
|
||||
return workflows
|
||||
|
||||
@@ -8,11 +8,19 @@
|
||||
# See the LICENSE-APACHE file or http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Additional attribution and requirements are provided in the NOTICE file.
|
||||
"""Fixtures used across tests."""
|
||||
|
||||
import sys
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any
|
||||
from types import CoroutineType
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from modules.analyzer.security_analyzer import SecurityAnalyzer
|
||||
from modules.fuzzer.atheris_fuzzer import AtherisFuzzer
|
||||
from modules.fuzzer.cargo_fuzzer import CargoFuzzer
|
||||
from modules.scanner.file_scanner import FileScanner
|
||||
|
||||
# Ensure project root is on sys.path so `src` is importable
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
@@ -29,17 +37,18 @@ if str(TOOLBOX) not in sys.path:
|
||||
# Workspace Fixtures
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_workspace(tmp_path):
|
||||
"""Create a temporary workspace directory for testing"""
|
||||
def temp_workspace(tmp_path: Path) -> Path:
|
||||
"""Create a temporary workspace directory for testing."""
|
||||
workspace = tmp_path / "workspace"
|
||||
workspace.mkdir()
|
||||
return workspace
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def python_test_workspace(temp_workspace):
|
||||
"""Create a Python test workspace with sample files"""
|
||||
def python_test_workspace(temp_workspace: Path) -> Path:
|
||||
"""Create a Python test workspace with sample files."""
|
||||
# Create a simple Python project structure
|
||||
(temp_workspace / "main.py").write_text("""
|
||||
def process_data(data):
|
||||
@@ -62,8 +71,8 @@ AWS_SECRET = "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def rust_test_workspace(temp_workspace):
|
||||
"""Create a Rust test workspace with fuzz targets"""
|
||||
def rust_test_workspace(temp_workspace: Path) -> Path:
|
||||
"""Create a Rust test workspace with fuzz targets."""
|
||||
# Create Cargo.toml
|
||||
(temp_workspace / "Cargo.toml").write_text("""[package]
|
||||
name = "test_project"
|
||||
@@ -131,44 +140,45 @@ fuzz_target!(|data: &[u8]| {
|
||||
# Module Configuration Fixtures
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def atheris_config():
|
||||
"""Default Atheris fuzzer configuration"""
|
||||
def atheris_config() -> dict[str, Any]:
|
||||
"""Return default Atheris fuzzer configuration."""
|
||||
return {
|
||||
"target_file": "auto-discover",
|
||||
"max_iterations": 1000,
|
||||
"timeout_seconds": 10,
|
||||
"corpus_dir": None
|
||||
"corpus_dir": None,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cargo_fuzz_config():
|
||||
"""Default cargo-fuzz configuration"""
|
||||
def cargo_fuzz_config() -> dict[str, Any]:
|
||||
"""Return default cargo-fuzz configuration."""
|
||||
return {
|
||||
"target_name": None,
|
||||
"max_iterations": 1000,
|
||||
"timeout_seconds": 10,
|
||||
"sanitizer": "address"
|
||||
"sanitizer": "address",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def gitleaks_config():
|
||||
"""Default Gitleaks configuration"""
|
||||
def gitleaks_config() -> dict[str, Any]:
|
||||
"""Return default Gitleaks configuration."""
|
||||
return {
|
||||
"config_path": None,
|
||||
"scan_uncommitted": True
|
||||
"scan_uncommitted": True,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def file_scanner_config():
|
||||
"""Default file scanner configuration"""
|
||||
def file_scanner_config() -> dict[str, Any]:
|
||||
"""Return default file scanner configuration."""
|
||||
return {
|
||||
"scan_patterns": ["*.py", "*.rs", "*.js"],
|
||||
"exclude_patterns": ["*.test.*", "*.spec.*"],
|
||||
"max_file_size": 1048576 # 1MB
|
||||
"max_file_size": 1048576, # 1MB
|
||||
}
|
||||
|
||||
|
||||
@@ -176,55 +186,67 @@ def file_scanner_config():
|
||||
# Module Instance Fixtures
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def atheris_fuzzer():
|
||||
"""Create an AtherisFuzzer instance"""
|
||||
from modules.fuzzer.atheris_fuzzer import AtherisFuzzer
|
||||
def atheris_fuzzer() -> AtherisFuzzer:
|
||||
"""Create an AtherisFuzzer instance."""
|
||||
return AtherisFuzzer()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cargo_fuzzer():
|
||||
"""Create a CargoFuzzer instance"""
|
||||
from modules.fuzzer.cargo_fuzzer import CargoFuzzer
|
||||
def cargo_fuzzer() -> CargoFuzzer:
|
||||
"""Create a CargoFuzzer instance."""
|
||||
return CargoFuzzer()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def file_scanner():
|
||||
"""Create a FileScanner instance"""
|
||||
from modules.scanner.file_scanner import FileScanner
|
||||
def file_scanner() -> FileScanner:
|
||||
"""Create a FileScanner instance."""
|
||||
return FileScanner()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def security_analyzer() -> SecurityAnalyzer:
|
||||
"""Create SecurityAnalyzer instance."""
|
||||
return SecurityAnalyzer()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Mock Fixtures
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_stats_callback():
|
||||
"""Mock stats callback for fuzzing"""
|
||||
def mock_stats_callback() -> Callable[[], CoroutineType]:
|
||||
"""Mock stats callback for fuzzing."""
|
||||
stats_received = []
|
||||
|
||||
async def callback(stats: Dict[str, Any]):
|
||||
async def callback(stats: dict[str, Any]) -> None:
|
||||
stats_received.append(stats)
|
||||
|
||||
callback.stats_received = stats_received
|
||||
return callback
|
||||
|
||||
|
||||
class MockActivityInfo:
|
||||
"""Mock activity info."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize an instance of the class."""
|
||||
self.workflow_id = "test-workflow-123"
|
||||
self.activity_id = "test-activity-1"
|
||||
self.attempt = 1
|
||||
|
||||
|
||||
class MockContext:
|
||||
"""Mock context."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize an instance of the class."""
|
||||
self.info = MockActivityInfo()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_temporal_context():
|
||||
"""Mock Temporal activity context"""
|
||||
class MockActivityInfo:
|
||||
def __init__(self):
|
||||
self.workflow_id = "test-workflow-123"
|
||||
self.activity_id = "test-activity-1"
|
||||
self.attempt = 1
|
||||
|
||||
class MockContext:
|
||||
def __init__(self):
|
||||
self.info = MockActivityInfo()
|
||||
|
||||
def mock_temporal_context() -> MockContext:
|
||||
"""Mock Temporal activity context."""
|
||||
return MockContext()
|
||||
|
||||
|
||||
0
backend/tests/fixtures/__init__.py
vendored
0
backend/tests/fixtures/__init__.py
vendored
@@ -0,0 +1 @@
|
||||
"""Unit tests."""
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
"""Unit tests for modules."""
|
||||
|
||||
@@ -1,17 +1,26 @@
|
||||
"""
|
||||
Unit tests for AtherisFuzzer module
|
||||
"""
|
||||
"""Unit tests for AtherisFuzzer module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from modules.fuzzer.atheris_fuzzer import AtherisFuzzer
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestAtherisFuzzerMetadata:
|
||||
"""Test AtherisFuzzer metadata"""
|
||||
"""Test AtherisFuzzer metadata."""
|
||||
|
||||
async def test_metadata_structure(self, atheris_fuzzer):
|
||||
"""Test that module metadata is properly defined"""
|
||||
async def test_metadata_structure(self, atheris_fuzzer: AtherisFuzzer) -> None:
|
||||
"""Test that module metadata is properly defined."""
|
||||
metadata = atheris_fuzzer.get_metadata()
|
||||
|
||||
assert metadata.name == "atheris_fuzzer"
|
||||
@@ -22,28 +31,28 @@ class TestAtherisFuzzerMetadata:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestAtherisFuzzerConfigValidation:
|
||||
"""Test configuration validation"""
|
||||
"""Test configuration validation."""
|
||||
|
||||
async def test_valid_config(self, atheris_fuzzer, atheris_config):
|
||||
"""Test validation of valid configuration"""
|
||||
async def test_valid_config(self, atheris_fuzzer: AtherisFuzzer, atheris_config: dict[str, Any]) -> None:
|
||||
"""Test validation of valid configuration."""
|
||||
assert atheris_fuzzer.validate_config(atheris_config) is True
|
||||
|
||||
async def test_invalid_max_iterations(self, atheris_fuzzer):
|
||||
"""Test validation fails with invalid max_iterations"""
|
||||
async def test_invalid_max_iterations(self, atheris_fuzzer: AtherisFuzzer) -> None:
|
||||
"""Test validation fails with invalid max_iterations."""
|
||||
config = {
|
||||
"target_file": "fuzz_target.py",
|
||||
"max_iterations": -1,
|
||||
"timeout_seconds": 10
|
||||
"timeout_seconds": 10,
|
||||
}
|
||||
with pytest.raises(ValueError, match="max_iterations"):
|
||||
atheris_fuzzer.validate_config(config)
|
||||
|
||||
async def test_invalid_timeout(self, atheris_fuzzer):
|
||||
"""Test validation fails with invalid timeout"""
|
||||
async def test_invalid_timeout(self, atheris_fuzzer: AtherisFuzzer) -> None:
|
||||
"""Test validation fails with invalid timeout."""
|
||||
config = {
|
||||
"target_file": "fuzz_target.py",
|
||||
"max_iterations": 1000,
|
||||
"timeout_seconds": 0
|
||||
"timeout_seconds": 0,
|
||||
}
|
||||
with pytest.raises(ValueError, match="timeout_seconds"):
|
||||
atheris_fuzzer.validate_config(config)
|
||||
@@ -51,10 +60,10 @@ class TestAtherisFuzzerConfigValidation:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestAtherisFuzzerDiscovery:
|
||||
"""Test fuzz target discovery"""
|
||||
"""Test fuzz target discovery."""
|
||||
|
||||
async def test_auto_discover(self, atheris_fuzzer, python_test_workspace):
|
||||
"""Test auto-discovery of Python fuzz targets"""
|
||||
async def test_auto_discover(self, atheris_fuzzer: AtherisFuzzer, python_test_workspace: Path) -> None:
|
||||
"""Test auto-discovery of Python fuzz targets."""
|
||||
# Create a fuzz target file
|
||||
(python_test_workspace / "fuzz_target.py").write_text("""
|
||||
import atheris
|
||||
@@ -69,7 +78,7 @@ if __name__ == "__main__":
|
||||
""")
|
||||
|
||||
# Pass None for auto-discovery
|
||||
target = atheris_fuzzer._discover_target(python_test_workspace, None)
|
||||
target = atheris_fuzzer._discover_target(python_test_workspace, None) # noqa: SLF001
|
||||
|
||||
assert target is not None
|
||||
assert "fuzz_target.py" in str(target)
|
||||
@@ -77,10 +86,14 @@ if __name__ == "__main__":
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestAtherisFuzzerExecution:
|
||||
"""Test fuzzer execution logic"""
|
||||
"""Test fuzzer execution logic."""
|
||||
|
||||
async def test_execution_creates_result(self, atheris_fuzzer, python_test_workspace, atheris_config):
|
||||
"""Test that execution returns a ModuleResult"""
|
||||
async def test_execution_creates_result(
|
||||
self,
|
||||
atheris_fuzzer: AtherisFuzzer,
|
||||
python_test_workspace: Path,
|
||||
) -> None:
|
||||
"""Test that execution returns a ModuleResult."""
|
||||
# Create a simple fuzz target
|
||||
(python_test_workspace / "fuzz_target.py").write_text("""
|
||||
import atheris
|
||||
@@ -99,11 +112,16 @@ if __name__ == "__main__":
|
||||
test_config = {
|
||||
"target_file": "fuzz_target.py",
|
||||
"max_iterations": 10,
|
||||
"timeout_seconds": 1
|
||||
"timeout_seconds": 1,
|
||||
}
|
||||
|
||||
# Mock the fuzzing subprocess to avoid actual execution
|
||||
with patch.object(atheris_fuzzer, '_run_fuzzing', new_callable=AsyncMock, return_value=([], {"total_executions": 10})):
|
||||
with patch.object(
|
||||
atheris_fuzzer,
|
||||
"_run_fuzzing",
|
||||
new_callable=AsyncMock,
|
||||
return_value=([], {"total_executions": 10}),
|
||||
):
|
||||
result = await atheris_fuzzer.execute(test_config, python_test_workspace)
|
||||
|
||||
assert result.module == "atheris_fuzzer"
|
||||
@@ -113,10 +131,16 @@ if __name__ == "__main__":
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestAtherisFuzzerStatsCallback:
|
||||
"""Test stats callback functionality"""
|
||||
"""Test stats callback functionality."""
|
||||
|
||||
async def test_stats_callback_invoked(self, atheris_fuzzer, python_test_workspace, atheris_config, mock_stats_callback):
|
||||
"""Test that stats callback is invoked during fuzzing"""
|
||||
async def test_stats_callback_invoked(
|
||||
self,
|
||||
atheris_fuzzer: AtherisFuzzer,
|
||||
python_test_workspace: Path,
|
||||
atheris_config: dict[str, Any],
|
||||
mock_stats_callback: Callable | None,
|
||||
) -> None:
|
||||
"""Test that stats callback is invoked during fuzzing."""
|
||||
(python_test_workspace / "fuzz_target.py").write_text("""
|
||||
import atheris
|
||||
import sys
|
||||
@@ -130,35 +154,45 @@ if __name__ == "__main__":
|
||||
""")
|
||||
|
||||
# Mock fuzzing to simulate stats
|
||||
async def mock_run_fuzzing(test_one_input, target_path, workspace, max_iterations, timeout_seconds, stats_callback):
|
||||
async def mock_run_fuzzing(
|
||||
test_one_input: Callable, # noqa: ARG001
|
||||
target_path: Path, # noqa: ARG001
|
||||
workspace: Path, # noqa: ARG001
|
||||
max_iterations: int, # noqa: ARG001
|
||||
timeout_seconds: int, # noqa: ARG001
|
||||
stats_callback: Callable | None,
|
||||
) -> None:
|
||||
if stats_callback:
|
||||
await stats_callback({
|
||||
"total_execs": 100,
|
||||
"execs_per_sec": 10.0,
|
||||
"crashes": 0,
|
||||
"coverage": 5,
|
||||
"corpus_size": 2,
|
||||
"elapsed_time": 10
|
||||
})
|
||||
return
|
||||
await stats_callback(
|
||||
{
|
||||
"total_execs": 100,
|
||||
"execs_per_sec": 10.0,
|
||||
"crashes": 0,
|
||||
"coverage": 5,
|
||||
"corpus_size": 2,
|
||||
"elapsed_time": 10,
|
||||
},
|
||||
)
|
||||
|
||||
with patch.object(atheris_fuzzer, '_run_fuzzing', side_effect=mock_run_fuzzing):
|
||||
with patch.object(atheris_fuzzer, '_load_target_module', return_value=lambda x: None):
|
||||
# Put stats_callback in config dict, not as kwarg
|
||||
atheris_config["target_file"] = "fuzz_target.py"
|
||||
atheris_config["stats_callback"] = mock_stats_callback
|
||||
await atheris_fuzzer.execute(atheris_config, python_test_workspace)
|
||||
with (
|
||||
patch.object(atheris_fuzzer, "_run_fuzzing", side_effect=mock_run_fuzzing),
|
||||
patch.object(atheris_fuzzer, "_load_target_module", return_value=lambda _x: None),
|
||||
):
|
||||
# Put stats_callback in config dict, not as kwarg
|
||||
atheris_config["target_file"] = "fuzz_target.py"
|
||||
atheris_config["stats_callback"] = mock_stats_callback
|
||||
await atheris_fuzzer.execute(atheris_config, python_test_workspace)
|
||||
|
||||
# Verify callback was invoked
|
||||
assert len(mock_stats_callback.stats_received) > 0
|
||||
# Verify callback was invoked
|
||||
assert len(mock_stats_callback.stats_received) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestAtherisFuzzerFindingGeneration:
|
||||
"""Test finding generation from crashes"""
|
||||
"""Test finding generation from crashes."""
|
||||
|
||||
async def test_create_crash_finding(self, atheris_fuzzer):
|
||||
"""Test crash finding creation"""
|
||||
async def test_create_crash_finding(self, atheris_fuzzer: AtherisFuzzer) -> None:
|
||||
"""Test crash finding creation."""
|
||||
finding = atheris_fuzzer.create_finding(
|
||||
title="Crash: Exception in TestOneInput",
|
||||
description="IndexError: list index out of range",
|
||||
@@ -167,8 +201,8 @@ class TestAtherisFuzzerFindingGeneration:
|
||||
file_path="fuzz_target.py",
|
||||
metadata={
|
||||
"crash_type": "IndexError",
|
||||
"stack_trace": "Traceback..."
|
||||
}
|
||||
"stack_trace": "Traceback...",
|
||||
},
|
||||
)
|
||||
|
||||
assert finding.title == "Crash: Exception in TestOneInput"
|
||||
|
||||
@@ -1,17 +1,26 @@
|
||||
"""
|
||||
Unit tests for CargoFuzzer module
|
||||
"""
|
||||
"""Unit tests for CargoFuzzer module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from modules.fuzzer.cargo_fuzzer import CargoFuzzer
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestCargoFuzzerMetadata:
|
||||
"""Test CargoFuzzer metadata"""
|
||||
"""Test CargoFuzzer metadata."""
|
||||
|
||||
async def test_metadata_structure(self, cargo_fuzzer):
|
||||
"""Test that module metadata is properly defined"""
|
||||
async def test_metadata_structure(self, cargo_fuzzer: CargoFuzzer) -> None:
|
||||
"""Test that module metadata is properly defined."""
|
||||
metadata = cargo_fuzzer.get_metadata()
|
||||
|
||||
assert metadata.name == "cargo_fuzz"
|
||||
@@ -23,38 +32,38 @@ class TestCargoFuzzerMetadata:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestCargoFuzzerConfigValidation:
|
||||
"""Test configuration validation"""
|
||||
"""Test configuration validation."""
|
||||
|
||||
async def test_valid_config(self, cargo_fuzzer, cargo_fuzz_config):
|
||||
"""Test validation of valid configuration"""
|
||||
async def test_valid_config(self, cargo_fuzzer: CargoFuzzer, cargo_fuzz_config: dict[str, Any]) -> None:
|
||||
"""Test validation of valid configuration."""
|
||||
assert cargo_fuzzer.validate_config(cargo_fuzz_config) is True
|
||||
|
||||
async def test_invalid_max_iterations(self, cargo_fuzzer):
|
||||
"""Test validation fails with invalid max_iterations"""
|
||||
async def test_invalid_max_iterations(self, cargo_fuzzer: CargoFuzzer) -> None:
|
||||
"""Test validation fails with invalid max_iterations."""
|
||||
config = {
|
||||
"max_iterations": -1,
|
||||
"timeout_seconds": 10,
|
||||
"sanitizer": "address"
|
||||
"sanitizer": "address",
|
||||
}
|
||||
with pytest.raises(ValueError, match="max_iterations"):
|
||||
cargo_fuzzer.validate_config(config)
|
||||
|
||||
async def test_invalid_timeout(self, cargo_fuzzer):
|
||||
"""Test validation fails with invalid timeout"""
|
||||
async def test_invalid_timeout(self, cargo_fuzzer: CargoFuzzer) -> None:
|
||||
"""Test validation fails with invalid timeout."""
|
||||
config = {
|
||||
"max_iterations": 1000,
|
||||
"timeout_seconds": 0,
|
||||
"sanitizer": "address"
|
||||
"sanitizer": "address",
|
||||
}
|
||||
with pytest.raises(ValueError, match="timeout_seconds"):
|
||||
cargo_fuzzer.validate_config(config)
|
||||
|
||||
async def test_invalid_sanitizer(self, cargo_fuzzer):
|
||||
"""Test validation fails with invalid sanitizer"""
|
||||
async def test_invalid_sanitizer(self, cargo_fuzzer: CargoFuzzer) -> None:
|
||||
"""Test validation fails with invalid sanitizer."""
|
||||
config = {
|
||||
"max_iterations": 1000,
|
||||
"timeout_seconds": 10,
|
||||
"sanitizer": "invalid_sanitizer"
|
||||
"sanitizer": "invalid_sanitizer",
|
||||
}
|
||||
with pytest.raises(ValueError, match="sanitizer"):
|
||||
cargo_fuzzer.validate_config(config)
|
||||
@@ -62,20 +71,20 @@ class TestCargoFuzzerConfigValidation:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestCargoFuzzerWorkspaceValidation:
|
||||
"""Test workspace validation"""
|
||||
"""Test workspace validation."""
|
||||
|
||||
async def test_valid_workspace(self, cargo_fuzzer, rust_test_workspace):
|
||||
"""Test validation of valid workspace"""
|
||||
async def test_valid_workspace(self, cargo_fuzzer: CargoFuzzer, rust_test_workspace: Path) -> None:
|
||||
"""Test validation of valid workspace."""
|
||||
assert cargo_fuzzer.validate_workspace(rust_test_workspace) is True
|
||||
|
||||
async def test_nonexistent_workspace(self, cargo_fuzzer, tmp_path):
|
||||
"""Test validation fails with nonexistent workspace"""
|
||||
async def test_nonexistent_workspace(self, cargo_fuzzer: CargoFuzzer, tmp_path: Path) -> None:
|
||||
"""Test validation fails with nonexistent workspace."""
|
||||
nonexistent = tmp_path / "does_not_exist"
|
||||
with pytest.raises(ValueError, match="does not exist"):
|
||||
cargo_fuzzer.validate_workspace(nonexistent)
|
||||
|
||||
async def test_workspace_is_file(self, cargo_fuzzer, tmp_path):
|
||||
"""Test validation fails when workspace is a file"""
|
||||
async def test_workspace_is_file(self, cargo_fuzzer: CargoFuzzer, tmp_path: Path) -> None:
|
||||
"""Test validation fails when workspace is a file."""
|
||||
file_path = tmp_path / "file.txt"
|
||||
file_path.write_text("test")
|
||||
with pytest.raises(ValueError, match="not a directory"):
|
||||
@@ -84,41 +93,58 @@ class TestCargoFuzzerWorkspaceValidation:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestCargoFuzzerDiscovery:
|
||||
"""Test fuzz target discovery"""
|
||||
"""Test fuzz target discovery."""
|
||||
|
||||
async def test_discover_targets(self, cargo_fuzzer, rust_test_workspace):
|
||||
"""Test discovery of fuzz targets"""
|
||||
targets = await cargo_fuzzer._discover_fuzz_targets(rust_test_workspace)
|
||||
async def test_discover_targets(self, cargo_fuzzer: CargoFuzzer, rust_test_workspace: Path) -> None:
|
||||
"""Test discovery of fuzz targets."""
|
||||
targets = await cargo_fuzzer._discover_fuzz_targets(rust_test_workspace) # noqa: SLF001
|
||||
|
||||
assert len(targets) == 1
|
||||
assert "fuzz_target_1" in targets
|
||||
|
||||
async def test_no_fuzz_directory(self, cargo_fuzzer, temp_workspace):
|
||||
"""Test discovery with no fuzz directory"""
|
||||
targets = await cargo_fuzzer._discover_fuzz_targets(temp_workspace)
|
||||
async def test_no_fuzz_directory(self, cargo_fuzzer: CargoFuzzer, temp_workspace: Path) -> None:
|
||||
"""Test discovery with no fuzz directory."""
|
||||
targets = await cargo_fuzzer._discover_fuzz_targets(temp_workspace) # noqa: SLF001
|
||||
|
||||
assert targets == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestCargoFuzzerExecution:
|
||||
"""Test fuzzer execution logic"""
|
||||
"""Test fuzzer execution logic."""
|
||||
|
||||
async def test_execution_creates_result(self, cargo_fuzzer, rust_test_workspace, cargo_fuzz_config):
|
||||
"""Test that execution returns a ModuleResult"""
|
||||
async def test_execution_creates_result(
|
||||
self,
|
||||
cargo_fuzzer: CargoFuzzer,
|
||||
rust_test_workspace: Path,
|
||||
cargo_fuzz_config: dict[str, Any],
|
||||
) -> None:
|
||||
"""Test that execution returns a ModuleResult."""
|
||||
# Mock the build and run methods to avoid actual fuzzing
|
||||
with patch.object(cargo_fuzzer, '_build_fuzz_target', new_callable=AsyncMock, return_value=True):
|
||||
with patch.object(cargo_fuzzer, '_run_fuzzing', new_callable=AsyncMock, return_value=([], {"total_executions": 0, "crashes_found": 0})):
|
||||
with patch.object(cargo_fuzzer, '_parse_crash_artifacts', new_callable=AsyncMock, return_value=[]):
|
||||
result = await cargo_fuzzer.execute(cargo_fuzz_config, rust_test_workspace)
|
||||
with (
|
||||
patch.object(cargo_fuzzer, "_build_fuzz_target", new_callable=AsyncMock, return_value=True),
|
||||
patch.object(
|
||||
cargo_fuzzer,
|
||||
"_run_fuzzing",
|
||||
new_callable=AsyncMock,
|
||||
return_value=([], {"total_executions": 0, "crashes_found": 0}),
|
||||
),
|
||||
patch.object(cargo_fuzzer, "_parse_crash_artifacts", new_callable=AsyncMock, return_value=[]),
|
||||
):
|
||||
result = await cargo_fuzzer.execute(cargo_fuzz_config, rust_test_workspace)
|
||||
|
||||
assert result.module == "cargo_fuzz"
|
||||
assert result.status == "success"
|
||||
assert isinstance(result.execution_time, float)
|
||||
assert result.execution_time >= 0
|
||||
assert result.module == "cargo_fuzz"
|
||||
assert result.status == "success"
|
||||
assert isinstance(result.execution_time, float)
|
||||
assert result.execution_time >= 0
|
||||
|
||||
async def test_execution_with_no_targets(self, cargo_fuzzer, temp_workspace, cargo_fuzz_config):
|
||||
"""Test execution fails gracefully with no fuzz targets"""
|
||||
async def test_execution_with_no_targets(
|
||||
self,
|
||||
cargo_fuzzer: CargoFuzzer,
|
||||
temp_workspace: Path,
|
||||
cargo_fuzz_config: dict[str, Any],
|
||||
) -> None:
|
||||
"""Test execution fails gracefully with no fuzz targets."""
|
||||
result = await cargo_fuzzer.execute(cargo_fuzz_config, temp_workspace)
|
||||
|
||||
assert result.status == "failed"
|
||||
@@ -127,47 +153,67 @@ class TestCargoFuzzerExecution:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestCargoFuzzerStatsCallback:
|
||||
"""Test stats callback functionality"""
|
||||
"""Test stats callback functionality."""
|
||||
|
||||
async def test_stats_callback_invoked(
|
||||
self,
|
||||
cargo_fuzzer: CargoFuzzer,
|
||||
rust_test_workspace: Path,
|
||||
cargo_fuzz_config: dict[str, Any],
|
||||
mock_stats_callback: Callable | None,
|
||||
) -> None:
|
||||
"""Test that stats callback is invoked during fuzzing."""
|
||||
|
||||
async def test_stats_callback_invoked(self, cargo_fuzzer, rust_test_workspace, cargo_fuzz_config, mock_stats_callback):
|
||||
"""Test that stats callback is invoked during fuzzing"""
|
||||
# Mock build/run to simulate stats generation
|
||||
async def mock_run_fuzzing(workspace, target, config, callback):
|
||||
async def mock_run_fuzzing(
|
||||
_workspace: Path,
|
||||
_target: str,
|
||||
_config: dict[str, Any],
|
||||
callback: Callable | None,
|
||||
) -> tuple[list, dict[str, int]]:
|
||||
# Simulate stats callback
|
||||
if callback:
|
||||
await callback({
|
||||
"total_execs": 1000,
|
||||
"execs_per_sec": 100.0,
|
||||
"crashes": 0,
|
||||
"coverage": 10,
|
||||
"corpus_size": 5,
|
||||
"elapsed_time": 10
|
||||
})
|
||||
await callback(
|
||||
{
|
||||
"total_execs": 1000,
|
||||
"execs_per_sec": 100.0,
|
||||
"crashes": 0,
|
||||
"coverage": 10,
|
||||
"corpus_size": 5,
|
||||
"elapsed_time": 10,
|
||||
},
|
||||
)
|
||||
return [], {"total_executions": 1000}
|
||||
|
||||
with patch.object(cargo_fuzzer, '_build_fuzz_target', new_callable=AsyncMock, return_value=True):
|
||||
with patch.object(cargo_fuzzer, '_run_fuzzing', side_effect=mock_run_fuzzing):
|
||||
with patch.object(cargo_fuzzer, '_parse_crash_artifacts', new_callable=AsyncMock, return_value=[]):
|
||||
await cargo_fuzzer.execute(cargo_fuzz_config, rust_test_workspace, stats_callback=mock_stats_callback)
|
||||
with (
|
||||
patch.object(cargo_fuzzer, "_build_fuzz_target", new_callable=AsyncMock, return_value=True),
|
||||
patch.object(cargo_fuzzer, "_run_fuzzing", side_effect=mock_run_fuzzing),
|
||||
patch.object(cargo_fuzzer, "_parse_crash_artifacts", new_callable=AsyncMock, return_value=[]),
|
||||
):
|
||||
await cargo_fuzzer.execute(
|
||||
cargo_fuzz_config,
|
||||
rust_test_workspace,
|
||||
stats_callback=mock_stats_callback,
|
||||
)
|
||||
|
||||
# Verify callback was invoked
|
||||
assert len(mock_stats_callback.stats_received) > 0
|
||||
assert mock_stats_callback.stats_received[0]["total_execs"] == 1000
|
||||
# Verify callback was invoked
|
||||
assert len(mock_stats_callback.stats_received) > 0
|
||||
assert mock_stats_callback.stats_received[0]["total_execs"] == 1000
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestCargoFuzzerFindingGeneration:
|
||||
"""Test finding generation from crashes"""
|
||||
"""Test finding generation from crashes."""
|
||||
|
||||
async def test_create_finding_from_crash(self, cargo_fuzzer):
|
||||
"""Test finding creation"""
|
||||
async def test_create_finding_from_crash(self, cargo_fuzzer: CargoFuzzer) -> None:
|
||||
"""Test finding creation."""
|
||||
finding = cargo_fuzzer.create_finding(
|
||||
title="Crash: Segmentation Fault",
|
||||
description="Test crash",
|
||||
severity="critical",
|
||||
category="crash",
|
||||
file_path="fuzz/fuzz_targets/fuzz_target_1.rs",
|
||||
metadata={"crash_type": "SIGSEGV"}
|
||||
metadata={"crash_type": "SIGSEGV"},
|
||||
)
|
||||
|
||||
assert finding.title == "Crash: Segmentation Fault"
|
||||
|
||||
@@ -1,22 +1,25 @@
|
||||
"""
|
||||
Unit tests for FileScanner module
|
||||
"""
|
||||
"""Unit tests for FileScanner module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pytest
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from modules.scanner.file_scanner import FileScanner
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parents[3] / "toolbox"))
|
||||
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestFileScannerMetadata:
|
||||
"""Test FileScanner metadata"""
|
||||
"""Test FileScanner metadata."""
|
||||
|
||||
async def test_metadata_structure(self, file_scanner):
|
||||
"""Test that metadata has correct structure"""
|
||||
async def test_metadata_structure(self, file_scanner: FileScanner) -> None:
|
||||
"""Test that metadata has correct structure."""
|
||||
metadata = file_scanner.get_metadata()
|
||||
|
||||
assert metadata.name == "file_scanner"
|
||||
@@ -29,37 +32,37 @@ class TestFileScannerMetadata:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestFileScannerConfigValidation:
|
||||
"""Test configuration validation"""
|
||||
"""Test configuration validation."""
|
||||
|
||||
async def test_valid_config(self, file_scanner):
|
||||
"""Test that valid config passes validation"""
|
||||
async def test_valid_config(self, file_scanner: FileScanner) -> None:
|
||||
"""Test that valid config passes validation."""
|
||||
config = {
|
||||
"patterns": ["*.py", "*.js"],
|
||||
"max_file_size": 1048576,
|
||||
"check_sensitive": True,
|
||||
"calculate_hashes": False
|
||||
"calculate_hashes": False,
|
||||
}
|
||||
assert file_scanner.validate_config(config) is True
|
||||
|
||||
async def test_default_config(self, file_scanner):
|
||||
"""Test that empty config uses defaults"""
|
||||
async def test_default_config(self, file_scanner: FileScanner) -> None:
|
||||
"""Test that empty config uses defaults."""
|
||||
config = {}
|
||||
assert file_scanner.validate_config(config) is True
|
||||
|
||||
async def test_invalid_patterns_type(self, file_scanner):
|
||||
"""Test that non-list patterns raises error"""
|
||||
async def test_invalid_patterns_type(self, file_scanner: FileScanner) -> None:
|
||||
"""Test that non-list patterns raises error."""
|
||||
config = {"patterns": "*.py"}
|
||||
with pytest.raises(ValueError, match="patterns must be a list"):
|
||||
file_scanner.validate_config(config)
|
||||
|
||||
async def test_invalid_max_file_size(self, file_scanner):
|
||||
"""Test that invalid max_file_size raises error"""
|
||||
async def test_invalid_max_file_size(self, file_scanner: FileScanner) -> None:
|
||||
"""Test that invalid max_file_size raises error."""
|
||||
config = {"max_file_size": -1}
|
||||
with pytest.raises(ValueError, match="max_file_size must be a positive integer"):
|
||||
file_scanner.validate_config(config)
|
||||
|
||||
async def test_invalid_max_file_size_type(self, file_scanner):
|
||||
"""Test that non-integer max_file_size raises error"""
|
||||
async def test_invalid_max_file_size_type(self, file_scanner: FileScanner) -> None:
|
||||
"""Test that non-integer max_file_size raises error."""
|
||||
config = {"max_file_size": "large"}
|
||||
with pytest.raises(ValueError, match="max_file_size must be a positive integer"):
|
||||
file_scanner.validate_config(config)
|
||||
@@ -67,14 +70,14 @@ class TestFileScannerConfigValidation:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestFileScannerExecution:
|
||||
"""Test scanner execution"""
|
||||
"""Test scanner execution."""
|
||||
|
||||
async def test_scan_python_files(self, file_scanner, python_test_workspace):
|
||||
"""Test scanning Python files"""
|
||||
async def test_scan_python_files(self, file_scanner: FileScanner, python_test_workspace: Path) -> None:
|
||||
"""Test scanning Python files."""
|
||||
config = {
|
||||
"patterns": ["*.py"],
|
||||
"check_sensitive": False,
|
||||
"calculate_hashes": False
|
||||
"calculate_hashes": False,
|
||||
}
|
||||
|
||||
result = await file_scanner.execute(config, python_test_workspace)
|
||||
@@ -84,15 +87,15 @@ class TestFileScannerExecution:
|
||||
assert len(result.findings) > 0
|
||||
|
||||
# Check that Python files were found
|
||||
python_files = [f for f in result.findings if f.file_path.endswith('.py')]
|
||||
python_files = [f for f in result.findings if f.file_path.endswith(".py")]
|
||||
assert len(python_files) > 0
|
||||
|
||||
async def test_scan_all_files(self, file_scanner, python_test_workspace):
|
||||
"""Test scanning all files with wildcard"""
|
||||
async def test_scan_all_files(self, file_scanner: FileScanner, python_test_workspace: Path) -> None:
|
||||
"""Test scanning all files with wildcard."""
|
||||
config = {
|
||||
"patterns": ["*"],
|
||||
"check_sensitive": False,
|
||||
"calculate_hashes": False
|
||||
"calculate_hashes": False,
|
||||
}
|
||||
|
||||
result = await file_scanner.execute(config, python_test_workspace)
|
||||
@@ -101,12 +104,12 @@ class TestFileScannerExecution:
|
||||
assert len(result.findings) > 0
|
||||
assert result.summary["total_files"] > 0
|
||||
|
||||
async def test_scan_with_multiple_patterns(self, file_scanner, python_test_workspace):
|
||||
"""Test scanning with multiple patterns"""
|
||||
async def test_scan_with_multiple_patterns(self, file_scanner: FileScanner, python_test_workspace: Path) -> None:
|
||||
"""Test scanning with multiple patterns."""
|
||||
config = {
|
||||
"patterns": ["*.py", "*.txt"],
|
||||
"check_sensitive": False,
|
||||
"calculate_hashes": False
|
||||
"calculate_hashes": False,
|
||||
}
|
||||
|
||||
result = await file_scanner.execute(config, python_test_workspace)
|
||||
@@ -114,11 +117,11 @@ class TestFileScannerExecution:
|
||||
assert result.status == "success"
|
||||
assert len(result.findings) > 0
|
||||
|
||||
async def test_empty_workspace(self, file_scanner, temp_workspace):
|
||||
"""Test scanning empty workspace"""
|
||||
async def test_empty_workspace(self, file_scanner: FileScanner, temp_workspace: Path) -> None:
|
||||
"""Test scanning empty workspace."""
|
||||
config = {
|
||||
"patterns": ["*.py"],
|
||||
"check_sensitive": False
|
||||
"check_sensitive": False,
|
||||
}
|
||||
|
||||
result = await file_scanner.execute(config, temp_workspace)
|
||||
@@ -130,17 +133,17 @@ class TestFileScannerExecution:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestFileScannerSensitiveDetection:
|
||||
"""Test sensitive file detection"""
|
||||
"""Test sensitive file detection."""
|
||||
|
||||
async def test_detect_env_file(self, file_scanner, temp_workspace):
|
||||
"""Test detection of .env file"""
|
||||
async def test_detect_env_file(self, file_scanner: FileScanner, temp_workspace: Path) -> None:
|
||||
"""Test detection of .env file."""
|
||||
# Create .env file
|
||||
(temp_workspace / ".env").write_text("API_KEY=secret123")
|
||||
|
||||
config = {
|
||||
"patterns": ["*"],
|
||||
"check_sensitive": True,
|
||||
"calculate_hashes": False
|
||||
"calculate_hashes": False,
|
||||
}
|
||||
|
||||
result = await file_scanner.execute(config, temp_workspace)
|
||||
@@ -152,14 +155,14 @@ class TestFileScannerSensitiveDetection:
|
||||
assert len(sensitive_findings) > 0
|
||||
assert any(".env" in f.title for f in sensitive_findings)
|
||||
|
||||
async def test_detect_private_key(self, file_scanner, temp_workspace):
|
||||
"""Test detection of private key file"""
|
||||
async def test_detect_private_key(self, file_scanner: FileScanner, temp_workspace: Path) -> None:
|
||||
"""Test detection of private key file."""
|
||||
# Create private key file
|
||||
(temp_workspace / "id_rsa").write_text("-----BEGIN RSA PRIVATE KEY-----")
|
||||
|
||||
config = {
|
||||
"patterns": ["*"],
|
||||
"check_sensitive": True
|
||||
"check_sensitive": True,
|
||||
}
|
||||
|
||||
result = await file_scanner.execute(config, temp_workspace)
|
||||
@@ -168,13 +171,13 @@ class TestFileScannerSensitiveDetection:
|
||||
sensitive_findings = [f for f in result.findings if f.category == "sensitive_file"]
|
||||
assert len(sensitive_findings) > 0
|
||||
|
||||
async def test_no_sensitive_detection_when_disabled(self, file_scanner, temp_workspace):
|
||||
"""Test that sensitive detection can be disabled"""
|
||||
async def test_no_sensitive_detection_when_disabled(self, file_scanner: FileScanner, temp_workspace: Path) -> None:
|
||||
"""Test that sensitive detection can be disabled."""
|
||||
(temp_workspace / ".env").write_text("API_KEY=secret123")
|
||||
|
||||
config = {
|
||||
"patterns": ["*"],
|
||||
"check_sensitive": False
|
||||
"check_sensitive": False,
|
||||
}
|
||||
|
||||
result = await file_scanner.execute(config, temp_workspace)
|
||||
@@ -186,17 +189,17 @@ class TestFileScannerSensitiveDetection:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestFileScannerHashing:
|
||||
"""Test file hashing functionality"""
|
||||
"""Test file hashing functionality."""
|
||||
|
||||
async def test_hash_calculation(self, file_scanner, temp_workspace):
|
||||
"""Test SHA256 hash calculation"""
|
||||
async def test_hash_calculation(self, file_scanner: FileScanner, temp_workspace: Path) -> None:
|
||||
"""Test SHA256 hash calculation."""
|
||||
# Create test file
|
||||
test_file = temp_workspace / "test.txt"
|
||||
test_file.write_text("Hello World")
|
||||
|
||||
config = {
|
||||
"patterns": ["*.txt"],
|
||||
"calculate_hashes": True
|
||||
"calculate_hashes": True,
|
||||
}
|
||||
|
||||
result = await file_scanner.execute(config, temp_workspace)
|
||||
@@ -212,14 +215,14 @@ class TestFileScannerHashing:
|
||||
assert finding.metadata.get("file_hash") is not None
|
||||
assert len(finding.metadata["file_hash"]) == 64 # SHA256 hex length
|
||||
|
||||
async def test_no_hash_when_disabled(self, file_scanner, temp_workspace):
|
||||
"""Test that hashing can be disabled"""
|
||||
async def test_no_hash_when_disabled(self, file_scanner: FileScanner, temp_workspace: Path) -> None:
|
||||
"""Test that hashing can be disabled."""
|
||||
test_file = temp_workspace / "test.txt"
|
||||
test_file.write_text("Hello World")
|
||||
|
||||
config = {
|
||||
"patterns": ["*.txt"],
|
||||
"calculate_hashes": False
|
||||
"calculate_hashes": False,
|
||||
}
|
||||
|
||||
result = await file_scanner.execute(config, temp_workspace)
|
||||
@@ -234,10 +237,10 @@ class TestFileScannerHashing:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestFileScannerFileTypes:
|
||||
"""Test file type detection"""
|
||||
"""Test file type detection."""
|
||||
|
||||
async def test_detect_python_type(self, file_scanner, temp_workspace):
|
||||
"""Test detection of Python file type"""
|
||||
async def test_detect_python_type(self, file_scanner: FileScanner, temp_workspace: Path) -> None:
|
||||
"""Test detection of Python file type."""
|
||||
(temp_workspace / "script.py").write_text("print('hello')")
|
||||
|
||||
config = {"patterns": ["*.py"]}
|
||||
@@ -248,8 +251,8 @@ class TestFileScannerFileTypes:
|
||||
assert len(py_findings) > 0
|
||||
assert "python" in py_findings[0].metadata["file_type"]
|
||||
|
||||
async def test_detect_javascript_type(self, file_scanner, temp_workspace):
|
||||
"""Test detection of JavaScript file type"""
|
||||
async def test_detect_javascript_type(self, file_scanner: FileScanner, temp_workspace: Path) -> None:
|
||||
"""Test detection of JavaScript file type."""
|
||||
(temp_workspace / "app.js").write_text("console.log('hello')")
|
||||
|
||||
config = {"patterns": ["*.js"]}
|
||||
@@ -260,8 +263,8 @@ class TestFileScannerFileTypes:
|
||||
assert len(js_findings) > 0
|
||||
assert "javascript" in js_findings[0].metadata["file_type"]
|
||||
|
||||
async def test_file_type_summary(self, file_scanner, temp_workspace):
|
||||
"""Test that file type summary is generated"""
|
||||
async def test_file_type_summary(self, file_scanner: FileScanner, temp_workspace: Path) -> None:
|
||||
"""Test that file type summary is generated."""
|
||||
(temp_workspace / "script.py").write_text("print('hello')")
|
||||
(temp_workspace / "app.js").write_text("console.log('hello')")
|
||||
(temp_workspace / "readme.txt").write_text("Documentation")
|
||||
@@ -276,17 +279,17 @@ class TestFileScannerFileTypes:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestFileScannerSizeLimits:
|
||||
"""Test file size handling"""
|
||||
"""Test file size handling."""
|
||||
|
||||
async def test_skip_large_files(self, file_scanner, temp_workspace):
|
||||
"""Test that large files are skipped"""
|
||||
async def test_skip_large_files(self, file_scanner: FileScanner, temp_workspace: Path) -> None:
|
||||
"""Test that large files are skipped."""
|
||||
# Create a "large" file
|
||||
large_file = temp_workspace / "large.txt"
|
||||
large_file.write_text("x" * 1000)
|
||||
|
||||
config = {
|
||||
"patterns": ["*.txt"],
|
||||
"max_file_size": 500 # Set limit smaller than file
|
||||
"max_file_size": 500, # Set limit smaller than file
|
||||
}
|
||||
|
||||
result = await file_scanner.execute(config, temp_workspace)
|
||||
@@ -297,14 +300,14 @@ class TestFileScannerSizeLimits:
|
||||
# The file should still be counted but not have a detailed finding
|
||||
assert result.summary["total_files"] > 0
|
||||
|
||||
async def test_process_small_files(self, file_scanner, temp_workspace):
|
||||
"""Test that small files are processed"""
|
||||
async def test_process_small_files(self, file_scanner: FileScanner, temp_workspace: Path) -> None:
|
||||
"""Test that small files are processed."""
|
||||
small_file = temp_workspace / "small.txt"
|
||||
small_file.write_text("small content")
|
||||
|
||||
config = {
|
||||
"patterns": ["*.txt"],
|
||||
"max_file_size": 1048576 # 1MB
|
||||
"max_file_size": 1048576, # 1MB
|
||||
}
|
||||
|
||||
result = await file_scanner.execute(config, temp_workspace)
|
||||
@@ -316,10 +319,10 @@ class TestFileScannerSizeLimits:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestFileScannerSummary:
|
||||
"""Test result summary generation"""
|
||||
"""Test result summary generation."""
|
||||
|
||||
async def test_summary_structure(self, file_scanner, python_test_workspace):
|
||||
"""Test that summary has correct structure"""
|
||||
async def test_summary_structure(self, file_scanner: FileScanner, python_test_workspace: Path) -> None:
|
||||
"""Test that summary has correct structure."""
|
||||
config = {"patterns": ["*"]}
|
||||
result = await file_scanner.execute(config, python_test_workspace)
|
||||
|
||||
@@ -334,8 +337,8 @@ class TestFileScannerSummary:
|
||||
assert isinstance(result.summary["file_types"], dict)
|
||||
assert isinstance(result.summary["patterns_scanned"], list)
|
||||
|
||||
async def test_summary_counts(self, file_scanner, temp_workspace):
|
||||
"""Test that summary counts are accurate"""
|
||||
async def test_summary_counts(self, file_scanner: FileScanner, temp_workspace: Path) -> None:
|
||||
"""Test that summary counts are accurate."""
|
||||
# Create known files
|
||||
(temp_workspace / "file1.py").write_text("content1")
|
||||
(temp_workspace / "file2.py").write_text("content2")
|
||||
|
||||
@@ -1,28 +1,25 @@
|
||||
"""
|
||||
Unit tests for SecurityAnalyzer module
|
||||
"""
|
||||
"""Unit tests for SecurityAnalyzer module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pytest
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parents[3] / "toolbox"))
|
||||
|
||||
from modules.analyzer.security_analyzer import SecurityAnalyzer
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def security_analyzer():
|
||||
"""Create SecurityAnalyzer instance"""
|
||||
return SecurityAnalyzer()
|
||||
if TYPE_CHECKING:
|
||||
from modules.analyzer.security_analyzer import SecurityAnalyzer
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestSecurityAnalyzerMetadata:
|
||||
"""Test SecurityAnalyzer metadata"""
|
||||
"""Test SecurityAnalyzer metadata."""
|
||||
|
||||
async def test_metadata_structure(self, security_analyzer):
|
||||
"""Test that metadata has correct structure"""
|
||||
async def test_metadata_structure(self, security_analyzer: SecurityAnalyzer) -> None:
|
||||
"""Test that metadata has correct structure."""
|
||||
metadata = security_analyzer.get_metadata()
|
||||
|
||||
assert metadata.name == "security_analyzer"
|
||||
@@ -35,25 +32,25 @@ class TestSecurityAnalyzerMetadata:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestSecurityAnalyzerConfigValidation:
|
||||
"""Test configuration validation"""
|
||||
"""Test configuration validation."""
|
||||
|
||||
async def test_valid_config(self, security_analyzer):
|
||||
"""Test that valid config passes validation"""
|
||||
async def test_valid_config(self, security_analyzer: SecurityAnalyzer) -> None:
|
||||
"""Test that valid config passes validation."""
|
||||
config = {
|
||||
"file_extensions": [".py", ".js"],
|
||||
"check_secrets": True,
|
||||
"check_sql": True,
|
||||
"check_dangerous_functions": True
|
||||
"check_dangerous_functions": True,
|
||||
}
|
||||
assert security_analyzer.validate_config(config) is True
|
||||
|
||||
async def test_default_config(self, security_analyzer):
|
||||
"""Test that empty config uses defaults"""
|
||||
async def test_default_config(self, security_analyzer: SecurityAnalyzer) -> None:
|
||||
"""Test that empty config uses defaults."""
|
||||
config = {}
|
||||
assert security_analyzer.validate_config(config) is True
|
||||
|
||||
async def test_invalid_extensions_type(self, security_analyzer):
|
||||
"""Test that non-list extensions raises error"""
|
||||
async def test_invalid_extensions_type(self, security_analyzer: SecurityAnalyzer) -> None:
|
||||
"""Test that non-list extensions raises error."""
|
||||
config = {"file_extensions": ".py"}
|
||||
with pytest.raises(ValueError, match="file_extensions must be a list"):
|
||||
security_analyzer.validate_config(config)
|
||||
@@ -61,10 +58,10 @@ class TestSecurityAnalyzerConfigValidation:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestSecurityAnalyzerSecretDetection:
|
||||
"""Test hardcoded secret detection"""
|
||||
"""Test hardcoded secret detection."""
|
||||
|
||||
async def test_detect_api_key(self, security_analyzer, temp_workspace):
|
||||
"""Test detection of hardcoded API key"""
|
||||
async def test_detect_api_key(self, security_analyzer: SecurityAnalyzer, temp_workspace: Path) -> None:
|
||||
"""Test detection of hardcoded API key."""
|
||||
code_file = temp_workspace / "config.py"
|
||||
code_file.write_text("""
|
||||
# Configuration file
|
||||
@@ -76,7 +73,7 @@ database_url = "postgresql://localhost/db"
|
||||
"file_extensions": [".py"],
|
||||
"check_secrets": True,
|
||||
"check_sql": False,
|
||||
"check_dangerous_functions": False
|
||||
"check_dangerous_functions": False,
|
||||
}
|
||||
|
||||
result = await security_analyzer.execute(config, temp_workspace)
|
||||
@@ -86,8 +83,8 @@ database_url = "postgresql://localhost/db"
|
||||
assert len(secret_findings) > 0
|
||||
assert any("API Key" in f.title for f in secret_findings)
|
||||
|
||||
async def test_detect_password(self, security_analyzer, temp_workspace):
|
||||
"""Test detection of hardcoded password"""
|
||||
async def test_detect_password(self, security_analyzer: SecurityAnalyzer, temp_workspace: Path) -> None:
|
||||
"""Test detection of hardcoded password."""
|
||||
code_file = temp_workspace / "auth.py"
|
||||
code_file.write_text("""
|
||||
def connect():
|
||||
@@ -99,7 +96,7 @@ def connect():
|
||||
"file_extensions": [".py"],
|
||||
"check_secrets": True,
|
||||
"check_sql": False,
|
||||
"check_dangerous_functions": False
|
||||
"check_dangerous_functions": False,
|
||||
}
|
||||
|
||||
result = await security_analyzer.execute(config, temp_workspace)
|
||||
@@ -108,8 +105,8 @@ def connect():
|
||||
secret_findings = [f for f in result.findings if f.category == "hardcoded_secret"]
|
||||
assert len(secret_findings) > 0
|
||||
|
||||
async def test_detect_aws_credentials(self, security_analyzer, temp_workspace):
|
||||
"""Test detection of AWS credentials"""
|
||||
async def test_detect_aws_credentials(self, security_analyzer: SecurityAnalyzer, temp_workspace: Path) -> None:
|
||||
"""Test detection of AWS credentials."""
|
||||
code_file = temp_workspace / "aws_config.py"
|
||||
code_file.write_text("""
|
||||
aws_access_key = "AKIAIOSFODNN7REALKEY"
|
||||
@@ -118,7 +115,7 @@ aws_secret_key = "wJalrXUtnFEMI/K7MDENG/bPxRfiCYREALKEY"
|
||||
|
||||
config = {
|
||||
"file_extensions": [".py"],
|
||||
"check_secrets": True
|
||||
"check_secrets": True,
|
||||
}
|
||||
|
||||
result = await security_analyzer.execute(config, temp_workspace)
|
||||
@@ -127,14 +124,18 @@ aws_secret_key = "wJalrXUtnFEMI/K7MDENG/bPxRfiCYREALKEY"
|
||||
aws_findings = [f for f in result.findings if "AWS" in f.title]
|
||||
assert len(aws_findings) >= 2 # Both access key and secret key
|
||||
|
||||
async def test_no_secret_detection_when_disabled(self, security_analyzer, temp_workspace):
|
||||
"""Test that secret detection can be disabled"""
|
||||
async def test_no_secret_detection_when_disabled(
|
||||
self,
|
||||
security_analyzer: SecurityAnalyzer,
|
||||
temp_workspace: Path,
|
||||
) -> None:
|
||||
"""Test that secret detection can be disabled."""
|
||||
code_file = temp_workspace / "config.py"
|
||||
code_file.write_text('api_key = "sk_live_1234567890abcdef"')
|
||||
|
||||
config = {
|
||||
"file_extensions": [".py"],
|
||||
"check_secrets": False
|
||||
"check_secrets": False,
|
||||
}
|
||||
|
||||
result = await security_analyzer.execute(config, temp_workspace)
|
||||
@@ -146,10 +147,10 @@ aws_secret_key = "wJalrXUtnFEMI/K7MDENG/bPxRfiCYREALKEY"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestSecurityAnalyzerSQLInjection:
|
||||
"""Test SQL injection detection"""
|
||||
"""Test SQL injection detection."""
|
||||
|
||||
async def test_detect_string_concatenation(self, security_analyzer, temp_workspace):
|
||||
"""Test detection of SQL string concatenation"""
|
||||
async def test_detect_string_concatenation(self, security_analyzer: SecurityAnalyzer, temp_workspace: Path) -> None:
|
||||
"""Test detection of SQL string concatenation."""
|
||||
code_file = temp_workspace / "db.py"
|
||||
code_file.write_text("""
|
||||
def get_user(user_id):
|
||||
@@ -161,7 +162,7 @@ def get_user(user_id):
|
||||
"file_extensions": [".py"],
|
||||
"check_secrets": False,
|
||||
"check_sql": True,
|
||||
"check_dangerous_functions": False
|
||||
"check_dangerous_functions": False,
|
||||
}
|
||||
|
||||
result = await security_analyzer.execute(config, temp_workspace)
|
||||
@@ -170,8 +171,8 @@ def get_user(user_id):
|
||||
sql_findings = [f for f in result.findings if f.category == "sql_injection"]
|
||||
assert len(sql_findings) > 0
|
||||
|
||||
async def test_detect_f_string_sql(self, security_analyzer, temp_workspace):
|
||||
"""Test detection of f-string in SQL"""
|
||||
async def test_detect_f_string_sql(self, security_analyzer: SecurityAnalyzer, temp_workspace: Path) -> None:
|
||||
"""Test detection of f-string in SQL."""
|
||||
code_file = temp_workspace / "db.py"
|
||||
code_file.write_text("""
|
||||
def get_user(name):
|
||||
@@ -181,7 +182,7 @@ def get_user(name):
|
||||
|
||||
config = {
|
||||
"file_extensions": [".py"],
|
||||
"check_sql": True
|
||||
"check_sql": True,
|
||||
}
|
||||
|
||||
result = await security_analyzer.execute(config, temp_workspace)
|
||||
@@ -190,8 +191,12 @@ def get_user(name):
|
||||
sql_findings = [f for f in result.findings if f.category == "sql_injection"]
|
||||
assert len(sql_findings) > 0
|
||||
|
||||
async def test_detect_dynamic_query_building(self, security_analyzer, temp_workspace):
|
||||
"""Test detection of dynamic query building"""
|
||||
async def test_detect_dynamic_query_building(
|
||||
self,
|
||||
security_analyzer: SecurityAnalyzer,
|
||||
temp_workspace: Path,
|
||||
) -> None:
|
||||
"""Test detection of dynamic query building."""
|
||||
code_file = temp_workspace / "queries.py"
|
||||
code_file.write_text("""
|
||||
def search(keyword):
|
||||
@@ -201,7 +206,7 @@ def search(keyword):
|
||||
|
||||
config = {
|
||||
"file_extensions": [".py"],
|
||||
"check_sql": True
|
||||
"check_sql": True,
|
||||
}
|
||||
|
||||
result = await security_analyzer.execute(config, temp_workspace)
|
||||
@@ -210,14 +215,18 @@ def search(keyword):
|
||||
sql_findings = [f for f in result.findings if f.category == "sql_injection"]
|
||||
assert len(sql_findings) > 0
|
||||
|
||||
async def test_no_sql_detection_when_disabled(self, security_analyzer, temp_workspace):
|
||||
"""Test that SQL detection can be disabled"""
|
||||
async def test_no_sql_detection_when_disabled(
|
||||
self,
|
||||
security_analyzer: SecurityAnalyzer,
|
||||
temp_workspace: Path,
|
||||
) -> None:
|
||||
"""Test that SQL detection can be disabled."""
|
||||
code_file = temp_workspace / "db.py"
|
||||
code_file.write_text('query = "SELECT * FROM users WHERE id = " + user_id')
|
||||
|
||||
config = {
|
||||
"file_extensions": [".py"],
|
||||
"check_sql": False
|
||||
"check_sql": False,
|
||||
}
|
||||
|
||||
result = await security_analyzer.execute(config, temp_workspace)
|
||||
@@ -229,10 +238,10 @@ def search(keyword):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestSecurityAnalyzerDangerousFunctions:
|
||||
"""Test dangerous function detection"""
|
||||
"""Test dangerous function detection."""
|
||||
|
||||
async def test_detect_eval(self, security_analyzer, temp_workspace):
|
||||
"""Test detection of eval() usage"""
|
||||
async def test_detect_eval(self, security_analyzer: SecurityAnalyzer, temp_workspace: Path) -> None:
|
||||
"""Test detection of eval() usage."""
|
||||
code_file = temp_workspace / "dangerous.py"
|
||||
code_file.write_text("""
|
||||
def process_input(user_input):
|
||||
@@ -244,7 +253,7 @@ def process_input(user_input):
|
||||
"file_extensions": [".py"],
|
||||
"check_secrets": False,
|
||||
"check_sql": False,
|
||||
"check_dangerous_functions": True
|
||||
"check_dangerous_functions": True,
|
||||
}
|
||||
|
||||
result = await security_analyzer.execute(config, temp_workspace)
|
||||
@@ -254,8 +263,8 @@ def process_input(user_input):
|
||||
assert len(dangerous_findings) > 0
|
||||
assert any("eval" in f.title.lower() for f in dangerous_findings)
|
||||
|
||||
async def test_detect_exec(self, security_analyzer, temp_workspace):
|
||||
"""Test detection of exec() usage"""
|
||||
async def test_detect_exec(self, security_analyzer: SecurityAnalyzer, temp_workspace: Path) -> None:
|
||||
"""Test detection of exec() usage."""
|
||||
code_file = temp_workspace / "runner.py"
|
||||
code_file.write_text("""
|
||||
def run_code(code):
|
||||
@@ -264,7 +273,7 @@ def run_code(code):
|
||||
|
||||
config = {
|
||||
"file_extensions": [".py"],
|
||||
"check_dangerous_functions": True
|
||||
"check_dangerous_functions": True,
|
||||
}
|
||||
|
||||
result = await security_analyzer.execute(config, temp_workspace)
|
||||
@@ -273,8 +282,8 @@ def run_code(code):
|
||||
dangerous_findings = [f for f in result.findings if f.category == "dangerous_function"]
|
||||
assert len(dangerous_findings) > 0
|
||||
|
||||
async def test_detect_os_system(self, security_analyzer, temp_workspace):
|
||||
"""Test detection of os.system() usage"""
|
||||
async def test_detect_os_system(self, security_analyzer: SecurityAnalyzer, temp_workspace: Path) -> None:
|
||||
"""Test detection of os.system() usage."""
|
||||
code_file = temp_workspace / "commands.py"
|
||||
code_file.write_text("""
|
||||
import os
|
||||
@@ -285,7 +294,7 @@ def run_command(cmd):
|
||||
|
||||
config = {
|
||||
"file_extensions": [".py"],
|
||||
"check_dangerous_functions": True
|
||||
"check_dangerous_functions": True,
|
||||
}
|
||||
|
||||
result = await security_analyzer.execute(config, temp_workspace)
|
||||
@@ -295,8 +304,8 @@ def run_command(cmd):
|
||||
assert len(dangerous_findings) > 0
|
||||
assert any("os.system" in f.title for f in dangerous_findings)
|
||||
|
||||
async def test_detect_pickle_loads(self, security_analyzer, temp_workspace):
|
||||
"""Test detection of pickle.loads() usage"""
|
||||
async def test_detect_pickle_loads(self, security_analyzer: SecurityAnalyzer, temp_workspace: Path) -> None:
|
||||
"""Test detection of pickle.loads() usage."""
|
||||
code_file = temp_workspace / "serializer.py"
|
||||
code_file.write_text("""
|
||||
import pickle
|
||||
@@ -307,7 +316,7 @@ def deserialize(data):
|
||||
|
||||
config = {
|
||||
"file_extensions": [".py"],
|
||||
"check_dangerous_functions": True
|
||||
"check_dangerous_functions": True,
|
||||
}
|
||||
|
||||
result = await security_analyzer.execute(config, temp_workspace)
|
||||
@@ -316,8 +325,8 @@ def deserialize(data):
|
||||
dangerous_findings = [f for f in result.findings if f.category == "dangerous_function"]
|
||||
assert len(dangerous_findings) > 0
|
||||
|
||||
async def test_detect_javascript_eval(self, security_analyzer, temp_workspace):
|
||||
"""Test detection of eval() in JavaScript"""
|
||||
async def test_detect_javascript_eval(self, security_analyzer: SecurityAnalyzer, temp_workspace: Path) -> None:
|
||||
"""Test detection of eval() in JavaScript."""
|
||||
code_file = temp_workspace / "app.js"
|
||||
code_file.write_text("""
|
||||
function processInput(userInput) {
|
||||
@@ -327,7 +336,7 @@ function processInput(userInput) {
|
||||
|
||||
config = {
|
||||
"file_extensions": [".js"],
|
||||
"check_dangerous_functions": True
|
||||
"check_dangerous_functions": True,
|
||||
}
|
||||
|
||||
result = await security_analyzer.execute(config, temp_workspace)
|
||||
@@ -336,8 +345,8 @@ function processInput(userInput) {
|
||||
dangerous_findings = [f for f in result.findings if f.category == "dangerous_function"]
|
||||
assert len(dangerous_findings) > 0
|
||||
|
||||
async def test_detect_innerHTML(self, security_analyzer, temp_workspace):
|
||||
"""Test detection of innerHTML (XSS risk)"""
|
||||
async def test_detect_inner_html(self, security_analyzer: SecurityAnalyzer, temp_workspace: Path) -> None:
|
||||
"""Test detection of innerHTML (XSS risk)."""
|
||||
code_file = temp_workspace / "dom.js"
|
||||
code_file.write_text("""
|
||||
function updateContent(html) {
|
||||
@@ -347,7 +356,7 @@ function updateContent(html) {
|
||||
|
||||
config = {
|
||||
"file_extensions": [".js"],
|
||||
"check_dangerous_functions": True
|
||||
"check_dangerous_functions": True,
|
||||
}
|
||||
|
||||
result = await security_analyzer.execute(config, temp_workspace)
|
||||
@@ -356,14 +365,18 @@ function updateContent(html) {
|
||||
dangerous_findings = [f for f in result.findings if f.category == "dangerous_function"]
|
||||
assert len(dangerous_findings) > 0
|
||||
|
||||
async def test_no_dangerous_detection_when_disabled(self, security_analyzer, temp_workspace):
|
||||
"""Test that dangerous function detection can be disabled"""
|
||||
async def test_no_dangerous_detection_when_disabled(
|
||||
self,
|
||||
security_analyzer: SecurityAnalyzer,
|
||||
temp_workspace: Path,
|
||||
) -> None:
|
||||
"""Test that dangerous function detection can be disabled."""
|
||||
code_file = temp_workspace / "code.py"
|
||||
code_file.write_text('result = eval(user_input)')
|
||||
code_file.write_text("result = eval(user_input)")
|
||||
|
||||
config = {
|
||||
"file_extensions": [".py"],
|
||||
"check_dangerous_functions": False
|
||||
"check_dangerous_functions": False,
|
||||
}
|
||||
|
||||
result = await security_analyzer.execute(config, temp_workspace)
|
||||
@@ -375,10 +388,14 @@ function updateContent(html) {
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestSecurityAnalyzerMultipleIssues:
|
||||
"""Test detection of multiple issues in same file"""
|
||||
"""Test detection of multiple issues in same file."""
|
||||
|
||||
async def test_detect_multiple_vulnerabilities(self, security_analyzer, temp_workspace):
|
||||
"""Test detection of multiple vulnerability types"""
|
||||
async def test_detect_multiple_vulnerabilities(
|
||||
self,
|
||||
security_analyzer: SecurityAnalyzer,
|
||||
temp_workspace: Path,
|
||||
) -> None:
|
||||
"""Test detection of multiple vulnerability types."""
|
||||
code_file = temp_workspace / "vulnerable.py"
|
||||
code_file.write_text("""
|
||||
import os
|
||||
@@ -404,7 +421,7 @@ def process_query(user_input):
|
||||
"file_extensions": [".py"],
|
||||
"check_secrets": True,
|
||||
"check_sql": True,
|
||||
"check_dangerous_functions": True
|
||||
"check_dangerous_functions": True,
|
||||
}
|
||||
|
||||
result = await security_analyzer.execute(config, temp_workspace)
|
||||
@@ -423,10 +440,10 @@ def process_query(user_input):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestSecurityAnalyzerSummary:
|
||||
"""Test result summary generation"""
|
||||
"""Test result summary generation."""
|
||||
|
||||
async def test_summary_structure(self, security_analyzer, temp_workspace):
|
||||
"""Test that summary has correct structure"""
|
||||
async def test_summary_structure(self, security_analyzer: SecurityAnalyzer, temp_workspace: Path) -> None:
|
||||
"""Test that summary has correct structure."""
|
||||
(temp_workspace / "test.py").write_text("print('hello')")
|
||||
|
||||
config = {"file_extensions": [".py"]}
|
||||
@@ -441,16 +458,16 @@ class TestSecurityAnalyzerSummary:
|
||||
assert isinstance(result.summary["total_findings"], int)
|
||||
assert isinstance(result.summary["extensions_scanned"], list)
|
||||
|
||||
async def test_empty_workspace(self, security_analyzer, temp_workspace):
|
||||
"""Test analyzing empty workspace"""
|
||||
async def test_empty_workspace(self, security_analyzer: SecurityAnalyzer, temp_workspace: Path) -> None:
|
||||
"""Test analyzing empty workspace."""
|
||||
config = {"file_extensions": [".py"]}
|
||||
result = await security_analyzer.execute(config, temp_workspace)
|
||||
|
||||
assert result.status == "partial" # No files found
|
||||
assert result.summary["files_analyzed"] == 0
|
||||
|
||||
async def test_analyze_multiple_file_types(self, security_analyzer, temp_workspace):
|
||||
"""Test analyzing multiple file types"""
|
||||
async def test_analyze_multiple_file_types(self, security_analyzer: SecurityAnalyzer, temp_workspace: Path) -> None:
|
||||
"""Test analyzing multiple file types."""
|
||||
(temp_workspace / "app.py").write_text("print('hello')")
|
||||
(temp_workspace / "script.js").write_text("console.log('hello')")
|
||||
(temp_workspace / "index.php").write_text("<?php echo 'hello'; ?>")
|
||||
@@ -464,10 +481,10 @@ class TestSecurityAnalyzerSummary:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestSecurityAnalyzerFalsePositives:
|
||||
"""Test false positive filtering"""
|
||||
"""Test false positive filtering."""
|
||||
|
||||
async def test_skip_test_secrets(self, security_analyzer, temp_workspace):
|
||||
"""Test that test/example secrets are filtered"""
|
||||
async def test_skip_test_secrets(self, security_analyzer: SecurityAnalyzer, temp_workspace: Path) -> None:
|
||||
"""Test that test/example secrets are filtered."""
|
||||
code_file = temp_workspace / "test_config.py"
|
||||
code_file.write_text("""
|
||||
# Test configuration - should be filtered
|
||||
@@ -478,7 +495,7 @@ token = "sample_token_placeholder"
|
||||
|
||||
config = {
|
||||
"file_extensions": [".py"],
|
||||
"check_secrets": True
|
||||
"check_secrets": True,
|
||||
}
|
||||
|
||||
result = await security_analyzer.execute(config, temp_workspace)
|
||||
@@ -488,6 +505,6 @@ token = "sample_token_placeholder"
|
||||
secret_findings = [f for f in result.findings if f.category == "hardcoded_secret"]
|
||||
# Should have fewer or no findings due to false positive filtering
|
||||
assert len(secret_findings) == 0 or all(
|
||||
not any(fp in f.description.lower() for fp in ['test', 'example', 'dummy', 'sample'])
|
||||
not any(fp in f.description.lower() for fp in ["test", "example", "dummy", "sample"])
|
||||
for f in secret_findings
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user