mirror of
https://github.com/msoedov/agentic_security.git
synced 2026-06-25 06:39:57 +02:00
test: Add coverage improvement test for tests/test_scan.py
This commit is contained in:
committed by
GitHub
parent
594f8960e8
commit
56e3c91af2
@@ -0,0 +1,126 @@
|
||||
import io
|
||||
import asyncio
|
||||
import json
|
||||
from datetime import datetime, timedelta
|
||||
from threading import Event
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
from agentic_security.routes import scan
|
||||
|
||||
# Dummy LLMSpec for success tests
|
||||
class DummyLLMSpec:
|
||||
def __init__(self, spec_string):
|
||||
self.spec_string = spec_string
|
||||
async def verify(self):
|
||||
class DummyResponse:
|
||||
status_code = 200
|
||||
text = "verification succeeded"
|
||||
elapsed = timedelta(seconds=0.5)
|
||||
return DummyResponse()
|
||||
@classmethod
|
||||
def from_string(cls, spec_string):
|
||||
return DummyLLMSpec(spec_string)
|
||||
|
||||
# Dummy scan_router generator to simulate streaming responses
|
||||
async def dummy_scan_router(request_factory, scan_parameters, tools_inbox, stop_event):
|
||||
for i in range(2):
|
||||
yield f"result {i}"
|
||||
|
||||
# Define a dummy Secrets class for testing purposes.
|
||||
class DummySecrets:
|
||||
def __init__(self):
|
||||
self.secrets = {}
|
||||
|
||||
# Create FastAPI app for testing and include the scan router.
|
||||
@pytest.fixture
|
||||
def app():
|
||||
app = FastAPI()
|
||||
app.include_router(scan.router)
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
def client(app):
|
||||
return TestClient(app)
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def patch_dependencies(monkeypatch):
|
||||
# Patch LLMSpec used in the routes with our dummy implementation.
|
||||
monkeypatch.setattr(scan, "LLMSpec", DummyLLMSpec)
|
||||
# Patch fuzzer.scan_router to use our dummy scanning generator.
|
||||
monkeypatch.setattr(scan.fuzzer, "scan_router", dummy_scan_router)
|
||||
# Patch get_stop_event to return a dummy Event.
|
||||
dummy_event = Event()
|
||||
monkeypatch.setattr(scan, "get_stop_event", lambda: dummy_event)
|
||||
# Patch get_tools_inbox to return None.
|
||||
monkeypatch.setattr(scan, "get_tools_inbox", lambda: None)
|
||||
# Patch set_current_run to be a no-op.
|
||||
monkeypatch.setattr(scan, "set_current_run", lambda x: None)
|
||||
# Patch get_in_memory_secrets to return a DummySecrets instance.
|
||||
monkeypatch.setattr(scan, "get_in_memory_secrets", lambda: DummySecrets())
|
||||
# Ensure Scan.with_secrets is a no-op if not already implemented.
|
||||
if not hasattr(scan.Scan, "with_secrets"):
|
||||
monkeypatch.setattr(scan.Scan, "with_secrets", lambda self, secrets: None)
|
||||
|
||||
def test_verify_success(client):
|
||||
"""Test /verify endpoint for a successful verification."""
|
||||
data = {"spec": "dummy"}
|
||||
response = client.post("/verify", json=data)
|
||||
res_json = response.json()
|
||||
assert response.status_code == 200
|
||||
assert res_json["status_code"] == 200
|
||||
assert res_json["body"] == "verification succeeded"
|
||||
assert "elapsed" in res_json
|
||||
assert "timestamp" in res_json
|
||||
|
||||
def test_verify_failure(client, monkeypatch):
|
||||
"""Test /verify endpoint when verification fails."""
|
||||
class DummyLLMSpecFailure:
|
||||
def __init__(self, spec_string):
|
||||
self.spec_string = spec_string
|
||||
async def verify(self):
|
||||
raise Exception("verification error")
|
||||
@classmethod
|
||||
def from_string(cls, spec_string):
|
||||
return DummyLLMSpecFailure(spec_string)
|
||||
monkeypatch.setattr(scan, "LLMSpec", DummyLLMSpecFailure)
|
||||
data = {"spec": "bad"}
|
||||
response = client.post("/verify", json=data)
|
||||
assert response.status_code == 400
|
||||
assert "verification error" in response.text
|
||||
|
||||
def test_scan(client):
|
||||
"""Test /scan endpoint to ensure streaming response works."""
|
||||
data = {"llmSpec": "dummy", "optimize": False, "maxBudget": 10, "enableMultiStepAttack": False}
|
||||
response = client.post("/scan", json=data)
|
||||
assert response.status_code == 200
|
||||
content = list(response.iter_lines())
|
||||
expected = ["result 0", "result 1"]
|
||||
assert content == expected
|
||||
|
||||
def test_stop_scan(client):
|
||||
"""Test /stop endpoint to ensure scan stopping functionality."""
|
||||
dummy_event = scan.get_stop_event()
|
||||
dummy_event.clear()
|
||||
response = client.post("/stop")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"status": "Scan stopped"}
|
||||
assert dummy_event.is_set()
|
||||
|
||||
def test_scan_csv(client):
|
||||
"""Test /scan-csv endpoint with CSV file and llmSpec upload."""
|
||||
csv_content = b"col1,col2\nvalue1,value2"
|
||||
llm_spec_content = b"dummy"
|
||||
files = {
|
||||
"file": ("dummy.csv", csv_content, "text/csv"),
|
||||
"llmSpec": ("spec.txt", llm_spec_content, "text/plain"),
|
||||
}
|
||||
response = client.post(
|
||||
"/scan-csv",
|
||||
files=files,
|
||||
data={"optimize": "false", "maxBudget": "10", "enableMultiStepAttack": "false"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
content = list(response.iter_lines())
|
||||
expected = ["result 0", "result 1"]
|
||||
assert content == expected
|
||||
Reference in New Issue
Block a user