diff --git a/tests/test_scan.py b/tests/test_scan.py new file mode 100644 index 0000000..4781962 --- /dev/null +++ b/tests/test_scan.py @@ -0,0 +1,126 @@ +import io +import asyncio +import json +from datetime import datetime, timedelta +from threading import Event +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient +from agentic_security.routes import scan + +# Dummy LLMSpec for success tests +class DummyLLMSpec: + def __init__(self, spec_string): + self.spec_string = spec_string + async def verify(self): + class DummyResponse: + status_code = 200 + text = "verification succeeded" + elapsed = timedelta(seconds=0.5) + return DummyResponse() + @classmethod + def from_string(cls, spec_string): + return DummyLLMSpec(spec_string) + +# Dummy scan_router generator to simulate streaming responses +async def dummy_scan_router(request_factory, scan_parameters, tools_inbox, stop_event): + for i in range(2): + yield f"result {i}" + +# Define a dummy Secrets class for testing purposes. +class DummySecrets: + def __init__(self): + self.secrets = {} + +# Create FastAPI app for testing and include the scan router. +@pytest.fixture +def app(): + app = FastAPI() + app.include_router(scan.router) + return app + +@pytest.fixture +def client(app): + return TestClient(app) + +@pytest.fixture(autouse=True) +def patch_dependencies(monkeypatch): + # Patch LLMSpec used in the routes with our dummy implementation. + monkeypatch.setattr(scan, "LLMSpec", DummyLLMSpec) + # Patch fuzzer.scan_router to use our dummy scanning generator. + monkeypatch.setattr(scan.fuzzer, "scan_router", dummy_scan_router) + # Patch get_stop_event to return a dummy Event. + dummy_event = Event() + monkeypatch.setattr(scan, "get_stop_event", lambda: dummy_event) + # Patch get_tools_inbox to return None. + monkeypatch.setattr(scan, "get_tools_inbox", lambda: None) + # Patch set_current_run to be a no-op. + monkeypatch.setattr(scan, "set_current_run", lambda x: None) + # Patch get_in_memory_secrets to return a DummySecrets instance. + monkeypatch.setattr(scan, "get_in_memory_secrets", lambda: DummySecrets()) + # Ensure Scan.with_secrets is a no-op if not already implemented. + if not hasattr(scan.Scan, "with_secrets"): + monkeypatch.setattr(scan.Scan, "with_secrets", lambda self, secrets: None) + +def test_verify_success(client): + """Test /verify endpoint for a successful verification.""" + data = {"spec": "dummy"} + response = client.post("/verify", json=data) + res_json = response.json() + assert response.status_code == 200 + assert res_json["status_code"] == 200 + assert res_json["body"] == "verification succeeded" + assert "elapsed" in res_json + assert "timestamp" in res_json + +def test_verify_failure(client, monkeypatch): + """Test /verify endpoint when verification fails.""" + class DummyLLMSpecFailure: + def __init__(self, spec_string): + self.spec_string = spec_string + async def verify(self): + raise Exception("verification error") + @classmethod + def from_string(cls, spec_string): + return DummyLLMSpecFailure(spec_string) + monkeypatch.setattr(scan, "LLMSpec", DummyLLMSpecFailure) + data = {"spec": "bad"} + response = client.post("/verify", json=data) + assert response.status_code == 400 + assert "verification error" in response.text + +def test_scan(client): + """Test /scan endpoint to ensure streaming response works.""" + data = {"llmSpec": "dummy", "optimize": False, "maxBudget": 10, "enableMultiStepAttack": False} + response = client.post("/scan", json=data) + assert response.status_code == 200 + content = list(response.iter_lines()) + expected = ["result 0", "result 1"] + assert content == expected + +def test_stop_scan(client): + """Test /stop endpoint to ensure scan stopping functionality.""" + dummy_event = scan.get_stop_event() + dummy_event.clear() + response = client.post("/stop") + assert response.status_code == 200 + assert response.json() == {"status": "Scan stopped"} + assert dummy_event.is_set() + +def test_scan_csv(client): + """Test /scan-csv endpoint with CSV file and llmSpec upload.""" + csv_content = b"col1,col2\nvalue1,value2" + llm_spec_content = b"dummy" + files = { + "file": ("dummy.csv", csv_content, "text/csv"), + "llmSpec": ("spec.txt", llm_spec_content, "text/plain"), + } + response = client.post( + "/scan-csv", + files=files, + data={"optimize": "false", "maxBudget": "10", "enableMultiStepAttack": "false"}, + ) + assert response.status_code == 200 + content = list(response.iter_lines()) + expected = ["result 0", "result 1"] + assert content == expected \ No newline at end of file