import io import asyncio import json from datetime import datetime, timedelta from threading import Event import pytest from fastapi import FastAPI from fastapi.testclient import TestClient from agentic_security.routes import scan # Dummy LLMSpec for success tests class DummyLLMSpec: def __init__(self, spec_string): self.spec_string = spec_string async def verify(self): class DummyResponse: status_code = 200 text = "verification succeeded" elapsed = timedelta(seconds=0.5) return DummyResponse() @classmethod def from_string(cls, spec_string): return DummyLLMSpec(spec_string) # Dummy scan_router generator to simulate streaming responses async def dummy_scan_router(request_factory, scan_parameters, tools_inbox, stop_event): for i in range(2): yield f"result {i}" # Define a dummy Secrets class for testing purposes. class DummySecrets: def __init__(self): self.secrets = {} # Create FastAPI app for testing and include the scan router. @pytest.fixture def app(): app = FastAPI() app.include_router(scan.router) return app @pytest.fixture def client(app): return TestClient(app) @pytest.fixture(autouse=True) def patch_dependencies(monkeypatch): # Patch LLMSpec used in the routes with our dummy implementation. monkeypatch.setattr(scan, "LLMSpec", DummyLLMSpec) # Patch fuzzer.scan_router to use our dummy scanning generator. monkeypatch.setattr(scan.fuzzer, "scan_router", dummy_scan_router) # Patch get_stop_event to return a dummy Event. dummy_event = Event() monkeypatch.setattr(scan, "get_stop_event", lambda: dummy_event) # Patch get_tools_inbox to return None. monkeypatch.setattr(scan, "get_tools_inbox", lambda: None) # Patch set_current_run to be a no-op. monkeypatch.setattr(scan, "set_current_run", lambda x: None) # Patch get_in_memory_secrets to return a DummySecrets instance. monkeypatch.setattr(scan, "get_in_memory_secrets", lambda: DummySecrets()) # Ensure Scan.with_secrets is a no-op if not already implemented. if not hasattr(scan.Scan, "with_secrets"): monkeypatch.setattr(scan.Scan, "with_secrets", lambda self, secrets: None) def test_verify_success(client): """Test /verify endpoint for a successful verification.""" data = {"spec": "dummy"} response = client.post("/verify", json=data) res_json = response.json() assert response.status_code == 200 assert res_json["status_code"] == 200 assert res_json["body"] == "verification succeeded" assert "elapsed" in res_json assert "timestamp" in res_json def test_verify_failure(client, monkeypatch): """Test /verify endpoint when verification fails.""" class DummyLLMSpecFailure: def __init__(self, spec_string): self.spec_string = spec_string async def verify(self): raise Exception("verification error") @classmethod def from_string(cls, spec_string): return DummyLLMSpecFailure(spec_string) monkeypatch.setattr(scan, "LLMSpec", DummyLLMSpecFailure) data = {"spec": "bad"} response = client.post("/verify", json=data) assert response.status_code == 400 assert "verification error" in response.text def test_scan(client): """Test /scan endpoint to ensure streaming response works.""" data = {"llmSpec": "dummy", "optimize": False, "maxBudget": 10, "enableMultiStepAttack": False} response = client.post("/scan", json=data) assert response.status_code == 200 content = list(response.iter_lines()) expected = ["result 0", "result 1"] assert content == expected def test_stop_scan(client): """Test /stop endpoint to ensure scan stopping functionality.""" dummy_event = scan.get_stop_event() dummy_event.clear() response = client.post("/stop") assert response.status_code == 200 assert response.json() == {"status": "Scan stopped"} assert dummy_event.is_set() def test_scan_csv(client): """Test /scan-csv endpoint with CSV file and llmSpec upload.""" csv_content = b"col1,col2\nvalue1,value2" llm_spec_content = b"dummy" files = { "file": ("dummy.csv", csv_content, "text/csv"), "llmSpec": ("spec.txt", llm_spec_content, "text/plain"), } response = client.post( "/scan-csv", files=files, data={"optimize": "false", "maxBudget": "10", "enableMultiStepAttack": "false"}, ) assert response.status_code == 200 content = list(response.iter_lines()) expected = ["result 0", "result 1"] assert content == expected