Files
agentic_security/tests/unit/refusal_classifier/test_llm_classifier.py
T
Alexander Myasoedov 32f103acbc feat: US-001 - Dual-LLM Evaluation for Attack Success Detection
Add LLM-based refusal classifier inspired by Promptmap's dual-LLM
architecture. The controller LLM evaluates whether an attack succeeded
by analyzing the target's response against pass/fail conditions.

- Create LLMRefusalClassifier plugin integrating with existing system
- Support OpenAI and Anthropic providers with lazy initialization
- Add configurable system prompts and pass/fail conditions
- Include 20 unit tests for comprehensive coverage
2026-01-28 18:18:09 +02:00

178 lines
6.5 KiB
Python

"""Unit tests for LLM-based refusal classifier."""
from unittest.mock import MagicMock
import pytest
from inline_snapshot import snapshot
from agentic_security.refusal_classifier.llm_classifier import (
LLMRefusalClassifier,
build_evaluation_prompt,
create_llm_classifier,
DEFAULT_PASS_CONDITIONS,
DEFAULT_FAIL_CONDITIONS,
)
class MockLLMClient:
"""Mock LLM client for testing."""
def __init__(self, response: str = "fail"):
self.response = response
self.calls: list[tuple[str, str]] = []
def complete(self, system_prompt: str, user_prompt: str) -> str:
self.calls.append((system_prompt, user_prompt))
return self.response
class TestBuildEvaluationPrompt:
def test_build_with_defaults(self):
result = build_evaluation_prompt("Hello, I can help with that.")
assert "Pass conditions (attack succeeded):" in result
assert "Fail conditions (model refused):" in result
assert "Hello, I can help with that." in result
def test_build_with_custom_conditions(self):
result = build_evaluation_prompt(
"test output",
pass_conditions=["Custom pass 1"],
fail_conditions=["Custom fail 1"],
)
assert "Custom pass 1" in result
assert "Custom fail 1" in result
def test_build_with_empty_output(self):
result = build_evaluation_prompt("")
assert "<empty response>" in result
def test_default_conditions_exist(self):
assert len(DEFAULT_PASS_CONDITIONS) == snapshot(4)
assert len(DEFAULT_FAIL_CONDITIONS) == snapshot(6)
class TestLLMRefusalClassifier:
def test_is_refusal_pass_response(self):
client = MockLLMClient(response="pass")
classifier = LLMRefusalClassifier(client=client)
# "pass" means attack succeeded, so is_refusal returns False
assert classifier.is_refusal("Here's how to do that...") is False
def test_is_refusal_fail_response(self):
client = MockLLMClient(response="fail")
classifier = LLMRefusalClassifier(client=client)
# "fail" means model refused, so is_refusal returns True
assert classifier.is_refusal("I cannot help with that.") is True
def test_is_refusal_with_extra_text(self):
client = MockLLMClient(response="fail\n")
classifier = LLMRefusalClassifier(client=client)
assert classifier.is_refusal("test") is True
client2 = MockLLMClient(response=" pass ")
classifier2 = LLMRefusalClassifier(client=client2)
assert classifier2.is_refusal("test") is False
def test_no_client_returns_false(self):
classifier = LLMRefusalClassifier(client=None)
classifier._initialized = True # Skip lazy init
# No client means defer to other classifiers
assert classifier.is_refusal("anything") is False
def test_client_error_returns_false(self):
client = MagicMock()
client.complete.side_effect = Exception("API error")
classifier = LLMRefusalClassifier(client=client)
# Errors should return False to defer to other classifiers
assert classifier.is_refusal("test") is False
def test_custom_system_prompt(self):
client = MockLLMClient(response="pass")
custom_prompt = "Custom system prompt"
classifier = LLMRefusalClassifier(
client=client,
system_prompt=custom_prompt,
)
classifier.is_refusal("test output")
assert client.calls[0][0] == custom_prompt
def test_custom_conditions(self):
client = MockLLMClient(response="pass")
classifier = LLMRefusalClassifier(
client=client,
pass_conditions=["Custom pass"],
fail_conditions=["Custom fail"],
)
classifier.is_refusal("test output")
user_prompt = client.calls[0][1]
assert "Custom pass" in user_prompt
assert "Custom fail" in user_prompt
def test_calls_client_with_correct_prompt(self):
client = MockLLMClient(response="fail")
classifier = LLMRefusalClassifier(client=client)
classifier.is_refusal("Model says: I can help!")
assert len(client.calls) == 1
system_prompt, user_prompt = client.calls[0]
assert "security evaluation controller" in system_prompt.lower()
assert "Model says: I can help!" in user_prompt
class TestCreateLLMClassifier:
def test_create_openai_missing_key(self, monkeypatch):
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
with pytest.raises(ValueError, match="OpenAI API key required"):
create_llm_classifier(provider="openai")
def test_create_anthropic_missing_key(self, monkeypatch):
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
with pytest.raises(ValueError, match="Anthropic API key required"):
create_llm_classifier(provider="anthropic")
def test_create_unknown_provider(self):
with pytest.raises(ValueError, match="Unknown provider"):
create_llm_classifier(provider="unknown")
def test_create_with_custom_model(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-key")
classifier = create_llm_classifier(provider="openai", model="gpt-4")
assert classifier.client.model == "gpt-4"
def test_create_with_api_key(self, monkeypatch):
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
classifier = create_llm_classifier(
provider="openai",
api_key="direct-key",
)
assert classifier.client.api_key == "direct-key"
class TestLazyInitialization:
def test_lazy_init_openai(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-key")
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
classifier = LLMRefusalClassifier()
assert classifier.client is None
classifier._ensure_client()
assert classifier.client is not None
assert hasattr(classifier.client, "api_key")
def test_lazy_init_anthropic(self, monkeypatch):
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
monkeypatch.setenv("ANTHROPIC_API_KEY", "test-key")
classifier = LLMRefusalClassifier()
classifier._ensure_client()
assert classifier.client is not None
def test_lazy_init_no_keys(self, monkeypatch):
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
classifier = LLMRefusalClassifier()
classifier._ensure_client()
assert classifier.client is None