From 40bbb18795a122bb6a79e58c582aebb3b5e3108d Mon Sep 17 00:00:00 2001 From: fztee Date: Mon, 10 Nov 2025 17:01:42 +0100 Subject: [PATCH] chore: improve code quality (backend package). - add configuration file for 'ruff'. - fix most of 'ruff' lints. - format 'backend' package using 'ruff'. --- backend/ruff.toml | 11 + backend/src/api/fuzzing.py | 103 +++-- backend/src/api/runs.py | 113 +++--- backend/src/api/system.py | 10 +- backend/src/api/workflows.py | 380 +++++++++--------- backend/src/core/setup.py | 14 +- backend/src/main.py | 156 +++---- backend/src/models/findings.py | 80 ++-- backend/src/storage/__init__.py | 5 +- backend/src/storage/base.py | 58 ++- backend/src/storage/s3_cached.py | 346 ++++++++-------- backend/src/temporal/__init__.py | 5 +- backend/src/temporal/discovery.py | 150 +++---- backend/src/temporal/manager.py | 227 ++++++----- backend/tests/conftest.py | 110 +++-- backend/tests/fixtures/__init__.py | 0 backend/tests/integration/__init__.py | 0 backend/tests/unit/__init__.py | 1 + backend/tests/unit/test_api/__init__.py | 0 backend/tests/unit/test_modules/__init__.py | 1 + .../unit/test_modules/test_atheris_fuzzer.py | 136 ++++--- .../unit/test_modules/test_cargo_fuzzer.py | 184 +++++---- .../unit/test_modules/test_file_scanner.py | 141 +++---- .../test_modules/test_security_analyzer.py | 191 +++++---- backend/tests/unit/test_workflows/__init__.py | 0 25 files changed, 1273 insertions(+), 1149 deletions(-) create mode 100644 backend/ruff.toml delete mode 100644 backend/tests/fixtures/__init__.py delete mode 100644 backend/tests/integration/__init__.py delete mode 100644 backend/tests/unit/test_api/__init__.py delete mode 100644 backend/tests/unit/test_workflows/__init__.py diff --git a/backend/ruff.toml b/backend/ruff.toml new file mode 100644 index 0000000..3231a98 --- /dev/null +++ b/backend/ruff.toml @@ -0,0 +1,11 @@ +line-length = 120 + +[lint] +select = [ "ALL" ] +ignore = [] + +[lint.per-file-ignores] +"tests/*" = [ + "PLR2004", # allowing comparisons using unamed numerical constants in tests + "S101", # allowing 'assert' statements in tests +] \ No newline at end of file diff --git a/backend/src/api/fuzzing.py b/backend/src/api/fuzzing.py index 166319a..d18cb3f 100644 --- a/backend/src/api/fuzzing.py +++ b/backend/src/api/fuzzing.py @@ -1,6 +1,4 @@ -""" -API endpoints for fuzzing workflow management and real-time monitoring -""" +"""API endpoints for fuzzing workflow management and real-time monitoring.""" # Copyright (c) 2025 FuzzingLabs # @@ -13,32 +11,29 @@ API endpoints for fuzzing workflow management and real-time monitoring # # Additional attribution and requirements are provided in the NOTICE file. -import logging -from typing import List, Dict -from fastapi import APIRouter, HTTPException, WebSocket, WebSocketDisconnect -from fastapi.responses import StreamingResponse import asyncio +import contextlib import json +import logging from datetime import datetime -from src.models.findings import ( - FuzzingStats, - CrashReport -) +from fastapi import APIRouter, HTTPException, WebSocket, WebSocketDisconnect +from fastapi.responses import StreamingResponse + +from src.models.findings import CrashReport, FuzzingStats logger = logging.getLogger(__name__) router = APIRouter(prefix="/fuzzing", tags=["fuzzing"]) # In-memory storage for real-time stats (in production, use Redis or similar) -fuzzing_stats: Dict[str, FuzzingStats] = {} -crash_reports: Dict[str, List[CrashReport]] = {} -active_connections: Dict[str, List[WebSocket]] = {} +fuzzing_stats: dict[str, FuzzingStats] = {} +crash_reports: dict[str, list[CrashReport]] = {} +active_connections: dict[str, list[WebSocket]] = {} -def initialize_fuzzing_tracking(run_id: str, workflow_name: str): - """ - Initialize fuzzing tracking for a new run. +def initialize_fuzzing_tracking(run_id: str, workflow_name: str) -> None: + """Initialize fuzzing tracking for a new run. This function should be called when a workflow is submitted to enable real-time monitoring and stats collection. @@ -46,19 +41,19 @@ def initialize_fuzzing_tracking(run_id: str, workflow_name: str): Args: run_id: The run identifier workflow_name: Name of the workflow + """ fuzzing_stats[run_id] = FuzzingStats( run_id=run_id, - workflow=workflow_name + workflow=workflow_name, ) crash_reports[run_id] = [] active_connections[run_id] = [] -@router.get("/{run_id}/stats", response_model=FuzzingStats) +@router.get("/{run_id}/stats") async def get_fuzzing_stats(run_id: str) -> FuzzingStats: - """ - Get current fuzzing statistics for a run. + """Get current fuzzing statistics for a run. Args: run_id: The fuzzing run ID @@ -68,20 +63,20 @@ async def get_fuzzing_stats(run_id: str) -> FuzzingStats: Raises: HTTPException: 404 if run not found + """ if run_id not in fuzzing_stats: raise HTTPException( status_code=404, - detail=f"Fuzzing run not found: {run_id}" + detail=f"Fuzzing run not found: {run_id}", ) return fuzzing_stats[run_id] -@router.get("/{run_id}/crashes", response_model=List[CrashReport]) -async def get_crash_reports(run_id: str) -> List[CrashReport]: - """ - Get crash reports for a fuzzing run. +@router.get("/{run_id}/crashes") +async def get_crash_reports(run_id: str) -> list[CrashReport]: + """Get crash reports for a fuzzing run. Args: run_id: The fuzzing run ID @@ -91,11 +86,12 @@ async def get_crash_reports(run_id: str) -> List[CrashReport]: Raises: HTTPException: 404 if run not found + """ if run_id not in crash_reports: raise HTTPException( status_code=404, - detail=f"Fuzzing run not found: {run_id}" + detail=f"Fuzzing run not found: {run_id}", ) return crash_reports[run_id] @@ -103,8 +99,7 @@ async def get_crash_reports(run_id: str) -> List[CrashReport]: @router.post("/{run_id}/stats") async def update_fuzzing_stats(run_id: str, stats: FuzzingStats): - """ - Update fuzzing statistics (called by fuzzing workflows). + """Update fuzzing statistics (called by fuzzing workflows). Args: run_id: The fuzzing run ID @@ -112,18 +107,19 @@ async def update_fuzzing_stats(run_id: str, stats: FuzzingStats): Raises: HTTPException: 404 if run not found + """ if run_id not in fuzzing_stats: raise HTTPException( status_code=404, - detail=f"Fuzzing run not found: {run_id}" + detail=f"Fuzzing run not found: {run_id}", ) # Update stats fuzzing_stats[run_id] = stats # Debug: log reception for live instrumentation - try: + with contextlib.suppress(Exception): logger.info( "Received fuzzing stats update: run_id=%s exec=%s eps=%.2f crashes=%s corpus=%s coverage=%s elapsed=%ss", run_id, @@ -134,14 +130,12 @@ async def update_fuzzing_stats(run_id: str, stats: FuzzingStats): stats.coverage, stats.elapsed_time, ) - except Exception: - pass # Notify connected WebSocket clients if run_id in active_connections: message = { "type": "stats_update", - "data": stats.model_dump() + "data": stats.model_dump(), } for websocket in active_connections[run_id][:]: # Copy to avoid modification during iteration try: @@ -153,12 +147,12 @@ async def update_fuzzing_stats(run_id: str, stats: FuzzingStats): @router.post("/{run_id}/crash") async def report_crash(run_id: str, crash: CrashReport): - """ - Report a new crash (called by fuzzing workflows). + """Report a new crash (called by fuzzing workflows). Args: run_id: The fuzzing run ID crash: Crash report details + """ if run_id not in crash_reports: crash_reports[run_id] = [] @@ -175,7 +169,7 @@ async def report_crash(run_id: str, crash: CrashReport): if run_id in active_connections: message = { "type": "crash_report", - "data": crash.model_dump() + "data": crash.model_dump(), } for websocket in active_connections[run_id][:]: try: @@ -186,12 +180,12 @@ async def report_crash(run_id: str, crash: CrashReport): @router.websocket("/{run_id}/live") async def websocket_endpoint(websocket: WebSocket, run_id: str): - """ - WebSocket endpoint for real-time fuzzing updates. + """WebSocket endpoint for real-time fuzzing updates. Args: websocket: WebSocket connection run_id: The fuzzing run ID to monitor + """ await websocket.accept() @@ -223,7 +217,7 @@ async def websocket_endpoint(websocket: WebSocket, run_id: str): # Echo back for ping-pong if data == "ping": await websocket.send_text("pong") - except asyncio.TimeoutError: + except TimeoutError: # Send periodic heartbeat await websocket.send_text(json.dumps({"type": "heartbeat"})) @@ -231,31 +225,31 @@ async def websocket_endpoint(websocket: WebSocket, run_id: str): # Clean up connection if run_id in active_connections and websocket in active_connections[run_id]: active_connections[run_id].remove(websocket) - except Exception as e: - logger.error(f"WebSocket error for run {run_id}: {e}") + except Exception: + logger.exception("WebSocket error for run %s", run_id) if run_id in active_connections and websocket in active_connections[run_id]: active_connections[run_id].remove(websocket) @router.get("/{run_id}/stream") async def stream_fuzzing_updates(run_id: str): - """ - Server-Sent Events endpoint for real-time fuzzing updates. + """Server-Sent Events endpoint for real-time fuzzing updates. Args: run_id: The fuzzing run ID to monitor Returns: Streaming response with real-time updates + """ if run_id not in fuzzing_stats: raise HTTPException( status_code=404, - detail=f"Fuzzing run not found: {run_id}" + detail=f"Fuzzing run not found: {run_id}", ) async def event_stream(): - """Generate server-sent events for fuzzing updates""" + """Generate server-sent events for fuzzing updates.""" last_stats_time = datetime.utcnow() while True: @@ -276,10 +270,7 @@ async def stream_fuzzing_updates(run_id: str): # Send recent crashes if run_id in crash_reports: - recent_crashes = [ - crash for crash in crash_reports[run_id] - if crash.timestamp > last_stats_time - ] + recent_crashes = [crash for crash in crash_reports[run_id] if crash.timestamp > last_stats_time] for crash in recent_crashes: event_data = f"data: {json.dumps({'type': 'crash', 'data': crash.model_dump()})}\n\n" yield event_data @@ -287,8 +278,8 @@ async def stream_fuzzing_updates(run_id: str): last_stats_time = datetime.utcnow() await asyncio.sleep(5) # Update every 5 seconds - except Exception as e: - logger.error(f"Error in event stream for run {run_id}: {e}") + except Exception: + logger.exception("Error in event stream for run %s", run_id) break return StreamingResponse( @@ -297,17 +288,17 @@ async def stream_fuzzing_updates(run_id: str): headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", - } + }, ) @router.delete("/{run_id}") -async def cleanup_fuzzing_run(run_id: str): - """ - Clean up fuzzing run data. +async def cleanup_fuzzing_run(run_id: str) -> dict[str, str]: + """Clean up fuzzing run data. Args: run_id: The fuzzing run ID to clean up + """ # Clean up tracking data fuzzing_stats.pop(run_id, None) diff --git a/backend/src/api/runs.py b/backend/src/api/runs.py index b975f4b..b86cc7e 100644 --- a/backend/src/api/runs.py +++ b/backend/src/api/runs.py @@ -1,6 +1,4 @@ -""" -API endpoints for workflow run management and findings retrieval -""" +"""API endpoints for workflow run management and findings retrieval.""" # Copyright (c) 2025 FuzzingLabs # @@ -14,37 +12,36 @@ API endpoints for workflow run management and findings retrieval # Additional attribution and requirements are provided in the NOTICE file. import logging -from fastapi import APIRouter, HTTPException, Depends +from typing import Annotated +from fastapi import APIRouter, Depends, HTTPException + +from src.main import temporal_mgr from src.models.findings import WorkflowFindings, WorkflowStatus +from src.temporal import TemporalManager logger = logging.getLogger(__name__) router = APIRouter(prefix="/runs", tags=["runs"]) -def get_temporal_manager(): - """Dependency to get the Temporal manager instance""" - from src.main import temporal_mgr +def get_temporal_manager() -> TemporalManager: + """Dependency to get the Temporal manager instance.""" return temporal_mgr -@router.get("/{run_id}/status", response_model=WorkflowStatus) +@router.get("/{run_id}/status") async def get_run_status( run_id: str, - temporal_mgr=Depends(get_temporal_manager) + temporal_mgr: Annotated[TemporalManager, Depends(get_temporal_manager)], ) -> WorkflowStatus: - """ - Get the current status of a workflow run. + """Get the current status of a workflow run. - Args: - run_id: The workflow run ID + :param run_id: The workflow run ID + :param temporal_mgr: The temporal manager instance. + :return: Status information including state, timestamps, and completion flags + :raises HTTPException: 404 if run not found - Returns: - Status information including state, timestamps, and completion flags - - Raises: - HTTPException: 404 if run not found """ try: status = await temporal_mgr.get_workflow_status(run_id) @@ -56,7 +53,7 @@ async def get_run_status( is_running = workflow_status == "RUNNING" # Extract workflow name from run_id (format: workflow_name-unique_id) - workflow_name = run_id.rsplit('-', 1)[0] if '-' in run_id else "unknown" + workflow_name = run_id.rsplit("-", 1)[0] if "-" in run_id else "unknown" return WorkflowStatus( run_id=run_id, @@ -66,33 +63,29 @@ async def get_run_status( is_failed=is_failed, is_running=is_running, created_at=status.get("start_time"), - updated_at=status.get("close_time") or status.get("execution_time") + updated_at=status.get("close_time") or status.get("execution_time"), ) except Exception as e: - logger.error(f"Failed to get status for run {run_id}: {e}") + logger.exception("Failed to get status for run %s", run_id) raise HTTPException( status_code=404, - detail=f"Run not found: {run_id}" - ) + detail=f"Run not found: {run_id}", + ) from e -@router.get("/{run_id}/findings", response_model=WorkflowFindings) +@router.get("/{run_id}/findings") async def get_run_findings( run_id: str, - temporal_mgr=Depends(get_temporal_manager) + temporal_mgr: Annotated[TemporalManager, Depends(get_temporal_manager)], ) -> WorkflowFindings: - """ - Get the findings from a completed workflow run. + """Get the findings from a completed workflow run. - Args: - run_id: The workflow run ID + :param run_id: The workflow run ID + :param temporal_mgr: The temporal manager instance. + :return: SARIF-formatted findings from the workflow execution + :raises HTTPException: 404 if run not found, 400 if run not completed - Returns: - SARIF-formatted findings from the workflow execution - - Raises: - HTTPException: 404 if run not found, 400 if run not completed """ try: # Get run status first @@ -103,80 +96,72 @@ async def get_run_findings( if workflow_status == "RUNNING": raise HTTPException( status_code=400, - detail=f"Run {run_id} is still running. Current status: {workflow_status}" - ) - else: - raise HTTPException( - status_code=400, - detail=f"Run {run_id} not completed. Status: {workflow_status}" + detail=f"Run {run_id} is still running. Current status: {workflow_status}", ) + raise HTTPException( + status_code=400, + detail=f"Run {run_id} not completed. Status: {workflow_status}", + ) if workflow_status == "FAILED": raise HTTPException( status_code=400, - detail=f"Run {run_id} failed. Status: {workflow_status}" + detail=f"Run {run_id} failed. Status: {workflow_status}", ) # Get the workflow result result = await temporal_mgr.get_workflow_result(run_id) # Extract SARIF from result (handle None for backwards compatibility) - if isinstance(result, dict): - sarif = result.get("sarif") or {} - else: - sarif = {} + sarif = result.get("sarif", {}) if isinstance(result, dict) else {} # Extract workflow name from run_id (format: workflow_name-unique_id) - workflow_name = run_id.rsplit('-', 1)[0] if '-' in run_id else "unknown" + workflow_name = run_id.rsplit("-", 1)[0] if "-" in run_id else "unknown" # Metadata metadata = { "completion_time": status.get("close_time"), - "workflow_version": "unknown" + "workflow_version": "unknown", } return WorkflowFindings( workflow=workflow_name, run_id=run_id, sarif=sarif, - metadata=metadata + metadata=metadata, ) except HTTPException: raise except Exception as e: - logger.error(f"Failed to get findings for run {run_id}: {e}") + logger.exception("Failed to get findings for run %s", run_id) raise HTTPException( status_code=500, - detail=f"Failed to retrieve findings: {str(e)}" - ) + detail=f"Failed to retrieve findings: {e!s}", + ) from e -@router.get("/{workflow_name}/findings/{run_id}", response_model=WorkflowFindings) +@router.get("/{workflow_name}/findings/{run_id}") async def get_workflow_findings( workflow_name: str, run_id: str, - temporal_mgr=Depends(get_temporal_manager) + temporal_mgr: Annotated[TemporalManager, Depends(get_temporal_manager)], ) -> WorkflowFindings: - """ - Get findings for a specific workflow run. + """Get findings for a specific workflow run. Alternative endpoint that includes workflow name in the path for clarity. - Args: - workflow_name: Name of the workflow - run_id: The workflow run ID + :param workflow_name: Name of the workflow + :param run_id: The workflow run ID + :param temporal_mgr: The temporal manager instance. + :return: SARIF-formatted findings from the workflow execution + :raises HTTPException: 404 if workflow or run not found, 400 if run not completed - Returns: - SARIF-formatted findings from the workflow execution - - Raises: - HTTPException: 404 if workflow or run not found, 400 if run not completed """ if workflow_name not in temporal_mgr.workflows: raise HTTPException( status_code=404, - detail=f"Workflow not found: {workflow_name}" + detail=f"Workflow not found: {workflow_name}", ) # Delegate to the main findings endpoint diff --git a/backend/src/api/system.py b/backend/src/api/system.py index a4ee1a6..057e5c1 100644 --- a/backend/src/api/system.py +++ b/backend/src/api/system.py @@ -9,14 +9,12 @@ # # Additional attribution and requirements are provided in the NOTICE file. -""" -System information endpoints for FuzzForge API. +"""System information endpoints for FuzzForge API. Provides system configuration and filesystem paths to CLI for worker management. """ import os -from typing import Dict from fastapi import APIRouter @@ -24,9 +22,8 @@ router = APIRouter(prefix="/system", tags=["system"]) @router.get("/info") -async def get_system_info() -> Dict[str, str]: - """ - Get system information including host filesystem paths. +async def get_system_info() -> dict[str, str]: + """Get system information including host filesystem paths. This endpoint exposes paths needed by the CLI to manage workers via docker-compose. The FUZZFORGE_HOST_ROOT environment variable is set by docker-compose and points @@ -37,6 +34,7 @@ async def get_system_info() -> Dict[str, str]: - host_root: Absolute path to FuzzForge root on host - docker_compose_path: Path to docker-compose.yml on host - workers_dir: Path to workers directory on host + """ host_root = os.getenv("FUZZFORGE_HOST_ROOT", "") diff --git a/backend/src/api/workflows.py b/backend/src/api/workflows.py index a4d1b7c..4e24af7 100644 --- a/backend/src/api/workflows.py +++ b/backend/src/api/workflows.py @@ -1,6 +1,4 @@ -""" -API endpoints for workflow management with enhanced error handling -""" +"""API endpoints for workflow management with enhanced error handling.""" # Copyright (c) 2025 FuzzingLabs # @@ -13,20 +11,24 @@ API endpoints for workflow management with enhanced error handling # # Additional attribution and requirements are provided in the NOTICE file. +import json import logging -import traceback import tempfile -from typing import List, Dict, Any, Optional -from fastapi import APIRouter, HTTPException, Depends, UploadFile, File, Form from pathlib import Path +from typing import Annotated, Any +from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile + +from src.api.fuzzing import initialize_fuzzing_tracking +from src.main import temporal_mgr from src.models.findings import ( - WorkflowSubmission, - WorkflowMetadata, + RunSubmissionResponse, WorkflowListItem, - RunSubmissionResponse + WorkflowMetadata, + WorkflowSubmission, ) from src.temporal.discovery import WorkflowDiscovery +from src.temporal.manager import TemporalManager logger = logging.getLogger(__name__) @@ -43,9 +45,8 @@ ALLOWED_CONTENT_TYPES = [ router = APIRouter(prefix="/workflows", tags=["workflows"]) -def extract_defaults_from_json_schema(metadata: Dict[str, Any]) -> Dict[str, Any]: - """ - Extract default parameter values from JSON Schema format. +def extract_defaults_from_json_schema(metadata: dict[str, Any]) -> dict[str, Any]: + """Extract default parameter values from JSON Schema format. Converts from: parameters: @@ -61,6 +62,7 @@ def extract_defaults_from_json_schema(metadata: Dict[str, Any]) -> Dict[str, Any Returns: Dictionary of parameter defaults + """ defaults = {} @@ -82,19 +84,19 @@ def extract_defaults_from_json_schema(metadata: Dict[str, Any]) -> Dict[str, Any def create_structured_error_response( error_type: str, message: str, - workflow_name: Optional[str] = None, - run_id: Optional[str] = None, - container_info: Optional[Dict[str, Any]] = None, - deployment_info: Optional[Dict[str, Any]] = None, - suggestions: Optional[List[str]] = None -) -> Dict[str, Any]: + workflow_name: str | None = None, + run_id: str | None = None, + container_info: dict[str, Any] | None = None, + deployment_info: dict[str, Any] | None = None, + suggestions: list[str] | None = None, +) -> dict[str, Any]: """Create a structured error response with rich context.""" error_response = { "error": { "type": error_type, "message": message, - "timestamp": __import__("datetime").datetime.utcnow().isoformat() + "Z" - } + "timestamp": __import__("datetime").datetime.utcnow().isoformat() + "Z", + }, } if workflow_name: @@ -115,39 +117,38 @@ def create_structured_error_response( return error_response -def get_temporal_manager(): - """Dependency to get the Temporal manager instance""" - from src.main import temporal_mgr +def get_temporal_manager() -> TemporalManager: + """Dependency to get the Temporal manager instance.""" return temporal_mgr -@router.get("/", response_model=List[WorkflowListItem]) +@router.get("/") async def list_workflows( - temporal_mgr=Depends(get_temporal_manager) -) -> List[WorkflowListItem]: - """ - List all discovered workflows with their metadata. + temporal_mgr: Annotated[TemporalManager, Depends(get_temporal_manager)], +) -> list[WorkflowListItem]: + """List all discovered workflows with their metadata. Returns a summary of each workflow including name, version, description, author, and tags. """ workflows = [] for name, info in temporal_mgr.workflows.items(): - workflows.append(WorkflowListItem( - name=name, - version=info.metadata.get("version", "0.6.0"), - description=info.metadata.get("description", ""), - author=info.metadata.get("author"), - tags=info.metadata.get("tags", []) - )) + workflows.append( + WorkflowListItem( + name=name, + version=info.metadata.get("version", "0.6.0"), + description=info.metadata.get("description", ""), + author=info.metadata.get("author"), + tags=info.metadata.get("tags", []), + ), + ) return workflows @router.get("/metadata/schema") -async def get_metadata_schema() -> Dict[str, Any]: - """ - Get the JSON schema for workflow metadata files. +async def get_metadata_schema() -> dict[str, Any]: + """Get the JSON schema for workflow metadata files. This schema defines the structure and requirements for metadata.yaml files that must accompany each workflow. @@ -155,23 +156,19 @@ async def get_metadata_schema() -> Dict[str, Any]: return WorkflowDiscovery.get_metadata_schema() -@router.get("/{workflow_name}/metadata", response_model=WorkflowMetadata) +@router.get("/{workflow_name}/metadata") async def get_workflow_metadata( workflow_name: str, - temporal_mgr=Depends(get_temporal_manager) + temporal_mgr: Annotated[TemporalManager, Depends(get_temporal_manager)], ) -> WorkflowMetadata: - """ - Get complete metadata for a specific workflow. + """Get complete metadata for a specific workflow. - Args: - workflow_name: Name of the workflow - - Returns: - Complete metadata including parameters schema, supported volume modes, + :param workflow_name: Name of the workflow + :param temporal_mgr: The temporal manager instance. + :return: Complete metadata including parameters schema, supported volume modes, required modules, and more. + :raises HTTPException: 404 if workflow not found - Raises: - HTTPException: 404 if workflow not found """ if workflow_name not in temporal_mgr.workflows: available_workflows = list(temporal_mgr.workflows.keys()) @@ -182,12 +179,12 @@ async def get_workflow_metadata( suggestions=[ f"Available workflows: {', '.join(available_workflows)}", "Use GET /workflows/ to see all available workflows", - "Check workflow name spelling and case sensitivity" - ] + "Check workflow name spelling and case sensitivity", + ], ) raise HTTPException( status_code=404, - detail=error_response + detail=error_response, ) info = temporal_mgr.workflows[workflow_name] @@ -201,28 +198,24 @@ async def get_workflow_metadata( tags=metadata.get("tags", []), parameters=metadata.get("parameters", {}), default_parameters=extract_defaults_from_json_schema(metadata), - required_modules=metadata.get("required_modules", []) + required_modules=metadata.get("required_modules", []), ) -@router.post("/{workflow_name}/submit", response_model=RunSubmissionResponse) +@router.post("/{workflow_name}/submit") async def submit_workflow( workflow_name: str, submission: WorkflowSubmission, - temporal_mgr=Depends(get_temporal_manager) + temporal_mgr: Annotated[TemporalManager, Depends(get_temporal_manager)], ) -> RunSubmissionResponse: - """ - Submit a workflow for execution. + """Submit a workflow for execution. - Args: - workflow_name: Name of the workflow to execute - submission: Submission parameters including target path and parameters + :param workflow_name: Name of the workflow to execute + :param submission: Submission parameters including target path and parameters + :param temporal_mgr: The temporal manager instance. + :return: Run submission response with run_id and initial status + :raises HTTPException: 404 if workflow not found, 400 for invalid parameters - Returns: - Run submission response with run_id and initial status - - Raises: - HTTPException: 404 if workflow not found, 400 for invalid parameters """ if workflow_name not in temporal_mgr.workflows: available_workflows = list(temporal_mgr.workflows.keys()) @@ -233,25 +226,26 @@ async def submit_workflow( suggestions=[ f"Available workflows: {', '.join(available_workflows)}", "Use GET /workflows/ to see all available workflows", - "Check workflow name spelling and case sensitivity" - ] + "Check workflow name spelling and case sensitivity", + ], ) raise HTTPException( status_code=404, - detail=error_response + detail=error_response, ) try: # Upload target file to MinIO and get target_id target_path = Path(submission.target_path) if not target_path.exists(): - raise ValueError(f"Target path does not exist: {submission.target_path}") + msg = f"Target path does not exist: {submission.target_path}" + raise ValueError(msg) # Upload target (using anonymous user for now) target_id = await temporal_mgr.upload_target( file_path=target_path, user_id="api-user", - metadata={"workflow": workflow_name} + metadata={"workflow": workflow_name}, ) # Merge default parameters with user parameters @@ -265,23 +259,22 @@ async def submit_workflow( handle = await temporal_mgr.run_workflow( workflow_name=workflow_name, target_id=target_id, - workflow_params=workflow_params + workflow_params=workflow_params, ) run_id = handle.id # Initialize fuzzing tracking if this looks like a fuzzing workflow workflow_info = temporal_mgr.workflows.get(workflow_name, {}) - workflow_tags = workflow_info.metadata.get("tags", []) if hasattr(workflow_info, 'metadata') else [] + workflow_tags = workflow_info.metadata.get("tags", []) if hasattr(workflow_info, "metadata") else [] if "fuzzing" in workflow_tags or "fuzz" in workflow_name.lower(): - from src.api.fuzzing import initialize_fuzzing_tracking initialize_fuzzing_tracking(run_id, workflow_name) return RunSubmissionResponse( run_id=run_id, status="RUNNING", workflow=workflow_name, - message=f"Workflow '{workflow_name}' submitted successfully" + message=f"Workflow '{workflow_name}' submitted successfully", ) except ValueError as e: @@ -293,14 +286,13 @@ async def submit_workflow( suggestions=[ "Check parameter types and values", "Use GET /workflows/{workflow_name}/parameters for schema", - "Ensure all required parameters are provided" - ] + "Ensure all required parameters are provided", + ], ) - raise HTTPException(status_code=400, detail=error_response) + raise HTTPException(status_code=400, detail=error_response) from e except Exception as e: - logger.error(f"Failed to submit workflow '{workflow_name}': {e}") - logger.error(f"Traceback: {traceback.format_exc()}") + logger.exception("Failed to submit workflow '%s'", workflow_name) # Try to get more context about the error container_info = None @@ -313,47 +305,57 @@ async def submit_workflow( # Detect specific error patterns if "workflow" in error_message.lower() and "not found" in error_message.lower(): error_type = "WorkflowError" - suggestions.extend([ - "Check if Temporal server is running and accessible", - "Verify workflow workers are running", - "Check if workflow is registered with correct vertical", - "Ensure Docker is running and has sufficient resources" - ]) + suggestions.extend( + [ + "Check if Temporal server is running and accessible", + "Verify workflow workers are running", + "Check if workflow is registered with correct vertical", + "Ensure Docker is running and has sufficient resources", + ], + ) elif "volume" in error_message.lower() or "mount" in error_message.lower(): error_type = "VolumeError" - suggestions.extend([ - "Check if the target path exists and is accessible", - "Verify file permissions (Docker needs read access)", - "Ensure the path is not in use by another process", - "Try using an absolute path instead of relative path" - ]) + suggestions.extend( + [ + "Check if the target path exists and is accessible", + "Verify file permissions (Docker needs read access)", + "Ensure the path is not in use by another process", + "Try using an absolute path instead of relative path", + ], + ) elif "memory" in error_message.lower() or "resource" in error_message.lower(): error_type = "ResourceError" - suggestions.extend([ - "Check system memory and CPU availability", - "Consider reducing resource limits or dataset size", - "Monitor Docker resource usage", - "Increase Docker memory limits if needed" - ]) + suggestions.extend( + [ + "Check system memory and CPU availability", + "Consider reducing resource limits or dataset size", + "Monitor Docker resource usage", + "Increase Docker memory limits if needed", + ], + ) elif "image" in error_message.lower(): error_type = "ImageError" - suggestions.extend([ - "Check if the workflow image exists", - "Verify Docker registry access", - "Try rebuilding the workflow image", - "Check network connectivity to registries" - ]) + suggestions.extend( + [ + "Check if the workflow image exists", + "Verify Docker registry access", + "Try rebuilding the workflow image", + "Check network connectivity to registries", + ], + ) else: - suggestions.extend([ - "Check FuzzForge backend logs for details", - "Verify all services are running (docker-compose up -d)", - "Try restarting the workflow deployment", - "Contact support if the issue persists" - ]) + suggestions.extend( + [ + "Check FuzzForge backend logs for details", + "Verify all services are running (docker-compose up -d)", + "Try restarting the workflow deployment", + "Contact support if the issue persists", + ], + ) error_response = create_structured_error_response( error_type=error_type, @@ -361,41 +363,35 @@ async def submit_workflow( workflow_name=workflow_name, container_info=container_info, deployment_info=deployment_info, - suggestions=suggestions + suggestions=suggestions, ) raise HTTPException( status_code=500, - detail=error_response - ) + detail=error_response, + ) from e -@router.post("/{workflow_name}/upload-and-submit", response_model=RunSubmissionResponse) +@router.post("/{workflow_name}/upload-and-submit") async def upload_and_submit_workflow( workflow_name: str, - file: UploadFile = File(..., description="Target file or tarball to analyze"), - parameters: Optional[str] = Form(None, description="JSON-encoded workflow parameters"), - timeout: Optional[int] = Form(None, description="Timeout in seconds"), - temporal_mgr=Depends(get_temporal_manager) + temporal_mgr: Annotated[TemporalManager, Depends(get_temporal_manager)], + file: Annotated[UploadFile, File(..., description="Target file or tarball to analyze")], + parameters: Annotated[str, Form(None, description="JSON-encoded workflow parameters")], ) -> RunSubmissionResponse: - """ - Upload a target file/tarball and submit workflow for execution. + """Upload a target file/tarball and submit workflow for execution. This endpoint accepts multipart/form-data uploads and is the recommended way to submit workflows from remote CLI clients. - Args: - workflow_name: Name of the workflow to execute - file: Target file or tarball (compressed directory) - parameters: JSON string of workflow parameters (optional) - timeout: Execution timeout in seconds (optional) + :param workflow_name: Name of the workflow to execute + :param temporal_mgr: The temporal manager instance. + :param file: Target file or tarball (compressed directory) + :param parameters: JSON string of workflow parameters (optional) + :returns: Run submission response with run_id and initial status + :raises HTTPException: 404 if workflow not found, 400 for invalid parameters, + 413 if file too large - Returns: - Run submission response with run_id and initial status - - Raises: - HTTPException: 404 if workflow not found, 400 for invalid parameters, - 413 if file too large """ if workflow_name not in temporal_mgr.workflows: available_workflows = list(temporal_mgr.workflows.keys()) @@ -405,8 +401,8 @@ async def upload_and_submit_workflow( workflow_name=workflow_name, suggestions=[ f"Available workflows: {', '.join(available_workflows)}", - "Use GET /workflows/ to see all available workflows" - ] + "Use GET /workflows/ to see all available workflows", + ], ) raise HTTPException(status_code=404, detail=error_response) @@ -420,10 +416,10 @@ async def upload_and_submit_workflow( # Create temporary file temp_fd, temp_file_path = tempfile.mkstemp(suffix=".tar.gz") - logger.info(f"Receiving file upload for workflow '{workflow_name}': {file.filename}") + logger.info("Receiving file upload for workflow '%s': %s", workflow_name, file.filename) # Stream file to disk - with open(temp_fd, 'wb') as temp_file: + with open(temp_fd, "wb") as temp_file: while True: chunk = await file.read(chunk_size) if not chunk: @@ -442,33 +438,33 @@ async def upload_and_submit_workflow( suggestions=[ "Reduce the size of your target directory", "Exclude unnecessary files (build artifacts, dependencies, etc.)", - "Consider splitting into smaller analysis targets" - ] - ) + "Consider splitting into smaller analysis targets", + ], + ), ) temp_file.write(chunk) - logger.info(f"Received file: {file_size / (1024**2):.2f} MB") + logger.info("Received file: %s MB", f"{file_size / (1024**2):.2f}") # Parse parameters workflow_params = {} if parameters: try: - import json workflow_params = json.loads(parameters) if not isinstance(workflow_params, dict): - raise ValueError("Parameters must be a JSON object") - except (json.JSONDecodeError, ValueError) as e: + msg = "Parameters must be a JSON object" + raise TypeError(msg) + except (json.JSONDecodeError, TypeError) as e: raise HTTPException( status_code=400, detail=create_structured_error_response( error_type="InvalidParameters", message=f"Invalid parameters JSON: {e}", workflow_name=workflow_name, - suggestions=["Ensure parameters is valid JSON object"] - ) - ) + suggestions=["Ensure parameters is valid JSON object"], + ), + ) from e # Upload to MinIO target_id = await temporal_mgr.upload_target( @@ -477,11 +473,11 @@ async def upload_and_submit_workflow( metadata={ "workflow": workflow_name, "original_filename": file.filename, - "upload_method": "multipart" - } + "upload_method": "multipart", + }, ) - logger.info(f"Uploaded to MinIO with target_id: {target_id}") + logger.info("Uploaded to MinIO with target_id: %s", target_id) # Merge default parameters with user parameters workflow_info = temporal_mgr.workflows.get(workflow_name) @@ -493,74 +489,68 @@ async def upload_and_submit_workflow( handle = await temporal_mgr.run_workflow( workflow_name=workflow_name, target_id=target_id, - workflow_params=workflow_params + workflow_params=workflow_params, ) run_id = handle.id # Initialize fuzzing tracking if needed workflow_info = temporal_mgr.workflows.get(workflow_name, {}) - workflow_tags = workflow_info.metadata.get("tags", []) if hasattr(workflow_info, 'metadata') else [] + workflow_tags = workflow_info.metadata.get("tags", []) if hasattr(workflow_info, "metadata") else [] if "fuzzing" in workflow_tags or "fuzz" in workflow_name.lower(): - from src.api.fuzzing import initialize_fuzzing_tracking initialize_fuzzing_tracking(run_id, workflow_name) return RunSubmissionResponse( run_id=run_id, status="RUNNING", workflow=workflow_name, - message=f"Workflow '{workflow_name}' submitted successfully with uploaded target" + message=f"Workflow '{workflow_name}' submitted successfully with uploaded target", ) except HTTPException: raise except Exception as e: - logger.error(f"Failed to upload and submit workflow '{workflow_name}': {e}") - logger.error(f"Traceback: {traceback.format_exc()}") + logger.exception("Failed to upload and submit workflow '%s'", workflow_name) error_response = create_structured_error_response( error_type="WorkflowSubmissionError", - message=f"Failed to process upload and submit workflow: {str(e)}", + message=f"Failed to process upload and submit workflow: {e!s}", workflow_name=workflow_name, suggestions=[ "Check if the uploaded file is a valid tarball", "Verify MinIO storage is accessible", "Check backend logs for detailed error information", - "Ensure Temporal workers are running" - ] + "Ensure Temporal workers are running", + ], ) - raise HTTPException(status_code=500, detail=error_response) + raise HTTPException(status_code=500, detail=error_response) from e finally: # Cleanup temporary file if temp_file_path and Path(temp_file_path).exists(): try: Path(temp_file_path).unlink() - logger.debug(f"Cleaned up temp file: {temp_file_path}") + logger.debug("Cleaned up temp file: %s", temp_file_path) except Exception as e: - logger.warning(f"Failed to cleanup temp file {temp_file_path}: {e}") + logger.warning("Failed to cleanup temp file %s: %s", temp_file_path, e) @router.get("/{workflow_name}/worker-info") async def get_workflow_worker_info( workflow_name: str, - temporal_mgr=Depends(get_temporal_manager) -) -> Dict[str, Any]: - """ - Get worker information for a workflow. + temporal_mgr: Annotated[TemporalManager, Depends(get_temporal_manager)], +) -> dict[str, Any]: + """Get worker information for a workflow. Returns details about which worker is required to execute this workflow, including container name, task queue, and vertical. - Args: - workflow_name: Name of the workflow + :param workflow_name: Name of the workflow + :param temporal_mgr: The temporal manager instance. + :return: Worker information including vertical, container name, and task queue + :raises HTTPException: 404 if workflow not found - Returns: - Worker information including vertical, container name, and task queue - - Raises: - HTTPException: 404 if workflow not found """ if workflow_name not in temporal_mgr.workflows: available_workflows = list(temporal_mgr.workflows.keys()) @@ -570,12 +560,12 @@ async def get_workflow_worker_info( workflow_name=workflow_name, suggestions=[ f"Available workflows: {', '.join(available_workflows)}", - "Use GET /workflows/ to see all available workflows" - ] + "Use GET /workflows/ to see all available workflows", + ], ) raise HTTPException( status_code=404, - detail=error_response + detail=error_response, ) info = temporal_mgr.workflows[workflow_name] @@ -591,12 +581,12 @@ async def get_workflow_worker_info( workflow_name=workflow_name, suggestions=[ "Check workflow metadata.yaml for 'vertical' field", - "Contact workflow author for support" - ] + "Contact workflow author for support", + ], ) raise HTTPException( status_code=500, - detail=error_response + detail=error_response, ) return { @@ -604,26 +594,22 @@ async def get_workflow_worker_info( "vertical": vertical, "worker_service": f"worker-{vertical}", "task_queue": f"{vertical}-queue", - "required": True + "required": True, } @router.get("/{workflow_name}/parameters") async def get_workflow_parameters( workflow_name: str, - temporal_mgr=Depends(get_temporal_manager) -) -> Dict[str, Any]: - """ - Get the parameters schema for a workflow. + temporal_mgr: Annotated[TemporalManager, Depends(get_temporal_manager)], +) -> dict[str, Any]: + """Get the parameters schema for a workflow. - Args: - workflow_name: Name of the workflow + :param workflow_name: Name of the workflow + :param temporal_mgr: The temporal manager instance. + :return: Parameters schema with types, descriptions, and defaults + :raises HTTPException: 404 if workflow not found - Returns: - Parameters schema with types, descriptions, and defaults - - Raises: - HTTPException: 404 if workflow not found """ if workflow_name not in temporal_mgr.workflows: available_workflows = list(temporal_mgr.workflows.keys()) @@ -633,12 +619,12 @@ async def get_workflow_parameters( workflow_name=workflow_name, suggestions=[ f"Available workflows: {', '.join(available_workflows)}", - "Use GET /workflows/ to see all available workflows" - ] + "Use GET /workflows/ to see all available workflows", + ], ) raise HTTPException( status_code=404, - detail=error_response + detail=error_response, ) info = temporal_mgr.workflows[workflow_name] @@ -648,10 +634,7 @@ async def get_workflow_parameters( parameters_schema = metadata.get("parameters", {}) # Extract the actual parameter definitions from JSON schema structure - if "properties" in parameters_schema: - param_definitions = parameters_schema["properties"] - else: - param_definitions = parameters_schema + param_definitions = parameters_schema.get("properties", parameters_schema) # Extract default values from JSON Schema default_params = extract_defaults_from_json_schema(metadata) @@ -661,7 +644,8 @@ async def get_workflow_parameters( "parameters": param_definitions, "default_parameters": default_params, "required_parameters": [ - name for name, schema in param_definitions.items() + name + for name, schema in param_definitions.items() if isinstance(schema, dict) and schema.get("required", False) - ] - } \ No newline at end of file + ], + } diff --git a/backend/src/core/setup.py b/backend/src/core/setup.py index 97b3a46..9235aa7 100644 --- a/backend/src/core/setup.py +++ b/backend/src/core/setup.py @@ -1,6 +1,4 @@ -""" -Setup utilities for FuzzForge infrastructure -""" +"""Setup utilities for FuzzForge infrastructure.""" # Copyright (c) 2025 FuzzingLabs # @@ -18,9 +16,8 @@ import logging logger = logging.getLogger(__name__) -async def setup_result_storage(): - """ - Setup result storage (MinIO). +async def setup_result_storage() -> bool: + """Set up result storage (MinIO). MinIO is used for both target upload and result storage. This is a placeholder for any MinIO-specific setup if needed. @@ -31,9 +28,8 @@ async def setup_result_storage(): return True -async def validate_infrastructure(): - """ - Validate all required infrastructure components. +async def validate_infrastructure() -> None: + """Validate all required infrastructure components. This should be called during startup to ensure everything is ready. """ diff --git a/backend/src/main.py b/backend/src/main.py index c219742..f28e8f9 100644 --- a/backend/src/main.py +++ b/backend/src/main.py @@ -13,20 +13,19 @@ import asyncio import logging import os from contextlib import AsyncExitStack, asynccontextmanager, suppress -from typing import Any, Dict, Optional, List +from typing import Any import uvicorn from fastapi import FastAPI +from fastmcp import FastMCP +from fastmcp.server.http import create_sse_app from starlette.applications import Starlette from starlette.routing import Mount -from fastmcp.server.http import create_sse_app - -from src.temporal.manager import TemporalManager +from src.api import fuzzing, runs, system, workflows from src.core.setup import setup_result_storage, validate_infrastructure -from src.api import workflows, runs, fuzzing, system - -from fastmcp import FastMCP +from src.temporal.discovery import WorkflowDiscovery +from src.temporal.manager import TemporalManager logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -38,12 +37,14 @@ class TemporalBootstrapState: """Tracks Temporal initialization progress for API and MCP consumers.""" def __init__(self) -> None: + """Initialize an instance of the class.""" self.ready: bool = False self.status: str = "not_started" - self.last_error: Optional[str] = None + self.last_error: str | None = None self.task_running: bool = False - def as_dict(self) -> Dict[str, Any]: + def as_dict(self) -> dict[str, Any]: + """Return the current state as a Python dictionnary.""" return { "ready": self.ready, "status": self.status, @@ -61,7 +62,7 @@ STARTUP_RETRY_MAX_SECONDS = max( int(os.getenv("FUZZFORGE_STARTUP_RETRY_MAX_SECONDS", "60")), ) -temporal_bootstrap_task: Optional[asyncio.Task] = None +temporal_bootstrap_task: asyncio.Task | None = None # --------------------------------------------------------------------------- # FastAPI application (REST API) @@ -79,17 +80,15 @@ app.include_router(fuzzing.router) app.include_router(system.router) -def get_temporal_status() -> Dict[str, Any]: +def get_temporal_status() -> dict[str, Any]: """Return a snapshot of Temporal bootstrap state for diagnostics.""" status = temporal_bootstrap_state.as_dict() status["workflows_loaded"] = len(temporal_mgr.workflows) - status["bootstrap_task_running"] = ( - temporal_bootstrap_task is not None and not temporal_bootstrap_task.done() - ) + status["bootstrap_task_running"] = temporal_bootstrap_task is not None and not temporal_bootstrap_task.done() return status -def _temporal_not_ready_status() -> Optional[Dict[str, Any]]: +def _temporal_not_ready_status() -> dict[str, Any] | None: """Return status details if Temporal is not ready yet.""" status = get_temporal_status() if status.get("ready"): @@ -98,7 +97,7 @@ def _temporal_not_ready_status() -> Optional[Dict[str, Any]]: @app.get("/") -async def root() -> Dict[str, Any]: +async def root() -> dict[str, Any]: status = get_temporal_status() return { "name": "FuzzForge API", @@ -110,14 +109,14 @@ async def root() -> Dict[str, Any]: @app.get("/health") -async def health() -> Dict[str, str]: +async def health() -> dict[str, str]: status = get_temporal_status() health_status = "healthy" if status.get("ready") else "initializing" return {"status": health_status} # Map FastAPI OpenAPI operationIds to readable MCP tool names -FASTAPI_MCP_NAME_OVERRIDES: Dict[str, str] = { +FASTAPI_MCP_NAME_OVERRIDES: dict[str, str] = { "list_workflows_workflows__get": "api_list_workflows", "get_metadata_schema_workflows_metadata_schema_get": "api_get_metadata_schema", "get_workflow_metadata_workflows__workflow_name__metadata_get": "api_get_workflow_metadata", @@ -155,7 +154,6 @@ mcp = FastMCP(name="FuzzForge MCP") async def _bootstrap_temporal_with_retries() -> None: """Initialize Temporal infrastructure with exponential backoff retries.""" - attempt = 0 while True: @@ -175,7 +173,6 @@ async def _bootstrap_temporal_with_retries() -> None: temporal_bootstrap_state.status = "ready" temporal_bootstrap_state.task_running = False logger.info("Temporal infrastructure ready") - return except asyncio.CancelledError: temporal_bootstrap_state.status = "cancelled" @@ -204,9 +201,11 @@ async def _bootstrap_temporal_with_retries() -> None: temporal_bootstrap_state.status = "cancelled" temporal_bootstrap_state.task_running = False raise + else: + return -def _lookup_workflow(workflow_name: str): +def _lookup_workflow(workflow_name: str) -> dict[str, Any]: info = temporal_mgr.workflows.get(workflow_name) if not info: return None @@ -222,12 +221,12 @@ def _lookup_workflow(workflow_name: str): "parameters": metadata.get("parameters", {}), "default_parameters": metadata.get("default_parameters", {}), "required_modules": metadata.get("required_modules", []), - "default_target_path": default_target_path + "default_target_path": default_target_path, } @mcp.tool -async def list_workflows_mcp() -> Dict[str, Any]: +async def list_workflows_mcp() -> dict[str, Any]: """List all discovered workflows and their metadata summary.""" not_ready = _temporal_not_ready_status() if not_ready: @@ -241,20 +240,21 @@ async def list_workflows_mcp() -> Dict[str, Any]: for name, info in temporal_mgr.workflows.items(): metadata = info.metadata defaults = metadata.get("default_parameters", {}) - workflows_summary.append({ - "name": name, - "version": metadata.get("version", "0.6.0"), - "description": metadata.get("description", ""), - "author": metadata.get("author"), - "tags": metadata.get("tags", []), - "default_target_path": metadata.get("default_target_path") - or defaults.get("target_path") - }) + workflows_summary.append( + { + "name": name, + "version": metadata.get("version", "0.6.0"), + "description": metadata.get("description", ""), + "author": metadata.get("author"), + "tags": metadata.get("tags", []), + "default_target_path": metadata.get("default_target_path") or defaults.get("target_path"), + }, + ) return {"workflows": workflows_summary, "temporal": get_temporal_status()} @mcp.tool -async def get_workflow_metadata_mcp(workflow_name: str) -> Dict[str, Any]: +async def get_workflow_metadata_mcp(workflow_name: str) -> dict[str, Any]: """Fetch detailed metadata for a workflow.""" not_ready = _temporal_not_ready_status() if not_ready: @@ -270,7 +270,7 @@ async def get_workflow_metadata_mcp(workflow_name: str) -> Dict[str, Any]: @mcp.tool -async def get_workflow_parameters_mcp(workflow_name: str) -> Dict[str, Any]: +async def get_workflow_parameters_mcp(workflow_name: str) -> dict[str, Any]: """Return the parameter schema and defaults for a workflow.""" not_ready = _temporal_not_ready_status() if not_ready: @@ -289,9 +289,8 @@ async def get_workflow_parameters_mcp(workflow_name: str) -> Dict[str, Any]: @mcp.tool -async def get_workflow_metadata_schema_mcp() -> Dict[str, Any]: +async def get_workflow_metadata_schema_mcp() -> dict[str, Any]: """Return the JSON schema describing workflow metadata files.""" - from src.temporal.discovery import WorkflowDiscovery return WorkflowDiscovery.get_metadata_schema() @@ -299,8 +298,8 @@ async def get_workflow_metadata_schema_mcp() -> Dict[str, Any]: async def submit_security_scan_mcp( workflow_name: str, target_id: str, - parameters: Dict[str, Any] | None = None, -) -> Dict[str, Any] | Dict[str, str]: + parameters: dict[str, Any] | None = None, +) -> dict[str, Any] | dict[str, str]: """Submit a Temporal workflow via MCP.""" try: not_ready = _temporal_not_ready_status() @@ -318,7 +317,7 @@ async def submit_security_scan_mcp( defaults = metadata.get("default_parameters", {}) parameters = parameters or {} - cleaned_parameters: Dict[str, Any] = {**defaults, **parameters} + cleaned_parameters: dict[str, Any] = {**defaults, **parameters} # Ensure *_config structures default to dicts for key, value in list(cleaned_parameters.items()): @@ -327,9 +326,7 @@ async def submit_security_scan_mcp( # Some workflows expect configuration dictionaries even when omitted parameter_definitions = ( - metadata.get("parameters", {}).get("properties", {}) - if isinstance(metadata.get("parameters"), dict) - else {} + metadata.get("parameters", {}).get("properties", {}) if isinstance(metadata.get("parameters"), dict) else {} ) for key, definition in parameter_definitions.items(): if not isinstance(key, str) or not key.endswith("_config"): @@ -347,6 +344,10 @@ async def submit_security_scan_mcp( workflow_params=cleaned_parameters, ) + except Exception as exc: # pragma: no cover - defensive logging + logger.exception("MCP submit failed") + return {"error": f"Failed to submit workflow: {exc}"} + else: return { "run_id": handle.id, "status": "RUNNING", @@ -356,13 +357,10 @@ async def submit_security_scan_mcp( "parameters": cleaned_parameters, "mcp_enabled": True, } - except Exception as exc: # pragma: no cover - defensive logging - logger.exception("MCP submit failed") - return {"error": f"Failed to submit workflow: {exc}"} @mcp.tool -async def get_comprehensive_scan_summary(run_id: str) -> Dict[str, Any] | Dict[str, str]: +async def get_comprehensive_scan_summary(run_id: str) -> dict[str, Any] | dict[str, str]: """Return a summary for the given workflow run via MCP.""" try: not_ready = _temporal_not_ready_status() @@ -385,7 +383,7 @@ async def get_comprehensive_scan_summary(run_id: str) -> Dict[str, Any] | Dict[s summary = result.get("summary", {}) total_findings = summary.get("total_findings", 0) except Exception as e: - logger.debug(f"Could not retrieve result for {run_id}: {e}") + logger.debug("Could not retrieve result for %s: %s", run_id, e) return { "run_id": run_id, @@ -412,7 +410,7 @@ async def get_comprehensive_scan_summary(run_id: str) -> Dict[str, Any] | Dict[s @mcp.tool -async def get_run_status_mcp(run_id: str) -> Dict[str, Any]: +async def get_run_status_mcp(run_id: str) -> dict[str, Any]: """Return current status information for a Temporal run.""" try: not_ready = _temporal_not_ready_status() @@ -440,7 +438,7 @@ async def get_run_status_mcp(run_id: str) -> Dict[str, Any]: @mcp.tool -async def get_run_findings_mcp(run_id: str) -> Dict[str, Any]: +async def get_run_findings_mcp(run_id: str) -> dict[str, Any]: """Return SARIF findings for a completed run.""" try: not_ready = _temporal_not_ready_status() @@ -463,24 +461,24 @@ async def get_run_findings_mcp(run_id: str) -> Dict[str, Any]: sarif = result.get("sarif", {}) if isinstance(result, dict) else {} + except Exception as exc: + logger.exception("MCP findings failed") + return {"error": f"Failed to retrieve findings: {exc}"} + else: return { "workflow": "unknown", "run_id": run_id, "sarif": sarif, "metadata": metadata, } - except Exception as exc: - logger.exception("MCP findings failed") - return {"error": f"Failed to retrieve findings: {exc}"} @mcp.tool async def list_recent_runs_mcp( limit: int = 10, workflow_name: str | None = None, -) -> Dict[str, Any]: +) -> dict[str, Any]: """List recent Temporal runs with optional workflow filter.""" - not_ready = _temporal_not_ready_status() if not_ready: return { @@ -505,19 +503,21 @@ async def list_recent_runs_mcp( workflows = await temporal_mgr.list_workflows(filter_query, limit_value) - results: List[Dict[str, Any]] = [] + results: list[dict[str, Any]] = [] for wf in workflows: - results.append({ - "run_id": wf["workflow_id"], - "workflow": workflow_name or "unknown", - "state": wf["status"], - "state_type": wf["status"], - "is_completed": wf["status"] in ["COMPLETED", "FAILED", "CANCELLED"], - "is_running": wf["status"] == "RUNNING", - "is_failed": wf["status"] == "FAILED", - "created_at": wf.get("start_time"), - "updated_at": wf.get("close_time"), - }) + results.append( + { + "run_id": wf["workflow_id"], + "workflow": workflow_name or "unknown", + "state": wf["status"], + "state_type": wf["status"], + "is_completed": wf["status"] in ["COMPLETED", "FAILED", "CANCELLED"], + "is_running": wf["status"] == "RUNNING", + "is_failed": wf["status"] == "FAILED", + "created_at": wf.get("start_time"), + "updated_at": wf.get("close_time"), + }, + ) return {"runs": results, "temporal": get_temporal_status()} @@ -526,12 +526,12 @@ async def list_recent_runs_mcp( return { "runs": [], "temporal": get_temporal_status(), - "error": str(exc) + "error": str(exc), } @mcp.tool -async def get_fuzzing_stats_mcp(run_id: str) -> Dict[str, Any]: +async def get_fuzzing_stats_mcp(run_id: str) -> dict[str, Any]: """Return fuzzing statistics for a run if available.""" not_ready = _temporal_not_ready_status() if not_ready: @@ -555,7 +555,7 @@ async def get_fuzzing_stats_mcp(run_id: str) -> Dict[str, Any]: @mcp.tool -async def get_fuzzing_crash_reports_mcp(run_id: str) -> Dict[str, Any]: +async def get_fuzzing_crash_reports_mcp(run_id: str) -> dict[str, Any]: """Return crash reports collected for a fuzzing run.""" not_ready = _temporal_not_ready_status() if not_ready: @@ -571,11 +571,10 @@ async def get_fuzzing_crash_reports_mcp(run_id: str) -> Dict[str, Any]: @mcp.tool -async def get_backend_status_mcp() -> Dict[str, Any]: +async def get_backend_status_mcp() -> dict[str, Any]: """Expose backend readiness, workflows, and registered MCP tools.""" - status = get_temporal_status() - response: Dict[str, Any] = {"temporal": status} + response: dict[str, Any] = {"temporal": status} if status.get("ready"): response["workflows"] = list(temporal_mgr.workflows.keys()) @@ -591,7 +590,6 @@ async def get_backend_status_mcp() -> Dict[str, Any]: def create_mcp_transport_app() -> Starlette: """Build a Starlette app serving HTTP + SSE transports on one port.""" - http_app = mcp.http_app(path="/", transport="streamable-http") sse_app = create_sse_app( server=mcp, @@ -609,10 +607,10 @@ def create_mcp_transport_app() -> Starlette: async def lifespan(app: Starlette): # pragma: no cover - integration wiring async with AsyncExitStack() as stack: await stack.enter_async_context( - http_app.router.lifespan_context(http_app) + http_app.router.lifespan_context(http_app), ) await stack.enter_async_context( - sse_app.router.lifespan_context(sse_app) + sse_app.router.lifespan_context(sse_app), ) yield @@ -627,6 +625,7 @@ def create_mcp_transport_app() -> Starlette: # Combined lifespan: Temporal init + dedicated MCP transports # --------------------------------------------------------------------------- + @asynccontextmanager async def combined_lifespan(app: FastAPI): global temporal_bootstrap_task, _fastapi_mcp_imported @@ -675,13 +674,14 @@ async def combined_lifespan(app: FastAPI): if getattr(mcp_server, "started", False): return await asyncio.sleep(poll_interval) - raise asyncio.TimeoutError + raise TimeoutError try: await _wait_for_uvicorn_startup() - except asyncio.TimeoutError: # pragma: no cover - defensive logging + except TimeoutError: # pragma: no cover - defensive logging if mcp_task.done(): - raise RuntimeError("MCP server failed to start") from mcp_task.exception() + msg = "MCP server failed to start" + raise RuntimeError(msg) from mcp_task.exception() logger.warning("Timed out waiting for MCP server startup; continuing anyway") logger.info("MCP HTTP available at http://0.0.0.0:8010/mcp") diff --git a/backend/src/models/findings.py b/backend/src/models/findings.py index b71a9b6..c15e52f 100644 --- a/backend/src/models/findings.py +++ b/backend/src/models/findings.py @@ -1,6 +1,4 @@ -""" -Models for workflow findings and submissions -""" +"""Models for workflow findings and submissions.""" # Copyright (c) 2025 FuzzingLabs # @@ -13,40 +11,43 @@ Models for workflow findings and submissions # # Additional attribution and requirements are provided in the NOTICE file. -from pydantic import BaseModel, Field -from typing import Dict, Any, Optional, List from datetime import datetime +from typing import Any + +from pydantic import BaseModel, Field class WorkflowFindings(BaseModel): - """Findings from a workflow execution in SARIF format""" + """Findings from a workflow execution in SARIF format.""" + workflow: str = Field(..., description="Workflow name") run_id: str = Field(..., description="Unique run identifier") - sarif: Dict[str, Any] = Field(..., description="SARIF formatted findings") - metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata") + sarif: dict[str, Any] = Field(..., description="SARIF formatted findings") + metadata: dict[str, Any] = Field(default_factory=dict, description="Additional metadata") class WorkflowSubmission(BaseModel): - """ - Submit a workflow with configurable settings. + """Submit a workflow with configurable settings. Note: This model is deprecated in favor of the /upload-and-submit endpoint which handles file uploads directly. """ - parameters: Dict[str, Any] = Field( + + parameters: dict[str, Any] = Field( default_factory=dict, - description="Workflow-specific parameters" + description="Workflow-specific parameters", ) - timeout: Optional[int] = Field( + timeout: int | None = Field( default=None, # Allow workflow-specific defaults description="Timeout in seconds (None for workflow default)", ge=1, - le=604800 # Max 7 days to support fuzzing campaigns + le=604800, # Max 7 days to support fuzzing campaigns ) class WorkflowStatus(BaseModel): - """Status of a workflow run""" + """Status of a workflow run.""" + run_id: str = Field(..., description="Unique run identifier") workflow: str = Field(..., description="Workflow name") status: str = Field(..., description="Current status") @@ -58,34 +59,37 @@ class WorkflowStatus(BaseModel): class WorkflowMetadata(BaseModel): - """Complete metadata for a workflow""" + """Complete metadata for a workflow.""" + name: str = Field(..., description="Workflow name") version: str = Field(..., description="Semantic version") description: str = Field(..., description="Workflow description") - author: Optional[str] = Field(None, description="Workflow author") - tags: List[str] = Field(default_factory=list, description="Workflow tags") - parameters: Dict[str, Any] = Field(..., description="Parameters schema") - default_parameters: Dict[str, Any] = Field( + author: str | None = Field(None, description="Workflow author") + tags: list[str] = Field(default_factory=list, description="Workflow tags") + parameters: dict[str, Any] = Field(..., description="Parameters schema") + default_parameters: dict[str, Any] = Field( default_factory=dict, - description="Default parameter values" + description="Default parameter values", ) - required_modules: List[str] = Field( + required_modules: list[str] = Field( default_factory=list, - description="Required module names" + description="Required module names", ) class WorkflowListItem(BaseModel): - """Summary information for a workflow in list views""" + """Summary information for a workflow in list views.""" + name: str = Field(..., description="Workflow name") version: str = Field(..., description="Semantic version") description: str = Field(..., description="Workflow description") - author: Optional[str] = Field(None, description="Workflow author") - tags: List[str] = Field(default_factory=list, description="Workflow tags") + author: str | None = Field(None, description="Workflow author") + tags: list[str] = Field(default_factory=list, description="Workflow tags") class RunSubmissionResponse(BaseModel): - """Response after submitting a workflow""" + """Response after submitting a workflow.""" + run_id: str = Field(..., description="Unique run identifier") status: str = Field(..., description="Initial status") workflow: str = Field(..., description="Workflow name") @@ -93,28 +97,30 @@ class RunSubmissionResponse(BaseModel): class FuzzingStats(BaseModel): - """Real-time fuzzing statistics""" + """Real-time fuzzing statistics.""" + run_id: str = Field(..., description="Unique run identifier") workflow: str = Field(..., description="Workflow name") executions: int = Field(default=0, description="Total executions") executions_per_sec: float = Field(default=0.0, description="Current execution rate") crashes: int = Field(default=0, description="Total crashes found") unique_crashes: int = Field(default=0, description="Unique crashes") - coverage: Optional[float] = Field(None, description="Code coverage percentage") + coverage: float | None = Field(None, description="Code coverage percentage") corpus_size: int = Field(default=0, description="Current corpus size") elapsed_time: int = Field(default=0, description="Elapsed time in seconds") - last_crash_time: Optional[datetime] = Field(None, description="Time of last crash") + last_crash_time: datetime | None = Field(None, description="Time of last crash") class CrashReport(BaseModel): - """Individual crash report from fuzzing""" + """Individual crash report from fuzzing.""" + run_id: str = Field(..., description="Run identifier") crash_id: str = Field(..., description="Unique crash identifier") timestamp: datetime = Field(default_factory=datetime.utcnow) - signal: Optional[str] = Field(None, description="Crash signal (SIGSEGV, etc.)") - crash_type: Optional[str] = Field(None, description="Type of crash") - stack_trace: Optional[str] = Field(None, description="Stack trace") - input_file: Optional[str] = Field(None, description="Path to crashing input") - reproducer: Optional[str] = Field(None, description="Minimized reproducer") + signal: str | None = Field(None, description="Crash signal (SIGSEGV, etc.)") + crash_type: str | None = Field(None, description="Type of crash") + stack_trace: str | None = Field(None, description="Stack trace") + input_file: str | None = Field(None, description="Path to crashing input") + reproducer: str | None = Field(None, description="Minimized reproducer") severity: str = Field(default="medium", description="Crash severity") - exploitability: Optional[str] = Field(None, description="Exploitability assessment") \ No newline at end of file + exploitability: str | None = Field(None, description="Exploitability assessment") diff --git a/backend/src/storage/__init__.py b/backend/src/storage/__init__.py index 4f78cff..570df61 100644 --- a/backend/src/storage/__init__.py +++ b/backend/src/storage/__init__.py @@ -1,5 +1,4 @@ -""" -Storage abstraction layer for FuzzForge. +"""Storage abstraction layer for FuzzForge. Provides unified interface for storing and retrieving targets and results. """ @@ -7,4 +6,4 @@ Provides unified interface for storing and retrieving targets and results. from .base import StorageBackend from .s3_cached import S3CachedStorage -__all__ = ["StorageBackend", "S3CachedStorage"] +__all__ = ["S3CachedStorage", "StorageBackend"] diff --git a/backend/src/storage/base.py b/backend/src/storage/base.py index 7323fd3..d4131ac 100644 --- a/backend/src/storage/base.py +++ b/backend/src/storage/base.py @@ -1,17 +1,15 @@ -""" -Base storage backend interface. +"""Base storage backend interface. All storage implementations must implement this interface. """ from abc import ABC, abstractmethod from pathlib import Path -from typing import Optional, Dict, Any +from typing import Any class StorageBackend(ABC): - """ - Abstract base class for storage backends. + """Abstract base class for storage backends. Implementations handle storage and retrieval of: - Uploaded targets (code, binaries, etc.) @@ -24,10 +22,9 @@ class StorageBackend(ABC): self, file_path: Path, user_id: str, - metadata: Optional[Dict[str, Any]] = None + metadata: dict[str, Any] | None = None, ) -> str: - """ - Upload a target file to storage. + """Upload a target file to storage. Args: file_path: Local path to file to upload @@ -40,13 +37,12 @@ class StorageBackend(ABC): Raises: FileNotFoundError: If file_path doesn't exist StorageError: If upload fails + """ - pass @abstractmethod async def get_target(self, target_id: str) -> Path: - """ - Get target file from storage. + """Get target file from storage. Args: target_id: Unique identifier from upload_target() @@ -57,31 +53,29 @@ class StorageBackend(ABC): Raises: FileNotFoundError: If target doesn't exist StorageError: If download fails + """ - pass @abstractmethod async def delete_target(self, target_id: str) -> None: - """ - Delete target from storage. + """Delete target from storage. Args: target_id: Unique identifier to delete Raises: StorageError: If deletion fails (doesn't raise if not found) + """ - pass @abstractmethod async def upload_results( self, workflow_id: str, - results: Dict[str, Any], - results_format: str = "json" + results: dict[str, Any], + results_format: str = "json", ) -> str: - """ - Upload workflow results to storage. + """Upload workflow results to storage. Args: workflow_id: Workflow execution ID @@ -93,13 +87,12 @@ class StorageBackend(ABC): Raises: StorageError: If upload fails + """ - pass @abstractmethod - async def get_results(self, workflow_id: str) -> Dict[str, Any]: - """ - Get workflow results from storage. + async def get_results(self, workflow_id: str) -> dict[str, Any]: + """Get workflow results from storage. Args: workflow_id: Workflow execution ID @@ -110,17 +103,16 @@ class StorageBackend(ABC): Raises: FileNotFoundError: If results don't exist StorageError: If download fails + """ - pass @abstractmethod async def list_targets( self, - user_id: Optional[str] = None, - limit: int = 100 - ) -> list[Dict[str, Any]]: - """ - List uploaded targets. + user_id: str | None = None, + limit: int = 100, + ) -> list[dict[str, Any]]: + """List uploaded targets. Args: user_id: Filter by user ID (None = all users) @@ -131,23 +123,21 @@ class StorageBackend(ABC): Raises: StorageError: If listing fails + """ - pass @abstractmethod async def cleanup_cache(self) -> int: - """ - Clean up local cache (LRU eviction). + """Clean up local cache (LRU eviction). Returns: Number of files removed Raises: StorageError: If cleanup fails + """ - pass class StorageError(Exception): """Base exception for storage operations.""" - pass diff --git a/backend/src/storage/s3_cached.py b/backend/src/storage/s3_cached.py index 99c8e3a..3802951 100644 --- a/backend/src/storage/s3_cached.py +++ b/backend/src/storage/s3_cached.py @@ -1,5 +1,4 @@ -""" -S3-compatible storage backend with local caching. +"""S3-compatible storage backend with local caching. Works with MinIO (dev/prod) or AWS S3 (cloud). """ @@ -10,7 +9,7 @@ import os import shutil from datetime import datetime from pathlib import Path -from typing import Optional, Dict, Any +from typing import Any from uuid import uuid4 import boto3 @@ -22,8 +21,7 @@ logger = logging.getLogger(__name__) class S3CachedStorage(StorageBackend): - """ - S3-compatible storage with local caching. + """S3-compatible storage with local caching. Features: - Upload targets to S3/MinIO @@ -34,17 +32,16 @@ class S3CachedStorage(StorageBackend): def __init__( self, - endpoint_url: Optional[str] = None, - access_key: Optional[str] = None, - secret_key: Optional[str] = None, + endpoint_url: str | None = None, + access_key: str | None = None, + secret_key: str | None = None, bucket: str = "targets", region: str = "us-east-1", use_ssl: bool = False, - cache_dir: Optional[Path] = None, - cache_max_size_gb: int = 10 - ): - """ - Initialize S3 storage backend. + cache_dir: Path | None = None, + cache_max_size_gb: int = 10, + ) -> None: + """Initialize S3 storage backend. Args: endpoint_url: S3 endpoint (None = AWS S3, or MinIO URL) @@ -55,18 +52,19 @@ class S3CachedStorage(StorageBackend): use_ssl: Use HTTPS cache_dir: Local cache directory cache_max_size_gb: Maximum cache size in GB + """ # Use environment variables as defaults - self.endpoint_url = endpoint_url or os.getenv('S3_ENDPOINT', 'http://minio:9000') - self.access_key = access_key or os.getenv('S3_ACCESS_KEY', 'fuzzforge') - self.secret_key = secret_key or os.getenv('S3_SECRET_KEY', 'fuzzforge123') - self.bucket = bucket or os.getenv('S3_BUCKET', 'targets') - self.region = region or os.getenv('S3_REGION', 'us-east-1') - self.use_ssl = use_ssl or os.getenv('S3_USE_SSL', 'false').lower() == 'true' + self.endpoint_url = endpoint_url or os.getenv("S3_ENDPOINT", "http://minio:9000") + self.access_key = access_key or os.getenv("S3_ACCESS_KEY", "fuzzforge") + self.secret_key = secret_key or os.getenv("S3_SECRET_KEY", "fuzzforge123") + self.bucket = bucket or os.getenv("S3_BUCKET", "targets") + self.region = region or os.getenv("S3_REGION", "us-east-1") + self.use_ssl = use_ssl or os.getenv("S3_USE_SSL", "false").lower() == "true" # Cache configuration - self.cache_dir = cache_dir or Path(os.getenv('CACHE_DIR', '/tmp/fuzzforge-cache')) - self.cache_max_size = cache_max_size_gb * (1024 ** 3) # Convert to bytes + self.cache_dir = cache_dir or Path(os.getenv("CACHE_DIR", "/tmp/fuzzforge-cache")) + self.cache_max_size = cache_max_size_gb * (1024**3) # Convert to bytes # Ensure cache directory exists self.cache_dir.mkdir(parents=True, exist_ok=True) @@ -74,69 +72,75 @@ class S3CachedStorage(StorageBackend): # Initialize S3 client try: self.s3_client = boto3.client( - 's3', + "s3", endpoint_url=self.endpoint_url, aws_access_key_id=self.access_key, aws_secret_access_key=self.secret_key, region_name=self.region, - use_ssl=self.use_ssl + use_ssl=self.use_ssl, ) - logger.info(f"Initialized S3 storage: {self.endpoint_url}/{self.bucket}") + logger.info("Initialized S3 storage: %s/%s", self.endpoint_url, self.bucket) except Exception as e: - logger.error(f"Failed to initialize S3 client: {e}") - raise StorageError(f"S3 initialization failed: {e}") + logger.exception("Failed to initialize S3 client") + msg = f"S3 initialization failed: {e}" + raise StorageError(msg) from e async def upload_target( self, file_path: Path, user_id: str, - metadata: Optional[Dict[str, Any]] = None + metadata: dict[str, Any] | None = None, ) -> str: """Upload target file to S3/MinIO.""" if not file_path.exists(): - raise FileNotFoundError(f"File not found: {file_path}") + msg = f"File not found: {file_path}" + raise FileNotFoundError(msg) # Generate unique target ID target_id = str(uuid4()) # Prepare metadata upload_metadata = { - 'user_id': user_id, - 'uploaded_at': datetime.now().isoformat(), - 'filename': file_path.name, - 'size': str(file_path.stat().st_size) + "user_id": user_id, + "uploaded_at": datetime.now().isoformat(), + "filename": file_path.name, + "size": str(file_path.stat().st_size), } if metadata: upload_metadata.update(metadata) # Upload to S3 - s3_key = f'{target_id}/target' + s3_key = f"{target_id}/target" try: - logger.info(f"Uploading target to s3://{self.bucket}/{s3_key}") + logger.info("Uploading target to s3://%s/%s", self.bucket, s3_key) self.s3_client.upload_file( str(file_path), self.bucket, s3_key, ExtraArgs={ - 'Metadata': upload_metadata - } + "Metadata": upload_metadata, + }, ) file_size_mb = file_path.stat().st_size / (1024 * 1024) logger.info( - f"✓ Uploaded target {target_id} " - f"({file_path.name}, {file_size_mb:.2f} MB)" + "✓ Uploaded target %s (%s, %s MB)", + target_id, + file_path.name, + f"{file_size_mb:.2f}", ) - return target_id - except ClientError as e: - logger.error(f"S3 upload failed: {e}", exc_info=True) - raise StorageError(f"Failed to upload target: {e}") + logger.exception("S3 upload failed") + msg = f"Failed to upload target: {e}" + raise StorageError(msg) from e except Exception as e: - logger.error(f"Upload failed: {e}", exc_info=True) - raise StorageError(f"Upload error: {e}") + logger.exception("Upload failed") + msg = f"Upload error: {e}" + raise StorageError(msg) from e + else: + return target_id async def get_target(self, target_id: str) -> Path: """Get target from cache or download from S3/MinIO.""" @@ -147,105 +151,110 @@ class S3CachedStorage(StorageBackend): if cached_file.exists(): # Update access time for LRU cached_file.touch() - logger.info(f"Cache HIT: {target_id}") + logger.info("Cache HIT: %s", target_id) return cached_file # Cache miss - download from S3 - logger.info(f"Cache MISS: {target_id}, downloading from S3...") + logger.info("Cache MISS: %s, downloading from S3...", target_id) try: # Create cache directory cache_path.mkdir(parents=True, exist_ok=True) # Download from S3 - s3_key = f'{target_id}/target' - logger.info(f"Downloading s3://{self.bucket}/{s3_key}") + s3_key = f"{target_id}/target" + logger.info("Downloading s3://%s/%s", self.bucket, s3_key) self.s3_client.download_file( self.bucket, s3_key, - str(cached_file) + str(cached_file), ) # Verify download if not cached_file.exists(): - raise StorageError(f"Downloaded file not found: {cached_file}") + msg = f"Downloaded file not found: {cached_file}" + raise StorageError(msg) file_size_mb = cached_file.stat().st_size / (1024 * 1024) - logger.info(f"✓ Downloaded target {target_id} ({file_size_mb:.2f} MB)") - - return cached_file + logger.info("✓ Downloaded target %s (%s MB)", target_id, f"{file_size_mb:.2f}") except ClientError as e: - error_code = e.response.get('Error', {}).get('Code') - if error_code in ['404', 'NoSuchKey']: - logger.error(f"Target not found: {target_id}") - raise FileNotFoundError(f"Target {target_id} not found in storage") - else: - logger.error(f"S3 download failed: {e}", exc_info=True) - raise StorageError(f"Download failed: {e}") + error_code = e.response.get("Error", {}).get("Code") + if error_code in ["404", "NoSuchKey"]: + logger.exception("Target not found: %s", target_id) + msg = f"Target {target_id} not found in storage" + raise FileNotFoundError(msg) from e + logger.exception("S3 download failed") + msg = f"Download failed: {e}" + raise StorageError(msg) from e except Exception as e: - logger.error(f"Download error: {e}", exc_info=True) + logger.exception("Download error") # Cleanup partial download if cache_path.exists(): shutil.rmtree(cache_path, ignore_errors=True) - raise StorageError(f"Download error: {e}") + msg = f"Download error: {e}" + raise StorageError(msg) from e + else: + return cached_file async def delete_target(self, target_id: str) -> None: """Delete target from S3/MinIO.""" try: - s3_key = f'{target_id}/target' - logger.info(f"Deleting s3://{self.bucket}/{s3_key}") + s3_key = f"{target_id}/target" + logger.info("Deleting s3://%s/%s", self.bucket, s3_key) self.s3_client.delete_object( Bucket=self.bucket, - Key=s3_key + Key=s3_key, ) # Also delete from cache if present cache_path = self.cache_dir / target_id if cache_path.exists(): shutil.rmtree(cache_path, ignore_errors=True) - logger.info(f"✓ Deleted target {target_id} from S3 and cache") + logger.info("✓ Deleted target %s from S3 and cache", target_id) else: - logger.info(f"✓ Deleted target {target_id} from S3") + logger.info("✓ Deleted target %s from S3", target_id) except ClientError as e: - logger.error(f"S3 delete failed: {e}", exc_info=True) + logger.exception("S3 delete failed") # Don't raise error if object doesn't exist - if e.response.get('Error', {}).get('Code') not in ['404', 'NoSuchKey']: - raise StorageError(f"Delete failed: {e}") + if e.response.get("Error", {}).get("Code") not in ["404", "NoSuchKey"]: + msg = f"Delete failed: {e}" + raise StorageError(msg) from e except Exception as e: - logger.error(f"Delete error: {e}", exc_info=True) - raise StorageError(f"Delete error: {e}") + logger.exception("Delete error") + msg = f"Delete error: {e}" + raise StorageError(msg) from e async def upload_results( self, workflow_id: str, - results: Dict[str, Any], - results_format: str = "json" + results: dict[str, Any], + results_format: str = "json", ) -> str: """Upload workflow results to S3/MinIO.""" try: # Prepare results content if results_format == "json": - content = json.dumps(results, indent=2).encode('utf-8') - content_type = 'application/json' - file_ext = 'json' + content = json.dumps(results, indent=2).encode("utf-8") + content_type = "application/json" + file_ext = "json" elif results_format == "sarif": - content = json.dumps(results, indent=2).encode('utf-8') - content_type = 'application/sarif+json' - file_ext = 'sarif' + content = json.dumps(results, indent=2).encode("utf-8") + content_type = "application/sarif+json" + file_ext = "sarif" else: - content = json.dumps(results, indent=2).encode('utf-8') - content_type = 'application/json' - file_ext = 'json' + content = json.dumps(results, indent=2).encode("utf-8") + content_type = "application/json" + file_ext = "json" # Upload to results bucket - results_bucket = 'results' - s3_key = f'{workflow_id}/results.{file_ext}' + results_bucket = "results" + s3_key = f"{workflow_id}/results.{file_ext}" - logger.info(f"Uploading results to s3://{results_bucket}/{s3_key}") + logger.info("Uploading results to s3://%s/%s", results_bucket, s3_key) self.s3_client.put_object( Bucket=results_bucket, @@ -253,95 +262,103 @@ class S3CachedStorage(StorageBackend): Body=content, ContentType=content_type, Metadata={ - 'workflow_id': workflow_id, - 'format': results_format, - 'uploaded_at': datetime.now().isoformat() - } + "workflow_id": workflow_id, + "format": results_format, + "uploaded_at": datetime.now().isoformat(), + }, ) # Construct URL results_url = f"{self.endpoint_url}/{results_bucket}/{s3_key}" - logger.info(f"✓ Uploaded results: {results_url}") - - return results_url + logger.info("✓ Uploaded results: %s", results_url) except Exception as e: - logger.error(f"Results upload failed: {e}", exc_info=True) - raise StorageError(f"Results upload failed: {e}") + logger.exception("Results upload failed") + msg = f"Results upload failed: {e}" + raise StorageError(msg) from e + else: + return results_url - async def get_results(self, workflow_id: str) -> Dict[str, Any]: + async def get_results(self, workflow_id: str) -> dict[str, Any]: """Get workflow results from S3/MinIO.""" try: - results_bucket = 'results' - s3_key = f'{workflow_id}/results.json' + results_bucket = "results" + s3_key = f"{workflow_id}/results.json" - logger.info(f"Downloading results from s3://{results_bucket}/{s3_key}") + logger.info("Downloading results from s3://%s/%s", results_bucket, s3_key) response = self.s3_client.get_object( Bucket=results_bucket, - Key=s3_key + Key=s3_key, ) - content = response['Body'].read().decode('utf-8') + content = response["Body"].read().decode("utf-8") results = json.loads(content) - logger.info(f"✓ Downloaded results for workflow {workflow_id}") - return results + logger.info("✓ Downloaded results for workflow %s", workflow_id) except ClientError as e: - error_code = e.response.get('Error', {}).get('Code') - if error_code in ['404', 'NoSuchKey']: - logger.error(f"Results not found: {workflow_id}") - raise FileNotFoundError(f"Results for workflow {workflow_id} not found") - else: - logger.error(f"Results download failed: {e}", exc_info=True) - raise StorageError(f"Results download failed: {e}") + error_code = e.response.get("Error", {}).get("Code") + if error_code in ["404", "NoSuchKey"]: + logger.exception("Results not found: %s", workflow_id) + msg = f"Results for workflow {workflow_id} not found" + raise FileNotFoundError(msg) from e + logger.exception("Results download failed") + msg = f"Results download failed: {e}" + raise StorageError(msg) from e except Exception as e: - logger.error(f"Results download error: {e}", exc_info=True) - raise StorageError(f"Results download error: {e}") + logger.exception("Results download error") + msg = f"Results download error: {e}" + raise StorageError(msg) from e + else: + return results async def list_targets( self, - user_id: Optional[str] = None, - limit: int = 100 - ) -> list[Dict[str, Any]]: + user_id: str | None = None, + limit: int = 100, + ) -> list[dict[str, Any]]: """List uploaded targets.""" try: targets = [] - paginator = self.s3_client.get_paginator('list_objects_v2') + paginator = self.s3_client.get_paginator("list_objects_v2") - for page in paginator.paginate(Bucket=self.bucket, PaginationConfig={'MaxItems': limit}): - for obj in page.get('Contents', []): + for page in paginator.paginate(Bucket=self.bucket, PaginationConfig={"MaxItems": limit}): + for obj in page.get("Contents", []): # Get object metadata try: metadata_response = self.s3_client.head_object( Bucket=self.bucket, - Key=obj['Key'] + Key=obj["Key"], ) - metadata = metadata_response.get('Metadata', {}) + metadata = metadata_response.get("Metadata", {}) # Filter by user_id if specified - if user_id and metadata.get('user_id') != user_id: + if user_id and metadata.get("user_id") != user_id: continue - targets.append({ - 'target_id': obj['Key'].split('/')[0], - 'key': obj['Key'], - 'size': obj['Size'], - 'last_modified': obj['LastModified'].isoformat(), - 'metadata': metadata - }) + targets.append( + { + "target_id": obj["Key"].split("/")[0], + "key": obj["Key"], + "size": obj["Size"], + "last_modified": obj["LastModified"].isoformat(), + "metadata": metadata, + }, + ) except Exception as e: - logger.warning(f"Failed to get metadata for {obj['Key']}: {e}") + logger.warning("Failed to get metadata for %s: %s", obj["Key"], e) continue - logger.info(f"Listed {len(targets)} targets (user_id={user_id})") - return targets + logger.info("Listed %s targets (user_id=%s)", len(targets), user_id) except Exception as e: - logger.error(f"List targets failed: {e}", exc_info=True) - raise StorageError(f"List targets failed: {e}") + logger.exception("List targets failed") + msg = f"List targets failed: {e}" + raise StorageError(msg) from e + else: + return targets async def cleanup_cache(self) -> int: """Clean up local cache using LRU eviction.""" @@ -350,30 +367,33 @@ class S3CachedStorage(StorageBackend): total_size = 0 # Gather all cached files with metadata - for cache_file in self.cache_dir.rglob('*'): + for cache_file in self.cache_dir.rglob("*"): if cache_file.is_file(): try: stat = cache_file.stat() - cache_files.append({ - 'path': cache_file, - 'size': stat.st_size, - 'atime': stat.st_atime # Last access time - }) + cache_files.append( + { + "path": cache_file, + "size": stat.st_size, + "atime": stat.st_atime, # Last access time + }, + ) total_size += stat.st_size except Exception as e: - logger.warning(f"Failed to stat {cache_file}: {e}") + logger.warning("Failed to stat %s: %s", cache_file, e) continue # Check if cleanup is needed if total_size <= self.cache_max_size: logger.info( - f"Cache size OK: {total_size / (1024**3):.2f} GB / " - f"{self.cache_max_size / (1024**3):.2f} GB" + "Cache size OK: %s GB / %s GB", + f"{total_size / (1024**3):.2f}", + f"{self.cache_max_size / (1024**3):.2f}", ) return 0 # Sort by access time (oldest first) - cache_files.sort(key=lambda x: x['atime']) + cache_files.sort(key=lambda x: x["atime"]) # Remove files until under limit removed_count = 0 @@ -382,42 +402,46 @@ class S3CachedStorage(StorageBackend): break try: - file_info['path'].unlink() - total_size -= file_info['size'] + file_info["path"].unlink() + total_size -= file_info["size"] removed_count += 1 - logger.debug(f"Evicted from cache: {file_info['path']}") + logger.debug("Evicted from cache: %s", file_info["path"]) except Exception as e: - logger.warning(f"Failed to delete {file_info['path']}: {e}") + logger.warning("Failed to delete %s: %s", file_info["path"], e) continue logger.info( - f"✓ Cache cleanup: removed {removed_count} files, " - f"new size: {total_size / (1024**3):.2f} GB" + "✓ Cache cleanup: removed %s files, new size: %s GB", + removed_count, + f"{total_size / (1024**3):.2f}", ) - return removed_count except Exception as e: - logger.error(f"Cache cleanup failed: {e}", exc_info=True) - raise StorageError(f"Cache cleanup failed: {e}") + logger.exception("Cache cleanup failed") + msg = f"Cache cleanup failed: {e}" + raise StorageError(msg) from e - def get_cache_stats(self) -> Dict[str, Any]: + else: + return removed_count + + def get_cache_stats(self) -> dict[str, Any]: """Get cache statistics.""" try: total_size = 0 file_count = 0 - for cache_file in self.cache_dir.rglob('*'): + for cache_file in self.cache_dir.rglob("*"): if cache_file.is_file(): total_size += cache_file.stat().st_size file_count += 1 return { - 'total_size_bytes': total_size, - 'total_size_gb': total_size / (1024 ** 3), - 'file_count': file_count, - 'max_size_gb': self.cache_max_size / (1024 ** 3), - 'usage_percent': (total_size / self.cache_max_size) * 100 + "total_size_bytes": total_size, + "total_size_gb": total_size / (1024**3), + "file_count": file_count, + "max_size_gb": self.cache_max_size / (1024**3), + "usage_percent": (total_size / self.cache_max_size) * 100, } except Exception as e: - logger.error(f"Failed to get cache stats: {e}") - return {'error': str(e)} + logger.exception("Failed to get cache stats") + return {"error": str(e)} diff --git a/backend/src/temporal/__init__.py b/backend/src/temporal/__init__.py index acaa368..9eb66fb 100644 --- a/backend/src/temporal/__init__.py +++ b/backend/src/temporal/__init__.py @@ -1,10 +1,9 @@ -""" -Temporal integration for FuzzForge. +"""Temporal integration for FuzzForge. Handles workflow execution, monitoring, and management. """ -from .manager import TemporalManager from .discovery import WorkflowDiscovery +from .manager import TemporalManager __all__ = ["TemporalManager", "WorkflowDiscovery"] diff --git a/backend/src/temporal/discovery.py b/backend/src/temporal/discovery.py index 07da6f8..c91a55c 100644 --- a/backend/src/temporal/discovery.py +++ b/backend/src/temporal/discovery.py @@ -1,25 +1,26 @@ -""" -Workflow Discovery for Temporal +"""Workflow Discovery for Temporal. Discovers workflows from the toolbox/workflows directory and provides metadata about available workflows. """ import logging -import yaml from pathlib import Path -from typing import Dict, Any -from pydantic import BaseModel, Field, ConfigDict +from typing import Any + +import yaml +from pydantic import BaseModel, ConfigDict, Field logger = logging.getLogger(__name__) class WorkflowInfo(BaseModel): - """Information about a discovered workflow""" + """Information about a discovered workflow.""" + name: str = Field(..., description="Workflow name") path: Path = Field(..., description="Path to workflow directory") workflow_file: Path = Field(..., description="Path to workflow.py file") - metadata: Dict[str, Any] = Field(..., description="Workflow metadata from YAML") + metadata: dict[str, Any] = Field(..., description="Workflow metadata from YAML") workflow_type: str = Field(..., description="Workflow class name") vertical: str = Field(..., description="Vertical (worker type) for this workflow") @@ -27,8 +28,7 @@ class WorkflowInfo(BaseModel): class WorkflowDiscovery: - """ - Discovers workflows from the filesystem. + """Discovers workflows from the filesystem. Scans toolbox/workflows/ for directories containing: - metadata.yaml (required) @@ -38,106 +38,109 @@ class WorkflowDiscovery: which determines which worker pool will execute it. """ - def __init__(self, workflows_dir: Path): - """ - Initialize workflow discovery. + def __init__(self, workflows_dir: Path) -> None: + """Initialize workflow discovery. Args: workflows_dir: Path to the workflows directory + """ self.workflows_dir = workflows_dir if not self.workflows_dir.exists(): self.workflows_dir.mkdir(parents=True, exist_ok=True) - logger.info(f"Created workflows directory: {self.workflows_dir}") + logger.info("Created workflows directory: %s", self.workflows_dir) - async def discover_workflows(self) -> Dict[str, WorkflowInfo]: - """ - Discover workflows by scanning the workflows directory. + async def discover_workflows(self) -> dict[str, WorkflowInfo]: + """Discover workflows by scanning the workflows directory. Returns: Dictionary mapping workflow names to their information + """ workflows = {} - logger.info(f"Scanning for workflows in: {self.workflows_dir}") + logger.info("Scanning for workflows in: %s", self.workflows_dir) for workflow_dir in self.workflows_dir.iterdir(): if not workflow_dir.is_dir(): continue # Skip special directories - if workflow_dir.name.startswith('.') or workflow_dir.name == '__pycache__': + if workflow_dir.name.startswith(".") or workflow_dir.name == "__pycache__": continue metadata_file = workflow_dir / "metadata.yaml" if not metadata_file.exists(): - logger.debug(f"No metadata.yaml in {workflow_dir.name}, skipping") + logger.debug("No metadata.yaml in %s, skipping", workflow_dir.name) continue workflow_file = workflow_dir / "workflow.py" if not workflow_file.exists(): logger.warning( - f"Workflow {workflow_dir.name} has metadata but no workflow.py, skipping" + "Workflow %s has metadata but no workflow.py, skipping", + workflow_dir.name, ) continue try: # Parse metadata - with open(metadata_file) as f: + with metadata_file.open() as f: metadata = yaml.safe_load(f) # Validate required fields - if 'name' not in metadata: - logger.warning(f"Workflow {workflow_dir.name} metadata missing 'name' field") - metadata['name'] = workflow_dir.name + if "name" not in metadata: + logger.warning("Workflow %s metadata missing 'name' field", workflow_dir.name) + metadata["name"] = workflow_dir.name - if 'vertical' not in metadata: + if "vertical" not in metadata: logger.warning( - f"Workflow {workflow_dir.name} metadata missing 'vertical' field" + "Workflow %s metadata missing 'vertical' field", + workflow_dir.name, ) continue # Infer workflow class name from metadata or use convention - workflow_type = metadata.get('workflow_class') + workflow_type = metadata.get("workflow_class") if not workflow_type: # Convention: convert snake_case to PascalCase + Workflow # e.g., rust_test -> RustTestWorkflow - parts = workflow_dir.name.split('_') - workflow_type = ''.join(part.capitalize() for part in parts) + 'Workflow' + parts = workflow_dir.name.split("_") + workflow_type = "".join(part.capitalize() for part in parts) + "Workflow" # Create workflow info info = WorkflowInfo( - name=metadata['name'], + name=metadata["name"], path=workflow_dir, workflow_file=workflow_file, metadata=metadata, workflow_type=workflow_type, - vertical=metadata['vertical'] + vertical=metadata["vertical"], ) workflows[info.name] = info logger.info( - f"✓ Discovered workflow: {info.name} " - f"(vertical: {info.vertical}, class: {info.workflow_type})" + "✓ Discovered workflow: %s (vertical: %s, class: %s)", + info.name, + info.vertical, + info.workflow_type, ) - except Exception as e: - logger.error( - f"Error discovering workflow {workflow_dir.name}: {e}", - exc_info=True + except Exception: + logger.exception( + "Error discovering workflow %s", + workflow_dir.name, ) continue - logger.info(f"Discovered {len(workflows)} workflows") + logger.info("Discovered %s workflows", len(workflows)) return workflows def get_workflows_by_vertical( self, - workflows: Dict[str, WorkflowInfo], - vertical: str - ) -> Dict[str, WorkflowInfo]: - """ - Filter workflows by vertical. + workflows: dict[str, WorkflowInfo], + vertical: str, + ) -> dict[str, WorkflowInfo]: + """Filter workflows by vertical. Args: workflows: All discovered workflows @@ -145,32 +148,29 @@ class WorkflowDiscovery: Returns: Filtered workflows dictionary - """ - return { - name: info - for name, info in workflows.items() - if info.vertical == vertical - } - def get_available_verticals(self, workflows: Dict[str, WorkflowInfo]) -> list[str]: """ - Get list of all verticals from discovered workflows. + return {name: info for name, info in workflows.items() if info.vertical == vertical} + + def get_available_verticals(self, workflows: dict[str, WorkflowInfo]) -> list[str]: + """Get list of all verticals from discovered workflows. Args: workflows: All discovered workflows Returns: List of unique vertical names + """ - return list(set(info.vertical for info in workflows.values())) + return {info.vertical for info in workflows.values()} @staticmethod - def get_metadata_schema() -> Dict[str, Any]: - """ - Get the JSON schema for workflow metadata. + def get_metadata_schema() -> dict[str, Any]: + """Get the JSON schema for workflow metadata. Returns: JSON schema dictionary + """ return { "type": "object", @@ -178,34 +178,34 @@ class WorkflowDiscovery: "properties": { "name": { "type": "string", - "description": "Workflow name" + "description": "Workflow name", }, "version": { "type": "string", "pattern": "^\\d+\\.\\d+\\.\\d+$", - "description": "Semantic version (x.y.z)" + "description": "Semantic version (x.y.z)", }, "vertical": { "type": "string", - "description": "Vertical worker type (rust, android, web, etc.)" + "description": "Vertical worker type (rust, android, web, etc.)", }, "description": { "type": "string", - "description": "Workflow description" + "description": "Workflow description", }, "author": { "type": "string", - "description": "Workflow author" + "description": "Workflow author", }, "category": { "type": "string", "enum": ["comprehensive", "specialized", "fuzzing", "focused"], - "description": "Workflow category" + "description": "Workflow category", }, "tags": { "type": "array", "items": {"type": "string"}, - "description": "Workflow tags for categorization" + "description": "Workflow tags for categorization", }, "requirements": { "type": "object", @@ -214,7 +214,7 @@ class WorkflowDiscovery: "tools": { "type": "array", "items": {"type": "string"}, - "description": "Required security tools" + "description": "Required security tools", }, "resources": { "type": "object", @@ -223,35 +223,35 @@ class WorkflowDiscovery: "memory": { "type": "string", "pattern": "^\\d+[GMK]i$", - "description": "Memory limit (e.g., 1Gi, 512Mi)" + "description": "Memory limit (e.g., 1Gi, 512Mi)", }, "cpu": { "type": "string", "pattern": "^\\d+m?$", - "description": "CPU limit (e.g., 1000m, 2)" + "description": "CPU limit (e.g., 1000m, 2)", }, "timeout": { "type": "integer", "minimum": 60, "maximum": 7200, - "description": "Workflow timeout in seconds" - } - } - } - } + "description": "Workflow timeout in seconds", + }, + }, + }, + }, }, "parameters": { "type": "object", - "description": "Workflow parameters schema" + "description": "Workflow parameters schema", }, "default_parameters": { "type": "object", - "description": "Default parameter values" + "description": "Default parameter values", }, "required_modules": { "type": "array", "items": {"type": "string"}, - "description": "Required module names" - } - } + "description": "Required module names", + }, + }, } diff --git a/backend/src/temporal/manager.py b/backend/src/temporal/manager.py index 96d9a84..8abe305 100644 --- a/backend/src/temporal/manager.py +++ b/backend/src/temporal/manager.py @@ -1,5 +1,4 @@ -""" -Temporal Manager - Workflow execution and management +"""Temporal Manager - Workflow execution and management. Handles: - Workflow discovery from toolbox @@ -8,25 +7,26 @@ Handles: - Results retrieval """ +import asyncio import logging import os +from datetime import timedelta from pathlib import Path -from typing import Dict, Optional, Any +from typing import Any from uuid import uuid4 from temporalio.client import Client, WorkflowHandle from temporalio.common import RetryPolicy -from datetime import timedelta + +from src.storage import S3CachedStorage from .discovery import WorkflowDiscovery, WorkflowInfo -from src.storage import S3CachedStorage logger = logging.getLogger(__name__) class TemporalManager: - """ - Manages Temporal workflow execution for FuzzForge. + """Manages Temporal workflow execution for FuzzForge. This class: - Discovers available workflows from toolbox @@ -37,41 +37,42 @@ class TemporalManager: def __init__( self, - workflows_dir: Optional[Path] = None, - temporal_address: Optional[str] = None, + workflows_dir: Path | None = None, + temporal_address: str | None = None, temporal_namespace: str = "default", - storage: Optional[S3CachedStorage] = None - ): - """ - Initialize Temporal manager. + storage: S3CachedStorage | None = None, + ) -> None: + """Initialize Temporal manager. Args: workflows_dir: Path to workflows directory (default: toolbox/workflows) temporal_address: Temporal server address (default: from env or localhost:7233) temporal_namespace: Temporal namespace storage: Storage backend for file uploads (default: S3CachedStorage) + """ if workflows_dir is None: workflows_dir = Path("toolbox/workflows") self.temporal_address = temporal_address or os.getenv( - 'TEMPORAL_ADDRESS', - 'localhost:7233' + "TEMPORAL_ADDRESS", + "localhost:7233", ) self.temporal_namespace = temporal_namespace self.discovery = WorkflowDiscovery(workflows_dir) - self.workflows: Dict[str, WorkflowInfo] = {} - self.client: Optional[Client] = None + self.workflows: dict[str, WorkflowInfo] = {} + self.client: Client | None = None # Initialize storage backend self.storage = storage or S3CachedStorage() logger.info( - f"TemporalManager initialized: {self.temporal_address} " - f"(namespace: {self.temporal_namespace})" + "TemporalManager initialized: %s (namespace: %s)", + self.temporal_address, + self.temporal_namespace, ) - async def initialize(self): + async def initialize(self) -> None: """Initialize the manager by discovering workflows and connecting to Temporal.""" try: # Discover workflows @@ -81,45 +82,46 @@ class TemporalManager: logger.warning("No workflows discovered") else: logger.info( - f"Discovered {len(self.workflows)} workflows: " - f"{list(self.workflows.keys())}" + "Discovered %s workflows: %s", + len(self.workflows), + list(self.workflows.keys()), ) # Connect to Temporal self.client = await Client.connect( self.temporal_address, - namespace=self.temporal_namespace + namespace=self.temporal_namespace, ) - logger.info(f"✓ Connected to Temporal: {self.temporal_address}") + logger.info("✓ Connected to Temporal: %s", self.temporal_address) - except Exception as e: - logger.error(f"Failed to initialize Temporal manager: {e}", exc_info=True) + except Exception: + logger.exception("Failed to initialize Temporal manager") raise - async def close(self): + async def close(self) -> None: """Close Temporal client connection.""" if self.client: # Temporal client doesn't need explicit close in Python SDK pass - async def get_workflows(self) -> Dict[str, WorkflowInfo]: - """ - Get all discovered workflows. + async def get_workflows(self) -> dict[str, WorkflowInfo]: + """Get all discovered workflows. Returns: Dictionary mapping workflow names to their info + """ return self.workflows - async def get_workflow(self, name: str) -> Optional[WorkflowInfo]: - """ - Get workflow info by name. + async def get_workflow(self, name: str) -> WorkflowInfo | None: + """Get workflow info by name. Args: name: Workflow name Returns: WorkflowInfo or None if not found + """ return self.workflows.get(name) @@ -127,10 +129,9 @@ class TemporalManager: self, file_path: Path, user_id: str, - metadata: Optional[Dict[str, Any]] = None + metadata: dict[str, Any] | None = None, ) -> str: - """ - Upload target file to storage. + """Upload target file to storage. Args: file_path: Local path to file @@ -139,20 +140,20 @@ class TemporalManager: Returns: Target ID for use in workflow execution + """ target_id = await self.storage.upload_target(file_path, user_id, metadata) - logger.info(f"Uploaded target: {target_id}") + logger.info("Uploaded target: %s", target_id) return target_id async def run_workflow( self, workflow_name: str, target_id: str, - workflow_params: Optional[Dict[str, Any]] = None, - workflow_id: Optional[str] = None + workflow_params: dict[str, Any] | None = None, + workflow_id: str | None = None, ) -> WorkflowHandle: - """ - Execute a workflow. + """Execute a workflow. Args: workflow_name: Name of workflow to execute @@ -165,14 +166,17 @@ class TemporalManager: Raises: ValueError: If workflow not found or client not initialized + """ if not self.client: - raise ValueError("Temporal client not initialized. Call initialize() first.") + msg = "Temporal client not initialized. Call initialize() first." + raise ValueError(msg) # Get workflow info workflow_info = self.workflows.get(workflow_name) if not workflow_info: - raise ValueError(f"Workflow not found: {workflow_name}") + msg = f"Workflow not found: {workflow_name}" + raise ValueError(msg) # Generate workflow ID if not provided if not workflow_id: @@ -188,23 +192,23 @@ class TemporalManager: # Add parameters in order based on metadata schema # This ensures parameters match the workflow signature order # Apply defaults from metadata.yaml if parameter not provided - if 'parameters' in workflow_info.metadata: - param_schema = workflow_info.metadata['parameters'].get('properties', {}) - logger.debug(f"Found {len(param_schema)} parameters in schema") + if "parameters" in workflow_info.metadata: + param_schema = workflow_info.metadata["parameters"].get("properties", {}) + logger.debug("Found %s parameters in schema", len(param_schema)) # Iterate parameters in schema order and add values - for param_name in param_schema.keys(): + for param_name in param_schema: param_spec = param_schema[param_name] # Use provided param, or fall back to default from metadata if workflow_params and param_name in workflow_params: param_value = workflow_params[param_name] - logger.debug(f"Using provided value for {param_name}: {param_value}") - elif 'default' in param_spec: - param_value = param_spec['default'] - logger.debug(f"Using default for {param_name}: {param_value}") + logger.debug("Using provided value for %s: %s", param_name, param_value) + elif "default" in param_spec: + param_value = param_spec["default"] + logger.debug("Using default for %s: %s", param_name, param_value) else: param_value = None - logger.debug(f"No value or default for {param_name}, using None") + logger.debug("No value or default for {param_name}, using None") workflow_args.append(param_value) else: @@ -215,11 +219,14 @@ class TemporalManager: task_queue = f"{vertical}-queue" logger.info( - f"Starting workflow: {workflow_name} " - f"(id={workflow_id}, queue={task_queue}, target={target_id})" + "Starting workflow: %s (id=%s, queue=%s, target=%s)", + workflow_name, + workflow_id, + task_queue, + target_id, ) - logger.info(f"DEBUG: workflow_args = {workflow_args}") - logger.info(f"DEBUG: workflow_params received = {workflow_params}") + logger.info("DEBUG: workflow_args = %s", workflow_args) + logger.infof("DEBUG: workflow_params received = %s", workflow_params) try: # Start workflow execution with positional arguments @@ -231,20 +238,20 @@ class TemporalManager: retry_policy=RetryPolicy( initial_interval=timedelta(seconds=1), maximum_interval=timedelta(minutes=1), - maximum_attempts=3 - ) + maximum_attempts=3, + ), ) - logger.info(f"✓ Workflow started: {workflow_id}") + logger.info("✓ Workflow started: %s", workflow_id) + + except Exception: + logger.exception("Failed to start workflow %s", workflow_name) + raise + else: return handle - except Exception as e: - logger.error(f"Failed to start workflow {workflow_name}: {e}", exc_info=True) - raise - - async def get_workflow_status(self, workflow_id: str) -> Dict[str, Any]: - """ - Get workflow execution status. + async def get_workflow_status(self, workflow_id: str) -> dict[str, Any]: + """Get workflow execution status. Args: workflow_id: Workflow execution ID @@ -254,9 +261,11 @@ class TemporalManager: Raises: ValueError: If client not initialized or workflow not found + """ if not self.client: - raise ValueError("Temporal client not initialized") + msg = "Temporal client not initialized" + raise ValueError(msg) try: # Get workflow handle @@ -274,20 +283,20 @@ class TemporalManager: "task_queue": description.task_queue, } - logger.info(f"Workflow {workflow_id} status: {status['status']}") - return status + logger.info("Workflow %s status: %s", workflow_id, status["status"]) - except Exception as e: - logger.error(f"Failed to get workflow status: {e}", exc_info=True) + except Exception: + logger.exception("Failed to get workflow status") raise + else: + return status async def get_workflow_result( self, workflow_id: str, - timeout: Optional[timedelta] = None + timeout: timedelta | None = None, ) -> Any: - """ - Get workflow execution result (blocking). + """Get workflow execution result (blocking). Args: workflow_id: Workflow execution ID @@ -299,60 +308,62 @@ class TemporalManager: Raises: ValueError: If client not initialized TimeoutError: If timeout exceeded + """ if not self.client: - raise ValueError("Temporal client not initialized") + msg = "Temporal client not initialized" + raise ValueError(msg) try: handle = self.client.get_workflow_handle(workflow_id) - logger.info(f"Waiting for workflow result: {workflow_id}") + logger.info("Waiting for workflow result: %s", workflow_id) # Wait for workflow to complete and get result if timeout: # Use asyncio timeout if provided - import asyncio result = await asyncio.wait_for(handle.result(), timeout=timeout.total_seconds()) else: result = await handle.result() - logger.info(f"✓ Workflow {workflow_id} completed") + logger.info("✓ Workflow %s completed", workflow_id) + + except Exception: + logger.exception("Failed to get workflow result") + raise + else: return result - except Exception as e: - logger.error(f"Failed to get workflow result: {e}", exc_info=True) - raise - async def cancel_workflow(self, workflow_id: str) -> None: - """ - Cancel a running workflow. + """Cancel a running workflow. Args: workflow_id: Workflow execution ID Raises: ValueError: If client not initialized + """ if not self.client: - raise ValueError("Temporal client not initialized") + msg = "Temporal client not initialized" + raise ValueError(msg) try: handle = self.client.get_workflow_handle(workflow_id) await handle.cancel() - logger.info(f"✓ Workflow cancelled: {workflow_id}") + logger.info("✓ Workflow cancelled: %s", workflow_id) - except Exception as e: - logger.error(f"Failed to cancel workflow: {e}", exc_info=True) + except Exception: + logger.exception("Failed to cancel workflow: %s") raise async def list_workflows( self, - filter_query: Optional[str] = None, - limit: int = 100 - ) -> list[Dict[str, Any]]: - """ - List workflow executions. + filter_query: str | None = None, + limit: int = 100, + ) -> list[dict[str, Any]]: + """List workflow executions. Args: filter_query: Optional Temporal list filter query @@ -363,30 +374,36 @@ class TemporalManager: Raises: ValueError: If client not initialized + """ if not self.client: - raise ValueError("Temporal client not initialized") + msg = "Temporal client not initialized" + raise ValueError(msg) try: workflows = [] # Use Temporal's list API async for workflow in self.client.list_workflows(filter_query): - workflows.append({ - "workflow_id": workflow.id, - "workflow_type": workflow.workflow_type, - "status": workflow.status.name, - "start_time": workflow.start_time.isoformat() if workflow.start_time else None, - "close_time": workflow.close_time.isoformat() if workflow.close_time else None, - "task_queue": workflow.task_queue, - }) + workflows.append( + { + "workflow_id": workflow.id, + "workflow_type": workflow.workflow_type, + "status": workflow.status.name, + "start_time": workflow.start_time.isoformat() if workflow.start_time else None, + "close_time": workflow.close_time.isoformat() if workflow.close_time else None, + "task_queue": workflow.task_queue, + }, + ) if len(workflows) >= limit: break - logger.info(f"Listed {len(workflows)} workflows") + logger.info("Listed %s workflows", len(workflows)) return workflows - except Exception as e: - logger.error(f"Failed to list workflows: {e}", exc_info=True) + except Exception: + logger.exception("Failed to list workflows") raise + else: + return workflows diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index 0bc6eee..a40bc5d 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -8,11 +8,19 @@ # See the LICENSE-APACHE file or http://www.apache.org/licenses/LICENSE-2.0 # # Additional attribution and requirements are provided in the NOTICE file. +"""Fixtures used across tests.""" import sys +from collections.abc import Callable from pathlib import Path -from typing import Dict, Any +from types import CoroutineType +from typing import Any + import pytest +from modules.analyzer.security_analyzer import SecurityAnalyzer +from modules.fuzzer.atheris_fuzzer import AtherisFuzzer +from modules.fuzzer.cargo_fuzzer import CargoFuzzer +from modules.scanner.file_scanner import FileScanner # Ensure project root is on sys.path so `src` is importable ROOT = Path(__file__).resolve().parents[1] @@ -29,17 +37,18 @@ if str(TOOLBOX) not in sys.path: # Workspace Fixtures # ============================================================================ + @pytest.fixture -def temp_workspace(tmp_path): - """Create a temporary workspace directory for testing""" +def temp_workspace(tmp_path: Path) -> Path: + """Create a temporary workspace directory for testing.""" workspace = tmp_path / "workspace" workspace.mkdir() return workspace @pytest.fixture -def python_test_workspace(temp_workspace): - """Create a Python test workspace with sample files""" +def python_test_workspace(temp_workspace: Path) -> Path: + """Create a Python test workspace with sample files.""" # Create a simple Python project structure (temp_workspace / "main.py").write_text(""" def process_data(data): @@ -62,8 +71,8 @@ AWS_SECRET = "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY" @pytest.fixture -def rust_test_workspace(temp_workspace): - """Create a Rust test workspace with fuzz targets""" +def rust_test_workspace(temp_workspace: Path) -> Path: + """Create a Rust test workspace with fuzz targets.""" # Create Cargo.toml (temp_workspace / "Cargo.toml").write_text("""[package] name = "test_project" @@ -131,44 +140,45 @@ fuzz_target!(|data: &[u8]| { # Module Configuration Fixtures # ============================================================================ + @pytest.fixture -def atheris_config(): - """Default Atheris fuzzer configuration""" +def atheris_config() -> dict[str, Any]: + """Return default Atheris fuzzer configuration.""" return { "target_file": "auto-discover", "max_iterations": 1000, "timeout_seconds": 10, - "corpus_dir": None + "corpus_dir": None, } @pytest.fixture -def cargo_fuzz_config(): - """Default cargo-fuzz configuration""" +def cargo_fuzz_config() -> dict[str, Any]: + """Return default cargo-fuzz configuration.""" return { "target_name": None, "max_iterations": 1000, "timeout_seconds": 10, - "sanitizer": "address" + "sanitizer": "address", } @pytest.fixture -def gitleaks_config(): - """Default Gitleaks configuration""" +def gitleaks_config() -> dict[str, Any]: + """Return default Gitleaks configuration.""" return { "config_path": None, - "scan_uncommitted": True + "scan_uncommitted": True, } @pytest.fixture -def file_scanner_config(): - """Default file scanner configuration""" +def file_scanner_config() -> dict[str, Any]: + """Return default file scanner configuration.""" return { "scan_patterns": ["*.py", "*.rs", "*.js"], "exclude_patterns": ["*.test.*", "*.spec.*"], - "max_file_size": 1048576 # 1MB + "max_file_size": 1048576, # 1MB } @@ -176,55 +186,67 @@ def file_scanner_config(): # Module Instance Fixtures # ============================================================================ + @pytest.fixture -def atheris_fuzzer(): - """Create an AtherisFuzzer instance""" - from modules.fuzzer.atheris_fuzzer import AtherisFuzzer +def atheris_fuzzer() -> AtherisFuzzer: + """Create an AtherisFuzzer instance.""" return AtherisFuzzer() @pytest.fixture -def cargo_fuzzer(): - """Create a CargoFuzzer instance""" - from modules.fuzzer.cargo_fuzzer import CargoFuzzer +def cargo_fuzzer() -> CargoFuzzer: + """Create a CargoFuzzer instance.""" return CargoFuzzer() @pytest.fixture -def file_scanner(): - """Create a FileScanner instance""" - from modules.scanner.file_scanner import FileScanner +def file_scanner() -> FileScanner: + """Create a FileScanner instance.""" return FileScanner() +@pytest.fixture +def security_analyzer() -> SecurityAnalyzer: + """Create SecurityAnalyzer instance.""" + return SecurityAnalyzer() + + # ============================================================================ # Mock Fixtures # ============================================================================ + @pytest.fixture -def mock_stats_callback(): - """Mock stats callback for fuzzing""" +def mock_stats_callback() -> Callable[[], CoroutineType]: + """Mock stats callback for fuzzing.""" stats_received = [] - async def callback(stats: Dict[str, Any]): + async def callback(stats: dict[str, Any]) -> None: stats_received.append(stats) callback.stats_received = stats_received return callback +class MockActivityInfo: + """Mock activity info.""" + + def __init__(self) -> None: + """Initialize an instance of the class.""" + self.workflow_id = "test-workflow-123" + self.activity_id = "test-activity-1" + self.attempt = 1 + + +class MockContext: + """Mock context.""" + + def __init__(self) -> None: + """Initialize an instance of the class.""" + self.info = MockActivityInfo() + + @pytest.fixture -def mock_temporal_context(): - """Mock Temporal activity context""" - class MockActivityInfo: - def __init__(self): - self.workflow_id = "test-workflow-123" - self.activity_id = "test-activity-1" - self.attempt = 1 - - class MockContext: - def __init__(self): - self.info = MockActivityInfo() - +def mock_temporal_context() -> MockContext: + """Mock Temporal activity context.""" return MockContext() - diff --git a/backend/tests/fixtures/__init__.py b/backend/tests/fixtures/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/backend/tests/integration/__init__.py b/backend/tests/integration/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/backend/tests/unit/__init__.py b/backend/tests/unit/__init__.py index e69de29..e0310a0 100644 --- a/backend/tests/unit/__init__.py +++ b/backend/tests/unit/__init__.py @@ -0,0 +1 @@ +"""Unit tests.""" diff --git a/backend/tests/unit/test_api/__init__.py b/backend/tests/unit/test_api/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/backend/tests/unit/test_modules/__init__.py b/backend/tests/unit/test_modules/__init__.py index e69de29..ef70ba2 100644 --- a/backend/tests/unit/test_modules/__init__.py +++ b/backend/tests/unit/test_modules/__init__.py @@ -0,0 +1 @@ +"""Unit tests for modules.""" diff --git a/backend/tests/unit/test_modules/test_atheris_fuzzer.py b/backend/tests/unit/test_modules/test_atheris_fuzzer.py index 9cd01ce..bcd345b 100644 --- a/backend/tests/unit/test_modules/test_atheris_fuzzer.py +++ b/backend/tests/unit/test_modules/test_atheris_fuzzer.py @@ -1,17 +1,26 @@ -""" -Unit tests for AtherisFuzzer module -""" +"""Unit tests for AtherisFuzzer module.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING +from unittest.mock import AsyncMock, patch import pytest -from unittest.mock import AsyncMock, patch + +if TYPE_CHECKING: + from collections.abc import Callable + from pathlib import Path + from typing import Any + + from modules.fuzzer.atheris_fuzzer import AtherisFuzzer @pytest.mark.asyncio class TestAtherisFuzzerMetadata: - """Test AtherisFuzzer metadata""" + """Test AtherisFuzzer metadata.""" - async def test_metadata_structure(self, atheris_fuzzer): - """Test that module metadata is properly defined""" + async def test_metadata_structure(self, atheris_fuzzer: AtherisFuzzer) -> None: + """Test that module metadata is properly defined.""" metadata = atheris_fuzzer.get_metadata() assert metadata.name == "atheris_fuzzer" @@ -22,28 +31,28 @@ class TestAtherisFuzzerMetadata: @pytest.mark.asyncio class TestAtherisFuzzerConfigValidation: - """Test configuration validation""" + """Test configuration validation.""" - async def test_valid_config(self, atheris_fuzzer, atheris_config): - """Test validation of valid configuration""" + async def test_valid_config(self, atheris_fuzzer: AtherisFuzzer, atheris_config: dict[str, Any]) -> None: + """Test validation of valid configuration.""" assert atheris_fuzzer.validate_config(atheris_config) is True - async def test_invalid_max_iterations(self, atheris_fuzzer): - """Test validation fails with invalid max_iterations""" + async def test_invalid_max_iterations(self, atheris_fuzzer: AtherisFuzzer) -> None: + """Test validation fails with invalid max_iterations.""" config = { "target_file": "fuzz_target.py", "max_iterations": -1, - "timeout_seconds": 10 + "timeout_seconds": 10, } with pytest.raises(ValueError, match="max_iterations"): atheris_fuzzer.validate_config(config) - async def test_invalid_timeout(self, atheris_fuzzer): - """Test validation fails with invalid timeout""" + async def test_invalid_timeout(self, atheris_fuzzer: AtherisFuzzer) -> None: + """Test validation fails with invalid timeout.""" config = { "target_file": "fuzz_target.py", "max_iterations": 1000, - "timeout_seconds": 0 + "timeout_seconds": 0, } with pytest.raises(ValueError, match="timeout_seconds"): atheris_fuzzer.validate_config(config) @@ -51,10 +60,10 @@ class TestAtherisFuzzerConfigValidation: @pytest.mark.asyncio class TestAtherisFuzzerDiscovery: - """Test fuzz target discovery""" + """Test fuzz target discovery.""" - async def test_auto_discover(self, atheris_fuzzer, python_test_workspace): - """Test auto-discovery of Python fuzz targets""" + async def test_auto_discover(self, atheris_fuzzer: AtherisFuzzer, python_test_workspace: Path) -> None: + """Test auto-discovery of Python fuzz targets.""" # Create a fuzz target file (python_test_workspace / "fuzz_target.py").write_text(""" import atheris @@ -69,7 +78,7 @@ if __name__ == "__main__": """) # Pass None for auto-discovery - target = atheris_fuzzer._discover_target(python_test_workspace, None) + target = atheris_fuzzer._discover_target(python_test_workspace, None) # noqa: SLF001 assert target is not None assert "fuzz_target.py" in str(target) @@ -77,10 +86,14 @@ if __name__ == "__main__": @pytest.mark.asyncio class TestAtherisFuzzerExecution: - """Test fuzzer execution logic""" + """Test fuzzer execution logic.""" - async def test_execution_creates_result(self, atheris_fuzzer, python_test_workspace, atheris_config): - """Test that execution returns a ModuleResult""" + async def test_execution_creates_result( + self, + atheris_fuzzer: AtherisFuzzer, + python_test_workspace: Path, + ) -> None: + """Test that execution returns a ModuleResult.""" # Create a simple fuzz target (python_test_workspace / "fuzz_target.py").write_text(""" import atheris @@ -99,11 +112,16 @@ if __name__ == "__main__": test_config = { "target_file": "fuzz_target.py", "max_iterations": 10, - "timeout_seconds": 1 + "timeout_seconds": 1, } # Mock the fuzzing subprocess to avoid actual execution - with patch.object(atheris_fuzzer, '_run_fuzzing', new_callable=AsyncMock, return_value=([], {"total_executions": 10})): + with patch.object( + atheris_fuzzer, + "_run_fuzzing", + new_callable=AsyncMock, + return_value=([], {"total_executions": 10}), + ): result = await atheris_fuzzer.execute(test_config, python_test_workspace) assert result.module == "atheris_fuzzer" @@ -113,10 +131,16 @@ if __name__ == "__main__": @pytest.mark.asyncio class TestAtherisFuzzerStatsCallback: - """Test stats callback functionality""" + """Test stats callback functionality.""" - async def test_stats_callback_invoked(self, atheris_fuzzer, python_test_workspace, atheris_config, mock_stats_callback): - """Test that stats callback is invoked during fuzzing""" + async def test_stats_callback_invoked( + self, + atheris_fuzzer: AtherisFuzzer, + python_test_workspace: Path, + atheris_config: dict[str, Any], + mock_stats_callback: Callable | None, + ) -> None: + """Test that stats callback is invoked during fuzzing.""" (python_test_workspace / "fuzz_target.py").write_text(""" import atheris import sys @@ -130,35 +154,45 @@ if __name__ == "__main__": """) # Mock fuzzing to simulate stats - async def mock_run_fuzzing(test_one_input, target_path, workspace, max_iterations, timeout_seconds, stats_callback): + async def mock_run_fuzzing( + test_one_input: Callable, # noqa: ARG001 + target_path: Path, # noqa: ARG001 + workspace: Path, # noqa: ARG001 + max_iterations: int, # noqa: ARG001 + timeout_seconds: int, # noqa: ARG001 + stats_callback: Callable | None, + ) -> None: if stats_callback: - await stats_callback({ - "total_execs": 100, - "execs_per_sec": 10.0, - "crashes": 0, - "coverage": 5, - "corpus_size": 2, - "elapsed_time": 10 - }) - return + await stats_callback( + { + "total_execs": 100, + "execs_per_sec": 10.0, + "crashes": 0, + "coverage": 5, + "corpus_size": 2, + "elapsed_time": 10, + }, + ) - with patch.object(atheris_fuzzer, '_run_fuzzing', side_effect=mock_run_fuzzing): - with patch.object(atheris_fuzzer, '_load_target_module', return_value=lambda x: None): - # Put stats_callback in config dict, not as kwarg - atheris_config["target_file"] = "fuzz_target.py" - atheris_config["stats_callback"] = mock_stats_callback - await atheris_fuzzer.execute(atheris_config, python_test_workspace) + with ( + patch.object(atheris_fuzzer, "_run_fuzzing", side_effect=mock_run_fuzzing), + patch.object(atheris_fuzzer, "_load_target_module", return_value=lambda _x: None), + ): + # Put stats_callback in config dict, not as kwarg + atheris_config["target_file"] = "fuzz_target.py" + atheris_config["stats_callback"] = mock_stats_callback + await atheris_fuzzer.execute(atheris_config, python_test_workspace) - # Verify callback was invoked - assert len(mock_stats_callback.stats_received) > 0 + # Verify callback was invoked + assert len(mock_stats_callback.stats_received) > 0 @pytest.mark.asyncio class TestAtherisFuzzerFindingGeneration: - """Test finding generation from crashes""" + """Test finding generation from crashes.""" - async def test_create_crash_finding(self, atheris_fuzzer): - """Test crash finding creation""" + async def test_create_crash_finding(self, atheris_fuzzer: AtherisFuzzer) -> None: + """Test crash finding creation.""" finding = atheris_fuzzer.create_finding( title="Crash: Exception in TestOneInput", description="IndexError: list index out of range", @@ -167,8 +201,8 @@ class TestAtherisFuzzerFindingGeneration: file_path="fuzz_target.py", metadata={ "crash_type": "IndexError", - "stack_trace": "Traceback..." - } + "stack_trace": "Traceback...", + }, ) assert finding.title == "Crash: Exception in TestOneInput" diff --git a/backend/tests/unit/test_modules/test_cargo_fuzzer.py b/backend/tests/unit/test_modules/test_cargo_fuzzer.py index f550b9a..317773a 100644 --- a/backend/tests/unit/test_modules/test_cargo_fuzzer.py +++ b/backend/tests/unit/test_modules/test_cargo_fuzzer.py @@ -1,17 +1,26 @@ -""" -Unit tests for CargoFuzzer module -""" +"""Unit tests for CargoFuzzer module.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING +from unittest.mock import AsyncMock, patch import pytest -from unittest.mock import AsyncMock, patch + +if TYPE_CHECKING: + from collections.abc import Callable + from pathlib import Path + from typing import Any + + from modules.fuzzer.cargo_fuzzer import CargoFuzzer @pytest.mark.asyncio class TestCargoFuzzerMetadata: - """Test CargoFuzzer metadata""" + """Test CargoFuzzer metadata.""" - async def test_metadata_structure(self, cargo_fuzzer): - """Test that module metadata is properly defined""" + async def test_metadata_structure(self, cargo_fuzzer: CargoFuzzer) -> None: + """Test that module metadata is properly defined.""" metadata = cargo_fuzzer.get_metadata() assert metadata.name == "cargo_fuzz" @@ -23,38 +32,38 @@ class TestCargoFuzzerMetadata: @pytest.mark.asyncio class TestCargoFuzzerConfigValidation: - """Test configuration validation""" + """Test configuration validation.""" - async def test_valid_config(self, cargo_fuzzer, cargo_fuzz_config): - """Test validation of valid configuration""" + async def test_valid_config(self, cargo_fuzzer: CargoFuzzer, cargo_fuzz_config: dict[str, Any]) -> None: + """Test validation of valid configuration.""" assert cargo_fuzzer.validate_config(cargo_fuzz_config) is True - async def test_invalid_max_iterations(self, cargo_fuzzer): - """Test validation fails with invalid max_iterations""" + async def test_invalid_max_iterations(self, cargo_fuzzer: CargoFuzzer) -> None: + """Test validation fails with invalid max_iterations.""" config = { "max_iterations": -1, "timeout_seconds": 10, - "sanitizer": "address" + "sanitizer": "address", } with pytest.raises(ValueError, match="max_iterations"): cargo_fuzzer.validate_config(config) - async def test_invalid_timeout(self, cargo_fuzzer): - """Test validation fails with invalid timeout""" + async def test_invalid_timeout(self, cargo_fuzzer: CargoFuzzer) -> None: + """Test validation fails with invalid timeout.""" config = { "max_iterations": 1000, "timeout_seconds": 0, - "sanitizer": "address" + "sanitizer": "address", } with pytest.raises(ValueError, match="timeout_seconds"): cargo_fuzzer.validate_config(config) - async def test_invalid_sanitizer(self, cargo_fuzzer): - """Test validation fails with invalid sanitizer""" + async def test_invalid_sanitizer(self, cargo_fuzzer: CargoFuzzer) -> None: + """Test validation fails with invalid sanitizer.""" config = { "max_iterations": 1000, "timeout_seconds": 10, - "sanitizer": "invalid_sanitizer" + "sanitizer": "invalid_sanitizer", } with pytest.raises(ValueError, match="sanitizer"): cargo_fuzzer.validate_config(config) @@ -62,20 +71,20 @@ class TestCargoFuzzerConfigValidation: @pytest.mark.asyncio class TestCargoFuzzerWorkspaceValidation: - """Test workspace validation""" + """Test workspace validation.""" - async def test_valid_workspace(self, cargo_fuzzer, rust_test_workspace): - """Test validation of valid workspace""" + async def test_valid_workspace(self, cargo_fuzzer: CargoFuzzer, rust_test_workspace: Path) -> None: + """Test validation of valid workspace.""" assert cargo_fuzzer.validate_workspace(rust_test_workspace) is True - async def test_nonexistent_workspace(self, cargo_fuzzer, tmp_path): - """Test validation fails with nonexistent workspace""" + async def test_nonexistent_workspace(self, cargo_fuzzer: CargoFuzzer, tmp_path: Path) -> None: + """Test validation fails with nonexistent workspace.""" nonexistent = tmp_path / "does_not_exist" with pytest.raises(ValueError, match="does not exist"): cargo_fuzzer.validate_workspace(nonexistent) - async def test_workspace_is_file(self, cargo_fuzzer, tmp_path): - """Test validation fails when workspace is a file""" + async def test_workspace_is_file(self, cargo_fuzzer: CargoFuzzer, tmp_path: Path) -> None: + """Test validation fails when workspace is a file.""" file_path = tmp_path / "file.txt" file_path.write_text("test") with pytest.raises(ValueError, match="not a directory"): @@ -84,41 +93,58 @@ class TestCargoFuzzerWorkspaceValidation: @pytest.mark.asyncio class TestCargoFuzzerDiscovery: - """Test fuzz target discovery""" + """Test fuzz target discovery.""" - async def test_discover_targets(self, cargo_fuzzer, rust_test_workspace): - """Test discovery of fuzz targets""" - targets = await cargo_fuzzer._discover_fuzz_targets(rust_test_workspace) + async def test_discover_targets(self, cargo_fuzzer: CargoFuzzer, rust_test_workspace: Path) -> None: + """Test discovery of fuzz targets.""" + targets = await cargo_fuzzer._discover_fuzz_targets(rust_test_workspace) # noqa: SLF001 assert len(targets) == 1 assert "fuzz_target_1" in targets - async def test_no_fuzz_directory(self, cargo_fuzzer, temp_workspace): - """Test discovery with no fuzz directory""" - targets = await cargo_fuzzer._discover_fuzz_targets(temp_workspace) + async def test_no_fuzz_directory(self, cargo_fuzzer: CargoFuzzer, temp_workspace: Path) -> None: + """Test discovery with no fuzz directory.""" + targets = await cargo_fuzzer._discover_fuzz_targets(temp_workspace) # noqa: SLF001 assert targets == [] @pytest.mark.asyncio class TestCargoFuzzerExecution: - """Test fuzzer execution logic""" + """Test fuzzer execution logic.""" - async def test_execution_creates_result(self, cargo_fuzzer, rust_test_workspace, cargo_fuzz_config): - """Test that execution returns a ModuleResult""" + async def test_execution_creates_result( + self, + cargo_fuzzer: CargoFuzzer, + rust_test_workspace: Path, + cargo_fuzz_config: dict[str, Any], + ) -> None: + """Test that execution returns a ModuleResult.""" # Mock the build and run methods to avoid actual fuzzing - with patch.object(cargo_fuzzer, '_build_fuzz_target', new_callable=AsyncMock, return_value=True): - with patch.object(cargo_fuzzer, '_run_fuzzing', new_callable=AsyncMock, return_value=([], {"total_executions": 0, "crashes_found": 0})): - with patch.object(cargo_fuzzer, '_parse_crash_artifacts', new_callable=AsyncMock, return_value=[]): - result = await cargo_fuzzer.execute(cargo_fuzz_config, rust_test_workspace) + with ( + patch.object(cargo_fuzzer, "_build_fuzz_target", new_callable=AsyncMock, return_value=True), + patch.object( + cargo_fuzzer, + "_run_fuzzing", + new_callable=AsyncMock, + return_value=([], {"total_executions": 0, "crashes_found": 0}), + ), + patch.object(cargo_fuzzer, "_parse_crash_artifacts", new_callable=AsyncMock, return_value=[]), + ): + result = await cargo_fuzzer.execute(cargo_fuzz_config, rust_test_workspace) - assert result.module == "cargo_fuzz" - assert result.status == "success" - assert isinstance(result.execution_time, float) - assert result.execution_time >= 0 + assert result.module == "cargo_fuzz" + assert result.status == "success" + assert isinstance(result.execution_time, float) + assert result.execution_time >= 0 - async def test_execution_with_no_targets(self, cargo_fuzzer, temp_workspace, cargo_fuzz_config): - """Test execution fails gracefully with no fuzz targets""" + async def test_execution_with_no_targets( + self, + cargo_fuzzer: CargoFuzzer, + temp_workspace: Path, + cargo_fuzz_config: dict[str, Any], + ) -> None: + """Test execution fails gracefully with no fuzz targets.""" result = await cargo_fuzzer.execute(cargo_fuzz_config, temp_workspace) assert result.status == "failed" @@ -127,47 +153,67 @@ class TestCargoFuzzerExecution: @pytest.mark.asyncio class TestCargoFuzzerStatsCallback: - """Test stats callback functionality""" + """Test stats callback functionality.""" + + async def test_stats_callback_invoked( + self, + cargo_fuzzer: CargoFuzzer, + rust_test_workspace: Path, + cargo_fuzz_config: dict[str, Any], + mock_stats_callback: Callable | None, + ) -> None: + """Test that stats callback is invoked during fuzzing.""" - async def test_stats_callback_invoked(self, cargo_fuzzer, rust_test_workspace, cargo_fuzz_config, mock_stats_callback): - """Test that stats callback is invoked during fuzzing""" # Mock build/run to simulate stats generation - async def mock_run_fuzzing(workspace, target, config, callback): + async def mock_run_fuzzing( + _workspace: Path, + _target: str, + _config: dict[str, Any], + callback: Callable | None, + ) -> tuple[list, dict[str, int]]: # Simulate stats callback if callback: - await callback({ - "total_execs": 1000, - "execs_per_sec": 100.0, - "crashes": 0, - "coverage": 10, - "corpus_size": 5, - "elapsed_time": 10 - }) + await callback( + { + "total_execs": 1000, + "execs_per_sec": 100.0, + "crashes": 0, + "coverage": 10, + "corpus_size": 5, + "elapsed_time": 10, + }, + ) return [], {"total_executions": 1000} - with patch.object(cargo_fuzzer, '_build_fuzz_target', new_callable=AsyncMock, return_value=True): - with patch.object(cargo_fuzzer, '_run_fuzzing', side_effect=mock_run_fuzzing): - with patch.object(cargo_fuzzer, '_parse_crash_artifacts', new_callable=AsyncMock, return_value=[]): - await cargo_fuzzer.execute(cargo_fuzz_config, rust_test_workspace, stats_callback=mock_stats_callback) + with ( + patch.object(cargo_fuzzer, "_build_fuzz_target", new_callable=AsyncMock, return_value=True), + patch.object(cargo_fuzzer, "_run_fuzzing", side_effect=mock_run_fuzzing), + patch.object(cargo_fuzzer, "_parse_crash_artifacts", new_callable=AsyncMock, return_value=[]), + ): + await cargo_fuzzer.execute( + cargo_fuzz_config, + rust_test_workspace, + stats_callback=mock_stats_callback, + ) - # Verify callback was invoked - assert len(mock_stats_callback.stats_received) > 0 - assert mock_stats_callback.stats_received[0]["total_execs"] == 1000 + # Verify callback was invoked + assert len(mock_stats_callback.stats_received) > 0 + assert mock_stats_callback.stats_received[0]["total_execs"] == 1000 @pytest.mark.asyncio class TestCargoFuzzerFindingGeneration: - """Test finding generation from crashes""" + """Test finding generation from crashes.""" - async def test_create_finding_from_crash(self, cargo_fuzzer): - """Test finding creation""" + async def test_create_finding_from_crash(self, cargo_fuzzer: CargoFuzzer) -> None: + """Test finding creation.""" finding = cargo_fuzzer.create_finding( title="Crash: Segmentation Fault", description="Test crash", severity="critical", category="crash", file_path="fuzz/fuzz_targets/fuzz_target_1.rs", - metadata={"crash_type": "SIGSEGV"} + metadata={"crash_type": "SIGSEGV"}, ) assert finding.title == "Crash: Segmentation Fault" diff --git a/backend/tests/unit/test_modules/test_file_scanner.py b/backend/tests/unit/test_modules/test_file_scanner.py index 12332f0..eccd8e7 100644 --- a/backend/tests/unit/test_modules/test_file_scanner.py +++ b/backend/tests/unit/test_modules/test_file_scanner.py @@ -1,22 +1,25 @@ -""" -Unit tests for FileScanner module -""" +"""Unit tests for FileScanner module.""" + +from __future__ import annotations import sys from pathlib import Path +from typing import TYPE_CHECKING import pytest +if TYPE_CHECKING: + from modules.scanner.file_scanner import FileScanner + sys.path.insert(0, str(Path(__file__).resolve().parents[3] / "toolbox")) - @pytest.mark.asyncio class TestFileScannerMetadata: - """Test FileScanner metadata""" + """Test FileScanner metadata.""" - async def test_metadata_structure(self, file_scanner): - """Test that metadata has correct structure""" + async def test_metadata_structure(self, file_scanner: FileScanner) -> None: + """Test that metadata has correct structure.""" metadata = file_scanner.get_metadata() assert metadata.name == "file_scanner" @@ -29,37 +32,37 @@ class TestFileScannerMetadata: @pytest.mark.asyncio class TestFileScannerConfigValidation: - """Test configuration validation""" + """Test configuration validation.""" - async def test_valid_config(self, file_scanner): - """Test that valid config passes validation""" + async def test_valid_config(self, file_scanner: FileScanner) -> None: + """Test that valid config passes validation.""" config = { "patterns": ["*.py", "*.js"], "max_file_size": 1048576, "check_sensitive": True, - "calculate_hashes": False + "calculate_hashes": False, } assert file_scanner.validate_config(config) is True - async def test_default_config(self, file_scanner): - """Test that empty config uses defaults""" + async def test_default_config(self, file_scanner: FileScanner) -> None: + """Test that empty config uses defaults.""" config = {} assert file_scanner.validate_config(config) is True - async def test_invalid_patterns_type(self, file_scanner): - """Test that non-list patterns raises error""" + async def test_invalid_patterns_type(self, file_scanner: FileScanner) -> None: + """Test that non-list patterns raises error.""" config = {"patterns": "*.py"} with pytest.raises(ValueError, match="patterns must be a list"): file_scanner.validate_config(config) - async def test_invalid_max_file_size(self, file_scanner): - """Test that invalid max_file_size raises error""" + async def test_invalid_max_file_size(self, file_scanner: FileScanner) -> None: + """Test that invalid max_file_size raises error.""" config = {"max_file_size": -1} with pytest.raises(ValueError, match="max_file_size must be a positive integer"): file_scanner.validate_config(config) - async def test_invalid_max_file_size_type(self, file_scanner): - """Test that non-integer max_file_size raises error""" + async def test_invalid_max_file_size_type(self, file_scanner: FileScanner) -> None: + """Test that non-integer max_file_size raises error.""" config = {"max_file_size": "large"} with pytest.raises(ValueError, match="max_file_size must be a positive integer"): file_scanner.validate_config(config) @@ -67,14 +70,14 @@ class TestFileScannerConfigValidation: @pytest.mark.asyncio class TestFileScannerExecution: - """Test scanner execution""" + """Test scanner execution.""" - async def test_scan_python_files(self, file_scanner, python_test_workspace): - """Test scanning Python files""" + async def test_scan_python_files(self, file_scanner: FileScanner, python_test_workspace: Path) -> None: + """Test scanning Python files.""" config = { "patterns": ["*.py"], "check_sensitive": False, - "calculate_hashes": False + "calculate_hashes": False, } result = await file_scanner.execute(config, python_test_workspace) @@ -84,15 +87,15 @@ class TestFileScannerExecution: assert len(result.findings) > 0 # Check that Python files were found - python_files = [f for f in result.findings if f.file_path.endswith('.py')] + python_files = [f for f in result.findings if f.file_path.endswith(".py")] assert len(python_files) > 0 - async def test_scan_all_files(self, file_scanner, python_test_workspace): - """Test scanning all files with wildcard""" + async def test_scan_all_files(self, file_scanner: FileScanner, python_test_workspace: Path) -> None: + """Test scanning all files with wildcard.""" config = { "patterns": ["*"], "check_sensitive": False, - "calculate_hashes": False + "calculate_hashes": False, } result = await file_scanner.execute(config, python_test_workspace) @@ -101,12 +104,12 @@ class TestFileScannerExecution: assert len(result.findings) > 0 assert result.summary["total_files"] > 0 - async def test_scan_with_multiple_patterns(self, file_scanner, python_test_workspace): - """Test scanning with multiple patterns""" + async def test_scan_with_multiple_patterns(self, file_scanner: FileScanner, python_test_workspace: Path) -> None: + """Test scanning with multiple patterns.""" config = { "patterns": ["*.py", "*.txt"], "check_sensitive": False, - "calculate_hashes": False + "calculate_hashes": False, } result = await file_scanner.execute(config, python_test_workspace) @@ -114,11 +117,11 @@ class TestFileScannerExecution: assert result.status == "success" assert len(result.findings) > 0 - async def test_empty_workspace(self, file_scanner, temp_workspace): - """Test scanning empty workspace""" + async def test_empty_workspace(self, file_scanner: FileScanner, temp_workspace: Path) -> None: + """Test scanning empty workspace.""" config = { "patterns": ["*.py"], - "check_sensitive": False + "check_sensitive": False, } result = await file_scanner.execute(config, temp_workspace) @@ -130,17 +133,17 @@ class TestFileScannerExecution: @pytest.mark.asyncio class TestFileScannerSensitiveDetection: - """Test sensitive file detection""" + """Test sensitive file detection.""" - async def test_detect_env_file(self, file_scanner, temp_workspace): - """Test detection of .env file""" + async def test_detect_env_file(self, file_scanner: FileScanner, temp_workspace: Path) -> None: + """Test detection of .env file.""" # Create .env file (temp_workspace / ".env").write_text("API_KEY=secret123") config = { "patterns": ["*"], "check_sensitive": True, - "calculate_hashes": False + "calculate_hashes": False, } result = await file_scanner.execute(config, temp_workspace) @@ -152,14 +155,14 @@ class TestFileScannerSensitiveDetection: assert len(sensitive_findings) > 0 assert any(".env" in f.title for f in sensitive_findings) - async def test_detect_private_key(self, file_scanner, temp_workspace): - """Test detection of private key file""" + async def test_detect_private_key(self, file_scanner: FileScanner, temp_workspace: Path) -> None: + """Test detection of private key file.""" # Create private key file (temp_workspace / "id_rsa").write_text("-----BEGIN RSA PRIVATE KEY-----") config = { "patterns": ["*"], - "check_sensitive": True + "check_sensitive": True, } result = await file_scanner.execute(config, temp_workspace) @@ -168,13 +171,13 @@ class TestFileScannerSensitiveDetection: sensitive_findings = [f for f in result.findings if f.category == "sensitive_file"] assert len(sensitive_findings) > 0 - async def test_no_sensitive_detection_when_disabled(self, file_scanner, temp_workspace): - """Test that sensitive detection can be disabled""" + async def test_no_sensitive_detection_when_disabled(self, file_scanner: FileScanner, temp_workspace: Path) -> None: + """Test that sensitive detection can be disabled.""" (temp_workspace / ".env").write_text("API_KEY=secret123") config = { "patterns": ["*"], - "check_sensitive": False + "check_sensitive": False, } result = await file_scanner.execute(config, temp_workspace) @@ -186,17 +189,17 @@ class TestFileScannerSensitiveDetection: @pytest.mark.asyncio class TestFileScannerHashing: - """Test file hashing functionality""" + """Test file hashing functionality.""" - async def test_hash_calculation(self, file_scanner, temp_workspace): - """Test SHA256 hash calculation""" + async def test_hash_calculation(self, file_scanner: FileScanner, temp_workspace: Path) -> None: + """Test SHA256 hash calculation.""" # Create test file test_file = temp_workspace / "test.txt" test_file.write_text("Hello World") config = { "patterns": ["*.txt"], - "calculate_hashes": True + "calculate_hashes": True, } result = await file_scanner.execute(config, temp_workspace) @@ -212,14 +215,14 @@ class TestFileScannerHashing: assert finding.metadata.get("file_hash") is not None assert len(finding.metadata["file_hash"]) == 64 # SHA256 hex length - async def test_no_hash_when_disabled(self, file_scanner, temp_workspace): - """Test that hashing can be disabled""" + async def test_no_hash_when_disabled(self, file_scanner: FileScanner, temp_workspace: Path) -> None: + """Test that hashing can be disabled.""" test_file = temp_workspace / "test.txt" test_file.write_text("Hello World") config = { "patterns": ["*.txt"], - "calculate_hashes": False + "calculate_hashes": False, } result = await file_scanner.execute(config, temp_workspace) @@ -234,10 +237,10 @@ class TestFileScannerHashing: @pytest.mark.asyncio class TestFileScannerFileTypes: - """Test file type detection""" + """Test file type detection.""" - async def test_detect_python_type(self, file_scanner, temp_workspace): - """Test detection of Python file type""" + async def test_detect_python_type(self, file_scanner: FileScanner, temp_workspace: Path) -> None: + """Test detection of Python file type.""" (temp_workspace / "script.py").write_text("print('hello')") config = {"patterns": ["*.py"]} @@ -248,8 +251,8 @@ class TestFileScannerFileTypes: assert len(py_findings) > 0 assert "python" in py_findings[0].metadata["file_type"] - async def test_detect_javascript_type(self, file_scanner, temp_workspace): - """Test detection of JavaScript file type""" + async def test_detect_javascript_type(self, file_scanner: FileScanner, temp_workspace: Path) -> None: + """Test detection of JavaScript file type.""" (temp_workspace / "app.js").write_text("console.log('hello')") config = {"patterns": ["*.js"]} @@ -260,8 +263,8 @@ class TestFileScannerFileTypes: assert len(js_findings) > 0 assert "javascript" in js_findings[0].metadata["file_type"] - async def test_file_type_summary(self, file_scanner, temp_workspace): - """Test that file type summary is generated""" + async def test_file_type_summary(self, file_scanner: FileScanner, temp_workspace: Path) -> None: + """Test that file type summary is generated.""" (temp_workspace / "script.py").write_text("print('hello')") (temp_workspace / "app.js").write_text("console.log('hello')") (temp_workspace / "readme.txt").write_text("Documentation") @@ -276,17 +279,17 @@ class TestFileScannerFileTypes: @pytest.mark.asyncio class TestFileScannerSizeLimits: - """Test file size handling""" + """Test file size handling.""" - async def test_skip_large_files(self, file_scanner, temp_workspace): - """Test that large files are skipped""" + async def test_skip_large_files(self, file_scanner: FileScanner, temp_workspace: Path) -> None: + """Test that large files are skipped.""" # Create a "large" file large_file = temp_workspace / "large.txt" large_file.write_text("x" * 1000) config = { "patterns": ["*.txt"], - "max_file_size": 500 # Set limit smaller than file + "max_file_size": 500, # Set limit smaller than file } result = await file_scanner.execute(config, temp_workspace) @@ -297,14 +300,14 @@ class TestFileScannerSizeLimits: # The file should still be counted but not have a detailed finding assert result.summary["total_files"] > 0 - async def test_process_small_files(self, file_scanner, temp_workspace): - """Test that small files are processed""" + async def test_process_small_files(self, file_scanner: FileScanner, temp_workspace: Path) -> None: + """Test that small files are processed.""" small_file = temp_workspace / "small.txt" small_file.write_text("small content") config = { "patterns": ["*.txt"], - "max_file_size": 1048576 # 1MB + "max_file_size": 1048576, # 1MB } result = await file_scanner.execute(config, temp_workspace) @@ -316,10 +319,10 @@ class TestFileScannerSizeLimits: @pytest.mark.asyncio class TestFileScannerSummary: - """Test result summary generation""" + """Test result summary generation.""" - async def test_summary_structure(self, file_scanner, python_test_workspace): - """Test that summary has correct structure""" + async def test_summary_structure(self, file_scanner: FileScanner, python_test_workspace: Path) -> None: + """Test that summary has correct structure.""" config = {"patterns": ["*"]} result = await file_scanner.execute(config, python_test_workspace) @@ -334,8 +337,8 @@ class TestFileScannerSummary: assert isinstance(result.summary["file_types"], dict) assert isinstance(result.summary["patterns_scanned"], list) - async def test_summary_counts(self, file_scanner, temp_workspace): - """Test that summary counts are accurate""" + async def test_summary_counts(self, file_scanner: FileScanner, temp_workspace: Path) -> None: + """Test that summary counts are accurate.""" # Create known files (temp_workspace / "file1.py").write_text("content1") (temp_workspace / "file2.py").write_text("content2") diff --git a/backend/tests/unit/test_modules/test_security_analyzer.py b/backend/tests/unit/test_modules/test_security_analyzer.py index 7365a78..6b5c413 100644 --- a/backend/tests/unit/test_modules/test_security_analyzer.py +++ b/backend/tests/unit/test_modules/test_security_analyzer.py @@ -1,28 +1,25 @@ -""" -Unit tests for SecurityAnalyzer module -""" +"""Unit tests for SecurityAnalyzer module.""" + +from __future__ import annotations -import pytest import sys from pathlib import Path +from typing import TYPE_CHECKING + +import pytest sys.path.insert(0, str(Path(__file__).resolve().parents[3] / "toolbox")) -from modules.analyzer.security_analyzer import SecurityAnalyzer - - -@pytest.fixture -def security_analyzer(): - """Create SecurityAnalyzer instance""" - return SecurityAnalyzer() +if TYPE_CHECKING: + from modules.analyzer.security_analyzer import SecurityAnalyzer @pytest.mark.asyncio class TestSecurityAnalyzerMetadata: - """Test SecurityAnalyzer metadata""" + """Test SecurityAnalyzer metadata.""" - async def test_metadata_structure(self, security_analyzer): - """Test that metadata has correct structure""" + async def test_metadata_structure(self, security_analyzer: SecurityAnalyzer) -> None: + """Test that metadata has correct structure.""" metadata = security_analyzer.get_metadata() assert metadata.name == "security_analyzer" @@ -35,25 +32,25 @@ class TestSecurityAnalyzerMetadata: @pytest.mark.asyncio class TestSecurityAnalyzerConfigValidation: - """Test configuration validation""" + """Test configuration validation.""" - async def test_valid_config(self, security_analyzer): - """Test that valid config passes validation""" + async def test_valid_config(self, security_analyzer: SecurityAnalyzer) -> None: + """Test that valid config passes validation.""" config = { "file_extensions": [".py", ".js"], "check_secrets": True, "check_sql": True, - "check_dangerous_functions": True + "check_dangerous_functions": True, } assert security_analyzer.validate_config(config) is True - async def test_default_config(self, security_analyzer): - """Test that empty config uses defaults""" + async def test_default_config(self, security_analyzer: SecurityAnalyzer) -> None: + """Test that empty config uses defaults.""" config = {} assert security_analyzer.validate_config(config) is True - async def test_invalid_extensions_type(self, security_analyzer): - """Test that non-list extensions raises error""" + async def test_invalid_extensions_type(self, security_analyzer: SecurityAnalyzer) -> None: + """Test that non-list extensions raises error.""" config = {"file_extensions": ".py"} with pytest.raises(ValueError, match="file_extensions must be a list"): security_analyzer.validate_config(config) @@ -61,10 +58,10 @@ class TestSecurityAnalyzerConfigValidation: @pytest.mark.asyncio class TestSecurityAnalyzerSecretDetection: - """Test hardcoded secret detection""" + """Test hardcoded secret detection.""" - async def test_detect_api_key(self, security_analyzer, temp_workspace): - """Test detection of hardcoded API key""" + async def test_detect_api_key(self, security_analyzer: SecurityAnalyzer, temp_workspace: Path) -> None: + """Test detection of hardcoded API key.""" code_file = temp_workspace / "config.py" code_file.write_text(""" # Configuration file @@ -76,7 +73,7 @@ database_url = "postgresql://localhost/db" "file_extensions": [".py"], "check_secrets": True, "check_sql": False, - "check_dangerous_functions": False + "check_dangerous_functions": False, } result = await security_analyzer.execute(config, temp_workspace) @@ -86,8 +83,8 @@ database_url = "postgresql://localhost/db" assert len(secret_findings) > 0 assert any("API Key" in f.title for f in secret_findings) - async def test_detect_password(self, security_analyzer, temp_workspace): - """Test detection of hardcoded password""" + async def test_detect_password(self, security_analyzer: SecurityAnalyzer, temp_workspace: Path) -> None: + """Test detection of hardcoded password.""" code_file = temp_workspace / "auth.py" code_file.write_text(""" def connect(): @@ -99,7 +96,7 @@ def connect(): "file_extensions": [".py"], "check_secrets": True, "check_sql": False, - "check_dangerous_functions": False + "check_dangerous_functions": False, } result = await security_analyzer.execute(config, temp_workspace) @@ -108,8 +105,8 @@ def connect(): secret_findings = [f for f in result.findings if f.category == "hardcoded_secret"] assert len(secret_findings) > 0 - async def test_detect_aws_credentials(self, security_analyzer, temp_workspace): - """Test detection of AWS credentials""" + async def test_detect_aws_credentials(self, security_analyzer: SecurityAnalyzer, temp_workspace: Path) -> None: + """Test detection of AWS credentials.""" code_file = temp_workspace / "aws_config.py" code_file.write_text(""" aws_access_key = "AKIAIOSFODNN7REALKEY" @@ -118,7 +115,7 @@ aws_secret_key = "wJalrXUtnFEMI/K7MDENG/bPxRfiCYREALKEY" config = { "file_extensions": [".py"], - "check_secrets": True + "check_secrets": True, } result = await security_analyzer.execute(config, temp_workspace) @@ -127,14 +124,18 @@ aws_secret_key = "wJalrXUtnFEMI/K7MDENG/bPxRfiCYREALKEY" aws_findings = [f for f in result.findings if "AWS" in f.title] assert len(aws_findings) >= 2 # Both access key and secret key - async def test_no_secret_detection_when_disabled(self, security_analyzer, temp_workspace): - """Test that secret detection can be disabled""" + async def test_no_secret_detection_when_disabled( + self, + security_analyzer: SecurityAnalyzer, + temp_workspace: Path, + ) -> None: + """Test that secret detection can be disabled.""" code_file = temp_workspace / "config.py" code_file.write_text('api_key = "sk_live_1234567890abcdef"') config = { "file_extensions": [".py"], - "check_secrets": False + "check_secrets": False, } result = await security_analyzer.execute(config, temp_workspace) @@ -146,10 +147,10 @@ aws_secret_key = "wJalrXUtnFEMI/K7MDENG/bPxRfiCYREALKEY" @pytest.mark.asyncio class TestSecurityAnalyzerSQLInjection: - """Test SQL injection detection""" + """Test SQL injection detection.""" - async def test_detect_string_concatenation(self, security_analyzer, temp_workspace): - """Test detection of SQL string concatenation""" + async def test_detect_string_concatenation(self, security_analyzer: SecurityAnalyzer, temp_workspace: Path) -> None: + """Test detection of SQL string concatenation.""" code_file = temp_workspace / "db.py" code_file.write_text(""" def get_user(user_id): @@ -161,7 +162,7 @@ def get_user(user_id): "file_extensions": [".py"], "check_secrets": False, "check_sql": True, - "check_dangerous_functions": False + "check_dangerous_functions": False, } result = await security_analyzer.execute(config, temp_workspace) @@ -170,8 +171,8 @@ def get_user(user_id): sql_findings = [f for f in result.findings if f.category == "sql_injection"] assert len(sql_findings) > 0 - async def test_detect_f_string_sql(self, security_analyzer, temp_workspace): - """Test detection of f-string in SQL""" + async def test_detect_f_string_sql(self, security_analyzer: SecurityAnalyzer, temp_workspace: Path) -> None: + """Test detection of f-string in SQL.""" code_file = temp_workspace / "db.py" code_file.write_text(""" def get_user(name): @@ -181,7 +182,7 @@ def get_user(name): config = { "file_extensions": [".py"], - "check_sql": True + "check_sql": True, } result = await security_analyzer.execute(config, temp_workspace) @@ -190,8 +191,12 @@ def get_user(name): sql_findings = [f for f in result.findings if f.category == "sql_injection"] assert len(sql_findings) > 0 - async def test_detect_dynamic_query_building(self, security_analyzer, temp_workspace): - """Test detection of dynamic query building""" + async def test_detect_dynamic_query_building( + self, + security_analyzer: SecurityAnalyzer, + temp_workspace: Path, + ) -> None: + """Test detection of dynamic query building.""" code_file = temp_workspace / "queries.py" code_file.write_text(""" def search(keyword): @@ -201,7 +206,7 @@ def search(keyword): config = { "file_extensions": [".py"], - "check_sql": True + "check_sql": True, } result = await security_analyzer.execute(config, temp_workspace) @@ -210,14 +215,18 @@ def search(keyword): sql_findings = [f for f in result.findings if f.category == "sql_injection"] assert len(sql_findings) > 0 - async def test_no_sql_detection_when_disabled(self, security_analyzer, temp_workspace): - """Test that SQL detection can be disabled""" + async def test_no_sql_detection_when_disabled( + self, + security_analyzer: SecurityAnalyzer, + temp_workspace: Path, + ) -> None: + """Test that SQL detection can be disabled.""" code_file = temp_workspace / "db.py" code_file.write_text('query = "SELECT * FROM users WHERE id = " + user_id') config = { "file_extensions": [".py"], - "check_sql": False + "check_sql": False, } result = await security_analyzer.execute(config, temp_workspace) @@ -229,10 +238,10 @@ def search(keyword): @pytest.mark.asyncio class TestSecurityAnalyzerDangerousFunctions: - """Test dangerous function detection""" + """Test dangerous function detection.""" - async def test_detect_eval(self, security_analyzer, temp_workspace): - """Test detection of eval() usage""" + async def test_detect_eval(self, security_analyzer: SecurityAnalyzer, temp_workspace: Path) -> None: + """Test detection of eval() usage.""" code_file = temp_workspace / "dangerous.py" code_file.write_text(""" def process_input(user_input): @@ -244,7 +253,7 @@ def process_input(user_input): "file_extensions": [".py"], "check_secrets": False, "check_sql": False, - "check_dangerous_functions": True + "check_dangerous_functions": True, } result = await security_analyzer.execute(config, temp_workspace) @@ -254,8 +263,8 @@ def process_input(user_input): assert len(dangerous_findings) > 0 assert any("eval" in f.title.lower() for f in dangerous_findings) - async def test_detect_exec(self, security_analyzer, temp_workspace): - """Test detection of exec() usage""" + async def test_detect_exec(self, security_analyzer: SecurityAnalyzer, temp_workspace: Path) -> None: + """Test detection of exec() usage.""" code_file = temp_workspace / "runner.py" code_file.write_text(""" def run_code(code): @@ -264,7 +273,7 @@ def run_code(code): config = { "file_extensions": [".py"], - "check_dangerous_functions": True + "check_dangerous_functions": True, } result = await security_analyzer.execute(config, temp_workspace) @@ -273,8 +282,8 @@ def run_code(code): dangerous_findings = [f for f in result.findings if f.category == "dangerous_function"] assert len(dangerous_findings) > 0 - async def test_detect_os_system(self, security_analyzer, temp_workspace): - """Test detection of os.system() usage""" + async def test_detect_os_system(self, security_analyzer: SecurityAnalyzer, temp_workspace: Path) -> None: + """Test detection of os.system() usage.""" code_file = temp_workspace / "commands.py" code_file.write_text(""" import os @@ -285,7 +294,7 @@ def run_command(cmd): config = { "file_extensions": [".py"], - "check_dangerous_functions": True + "check_dangerous_functions": True, } result = await security_analyzer.execute(config, temp_workspace) @@ -295,8 +304,8 @@ def run_command(cmd): assert len(dangerous_findings) > 0 assert any("os.system" in f.title for f in dangerous_findings) - async def test_detect_pickle_loads(self, security_analyzer, temp_workspace): - """Test detection of pickle.loads() usage""" + async def test_detect_pickle_loads(self, security_analyzer: SecurityAnalyzer, temp_workspace: Path) -> None: + """Test detection of pickle.loads() usage.""" code_file = temp_workspace / "serializer.py" code_file.write_text(""" import pickle @@ -307,7 +316,7 @@ def deserialize(data): config = { "file_extensions": [".py"], - "check_dangerous_functions": True + "check_dangerous_functions": True, } result = await security_analyzer.execute(config, temp_workspace) @@ -316,8 +325,8 @@ def deserialize(data): dangerous_findings = [f for f in result.findings if f.category == "dangerous_function"] assert len(dangerous_findings) > 0 - async def test_detect_javascript_eval(self, security_analyzer, temp_workspace): - """Test detection of eval() in JavaScript""" + async def test_detect_javascript_eval(self, security_analyzer: SecurityAnalyzer, temp_workspace: Path) -> None: + """Test detection of eval() in JavaScript.""" code_file = temp_workspace / "app.js" code_file.write_text(""" function processInput(userInput) { @@ -327,7 +336,7 @@ function processInput(userInput) { config = { "file_extensions": [".js"], - "check_dangerous_functions": True + "check_dangerous_functions": True, } result = await security_analyzer.execute(config, temp_workspace) @@ -336,8 +345,8 @@ function processInput(userInput) { dangerous_findings = [f for f in result.findings if f.category == "dangerous_function"] assert len(dangerous_findings) > 0 - async def test_detect_innerHTML(self, security_analyzer, temp_workspace): - """Test detection of innerHTML (XSS risk)""" + async def test_detect_inner_html(self, security_analyzer: SecurityAnalyzer, temp_workspace: Path) -> None: + """Test detection of innerHTML (XSS risk).""" code_file = temp_workspace / "dom.js" code_file.write_text(""" function updateContent(html) { @@ -347,7 +356,7 @@ function updateContent(html) { config = { "file_extensions": [".js"], - "check_dangerous_functions": True + "check_dangerous_functions": True, } result = await security_analyzer.execute(config, temp_workspace) @@ -356,14 +365,18 @@ function updateContent(html) { dangerous_findings = [f for f in result.findings if f.category == "dangerous_function"] assert len(dangerous_findings) > 0 - async def test_no_dangerous_detection_when_disabled(self, security_analyzer, temp_workspace): - """Test that dangerous function detection can be disabled""" + async def test_no_dangerous_detection_when_disabled( + self, + security_analyzer: SecurityAnalyzer, + temp_workspace: Path, + ) -> None: + """Test that dangerous function detection can be disabled.""" code_file = temp_workspace / "code.py" - code_file.write_text('result = eval(user_input)') + code_file.write_text("result = eval(user_input)") config = { "file_extensions": [".py"], - "check_dangerous_functions": False + "check_dangerous_functions": False, } result = await security_analyzer.execute(config, temp_workspace) @@ -375,10 +388,14 @@ function updateContent(html) { @pytest.mark.asyncio class TestSecurityAnalyzerMultipleIssues: - """Test detection of multiple issues in same file""" + """Test detection of multiple issues in same file.""" - async def test_detect_multiple_vulnerabilities(self, security_analyzer, temp_workspace): - """Test detection of multiple vulnerability types""" + async def test_detect_multiple_vulnerabilities( + self, + security_analyzer: SecurityAnalyzer, + temp_workspace: Path, + ) -> None: + """Test detection of multiple vulnerability types.""" code_file = temp_workspace / "vulnerable.py" code_file.write_text(""" import os @@ -404,7 +421,7 @@ def process_query(user_input): "file_extensions": [".py"], "check_secrets": True, "check_sql": True, - "check_dangerous_functions": True + "check_dangerous_functions": True, } result = await security_analyzer.execute(config, temp_workspace) @@ -423,10 +440,10 @@ def process_query(user_input): @pytest.mark.asyncio class TestSecurityAnalyzerSummary: - """Test result summary generation""" + """Test result summary generation.""" - async def test_summary_structure(self, security_analyzer, temp_workspace): - """Test that summary has correct structure""" + async def test_summary_structure(self, security_analyzer: SecurityAnalyzer, temp_workspace: Path) -> None: + """Test that summary has correct structure.""" (temp_workspace / "test.py").write_text("print('hello')") config = {"file_extensions": [".py"]} @@ -441,16 +458,16 @@ class TestSecurityAnalyzerSummary: assert isinstance(result.summary["total_findings"], int) assert isinstance(result.summary["extensions_scanned"], list) - async def test_empty_workspace(self, security_analyzer, temp_workspace): - """Test analyzing empty workspace""" + async def test_empty_workspace(self, security_analyzer: SecurityAnalyzer, temp_workspace: Path) -> None: + """Test analyzing empty workspace.""" config = {"file_extensions": [".py"]} result = await security_analyzer.execute(config, temp_workspace) assert result.status == "partial" # No files found assert result.summary["files_analyzed"] == 0 - async def test_analyze_multiple_file_types(self, security_analyzer, temp_workspace): - """Test analyzing multiple file types""" + async def test_analyze_multiple_file_types(self, security_analyzer: SecurityAnalyzer, temp_workspace: Path) -> None: + """Test analyzing multiple file types.""" (temp_workspace / "app.py").write_text("print('hello')") (temp_workspace / "script.js").write_text("console.log('hello')") (temp_workspace / "index.php").write_text("") @@ -464,10 +481,10 @@ class TestSecurityAnalyzerSummary: @pytest.mark.asyncio class TestSecurityAnalyzerFalsePositives: - """Test false positive filtering""" + """Test false positive filtering.""" - async def test_skip_test_secrets(self, security_analyzer, temp_workspace): - """Test that test/example secrets are filtered""" + async def test_skip_test_secrets(self, security_analyzer: SecurityAnalyzer, temp_workspace: Path) -> None: + """Test that test/example secrets are filtered.""" code_file = temp_workspace / "test_config.py" code_file.write_text(""" # Test configuration - should be filtered @@ -478,7 +495,7 @@ token = "sample_token_placeholder" config = { "file_extensions": [".py"], - "check_secrets": True + "check_secrets": True, } result = await security_analyzer.execute(config, temp_workspace) @@ -488,6 +505,6 @@ token = "sample_token_placeholder" secret_findings = [f for f in result.findings if f.category == "hardcoded_secret"] # Should have fewer or no findings due to false positive filtering assert len(secret_findings) == 0 or all( - not any(fp in f.description.lower() for fp in ['test', 'example', 'dummy', 'sample']) + not any(fp in f.description.lower() for fp in ["test", "example", "dummy", "sample"]) for f in secret_findings ) diff --git a/backend/tests/unit/test_workflows/__init__.py b/backend/tests/unit/test_workflows/__init__.py deleted file mode 100644 index e69de29..0000000