feat: US-005 - Enhanced Refusal Detection with Hybrid Approach

Implement hybrid refusal classifier combining multiple detection methods:
- Add confidence scoring to refusal detection (HybridResult)
- Implement weighted voting with configurable thresholds
- Support require_unanimous mode for strict classification
- Add factory function create_hybrid_classifier for common setup
- Include 32 unit tests with table-driven test patterns
This commit is contained in:
Alexander Myasoedov
2026-01-28 18:52:20 +02:00
parent d5e2746567
commit b38a27d78c
3 changed files with 527 additions and 0 deletions
@@ -1 +1,5 @@
from .model import RefusalClassifier # noqa
# Note: llm_classifier and hybrid_classifier are imported lazily due to circular imports
# Use: from agentic_security.refusal_classifier.llm_classifier import LLMRefusalClassifier
# Use: from agentic_security.refusal_classifier.hybrid_classifier import HybridRefusalClassifier
@@ -0,0 +1,210 @@
"""Hybrid refusal classifier combining multiple detection methods with confidence scoring.
Combines marker-based, ML-based, and LLM-based detection for more accurate
refusal classification with reduced false positives/negatives.
"""
from dataclasses import dataclass, field
from typing import Protocol
class RefusalDetector(Protocol):
"""Protocol for refusal detection methods."""
def is_refusal(self, response: str) -> bool:
"""Check if response is a refusal."""
...
@dataclass
class DetectionResult:
"""Result from a single detection method."""
method: str
is_refusal: bool
weight: float = 1.0
@property
def weighted_score(self) -> float:
"""Return weighted score: positive for refusal, negative for non-refusal."""
return self.weight if self.is_refusal else -self.weight
@dataclass
class HybridResult:
"""Result from hybrid classification with confidence scoring."""
is_refusal: bool
confidence: float # 0.0 to 1.0
method_results: list[DetectionResult] = field(default_factory=list)
@property
def total_weight(self) -> float:
return sum(r.weight for r in self.method_results)
@property
def refusal_weight(self) -> float:
return sum(r.weight for r in self.method_results if r.is_refusal)
@dataclass
class DetectorConfig:
"""Configuration for a single detector."""
detector: RefusalDetector
weight: float = 1.0
name: str = ""
class HybridRefusalClassifier:
"""Hybrid refusal classifier combining multiple detection methods.
Uses weighted voting with configurable thresholds to combine marker-based,
ML-based, and LLM-based detection for more accurate classification.
"""
def __init__(
self,
threshold: float = 0.5,
require_unanimous: bool = False,
):
"""Initialize hybrid classifier.
Args:
threshold: Confidence threshold for refusal classification (0.0-1.0).
Higher values require more confidence to classify as refusal.
require_unanimous: If True, all detectors must agree for a refusal.
"""
self._detectors: list[DetectorConfig] = []
self.threshold = threshold
self.require_unanimous = require_unanimous
def add_detector(
self,
detector: RefusalDetector,
weight: float = 1.0,
name: str | None = None,
) -> "HybridRefusalClassifier":
"""Add a detection method with specified weight.
Args:
detector: Refusal detector implementing is_refusal(str) -> bool
weight: Weight for this detector's vote (default 1.0)
name: Optional name for identification
Returns:
self for method chaining
"""
detector_name = name or detector.__class__.__name__
self._detectors.append(DetectorConfig(
detector=detector,
weight=weight,
name=detector_name,
))
return self
def classify(self, response: str) -> HybridResult:
"""Classify response with confidence scoring.
Returns HybridResult with is_refusal, confidence, and individual method results.
"""
if not self._detectors:
return HybridResult(is_refusal=False, confidence=0.0)
results: list[DetectionResult] = []
for config in self._detectors:
try:
is_refusal = config.detector.is_refusal(response)
except Exception:
continue # Skip failed detectors
results.append(DetectionResult(
method=config.name,
is_refusal=is_refusal,
weight=config.weight,
))
if not results:
return HybridResult(is_refusal=False, confidence=0.0)
total_weight = sum(r.weight for r in results)
refusal_weight = sum(r.weight for r in results if r.is_refusal)
# Calculate confidence as how strongly detectors agree
raw_score = refusal_weight / total_weight # 0.0-1.0, 1.0 = all say refusal
# Check unanimous requirement
if self.require_unanimous:
all_agree = all(r.is_refusal for r in results) or all(not r.is_refusal for r in results)
if not all_agree:
# Disagreement - return uncertain result
return HybridResult(
is_refusal=False,
confidence=0.5,
method_results=results,
)
# Determine refusal based on threshold
is_refusal = raw_score >= self.threshold
# Confidence reflects how far from the decision boundary
if is_refusal:
confidence = raw_score
else:
confidence = 1.0 - raw_score
return HybridResult(
is_refusal=is_refusal,
confidence=confidence,
method_results=results,
)
def is_refusal(self, response: str) -> bool:
"""Check if response is a refusal (simple boolean interface).
This method provides compatibility with the RefusalClassifierPlugin interface.
"""
return self.classify(response).is_refusal
def is_refusal_with_confidence(self, response: str) -> tuple[bool, float]:
"""Check if response is a refusal and return confidence.
Returns:
Tuple of (is_refusal, confidence)
"""
result = self.classify(response)
return result.is_refusal, result.confidence
def create_hybrid_classifier(
marker_detector: RefusalDetector | None = None,
ml_detector: RefusalDetector | None = None,
llm_detector: RefusalDetector | None = None,
threshold: float = 0.5,
marker_weight: float = 1.0,
ml_weight: float = 1.5,
llm_weight: float = 2.0,
) -> HybridRefusalClassifier:
"""Factory function to create a hybrid classifier with common detectors.
Args:
marker_detector: Marker-based detector (DefaultRefusalClassifier)
ml_detector: ML-based detector (RefusalClassifier from model.py)
llm_detector: LLM-based detector (LLMRefusalClassifier)
threshold: Classification threshold (0.0-1.0)
marker_weight: Weight for marker-based detection
ml_weight: Weight for ML-based detection
llm_weight: Weight for LLM-based detection
Returns:
Configured HybridRefusalClassifier
"""
classifier = HybridRefusalClassifier(threshold=threshold)
if marker_detector is not None:
classifier.add_detector(marker_detector, weight=marker_weight, name="marker")
if ml_detector is not None:
classifier.add_detector(ml_detector, weight=ml_weight, name="ml")
if llm_detector is not None:
classifier.add_detector(llm_detector, weight=llm_weight, name="llm")
return classifier
@@ -0,0 +1,313 @@
"""Unit tests for hybrid refusal classifier."""
from inline_snapshot import snapshot
from agentic_security.refusal_classifier.hybrid_classifier import (
DetectionResult,
HybridRefusalClassifier,
HybridResult,
create_hybrid_classifier,
)
class MockDetector:
"""Mock detector for testing."""
def __init__(self, result: bool):
self.result = result
self.calls: list[str] = []
def is_refusal(self, response: str) -> bool:
self.calls.append(response)
return self.result
class FailingDetector:
"""Detector that raises exceptions."""
def is_refusal(self, response: str) -> bool:
raise RuntimeError("Detector failed")
# Table-driven tests for DetectionResult
detection_result_cases = [
# (is_refusal, weight, expected_weighted_score)
(True, 1.0, 1.0),
(False, 1.0, -1.0),
(True, 2.0, 2.0),
(False, 2.0, -2.0),
(True, 0.5, 0.5),
(False, 0.5, -0.5),
]
class TestDetectionResult:
def test_weighted_score_cases(self):
for is_refusal, weight, expected in detection_result_cases:
result = DetectionResult(method="test", is_refusal=is_refusal, weight=weight)
assert result.weighted_score == expected
def test_default_weight(self):
result = DetectionResult(method="test", is_refusal=True)
assert result.weight == snapshot(1.0)
class TestHybridResult:
def test_total_weight(self):
results = [
DetectionResult(method="a", is_refusal=True, weight=1.0),
DetectionResult(method="b", is_refusal=False, weight=2.0),
]
hybrid = HybridResult(is_refusal=True, confidence=0.8, method_results=results)
assert hybrid.total_weight == snapshot(3.0)
def test_refusal_weight(self):
results = [
DetectionResult(method="a", is_refusal=True, weight=1.0),
DetectionResult(method="b", is_refusal=False, weight=2.0),
DetectionResult(method="c", is_refusal=True, weight=0.5),
]
hybrid = HybridResult(is_refusal=True, confidence=0.8, method_results=results)
assert hybrid.refusal_weight == snapshot(1.5)
def test_empty_results(self):
hybrid = HybridResult(is_refusal=False, confidence=0.0, method_results=[])
assert hybrid.total_weight == snapshot(0.0)
assert hybrid.refusal_weight == snapshot(0.0)
class TestHybridRefusalClassifier:
def test_no_detectors_returns_false(self):
classifier = HybridRefusalClassifier()
result = classifier.classify("test response")
assert result.is_refusal is False
assert result.confidence == snapshot(0.0)
def test_single_detector_refusal(self):
classifier = HybridRefusalClassifier()
classifier.add_detector(MockDetector(True), name="mock")
result = classifier.classify("test")
assert result.is_refusal is True
assert result.confidence == snapshot(1.0)
def test_single_detector_non_refusal(self):
classifier = HybridRefusalClassifier()
classifier.add_detector(MockDetector(False), name="mock")
result = classifier.classify("test")
assert result.is_refusal is False
assert result.confidence == snapshot(1.0)
def test_two_detectors_both_refusal(self):
classifier = HybridRefusalClassifier()
classifier.add_detector(MockDetector(True), name="a")
classifier.add_detector(MockDetector(True), name="b")
result = classifier.classify("test")
assert result.is_refusal is True
assert result.confidence == snapshot(1.0)
assert len(result.method_results) == snapshot(2)
def test_two_detectors_both_non_refusal(self):
classifier = HybridRefusalClassifier()
classifier.add_detector(MockDetector(False), name="a")
classifier.add_detector(MockDetector(False), name="b")
result = classifier.classify("test")
assert result.is_refusal is False
assert result.confidence == snapshot(1.0)
def test_weighted_voting_higher_refusal(self):
classifier = HybridRefusalClassifier(threshold=0.5)
classifier.add_detector(MockDetector(True), weight=2.0, name="a")
classifier.add_detector(MockDetector(False), weight=1.0, name="b")
result = classifier.classify("test")
# refusal_weight = 2.0, total = 3.0, ratio = 0.666
assert result.is_refusal is True
assert round(result.confidence, 2) == snapshot(0.67)
def test_weighted_voting_higher_non_refusal(self):
classifier = HybridRefusalClassifier(threshold=0.5)
classifier.add_detector(MockDetector(True), weight=1.0, name="a")
classifier.add_detector(MockDetector(False), weight=2.0, name="b")
result = classifier.classify("test")
# refusal_weight = 1.0, total = 3.0, ratio = 0.333
assert result.is_refusal is False
assert round(result.confidence, 2) == snapshot(0.67)
def test_threshold_boundary(self):
classifier = HybridRefusalClassifier(threshold=0.5)
classifier.add_detector(MockDetector(True), weight=1.0, name="a")
classifier.add_detector(MockDetector(False), weight=1.0, name="b")
result = classifier.classify("test")
# ratio = 0.5, exactly at threshold
assert result.is_refusal is True
def test_high_threshold(self):
classifier = HybridRefusalClassifier(threshold=0.8)
classifier.add_detector(MockDetector(True), weight=2.0, name="a")
classifier.add_detector(MockDetector(False), weight=1.0, name="b")
result = classifier.classify("test")
# ratio = 0.666, below 0.8 threshold
assert result.is_refusal is False
def test_unanimous_required_all_agree_refusal(self):
classifier = HybridRefusalClassifier(require_unanimous=True)
classifier.add_detector(MockDetector(True), name="a")
classifier.add_detector(MockDetector(True), name="b")
result = classifier.classify("test")
assert result.is_refusal is True
def test_unanimous_required_all_agree_non_refusal(self):
classifier = HybridRefusalClassifier(require_unanimous=True)
classifier.add_detector(MockDetector(False), name="a")
classifier.add_detector(MockDetector(False), name="b")
result = classifier.classify("test")
assert result.is_refusal is False
assert result.confidence == snapshot(1.0)
def test_unanimous_required_disagreement(self):
classifier = HybridRefusalClassifier(require_unanimous=True)
classifier.add_detector(MockDetector(True), name="a")
classifier.add_detector(MockDetector(False), name="b")
result = classifier.classify("test")
# Disagreement returns uncertain result
assert result.is_refusal is False
assert result.confidence == snapshot(0.5)
def test_failing_detector_skipped(self):
classifier = HybridRefusalClassifier()
classifier.add_detector(MockDetector(True), name="good")
classifier.add_detector(FailingDetector(), name="bad")
result = classifier.classify("test")
# Only the good detector counted
assert result.is_refusal is True
assert len(result.method_results) == snapshot(1)
def test_all_detectors_fail(self):
classifier = HybridRefusalClassifier()
classifier.add_detector(FailingDetector(), name="bad1")
classifier.add_detector(FailingDetector(), name="bad2")
result = classifier.classify("test")
assert result.is_refusal is False
assert result.confidence == snapshot(0.0)
def test_method_chaining(self):
classifier = (
HybridRefusalClassifier()
.add_detector(MockDetector(True), name="a")
.add_detector(MockDetector(False), name="b")
)
assert len(classifier._detectors) == snapshot(2)
def test_detector_calls_recorded(self):
detector = MockDetector(True)
classifier = HybridRefusalClassifier()
classifier.add_detector(detector, name="mock")
classifier.classify("test input")
assert detector.calls == snapshot(["test input"])
def test_is_refusal_simple_interface(self):
classifier = HybridRefusalClassifier()
classifier.add_detector(MockDetector(True), name="mock")
assert classifier.is_refusal("test") is True
def test_is_refusal_with_confidence(self):
classifier = HybridRefusalClassifier()
classifier.add_detector(MockDetector(True), name="mock")
is_ref, conf = classifier.is_refusal_with_confidence("test")
assert is_ref is True
assert conf == snapshot(1.0)
def test_default_detector_name(self):
classifier = HybridRefusalClassifier()
classifier.add_detector(MockDetector(True))
result = classifier.classify("test")
assert result.method_results[0].method == snapshot("MockDetector")
# Table-driven tests for create_hybrid_classifier factory
factory_cases = [
# (kwargs, expected_detector_count)
({}, 0),
({"marker_detector": MockDetector(True)}, 1),
({"ml_detector": MockDetector(True)}, 1),
({"llm_detector": MockDetector(True)}, 1),
({"marker_detector": MockDetector(True), "ml_detector": MockDetector(False)}, 2),
({"marker_detector": MockDetector(True), "ml_detector": MockDetector(False), "llm_detector": MockDetector(True)}, 3),
]
class TestCreateHybridClassifier:
def test_detector_count_cases(self):
for kwargs, expected_count in factory_cases:
classifier = create_hybrid_classifier(**kwargs)
assert len(classifier._detectors) == expected_count
def test_threshold_passed(self):
classifier = create_hybrid_classifier(threshold=0.7)
assert classifier.threshold == snapshot(0.7)
def test_default_weights(self):
classifier = create_hybrid_classifier(
marker_detector=MockDetector(True),
ml_detector=MockDetector(True),
llm_detector=MockDetector(True),
)
weights = {d.name: d.weight for d in classifier._detectors}
assert weights == snapshot({"marker": 1.0, "ml": 1.5, "llm": 2.0})
def test_custom_weights(self):
classifier = create_hybrid_classifier(
marker_detector=MockDetector(True),
ml_detector=MockDetector(True),
llm_detector=MockDetector(True),
marker_weight=0.5,
ml_weight=1.0,
llm_weight=3.0,
)
weights = {d.name: d.weight for d in classifier._detectors}
assert weights == snapshot({"marker": 0.5, "ml": 1.0, "llm": 3.0})
class TestConfidenceScoring:
"""Test confidence scoring edge cases."""
def test_confidence_high_agreement_refusal(self):
classifier = HybridRefusalClassifier()
classifier.add_detector(MockDetector(True), weight=1.0, name="a")
classifier.add_detector(MockDetector(True), weight=1.0, name="b")
classifier.add_detector(MockDetector(True), weight=1.0, name="c")
result = classifier.classify("test")
assert result.confidence == snapshot(1.0)
assert result.is_refusal is True
def test_confidence_high_agreement_non_refusal(self):
classifier = HybridRefusalClassifier()
classifier.add_detector(MockDetector(False), weight=1.0, name="a")
classifier.add_detector(MockDetector(False), weight=1.0, name="b")
classifier.add_detector(MockDetector(False), weight=1.0, name="c")
result = classifier.classify("test")
assert result.confidence == snapshot(1.0)
assert result.is_refusal is False
def test_confidence_moderate_majority_refusal(self):
classifier = HybridRefusalClassifier()
classifier.add_detector(MockDetector(True), weight=1.0, name="a")
classifier.add_detector(MockDetector(True), weight=1.0, name="b")
classifier.add_detector(MockDetector(False), weight=1.0, name="c")
result = classifier.classify("test")
# 2/3 = 0.666 confidence for refusal
assert round(result.confidence, 2) == snapshot(0.67)
assert result.is_refusal is True
def test_confidence_moderate_majority_non_refusal(self):
classifier = HybridRefusalClassifier()
classifier.add_detector(MockDetector(False), weight=1.0, name="a")
classifier.add_detector(MockDetector(False), weight=1.0, name="b")
classifier.add_detector(MockDetector(True), weight=1.0, name="c")
result = classifier.classify("test")
# 2/3 = 0.666 confidence for non-refusal
assert round(result.confidence, 2) == snapshot(0.67)
assert result.is_refusal is False