Compare commits

...

4 Commits

Author SHA1 Message Date
fztee
a271a6bef7 chore: update Gitlab CI and add Makefile (backend package) with code quality commands to run. 2025-11-12 12:00:33 +01:00
fztee
40bbb18795 chore: improve code quality (backend package).
- add configuration file for 'ruff'.
    - fix most of 'ruff' lints.
    - format 'backend' package using 'ruff'.
2025-11-10 17:01:42 +01:00
fztee
a810e29f76 chore: update file 'pyproject.toml' (backend package).
- remove unused dependency 'httpx'.
    - rename optional dependency 'dev' to 'tests'.
2025-11-07 16:29:05 +01:00
fztee
1dc0d967b3 chore: bump and fix versions (backend package). 2025-11-07 16:19:26 +01:00
28 changed files with 1319 additions and 1169 deletions

View File

@@ -110,7 +110,7 @@ jobs:
- name: Install dependencies - name: Install dependencies
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
pip install ruff mypy pip install ruff mypy bandit
- name: Run ruff - name: Run ruff
run: ruff check backend/src backend/toolbox backend/tests backend/benchmarks --output-format=github run: ruff check backend/src backend/toolbox backend/tests backend/benchmarks --output-format=github
@@ -119,6 +119,10 @@ jobs:
run: mypy backend/src backend/toolbox || true run: mypy backend/src backend/toolbox || true
continue-on-error: true continue-on-error: true
- name: Run bandit (continue on error)
run: bandit --recursive backend/src || true
continue-on-error: true
unit-tests: unit-tests:
name: Unit Tests name: Unit Tests
runs-on: ubuntu-latest runs-on: ubuntu-latest

19
backend/Makefile Normal file
View File

@@ -0,0 +1,19 @@
SOURCES=./src
TESTS=./tests
.PHONY: bandit format mypy pytest ruff
bandit:
uv run bandit --recursive $(SOURCES)
format:
uv run ruff format $(SOURCES) $(TESTS)
mypy:
uv run mypy $(SOURCES) $(TESTS)
pytest:
PYTHONPATH=./toolbox uv run pytest $(TESTS)
ruff:
uv run ruff check --fix $(SOURCES) $(TESTS)

View File

@@ -6,28 +6,31 @@ authors = []
readme = "README.md" readme = "README.md"
requires-python = ">=3.11" requires-python = ">=3.11"
dependencies = [ dependencies = [
"fastapi>=0.116.1", "aiofiles==25.1.0",
"temporalio>=1.6.0", "aiohttp==3.13.2",
"boto3>=1.34.0", "boto3==1.40.68",
"pydantic>=2.0.0", "docker==7.1.0",
"pyyaml>=6.0", "fastapi==0.121.0",
"docker>=7.0.0", "fastmcp==2.13.0.2",
"aiofiles>=23.0.0", "pydantic==2.12.4",
"uvicorn>=0.30.0", "pyyaml==6.0.3",
"aiohttp>=3.12.15", "temporalio==1.18.2",
"fastmcp", "uvicorn==0.38.0",
] ]
[project.optional-dependencies] [project.optional-dependencies]
dev = [ lints = [
"pytest>=8.0.0", "bandit==1.8.6",
"pytest-asyncio>=0.23.0", "mypy==1.18.2",
"pytest-benchmark>=4.0.0", "ruff==0.14.4",
"pytest-cov>=5.0.0", ]
"pytest-xdist>=3.5.0", tests = [
"pytest-mock>=3.12.0", "pytest==8.4.2",
"httpx>=0.27.0", "pytest-asyncio==1.2.0",
"ruff>=0.1.0", "pytest-benchmark==5.2.1",
"pytest-cov==7.0.0",
"pytest-mock==3.15.1",
"pytest-xdist==3.8.0",
] ]
[tool.pytest.ini_options] [tool.pytest.ini_options]

11
backend/ruff.toml Normal file
View File

@@ -0,0 +1,11 @@
line-length = 120
[lint]
select = [ "ALL" ]
ignore = []
[lint.per-file-ignores]
"tests/*" = [
"PLR2004", # allowing comparisons using unamed numerical constants in tests
"S101", # allowing 'assert' statements in tests
]

View File

@@ -1,6 +1,4 @@
""" """API endpoints for fuzzing workflow management and real-time monitoring."""
API endpoints for fuzzing workflow management and real-time monitoring
"""
# Copyright (c) 2025 FuzzingLabs # Copyright (c) 2025 FuzzingLabs
# #
@@ -13,32 +11,29 @@ API endpoints for fuzzing workflow management and real-time monitoring
# #
# Additional attribution and requirements are provided in the NOTICE file. # Additional attribution and requirements are provided in the NOTICE file.
import logging
from typing import List, Dict
from fastapi import APIRouter, HTTPException, WebSocket, WebSocketDisconnect
from fastapi.responses import StreamingResponse
import asyncio import asyncio
import contextlib
import json import json
import logging
from datetime import datetime from datetime import datetime
from src.models.findings import ( from fastapi import APIRouter, HTTPException, WebSocket, WebSocketDisconnect
FuzzingStats, from fastapi.responses import StreamingResponse
CrashReport
) from src.models.findings import CrashReport, FuzzingStats
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
router = APIRouter(prefix="/fuzzing", tags=["fuzzing"]) router = APIRouter(prefix="/fuzzing", tags=["fuzzing"])
# In-memory storage for real-time stats (in production, use Redis or similar) # In-memory storage for real-time stats (in production, use Redis or similar)
fuzzing_stats: Dict[str, FuzzingStats] = {} fuzzing_stats: dict[str, FuzzingStats] = {}
crash_reports: Dict[str, List[CrashReport]] = {} crash_reports: dict[str, list[CrashReport]] = {}
active_connections: Dict[str, List[WebSocket]] = {} active_connections: dict[str, list[WebSocket]] = {}
def initialize_fuzzing_tracking(run_id: str, workflow_name: str): def initialize_fuzzing_tracking(run_id: str, workflow_name: str) -> None:
""" """Initialize fuzzing tracking for a new run.
Initialize fuzzing tracking for a new run.
This function should be called when a workflow is submitted to enable This function should be called when a workflow is submitted to enable
real-time monitoring and stats collection. real-time monitoring and stats collection.
@@ -46,19 +41,19 @@ def initialize_fuzzing_tracking(run_id: str, workflow_name: str):
Args: Args:
run_id: The run identifier run_id: The run identifier
workflow_name: Name of the workflow workflow_name: Name of the workflow
""" """
fuzzing_stats[run_id] = FuzzingStats( fuzzing_stats[run_id] = FuzzingStats(
run_id=run_id, run_id=run_id,
workflow=workflow_name workflow=workflow_name,
) )
crash_reports[run_id] = [] crash_reports[run_id] = []
active_connections[run_id] = [] active_connections[run_id] = []
@router.get("/{run_id}/stats", response_model=FuzzingStats) @router.get("/{run_id}/stats")
async def get_fuzzing_stats(run_id: str) -> FuzzingStats: async def get_fuzzing_stats(run_id: str) -> FuzzingStats:
""" """Get current fuzzing statistics for a run.
Get current fuzzing statistics for a run.
Args: Args:
run_id: The fuzzing run ID run_id: The fuzzing run ID
@@ -68,20 +63,20 @@ async def get_fuzzing_stats(run_id: str) -> FuzzingStats:
Raises: Raises:
HTTPException: 404 if run not found HTTPException: 404 if run not found
""" """
if run_id not in fuzzing_stats: if run_id not in fuzzing_stats:
raise HTTPException( raise HTTPException(
status_code=404, status_code=404,
detail=f"Fuzzing run not found: {run_id}" detail=f"Fuzzing run not found: {run_id}",
) )
return fuzzing_stats[run_id] return fuzzing_stats[run_id]
@router.get("/{run_id}/crashes", response_model=List[CrashReport]) @router.get("/{run_id}/crashes")
async def get_crash_reports(run_id: str) -> List[CrashReport]: async def get_crash_reports(run_id: str) -> list[CrashReport]:
""" """Get crash reports for a fuzzing run.
Get crash reports for a fuzzing run.
Args: Args:
run_id: The fuzzing run ID run_id: The fuzzing run ID
@@ -91,11 +86,12 @@ async def get_crash_reports(run_id: str) -> List[CrashReport]:
Raises: Raises:
HTTPException: 404 if run not found HTTPException: 404 if run not found
""" """
if run_id not in crash_reports: if run_id not in crash_reports:
raise HTTPException( raise HTTPException(
status_code=404, status_code=404,
detail=f"Fuzzing run not found: {run_id}" detail=f"Fuzzing run not found: {run_id}",
) )
return crash_reports[run_id] return crash_reports[run_id]
@@ -103,8 +99,7 @@ async def get_crash_reports(run_id: str) -> List[CrashReport]:
@router.post("/{run_id}/stats") @router.post("/{run_id}/stats")
async def update_fuzzing_stats(run_id: str, stats: FuzzingStats): async def update_fuzzing_stats(run_id: str, stats: FuzzingStats):
""" """Update fuzzing statistics (called by fuzzing workflows).
Update fuzzing statistics (called by fuzzing workflows).
Args: Args:
run_id: The fuzzing run ID run_id: The fuzzing run ID
@@ -112,18 +107,19 @@ async def update_fuzzing_stats(run_id: str, stats: FuzzingStats):
Raises: Raises:
HTTPException: 404 if run not found HTTPException: 404 if run not found
""" """
if run_id not in fuzzing_stats: if run_id not in fuzzing_stats:
raise HTTPException( raise HTTPException(
status_code=404, status_code=404,
detail=f"Fuzzing run not found: {run_id}" detail=f"Fuzzing run not found: {run_id}",
) )
# Update stats # Update stats
fuzzing_stats[run_id] = stats fuzzing_stats[run_id] = stats
# Debug: log reception for live instrumentation # Debug: log reception for live instrumentation
try: with contextlib.suppress(Exception):
logger.info( logger.info(
"Received fuzzing stats update: run_id=%s exec=%s eps=%.2f crashes=%s corpus=%s coverage=%s elapsed=%ss", "Received fuzzing stats update: run_id=%s exec=%s eps=%.2f crashes=%s corpus=%s coverage=%s elapsed=%ss",
run_id, run_id,
@@ -134,14 +130,12 @@ async def update_fuzzing_stats(run_id: str, stats: FuzzingStats):
stats.coverage, stats.coverage,
stats.elapsed_time, stats.elapsed_time,
) )
except Exception:
pass
# Notify connected WebSocket clients # Notify connected WebSocket clients
if run_id in active_connections: if run_id in active_connections:
message = { message = {
"type": "stats_update", "type": "stats_update",
"data": stats.model_dump() "data": stats.model_dump(),
} }
for websocket in active_connections[run_id][:]: # Copy to avoid modification during iteration for websocket in active_connections[run_id][:]: # Copy to avoid modification during iteration
try: try:
@@ -153,12 +147,12 @@ async def update_fuzzing_stats(run_id: str, stats: FuzzingStats):
@router.post("/{run_id}/crash") @router.post("/{run_id}/crash")
async def report_crash(run_id: str, crash: CrashReport): async def report_crash(run_id: str, crash: CrashReport):
""" """Report a new crash (called by fuzzing workflows).
Report a new crash (called by fuzzing workflows).
Args: Args:
run_id: The fuzzing run ID run_id: The fuzzing run ID
crash: Crash report details crash: Crash report details
""" """
if run_id not in crash_reports: if run_id not in crash_reports:
crash_reports[run_id] = [] crash_reports[run_id] = []
@@ -175,7 +169,7 @@ async def report_crash(run_id: str, crash: CrashReport):
if run_id in active_connections: if run_id in active_connections:
message = { message = {
"type": "crash_report", "type": "crash_report",
"data": crash.model_dump() "data": crash.model_dump(),
} }
for websocket in active_connections[run_id][:]: for websocket in active_connections[run_id][:]:
try: try:
@@ -186,12 +180,12 @@ async def report_crash(run_id: str, crash: CrashReport):
@router.websocket("/{run_id}/live") @router.websocket("/{run_id}/live")
async def websocket_endpoint(websocket: WebSocket, run_id: str): async def websocket_endpoint(websocket: WebSocket, run_id: str):
""" """WebSocket endpoint for real-time fuzzing updates.
WebSocket endpoint for real-time fuzzing updates.
Args: Args:
websocket: WebSocket connection websocket: WebSocket connection
run_id: The fuzzing run ID to monitor run_id: The fuzzing run ID to monitor
""" """
await websocket.accept() await websocket.accept()
@@ -223,7 +217,7 @@ async def websocket_endpoint(websocket: WebSocket, run_id: str):
# Echo back for ping-pong # Echo back for ping-pong
if data == "ping": if data == "ping":
await websocket.send_text("pong") await websocket.send_text("pong")
except asyncio.TimeoutError: except TimeoutError:
# Send periodic heartbeat # Send periodic heartbeat
await websocket.send_text(json.dumps({"type": "heartbeat"})) await websocket.send_text(json.dumps({"type": "heartbeat"}))
@@ -231,31 +225,31 @@ async def websocket_endpoint(websocket: WebSocket, run_id: str):
# Clean up connection # Clean up connection
if run_id in active_connections and websocket in active_connections[run_id]: if run_id in active_connections and websocket in active_connections[run_id]:
active_connections[run_id].remove(websocket) active_connections[run_id].remove(websocket)
except Exception as e: except Exception:
logger.error(f"WebSocket error for run {run_id}: {e}") logger.exception("WebSocket error for run %s", run_id)
if run_id in active_connections and websocket in active_connections[run_id]: if run_id in active_connections and websocket in active_connections[run_id]:
active_connections[run_id].remove(websocket) active_connections[run_id].remove(websocket)
@router.get("/{run_id}/stream") @router.get("/{run_id}/stream")
async def stream_fuzzing_updates(run_id: str): async def stream_fuzzing_updates(run_id: str):
""" """Server-Sent Events endpoint for real-time fuzzing updates.
Server-Sent Events endpoint for real-time fuzzing updates.
Args: Args:
run_id: The fuzzing run ID to monitor run_id: The fuzzing run ID to monitor
Returns: Returns:
Streaming response with real-time updates Streaming response with real-time updates
""" """
if run_id not in fuzzing_stats: if run_id not in fuzzing_stats:
raise HTTPException( raise HTTPException(
status_code=404, status_code=404,
detail=f"Fuzzing run not found: {run_id}" detail=f"Fuzzing run not found: {run_id}",
) )
async def event_stream(): async def event_stream():
"""Generate server-sent events for fuzzing updates""" """Generate server-sent events for fuzzing updates."""
last_stats_time = datetime.utcnow() last_stats_time = datetime.utcnow()
while True: while True:
@@ -276,10 +270,7 @@ async def stream_fuzzing_updates(run_id: str):
# Send recent crashes # Send recent crashes
if run_id in crash_reports: if run_id in crash_reports:
recent_crashes = [ recent_crashes = [crash for crash in crash_reports[run_id] if crash.timestamp > last_stats_time]
crash for crash in crash_reports[run_id]
if crash.timestamp > last_stats_time
]
for crash in recent_crashes: for crash in recent_crashes:
event_data = f"data: {json.dumps({'type': 'crash', 'data': crash.model_dump()})}\n\n" event_data = f"data: {json.dumps({'type': 'crash', 'data': crash.model_dump()})}\n\n"
yield event_data yield event_data
@@ -287,8 +278,8 @@ async def stream_fuzzing_updates(run_id: str):
last_stats_time = datetime.utcnow() last_stats_time = datetime.utcnow()
await asyncio.sleep(5) # Update every 5 seconds await asyncio.sleep(5) # Update every 5 seconds
except Exception as e: except Exception:
logger.error(f"Error in event stream for run {run_id}: {e}") logger.exception("Error in event stream for run %s", run_id)
break break
return StreamingResponse( return StreamingResponse(
@@ -297,17 +288,17 @@ async def stream_fuzzing_updates(run_id: str):
headers={ headers={
"Cache-Control": "no-cache", "Cache-Control": "no-cache",
"Connection": "keep-alive", "Connection": "keep-alive",
} },
) )
@router.delete("/{run_id}") @router.delete("/{run_id}")
async def cleanup_fuzzing_run(run_id: str): async def cleanup_fuzzing_run(run_id: str) -> dict[str, str]:
""" """Clean up fuzzing run data.
Clean up fuzzing run data.
Args: Args:
run_id: The fuzzing run ID to clean up run_id: The fuzzing run ID to clean up
""" """
# Clean up tracking data # Clean up tracking data
fuzzing_stats.pop(run_id, None) fuzzing_stats.pop(run_id, None)

View File

@@ -1,6 +1,4 @@
""" """API endpoints for workflow run management and findings retrieval."""
API endpoints for workflow run management and findings retrieval
"""
# Copyright (c) 2025 FuzzingLabs # Copyright (c) 2025 FuzzingLabs
# #
@@ -14,37 +12,36 @@ API endpoints for workflow run management and findings retrieval
# Additional attribution and requirements are provided in the NOTICE file. # Additional attribution and requirements are provided in the NOTICE file.
import logging import logging
from fastapi import APIRouter, HTTPException, Depends from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException
from src.main import temporal_mgr
from src.models.findings import WorkflowFindings, WorkflowStatus from src.models.findings import WorkflowFindings, WorkflowStatus
from src.temporal import TemporalManager
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
router = APIRouter(prefix="/runs", tags=["runs"]) router = APIRouter(prefix="/runs", tags=["runs"])
def get_temporal_manager(): def get_temporal_manager() -> TemporalManager:
"""Dependency to get the Temporal manager instance""" """Dependency to get the Temporal manager instance."""
from src.main import temporal_mgr
return temporal_mgr return temporal_mgr
@router.get("/{run_id}/status", response_model=WorkflowStatus) @router.get("/{run_id}/status")
async def get_run_status( async def get_run_status(
run_id: str, run_id: str,
temporal_mgr=Depends(get_temporal_manager) temporal_mgr: Annotated[TemporalManager, Depends(get_temporal_manager)],
) -> WorkflowStatus: ) -> WorkflowStatus:
""" """Get the current status of a workflow run.
Get the current status of a workflow run.
Args: :param run_id: The workflow run ID
run_id: The workflow run ID :param temporal_mgr: The temporal manager instance.
:return: Status information including state, timestamps, and completion flags
:raises HTTPException: 404 if run not found
Returns:
Status information including state, timestamps, and completion flags
Raises:
HTTPException: 404 if run not found
""" """
try: try:
status = await temporal_mgr.get_workflow_status(run_id) status = await temporal_mgr.get_workflow_status(run_id)
@@ -56,7 +53,7 @@ async def get_run_status(
is_running = workflow_status == "RUNNING" is_running = workflow_status == "RUNNING"
# Extract workflow name from run_id (format: workflow_name-unique_id) # Extract workflow name from run_id (format: workflow_name-unique_id)
workflow_name = run_id.rsplit('-', 1)[0] if '-' in run_id else "unknown" workflow_name = run_id.rsplit("-", 1)[0] if "-" in run_id else "unknown"
return WorkflowStatus( return WorkflowStatus(
run_id=run_id, run_id=run_id,
@@ -66,33 +63,29 @@ async def get_run_status(
is_failed=is_failed, is_failed=is_failed,
is_running=is_running, is_running=is_running,
created_at=status.get("start_time"), created_at=status.get("start_time"),
updated_at=status.get("close_time") or status.get("execution_time") updated_at=status.get("close_time") or status.get("execution_time"),
) )
except Exception as e: except Exception as e:
logger.error(f"Failed to get status for run {run_id}: {e}") logger.exception("Failed to get status for run %s", run_id)
raise HTTPException( raise HTTPException(
status_code=404, status_code=404,
detail=f"Run not found: {run_id}" detail=f"Run not found: {run_id}",
) ) from e
@router.get("/{run_id}/findings", response_model=WorkflowFindings) @router.get("/{run_id}/findings")
async def get_run_findings( async def get_run_findings(
run_id: str, run_id: str,
temporal_mgr=Depends(get_temporal_manager) temporal_mgr: Annotated[TemporalManager, Depends(get_temporal_manager)],
) -> WorkflowFindings: ) -> WorkflowFindings:
""" """Get the findings from a completed workflow run.
Get the findings from a completed workflow run.
Args: :param run_id: The workflow run ID
run_id: The workflow run ID :param temporal_mgr: The temporal manager instance.
:return: SARIF-formatted findings from the workflow execution
:raises HTTPException: 404 if run not found, 400 if run not completed
Returns:
SARIF-formatted findings from the workflow execution
Raises:
HTTPException: 404 if run not found, 400 if run not completed
""" """
try: try:
# Get run status first # Get run status first
@@ -103,80 +96,72 @@ async def get_run_findings(
if workflow_status == "RUNNING": if workflow_status == "RUNNING":
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail=f"Run {run_id} is still running. Current status: {workflow_status}" detail=f"Run {run_id} is still running. Current status: {workflow_status}",
)
else:
raise HTTPException(
status_code=400,
detail=f"Run {run_id} not completed. Status: {workflow_status}"
) )
raise HTTPException(
status_code=400,
detail=f"Run {run_id} not completed. Status: {workflow_status}",
)
if workflow_status == "FAILED": if workflow_status == "FAILED":
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail=f"Run {run_id} failed. Status: {workflow_status}" detail=f"Run {run_id} failed. Status: {workflow_status}",
) )
# Get the workflow result # Get the workflow result
result = await temporal_mgr.get_workflow_result(run_id) result = await temporal_mgr.get_workflow_result(run_id)
# Extract SARIF from result (handle None for backwards compatibility) # Extract SARIF from result (handle None for backwards compatibility)
if isinstance(result, dict): sarif = result.get("sarif", {}) if isinstance(result, dict) else {}
sarif = result.get("sarif") or {}
else:
sarif = {}
# Extract workflow name from run_id (format: workflow_name-unique_id) # Extract workflow name from run_id (format: workflow_name-unique_id)
workflow_name = run_id.rsplit('-', 1)[0] if '-' in run_id else "unknown" workflow_name = run_id.rsplit("-", 1)[0] if "-" in run_id else "unknown"
# Metadata # Metadata
metadata = { metadata = {
"completion_time": status.get("close_time"), "completion_time": status.get("close_time"),
"workflow_version": "unknown" "workflow_version": "unknown",
} }
return WorkflowFindings( return WorkflowFindings(
workflow=workflow_name, workflow=workflow_name,
run_id=run_id, run_id=run_id,
sarif=sarif, sarif=sarif,
metadata=metadata metadata=metadata,
) )
except HTTPException: except HTTPException:
raise raise
except Exception as e: except Exception as e:
logger.error(f"Failed to get findings for run {run_id}: {e}") logger.exception("Failed to get findings for run %s", run_id)
raise HTTPException( raise HTTPException(
status_code=500, status_code=500,
detail=f"Failed to retrieve findings: {str(e)}" detail=f"Failed to retrieve findings: {e!s}",
) ) from e
@router.get("/{workflow_name}/findings/{run_id}", response_model=WorkflowFindings) @router.get("/{workflow_name}/findings/{run_id}")
async def get_workflow_findings( async def get_workflow_findings(
workflow_name: str, workflow_name: str,
run_id: str, run_id: str,
temporal_mgr=Depends(get_temporal_manager) temporal_mgr: Annotated[TemporalManager, Depends(get_temporal_manager)],
) -> WorkflowFindings: ) -> WorkflowFindings:
""" """Get findings for a specific workflow run.
Get findings for a specific workflow run.
Alternative endpoint that includes workflow name in the path for clarity. Alternative endpoint that includes workflow name in the path for clarity.
Args: :param workflow_name: Name of the workflow
workflow_name: Name of the workflow :param run_id: The workflow run ID
run_id: The workflow run ID :param temporal_mgr: The temporal manager instance.
:return: SARIF-formatted findings from the workflow execution
:raises HTTPException: 404 if workflow or run not found, 400 if run not completed
Returns:
SARIF-formatted findings from the workflow execution
Raises:
HTTPException: 404 if workflow or run not found, 400 if run not completed
""" """
if workflow_name not in temporal_mgr.workflows: if workflow_name not in temporal_mgr.workflows:
raise HTTPException( raise HTTPException(
status_code=404, status_code=404,
detail=f"Workflow not found: {workflow_name}" detail=f"Workflow not found: {workflow_name}",
) )
# Delegate to the main findings endpoint # Delegate to the main findings endpoint

View File

@@ -9,14 +9,12 @@
# #
# Additional attribution and requirements are provided in the NOTICE file. # Additional attribution and requirements are provided in the NOTICE file.
""" """System information endpoints for FuzzForge API.
System information endpoints for FuzzForge API.
Provides system configuration and filesystem paths to CLI for worker management. Provides system configuration and filesystem paths to CLI for worker management.
""" """
import os import os
from typing import Dict
from fastapi import APIRouter from fastapi import APIRouter
@@ -24,9 +22,8 @@ router = APIRouter(prefix="/system", tags=["system"])
@router.get("/info") @router.get("/info")
async def get_system_info() -> Dict[str, str]: async def get_system_info() -> dict[str, str]:
""" """Get system information including host filesystem paths.
Get system information including host filesystem paths.
This endpoint exposes paths needed by the CLI to manage workers via docker-compose. This endpoint exposes paths needed by the CLI to manage workers via docker-compose.
The FUZZFORGE_HOST_ROOT environment variable is set by docker-compose and points The FUZZFORGE_HOST_ROOT environment variable is set by docker-compose and points
@@ -37,6 +34,7 @@ async def get_system_info() -> Dict[str, str]:
- host_root: Absolute path to FuzzForge root on host - host_root: Absolute path to FuzzForge root on host
- docker_compose_path: Path to docker-compose.yml on host - docker_compose_path: Path to docker-compose.yml on host
- workers_dir: Path to workers directory on host - workers_dir: Path to workers directory on host
""" """
host_root = os.getenv("FUZZFORGE_HOST_ROOT", "") host_root = os.getenv("FUZZFORGE_HOST_ROOT", "")

View File

@@ -1,6 +1,4 @@
""" """API endpoints for workflow management with enhanced error handling."""
API endpoints for workflow management with enhanced error handling
"""
# Copyright (c) 2025 FuzzingLabs # Copyright (c) 2025 FuzzingLabs
# #
@@ -13,20 +11,24 @@ API endpoints for workflow management with enhanced error handling
# #
# Additional attribution and requirements are provided in the NOTICE file. # Additional attribution and requirements are provided in the NOTICE file.
import json
import logging import logging
import traceback
import tempfile import tempfile
from typing import List, Dict, Any, Optional
from fastapi import APIRouter, HTTPException, Depends, UploadFile, File, Form
from pathlib import Path from pathlib import Path
from typing import Annotated, Any
from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile
from src.api.fuzzing import initialize_fuzzing_tracking
from src.main import temporal_mgr
from src.models.findings import ( from src.models.findings import (
WorkflowSubmission, RunSubmissionResponse,
WorkflowMetadata,
WorkflowListItem, WorkflowListItem,
RunSubmissionResponse WorkflowMetadata,
WorkflowSubmission,
) )
from src.temporal.discovery import WorkflowDiscovery from src.temporal.discovery import WorkflowDiscovery
from src.temporal.manager import TemporalManager
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -43,9 +45,8 @@ ALLOWED_CONTENT_TYPES = [
router = APIRouter(prefix="/workflows", tags=["workflows"]) router = APIRouter(prefix="/workflows", tags=["workflows"])
def extract_defaults_from_json_schema(metadata: Dict[str, Any]) -> Dict[str, Any]: def extract_defaults_from_json_schema(metadata: dict[str, Any]) -> dict[str, Any]:
""" """Extract default parameter values from JSON Schema format.
Extract default parameter values from JSON Schema format.
Converts from: Converts from:
parameters: parameters:
@@ -61,6 +62,7 @@ def extract_defaults_from_json_schema(metadata: Dict[str, Any]) -> Dict[str, Any
Returns: Returns:
Dictionary of parameter defaults Dictionary of parameter defaults
""" """
defaults = {} defaults = {}
@@ -82,19 +84,19 @@ def extract_defaults_from_json_schema(metadata: Dict[str, Any]) -> Dict[str, Any
def create_structured_error_response( def create_structured_error_response(
error_type: str, error_type: str,
message: str, message: str,
workflow_name: Optional[str] = None, workflow_name: str | None = None,
run_id: Optional[str] = None, run_id: str | None = None,
container_info: Optional[Dict[str, Any]] = None, container_info: dict[str, Any] | None = None,
deployment_info: Optional[Dict[str, Any]] = None, deployment_info: dict[str, Any] | None = None,
suggestions: Optional[List[str]] = None suggestions: list[str] | None = None,
) -> Dict[str, Any]: ) -> dict[str, Any]:
"""Create a structured error response with rich context.""" """Create a structured error response with rich context."""
error_response = { error_response = {
"error": { "error": {
"type": error_type, "type": error_type,
"message": message, "message": message,
"timestamp": __import__("datetime").datetime.utcnow().isoformat() + "Z" "timestamp": __import__("datetime").datetime.utcnow().isoformat() + "Z",
} },
} }
if workflow_name: if workflow_name:
@@ -115,39 +117,38 @@ def create_structured_error_response(
return error_response return error_response
def get_temporal_manager(): def get_temporal_manager() -> TemporalManager:
"""Dependency to get the Temporal manager instance""" """Dependency to get the Temporal manager instance."""
from src.main import temporal_mgr
return temporal_mgr return temporal_mgr
@router.get("/", response_model=List[WorkflowListItem]) @router.get("/")
async def list_workflows( async def list_workflows(
temporal_mgr=Depends(get_temporal_manager) temporal_mgr: Annotated[TemporalManager, Depends(get_temporal_manager)],
) -> List[WorkflowListItem]: ) -> list[WorkflowListItem]:
""" """List all discovered workflows with their metadata.
List all discovered workflows with their metadata.
Returns a summary of each workflow including name, version, description, Returns a summary of each workflow including name, version, description,
author, and tags. author, and tags.
""" """
workflows = [] workflows = []
for name, info in temporal_mgr.workflows.items(): for name, info in temporal_mgr.workflows.items():
workflows.append(WorkflowListItem( workflows.append(
name=name, WorkflowListItem(
version=info.metadata.get("version", "0.6.0"), name=name,
description=info.metadata.get("description", ""), version=info.metadata.get("version", "0.6.0"),
author=info.metadata.get("author"), description=info.metadata.get("description", ""),
tags=info.metadata.get("tags", []) author=info.metadata.get("author"),
)) tags=info.metadata.get("tags", []),
),
)
return workflows return workflows
@router.get("/metadata/schema") @router.get("/metadata/schema")
async def get_metadata_schema() -> Dict[str, Any]: async def get_metadata_schema() -> dict[str, Any]:
""" """Get the JSON schema for workflow metadata files.
Get the JSON schema for workflow metadata files.
This schema defines the structure and requirements for metadata.yaml files This schema defines the structure and requirements for metadata.yaml files
that must accompany each workflow. that must accompany each workflow.
@@ -155,23 +156,19 @@ async def get_metadata_schema() -> Dict[str, Any]:
return WorkflowDiscovery.get_metadata_schema() return WorkflowDiscovery.get_metadata_schema()
@router.get("/{workflow_name}/metadata", response_model=WorkflowMetadata) @router.get("/{workflow_name}/metadata")
async def get_workflow_metadata( async def get_workflow_metadata(
workflow_name: str, workflow_name: str,
temporal_mgr=Depends(get_temporal_manager) temporal_mgr: Annotated[TemporalManager, Depends(get_temporal_manager)],
) -> WorkflowMetadata: ) -> WorkflowMetadata:
""" """Get complete metadata for a specific workflow.
Get complete metadata for a specific workflow.
Args: :param workflow_name: Name of the workflow
workflow_name: Name of the workflow :param temporal_mgr: The temporal manager instance.
:return: Complete metadata including parameters schema, supported volume modes,
Returns:
Complete metadata including parameters schema, supported volume modes,
required modules, and more. required modules, and more.
:raises HTTPException: 404 if workflow not found
Raises:
HTTPException: 404 if workflow not found
""" """
if workflow_name not in temporal_mgr.workflows: if workflow_name not in temporal_mgr.workflows:
available_workflows = list(temporal_mgr.workflows.keys()) available_workflows = list(temporal_mgr.workflows.keys())
@@ -182,12 +179,12 @@ async def get_workflow_metadata(
suggestions=[ suggestions=[
f"Available workflows: {', '.join(available_workflows)}", f"Available workflows: {', '.join(available_workflows)}",
"Use GET /workflows/ to see all available workflows", "Use GET /workflows/ to see all available workflows",
"Check workflow name spelling and case sensitivity" "Check workflow name spelling and case sensitivity",
] ],
) )
raise HTTPException( raise HTTPException(
status_code=404, status_code=404,
detail=error_response detail=error_response,
) )
info = temporal_mgr.workflows[workflow_name] info = temporal_mgr.workflows[workflow_name]
@@ -201,28 +198,24 @@ async def get_workflow_metadata(
tags=metadata.get("tags", []), tags=metadata.get("tags", []),
parameters=metadata.get("parameters", {}), parameters=metadata.get("parameters", {}),
default_parameters=extract_defaults_from_json_schema(metadata), default_parameters=extract_defaults_from_json_schema(metadata),
required_modules=metadata.get("required_modules", []) required_modules=metadata.get("required_modules", []),
) )
@router.post("/{workflow_name}/submit", response_model=RunSubmissionResponse) @router.post("/{workflow_name}/submit")
async def submit_workflow( async def submit_workflow(
workflow_name: str, workflow_name: str,
submission: WorkflowSubmission, submission: WorkflowSubmission,
temporal_mgr=Depends(get_temporal_manager) temporal_mgr: Annotated[TemporalManager, Depends(get_temporal_manager)],
) -> RunSubmissionResponse: ) -> RunSubmissionResponse:
""" """Submit a workflow for execution.
Submit a workflow for execution.
Args: :param workflow_name: Name of the workflow to execute
workflow_name: Name of the workflow to execute :param submission: Submission parameters including target path and parameters
submission: Submission parameters including target path and parameters :param temporal_mgr: The temporal manager instance.
:return: Run submission response with run_id and initial status
:raises HTTPException: 404 if workflow not found, 400 for invalid parameters
Returns:
Run submission response with run_id and initial status
Raises:
HTTPException: 404 if workflow not found, 400 for invalid parameters
""" """
if workflow_name not in temporal_mgr.workflows: if workflow_name not in temporal_mgr.workflows:
available_workflows = list(temporal_mgr.workflows.keys()) available_workflows = list(temporal_mgr.workflows.keys())
@@ -233,25 +226,26 @@ async def submit_workflow(
suggestions=[ suggestions=[
f"Available workflows: {', '.join(available_workflows)}", f"Available workflows: {', '.join(available_workflows)}",
"Use GET /workflows/ to see all available workflows", "Use GET /workflows/ to see all available workflows",
"Check workflow name spelling and case sensitivity" "Check workflow name spelling and case sensitivity",
] ],
) )
raise HTTPException( raise HTTPException(
status_code=404, status_code=404,
detail=error_response detail=error_response,
) )
try: try:
# Upload target file to MinIO and get target_id # Upload target file to MinIO and get target_id
target_path = Path(submission.target_path) target_path = Path(submission.target_path)
if not target_path.exists(): if not target_path.exists():
raise ValueError(f"Target path does not exist: {submission.target_path}") msg = f"Target path does not exist: {submission.target_path}"
raise ValueError(msg)
# Upload target (using anonymous user for now) # Upload target (using anonymous user for now)
target_id = await temporal_mgr.upload_target( target_id = await temporal_mgr.upload_target(
file_path=target_path, file_path=target_path,
user_id="api-user", user_id="api-user",
metadata={"workflow": workflow_name} metadata={"workflow": workflow_name},
) )
# Merge default parameters with user parameters # Merge default parameters with user parameters
@@ -265,23 +259,22 @@ async def submit_workflow(
handle = await temporal_mgr.run_workflow( handle = await temporal_mgr.run_workflow(
workflow_name=workflow_name, workflow_name=workflow_name,
target_id=target_id, target_id=target_id,
workflow_params=workflow_params workflow_params=workflow_params,
) )
run_id = handle.id run_id = handle.id
# Initialize fuzzing tracking if this looks like a fuzzing workflow # Initialize fuzzing tracking if this looks like a fuzzing workflow
workflow_info = temporal_mgr.workflows.get(workflow_name, {}) workflow_info = temporal_mgr.workflows.get(workflow_name, {})
workflow_tags = workflow_info.metadata.get("tags", []) if hasattr(workflow_info, 'metadata') else [] workflow_tags = workflow_info.metadata.get("tags", []) if hasattr(workflow_info, "metadata") else []
if "fuzzing" in workflow_tags or "fuzz" in workflow_name.lower(): if "fuzzing" in workflow_tags or "fuzz" in workflow_name.lower():
from src.api.fuzzing import initialize_fuzzing_tracking
initialize_fuzzing_tracking(run_id, workflow_name) initialize_fuzzing_tracking(run_id, workflow_name)
return RunSubmissionResponse( return RunSubmissionResponse(
run_id=run_id, run_id=run_id,
status="RUNNING", status="RUNNING",
workflow=workflow_name, workflow=workflow_name,
message=f"Workflow '{workflow_name}' submitted successfully" message=f"Workflow '{workflow_name}' submitted successfully",
) )
except ValueError as e: except ValueError as e:
@@ -293,14 +286,13 @@ async def submit_workflow(
suggestions=[ suggestions=[
"Check parameter types and values", "Check parameter types and values",
"Use GET /workflows/{workflow_name}/parameters for schema", "Use GET /workflows/{workflow_name}/parameters for schema",
"Ensure all required parameters are provided" "Ensure all required parameters are provided",
] ],
) )
raise HTTPException(status_code=400, detail=error_response) raise HTTPException(status_code=400, detail=error_response) from e
except Exception as e: except Exception as e:
logger.error(f"Failed to submit workflow '{workflow_name}': {e}") logger.exception("Failed to submit workflow '%s'", workflow_name)
logger.error(f"Traceback: {traceback.format_exc()}")
# Try to get more context about the error # Try to get more context about the error
container_info = None container_info = None
@@ -313,47 +305,57 @@ async def submit_workflow(
# Detect specific error patterns # Detect specific error patterns
if "workflow" in error_message.lower() and "not found" in error_message.lower(): if "workflow" in error_message.lower() and "not found" in error_message.lower():
error_type = "WorkflowError" error_type = "WorkflowError"
suggestions.extend([ suggestions.extend(
"Check if Temporal server is running and accessible", [
"Verify workflow workers are running", "Check if Temporal server is running and accessible",
"Check if workflow is registered with correct vertical", "Verify workflow workers are running",
"Ensure Docker is running and has sufficient resources" "Check if workflow is registered with correct vertical",
]) "Ensure Docker is running and has sufficient resources",
],
)
elif "volume" in error_message.lower() or "mount" in error_message.lower(): elif "volume" in error_message.lower() or "mount" in error_message.lower():
error_type = "VolumeError" error_type = "VolumeError"
suggestions.extend([ suggestions.extend(
"Check if the target path exists and is accessible", [
"Verify file permissions (Docker needs read access)", "Check if the target path exists and is accessible",
"Ensure the path is not in use by another process", "Verify file permissions (Docker needs read access)",
"Try using an absolute path instead of relative path" "Ensure the path is not in use by another process",
]) "Try using an absolute path instead of relative path",
],
)
elif "memory" in error_message.lower() or "resource" in error_message.lower(): elif "memory" in error_message.lower() or "resource" in error_message.lower():
error_type = "ResourceError" error_type = "ResourceError"
suggestions.extend([ suggestions.extend(
"Check system memory and CPU availability", [
"Consider reducing resource limits or dataset size", "Check system memory and CPU availability",
"Monitor Docker resource usage", "Consider reducing resource limits or dataset size",
"Increase Docker memory limits if needed" "Monitor Docker resource usage",
]) "Increase Docker memory limits if needed",
],
)
elif "image" in error_message.lower(): elif "image" in error_message.lower():
error_type = "ImageError" error_type = "ImageError"
suggestions.extend([ suggestions.extend(
"Check if the workflow image exists", [
"Verify Docker registry access", "Check if the workflow image exists",
"Try rebuilding the workflow image", "Verify Docker registry access",
"Check network connectivity to registries" "Try rebuilding the workflow image",
]) "Check network connectivity to registries",
],
)
else: else:
suggestions.extend([ suggestions.extend(
"Check FuzzForge backend logs for details", [
"Verify all services are running (docker-compose up -d)", "Check FuzzForge backend logs for details",
"Try restarting the workflow deployment", "Verify all services are running (docker-compose up -d)",
"Contact support if the issue persists" "Try restarting the workflow deployment",
]) "Contact support if the issue persists",
],
)
error_response = create_structured_error_response( error_response = create_structured_error_response(
error_type=error_type, error_type=error_type,
@@ -361,41 +363,35 @@ async def submit_workflow(
workflow_name=workflow_name, workflow_name=workflow_name,
container_info=container_info, container_info=container_info,
deployment_info=deployment_info, deployment_info=deployment_info,
suggestions=suggestions suggestions=suggestions,
) )
raise HTTPException( raise HTTPException(
status_code=500, status_code=500,
detail=error_response detail=error_response,
) ) from e
@router.post("/{workflow_name}/upload-and-submit", response_model=RunSubmissionResponse) @router.post("/{workflow_name}/upload-and-submit")
async def upload_and_submit_workflow( async def upload_and_submit_workflow(
workflow_name: str, workflow_name: str,
file: UploadFile = File(..., description="Target file or tarball to analyze"), temporal_mgr: Annotated[TemporalManager, Depends(get_temporal_manager)],
parameters: Optional[str] = Form(None, description="JSON-encoded workflow parameters"), file: Annotated[UploadFile, File(..., description="Target file or tarball to analyze")],
timeout: Optional[int] = Form(None, description="Timeout in seconds"), parameters: Annotated[str, Form(None, description="JSON-encoded workflow parameters")],
temporal_mgr=Depends(get_temporal_manager)
) -> RunSubmissionResponse: ) -> RunSubmissionResponse:
""" """Upload a target file/tarball and submit workflow for execution.
Upload a target file/tarball and submit workflow for execution.
This endpoint accepts multipart/form-data uploads and is the recommended This endpoint accepts multipart/form-data uploads and is the recommended
way to submit workflows from remote CLI clients. way to submit workflows from remote CLI clients.
Args: :param workflow_name: Name of the workflow to execute
workflow_name: Name of the workflow to execute :param temporal_mgr: The temporal manager instance.
file: Target file or tarball (compressed directory) :param file: Target file or tarball (compressed directory)
parameters: JSON string of workflow parameters (optional) :param parameters: JSON string of workflow parameters (optional)
timeout: Execution timeout in seconds (optional) :returns: Run submission response with run_id and initial status
:raises HTTPException: 404 if workflow not found, 400 for invalid parameters,
413 if file too large
Returns:
Run submission response with run_id and initial status
Raises:
HTTPException: 404 if workflow not found, 400 for invalid parameters,
413 if file too large
""" """
if workflow_name not in temporal_mgr.workflows: if workflow_name not in temporal_mgr.workflows:
available_workflows = list(temporal_mgr.workflows.keys()) available_workflows = list(temporal_mgr.workflows.keys())
@@ -405,8 +401,8 @@ async def upload_and_submit_workflow(
workflow_name=workflow_name, workflow_name=workflow_name,
suggestions=[ suggestions=[
f"Available workflows: {', '.join(available_workflows)}", f"Available workflows: {', '.join(available_workflows)}",
"Use GET /workflows/ to see all available workflows" "Use GET /workflows/ to see all available workflows",
] ],
) )
raise HTTPException(status_code=404, detail=error_response) raise HTTPException(status_code=404, detail=error_response)
@@ -420,10 +416,10 @@ async def upload_and_submit_workflow(
# Create temporary file # Create temporary file
temp_fd, temp_file_path = tempfile.mkstemp(suffix=".tar.gz") temp_fd, temp_file_path = tempfile.mkstemp(suffix=".tar.gz")
logger.info(f"Receiving file upload for workflow '{workflow_name}': {file.filename}") logger.info("Receiving file upload for workflow '%s': %s", workflow_name, file.filename)
# Stream file to disk # Stream file to disk
with open(temp_fd, 'wb') as temp_file: with open(temp_fd, "wb") as temp_file:
while True: while True:
chunk = await file.read(chunk_size) chunk = await file.read(chunk_size)
if not chunk: if not chunk:
@@ -442,33 +438,33 @@ async def upload_and_submit_workflow(
suggestions=[ suggestions=[
"Reduce the size of your target directory", "Reduce the size of your target directory",
"Exclude unnecessary files (build artifacts, dependencies, etc.)", "Exclude unnecessary files (build artifacts, dependencies, etc.)",
"Consider splitting into smaller analysis targets" "Consider splitting into smaller analysis targets",
] ],
) ),
) )
temp_file.write(chunk) temp_file.write(chunk)
logger.info(f"Received file: {file_size / (1024**2):.2f} MB") logger.info("Received file: %s MB", f"{file_size / (1024**2):.2f}")
# Parse parameters # Parse parameters
workflow_params = {} workflow_params = {}
if parameters: if parameters:
try: try:
import json
workflow_params = json.loads(parameters) workflow_params = json.loads(parameters)
if not isinstance(workflow_params, dict): if not isinstance(workflow_params, dict):
raise ValueError("Parameters must be a JSON object") msg = "Parameters must be a JSON object"
except (json.JSONDecodeError, ValueError) as e: raise TypeError(msg)
except (json.JSONDecodeError, TypeError) as e:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail=create_structured_error_response( detail=create_structured_error_response(
error_type="InvalidParameters", error_type="InvalidParameters",
message=f"Invalid parameters JSON: {e}", message=f"Invalid parameters JSON: {e}",
workflow_name=workflow_name, workflow_name=workflow_name,
suggestions=["Ensure parameters is valid JSON object"] suggestions=["Ensure parameters is valid JSON object"],
) ),
) ) from e
# Upload to MinIO # Upload to MinIO
target_id = await temporal_mgr.upload_target( target_id = await temporal_mgr.upload_target(
@@ -477,11 +473,11 @@ async def upload_and_submit_workflow(
metadata={ metadata={
"workflow": workflow_name, "workflow": workflow_name,
"original_filename": file.filename, "original_filename": file.filename,
"upload_method": "multipart" "upload_method": "multipart",
} },
) )
logger.info(f"Uploaded to MinIO with target_id: {target_id}") logger.info("Uploaded to MinIO with target_id: %s", target_id)
# Merge default parameters with user parameters # Merge default parameters with user parameters
workflow_info = temporal_mgr.workflows.get(workflow_name) workflow_info = temporal_mgr.workflows.get(workflow_name)
@@ -493,74 +489,68 @@ async def upload_and_submit_workflow(
handle = await temporal_mgr.run_workflow( handle = await temporal_mgr.run_workflow(
workflow_name=workflow_name, workflow_name=workflow_name,
target_id=target_id, target_id=target_id,
workflow_params=workflow_params workflow_params=workflow_params,
) )
run_id = handle.id run_id = handle.id
# Initialize fuzzing tracking if needed # Initialize fuzzing tracking if needed
workflow_info = temporal_mgr.workflows.get(workflow_name, {}) workflow_info = temporal_mgr.workflows.get(workflow_name, {})
workflow_tags = workflow_info.metadata.get("tags", []) if hasattr(workflow_info, 'metadata') else [] workflow_tags = workflow_info.metadata.get("tags", []) if hasattr(workflow_info, "metadata") else []
if "fuzzing" in workflow_tags or "fuzz" in workflow_name.lower(): if "fuzzing" in workflow_tags or "fuzz" in workflow_name.lower():
from src.api.fuzzing import initialize_fuzzing_tracking
initialize_fuzzing_tracking(run_id, workflow_name) initialize_fuzzing_tracking(run_id, workflow_name)
return RunSubmissionResponse( return RunSubmissionResponse(
run_id=run_id, run_id=run_id,
status="RUNNING", status="RUNNING",
workflow=workflow_name, workflow=workflow_name,
message=f"Workflow '{workflow_name}' submitted successfully with uploaded target" message=f"Workflow '{workflow_name}' submitted successfully with uploaded target",
) )
except HTTPException: except HTTPException:
raise raise
except Exception as e: except Exception as e:
logger.error(f"Failed to upload and submit workflow '{workflow_name}': {e}") logger.exception("Failed to upload and submit workflow '%s'", workflow_name)
logger.error(f"Traceback: {traceback.format_exc()}")
error_response = create_structured_error_response( error_response = create_structured_error_response(
error_type="WorkflowSubmissionError", error_type="WorkflowSubmissionError",
message=f"Failed to process upload and submit workflow: {str(e)}", message=f"Failed to process upload and submit workflow: {e!s}",
workflow_name=workflow_name, workflow_name=workflow_name,
suggestions=[ suggestions=[
"Check if the uploaded file is a valid tarball", "Check if the uploaded file is a valid tarball",
"Verify MinIO storage is accessible", "Verify MinIO storage is accessible",
"Check backend logs for detailed error information", "Check backend logs for detailed error information",
"Ensure Temporal workers are running" "Ensure Temporal workers are running",
] ],
) )
raise HTTPException(status_code=500, detail=error_response) raise HTTPException(status_code=500, detail=error_response) from e
finally: finally:
# Cleanup temporary file # Cleanup temporary file
if temp_file_path and Path(temp_file_path).exists(): if temp_file_path and Path(temp_file_path).exists():
try: try:
Path(temp_file_path).unlink() Path(temp_file_path).unlink()
logger.debug(f"Cleaned up temp file: {temp_file_path}") logger.debug("Cleaned up temp file: %s", temp_file_path)
except Exception as e: except Exception as e:
logger.warning(f"Failed to cleanup temp file {temp_file_path}: {e}") logger.warning("Failed to cleanup temp file %s: %s", temp_file_path, e)
@router.get("/{workflow_name}/worker-info") @router.get("/{workflow_name}/worker-info")
async def get_workflow_worker_info( async def get_workflow_worker_info(
workflow_name: str, workflow_name: str,
temporal_mgr=Depends(get_temporal_manager) temporal_mgr: Annotated[TemporalManager, Depends(get_temporal_manager)],
) -> Dict[str, Any]: ) -> dict[str, Any]:
""" """Get worker information for a workflow.
Get worker information for a workflow.
Returns details about which worker is required to execute this workflow, Returns details about which worker is required to execute this workflow,
including container name, task queue, and vertical. including container name, task queue, and vertical.
Args: :param workflow_name: Name of the workflow
workflow_name: Name of the workflow :param temporal_mgr: The temporal manager instance.
:return: Worker information including vertical, container name, and task queue
:raises HTTPException: 404 if workflow not found
Returns:
Worker information including vertical, container name, and task queue
Raises:
HTTPException: 404 if workflow not found
""" """
if workflow_name not in temporal_mgr.workflows: if workflow_name not in temporal_mgr.workflows:
available_workflows = list(temporal_mgr.workflows.keys()) available_workflows = list(temporal_mgr.workflows.keys())
@@ -570,12 +560,12 @@ async def get_workflow_worker_info(
workflow_name=workflow_name, workflow_name=workflow_name,
suggestions=[ suggestions=[
f"Available workflows: {', '.join(available_workflows)}", f"Available workflows: {', '.join(available_workflows)}",
"Use GET /workflows/ to see all available workflows" "Use GET /workflows/ to see all available workflows",
] ],
) )
raise HTTPException( raise HTTPException(
status_code=404, status_code=404,
detail=error_response detail=error_response,
) )
info = temporal_mgr.workflows[workflow_name] info = temporal_mgr.workflows[workflow_name]
@@ -591,12 +581,12 @@ async def get_workflow_worker_info(
workflow_name=workflow_name, workflow_name=workflow_name,
suggestions=[ suggestions=[
"Check workflow metadata.yaml for 'vertical' field", "Check workflow metadata.yaml for 'vertical' field",
"Contact workflow author for support" "Contact workflow author for support",
] ],
) )
raise HTTPException( raise HTTPException(
status_code=500, status_code=500,
detail=error_response detail=error_response,
) )
return { return {
@@ -604,26 +594,22 @@ async def get_workflow_worker_info(
"vertical": vertical, "vertical": vertical,
"worker_service": f"worker-{vertical}", "worker_service": f"worker-{vertical}",
"task_queue": f"{vertical}-queue", "task_queue": f"{vertical}-queue",
"required": True "required": True,
} }
@router.get("/{workflow_name}/parameters") @router.get("/{workflow_name}/parameters")
async def get_workflow_parameters( async def get_workflow_parameters(
workflow_name: str, workflow_name: str,
temporal_mgr=Depends(get_temporal_manager) temporal_mgr: Annotated[TemporalManager, Depends(get_temporal_manager)],
) -> Dict[str, Any]: ) -> dict[str, Any]:
""" """Get the parameters schema for a workflow.
Get the parameters schema for a workflow.
Args: :param workflow_name: Name of the workflow
workflow_name: Name of the workflow :param temporal_mgr: The temporal manager instance.
:return: Parameters schema with types, descriptions, and defaults
:raises HTTPException: 404 if workflow not found
Returns:
Parameters schema with types, descriptions, and defaults
Raises:
HTTPException: 404 if workflow not found
""" """
if workflow_name not in temporal_mgr.workflows: if workflow_name not in temporal_mgr.workflows:
available_workflows = list(temporal_mgr.workflows.keys()) available_workflows = list(temporal_mgr.workflows.keys())
@@ -633,12 +619,12 @@ async def get_workflow_parameters(
workflow_name=workflow_name, workflow_name=workflow_name,
suggestions=[ suggestions=[
f"Available workflows: {', '.join(available_workflows)}", f"Available workflows: {', '.join(available_workflows)}",
"Use GET /workflows/ to see all available workflows" "Use GET /workflows/ to see all available workflows",
] ],
) )
raise HTTPException( raise HTTPException(
status_code=404, status_code=404,
detail=error_response detail=error_response,
) )
info = temporal_mgr.workflows[workflow_name] info = temporal_mgr.workflows[workflow_name]
@@ -648,10 +634,7 @@ async def get_workflow_parameters(
parameters_schema = metadata.get("parameters", {}) parameters_schema = metadata.get("parameters", {})
# Extract the actual parameter definitions from JSON schema structure # Extract the actual parameter definitions from JSON schema structure
if "properties" in parameters_schema: param_definitions = parameters_schema.get("properties", parameters_schema)
param_definitions = parameters_schema["properties"]
else:
param_definitions = parameters_schema
# Extract default values from JSON Schema # Extract default values from JSON Schema
default_params = extract_defaults_from_json_schema(metadata) default_params = extract_defaults_from_json_schema(metadata)
@@ -661,7 +644,8 @@ async def get_workflow_parameters(
"parameters": param_definitions, "parameters": param_definitions,
"default_parameters": default_params, "default_parameters": default_params,
"required_parameters": [ "required_parameters": [
name for name, schema in param_definitions.items() name
for name, schema in param_definitions.items()
if isinstance(schema, dict) and schema.get("required", False) if isinstance(schema, dict) and schema.get("required", False)
] ],
} }

View File

@@ -1,6 +1,4 @@
""" """Setup utilities for FuzzForge infrastructure."""
Setup utilities for FuzzForge infrastructure
"""
# Copyright (c) 2025 FuzzingLabs # Copyright (c) 2025 FuzzingLabs
# #
@@ -18,9 +16,8 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
async def setup_result_storage(): async def setup_result_storage() -> bool:
""" """Set up result storage (MinIO).
Setup result storage (MinIO).
MinIO is used for both target upload and result storage. MinIO is used for both target upload and result storage.
This is a placeholder for any MinIO-specific setup if needed. This is a placeholder for any MinIO-specific setup if needed.
@@ -31,9 +28,8 @@ async def setup_result_storage():
return True return True
async def validate_infrastructure(): async def validate_infrastructure() -> None:
""" """Validate all required infrastructure components.
Validate all required infrastructure components.
This should be called during startup to ensure everything is ready. This should be called during startup to ensure everything is ready.
""" """

View File

@@ -13,20 +13,19 @@ import asyncio
import logging import logging
import os import os
from contextlib import AsyncExitStack, asynccontextmanager, suppress from contextlib import AsyncExitStack, asynccontextmanager, suppress
from typing import Any, Dict, Optional, List from typing import Any
import uvicorn import uvicorn
from fastapi import FastAPI from fastapi import FastAPI
from fastmcp import FastMCP
from fastmcp.server.http import create_sse_app
from starlette.applications import Starlette from starlette.applications import Starlette
from starlette.routing import Mount from starlette.routing import Mount
from fastmcp.server.http import create_sse_app from src.api import fuzzing, runs, system, workflows
from src.temporal.manager import TemporalManager
from src.core.setup import setup_result_storage, validate_infrastructure from src.core.setup import setup_result_storage, validate_infrastructure
from src.api import workflows, runs, fuzzing, system from src.temporal.discovery import WorkflowDiscovery
from src.temporal.manager import TemporalManager
from fastmcp import FastMCP
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -38,12 +37,14 @@ class TemporalBootstrapState:
"""Tracks Temporal initialization progress for API and MCP consumers.""" """Tracks Temporal initialization progress for API and MCP consumers."""
def __init__(self) -> None: def __init__(self) -> None:
"""Initialize an instance of the class."""
self.ready: bool = False self.ready: bool = False
self.status: str = "not_started" self.status: str = "not_started"
self.last_error: Optional[str] = None self.last_error: str | None = None
self.task_running: bool = False self.task_running: bool = False
def as_dict(self) -> Dict[str, Any]: def as_dict(self) -> dict[str, Any]:
"""Return the current state as a Python dictionnary."""
return { return {
"ready": self.ready, "ready": self.ready,
"status": self.status, "status": self.status,
@@ -61,7 +62,7 @@ STARTUP_RETRY_MAX_SECONDS = max(
int(os.getenv("FUZZFORGE_STARTUP_RETRY_MAX_SECONDS", "60")), int(os.getenv("FUZZFORGE_STARTUP_RETRY_MAX_SECONDS", "60")),
) )
temporal_bootstrap_task: Optional[asyncio.Task] = None temporal_bootstrap_task: asyncio.Task | None = None
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# FastAPI application (REST API) # FastAPI application (REST API)
@@ -79,17 +80,15 @@ app.include_router(fuzzing.router)
app.include_router(system.router) app.include_router(system.router)
def get_temporal_status() -> Dict[str, Any]: def get_temporal_status() -> dict[str, Any]:
"""Return a snapshot of Temporal bootstrap state for diagnostics.""" """Return a snapshot of Temporal bootstrap state for diagnostics."""
status = temporal_bootstrap_state.as_dict() status = temporal_bootstrap_state.as_dict()
status["workflows_loaded"] = len(temporal_mgr.workflows) status["workflows_loaded"] = len(temporal_mgr.workflows)
status["bootstrap_task_running"] = ( status["bootstrap_task_running"] = temporal_bootstrap_task is not None and not temporal_bootstrap_task.done()
temporal_bootstrap_task is not None and not temporal_bootstrap_task.done()
)
return status return status
def _temporal_not_ready_status() -> Optional[Dict[str, Any]]: def _temporal_not_ready_status() -> dict[str, Any] | None:
"""Return status details if Temporal is not ready yet.""" """Return status details if Temporal is not ready yet."""
status = get_temporal_status() status = get_temporal_status()
if status.get("ready"): if status.get("ready"):
@@ -98,7 +97,7 @@ def _temporal_not_ready_status() -> Optional[Dict[str, Any]]:
@app.get("/") @app.get("/")
async def root() -> Dict[str, Any]: async def root() -> dict[str, Any]:
status = get_temporal_status() status = get_temporal_status()
return { return {
"name": "FuzzForge API", "name": "FuzzForge API",
@@ -110,14 +109,14 @@ async def root() -> Dict[str, Any]:
@app.get("/health") @app.get("/health")
async def health() -> Dict[str, str]: async def health() -> dict[str, str]:
status = get_temporal_status() status = get_temporal_status()
health_status = "healthy" if status.get("ready") else "initializing" health_status = "healthy" if status.get("ready") else "initializing"
return {"status": health_status} return {"status": health_status}
# Map FastAPI OpenAPI operationIds to readable MCP tool names # Map FastAPI OpenAPI operationIds to readable MCP tool names
FASTAPI_MCP_NAME_OVERRIDES: Dict[str, str] = { FASTAPI_MCP_NAME_OVERRIDES: dict[str, str] = {
"list_workflows_workflows__get": "api_list_workflows", "list_workflows_workflows__get": "api_list_workflows",
"get_metadata_schema_workflows_metadata_schema_get": "api_get_metadata_schema", "get_metadata_schema_workflows_metadata_schema_get": "api_get_metadata_schema",
"get_workflow_metadata_workflows__workflow_name__metadata_get": "api_get_workflow_metadata", "get_workflow_metadata_workflows__workflow_name__metadata_get": "api_get_workflow_metadata",
@@ -155,7 +154,6 @@ mcp = FastMCP(name="FuzzForge MCP")
async def _bootstrap_temporal_with_retries() -> None: async def _bootstrap_temporal_with_retries() -> None:
"""Initialize Temporal infrastructure with exponential backoff retries.""" """Initialize Temporal infrastructure with exponential backoff retries."""
attempt = 0 attempt = 0
while True: while True:
@@ -175,7 +173,6 @@ async def _bootstrap_temporal_with_retries() -> None:
temporal_bootstrap_state.status = "ready" temporal_bootstrap_state.status = "ready"
temporal_bootstrap_state.task_running = False temporal_bootstrap_state.task_running = False
logger.info("Temporal infrastructure ready") logger.info("Temporal infrastructure ready")
return
except asyncio.CancelledError: except asyncio.CancelledError:
temporal_bootstrap_state.status = "cancelled" temporal_bootstrap_state.status = "cancelled"
@@ -204,9 +201,11 @@ async def _bootstrap_temporal_with_retries() -> None:
temporal_bootstrap_state.status = "cancelled" temporal_bootstrap_state.status = "cancelled"
temporal_bootstrap_state.task_running = False temporal_bootstrap_state.task_running = False
raise raise
else:
return
def _lookup_workflow(workflow_name: str): def _lookup_workflow(workflow_name: str) -> dict[str, Any]:
info = temporal_mgr.workflows.get(workflow_name) info = temporal_mgr.workflows.get(workflow_name)
if not info: if not info:
return None return None
@@ -222,12 +221,12 @@ def _lookup_workflow(workflow_name: str):
"parameters": metadata.get("parameters", {}), "parameters": metadata.get("parameters", {}),
"default_parameters": metadata.get("default_parameters", {}), "default_parameters": metadata.get("default_parameters", {}),
"required_modules": metadata.get("required_modules", []), "required_modules": metadata.get("required_modules", []),
"default_target_path": default_target_path "default_target_path": default_target_path,
} }
@mcp.tool @mcp.tool
async def list_workflows_mcp() -> Dict[str, Any]: async def list_workflows_mcp() -> dict[str, Any]:
"""List all discovered workflows and their metadata summary.""" """List all discovered workflows and their metadata summary."""
not_ready = _temporal_not_ready_status() not_ready = _temporal_not_ready_status()
if not_ready: if not_ready:
@@ -241,20 +240,21 @@ async def list_workflows_mcp() -> Dict[str, Any]:
for name, info in temporal_mgr.workflows.items(): for name, info in temporal_mgr.workflows.items():
metadata = info.metadata metadata = info.metadata
defaults = metadata.get("default_parameters", {}) defaults = metadata.get("default_parameters", {})
workflows_summary.append({ workflows_summary.append(
"name": name, {
"version": metadata.get("version", "0.6.0"), "name": name,
"description": metadata.get("description", ""), "version": metadata.get("version", "0.6.0"),
"author": metadata.get("author"), "description": metadata.get("description", ""),
"tags": metadata.get("tags", []), "author": metadata.get("author"),
"default_target_path": metadata.get("default_target_path") "tags": metadata.get("tags", []),
or defaults.get("target_path") "default_target_path": metadata.get("default_target_path") or defaults.get("target_path"),
}) },
)
return {"workflows": workflows_summary, "temporal": get_temporal_status()} return {"workflows": workflows_summary, "temporal": get_temporal_status()}
@mcp.tool @mcp.tool
async def get_workflow_metadata_mcp(workflow_name: str) -> Dict[str, Any]: async def get_workflow_metadata_mcp(workflow_name: str) -> dict[str, Any]:
"""Fetch detailed metadata for a workflow.""" """Fetch detailed metadata for a workflow."""
not_ready = _temporal_not_ready_status() not_ready = _temporal_not_ready_status()
if not_ready: if not_ready:
@@ -270,7 +270,7 @@ async def get_workflow_metadata_mcp(workflow_name: str) -> Dict[str, Any]:
@mcp.tool @mcp.tool
async def get_workflow_parameters_mcp(workflow_name: str) -> Dict[str, Any]: async def get_workflow_parameters_mcp(workflow_name: str) -> dict[str, Any]:
"""Return the parameter schema and defaults for a workflow.""" """Return the parameter schema and defaults for a workflow."""
not_ready = _temporal_not_ready_status() not_ready = _temporal_not_ready_status()
if not_ready: if not_ready:
@@ -289,9 +289,8 @@ async def get_workflow_parameters_mcp(workflow_name: str) -> Dict[str, Any]:
@mcp.tool @mcp.tool
async def get_workflow_metadata_schema_mcp() -> Dict[str, Any]: async def get_workflow_metadata_schema_mcp() -> dict[str, Any]:
"""Return the JSON schema describing workflow metadata files.""" """Return the JSON schema describing workflow metadata files."""
from src.temporal.discovery import WorkflowDiscovery
return WorkflowDiscovery.get_metadata_schema() return WorkflowDiscovery.get_metadata_schema()
@@ -299,8 +298,8 @@ async def get_workflow_metadata_schema_mcp() -> Dict[str, Any]:
async def submit_security_scan_mcp( async def submit_security_scan_mcp(
workflow_name: str, workflow_name: str,
target_id: str, target_id: str,
parameters: Dict[str, Any] | None = None, parameters: dict[str, Any] | None = None,
) -> Dict[str, Any] | Dict[str, str]: ) -> dict[str, Any] | dict[str, str]:
"""Submit a Temporal workflow via MCP.""" """Submit a Temporal workflow via MCP."""
try: try:
not_ready = _temporal_not_ready_status() not_ready = _temporal_not_ready_status()
@@ -318,7 +317,7 @@ async def submit_security_scan_mcp(
defaults = metadata.get("default_parameters", {}) defaults = metadata.get("default_parameters", {})
parameters = parameters or {} parameters = parameters or {}
cleaned_parameters: Dict[str, Any] = {**defaults, **parameters} cleaned_parameters: dict[str, Any] = {**defaults, **parameters}
# Ensure *_config structures default to dicts # Ensure *_config structures default to dicts
for key, value in list(cleaned_parameters.items()): for key, value in list(cleaned_parameters.items()):
@@ -327,9 +326,7 @@ async def submit_security_scan_mcp(
# Some workflows expect configuration dictionaries even when omitted # Some workflows expect configuration dictionaries even when omitted
parameter_definitions = ( parameter_definitions = (
metadata.get("parameters", {}).get("properties", {}) metadata.get("parameters", {}).get("properties", {}) if isinstance(metadata.get("parameters"), dict) else {}
if isinstance(metadata.get("parameters"), dict)
else {}
) )
for key, definition in parameter_definitions.items(): for key, definition in parameter_definitions.items():
if not isinstance(key, str) or not key.endswith("_config"): if not isinstance(key, str) or not key.endswith("_config"):
@@ -347,6 +344,10 @@ async def submit_security_scan_mcp(
workflow_params=cleaned_parameters, workflow_params=cleaned_parameters,
) )
except Exception as exc: # pragma: no cover - defensive logging
logger.exception("MCP submit failed")
return {"error": f"Failed to submit workflow: {exc}"}
else:
return { return {
"run_id": handle.id, "run_id": handle.id,
"status": "RUNNING", "status": "RUNNING",
@@ -356,13 +357,10 @@ async def submit_security_scan_mcp(
"parameters": cleaned_parameters, "parameters": cleaned_parameters,
"mcp_enabled": True, "mcp_enabled": True,
} }
except Exception as exc: # pragma: no cover - defensive logging
logger.exception("MCP submit failed")
return {"error": f"Failed to submit workflow: {exc}"}
@mcp.tool @mcp.tool
async def get_comprehensive_scan_summary(run_id: str) -> Dict[str, Any] | Dict[str, str]: async def get_comprehensive_scan_summary(run_id: str) -> dict[str, Any] | dict[str, str]:
"""Return a summary for the given workflow run via MCP.""" """Return a summary for the given workflow run via MCP."""
try: try:
not_ready = _temporal_not_ready_status() not_ready = _temporal_not_ready_status()
@@ -385,7 +383,7 @@ async def get_comprehensive_scan_summary(run_id: str) -> Dict[str, Any] | Dict[s
summary = result.get("summary", {}) summary = result.get("summary", {})
total_findings = summary.get("total_findings", 0) total_findings = summary.get("total_findings", 0)
except Exception as e: except Exception as e:
logger.debug(f"Could not retrieve result for {run_id}: {e}") logger.debug("Could not retrieve result for %s: %s", run_id, e)
return { return {
"run_id": run_id, "run_id": run_id,
@@ -412,7 +410,7 @@ async def get_comprehensive_scan_summary(run_id: str) -> Dict[str, Any] | Dict[s
@mcp.tool @mcp.tool
async def get_run_status_mcp(run_id: str) -> Dict[str, Any]: async def get_run_status_mcp(run_id: str) -> dict[str, Any]:
"""Return current status information for a Temporal run.""" """Return current status information for a Temporal run."""
try: try:
not_ready = _temporal_not_ready_status() not_ready = _temporal_not_ready_status()
@@ -440,7 +438,7 @@ async def get_run_status_mcp(run_id: str) -> Dict[str, Any]:
@mcp.tool @mcp.tool
async def get_run_findings_mcp(run_id: str) -> Dict[str, Any]: async def get_run_findings_mcp(run_id: str) -> dict[str, Any]:
"""Return SARIF findings for a completed run.""" """Return SARIF findings for a completed run."""
try: try:
not_ready = _temporal_not_ready_status() not_ready = _temporal_not_ready_status()
@@ -463,24 +461,24 @@ async def get_run_findings_mcp(run_id: str) -> Dict[str, Any]:
sarif = result.get("sarif", {}) if isinstance(result, dict) else {} sarif = result.get("sarif", {}) if isinstance(result, dict) else {}
except Exception as exc:
logger.exception("MCP findings failed")
return {"error": f"Failed to retrieve findings: {exc}"}
else:
return { return {
"workflow": "unknown", "workflow": "unknown",
"run_id": run_id, "run_id": run_id,
"sarif": sarif, "sarif": sarif,
"metadata": metadata, "metadata": metadata,
} }
except Exception as exc:
logger.exception("MCP findings failed")
return {"error": f"Failed to retrieve findings: {exc}"}
@mcp.tool @mcp.tool
async def list_recent_runs_mcp( async def list_recent_runs_mcp(
limit: int = 10, limit: int = 10,
workflow_name: str | None = None, workflow_name: str | None = None,
) -> Dict[str, Any]: ) -> dict[str, Any]:
"""List recent Temporal runs with optional workflow filter.""" """List recent Temporal runs with optional workflow filter."""
not_ready = _temporal_not_ready_status() not_ready = _temporal_not_ready_status()
if not_ready: if not_ready:
return { return {
@@ -505,19 +503,21 @@ async def list_recent_runs_mcp(
workflows = await temporal_mgr.list_workflows(filter_query, limit_value) workflows = await temporal_mgr.list_workflows(filter_query, limit_value)
results: List[Dict[str, Any]] = [] results: list[dict[str, Any]] = []
for wf in workflows: for wf in workflows:
results.append({ results.append(
"run_id": wf["workflow_id"], {
"workflow": workflow_name or "unknown", "run_id": wf["workflow_id"],
"state": wf["status"], "workflow": workflow_name or "unknown",
"state_type": wf["status"], "state": wf["status"],
"is_completed": wf["status"] in ["COMPLETED", "FAILED", "CANCELLED"], "state_type": wf["status"],
"is_running": wf["status"] == "RUNNING", "is_completed": wf["status"] in ["COMPLETED", "FAILED", "CANCELLED"],
"is_failed": wf["status"] == "FAILED", "is_running": wf["status"] == "RUNNING",
"created_at": wf.get("start_time"), "is_failed": wf["status"] == "FAILED",
"updated_at": wf.get("close_time"), "created_at": wf.get("start_time"),
}) "updated_at": wf.get("close_time"),
},
)
return {"runs": results, "temporal": get_temporal_status()} return {"runs": results, "temporal": get_temporal_status()}
@@ -526,12 +526,12 @@ async def list_recent_runs_mcp(
return { return {
"runs": [], "runs": [],
"temporal": get_temporal_status(), "temporal": get_temporal_status(),
"error": str(exc) "error": str(exc),
} }
@mcp.tool @mcp.tool
async def get_fuzzing_stats_mcp(run_id: str) -> Dict[str, Any]: async def get_fuzzing_stats_mcp(run_id: str) -> dict[str, Any]:
"""Return fuzzing statistics for a run if available.""" """Return fuzzing statistics for a run if available."""
not_ready = _temporal_not_ready_status() not_ready = _temporal_not_ready_status()
if not_ready: if not_ready:
@@ -555,7 +555,7 @@ async def get_fuzzing_stats_mcp(run_id: str) -> Dict[str, Any]:
@mcp.tool @mcp.tool
async def get_fuzzing_crash_reports_mcp(run_id: str) -> Dict[str, Any]: async def get_fuzzing_crash_reports_mcp(run_id: str) -> dict[str, Any]:
"""Return crash reports collected for a fuzzing run.""" """Return crash reports collected for a fuzzing run."""
not_ready = _temporal_not_ready_status() not_ready = _temporal_not_ready_status()
if not_ready: if not_ready:
@@ -571,11 +571,10 @@ async def get_fuzzing_crash_reports_mcp(run_id: str) -> Dict[str, Any]:
@mcp.tool @mcp.tool
async def get_backend_status_mcp() -> Dict[str, Any]: async def get_backend_status_mcp() -> dict[str, Any]:
"""Expose backend readiness, workflows, and registered MCP tools.""" """Expose backend readiness, workflows, and registered MCP tools."""
status = get_temporal_status() status = get_temporal_status()
response: Dict[str, Any] = {"temporal": status} response: dict[str, Any] = {"temporal": status}
if status.get("ready"): if status.get("ready"):
response["workflows"] = list(temporal_mgr.workflows.keys()) response["workflows"] = list(temporal_mgr.workflows.keys())
@@ -591,7 +590,6 @@ async def get_backend_status_mcp() -> Dict[str, Any]:
def create_mcp_transport_app() -> Starlette: def create_mcp_transport_app() -> Starlette:
"""Build a Starlette app serving HTTP + SSE transports on one port.""" """Build a Starlette app serving HTTP + SSE transports on one port."""
http_app = mcp.http_app(path="/", transport="streamable-http") http_app = mcp.http_app(path="/", transport="streamable-http")
sse_app = create_sse_app( sse_app = create_sse_app(
server=mcp, server=mcp,
@@ -609,10 +607,10 @@ def create_mcp_transport_app() -> Starlette:
async def lifespan(app: Starlette): # pragma: no cover - integration wiring async def lifespan(app: Starlette): # pragma: no cover - integration wiring
async with AsyncExitStack() as stack: async with AsyncExitStack() as stack:
await stack.enter_async_context( await stack.enter_async_context(
http_app.router.lifespan_context(http_app) http_app.router.lifespan_context(http_app),
) )
await stack.enter_async_context( await stack.enter_async_context(
sse_app.router.lifespan_context(sse_app) sse_app.router.lifespan_context(sse_app),
) )
yield yield
@@ -627,6 +625,7 @@ def create_mcp_transport_app() -> Starlette:
# Combined lifespan: Temporal init + dedicated MCP transports # Combined lifespan: Temporal init + dedicated MCP transports
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@asynccontextmanager @asynccontextmanager
async def combined_lifespan(app: FastAPI): async def combined_lifespan(app: FastAPI):
global temporal_bootstrap_task, _fastapi_mcp_imported global temporal_bootstrap_task, _fastapi_mcp_imported
@@ -675,13 +674,14 @@ async def combined_lifespan(app: FastAPI):
if getattr(mcp_server, "started", False): if getattr(mcp_server, "started", False):
return return
await asyncio.sleep(poll_interval) await asyncio.sleep(poll_interval)
raise asyncio.TimeoutError raise TimeoutError
try: try:
await _wait_for_uvicorn_startup() await _wait_for_uvicorn_startup()
except asyncio.TimeoutError: # pragma: no cover - defensive logging except TimeoutError: # pragma: no cover - defensive logging
if mcp_task.done(): if mcp_task.done():
raise RuntimeError("MCP server failed to start") from mcp_task.exception() msg = "MCP server failed to start"
raise RuntimeError(msg) from mcp_task.exception()
logger.warning("Timed out waiting for MCP server startup; continuing anyway") logger.warning("Timed out waiting for MCP server startup; continuing anyway")
logger.info("MCP HTTP available at http://0.0.0.0:8010/mcp") logger.info("MCP HTTP available at http://0.0.0.0:8010/mcp")

View File

@@ -1,6 +1,4 @@
""" """Models for workflow findings and submissions."""
Models for workflow findings and submissions
"""
# Copyright (c) 2025 FuzzingLabs # Copyright (c) 2025 FuzzingLabs
# #
@@ -13,40 +11,43 @@ Models for workflow findings and submissions
# #
# Additional attribution and requirements are provided in the NOTICE file. # Additional attribution and requirements are provided in the NOTICE file.
from pydantic import BaseModel, Field
from typing import Dict, Any, Optional, List
from datetime import datetime from datetime import datetime
from typing import Any
from pydantic import BaseModel, Field
class WorkflowFindings(BaseModel): class WorkflowFindings(BaseModel):
"""Findings from a workflow execution in SARIF format""" """Findings from a workflow execution in SARIF format."""
workflow: str = Field(..., description="Workflow name") workflow: str = Field(..., description="Workflow name")
run_id: str = Field(..., description="Unique run identifier") run_id: str = Field(..., description="Unique run identifier")
sarif: Dict[str, Any] = Field(..., description="SARIF formatted findings") sarif: dict[str, Any] = Field(..., description="SARIF formatted findings")
metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata") metadata: dict[str, Any] = Field(default_factory=dict, description="Additional metadata")
class WorkflowSubmission(BaseModel): class WorkflowSubmission(BaseModel):
""" """Submit a workflow with configurable settings.
Submit a workflow with configurable settings.
Note: This model is deprecated in favor of the /upload-and-submit endpoint Note: This model is deprecated in favor of the /upload-and-submit endpoint
which handles file uploads directly. which handles file uploads directly.
""" """
parameters: Dict[str, Any] = Field(
parameters: dict[str, Any] = Field(
default_factory=dict, default_factory=dict,
description="Workflow-specific parameters" description="Workflow-specific parameters",
) )
timeout: Optional[int] = Field( timeout: int | None = Field(
default=None, # Allow workflow-specific defaults default=None, # Allow workflow-specific defaults
description="Timeout in seconds (None for workflow default)", description="Timeout in seconds (None for workflow default)",
ge=1, ge=1,
le=604800 # Max 7 days to support fuzzing campaigns le=604800, # Max 7 days to support fuzzing campaigns
) )
class WorkflowStatus(BaseModel): class WorkflowStatus(BaseModel):
"""Status of a workflow run""" """Status of a workflow run."""
run_id: str = Field(..., description="Unique run identifier") run_id: str = Field(..., description="Unique run identifier")
workflow: str = Field(..., description="Workflow name") workflow: str = Field(..., description="Workflow name")
status: str = Field(..., description="Current status") status: str = Field(..., description="Current status")
@@ -58,34 +59,37 @@ class WorkflowStatus(BaseModel):
class WorkflowMetadata(BaseModel): class WorkflowMetadata(BaseModel):
"""Complete metadata for a workflow""" """Complete metadata for a workflow."""
name: str = Field(..., description="Workflow name") name: str = Field(..., description="Workflow name")
version: str = Field(..., description="Semantic version") version: str = Field(..., description="Semantic version")
description: str = Field(..., description="Workflow description") description: str = Field(..., description="Workflow description")
author: Optional[str] = Field(None, description="Workflow author") author: str | None = Field(None, description="Workflow author")
tags: List[str] = Field(default_factory=list, description="Workflow tags") tags: list[str] = Field(default_factory=list, description="Workflow tags")
parameters: Dict[str, Any] = Field(..., description="Parameters schema") parameters: dict[str, Any] = Field(..., description="Parameters schema")
default_parameters: Dict[str, Any] = Field( default_parameters: dict[str, Any] = Field(
default_factory=dict, default_factory=dict,
description="Default parameter values" description="Default parameter values",
) )
required_modules: List[str] = Field( required_modules: list[str] = Field(
default_factory=list, default_factory=list,
description="Required module names" description="Required module names",
) )
class WorkflowListItem(BaseModel): class WorkflowListItem(BaseModel):
"""Summary information for a workflow in list views""" """Summary information for a workflow in list views."""
name: str = Field(..., description="Workflow name") name: str = Field(..., description="Workflow name")
version: str = Field(..., description="Semantic version") version: str = Field(..., description="Semantic version")
description: str = Field(..., description="Workflow description") description: str = Field(..., description="Workflow description")
author: Optional[str] = Field(None, description="Workflow author") author: str | None = Field(None, description="Workflow author")
tags: List[str] = Field(default_factory=list, description="Workflow tags") tags: list[str] = Field(default_factory=list, description="Workflow tags")
class RunSubmissionResponse(BaseModel): class RunSubmissionResponse(BaseModel):
"""Response after submitting a workflow""" """Response after submitting a workflow."""
run_id: str = Field(..., description="Unique run identifier") run_id: str = Field(..., description="Unique run identifier")
status: str = Field(..., description="Initial status") status: str = Field(..., description="Initial status")
workflow: str = Field(..., description="Workflow name") workflow: str = Field(..., description="Workflow name")
@@ -93,28 +97,30 @@ class RunSubmissionResponse(BaseModel):
class FuzzingStats(BaseModel): class FuzzingStats(BaseModel):
"""Real-time fuzzing statistics""" """Real-time fuzzing statistics."""
run_id: str = Field(..., description="Unique run identifier") run_id: str = Field(..., description="Unique run identifier")
workflow: str = Field(..., description="Workflow name") workflow: str = Field(..., description="Workflow name")
executions: int = Field(default=0, description="Total executions") executions: int = Field(default=0, description="Total executions")
executions_per_sec: float = Field(default=0.0, description="Current execution rate") executions_per_sec: float = Field(default=0.0, description="Current execution rate")
crashes: int = Field(default=0, description="Total crashes found") crashes: int = Field(default=0, description="Total crashes found")
unique_crashes: int = Field(default=0, description="Unique crashes") unique_crashes: int = Field(default=0, description="Unique crashes")
coverage: Optional[float] = Field(None, description="Code coverage percentage") coverage: float | None = Field(None, description="Code coverage percentage")
corpus_size: int = Field(default=0, description="Current corpus size") corpus_size: int = Field(default=0, description="Current corpus size")
elapsed_time: int = Field(default=0, description="Elapsed time in seconds") elapsed_time: int = Field(default=0, description="Elapsed time in seconds")
last_crash_time: Optional[datetime] = Field(None, description="Time of last crash") last_crash_time: datetime | None = Field(None, description="Time of last crash")
class CrashReport(BaseModel): class CrashReport(BaseModel):
"""Individual crash report from fuzzing""" """Individual crash report from fuzzing."""
run_id: str = Field(..., description="Run identifier") run_id: str = Field(..., description="Run identifier")
crash_id: str = Field(..., description="Unique crash identifier") crash_id: str = Field(..., description="Unique crash identifier")
timestamp: datetime = Field(default_factory=datetime.utcnow) timestamp: datetime = Field(default_factory=datetime.utcnow)
signal: Optional[str] = Field(None, description="Crash signal (SIGSEGV, etc.)") signal: str | None = Field(None, description="Crash signal (SIGSEGV, etc.)")
crash_type: Optional[str] = Field(None, description="Type of crash") crash_type: str | None = Field(None, description="Type of crash")
stack_trace: Optional[str] = Field(None, description="Stack trace") stack_trace: str | None = Field(None, description="Stack trace")
input_file: Optional[str] = Field(None, description="Path to crashing input") input_file: str | None = Field(None, description="Path to crashing input")
reproducer: Optional[str] = Field(None, description="Minimized reproducer") reproducer: str | None = Field(None, description="Minimized reproducer")
severity: str = Field(default="medium", description="Crash severity") severity: str = Field(default="medium", description="Crash severity")
exploitability: Optional[str] = Field(None, description="Exploitability assessment") exploitability: str | None = Field(None, description="Exploitability assessment")

View File

@@ -1,5 +1,4 @@
""" """Storage abstraction layer for FuzzForge.
Storage abstraction layer for FuzzForge.
Provides unified interface for storing and retrieving targets and results. Provides unified interface for storing and retrieving targets and results.
""" """
@@ -7,4 +6,4 @@ Provides unified interface for storing and retrieving targets and results.
from .base import StorageBackend from .base import StorageBackend
from .s3_cached import S3CachedStorage from .s3_cached import S3CachedStorage
__all__ = ["StorageBackend", "S3CachedStorage"] __all__ = ["S3CachedStorage", "StorageBackend"]

View File

@@ -1,17 +1,15 @@
""" """Base storage backend interface.
Base storage backend interface.
All storage implementations must implement this interface. All storage implementations must implement this interface.
""" """
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
from typing import Optional, Dict, Any from typing import Any
class StorageBackend(ABC): class StorageBackend(ABC):
""" """Abstract base class for storage backends.
Abstract base class for storage backends.
Implementations handle storage and retrieval of: Implementations handle storage and retrieval of:
- Uploaded targets (code, binaries, etc.) - Uploaded targets (code, binaries, etc.)
@@ -24,10 +22,9 @@ class StorageBackend(ABC):
self, self,
file_path: Path, file_path: Path,
user_id: str, user_id: str,
metadata: Optional[Dict[str, Any]] = None metadata: dict[str, Any] | None = None,
) -> str: ) -> str:
""" """Upload a target file to storage.
Upload a target file to storage.
Args: Args:
file_path: Local path to file to upload file_path: Local path to file to upload
@@ -40,13 +37,12 @@ class StorageBackend(ABC):
Raises: Raises:
FileNotFoundError: If file_path doesn't exist FileNotFoundError: If file_path doesn't exist
StorageError: If upload fails StorageError: If upload fails
""" """
pass
@abstractmethod @abstractmethod
async def get_target(self, target_id: str) -> Path: async def get_target(self, target_id: str) -> Path:
""" """Get target file from storage.
Get target file from storage.
Args: Args:
target_id: Unique identifier from upload_target() target_id: Unique identifier from upload_target()
@@ -57,31 +53,29 @@ class StorageBackend(ABC):
Raises: Raises:
FileNotFoundError: If target doesn't exist FileNotFoundError: If target doesn't exist
StorageError: If download fails StorageError: If download fails
""" """
pass
@abstractmethod @abstractmethod
async def delete_target(self, target_id: str) -> None: async def delete_target(self, target_id: str) -> None:
""" """Delete target from storage.
Delete target from storage.
Args: Args:
target_id: Unique identifier to delete target_id: Unique identifier to delete
Raises: Raises:
StorageError: If deletion fails (doesn't raise if not found) StorageError: If deletion fails (doesn't raise if not found)
""" """
pass
@abstractmethod @abstractmethod
async def upload_results( async def upload_results(
self, self,
workflow_id: str, workflow_id: str,
results: Dict[str, Any], results: dict[str, Any],
results_format: str = "json" results_format: str = "json",
) -> str: ) -> str:
""" """Upload workflow results to storage.
Upload workflow results to storage.
Args: Args:
workflow_id: Workflow execution ID workflow_id: Workflow execution ID
@@ -93,13 +87,12 @@ class StorageBackend(ABC):
Raises: Raises:
StorageError: If upload fails StorageError: If upload fails
""" """
pass
@abstractmethod @abstractmethod
async def get_results(self, workflow_id: str) -> Dict[str, Any]: async def get_results(self, workflow_id: str) -> dict[str, Any]:
""" """Get workflow results from storage.
Get workflow results from storage.
Args: Args:
workflow_id: Workflow execution ID workflow_id: Workflow execution ID
@@ -110,17 +103,16 @@ class StorageBackend(ABC):
Raises: Raises:
FileNotFoundError: If results don't exist FileNotFoundError: If results don't exist
StorageError: If download fails StorageError: If download fails
""" """
pass
@abstractmethod @abstractmethod
async def list_targets( async def list_targets(
self, self,
user_id: Optional[str] = None, user_id: str | None = None,
limit: int = 100 limit: int = 100,
) -> list[Dict[str, Any]]: ) -> list[dict[str, Any]]:
""" """List uploaded targets.
List uploaded targets.
Args: Args:
user_id: Filter by user ID (None = all users) user_id: Filter by user ID (None = all users)
@@ -131,23 +123,21 @@ class StorageBackend(ABC):
Raises: Raises:
StorageError: If listing fails StorageError: If listing fails
""" """
pass
@abstractmethod @abstractmethod
async def cleanup_cache(self) -> int: async def cleanup_cache(self) -> int:
""" """Clean up local cache (LRU eviction).
Clean up local cache (LRU eviction).
Returns: Returns:
Number of files removed Number of files removed
Raises: Raises:
StorageError: If cleanup fails StorageError: If cleanup fails
""" """
pass
class StorageError(Exception): class StorageError(Exception):
"""Base exception for storage operations.""" """Base exception for storage operations."""
pass

View File

@@ -1,5 +1,4 @@
""" """S3-compatible storage backend with local caching.
S3-compatible storage backend with local caching.
Works with MinIO (dev/prod) or AWS S3 (cloud). Works with MinIO (dev/prod) or AWS S3 (cloud).
""" """
@@ -10,7 +9,7 @@ import os
import shutil import shutil
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import Optional, Dict, Any from typing import Any
from uuid import uuid4 from uuid import uuid4
import boto3 import boto3
@@ -22,8 +21,7 @@ logger = logging.getLogger(__name__)
class S3CachedStorage(StorageBackend): class S3CachedStorage(StorageBackend):
""" """S3-compatible storage with local caching.
S3-compatible storage with local caching.
Features: Features:
- Upload targets to S3/MinIO - Upload targets to S3/MinIO
@@ -34,17 +32,16 @@ class S3CachedStorage(StorageBackend):
def __init__( def __init__(
self, self,
endpoint_url: Optional[str] = None, endpoint_url: str | None = None,
access_key: Optional[str] = None, access_key: str | None = None,
secret_key: Optional[str] = None, secret_key: str | None = None,
bucket: str = "targets", bucket: str = "targets",
region: str = "us-east-1", region: str = "us-east-1",
use_ssl: bool = False, use_ssl: bool = False,
cache_dir: Optional[Path] = None, cache_dir: Path | None = None,
cache_max_size_gb: int = 10 cache_max_size_gb: int = 10,
): ) -> None:
""" """Initialize S3 storage backend.
Initialize S3 storage backend.
Args: Args:
endpoint_url: S3 endpoint (None = AWS S3, or MinIO URL) endpoint_url: S3 endpoint (None = AWS S3, or MinIO URL)
@@ -55,18 +52,19 @@ class S3CachedStorage(StorageBackend):
use_ssl: Use HTTPS use_ssl: Use HTTPS
cache_dir: Local cache directory cache_dir: Local cache directory
cache_max_size_gb: Maximum cache size in GB cache_max_size_gb: Maximum cache size in GB
""" """
# Use environment variables as defaults # Use environment variables as defaults
self.endpoint_url = endpoint_url or os.getenv('S3_ENDPOINT', 'http://minio:9000') self.endpoint_url = endpoint_url or os.getenv("S3_ENDPOINT", "http://minio:9000")
self.access_key = access_key or os.getenv('S3_ACCESS_KEY', 'fuzzforge') self.access_key = access_key or os.getenv("S3_ACCESS_KEY", "fuzzforge")
self.secret_key = secret_key or os.getenv('S3_SECRET_KEY', 'fuzzforge123') self.secret_key = secret_key or os.getenv("S3_SECRET_KEY", "fuzzforge123")
self.bucket = bucket or os.getenv('S3_BUCKET', 'targets') self.bucket = bucket or os.getenv("S3_BUCKET", "targets")
self.region = region or os.getenv('S3_REGION', 'us-east-1') self.region = region or os.getenv("S3_REGION", "us-east-1")
self.use_ssl = use_ssl or os.getenv('S3_USE_SSL', 'false').lower() == 'true' self.use_ssl = use_ssl or os.getenv("S3_USE_SSL", "false").lower() == "true"
# Cache configuration # Cache configuration
self.cache_dir = cache_dir or Path(os.getenv('CACHE_DIR', '/tmp/fuzzforge-cache')) self.cache_dir = cache_dir or Path(os.getenv("CACHE_DIR", "/tmp/fuzzforge-cache"))
self.cache_max_size = cache_max_size_gb * (1024 ** 3) # Convert to bytes self.cache_max_size = cache_max_size_gb * (1024**3) # Convert to bytes
# Ensure cache directory exists # Ensure cache directory exists
self.cache_dir.mkdir(parents=True, exist_ok=True) self.cache_dir.mkdir(parents=True, exist_ok=True)
@@ -74,69 +72,75 @@ class S3CachedStorage(StorageBackend):
# Initialize S3 client # Initialize S3 client
try: try:
self.s3_client = boto3.client( self.s3_client = boto3.client(
's3', "s3",
endpoint_url=self.endpoint_url, endpoint_url=self.endpoint_url,
aws_access_key_id=self.access_key, aws_access_key_id=self.access_key,
aws_secret_access_key=self.secret_key, aws_secret_access_key=self.secret_key,
region_name=self.region, region_name=self.region,
use_ssl=self.use_ssl use_ssl=self.use_ssl,
) )
logger.info(f"Initialized S3 storage: {self.endpoint_url}/{self.bucket}") logger.info("Initialized S3 storage: %s/%s", self.endpoint_url, self.bucket)
except Exception as e: except Exception as e:
logger.error(f"Failed to initialize S3 client: {e}") logger.exception("Failed to initialize S3 client")
raise StorageError(f"S3 initialization failed: {e}") msg = f"S3 initialization failed: {e}"
raise StorageError(msg) from e
async def upload_target( async def upload_target(
self, self,
file_path: Path, file_path: Path,
user_id: str, user_id: str,
metadata: Optional[Dict[str, Any]] = None metadata: dict[str, Any] | None = None,
) -> str: ) -> str:
"""Upload target file to S3/MinIO.""" """Upload target file to S3/MinIO."""
if not file_path.exists(): if not file_path.exists():
raise FileNotFoundError(f"File not found: {file_path}") msg = f"File not found: {file_path}"
raise FileNotFoundError(msg)
# Generate unique target ID # Generate unique target ID
target_id = str(uuid4()) target_id = str(uuid4())
# Prepare metadata # Prepare metadata
upload_metadata = { upload_metadata = {
'user_id': user_id, "user_id": user_id,
'uploaded_at': datetime.now().isoformat(), "uploaded_at": datetime.now().isoformat(),
'filename': file_path.name, "filename": file_path.name,
'size': str(file_path.stat().st_size) "size": str(file_path.stat().st_size),
} }
if metadata: if metadata:
upload_metadata.update(metadata) upload_metadata.update(metadata)
# Upload to S3 # Upload to S3
s3_key = f'{target_id}/target' s3_key = f"{target_id}/target"
try: try:
logger.info(f"Uploading target to s3://{self.bucket}/{s3_key}") logger.info("Uploading target to s3://%s/%s", self.bucket, s3_key)
self.s3_client.upload_file( self.s3_client.upload_file(
str(file_path), str(file_path),
self.bucket, self.bucket,
s3_key, s3_key,
ExtraArgs={ ExtraArgs={
'Metadata': upload_metadata "Metadata": upload_metadata,
} },
) )
file_size_mb = file_path.stat().st_size / (1024 * 1024) file_size_mb = file_path.stat().st_size / (1024 * 1024)
logger.info( logger.info(
f"✓ Uploaded target {target_id} " "✓ Uploaded target %s (%s, %s MB)",
f"({file_path.name}, {file_size_mb:.2f} MB)" target_id,
file_path.name,
f"{file_size_mb:.2f}",
) )
return target_id
except ClientError as e: except ClientError as e:
logger.error(f"S3 upload failed: {e}", exc_info=True) logger.exception("S3 upload failed")
raise StorageError(f"Failed to upload target: {e}") msg = f"Failed to upload target: {e}"
raise StorageError(msg) from e
except Exception as e: except Exception as e:
logger.error(f"Upload failed: {e}", exc_info=True) logger.exception("Upload failed")
raise StorageError(f"Upload error: {e}") msg = f"Upload error: {e}"
raise StorageError(msg) from e
else:
return target_id
async def get_target(self, target_id: str) -> Path: async def get_target(self, target_id: str) -> Path:
"""Get target from cache or download from S3/MinIO.""" """Get target from cache or download from S3/MinIO."""
@@ -147,105 +151,110 @@ class S3CachedStorage(StorageBackend):
if cached_file.exists(): if cached_file.exists():
# Update access time for LRU # Update access time for LRU
cached_file.touch() cached_file.touch()
logger.info(f"Cache HIT: {target_id}") logger.info("Cache HIT: %s", target_id)
return cached_file return cached_file
# Cache miss - download from S3 # Cache miss - download from S3
logger.info(f"Cache MISS: {target_id}, downloading from S3...") logger.info("Cache MISS: %s, downloading from S3...", target_id)
try: try:
# Create cache directory # Create cache directory
cache_path.mkdir(parents=True, exist_ok=True) cache_path.mkdir(parents=True, exist_ok=True)
# Download from S3 # Download from S3
s3_key = f'{target_id}/target' s3_key = f"{target_id}/target"
logger.info(f"Downloading s3://{self.bucket}/{s3_key}") logger.info("Downloading s3://%s/%s", self.bucket, s3_key)
self.s3_client.download_file( self.s3_client.download_file(
self.bucket, self.bucket,
s3_key, s3_key,
str(cached_file) str(cached_file),
) )
# Verify download # Verify download
if not cached_file.exists(): if not cached_file.exists():
raise StorageError(f"Downloaded file not found: {cached_file}") msg = f"Downloaded file not found: {cached_file}"
raise StorageError(msg)
file_size_mb = cached_file.stat().st_size / (1024 * 1024) file_size_mb = cached_file.stat().st_size / (1024 * 1024)
logger.info(f"✓ Downloaded target {target_id} ({file_size_mb:.2f} MB)") logger.info("✓ Downloaded target %s (%s MB)", target_id, f"{file_size_mb:.2f}")
return cached_file
except ClientError as e: except ClientError as e:
error_code = e.response.get('Error', {}).get('Code') error_code = e.response.get("Error", {}).get("Code")
if error_code in ['404', 'NoSuchKey']: if error_code in ["404", "NoSuchKey"]:
logger.error(f"Target not found: {target_id}") logger.exception("Target not found: %s", target_id)
raise FileNotFoundError(f"Target {target_id} not found in storage") msg = f"Target {target_id} not found in storage"
else: raise FileNotFoundError(msg) from e
logger.error(f"S3 download failed: {e}", exc_info=True) logger.exception("S3 download failed")
raise StorageError(f"Download failed: {e}") msg = f"Download failed: {e}"
raise StorageError(msg) from e
except Exception as e: except Exception as e:
logger.error(f"Download error: {e}", exc_info=True) logger.exception("Download error")
# Cleanup partial download # Cleanup partial download
if cache_path.exists(): if cache_path.exists():
shutil.rmtree(cache_path, ignore_errors=True) shutil.rmtree(cache_path, ignore_errors=True)
raise StorageError(f"Download error: {e}") msg = f"Download error: {e}"
raise StorageError(msg) from e
else:
return cached_file
async def delete_target(self, target_id: str) -> None: async def delete_target(self, target_id: str) -> None:
"""Delete target from S3/MinIO.""" """Delete target from S3/MinIO."""
try: try:
s3_key = f'{target_id}/target' s3_key = f"{target_id}/target"
logger.info(f"Deleting s3://{self.bucket}/{s3_key}") logger.info("Deleting s3://%s/%s", self.bucket, s3_key)
self.s3_client.delete_object( self.s3_client.delete_object(
Bucket=self.bucket, Bucket=self.bucket,
Key=s3_key Key=s3_key,
) )
# Also delete from cache if present # Also delete from cache if present
cache_path = self.cache_dir / target_id cache_path = self.cache_dir / target_id
if cache_path.exists(): if cache_path.exists():
shutil.rmtree(cache_path, ignore_errors=True) shutil.rmtree(cache_path, ignore_errors=True)
logger.info(f"✓ Deleted target {target_id} from S3 and cache") logger.info("✓ Deleted target %s from S3 and cache", target_id)
else: else:
logger.info(f"✓ Deleted target {target_id} from S3") logger.info("✓ Deleted target %s from S3", target_id)
except ClientError as e: except ClientError as e:
logger.error(f"S3 delete failed: {e}", exc_info=True) logger.exception("S3 delete failed")
# Don't raise error if object doesn't exist # Don't raise error if object doesn't exist
if e.response.get('Error', {}).get('Code') not in ['404', 'NoSuchKey']: if e.response.get("Error", {}).get("Code") not in ["404", "NoSuchKey"]:
raise StorageError(f"Delete failed: {e}") msg = f"Delete failed: {e}"
raise StorageError(msg) from e
except Exception as e: except Exception as e:
logger.error(f"Delete error: {e}", exc_info=True) logger.exception("Delete error")
raise StorageError(f"Delete error: {e}") msg = f"Delete error: {e}"
raise StorageError(msg) from e
async def upload_results( async def upload_results(
self, self,
workflow_id: str, workflow_id: str,
results: Dict[str, Any], results: dict[str, Any],
results_format: str = "json" results_format: str = "json",
) -> str: ) -> str:
"""Upload workflow results to S3/MinIO.""" """Upload workflow results to S3/MinIO."""
try: try:
# Prepare results content # Prepare results content
if results_format == "json": if results_format == "json":
content = json.dumps(results, indent=2).encode('utf-8') content = json.dumps(results, indent=2).encode("utf-8")
content_type = 'application/json' content_type = "application/json"
file_ext = 'json' file_ext = "json"
elif results_format == "sarif": elif results_format == "sarif":
content = json.dumps(results, indent=2).encode('utf-8') content = json.dumps(results, indent=2).encode("utf-8")
content_type = 'application/sarif+json' content_type = "application/sarif+json"
file_ext = 'sarif' file_ext = "sarif"
else: else:
content = json.dumps(results, indent=2).encode('utf-8') content = json.dumps(results, indent=2).encode("utf-8")
content_type = 'application/json' content_type = "application/json"
file_ext = 'json' file_ext = "json"
# Upload to results bucket # Upload to results bucket
results_bucket = 'results' results_bucket = "results"
s3_key = f'{workflow_id}/results.{file_ext}' s3_key = f"{workflow_id}/results.{file_ext}"
logger.info(f"Uploading results to s3://{results_bucket}/{s3_key}") logger.info("Uploading results to s3://%s/%s", results_bucket, s3_key)
self.s3_client.put_object( self.s3_client.put_object(
Bucket=results_bucket, Bucket=results_bucket,
@@ -253,95 +262,103 @@ class S3CachedStorage(StorageBackend):
Body=content, Body=content,
ContentType=content_type, ContentType=content_type,
Metadata={ Metadata={
'workflow_id': workflow_id, "workflow_id": workflow_id,
'format': results_format, "format": results_format,
'uploaded_at': datetime.now().isoformat() "uploaded_at": datetime.now().isoformat(),
} },
) )
# Construct URL # Construct URL
results_url = f"{self.endpoint_url}/{results_bucket}/{s3_key}" results_url = f"{self.endpoint_url}/{results_bucket}/{s3_key}"
logger.info(f"✓ Uploaded results: {results_url}") logger.info("✓ Uploaded results: %s", results_url)
return results_url
except Exception as e: except Exception as e:
logger.error(f"Results upload failed: {e}", exc_info=True) logger.exception("Results upload failed")
raise StorageError(f"Results upload failed: {e}") msg = f"Results upload failed: {e}"
raise StorageError(msg) from e
else:
return results_url
async def get_results(self, workflow_id: str) -> Dict[str, Any]: async def get_results(self, workflow_id: str) -> dict[str, Any]:
"""Get workflow results from S3/MinIO.""" """Get workflow results from S3/MinIO."""
try: try:
results_bucket = 'results' results_bucket = "results"
s3_key = f'{workflow_id}/results.json' s3_key = f"{workflow_id}/results.json"
logger.info(f"Downloading results from s3://{results_bucket}/{s3_key}") logger.info("Downloading results from s3://%s/%s", results_bucket, s3_key)
response = self.s3_client.get_object( response = self.s3_client.get_object(
Bucket=results_bucket, Bucket=results_bucket,
Key=s3_key Key=s3_key,
) )
content = response['Body'].read().decode('utf-8') content = response["Body"].read().decode("utf-8")
results = json.loads(content) results = json.loads(content)
logger.info(f"✓ Downloaded results for workflow {workflow_id}") logger.info("✓ Downloaded results for workflow %s", workflow_id)
return results
except ClientError as e: except ClientError as e:
error_code = e.response.get('Error', {}).get('Code') error_code = e.response.get("Error", {}).get("Code")
if error_code in ['404', 'NoSuchKey']: if error_code in ["404", "NoSuchKey"]:
logger.error(f"Results not found: {workflow_id}") logger.exception("Results not found: %s", workflow_id)
raise FileNotFoundError(f"Results for workflow {workflow_id} not found") msg = f"Results for workflow {workflow_id} not found"
else: raise FileNotFoundError(msg) from e
logger.error(f"Results download failed: {e}", exc_info=True) logger.exception("Results download failed")
raise StorageError(f"Results download failed: {e}") msg = f"Results download failed: {e}"
raise StorageError(msg) from e
except Exception as e: except Exception as e:
logger.error(f"Results download error: {e}", exc_info=True) logger.exception("Results download error")
raise StorageError(f"Results download error: {e}") msg = f"Results download error: {e}"
raise StorageError(msg) from e
else:
return results
async def list_targets( async def list_targets(
self, self,
user_id: Optional[str] = None, user_id: str | None = None,
limit: int = 100 limit: int = 100,
) -> list[Dict[str, Any]]: ) -> list[dict[str, Any]]:
"""List uploaded targets.""" """List uploaded targets."""
try: try:
targets = [] targets = []
paginator = self.s3_client.get_paginator('list_objects_v2') paginator = self.s3_client.get_paginator("list_objects_v2")
for page in paginator.paginate(Bucket=self.bucket, PaginationConfig={'MaxItems': limit}): for page in paginator.paginate(Bucket=self.bucket, PaginationConfig={"MaxItems": limit}):
for obj in page.get('Contents', []): for obj in page.get("Contents", []):
# Get object metadata # Get object metadata
try: try:
metadata_response = self.s3_client.head_object( metadata_response = self.s3_client.head_object(
Bucket=self.bucket, Bucket=self.bucket,
Key=obj['Key'] Key=obj["Key"],
) )
metadata = metadata_response.get('Metadata', {}) metadata = metadata_response.get("Metadata", {})
# Filter by user_id if specified # Filter by user_id if specified
if user_id and metadata.get('user_id') != user_id: if user_id and metadata.get("user_id") != user_id:
continue continue
targets.append({ targets.append(
'target_id': obj['Key'].split('/')[0], {
'key': obj['Key'], "target_id": obj["Key"].split("/")[0],
'size': obj['Size'], "key": obj["Key"],
'last_modified': obj['LastModified'].isoformat(), "size": obj["Size"],
'metadata': metadata "last_modified": obj["LastModified"].isoformat(),
}) "metadata": metadata,
},
)
except Exception as e: except Exception as e:
logger.warning(f"Failed to get metadata for {obj['Key']}: {e}") logger.warning("Failed to get metadata for %s: %s", obj["Key"], e)
continue continue
logger.info(f"Listed {len(targets)} targets (user_id={user_id})") logger.info("Listed %s targets (user_id=%s)", len(targets), user_id)
return targets
except Exception as e: except Exception as e:
logger.error(f"List targets failed: {e}", exc_info=True) logger.exception("List targets failed")
raise StorageError(f"List targets failed: {e}") msg = f"List targets failed: {e}"
raise StorageError(msg) from e
else:
return targets
async def cleanup_cache(self) -> int: async def cleanup_cache(self) -> int:
"""Clean up local cache using LRU eviction.""" """Clean up local cache using LRU eviction."""
@@ -350,30 +367,33 @@ class S3CachedStorage(StorageBackend):
total_size = 0 total_size = 0
# Gather all cached files with metadata # Gather all cached files with metadata
for cache_file in self.cache_dir.rglob('*'): for cache_file in self.cache_dir.rglob("*"):
if cache_file.is_file(): if cache_file.is_file():
try: try:
stat = cache_file.stat() stat = cache_file.stat()
cache_files.append({ cache_files.append(
'path': cache_file, {
'size': stat.st_size, "path": cache_file,
'atime': stat.st_atime # Last access time "size": stat.st_size,
}) "atime": stat.st_atime, # Last access time
},
)
total_size += stat.st_size total_size += stat.st_size
except Exception as e: except Exception as e:
logger.warning(f"Failed to stat {cache_file}: {e}") logger.warning("Failed to stat %s: %s", cache_file, e)
continue continue
# Check if cleanup is needed # Check if cleanup is needed
if total_size <= self.cache_max_size: if total_size <= self.cache_max_size:
logger.info( logger.info(
f"Cache size OK: {total_size / (1024**3):.2f} GB / " "Cache size OK: %s GB / %s GB",
f"{self.cache_max_size / (1024**3):.2f} GB" f"{total_size / (1024**3):.2f}",
f"{self.cache_max_size / (1024**3):.2f}",
) )
return 0 return 0
# Sort by access time (oldest first) # Sort by access time (oldest first)
cache_files.sort(key=lambda x: x['atime']) cache_files.sort(key=lambda x: x["atime"])
# Remove files until under limit # Remove files until under limit
removed_count = 0 removed_count = 0
@@ -382,42 +402,46 @@ class S3CachedStorage(StorageBackend):
break break
try: try:
file_info['path'].unlink() file_info["path"].unlink()
total_size -= file_info['size'] total_size -= file_info["size"]
removed_count += 1 removed_count += 1
logger.debug(f"Evicted from cache: {file_info['path']}") logger.debug("Evicted from cache: %s", file_info["path"])
except Exception as e: except Exception as e:
logger.warning(f"Failed to delete {file_info['path']}: {e}") logger.warning("Failed to delete %s: %s", file_info["path"], e)
continue continue
logger.info( logger.info(
f"✓ Cache cleanup: removed {removed_count} files, " "✓ Cache cleanup: removed %s files, new size: %s GB",
f"new size: {total_size / (1024**3):.2f} GB" removed_count,
f"{total_size / (1024**3):.2f}",
) )
return removed_count
except Exception as e: except Exception as e:
logger.error(f"Cache cleanup failed: {e}", exc_info=True) logger.exception("Cache cleanup failed")
raise StorageError(f"Cache cleanup failed: {e}") msg = f"Cache cleanup failed: {e}"
raise StorageError(msg) from e
def get_cache_stats(self) -> Dict[str, Any]: else:
return removed_count
def get_cache_stats(self) -> dict[str, Any]:
"""Get cache statistics.""" """Get cache statistics."""
try: try:
total_size = 0 total_size = 0
file_count = 0 file_count = 0
for cache_file in self.cache_dir.rglob('*'): for cache_file in self.cache_dir.rglob("*"):
if cache_file.is_file(): if cache_file.is_file():
total_size += cache_file.stat().st_size total_size += cache_file.stat().st_size
file_count += 1 file_count += 1
return { return {
'total_size_bytes': total_size, "total_size_bytes": total_size,
'total_size_gb': total_size / (1024 ** 3), "total_size_gb": total_size / (1024**3),
'file_count': file_count, "file_count": file_count,
'max_size_gb': self.cache_max_size / (1024 ** 3), "max_size_gb": self.cache_max_size / (1024**3),
'usage_percent': (total_size / self.cache_max_size) * 100 "usage_percent": (total_size / self.cache_max_size) * 100,
} }
except Exception as e: except Exception as e:
logger.error(f"Failed to get cache stats: {e}") logger.exception("Failed to get cache stats")
return {'error': str(e)} return {"error": str(e)}

View File

@@ -1,10 +1,9 @@
""" """Temporal integration for FuzzForge.
Temporal integration for FuzzForge.
Handles workflow execution, monitoring, and management. Handles workflow execution, monitoring, and management.
""" """
from .manager import TemporalManager
from .discovery import WorkflowDiscovery from .discovery import WorkflowDiscovery
from .manager import TemporalManager
__all__ = ["TemporalManager", "WorkflowDiscovery"] __all__ = ["TemporalManager", "WorkflowDiscovery"]

View File

@@ -1,25 +1,26 @@
""" """Workflow Discovery for Temporal.
Workflow Discovery for Temporal
Discovers workflows from the toolbox/workflows directory Discovers workflows from the toolbox/workflows directory
and provides metadata about available workflows. and provides metadata about available workflows.
""" """
import logging import logging
import yaml
from pathlib import Path from pathlib import Path
from typing import Dict, Any from typing import Any
from pydantic import BaseModel, Field, ConfigDict
import yaml
from pydantic import BaseModel, ConfigDict, Field
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class WorkflowInfo(BaseModel): class WorkflowInfo(BaseModel):
"""Information about a discovered workflow""" """Information about a discovered workflow."""
name: str = Field(..., description="Workflow name") name: str = Field(..., description="Workflow name")
path: Path = Field(..., description="Path to workflow directory") path: Path = Field(..., description="Path to workflow directory")
workflow_file: Path = Field(..., description="Path to workflow.py file") workflow_file: Path = Field(..., description="Path to workflow.py file")
metadata: Dict[str, Any] = Field(..., description="Workflow metadata from YAML") metadata: dict[str, Any] = Field(..., description="Workflow metadata from YAML")
workflow_type: str = Field(..., description="Workflow class name") workflow_type: str = Field(..., description="Workflow class name")
vertical: str = Field(..., description="Vertical (worker type) for this workflow") vertical: str = Field(..., description="Vertical (worker type) for this workflow")
@@ -27,8 +28,7 @@ class WorkflowInfo(BaseModel):
class WorkflowDiscovery: class WorkflowDiscovery:
""" """Discovers workflows from the filesystem.
Discovers workflows from the filesystem.
Scans toolbox/workflows/ for directories containing: Scans toolbox/workflows/ for directories containing:
- metadata.yaml (required) - metadata.yaml (required)
@@ -38,106 +38,109 @@ class WorkflowDiscovery:
which determines which worker pool will execute it. which determines which worker pool will execute it.
""" """
def __init__(self, workflows_dir: Path): def __init__(self, workflows_dir: Path) -> None:
""" """Initialize workflow discovery.
Initialize workflow discovery.
Args: Args:
workflows_dir: Path to the workflows directory workflows_dir: Path to the workflows directory
""" """
self.workflows_dir = workflows_dir self.workflows_dir = workflows_dir
if not self.workflows_dir.exists(): if not self.workflows_dir.exists():
self.workflows_dir.mkdir(parents=True, exist_ok=True) self.workflows_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"Created workflows directory: {self.workflows_dir}") logger.info("Created workflows directory: %s", self.workflows_dir)
async def discover_workflows(self) -> Dict[str, WorkflowInfo]: async def discover_workflows(self) -> dict[str, WorkflowInfo]:
""" """Discover workflows by scanning the workflows directory.
Discover workflows by scanning the workflows directory.
Returns: Returns:
Dictionary mapping workflow names to their information Dictionary mapping workflow names to their information
""" """
workflows = {} workflows = {}
logger.info(f"Scanning for workflows in: {self.workflows_dir}") logger.info("Scanning for workflows in: %s", self.workflows_dir)
for workflow_dir in self.workflows_dir.iterdir(): for workflow_dir in self.workflows_dir.iterdir():
if not workflow_dir.is_dir(): if not workflow_dir.is_dir():
continue continue
# Skip special directories # Skip special directories
if workflow_dir.name.startswith('.') or workflow_dir.name == '__pycache__': if workflow_dir.name.startswith(".") or workflow_dir.name == "__pycache__":
continue continue
metadata_file = workflow_dir / "metadata.yaml" metadata_file = workflow_dir / "metadata.yaml"
if not metadata_file.exists(): if not metadata_file.exists():
logger.debug(f"No metadata.yaml in {workflow_dir.name}, skipping") logger.debug("No metadata.yaml in %s, skipping", workflow_dir.name)
continue continue
workflow_file = workflow_dir / "workflow.py" workflow_file = workflow_dir / "workflow.py"
if not workflow_file.exists(): if not workflow_file.exists():
logger.warning( logger.warning(
f"Workflow {workflow_dir.name} has metadata but no workflow.py, skipping" "Workflow %s has metadata but no workflow.py, skipping",
workflow_dir.name,
) )
continue continue
try: try:
# Parse metadata # Parse metadata
with open(metadata_file) as f: with metadata_file.open() as f:
metadata = yaml.safe_load(f) metadata = yaml.safe_load(f)
# Validate required fields # Validate required fields
if 'name' not in metadata: if "name" not in metadata:
logger.warning(f"Workflow {workflow_dir.name} metadata missing 'name' field") logger.warning("Workflow %s metadata missing 'name' field", workflow_dir.name)
metadata['name'] = workflow_dir.name metadata["name"] = workflow_dir.name
if 'vertical' not in metadata: if "vertical" not in metadata:
logger.warning( logger.warning(
f"Workflow {workflow_dir.name} metadata missing 'vertical' field" "Workflow %s metadata missing 'vertical' field",
workflow_dir.name,
) )
continue continue
# Infer workflow class name from metadata or use convention # Infer workflow class name from metadata or use convention
workflow_type = metadata.get('workflow_class') workflow_type = metadata.get("workflow_class")
if not workflow_type: if not workflow_type:
# Convention: convert snake_case to PascalCase + Workflow # Convention: convert snake_case to PascalCase + Workflow
# e.g., rust_test -> RustTestWorkflow # e.g., rust_test -> RustTestWorkflow
parts = workflow_dir.name.split('_') parts = workflow_dir.name.split("_")
workflow_type = ''.join(part.capitalize() for part in parts) + 'Workflow' workflow_type = "".join(part.capitalize() for part in parts) + "Workflow"
# Create workflow info # Create workflow info
info = WorkflowInfo( info = WorkflowInfo(
name=metadata['name'], name=metadata["name"],
path=workflow_dir, path=workflow_dir,
workflow_file=workflow_file, workflow_file=workflow_file,
metadata=metadata, metadata=metadata,
workflow_type=workflow_type, workflow_type=workflow_type,
vertical=metadata['vertical'] vertical=metadata["vertical"],
) )
workflows[info.name] = info workflows[info.name] = info
logger.info( logger.info(
f"✓ Discovered workflow: {info.name} " "✓ Discovered workflow: %s (vertical: %s, class: %s)",
f"(vertical: {info.vertical}, class: {info.workflow_type})" info.name,
info.vertical,
info.workflow_type,
) )
except Exception as e: except Exception:
logger.error( logger.exception(
f"Error discovering workflow {workflow_dir.name}: {e}", "Error discovering workflow %s",
exc_info=True workflow_dir.name,
) )
continue continue
logger.info(f"Discovered {len(workflows)} workflows") logger.info("Discovered %s workflows", len(workflows))
return workflows return workflows
def get_workflows_by_vertical( def get_workflows_by_vertical(
self, self,
workflows: Dict[str, WorkflowInfo], workflows: dict[str, WorkflowInfo],
vertical: str vertical: str,
) -> Dict[str, WorkflowInfo]: ) -> dict[str, WorkflowInfo]:
""" """Filter workflows by vertical.
Filter workflows by vertical.
Args: Args:
workflows: All discovered workflows workflows: All discovered workflows
@@ -145,32 +148,29 @@ class WorkflowDiscovery:
Returns: Returns:
Filtered workflows dictionary Filtered workflows dictionary
"""
return {
name: info
for name, info in workflows.items()
if info.vertical == vertical
}
def get_available_verticals(self, workflows: Dict[str, WorkflowInfo]) -> list[str]:
""" """
Get list of all verticals from discovered workflows. return {name: info for name, info in workflows.items() if info.vertical == vertical}
def get_available_verticals(self, workflows: dict[str, WorkflowInfo]) -> list[str]:
"""Get list of all verticals from discovered workflows.
Args: Args:
workflows: All discovered workflows workflows: All discovered workflows
Returns: Returns:
List of unique vertical names List of unique vertical names
""" """
return list(set(info.vertical for info in workflows.values())) return {info.vertical for info in workflows.values()}
@staticmethod @staticmethod
def get_metadata_schema() -> Dict[str, Any]: def get_metadata_schema() -> dict[str, Any]:
""" """Get the JSON schema for workflow metadata.
Get the JSON schema for workflow metadata.
Returns: Returns:
JSON schema dictionary JSON schema dictionary
""" """
return { return {
"type": "object", "type": "object",
@@ -178,34 +178,34 @@ class WorkflowDiscovery:
"properties": { "properties": {
"name": { "name": {
"type": "string", "type": "string",
"description": "Workflow name" "description": "Workflow name",
}, },
"version": { "version": {
"type": "string", "type": "string",
"pattern": "^\\d+\\.\\d+\\.\\d+$", "pattern": "^\\d+\\.\\d+\\.\\d+$",
"description": "Semantic version (x.y.z)" "description": "Semantic version (x.y.z)",
}, },
"vertical": { "vertical": {
"type": "string", "type": "string",
"description": "Vertical worker type (rust, android, web, etc.)" "description": "Vertical worker type (rust, android, web, etc.)",
}, },
"description": { "description": {
"type": "string", "type": "string",
"description": "Workflow description" "description": "Workflow description",
}, },
"author": { "author": {
"type": "string", "type": "string",
"description": "Workflow author" "description": "Workflow author",
}, },
"category": { "category": {
"type": "string", "type": "string",
"enum": ["comprehensive", "specialized", "fuzzing", "focused"], "enum": ["comprehensive", "specialized", "fuzzing", "focused"],
"description": "Workflow category" "description": "Workflow category",
}, },
"tags": { "tags": {
"type": "array", "type": "array",
"items": {"type": "string"}, "items": {"type": "string"},
"description": "Workflow tags for categorization" "description": "Workflow tags for categorization",
}, },
"requirements": { "requirements": {
"type": "object", "type": "object",
@@ -214,7 +214,7 @@ class WorkflowDiscovery:
"tools": { "tools": {
"type": "array", "type": "array",
"items": {"type": "string"}, "items": {"type": "string"},
"description": "Required security tools" "description": "Required security tools",
}, },
"resources": { "resources": {
"type": "object", "type": "object",
@@ -223,35 +223,35 @@ class WorkflowDiscovery:
"memory": { "memory": {
"type": "string", "type": "string",
"pattern": "^\\d+[GMK]i$", "pattern": "^\\d+[GMK]i$",
"description": "Memory limit (e.g., 1Gi, 512Mi)" "description": "Memory limit (e.g., 1Gi, 512Mi)",
}, },
"cpu": { "cpu": {
"type": "string", "type": "string",
"pattern": "^\\d+m?$", "pattern": "^\\d+m?$",
"description": "CPU limit (e.g., 1000m, 2)" "description": "CPU limit (e.g., 1000m, 2)",
}, },
"timeout": { "timeout": {
"type": "integer", "type": "integer",
"minimum": 60, "minimum": 60,
"maximum": 7200, "maximum": 7200,
"description": "Workflow timeout in seconds" "description": "Workflow timeout in seconds",
} },
} },
} },
} },
}, },
"parameters": { "parameters": {
"type": "object", "type": "object",
"description": "Workflow parameters schema" "description": "Workflow parameters schema",
}, },
"default_parameters": { "default_parameters": {
"type": "object", "type": "object",
"description": "Default parameter values" "description": "Default parameter values",
}, },
"required_modules": { "required_modules": {
"type": "array", "type": "array",
"items": {"type": "string"}, "items": {"type": "string"},
"description": "Required module names" "description": "Required module names",
} },
} },
} }

View File

@@ -1,5 +1,4 @@
""" """Temporal Manager - Workflow execution and management.
Temporal Manager - Workflow execution and management
Handles: Handles:
- Workflow discovery from toolbox - Workflow discovery from toolbox
@@ -8,25 +7,26 @@ Handles:
- Results retrieval - Results retrieval
""" """
import asyncio
import logging import logging
import os import os
from datetime import timedelta
from pathlib import Path from pathlib import Path
from typing import Dict, Optional, Any from typing import Any
from uuid import uuid4 from uuid import uuid4
from temporalio.client import Client, WorkflowHandle from temporalio.client import Client, WorkflowHandle
from temporalio.common import RetryPolicy from temporalio.common import RetryPolicy
from datetime import timedelta
from src.storage import S3CachedStorage
from .discovery import WorkflowDiscovery, WorkflowInfo from .discovery import WorkflowDiscovery, WorkflowInfo
from src.storage import S3CachedStorage
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class TemporalManager: class TemporalManager:
""" """Manages Temporal workflow execution for FuzzForge.
Manages Temporal workflow execution for FuzzForge.
This class: This class:
- Discovers available workflows from toolbox - Discovers available workflows from toolbox
@@ -37,41 +37,42 @@ class TemporalManager:
def __init__( def __init__(
self, self,
workflows_dir: Optional[Path] = None, workflows_dir: Path | None = None,
temporal_address: Optional[str] = None, temporal_address: str | None = None,
temporal_namespace: str = "default", temporal_namespace: str = "default",
storage: Optional[S3CachedStorage] = None storage: S3CachedStorage | None = None,
): ) -> None:
""" """Initialize Temporal manager.
Initialize Temporal manager.
Args: Args:
workflows_dir: Path to workflows directory (default: toolbox/workflows) workflows_dir: Path to workflows directory (default: toolbox/workflows)
temporal_address: Temporal server address (default: from env or localhost:7233) temporal_address: Temporal server address (default: from env or localhost:7233)
temporal_namespace: Temporal namespace temporal_namespace: Temporal namespace
storage: Storage backend for file uploads (default: S3CachedStorage) storage: Storage backend for file uploads (default: S3CachedStorage)
""" """
if workflows_dir is None: if workflows_dir is None:
workflows_dir = Path("toolbox/workflows") workflows_dir = Path("toolbox/workflows")
self.temporal_address = temporal_address or os.getenv( self.temporal_address = temporal_address or os.getenv(
'TEMPORAL_ADDRESS', "TEMPORAL_ADDRESS",
'localhost:7233' "localhost:7233",
) )
self.temporal_namespace = temporal_namespace self.temporal_namespace = temporal_namespace
self.discovery = WorkflowDiscovery(workflows_dir) self.discovery = WorkflowDiscovery(workflows_dir)
self.workflows: Dict[str, WorkflowInfo] = {} self.workflows: dict[str, WorkflowInfo] = {}
self.client: Optional[Client] = None self.client: Client | None = None
# Initialize storage backend # Initialize storage backend
self.storage = storage or S3CachedStorage() self.storage = storage or S3CachedStorage()
logger.info( logger.info(
f"TemporalManager initialized: {self.temporal_address} " "TemporalManager initialized: %s (namespace: %s)",
f"(namespace: {self.temporal_namespace})" self.temporal_address,
self.temporal_namespace,
) )
async def initialize(self): async def initialize(self) -> None:
"""Initialize the manager by discovering workflows and connecting to Temporal.""" """Initialize the manager by discovering workflows and connecting to Temporal."""
try: try:
# Discover workflows # Discover workflows
@@ -81,45 +82,46 @@ class TemporalManager:
logger.warning("No workflows discovered") logger.warning("No workflows discovered")
else: else:
logger.info( logger.info(
f"Discovered {len(self.workflows)} workflows: " "Discovered %s workflows: %s",
f"{list(self.workflows.keys())}" len(self.workflows),
list(self.workflows.keys()),
) )
# Connect to Temporal # Connect to Temporal
self.client = await Client.connect( self.client = await Client.connect(
self.temporal_address, self.temporal_address,
namespace=self.temporal_namespace namespace=self.temporal_namespace,
) )
logger.info(f"✓ Connected to Temporal: {self.temporal_address}") logger.info("✓ Connected to Temporal: %s", self.temporal_address)
except Exception as e: except Exception:
logger.error(f"Failed to initialize Temporal manager: {e}", exc_info=True) logger.exception("Failed to initialize Temporal manager")
raise raise
async def close(self): async def close(self) -> None:
"""Close Temporal client connection.""" """Close Temporal client connection."""
if self.client: if self.client:
# Temporal client doesn't need explicit close in Python SDK # Temporal client doesn't need explicit close in Python SDK
pass pass
async def get_workflows(self) -> Dict[str, WorkflowInfo]: async def get_workflows(self) -> dict[str, WorkflowInfo]:
""" """Get all discovered workflows.
Get all discovered workflows.
Returns: Returns:
Dictionary mapping workflow names to their info Dictionary mapping workflow names to their info
""" """
return self.workflows return self.workflows
async def get_workflow(self, name: str) -> Optional[WorkflowInfo]: async def get_workflow(self, name: str) -> WorkflowInfo | None:
""" """Get workflow info by name.
Get workflow info by name.
Args: Args:
name: Workflow name name: Workflow name
Returns: Returns:
WorkflowInfo or None if not found WorkflowInfo or None if not found
""" """
return self.workflows.get(name) return self.workflows.get(name)
@@ -127,10 +129,9 @@ class TemporalManager:
self, self,
file_path: Path, file_path: Path,
user_id: str, user_id: str,
metadata: Optional[Dict[str, Any]] = None metadata: dict[str, Any] | None = None,
) -> str: ) -> str:
""" """Upload target file to storage.
Upload target file to storage.
Args: Args:
file_path: Local path to file file_path: Local path to file
@@ -139,20 +140,20 @@ class TemporalManager:
Returns: Returns:
Target ID for use in workflow execution Target ID for use in workflow execution
""" """
target_id = await self.storage.upload_target(file_path, user_id, metadata) target_id = await self.storage.upload_target(file_path, user_id, metadata)
logger.info(f"Uploaded target: {target_id}") logger.info("Uploaded target: %s", target_id)
return target_id return target_id
async def run_workflow( async def run_workflow(
self, self,
workflow_name: str, workflow_name: str,
target_id: str, target_id: str,
workflow_params: Optional[Dict[str, Any]] = None, workflow_params: dict[str, Any] | None = None,
workflow_id: Optional[str] = None workflow_id: str | None = None,
) -> WorkflowHandle: ) -> WorkflowHandle:
""" """Execute a workflow.
Execute a workflow.
Args: Args:
workflow_name: Name of workflow to execute workflow_name: Name of workflow to execute
@@ -165,14 +166,17 @@ class TemporalManager:
Raises: Raises:
ValueError: If workflow not found or client not initialized ValueError: If workflow not found or client not initialized
""" """
if not self.client: if not self.client:
raise ValueError("Temporal client not initialized. Call initialize() first.") msg = "Temporal client not initialized. Call initialize() first."
raise ValueError(msg)
# Get workflow info # Get workflow info
workflow_info = self.workflows.get(workflow_name) workflow_info = self.workflows.get(workflow_name)
if not workflow_info: if not workflow_info:
raise ValueError(f"Workflow not found: {workflow_name}") msg = f"Workflow not found: {workflow_name}"
raise ValueError(msg)
# Generate workflow ID if not provided # Generate workflow ID if not provided
if not workflow_id: if not workflow_id:
@@ -188,23 +192,23 @@ class TemporalManager:
# Add parameters in order based on metadata schema # Add parameters in order based on metadata schema
# This ensures parameters match the workflow signature order # This ensures parameters match the workflow signature order
# Apply defaults from metadata.yaml if parameter not provided # Apply defaults from metadata.yaml if parameter not provided
if 'parameters' in workflow_info.metadata: if "parameters" in workflow_info.metadata:
param_schema = workflow_info.metadata['parameters'].get('properties', {}) param_schema = workflow_info.metadata["parameters"].get("properties", {})
logger.debug(f"Found {len(param_schema)} parameters in schema") logger.debug("Found %s parameters in schema", len(param_schema))
# Iterate parameters in schema order and add values # Iterate parameters in schema order and add values
for param_name in param_schema.keys(): for param_name in param_schema:
param_spec = param_schema[param_name] param_spec = param_schema[param_name]
# Use provided param, or fall back to default from metadata # Use provided param, or fall back to default from metadata
if workflow_params and param_name in workflow_params: if workflow_params and param_name in workflow_params:
param_value = workflow_params[param_name] param_value = workflow_params[param_name]
logger.debug(f"Using provided value for {param_name}: {param_value}") logger.debug("Using provided value for %s: %s", param_name, param_value)
elif 'default' in param_spec: elif "default" in param_spec:
param_value = param_spec['default'] param_value = param_spec["default"]
logger.debug(f"Using default for {param_name}: {param_value}") logger.debug("Using default for %s: %s", param_name, param_value)
else: else:
param_value = None param_value = None
logger.debug(f"No value or default for {param_name}, using None") logger.debug("No value or default for {param_name}, using None")
workflow_args.append(param_value) workflow_args.append(param_value)
else: else:
@@ -215,11 +219,14 @@ class TemporalManager:
task_queue = f"{vertical}-queue" task_queue = f"{vertical}-queue"
logger.info( logger.info(
f"Starting workflow: {workflow_name} " "Starting workflow: %s (id=%s, queue=%s, target=%s)",
f"(id={workflow_id}, queue={task_queue}, target={target_id})" workflow_name,
workflow_id,
task_queue,
target_id,
) )
logger.info(f"DEBUG: workflow_args = {workflow_args}") logger.info("DEBUG: workflow_args = %s", workflow_args)
logger.info(f"DEBUG: workflow_params received = {workflow_params}") logger.infof("DEBUG: workflow_params received = %s", workflow_params)
try: try:
# Start workflow execution with positional arguments # Start workflow execution with positional arguments
@@ -231,20 +238,20 @@ class TemporalManager:
retry_policy=RetryPolicy( retry_policy=RetryPolicy(
initial_interval=timedelta(seconds=1), initial_interval=timedelta(seconds=1),
maximum_interval=timedelta(minutes=1), maximum_interval=timedelta(minutes=1),
maximum_attempts=3 maximum_attempts=3,
) ),
) )
logger.info(f"✓ Workflow started: {workflow_id}") logger.info("✓ Workflow started: %s", workflow_id)
except Exception:
logger.exception("Failed to start workflow %s", workflow_name)
raise
else:
return handle return handle
except Exception as e: async def get_workflow_status(self, workflow_id: str) -> dict[str, Any]:
logger.error(f"Failed to start workflow {workflow_name}: {e}", exc_info=True) """Get workflow execution status.
raise
async def get_workflow_status(self, workflow_id: str) -> Dict[str, Any]:
"""
Get workflow execution status.
Args: Args:
workflow_id: Workflow execution ID workflow_id: Workflow execution ID
@@ -254,9 +261,11 @@ class TemporalManager:
Raises: Raises:
ValueError: If client not initialized or workflow not found ValueError: If client not initialized or workflow not found
""" """
if not self.client: if not self.client:
raise ValueError("Temporal client not initialized") msg = "Temporal client not initialized"
raise ValueError(msg)
try: try:
# Get workflow handle # Get workflow handle
@@ -274,20 +283,20 @@ class TemporalManager:
"task_queue": description.task_queue, "task_queue": description.task_queue,
} }
logger.info(f"Workflow {workflow_id} status: {status['status']}") logger.info("Workflow %s status: %s", workflow_id, status["status"])
return status
except Exception as e: except Exception:
logger.error(f"Failed to get workflow status: {e}", exc_info=True) logger.exception("Failed to get workflow status")
raise raise
else:
return status
async def get_workflow_result( async def get_workflow_result(
self, self,
workflow_id: str, workflow_id: str,
timeout: Optional[timedelta] = None timeout: timedelta | None = None,
) -> Any: ) -> Any:
""" """Get workflow execution result (blocking).
Get workflow execution result (blocking).
Args: Args:
workflow_id: Workflow execution ID workflow_id: Workflow execution ID
@@ -299,60 +308,62 @@ class TemporalManager:
Raises: Raises:
ValueError: If client not initialized ValueError: If client not initialized
TimeoutError: If timeout exceeded TimeoutError: If timeout exceeded
""" """
if not self.client: if not self.client:
raise ValueError("Temporal client not initialized") msg = "Temporal client not initialized"
raise ValueError(msg)
try: try:
handle = self.client.get_workflow_handle(workflow_id) handle = self.client.get_workflow_handle(workflow_id)
logger.info(f"Waiting for workflow result: {workflow_id}") logger.info("Waiting for workflow result: %s", workflow_id)
# Wait for workflow to complete and get result # Wait for workflow to complete and get result
if timeout: if timeout:
# Use asyncio timeout if provided # Use asyncio timeout if provided
import asyncio
result = await asyncio.wait_for(handle.result(), timeout=timeout.total_seconds()) result = await asyncio.wait_for(handle.result(), timeout=timeout.total_seconds())
else: else:
result = await handle.result() result = await handle.result()
logger.info(f"✓ Workflow {workflow_id} completed") logger.info("✓ Workflow %s completed", workflow_id)
except Exception:
logger.exception("Failed to get workflow result")
raise
else:
return result return result
except Exception as e:
logger.error(f"Failed to get workflow result: {e}", exc_info=True)
raise
async def cancel_workflow(self, workflow_id: str) -> None: async def cancel_workflow(self, workflow_id: str) -> None:
""" """Cancel a running workflow.
Cancel a running workflow.
Args: Args:
workflow_id: Workflow execution ID workflow_id: Workflow execution ID
Raises: Raises:
ValueError: If client not initialized ValueError: If client not initialized
""" """
if not self.client: if not self.client:
raise ValueError("Temporal client not initialized") msg = "Temporal client not initialized"
raise ValueError(msg)
try: try:
handle = self.client.get_workflow_handle(workflow_id) handle = self.client.get_workflow_handle(workflow_id)
await handle.cancel() await handle.cancel()
logger.info(f"✓ Workflow cancelled: {workflow_id}") logger.info("✓ Workflow cancelled: %s", workflow_id)
except Exception as e: except Exception:
logger.error(f"Failed to cancel workflow: {e}", exc_info=True) logger.exception("Failed to cancel workflow: %s")
raise raise
async def list_workflows( async def list_workflows(
self, self,
filter_query: Optional[str] = None, filter_query: str | None = None,
limit: int = 100 limit: int = 100,
) -> list[Dict[str, Any]]: ) -> list[dict[str, Any]]:
""" """List workflow executions.
List workflow executions.
Args: Args:
filter_query: Optional Temporal list filter query filter_query: Optional Temporal list filter query
@@ -363,30 +374,36 @@ class TemporalManager:
Raises: Raises:
ValueError: If client not initialized ValueError: If client not initialized
""" """
if not self.client: if not self.client:
raise ValueError("Temporal client not initialized") msg = "Temporal client not initialized"
raise ValueError(msg)
try: try:
workflows = [] workflows = []
# Use Temporal's list API # Use Temporal's list API
async for workflow in self.client.list_workflows(filter_query): async for workflow in self.client.list_workflows(filter_query):
workflows.append({ workflows.append(
"workflow_id": workflow.id, {
"workflow_type": workflow.workflow_type, "workflow_id": workflow.id,
"status": workflow.status.name, "workflow_type": workflow.workflow_type,
"start_time": workflow.start_time.isoformat() if workflow.start_time else None, "status": workflow.status.name,
"close_time": workflow.close_time.isoformat() if workflow.close_time else None, "start_time": workflow.start_time.isoformat() if workflow.start_time else None,
"task_queue": workflow.task_queue, "close_time": workflow.close_time.isoformat() if workflow.close_time else None,
}) "task_queue": workflow.task_queue,
},
)
if len(workflows) >= limit: if len(workflows) >= limit:
break break
logger.info(f"Listed {len(workflows)} workflows") logger.info("Listed %s workflows", len(workflows))
return workflows return workflows
except Exception as e: except Exception:
logger.error(f"Failed to list workflows: {e}", exc_info=True) logger.exception("Failed to list workflows")
raise raise
else:
return workflows

View File

@@ -8,11 +8,19 @@
# See the LICENSE-APACHE file or http://www.apache.org/licenses/LICENSE-2.0 # See the LICENSE-APACHE file or http://www.apache.org/licenses/LICENSE-2.0
# #
# Additional attribution and requirements are provided in the NOTICE file. # Additional attribution and requirements are provided in the NOTICE file.
"""Fixtures used across tests."""
import sys import sys
from collections.abc import Callable
from pathlib import Path from pathlib import Path
from typing import Dict, Any from types import CoroutineType
from typing import Any
import pytest import pytest
from modules.analyzer.security_analyzer import SecurityAnalyzer
from modules.fuzzer.atheris_fuzzer import AtherisFuzzer
from modules.fuzzer.cargo_fuzzer import CargoFuzzer
from modules.scanner.file_scanner import FileScanner
# Ensure project root is on sys.path so `src` is importable # Ensure project root is on sys.path so `src` is importable
ROOT = Path(__file__).resolve().parents[1] ROOT = Path(__file__).resolve().parents[1]
@@ -29,17 +37,18 @@ if str(TOOLBOX) not in sys.path:
# Workspace Fixtures # Workspace Fixtures
# ============================================================================ # ============================================================================
@pytest.fixture @pytest.fixture
def temp_workspace(tmp_path): def temp_workspace(tmp_path: Path) -> Path:
"""Create a temporary workspace directory for testing""" """Create a temporary workspace directory for testing."""
workspace = tmp_path / "workspace" workspace = tmp_path / "workspace"
workspace.mkdir() workspace.mkdir()
return workspace return workspace
@pytest.fixture @pytest.fixture
def python_test_workspace(temp_workspace): def python_test_workspace(temp_workspace: Path) -> Path:
"""Create a Python test workspace with sample files""" """Create a Python test workspace with sample files."""
# Create a simple Python project structure # Create a simple Python project structure
(temp_workspace / "main.py").write_text(""" (temp_workspace / "main.py").write_text("""
def process_data(data): def process_data(data):
@@ -62,8 +71,8 @@ AWS_SECRET = "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"
@pytest.fixture @pytest.fixture
def rust_test_workspace(temp_workspace): def rust_test_workspace(temp_workspace: Path) -> Path:
"""Create a Rust test workspace with fuzz targets""" """Create a Rust test workspace with fuzz targets."""
# Create Cargo.toml # Create Cargo.toml
(temp_workspace / "Cargo.toml").write_text("""[package] (temp_workspace / "Cargo.toml").write_text("""[package]
name = "test_project" name = "test_project"
@@ -131,44 +140,45 @@ fuzz_target!(|data: &[u8]| {
# Module Configuration Fixtures # Module Configuration Fixtures
# ============================================================================ # ============================================================================
@pytest.fixture @pytest.fixture
def atheris_config(): def atheris_config() -> dict[str, Any]:
"""Default Atheris fuzzer configuration""" """Return default Atheris fuzzer configuration."""
return { return {
"target_file": "auto-discover", "target_file": "auto-discover",
"max_iterations": 1000, "max_iterations": 1000,
"timeout_seconds": 10, "timeout_seconds": 10,
"corpus_dir": None "corpus_dir": None,
} }
@pytest.fixture @pytest.fixture
def cargo_fuzz_config(): def cargo_fuzz_config() -> dict[str, Any]:
"""Default cargo-fuzz configuration""" """Return default cargo-fuzz configuration."""
return { return {
"target_name": None, "target_name": None,
"max_iterations": 1000, "max_iterations": 1000,
"timeout_seconds": 10, "timeout_seconds": 10,
"sanitizer": "address" "sanitizer": "address",
} }
@pytest.fixture @pytest.fixture
def gitleaks_config(): def gitleaks_config() -> dict[str, Any]:
"""Default Gitleaks configuration""" """Return default Gitleaks configuration."""
return { return {
"config_path": None, "config_path": None,
"scan_uncommitted": True "scan_uncommitted": True,
} }
@pytest.fixture @pytest.fixture
def file_scanner_config(): def file_scanner_config() -> dict[str, Any]:
"""Default file scanner configuration""" """Return default file scanner configuration."""
return { return {
"scan_patterns": ["*.py", "*.rs", "*.js"], "scan_patterns": ["*.py", "*.rs", "*.js"],
"exclude_patterns": ["*.test.*", "*.spec.*"], "exclude_patterns": ["*.test.*", "*.spec.*"],
"max_file_size": 1048576 # 1MB "max_file_size": 1048576, # 1MB
} }
@@ -176,55 +186,67 @@ def file_scanner_config():
# Module Instance Fixtures # Module Instance Fixtures
# ============================================================================ # ============================================================================
@pytest.fixture @pytest.fixture
def atheris_fuzzer(): def atheris_fuzzer() -> AtherisFuzzer:
"""Create an AtherisFuzzer instance""" """Create an AtherisFuzzer instance."""
from modules.fuzzer.atheris_fuzzer import AtherisFuzzer
return AtherisFuzzer() return AtherisFuzzer()
@pytest.fixture @pytest.fixture
def cargo_fuzzer(): def cargo_fuzzer() -> CargoFuzzer:
"""Create a CargoFuzzer instance""" """Create a CargoFuzzer instance."""
from modules.fuzzer.cargo_fuzzer import CargoFuzzer
return CargoFuzzer() return CargoFuzzer()
@pytest.fixture @pytest.fixture
def file_scanner(): def file_scanner() -> FileScanner:
"""Create a FileScanner instance""" """Create a FileScanner instance."""
from modules.scanner.file_scanner import FileScanner
return FileScanner() return FileScanner()
@pytest.fixture
def security_analyzer() -> SecurityAnalyzer:
"""Create SecurityAnalyzer instance."""
return SecurityAnalyzer()
# ============================================================================ # ============================================================================
# Mock Fixtures # Mock Fixtures
# ============================================================================ # ============================================================================
@pytest.fixture @pytest.fixture
def mock_stats_callback(): def mock_stats_callback() -> Callable[[], CoroutineType]:
"""Mock stats callback for fuzzing""" """Mock stats callback for fuzzing."""
stats_received = [] stats_received = []
async def callback(stats: Dict[str, Any]): async def callback(stats: dict[str, Any]) -> None:
stats_received.append(stats) stats_received.append(stats)
callback.stats_received = stats_received callback.stats_received = stats_received
return callback return callback
class MockActivityInfo:
"""Mock activity info."""
def __init__(self) -> None:
"""Initialize an instance of the class."""
self.workflow_id = "test-workflow-123"
self.activity_id = "test-activity-1"
self.attempt = 1
class MockContext:
"""Mock context."""
def __init__(self) -> None:
"""Initialize an instance of the class."""
self.info = MockActivityInfo()
@pytest.fixture @pytest.fixture
def mock_temporal_context(): def mock_temporal_context() -> MockContext:
"""Mock Temporal activity context""" """Mock Temporal activity context."""
class MockActivityInfo:
def __init__(self):
self.workflow_id = "test-workflow-123"
self.activity_id = "test-activity-1"
self.attempt = 1
class MockContext:
def __init__(self):
self.info = MockActivityInfo()
return MockContext() return MockContext()

View File

View File

@@ -0,0 +1 @@
"""Unit tests."""

View File

@@ -0,0 +1 @@
"""Unit tests for modules."""

View File

@@ -1,17 +1,26 @@
""" """Unit tests for AtherisFuzzer module."""
Unit tests for AtherisFuzzer module
""" from __future__ import annotations
from typing import TYPE_CHECKING
from unittest.mock import AsyncMock, patch
import pytest import pytest
from unittest.mock import AsyncMock, patch
if TYPE_CHECKING:
from collections.abc import Callable
from pathlib import Path
from typing import Any
from modules.fuzzer.atheris_fuzzer import AtherisFuzzer
@pytest.mark.asyncio @pytest.mark.asyncio
class TestAtherisFuzzerMetadata: class TestAtherisFuzzerMetadata:
"""Test AtherisFuzzer metadata""" """Test AtherisFuzzer metadata."""
async def test_metadata_structure(self, atheris_fuzzer): async def test_metadata_structure(self, atheris_fuzzer: AtherisFuzzer) -> None:
"""Test that module metadata is properly defined""" """Test that module metadata is properly defined."""
metadata = atheris_fuzzer.get_metadata() metadata = atheris_fuzzer.get_metadata()
assert metadata.name == "atheris_fuzzer" assert metadata.name == "atheris_fuzzer"
@@ -22,28 +31,28 @@ class TestAtherisFuzzerMetadata:
@pytest.mark.asyncio @pytest.mark.asyncio
class TestAtherisFuzzerConfigValidation: class TestAtherisFuzzerConfigValidation:
"""Test configuration validation""" """Test configuration validation."""
async def test_valid_config(self, atheris_fuzzer, atheris_config): async def test_valid_config(self, atheris_fuzzer: AtherisFuzzer, atheris_config: dict[str, Any]) -> None:
"""Test validation of valid configuration""" """Test validation of valid configuration."""
assert atheris_fuzzer.validate_config(atheris_config) is True assert atheris_fuzzer.validate_config(atheris_config) is True
async def test_invalid_max_iterations(self, atheris_fuzzer): async def test_invalid_max_iterations(self, atheris_fuzzer: AtherisFuzzer) -> None:
"""Test validation fails with invalid max_iterations""" """Test validation fails with invalid max_iterations."""
config = { config = {
"target_file": "fuzz_target.py", "target_file": "fuzz_target.py",
"max_iterations": -1, "max_iterations": -1,
"timeout_seconds": 10 "timeout_seconds": 10,
} }
with pytest.raises(ValueError, match="max_iterations"): with pytest.raises(ValueError, match="max_iterations"):
atheris_fuzzer.validate_config(config) atheris_fuzzer.validate_config(config)
async def test_invalid_timeout(self, atheris_fuzzer): async def test_invalid_timeout(self, atheris_fuzzer: AtherisFuzzer) -> None:
"""Test validation fails with invalid timeout""" """Test validation fails with invalid timeout."""
config = { config = {
"target_file": "fuzz_target.py", "target_file": "fuzz_target.py",
"max_iterations": 1000, "max_iterations": 1000,
"timeout_seconds": 0 "timeout_seconds": 0,
} }
with pytest.raises(ValueError, match="timeout_seconds"): with pytest.raises(ValueError, match="timeout_seconds"):
atheris_fuzzer.validate_config(config) atheris_fuzzer.validate_config(config)
@@ -51,10 +60,10 @@ class TestAtherisFuzzerConfigValidation:
@pytest.mark.asyncio @pytest.mark.asyncio
class TestAtherisFuzzerDiscovery: class TestAtherisFuzzerDiscovery:
"""Test fuzz target discovery""" """Test fuzz target discovery."""
async def test_auto_discover(self, atheris_fuzzer, python_test_workspace): async def test_auto_discover(self, atheris_fuzzer: AtherisFuzzer, python_test_workspace: Path) -> None:
"""Test auto-discovery of Python fuzz targets""" """Test auto-discovery of Python fuzz targets."""
# Create a fuzz target file # Create a fuzz target file
(python_test_workspace / "fuzz_target.py").write_text(""" (python_test_workspace / "fuzz_target.py").write_text("""
import atheris import atheris
@@ -69,7 +78,7 @@ if __name__ == "__main__":
""") """)
# Pass None for auto-discovery # Pass None for auto-discovery
target = atheris_fuzzer._discover_target(python_test_workspace, None) target = atheris_fuzzer._discover_target(python_test_workspace, None) # noqa: SLF001
assert target is not None assert target is not None
assert "fuzz_target.py" in str(target) assert "fuzz_target.py" in str(target)
@@ -77,10 +86,14 @@ if __name__ == "__main__":
@pytest.mark.asyncio @pytest.mark.asyncio
class TestAtherisFuzzerExecution: class TestAtherisFuzzerExecution:
"""Test fuzzer execution logic""" """Test fuzzer execution logic."""
async def test_execution_creates_result(self, atheris_fuzzer, python_test_workspace, atheris_config): async def test_execution_creates_result(
"""Test that execution returns a ModuleResult""" self,
atheris_fuzzer: AtherisFuzzer,
python_test_workspace: Path,
) -> None:
"""Test that execution returns a ModuleResult."""
# Create a simple fuzz target # Create a simple fuzz target
(python_test_workspace / "fuzz_target.py").write_text(""" (python_test_workspace / "fuzz_target.py").write_text("""
import atheris import atheris
@@ -99,11 +112,16 @@ if __name__ == "__main__":
test_config = { test_config = {
"target_file": "fuzz_target.py", "target_file": "fuzz_target.py",
"max_iterations": 10, "max_iterations": 10,
"timeout_seconds": 1 "timeout_seconds": 1,
} }
# Mock the fuzzing subprocess to avoid actual execution # Mock the fuzzing subprocess to avoid actual execution
with patch.object(atheris_fuzzer, '_run_fuzzing', new_callable=AsyncMock, return_value=([], {"total_executions": 10})): with patch.object(
atheris_fuzzer,
"_run_fuzzing",
new_callable=AsyncMock,
return_value=([], {"total_executions": 10}),
):
result = await atheris_fuzzer.execute(test_config, python_test_workspace) result = await atheris_fuzzer.execute(test_config, python_test_workspace)
assert result.module == "atheris_fuzzer" assert result.module == "atheris_fuzzer"
@@ -113,10 +131,16 @@ if __name__ == "__main__":
@pytest.mark.asyncio @pytest.mark.asyncio
class TestAtherisFuzzerStatsCallback: class TestAtherisFuzzerStatsCallback:
"""Test stats callback functionality""" """Test stats callback functionality."""
async def test_stats_callback_invoked(self, atheris_fuzzer, python_test_workspace, atheris_config, mock_stats_callback): async def test_stats_callback_invoked(
"""Test that stats callback is invoked during fuzzing""" self,
atheris_fuzzer: AtherisFuzzer,
python_test_workspace: Path,
atheris_config: dict[str, Any],
mock_stats_callback: Callable | None,
) -> None:
"""Test that stats callback is invoked during fuzzing."""
(python_test_workspace / "fuzz_target.py").write_text(""" (python_test_workspace / "fuzz_target.py").write_text("""
import atheris import atheris
import sys import sys
@@ -130,35 +154,45 @@ if __name__ == "__main__":
""") """)
# Mock fuzzing to simulate stats # Mock fuzzing to simulate stats
async def mock_run_fuzzing(test_one_input, target_path, workspace, max_iterations, timeout_seconds, stats_callback): async def mock_run_fuzzing(
test_one_input: Callable, # noqa: ARG001
target_path: Path, # noqa: ARG001
workspace: Path, # noqa: ARG001
max_iterations: int, # noqa: ARG001
timeout_seconds: int, # noqa: ARG001
stats_callback: Callable | None,
) -> None:
if stats_callback: if stats_callback:
await stats_callback({ await stats_callback(
"total_execs": 100, {
"execs_per_sec": 10.0, "total_execs": 100,
"crashes": 0, "execs_per_sec": 10.0,
"coverage": 5, "crashes": 0,
"corpus_size": 2, "coverage": 5,
"elapsed_time": 10 "corpus_size": 2,
}) "elapsed_time": 10,
return },
)
with patch.object(atheris_fuzzer, '_run_fuzzing', side_effect=mock_run_fuzzing): with (
with patch.object(atheris_fuzzer, '_load_target_module', return_value=lambda x: None): patch.object(atheris_fuzzer, "_run_fuzzing", side_effect=mock_run_fuzzing),
# Put stats_callback in config dict, not as kwarg patch.object(atheris_fuzzer, "_load_target_module", return_value=lambda _x: None),
atheris_config["target_file"] = "fuzz_target.py" ):
atheris_config["stats_callback"] = mock_stats_callback # Put stats_callback in config dict, not as kwarg
await atheris_fuzzer.execute(atheris_config, python_test_workspace) atheris_config["target_file"] = "fuzz_target.py"
atheris_config["stats_callback"] = mock_stats_callback
await atheris_fuzzer.execute(atheris_config, python_test_workspace)
# Verify callback was invoked # Verify callback was invoked
assert len(mock_stats_callback.stats_received) > 0 assert len(mock_stats_callback.stats_received) > 0
@pytest.mark.asyncio @pytest.mark.asyncio
class TestAtherisFuzzerFindingGeneration: class TestAtherisFuzzerFindingGeneration:
"""Test finding generation from crashes""" """Test finding generation from crashes."""
async def test_create_crash_finding(self, atheris_fuzzer): async def test_create_crash_finding(self, atheris_fuzzer: AtherisFuzzer) -> None:
"""Test crash finding creation""" """Test crash finding creation."""
finding = atheris_fuzzer.create_finding( finding = atheris_fuzzer.create_finding(
title="Crash: Exception in TestOneInput", title="Crash: Exception in TestOneInput",
description="IndexError: list index out of range", description="IndexError: list index out of range",
@@ -167,8 +201,8 @@ class TestAtherisFuzzerFindingGeneration:
file_path="fuzz_target.py", file_path="fuzz_target.py",
metadata={ metadata={
"crash_type": "IndexError", "crash_type": "IndexError",
"stack_trace": "Traceback..." "stack_trace": "Traceback...",
} },
) )
assert finding.title == "Crash: Exception in TestOneInput" assert finding.title == "Crash: Exception in TestOneInput"

View File

@@ -1,17 +1,26 @@
""" """Unit tests for CargoFuzzer module."""
Unit tests for CargoFuzzer module
""" from __future__ import annotations
from typing import TYPE_CHECKING
from unittest.mock import AsyncMock, patch
import pytest import pytest
from unittest.mock import AsyncMock, patch
if TYPE_CHECKING:
from collections.abc import Callable
from pathlib import Path
from typing import Any
from modules.fuzzer.cargo_fuzzer import CargoFuzzer
@pytest.mark.asyncio @pytest.mark.asyncio
class TestCargoFuzzerMetadata: class TestCargoFuzzerMetadata:
"""Test CargoFuzzer metadata""" """Test CargoFuzzer metadata."""
async def test_metadata_structure(self, cargo_fuzzer): async def test_metadata_structure(self, cargo_fuzzer: CargoFuzzer) -> None:
"""Test that module metadata is properly defined""" """Test that module metadata is properly defined."""
metadata = cargo_fuzzer.get_metadata() metadata = cargo_fuzzer.get_metadata()
assert metadata.name == "cargo_fuzz" assert metadata.name == "cargo_fuzz"
@@ -23,38 +32,38 @@ class TestCargoFuzzerMetadata:
@pytest.mark.asyncio @pytest.mark.asyncio
class TestCargoFuzzerConfigValidation: class TestCargoFuzzerConfigValidation:
"""Test configuration validation""" """Test configuration validation."""
async def test_valid_config(self, cargo_fuzzer, cargo_fuzz_config): async def test_valid_config(self, cargo_fuzzer: CargoFuzzer, cargo_fuzz_config: dict[str, Any]) -> None:
"""Test validation of valid configuration""" """Test validation of valid configuration."""
assert cargo_fuzzer.validate_config(cargo_fuzz_config) is True assert cargo_fuzzer.validate_config(cargo_fuzz_config) is True
async def test_invalid_max_iterations(self, cargo_fuzzer): async def test_invalid_max_iterations(self, cargo_fuzzer: CargoFuzzer) -> None:
"""Test validation fails with invalid max_iterations""" """Test validation fails with invalid max_iterations."""
config = { config = {
"max_iterations": -1, "max_iterations": -1,
"timeout_seconds": 10, "timeout_seconds": 10,
"sanitizer": "address" "sanitizer": "address",
} }
with pytest.raises(ValueError, match="max_iterations"): with pytest.raises(ValueError, match="max_iterations"):
cargo_fuzzer.validate_config(config) cargo_fuzzer.validate_config(config)
async def test_invalid_timeout(self, cargo_fuzzer): async def test_invalid_timeout(self, cargo_fuzzer: CargoFuzzer) -> None:
"""Test validation fails with invalid timeout""" """Test validation fails with invalid timeout."""
config = { config = {
"max_iterations": 1000, "max_iterations": 1000,
"timeout_seconds": 0, "timeout_seconds": 0,
"sanitizer": "address" "sanitizer": "address",
} }
with pytest.raises(ValueError, match="timeout_seconds"): with pytest.raises(ValueError, match="timeout_seconds"):
cargo_fuzzer.validate_config(config) cargo_fuzzer.validate_config(config)
async def test_invalid_sanitizer(self, cargo_fuzzer): async def test_invalid_sanitizer(self, cargo_fuzzer: CargoFuzzer) -> None:
"""Test validation fails with invalid sanitizer""" """Test validation fails with invalid sanitizer."""
config = { config = {
"max_iterations": 1000, "max_iterations": 1000,
"timeout_seconds": 10, "timeout_seconds": 10,
"sanitizer": "invalid_sanitizer" "sanitizer": "invalid_sanitizer",
} }
with pytest.raises(ValueError, match="sanitizer"): with pytest.raises(ValueError, match="sanitizer"):
cargo_fuzzer.validate_config(config) cargo_fuzzer.validate_config(config)
@@ -62,20 +71,20 @@ class TestCargoFuzzerConfigValidation:
@pytest.mark.asyncio @pytest.mark.asyncio
class TestCargoFuzzerWorkspaceValidation: class TestCargoFuzzerWorkspaceValidation:
"""Test workspace validation""" """Test workspace validation."""
async def test_valid_workspace(self, cargo_fuzzer, rust_test_workspace): async def test_valid_workspace(self, cargo_fuzzer: CargoFuzzer, rust_test_workspace: Path) -> None:
"""Test validation of valid workspace""" """Test validation of valid workspace."""
assert cargo_fuzzer.validate_workspace(rust_test_workspace) is True assert cargo_fuzzer.validate_workspace(rust_test_workspace) is True
async def test_nonexistent_workspace(self, cargo_fuzzer, tmp_path): async def test_nonexistent_workspace(self, cargo_fuzzer: CargoFuzzer, tmp_path: Path) -> None:
"""Test validation fails with nonexistent workspace""" """Test validation fails with nonexistent workspace."""
nonexistent = tmp_path / "does_not_exist" nonexistent = tmp_path / "does_not_exist"
with pytest.raises(ValueError, match="does not exist"): with pytest.raises(ValueError, match="does not exist"):
cargo_fuzzer.validate_workspace(nonexistent) cargo_fuzzer.validate_workspace(nonexistent)
async def test_workspace_is_file(self, cargo_fuzzer, tmp_path): async def test_workspace_is_file(self, cargo_fuzzer: CargoFuzzer, tmp_path: Path) -> None:
"""Test validation fails when workspace is a file""" """Test validation fails when workspace is a file."""
file_path = tmp_path / "file.txt" file_path = tmp_path / "file.txt"
file_path.write_text("test") file_path.write_text("test")
with pytest.raises(ValueError, match="not a directory"): with pytest.raises(ValueError, match="not a directory"):
@@ -84,41 +93,58 @@ class TestCargoFuzzerWorkspaceValidation:
@pytest.mark.asyncio @pytest.mark.asyncio
class TestCargoFuzzerDiscovery: class TestCargoFuzzerDiscovery:
"""Test fuzz target discovery""" """Test fuzz target discovery."""
async def test_discover_targets(self, cargo_fuzzer, rust_test_workspace): async def test_discover_targets(self, cargo_fuzzer: CargoFuzzer, rust_test_workspace: Path) -> None:
"""Test discovery of fuzz targets""" """Test discovery of fuzz targets."""
targets = await cargo_fuzzer._discover_fuzz_targets(rust_test_workspace) targets = await cargo_fuzzer._discover_fuzz_targets(rust_test_workspace) # noqa: SLF001
assert len(targets) == 1 assert len(targets) == 1
assert "fuzz_target_1" in targets assert "fuzz_target_1" in targets
async def test_no_fuzz_directory(self, cargo_fuzzer, temp_workspace): async def test_no_fuzz_directory(self, cargo_fuzzer: CargoFuzzer, temp_workspace: Path) -> None:
"""Test discovery with no fuzz directory""" """Test discovery with no fuzz directory."""
targets = await cargo_fuzzer._discover_fuzz_targets(temp_workspace) targets = await cargo_fuzzer._discover_fuzz_targets(temp_workspace) # noqa: SLF001
assert targets == [] assert targets == []
@pytest.mark.asyncio @pytest.mark.asyncio
class TestCargoFuzzerExecution: class TestCargoFuzzerExecution:
"""Test fuzzer execution logic""" """Test fuzzer execution logic."""
async def test_execution_creates_result(self, cargo_fuzzer, rust_test_workspace, cargo_fuzz_config): async def test_execution_creates_result(
"""Test that execution returns a ModuleResult""" self,
cargo_fuzzer: CargoFuzzer,
rust_test_workspace: Path,
cargo_fuzz_config: dict[str, Any],
) -> None:
"""Test that execution returns a ModuleResult."""
# Mock the build and run methods to avoid actual fuzzing # Mock the build and run methods to avoid actual fuzzing
with patch.object(cargo_fuzzer, '_build_fuzz_target', new_callable=AsyncMock, return_value=True): with (
with patch.object(cargo_fuzzer, '_run_fuzzing', new_callable=AsyncMock, return_value=([], {"total_executions": 0, "crashes_found": 0})): patch.object(cargo_fuzzer, "_build_fuzz_target", new_callable=AsyncMock, return_value=True),
with patch.object(cargo_fuzzer, '_parse_crash_artifacts', new_callable=AsyncMock, return_value=[]): patch.object(
result = await cargo_fuzzer.execute(cargo_fuzz_config, rust_test_workspace) cargo_fuzzer,
"_run_fuzzing",
new_callable=AsyncMock,
return_value=([], {"total_executions": 0, "crashes_found": 0}),
),
patch.object(cargo_fuzzer, "_parse_crash_artifacts", new_callable=AsyncMock, return_value=[]),
):
result = await cargo_fuzzer.execute(cargo_fuzz_config, rust_test_workspace)
assert result.module == "cargo_fuzz" assert result.module == "cargo_fuzz"
assert result.status == "success" assert result.status == "success"
assert isinstance(result.execution_time, float) assert isinstance(result.execution_time, float)
assert result.execution_time >= 0 assert result.execution_time >= 0
async def test_execution_with_no_targets(self, cargo_fuzzer, temp_workspace, cargo_fuzz_config): async def test_execution_with_no_targets(
"""Test execution fails gracefully with no fuzz targets""" self,
cargo_fuzzer: CargoFuzzer,
temp_workspace: Path,
cargo_fuzz_config: dict[str, Any],
) -> None:
"""Test execution fails gracefully with no fuzz targets."""
result = await cargo_fuzzer.execute(cargo_fuzz_config, temp_workspace) result = await cargo_fuzzer.execute(cargo_fuzz_config, temp_workspace)
assert result.status == "failed" assert result.status == "failed"
@@ -127,47 +153,67 @@ class TestCargoFuzzerExecution:
@pytest.mark.asyncio @pytest.mark.asyncio
class TestCargoFuzzerStatsCallback: class TestCargoFuzzerStatsCallback:
"""Test stats callback functionality""" """Test stats callback functionality."""
async def test_stats_callback_invoked(
self,
cargo_fuzzer: CargoFuzzer,
rust_test_workspace: Path,
cargo_fuzz_config: dict[str, Any],
mock_stats_callback: Callable | None,
) -> None:
"""Test that stats callback is invoked during fuzzing."""
async def test_stats_callback_invoked(self, cargo_fuzzer, rust_test_workspace, cargo_fuzz_config, mock_stats_callback):
"""Test that stats callback is invoked during fuzzing"""
# Mock build/run to simulate stats generation # Mock build/run to simulate stats generation
async def mock_run_fuzzing(workspace, target, config, callback): async def mock_run_fuzzing(
_workspace: Path,
_target: str,
_config: dict[str, Any],
callback: Callable | None,
) -> tuple[list, dict[str, int]]:
# Simulate stats callback # Simulate stats callback
if callback: if callback:
await callback({ await callback(
"total_execs": 1000, {
"execs_per_sec": 100.0, "total_execs": 1000,
"crashes": 0, "execs_per_sec": 100.0,
"coverage": 10, "crashes": 0,
"corpus_size": 5, "coverage": 10,
"elapsed_time": 10 "corpus_size": 5,
}) "elapsed_time": 10,
},
)
return [], {"total_executions": 1000} return [], {"total_executions": 1000}
with patch.object(cargo_fuzzer, '_build_fuzz_target', new_callable=AsyncMock, return_value=True): with (
with patch.object(cargo_fuzzer, '_run_fuzzing', side_effect=mock_run_fuzzing): patch.object(cargo_fuzzer, "_build_fuzz_target", new_callable=AsyncMock, return_value=True),
with patch.object(cargo_fuzzer, '_parse_crash_artifacts', new_callable=AsyncMock, return_value=[]): patch.object(cargo_fuzzer, "_run_fuzzing", side_effect=mock_run_fuzzing),
await cargo_fuzzer.execute(cargo_fuzz_config, rust_test_workspace, stats_callback=mock_stats_callback) patch.object(cargo_fuzzer, "_parse_crash_artifacts", new_callable=AsyncMock, return_value=[]),
):
await cargo_fuzzer.execute(
cargo_fuzz_config,
rust_test_workspace,
stats_callback=mock_stats_callback,
)
# Verify callback was invoked # Verify callback was invoked
assert len(mock_stats_callback.stats_received) > 0 assert len(mock_stats_callback.stats_received) > 0
assert mock_stats_callback.stats_received[0]["total_execs"] == 1000 assert mock_stats_callback.stats_received[0]["total_execs"] == 1000
@pytest.mark.asyncio @pytest.mark.asyncio
class TestCargoFuzzerFindingGeneration: class TestCargoFuzzerFindingGeneration:
"""Test finding generation from crashes""" """Test finding generation from crashes."""
async def test_create_finding_from_crash(self, cargo_fuzzer): async def test_create_finding_from_crash(self, cargo_fuzzer: CargoFuzzer) -> None:
"""Test finding creation""" """Test finding creation."""
finding = cargo_fuzzer.create_finding( finding = cargo_fuzzer.create_finding(
title="Crash: Segmentation Fault", title="Crash: Segmentation Fault",
description="Test crash", description="Test crash",
severity="critical", severity="critical",
category="crash", category="crash",
file_path="fuzz/fuzz_targets/fuzz_target_1.rs", file_path="fuzz/fuzz_targets/fuzz_target_1.rs",
metadata={"crash_type": "SIGSEGV"} metadata={"crash_type": "SIGSEGV"},
) )
assert finding.title == "Crash: Segmentation Fault" assert finding.title == "Crash: Segmentation Fault"

View File

@@ -1,22 +1,25 @@
""" """Unit tests for FileScanner module."""
Unit tests for FileScanner module
""" from __future__ import annotations
import sys import sys
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING
import pytest import pytest
if TYPE_CHECKING:
from modules.scanner.file_scanner import FileScanner
sys.path.insert(0, str(Path(__file__).resolve().parents[3] / "toolbox")) sys.path.insert(0, str(Path(__file__).resolve().parents[3] / "toolbox"))
@pytest.mark.asyncio @pytest.mark.asyncio
class TestFileScannerMetadata: class TestFileScannerMetadata:
"""Test FileScanner metadata""" """Test FileScanner metadata."""
async def test_metadata_structure(self, file_scanner): async def test_metadata_structure(self, file_scanner: FileScanner) -> None:
"""Test that metadata has correct structure""" """Test that metadata has correct structure."""
metadata = file_scanner.get_metadata() metadata = file_scanner.get_metadata()
assert metadata.name == "file_scanner" assert metadata.name == "file_scanner"
@@ -29,37 +32,37 @@ class TestFileScannerMetadata:
@pytest.mark.asyncio @pytest.mark.asyncio
class TestFileScannerConfigValidation: class TestFileScannerConfigValidation:
"""Test configuration validation""" """Test configuration validation."""
async def test_valid_config(self, file_scanner): async def test_valid_config(self, file_scanner: FileScanner) -> None:
"""Test that valid config passes validation""" """Test that valid config passes validation."""
config = { config = {
"patterns": ["*.py", "*.js"], "patterns": ["*.py", "*.js"],
"max_file_size": 1048576, "max_file_size": 1048576,
"check_sensitive": True, "check_sensitive": True,
"calculate_hashes": False "calculate_hashes": False,
} }
assert file_scanner.validate_config(config) is True assert file_scanner.validate_config(config) is True
async def test_default_config(self, file_scanner): async def test_default_config(self, file_scanner: FileScanner) -> None:
"""Test that empty config uses defaults""" """Test that empty config uses defaults."""
config = {} config = {}
assert file_scanner.validate_config(config) is True assert file_scanner.validate_config(config) is True
async def test_invalid_patterns_type(self, file_scanner): async def test_invalid_patterns_type(self, file_scanner: FileScanner) -> None:
"""Test that non-list patterns raises error""" """Test that non-list patterns raises error."""
config = {"patterns": "*.py"} config = {"patterns": "*.py"}
with pytest.raises(ValueError, match="patterns must be a list"): with pytest.raises(ValueError, match="patterns must be a list"):
file_scanner.validate_config(config) file_scanner.validate_config(config)
async def test_invalid_max_file_size(self, file_scanner): async def test_invalid_max_file_size(self, file_scanner: FileScanner) -> None:
"""Test that invalid max_file_size raises error""" """Test that invalid max_file_size raises error."""
config = {"max_file_size": -1} config = {"max_file_size": -1}
with pytest.raises(ValueError, match="max_file_size must be a positive integer"): with pytest.raises(ValueError, match="max_file_size must be a positive integer"):
file_scanner.validate_config(config) file_scanner.validate_config(config)
async def test_invalid_max_file_size_type(self, file_scanner): async def test_invalid_max_file_size_type(self, file_scanner: FileScanner) -> None:
"""Test that non-integer max_file_size raises error""" """Test that non-integer max_file_size raises error."""
config = {"max_file_size": "large"} config = {"max_file_size": "large"}
with pytest.raises(ValueError, match="max_file_size must be a positive integer"): with pytest.raises(ValueError, match="max_file_size must be a positive integer"):
file_scanner.validate_config(config) file_scanner.validate_config(config)
@@ -67,14 +70,14 @@ class TestFileScannerConfigValidation:
@pytest.mark.asyncio @pytest.mark.asyncio
class TestFileScannerExecution: class TestFileScannerExecution:
"""Test scanner execution""" """Test scanner execution."""
async def test_scan_python_files(self, file_scanner, python_test_workspace): async def test_scan_python_files(self, file_scanner: FileScanner, python_test_workspace: Path) -> None:
"""Test scanning Python files""" """Test scanning Python files."""
config = { config = {
"patterns": ["*.py"], "patterns": ["*.py"],
"check_sensitive": False, "check_sensitive": False,
"calculate_hashes": False "calculate_hashes": False,
} }
result = await file_scanner.execute(config, python_test_workspace) result = await file_scanner.execute(config, python_test_workspace)
@@ -84,15 +87,15 @@ class TestFileScannerExecution:
assert len(result.findings) > 0 assert len(result.findings) > 0
# Check that Python files were found # Check that Python files were found
python_files = [f for f in result.findings if f.file_path.endswith('.py')] python_files = [f for f in result.findings if f.file_path.endswith(".py")]
assert len(python_files) > 0 assert len(python_files) > 0
async def test_scan_all_files(self, file_scanner, python_test_workspace): async def test_scan_all_files(self, file_scanner: FileScanner, python_test_workspace: Path) -> None:
"""Test scanning all files with wildcard""" """Test scanning all files with wildcard."""
config = { config = {
"patterns": ["*"], "patterns": ["*"],
"check_sensitive": False, "check_sensitive": False,
"calculate_hashes": False "calculate_hashes": False,
} }
result = await file_scanner.execute(config, python_test_workspace) result = await file_scanner.execute(config, python_test_workspace)
@@ -101,12 +104,12 @@ class TestFileScannerExecution:
assert len(result.findings) > 0 assert len(result.findings) > 0
assert result.summary["total_files"] > 0 assert result.summary["total_files"] > 0
async def test_scan_with_multiple_patterns(self, file_scanner, python_test_workspace): async def test_scan_with_multiple_patterns(self, file_scanner: FileScanner, python_test_workspace: Path) -> None:
"""Test scanning with multiple patterns""" """Test scanning with multiple patterns."""
config = { config = {
"patterns": ["*.py", "*.txt"], "patterns": ["*.py", "*.txt"],
"check_sensitive": False, "check_sensitive": False,
"calculate_hashes": False "calculate_hashes": False,
} }
result = await file_scanner.execute(config, python_test_workspace) result = await file_scanner.execute(config, python_test_workspace)
@@ -114,11 +117,11 @@ class TestFileScannerExecution:
assert result.status == "success" assert result.status == "success"
assert len(result.findings) > 0 assert len(result.findings) > 0
async def test_empty_workspace(self, file_scanner, temp_workspace): async def test_empty_workspace(self, file_scanner: FileScanner, temp_workspace: Path) -> None:
"""Test scanning empty workspace""" """Test scanning empty workspace."""
config = { config = {
"patterns": ["*.py"], "patterns": ["*.py"],
"check_sensitive": False "check_sensitive": False,
} }
result = await file_scanner.execute(config, temp_workspace) result = await file_scanner.execute(config, temp_workspace)
@@ -130,17 +133,17 @@ class TestFileScannerExecution:
@pytest.mark.asyncio @pytest.mark.asyncio
class TestFileScannerSensitiveDetection: class TestFileScannerSensitiveDetection:
"""Test sensitive file detection""" """Test sensitive file detection."""
async def test_detect_env_file(self, file_scanner, temp_workspace): async def test_detect_env_file(self, file_scanner: FileScanner, temp_workspace: Path) -> None:
"""Test detection of .env file""" """Test detection of .env file."""
# Create .env file # Create .env file
(temp_workspace / ".env").write_text("API_KEY=secret123") (temp_workspace / ".env").write_text("API_KEY=secret123")
config = { config = {
"patterns": ["*"], "patterns": ["*"],
"check_sensitive": True, "check_sensitive": True,
"calculate_hashes": False "calculate_hashes": False,
} }
result = await file_scanner.execute(config, temp_workspace) result = await file_scanner.execute(config, temp_workspace)
@@ -152,14 +155,14 @@ class TestFileScannerSensitiveDetection:
assert len(sensitive_findings) > 0 assert len(sensitive_findings) > 0
assert any(".env" in f.title for f in sensitive_findings) assert any(".env" in f.title for f in sensitive_findings)
async def test_detect_private_key(self, file_scanner, temp_workspace): async def test_detect_private_key(self, file_scanner: FileScanner, temp_workspace: Path) -> None:
"""Test detection of private key file""" """Test detection of private key file."""
# Create private key file # Create private key file
(temp_workspace / "id_rsa").write_text("-----BEGIN RSA PRIVATE KEY-----") (temp_workspace / "id_rsa").write_text("-----BEGIN RSA PRIVATE KEY-----")
config = { config = {
"patterns": ["*"], "patterns": ["*"],
"check_sensitive": True "check_sensitive": True,
} }
result = await file_scanner.execute(config, temp_workspace) result = await file_scanner.execute(config, temp_workspace)
@@ -168,13 +171,13 @@ class TestFileScannerSensitiveDetection:
sensitive_findings = [f for f in result.findings if f.category == "sensitive_file"] sensitive_findings = [f for f in result.findings if f.category == "sensitive_file"]
assert len(sensitive_findings) > 0 assert len(sensitive_findings) > 0
async def test_no_sensitive_detection_when_disabled(self, file_scanner, temp_workspace): async def test_no_sensitive_detection_when_disabled(self, file_scanner: FileScanner, temp_workspace: Path) -> None:
"""Test that sensitive detection can be disabled""" """Test that sensitive detection can be disabled."""
(temp_workspace / ".env").write_text("API_KEY=secret123") (temp_workspace / ".env").write_text("API_KEY=secret123")
config = { config = {
"patterns": ["*"], "patterns": ["*"],
"check_sensitive": False "check_sensitive": False,
} }
result = await file_scanner.execute(config, temp_workspace) result = await file_scanner.execute(config, temp_workspace)
@@ -186,17 +189,17 @@ class TestFileScannerSensitiveDetection:
@pytest.mark.asyncio @pytest.mark.asyncio
class TestFileScannerHashing: class TestFileScannerHashing:
"""Test file hashing functionality""" """Test file hashing functionality."""
async def test_hash_calculation(self, file_scanner, temp_workspace): async def test_hash_calculation(self, file_scanner: FileScanner, temp_workspace: Path) -> None:
"""Test SHA256 hash calculation""" """Test SHA256 hash calculation."""
# Create test file # Create test file
test_file = temp_workspace / "test.txt" test_file = temp_workspace / "test.txt"
test_file.write_text("Hello World") test_file.write_text("Hello World")
config = { config = {
"patterns": ["*.txt"], "patterns": ["*.txt"],
"calculate_hashes": True "calculate_hashes": True,
} }
result = await file_scanner.execute(config, temp_workspace) result = await file_scanner.execute(config, temp_workspace)
@@ -212,14 +215,14 @@ class TestFileScannerHashing:
assert finding.metadata.get("file_hash") is not None assert finding.metadata.get("file_hash") is not None
assert len(finding.metadata["file_hash"]) == 64 # SHA256 hex length assert len(finding.metadata["file_hash"]) == 64 # SHA256 hex length
async def test_no_hash_when_disabled(self, file_scanner, temp_workspace): async def test_no_hash_when_disabled(self, file_scanner: FileScanner, temp_workspace: Path) -> None:
"""Test that hashing can be disabled""" """Test that hashing can be disabled."""
test_file = temp_workspace / "test.txt" test_file = temp_workspace / "test.txt"
test_file.write_text("Hello World") test_file.write_text("Hello World")
config = { config = {
"patterns": ["*.txt"], "patterns": ["*.txt"],
"calculate_hashes": False "calculate_hashes": False,
} }
result = await file_scanner.execute(config, temp_workspace) result = await file_scanner.execute(config, temp_workspace)
@@ -234,10 +237,10 @@ class TestFileScannerHashing:
@pytest.mark.asyncio @pytest.mark.asyncio
class TestFileScannerFileTypes: class TestFileScannerFileTypes:
"""Test file type detection""" """Test file type detection."""
async def test_detect_python_type(self, file_scanner, temp_workspace): async def test_detect_python_type(self, file_scanner: FileScanner, temp_workspace: Path) -> None:
"""Test detection of Python file type""" """Test detection of Python file type."""
(temp_workspace / "script.py").write_text("print('hello')") (temp_workspace / "script.py").write_text("print('hello')")
config = {"patterns": ["*.py"]} config = {"patterns": ["*.py"]}
@@ -248,8 +251,8 @@ class TestFileScannerFileTypes:
assert len(py_findings) > 0 assert len(py_findings) > 0
assert "python" in py_findings[0].metadata["file_type"] assert "python" in py_findings[0].metadata["file_type"]
async def test_detect_javascript_type(self, file_scanner, temp_workspace): async def test_detect_javascript_type(self, file_scanner: FileScanner, temp_workspace: Path) -> None:
"""Test detection of JavaScript file type""" """Test detection of JavaScript file type."""
(temp_workspace / "app.js").write_text("console.log('hello')") (temp_workspace / "app.js").write_text("console.log('hello')")
config = {"patterns": ["*.js"]} config = {"patterns": ["*.js"]}
@@ -260,8 +263,8 @@ class TestFileScannerFileTypes:
assert len(js_findings) > 0 assert len(js_findings) > 0
assert "javascript" in js_findings[0].metadata["file_type"] assert "javascript" in js_findings[0].metadata["file_type"]
async def test_file_type_summary(self, file_scanner, temp_workspace): async def test_file_type_summary(self, file_scanner: FileScanner, temp_workspace: Path) -> None:
"""Test that file type summary is generated""" """Test that file type summary is generated."""
(temp_workspace / "script.py").write_text("print('hello')") (temp_workspace / "script.py").write_text("print('hello')")
(temp_workspace / "app.js").write_text("console.log('hello')") (temp_workspace / "app.js").write_text("console.log('hello')")
(temp_workspace / "readme.txt").write_text("Documentation") (temp_workspace / "readme.txt").write_text("Documentation")
@@ -276,17 +279,17 @@ class TestFileScannerFileTypes:
@pytest.mark.asyncio @pytest.mark.asyncio
class TestFileScannerSizeLimits: class TestFileScannerSizeLimits:
"""Test file size handling""" """Test file size handling."""
async def test_skip_large_files(self, file_scanner, temp_workspace): async def test_skip_large_files(self, file_scanner: FileScanner, temp_workspace: Path) -> None:
"""Test that large files are skipped""" """Test that large files are skipped."""
# Create a "large" file # Create a "large" file
large_file = temp_workspace / "large.txt" large_file = temp_workspace / "large.txt"
large_file.write_text("x" * 1000) large_file.write_text("x" * 1000)
config = { config = {
"patterns": ["*.txt"], "patterns": ["*.txt"],
"max_file_size": 500 # Set limit smaller than file "max_file_size": 500, # Set limit smaller than file
} }
result = await file_scanner.execute(config, temp_workspace) result = await file_scanner.execute(config, temp_workspace)
@@ -297,14 +300,14 @@ class TestFileScannerSizeLimits:
# The file should still be counted but not have a detailed finding # The file should still be counted but not have a detailed finding
assert result.summary["total_files"] > 0 assert result.summary["total_files"] > 0
async def test_process_small_files(self, file_scanner, temp_workspace): async def test_process_small_files(self, file_scanner: FileScanner, temp_workspace: Path) -> None:
"""Test that small files are processed""" """Test that small files are processed."""
small_file = temp_workspace / "small.txt" small_file = temp_workspace / "small.txt"
small_file.write_text("small content") small_file.write_text("small content")
config = { config = {
"patterns": ["*.txt"], "patterns": ["*.txt"],
"max_file_size": 1048576 # 1MB "max_file_size": 1048576, # 1MB
} }
result = await file_scanner.execute(config, temp_workspace) result = await file_scanner.execute(config, temp_workspace)
@@ -316,10 +319,10 @@ class TestFileScannerSizeLimits:
@pytest.mark.asyncio @pytest.mark.asyncio
class TestFileScannerSummary: class TestFileScannerSummary:
"""Test result summary generation""" """Test result summary generation."""
async def test_summary_structure(self, file_scanner, python_test_workspace): async def test_summary_structure(self, file_scanner: FileScanner, python_test_workspace: Path) -> None:
"""Test that summary has correct structure""" """Test that summary has correct structure."""
config = {"patterns": ["*"]} config = {"patterns": ["*"]}
result = await file_scanner.execute(config, python_test_workspace) result = await file_scanner.execute(config, python_test_workspace)
@@ -334,8 +337,8 @@ class TestFileScannerSummary:
assert isinstance(result.summary["file_types"], dict) assert isinstance(result.summary["file_types"], dict)
assert isinstance(result.summary["patterns_scanned"], list) assert isinstance(result.summary["patterns_scanned"], list)
async def test_summary_counts(self, file_scanner, temp_workspace): async def test_summary_counts(self, file_scanner: FileScanner, temp_workspace: Path) -> None:
"""Test that summary counts are accurate""" """Test that summary counts are accurate."""
# Create known files # Create known files
(temp_workspace / "file1.py").write_text("content1") (temp_workspace / "file1.py").write_text("content1")
(temp_workspace / "file2.py").write_text("content2") (temp_workspace / "file2.py").write_text("content2")

View File

@@ -1,28 +1,25 @@
""" """Unit tests for SecurityAnalyzer module."""
Unit tests for SecurityAnalyzer module
""" from __future__ import annotations
import pytest
import sys import sys
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING
import pytest
sys.path.insert(0, str(Path(__file__).resolve().parents[3] / "toolbox")) sys.path.insert(0, str(Path(__file__).resolve().parents[3] / "toolbox"))
from modules.analyzer.security_analyzer import SecurityAnalyzer if TYPE_CHECKING:
from modules.analyzer.security_analyzer import SecurityAnalyzer
@pytest.fixture
def security_analyzer():
"""Create SecurityAnalyzer instance"""
return SecurityAnalyzer()
@pytest.mark.asyncio @pytest.mark.asyncio
class TestSecurityAnalyzerMetadata: class TestSecurityAnalyzerMetadata:
"""Test SecurityAnalyzer metadata""" """Test SecurityAnalyzer metadata."""
async def test_metadata_structure(self, security_analyzer): async def test_metadata_structure(self, security_analyzer: SecurityAnalyzer) -> None:
"""Test that metadata has correct structure""" """Test that metadata has correct structure."""
metadata = security_analyzer.get_metadata() metadata = security_analyzer.get_metadata()
assert metadata.name == "security_analyzer" assert metadata.name == "security_analyzer"
@@ -35,25 +32,25 @@ class TestSecurityAnalyzerMetadata:
@pytest.mark.asyncio @pytest.mark.asyncio
class TestSecurityAnalyzerConfigValidation: class TestSecurityAnalyzerConfigValidation:
"""Test configuration validation""" """Test configuration validation."""
async def test_valid_config(self, security_analyzer): async def test_valid_config(self, security_analyzer: SecurityAnalyzer) -> None:
"""Test that valid config passes validation""" """Test that valid config passes validation."""
config = { config = {
"file_extensions": [".py", ".js"], "file_extensions": [".py", ".js"],
"check_secrets": True, "check_secrets": True,
"check_sql": True, "check_sql": True,
"check_dangerous_functions": True "check_dangerous_functions": True,
} }
assert security_analyzer.validate_config(config) is True assert security_analyzer.validate_config(config) is True
async def test_default_config(self, security_analyzer): async def test_default_config(self, security_analyzer: SecurityAnalyzer) -> None:
"""Test that empty config uses defaults""" """Test that empty config uses defaults."""
config = {} config = {}
assert security_analyzer.validate_config(config) is True assert security_analyzer.validate_config(config) is True
async def test_invalid_extensions_type(self, security_analyzer): async def test_invalid_extensions_type(self, security_analyzer: SecurityAnalyzer) -> None:
"""Test that non-list extensions raises error""" """Test that non-list extensions raises error."""
config = {"file_extensions": ".py"} config = {"file_extensions": ".py"}
with pytest.raises(ValueError, match="file_extensions must be a list"): with pytest.raises(ValueError, match="file_extensions must be a list"):
security_analyzer.validate_config(config) security_analyzer.validate_config(config)
@@ -61,10 +58,10 @@ class TestSecurityAnalyzerConfigValidation:
@pytest.mark.asyncio @pytest.mark.asyncio
class TestSecurityAnalyzerSecretDetection: class TestSecurityAnalyzerSecretDetection:
"""Test hardcoded secret detection""" """Test hardcoded secret detection."""
async def test_detect_api_key(self, security_analyzer, temp_workspace): async def test_detect_api_key(self, security_analyzer: SecurityAnalyzer, temp_workspace: Path) -> None:
"""Test detection of hardcoded API key""" """Test detection of hardcoded API key."""
code_file = temp_workspace / "config.py" code_file = temp_workspace / "config.py"
code_file.write_text(""" code_file.write_text("""
# Configuration file # Configuration file
@@ -76,7 +73,7 @@ database_url = "postgresql://localhost/db"
"file_extensions": [".py"], "file_extensions": [".py"],
"check_secrets": True, "check_secrets": True,
"check_sql": False, "check_sql": False,
"check_dangerous_functions": False "check_dangerous_functions": False,
} }
result = await security_analyzer.execute(config, temp_workspace) result = await security_analyzer.execute(config, temp_workspace)
@@ -86,8 +83,8 @@ database_url = "postgresql://localhost/db"
assert len(secret_findings) > 0 assert len(secret_findings) > 0
assert any("API Key" in f.title for f in secret_findings) assert any("API Key" in f.title for f in secret_findings)
async def test_detect_password(self, security_analyzer, temp_workspace): async def test_detect_password(self, security_analyzer: SecurityAnalyzer, temp_workspace: Path) -> None:
"""Test detection of hardcoded password""" """Test detection of hardcoded password."""
code_file = temp_workspace / "auth.py" code_file = temp_workspace / "auth.py"
code_file.write_text(""" code_file.write_text("""
def connect(): def connect():
@@ -99,7 +96,7 @@ def connect():
"file_extensions": [".py"], "file_extensions": [".py"],
"check_secrets": True, "check_secrets": True,
"check_sql": False, "check_sql": False,
"check_dangerous_functions": False "check_dangerous_functions": False,
} }
result = await security_analyzer.execute(config, temp_workspace) result = await security_analyzer.execute(config, temp_workspace)
@@ -108,8 +105,8 @@ def connect():
secret_findings = [f for f in result.findings if f.category == "hardcoded_secret"] secret_findings = [f for f in result.findings if f.category == "hardcoded_secret"]
assert len(secret_findings) > 0 assert len(secret_findings) > 0
async def test_detect_aws_credentials(self, security_analyzer, temp_workspace): async def test_detect_aws_credentials(self, security_analyzer: SecurityAnalyzer, temp_workspace: Path) -> None:
"""Test detection of AWS credentials""" """Test detection of AWS credentials."""
code_file = temp_workspace / "aws_config.py" code_file = temp_workspace / "aws_config.py"
code_file.write_text(""" code_file.write_text("""
aws_access_key = "AKIAIOSFODNN7REALKEY" aws_access_key = "AKIAIOSFODNN7REALKEY"
@@ -118,7 +115,7 @@ aws_secret_key = "wJalrXUtnFEMI/K7MDENG/bPxRfiCYREALKEY"
config = { config = {
"file_extensions": [".py"], "file_extensions": [".py"],
"check_secrets": True "check_secrets": True,
} }
result = await security_analyzer.execute(config, temp_workspace) result = await security_analyzer.execute(config, temp_workspace)
@@ -127,14 +124,18 @@ aws_secret_key = "wJalrXUtnFEMI/K7MDENG/bPxRfiCYREALKEY"
aws_findings = [f for f in result.findings if "AWS" in f.title] aws_findings = [f for f in result.findings if "AWS" in f.title]
assert len(aws_findings) >= 2 # Both access key and secret key assert len(aws_findings) >= 2 # Both access key and secret key
async def test_no_secret_detection_when_disabled(self, security_analyzer, temp_workspace): async def test_no_secret_detection_when_disabled(
"""Test that secret detection can be disabled""" self,
security_analyzer: SecurityAnalyzer,
temp_workspace: Path,
) -> None:
"""Test that secret detection can be disabled."""
code_file = temp_workspace / "config.py" code_file = temp_workspace / "config.py"
code_file.write_text('api_key = "sk_live_1234567890abcdef"') code_file.write_text('api_key = "sk_live_1234567890abcdef"')
config = { config = {
"file_extensions": [".py"], "file_extensions": [".py"],
"check_secrets": False "check_secrets": False,
} }
result = await security_analyzer.execute(config, temp_workspace) result = await security_analyzer.execute(config, temp_workspace)
@@ -146,10 +147,10 @@ aws_secret_key = "wJalrXUtnFEMI/K7MDENG/bPxRfiCYREALKEY"
@pytest.mark.asyncio @pytest.mark.asyncio
class TestSecurityAnalyzerSQLInjection: class TestSecurityAnalyzerSQLInjection:
"""Test SQL injection detection""" """Test SQL injection detection."""
async def test_detect_string_concatenation(self, security_analyzer, temp_workspace): async def test_detect_string_concatenation(self, security_analyzer: SecurityAnalyzer, temp_workspace: Path) -> None:
"""Test detection of SQL string concatenation""" """Test detection of SQL string concatenation."""
code_file = temp_workspace / "db.py" code_file = temp_workspace / "db.py"
code_file.write_text(""" code_file.write_text("""
def get_user(user_id): def get_user(user_id):
@@ -161,7 +162,7 @@ def get_user(user_id):
"file_extensions": [".py"], "file_extensions": [".py"],
"check_secrets": False, "check_secrets": False,
"check_sql": True, "check_sql": True,
"check_dangerous_functions": False "check_dangerous_functions": False,
} }
result = await security_analyzer.execute(config, temp_workspace) result = await security_analyzer.execute(config, temp_workspace)
@@ -170,8 +171,8 @@ def get_user(user_id):
sql_findings = [f for f in result.findings if f.category == "sql_injection"] sql_findings = [f for f in result.findings if f.category == "sql_injection"]
assert len(sql_findings) > 0 assert len(sql_findings) > 0
async def test_detect_f_string_sql(self, security_analyzer, temp_workspace): async def test_detect_f_string_sql(self, security_analyzer: SecurityAnalyzer, temp_workspace: Path) -> None:
"""Test detection of f-string in SQL""" """Test detection of f-string in SQL."""
code_file = temp_workspace / "db.py" code_file = temp_workspace / "db.py"
code_file.write_text(""" code_file.write_text("""
def get_user(name): def get_user(name):
@@ -181,7 +182,7 @@ def get_user(name):
config = { config = {
"file_extensions": [".py"], "file_extensions": [".py"],
"check_sql": True "check_sql": True,
} }
result = await security_analyzer.execute(config, temp_workspace) result = await security_analyzer.execute(config, temp_workspace)
@@ -190,8 +191,12 @@ def get_user(name):
sql_findings = [f for f in result.findings if f.category == "sql_injection"] sql_findings = [f for f in result.findings if f.category == "sql_injection"]
assert len(sql_findings) > 0 assert len(sql_findings) > 0
async def test_detect_dynamic_query_building(self, security_analyzer, temp_workspace): async def test_detect_dynamic_query_building(
"""Test detection of dynamic query building""" self,
security_analyzer: SecurityAnalyzer,
temp_workspace: Path,
) -> None:
"""Test detection of dynamic query building."""
code_file = temp_workspace / "queries.py" code_file = temp_workspace / "queries.py"
code_file.write_text(""" code_file.write_text("""
def search(keyword): def search(keyword):
@@ -201,7 +206,7 @@ def search(keyword):
config = { config = {
"file_extensions": [".py"], "file_extensions": [".py"],
"check_sql": True "check_sql": True,
} }
result = await security_analyzer.execute(config, temp_workspace) result = await security_analyzer.execute(config, temp_workspace)
@@ -210,14 +215,18 @@ def search(keyword):
sql_findings = [f for f in result.findings if f.category == "sql_injection"] sql_findings = [f for f in result.findings if f.category == "sql_injection"]
assert len(sql_findings) > 0 assert len(sql_findings) > 0
async def test_no_sql_detection_when_disabled(self, security_analyzer, temp_workspace): async def test_no_sql_detection_when_disabled(
"""Test that SQL detection can be disabled""" self,
security_analyzer: SecurityAnalyzer,
temp_workspace: Path,
) -> None:
"""Test that SQL detection can be disabled."""
code_file = temp_workspace / "db.py" code_file = temp_workspace / "db.py"
code_file.write_text('query = "SELECT * FROM users WHERE id = " + user_id') code_file.write_text('query = "SELECT * FROM users WHERE id = " + user_id')
config = { config = {
"file_extensions": [".py"], "file_extensions": [".py"],
"check_sql": False "check_sql": False,
} }
result = await security_analyzer.execute(config, temp_workspace) result = await security_analyzer.execute(config, temp_workspace)
@@ -229,10 +238,10 @@ def search(keyword):
@pytest.mark.asyncio @pytest.mark.asyncio
class TestSecurityAnalyzerDangerousFunctions: class TestSecurityAnalyzerDangerousFunctions:
"""Test dangerous function detection""" """Test dangerous function detection."""
async def test_detect_eval(self, security_analyzer, temp_workspace): async def test_detect_eval(self, security_analyzer: SecurityAnalyzer, temp_workspace: Path) -> None:
"""Test detection of eval() usage""" """Test detection of eval() usage."""
code_file = temp_workspace / "dangerous.py" code_file = temp_workspace / "dangerous.py"
code_file.write_text(""" code_file.write_text("""
def process_input(user_input): def process_input(user_input):
@@ -244,7 +253,7 @@ def process_input(user_input):
"file_extensions": [".py"], "file_extensions": [".py"],
"check_secrets": False, "check_secrets": False,
"check_sql": False, "check_sql": False,
"check_dangerous_functions": True "check_dangerous_functions": True,
} }
result = await security_analyzer.execute(config, temp_workspace) result = await security_analyzer.execute(config, temp_workspace)
@@ -254,8 +263,8 @@ def process_input(user_input):
assert len(dangerous_findings) > 0 assert len(dangerous_findings) > 0
assert any("eval" in f.title.lower() for f in dangerous_findings) assert any("eval" in f.title.lower() for f in dangerous_findings)
async def test_detect_exec(self, security_analyzer, temp_workspace): async def test_detect_exec(self, security_analyzer: SecurityAnalyzer, temp_workspace: Path) -> None:
"""Test detection of exec() usage""" """Test detection of exec() usage."""
code_file = temp_workspace / "runner.py" code_file = temp_workspace / "runner.py"
code_file.write_text(""" code_file.write_text("""
def run_code(code): def run_code(code):
@@ -264,7 +273,7 @@ def run_code(code):
config = { config = {
"file_extensions": [".py"], "file_extensions": [".py"],
"check_dangerous_functions": True "check_dangerous_functions": True,
} }
result = await security_analyzer.execute(config, temp_workspace) result = await security_analyzer.execute(config, temp_workspace)
@@ -273,8 +282,8 @@ def run_code(code):
dangerous_findings = [f for f in result.findings if f.category == "dangerous_function"] dangerous_findings = [f for f in result.findings if f.category == "dangerous_function"]
assert len(dangerous_findings) > 0 assert len(dangerous_findings) > 0
async def test_detect_os_system(self, security_analyzer, temp_workspace): async def test_detect_os_system(self, security_analyzer: SecurityAnalyzer, temp_workspace: Path) -> None:
"""Test detection of os.system() usage""" """Test detection of os.system() usage."""
code_file = temp_workspace / "commands.py" code_file = temp_workspace / "commands.py"
code_file.write_text(""" code_file.write_text("""
import os import os
@@ -285,7 +294,7 @@ def run_command(cmd):
config = { config = {
"file_extensions": [".py"], "file_extensions": [".py"],
"check_dangerous_functions": True "check_dangerous_functions": True,
} }
result = await security_analyzer.execute(config, temp_workspace) result = await security_analyzer.execute(config, temp_workspace)
@@ -295,8 +304,8 @@ def run_command(cmd):
assert len(dangerous_findings) > 0 assert len(dangerous_findings) > 0
assert any("os.system" in f.title for f in dangerous_findings) assert any("os.system" in f.title for f in dangerous_findings)
async def test_detect_pickle_loads(self, security_analyzer, temp_workspace): async def test_detect_pickle_loads(self, security_analyzer: SecurityAnalyzer, temp_workspace: Path) -> None:
"""Test detection of pickle.loads() usage""" """Test detection of pickle.loads() usage."""
code_file = temp_workspace / "serializer.py" code_file = temp_workspace / "serializer.py"
code_file.write_text(""" code_file.write_text("""
import pickle import pickle
@@ -307,7 +316,7 @@ def deserialize(data):
config = { config = {
"file_extensions": [".py"], "file_extensions": [".py"],
"check_dangerous_functions": True "check_dangerous_functions": True,
} }
result = await security_analyzer.execute(config, temp_workspace) result = await security_analyzer.execute(config, temp_workspace)
@@ -316,8 +325,8 @@ def deserialize(data):
dangerous_findings = [f for f in result.findings if f.category == "dangerous_function"] dangerous_findings = [f for f in result.findings if f.category == "dangerous_function"]
assert len(dangerous_findings) > 0 assert len(dangerous_findings) > 0
async def test_detect_javascript_eval(self, security_analyzer, temp_workspace): async def test_detect_javascript_eval(self, security_analyzer: SecurityAnalyzer, temp_workspace: Path) -> None:
"""Test detection of eval() in JavaScript""" """Test detection of eval() in JavaScript."""
code_file = temp_workspace / "app.js" code_file = temp_workspace / "app.js"
code_file.write_text(""" code_file.write_text("""
function processInput(userInput) { function processInput(userInput) {
@@ -327,7 +336,7 @@ function processInput(userInput) {
config = { config = {
"file_extensions": [".js"], "file_extensions": [".js"],
"check_dangerous_functions": True "check_dangerous_functions": True,
} }
result = await security_analyzer.execute(config, temp_workspace) result = await security_analyzer.execute(config, temp_workspace)
@@ -336,8 +345,8 @@ function processInput(userInput) {
dangerous_findings = [f for f in result.findings if f.category == "dangerous_function"] dangerous_findings = [f for f in result.findings if f.category == "dangerous_function"]
assert len(dangerous_findings) > 0 assert len(dangerous_findings) > 0
async def test_detect_innerHTML(self, security_analyzer, temp_workspace): async def test_detect_inner_html(self, security_analyzer: SecurityAnalyzer, temp_workspace: Path) -> None:
"""Test detection of innerHTML (XSS risk)""" """Test detection of innerHTML (XSS risk)."""
code_file = temp_workspace / "dom.js" code_file = temp_workspace / "dom.js"
code_file.write_text(""" code_file.write_text("""
function updateContent(html) { function updateContent(html) {
@@ -347,7 +356,7 @@ function updateContent(html) {
config = { config = {
"file_extensions": [".js"], "file_extensions": [".js"],
"check_dangerous_functions": True "check_dangerous_functions": True,
} }
result = await security_analyzer.execute(config, temp_workspace) result = await security_analyzer.execute(config, temp_workspace)
@@ -356,14 +365,18 @@ function updateContent(html) {
dangerous_findings = [f for f in result.findings if f.category == "dangerous_function"] dangerous_findings = [f for f in result.findings if f.category == "dangerous_function"]
assert len(dangerous_findings) > 0 assert len(dangerous_findings) > 0
async def test_no_dangerous_detection_when_disabled(self, security_analyzer, temp_workspace): async def test_no_dangerous_detection_when_disabled(
"""Test that dangerous function detection can be disabled""" self,
security_analyzer: SecurityAnalyzer,
temp_workspace: Path,
) -> None:
"""Test that dangerous function detection can be disabled."""
code_file = temp_workspace / "code.py" code_file = temp_workspace / "code.py"
code_file.write_text('result = eval(user_input)') code_file.write_text("result = eval(user_input)")
config = { config = {
"file_extensions": [".py"], "file_extensions": [".py"],
"check_dangerous_functions": False "check_dangerous_functions": False,
} }
result = await security_analyzer.execute(config, temp_workspace) result = await security_analyzer.execute(config, temp_workspace)
@@ -375,10 +388,14 @@ function updateContent(html) {
@pytest.mark.asyncio @pytest.mark.asyncio
class TestSecurityAnalyzerMultipleIssues: class TestSecurityAnalyzerMultipleIssues:
"""Test detection of multiple issues in same file""" """Test detection of multiple issues in same file."""
async def test_detect_multiple_vulnerabilities(self, security_analyzer, temp_workspace): async def test_detect_multiple_vulnerabilities(
"""Test detection of multiple vulnerability types""" self,
security_analyzer: SecurityAnalyzer,
temp_workspace: Path,
) -> None:
"""Test detection of multiple vulnerability types."""
code_file = temp_workspace / "vulnerable.py" code_file = temp_workspace / "vulnerable.py"
code_file.write_text(""" code_file.write_text("""
import os import os
@@ -404,7 +421,7 @@ def process_query(user_input):
"file_extensions": [".py"], "file_extensions": [".py"],
"check_secrets": True, "check_secrets": True,
"check_sql": True, "check_sql": True,
"check_dangerous_functions": True "check_dangerous_functions": True,
} }
result = await security_analyzer.execute(config, temp_workspace) result = await security_analyzer.execute(config, temp_workspace)
@@ -423,10 +440,10 @@ def process_query(user_input):
@pytest.mark.asyncio @pytest.mark.asyncio
class TestSecurityAnalyzerSummary: class TestSecurityAnalyzerSummary:
"""Test result summary generation""" """Test result summary generation."""
async def test_summary_structure(self, security_analyzer, temp_workspace): async def test_summary_structure(self, security_analyzer: SecurityAnalyzer, temp_workspace: Path) -> None:
"""Test that summary has correct structure""" """Test that summary has correct structure."""
(temp_workspace / "test.py").write_text("print('hello')") (temp_workspace / "test.py").write_text("print('hello')")
config = {"file_extensions": [".py"]} config = {"file_extensions": [".py"]}
@@ -441,16 +458,16 @@ class TestSecurityAnalyzerSummary:
assert isinstance(result.summary["total_findings"], int) assert isinstance(result.summary["total_findings"], int)
assert isinstance(result.summary["extensions_scanned"], list) assert isinstance(result.summary["extensions_scanned"], list)
async def test_empty_workspace(self, security_analyzer, temp_workspace): async def test_empty_workspace(self, security_analyzer: SecurityAnalyzer, temp_workspace: Path) -> None:
"""Test analyzing empty workspace""" """Test analyzing empty workspace."""
config = {"file_extensions": [".py"]} config = {"file_extensions": [".py"]}
result = await security_analyzer.execute(config, temp_workspace) result = await security_analyzer.execute(config, temp_workspace)
assert result.status == "partial" # No files found assert result.status == "partial" # No files found
assert result.summary["files_analyzed"] == 0 assert result.summary["files_analyzed"] == 0
async def test_analyze_multiple_file_types(self, security_analyzer, temp_workspace): async def test_analyze_multiple_file_types(self, security_analyzer: SecurityAnalyzer, temp_workspace: Path) -> None:
"""Test analyzing multiple file types""" """Test analyzing multiple file types."""
(temp_workspace / "app.py").write_text("print('hello')") (temp_workspace / "app.py").write_text("print('hello')")
(temp_workspace / "script.js").write_text("console.log('hello')") (temp_workspace / "script.js").write_text("console.log('hello')")
(temp_workspace / "index.php").write_text("<?php echo 'hello'; ?>") (temp_workspace / "index.php").write_text("<?php echo 'hello'; ?>")
@@ -464,10 +481,10 @@ class TestSecurityAnalyzerSummary:
@pytest.mark.asyncio @pytest.mark.asyncio
class TestSecurityAnalyzerFalsePositives: class TestSecurityAnalyzerFalsePositives:
"""Test false positive filtering""" """Test false positive filtering."""
async def test_skip_test_secrets(self, security_analyzer, temp_workspace): async def test_skip_test_secrets(self, security_analyzer: SecurityAnalyzer, temp_workspace: Path) -> None:
"""Test that test/example secrets are filtered""" """Test that test/example secrets are filtered."""
code_file = temp_workspace / "test_config.py" code_file = temp_workspace / "test_config.py"
code_file.write_text(""" code_file.write_text("""
# Test configuration - should be filtered # Test configuration - should be filtered
@@ -478,7 +495,7 @@ token = "sample_token_placeholder"
config = { config = {
"file_extensions": [".py"], "file_extensions": [".py"],
"check_secrets": True "check_secrets": True,
} }
result = await security_analyzer.execute(config, temp_workspace) result = await security_analyzer.execute(config, temp_workspace)
@@ -488,6 +505,6 @@ token = "sample_token_placeholder"
secret_findings = [f for f in result.findings if f.category == "hardcoded_secret"] secret_findings = [f for f in result.findings if f.category == "hardcoded_secret"]
# Should have fewer or no findings due to false positive filtering # Should have fewer or no findings due to false positive filtering
assert len(secret_findings) == 0 or all( assert len(secret_findings) == 0 or all(
not any(fp in f.description.lower() for fp in ['test', 'example', 'dummy', 'sample']) not any(fp in f.description.lower() for fp in ["test", "example", "dummy", "sample"])
for f in secret_findings for f in secret_findings
) )