mirror of
https://github.com/msoedov/agentic_security.git
synced 2026-06-23 21:59:57 +02:00
feat: US-005 - Enhanced Refusal Detection with Hybrid Approach
Implement hybrid refusal classifier combining multiple detection methods: - Add confidence scoring to refusal detection (HybridResult) - Implement weighted voting with configurable thresholds - Support require_unanimous mode for strict classification - Add factory function create_hybrid_classifier for common setup - Include 32 unit tests with table-driven test patterns
This commit is contained in:
@@ -1 +1,5 @@
|
||||
from .model import RefusalClassifier # noqa
|
||||
|
||||
# Note: llm_classifier and hybrid_classifier are imported lazily due to circular imports
|
||||
# Use: from agentic_security.refusal_classifier.llm_classifier import LLMRefusalClassifier
|
||||
# Use: from agentic_security.refusal_classifier.hybrid_classifier import HybridRefusalClassifier
|
||||
|
||||
@@ -0,0 +1,210 @@
|
||||
"""Hybrid refusal classifier combining multiple detection methods with confidence scoring.
|
||||
|
||||
Combines marker-based, ML-based, and LLM-based detection for more accurate
|
||||
refusal classification with reduced false positives/negatives.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Protocol
|
||||
|
||||
|
||||
class RefusalDetector(Protocol):
|
||||
"""Protocol for refusal detection methods."""
|
||||
|
||||
def is_refusal(self, response: str) -> bool:
|
||||
"""Check if response is a refusal."""
|
||||
...
|
||||
|
||||
|
||||
@dataclass
|
||||
class DetectionResult:
|
||||
"""Result from a single detection method."""
|
||||
|
||||
method: str
|
||||
is_refusal: bool
|
||||
weight: float = 1.0
|
||||
|
||||
@property
|
||||
def weighted_score(self) -> float:
|
||||
"""Return weighted score: positive for refusal, negative for non-refusal."""
|
||||
return self.weight if self.is_refusal else -self.weight
|
||||
|
||||
|
||||
@dataclass
|
||||
class HybridResult:
|
||||
"""Result from hybrid classification with confidence scoring."""
|
||||
|
||||
is_refusal: bool
|
||||
confidence: float # 0.0 to 1.0
|
||||
method_results: list[DetectionResult] = field(default_factory=list)
|
||||
|
||||
@property
|
||||
def total_weight(self) -> float:
|
||||
return sum(r.weight for r in self.method_results)
|
||||
|
||||
@property
|
||||
def refusal_weight(self) -> float:
|
||||
return sum(r.weight for r in self.method_results if r.is_refusal)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DetectorConfig:
|
||||
"""Configuration for a single detector."""
|
||||
|
||||
detector: RefusalDetector
|
||||
weight: float = 1.0
|
||||
name: str = ""
|
||||
|
||||
|
||||
class HybridRefusalClassifier:
|
||||
"""Hybrid refusal classifier combining multiple detection methods.
|
||||
|
||||
Uses weighted voting with configurable thresholds to combine marker-based,
|
||||
ML-based, and LLM-based detection for more accurate classification.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
threshold: float = 0.5,
|
||||
require_unanimous: bool = False,
|
||||
):
|
||||
"""Initialize hybrid classifier.
|
||||
|
||||
Args:
|
||||
threshold: Confidence threshold for refusal classification (0.0-1.0).
|
||||
Higher values require more confidence to classify as refusal.
|
||||
require_unanimous: If True, all detectors must agree for a refusal.
|
||||
"""
|
||||
self._detectors: list[DetectorConfig] = []
|
||||
self.threshold = threshold
|
||||
self.require_unanimous = require_unanimous
|
||||
|
||||
def add_detector(
|
||||
self,
|
||||
detector: RefusalDetector,
|
||||
weight: float = 1.0,
|
||||
name: str | None = None,
|
||||
) -> "HybridRefusalClassifier":
|
||||
"""Add a detection method with specified weight.
|
||||
|
||||
Args:
|
||||
detector: Refusal detector implementing is_refusal(str) -> bool
|
||||
weight: Weight for this detector's vote (default 1.0)
|
||||
name: Optional name for identification
|
||||
|
||||
Returns:
|
||||
self for method chaining
|
||||
"""
|
||||
detector_name = name or detector.__class__.__name__
|
||||
self._detectors.append(DetectorConfig(
|
||||
detector=detector,
|
||||
weight=weight,
|
||||
name=detector_name,
|
||||
))
|
||||
return self
|
||||
|
||||
def classify(self, response: str) -> HybridResult:
|
||||
"""Classify response with confidence scoring.
|
||||
|
||||
Returns HybridResult with is_refusal, confidence, and individual method results.
|
||||
"""
|
||||
if not self._detectors:
|
||||
return HybridResult(is_refusal=False, confidence=0.0)
|
||||
|
||||
results: list[DetectionResult] = []
|
||||
for config in self._detectors:
|
||||
try:
|
||||
is_refusal = config.detector.is_refusal(response)
|
||||
except Exception:
|
||||
continue # Skip failed detectors
|
||||
results.append(DetectionResult(
|
||||
method=config.name,
|
||||
is_refusal=is_refusal,
|
||||
weight=config.weight,
|
||||
))
|
||||
|
||||
if not results:
|
||||
return HybridResult(is_refusal=False, confidence=0.0)
|
||||
|
||||
total_weight = sum(r.weight for r in results)
|
||||
refusal_weight = sum(r.weight for r in results if r.is_refusal)
|
||||
|
||||
# Calculate confidence as how strongly detectors agree
|
||||
raw_score = refusal_weight / total_weight # 0.0-1.0, 1.0 = all say refusal
|
||||
|
||||
# Check unanimous requirement
|
||||
if self.require_unanimous:
|
||||
all_agree = all(r.is_refusal for r in results) or all(not r.is_refusal for r in results)
|
||||
if not all_agree:
|
||||
# Disagreement - return uncertain result
|
||||
return HybridResult(
|
||||
is_refusal=False,
|
||||
confidence=0.5,
|
||||
method_results=results,
|
||||
)
|
||||
|
||||
# Determine refusal based on threshold
|
||||
is_refusal = raw_score >= self.threshold
|
||||
|
||||
# Confidence reflects how far from the decision boundary
|
||||
if is_refusal:
|
||||
confidence = raw_score
|
||||
else:
|
||||
confidence = 1.0 - raw_score
|
||||
|
||||
return HybridResult(
|
||||
is_refusal=is_refusal,
|
||||
confidence=confidence,
|
||||
method_results=results,
|
||||
)
|
||||
|
||||
def is_refusal(self, response: str) -> bool:
|
||||
"""Check if response is a refusal (simple boolean interface).
|
||||
|
||||
This method provides compatibility with the RefusalClassifierPlugin interface.
|
||||
"""
|
||||
return self.classify(response).is_refusal
|
||||
|
||||
def is_refusal_with_confidence(self, response: str) -> tuple[bool, float]:
|
||||
"""Check if response is a refusal and return confidence.
|
||||
|
||||
Returns:
|
||||
Tuple of (is_refusal, confidence)
|
||||
"""
|
||||
result = self.classify(response)
|
||||
return result.is_refusal, result.confidence
|
||||
|
||||
|
||||
def create_hybrid_classifier(
|
||||
marker_detector: RefusalDetector | None = None,
|
||||
ml_detector: RefusalDetector | None = None,
|
||||
llm_detector: RefusalDetector | None = None,
|
||||
threshold: float = 0.5,
|
||||
marker_weight: float = 1.0,
|
||||
ml_weight: float = 1.5,
|
||||
llm_weight: float = 2.0,
|
||||
) -> HybridRefusalClassifier:
|
||||
"""Factory function to create a hybrid classifier with common detectors.
|
||||
|
||||
Args:
|
||||
marker_detector: Marker-based detector (DefaultRefusalClassifier)
|
||||
ml_detector: ML-based detector (RefusalClassifier from model.py)
|
||||
llm_detector: LLM-based detector (LLMRefusalClassifier)
|
||||
threshold: Classification threshold (0.0-1.0)
|
||||
marker_weight: Weight for marker-based detection
|
||||
ml_weight: Weight for ML-based detection
|
||||
llm_weight: Weight for LLM-based detection
|
||||
|
||||
Returns:
|
||||
Configured HybridRefusalClassifier
|
||||
"""
|
||||
classifier = HybridRefusalClassifier(threshold=threshold)
|
||||
|
||||
if marker_detector is not None:
|
||||
classifier.add_detector(marker_detector, weight=marker_weight, name="marker")
|
||||
if ml_detector is not None:
|
||||
classifier.add_detector(ml_detector, weight=ml_weight, name="ml")
|
||||
if llm_detector is not None:
|
||||
classifier.add_detector(llm_detector, weight=llm_weight, name="llm")
|
||||
|
||||
return classifier
|
||||
@@ -0,0 +1,313 @@
|
||||
"""Unit tests for hybrid refusal classifier."""
|
||||
|
||||
from inline_snapshot import snapshot
|
||||
|
||||
from agentic_security.refusal_classifier.hybrid_classifier import (
|
||||
DetectionResult,
|
||||
HybridRefusalClassifier,
|
||||
HybridResult,
|
||||
create_hybrid_classifier,
|
||||
)
|
||||
|
||||
|
||||
class MockDetector:
|
||||
"""Mock detector for testing."""
|
||||
|
||||
def __init__(self, result: bool):
|
||||
self.result = result
|
||||
self.calls: list[str] = []
|
||||
|
||||
def is_refusal(self, response: str) -> bool:
|
||||
self.calls.append(response)
|
||||
return self.result
|
||||
|
||||
|
||||
class FailingDetector:
|
||||
"""Detector that raises exceptions."""
|
||||
|
||||
def is_refusal(self, response: str) -> bool:
|
||||
raise RuntimeError("Detector failed")
|
||||
|
||||
|
||||
# Table-driven tests for DetectionResult
|
||||
detection_result_cases = [
|
||||
# (is_refusal, weight, expected_weighted_score)
|
||||
(True, 1.0, 1.0),
|
||||
(False, 1.0, -1.0),
|
||||
(True, 2.0, 2.0),
|
||||
(False, 2.0, -2.0),
|
||||
(True, 0.5, 0.5),
|
||||
(False, 0.5, -0.5),
|
||||
]
|
||||
|
||||
|
||||
class TestDetectionResult:
|
||||
|
||||
def test_weighted_score_cases(self):
|
||||
for is_refusal, weight, expected in detection_result_cases:
|
||||
result = DetectionResult(method="test", is_refusal=is_refusal, weight=weight)
|
||||
assert result.weighted_score == expected
|
||||
|
||||
def test_default_weight(self):
|
||||
result = DetectionResult(method="test", is_refusal=True)
|
||||
assert result.weight == snapshot(1.0)
|
||||
|
||||
|
||||
class TestHybridResult:
|
||||
|
||||
def test_total_weight(self):
|
||||
results = [
|
||||
DetectionResult(method="a", is_refusal=True, weight=1.0),
|
||||
DetectionResult(method="b", is_refusal=False, weight=2.0),
|
||||
]
|
||||
hybrid = HybridResult(is_refusal=True, confidence=0.8, method_results=results)
|
||||
assert hybrid.total_weight == snapshot(3.0)
|
||||
|
||||
def test_refusal_weight(self):
|
||||
results = [
|
||||
DetectionResult(method="a", is_refusal=True, weight=1.0),
|
||||
DetectionResult(method="b", is_refusal=False, weight=2.0),
|
||||
DetectionResult(method="c", is_refusal=True, weight=0.5),
|
||||
]
|
||||
hybrid = HybridResult(is_refusal=True, confidence=0.8, method_results=results)
|
||||
assert hybrid.refusal_weight == snapshot(1.5)
|
||||
|
||||
def test_empty_results(self):
|
||||
hybrid = HybridResult(is_refusal=False, confidence=0.0, method_results=[])
|
||||
assert hybrid.total_weight == snapshot(0.0)
|
||||
assert hybrid.refusal_weight == snapshot(0.0)
|
||||
|
||||
|
||||
class TestHybridRefusalClassifier:
|
||||
|
||||
def test_no_detectors_returns_false(self):
|
||||
classifier = HybridRefusalClassifier()
|
||||
result = classifier.classify("test response")
|
||||
assert result.is_refusal is False
|
||||
assert result.confidence == snapshot(0.0)
|
||||
|
||||
def test_single_detector_refusal(self):
|
||||
classifier = HybridRefusalClassifier()
|
||||
classifier.add_detector(MockDetector(True), name="mock")
|
||||
result = classifier.classify("test")
|
||||
assert result.is_refusal is True
|
||||
assert result.confidence == snapshot(1.0)
|
||||
|
||||
def test_single_detector_non_refusal(self):
|
||||
classifier = HybridRefusalClassifier()
|
||||
classifier.add_detector(MockDetector(False), name="mock")
|
||||
result = classifier.classify("test")
|
||||
assert result.is_refusal is False
|
||||
assert result.confidence == snapshot(1.0)
|
||||
|
||||
def test_two_detectors_both_refusal(self):
|
||||
classifier = HybridRefusalClassifier()
|
||||
classifier.add_detector(MockDetector(True), name="a")
|
||||
classifier.add_detector(MockDetector(True), name="b")
|
||||
result = classifier.classify("test")
|
||||
assert result.is_refusal is True
|
||||
assert result.confidence == snapshot(1.0)
|
||||
assert len(result.method_results) == snapshot(2)
|
||||
|
||||
def test_two_detectors_both_non_refusal(self):
|
||||
classifier = HybridRefusalClassifier()
|
||||
classifier.add_detector(MockDetector(False), name="a")
|
||||
classifier.add_detector(MockDetector(False), name="b")
|
||||
result = classifier.classify("test")
|
||||
assert result.is_refusal is False
|
||||
assert result.confidence == snapshot(1.0)
|
||||
|
||||
def test_weighted_voting_higher_refusal(self):
|
||||
classifier = HybridRefusalClassifier(threshold=0.5)
|
||||
classifier.add_detector(MockDetector(True), weight=2.0, name="a")
|
||||
classifier.add_detector(MockDetector(False), weight=1.0, name="b")
|
||||
result = classifier.classify("test")
|
||||
# refusal_weight = 2.0, total = 3.0, ratio = 0.666
|
||||
assert result.is_refusal is True
|
||||
assert round(result.confidence, 2) == snapshot(0.67)
|
||||
|
||||
def test_weighted_voting_higher_non_refusal(self):
|
||||
classifier = HybridRefusalClassifier(threshold=0.5)
|
||||
classifier.add_detector(MockDetector(True), weight=1.0, name="a")
|
||||
classifier.add_detector(MockDetector(False), weight=2.0, name="b")
|
||||
result = classifier.classify("test")
|
||||
# refusal_weight = 1.0, total = 3.0, ratio = 0.333
|
||||
assert result.is_refusal is False
|
||||
assert round(result.confidence, 2) == snapshot(0.67)
|
||||
|
||||
def test_threshold_boundary(self):
|
||||
classifier = HybridRefusalClassifier(threshold=0.5)
|
||||
classifier.add_detector(MockDetector(True), weight=1.0, name="a")
|
||||
classifier.add_detector(MockDetector(False), weight=1.0, name="b")
|
||||
result = classifier.classify("test")
|
||||
# ratio = 0.5, exactly at threshold
|
||||
assert result.is_refusal is True
|
||||
|
||||
def test_high_threshold(self):
|
||||
classifier = HybridRefusalClassifier(threshold=0.8)
|
||||
classifier.add_detector(MockDetector(True), weight=2.0, name="a")
|
||||
classifier.add_detector(MockDetector(False), weight=1.0, name="b")
|
||||
result = classifier.classify("test")
|
||||
# ratio = 0.666, below 0.8 threshold
|
||||
assert result.is_refusal is False
|
||||
|
||||
def test_unanimous_required_all_agree_refusal(self):
|
||||
classifier = HybridRefusalClassifier(require_unanimous=True)
|
||||
classifier.add_detector(MockDetector(True), name="a")
|
||||
classifier.add_detector(MockDetector(True), name="b")
|
||||
result = classifier.classify("test")
|
||||
assert result.is_refusal is True
|
||||
|
||||
def test_unanimous_required_all_agree_non_refusal(self):
|
||||
classifier = HybridRefusalClassifier(require_unanimous=True)
|
||||
classifier.add_detector(MockDetector(False), name="a")
|
||||
classifier.add_detector(MockDetector(False), name="b")
|
||||
result = classifier.classify("test")
|
||||
assert result.is_refusal is False
|
||||
assert result.confidence == snapshot(1.0)
|
||||
|
||||
def test_unanimous_required_disagreement(self):
|
||||
classifier = HybridRefusalClassifier(require_unanimous=True)
|
||||
classifier.add_detector(MockDetector(True), name="a")
|
||||
classifier.add_detector(MockDetector(False), name="b")
|
||||
result = classifier.classify("test")
|
||||
# Disagreement returns uncertain result
|
||||
assert result.is_refusal is False
|
||||
assert result.confidence == snapshot(0.5)
|
||||
|
||||
def test_failing_detector_skipped(self):
|
||||
classifier = HybridRefusalClassifier()
|
||||
classifier.add_detector(MockDetector(True), name="good")
|
||||
classifier.add_detector(FailingDetector(), name="bad")
|
||||
result = classifier.classify("test")
|
||||
# Only the good detector counted
|
||||
assert result.is_refusal is True
|
||||
assert len(result.method_results) == snapshot(1)
|
||||
|
||||
def test_all_detectors_fail(self):
|
||||
classifier = HybridRefusalClassifier()
|
||||
classifier.add_detector(FailingDetector(), name="bad1")
|
||||
classifier.add_detector(FailingDetector(), name="bad2")
|
||||
result = classifier.classify("test")
|
||||
assert result.is_refusal is False
|
||||
assert result.confidence == snapshot(0.0)
|
||||
|
||||
def test_method_chaining(self):
|
||||
classifier = (
|
||||
HybridRefusalClassifier()
|
||||
.add_detector(MockDetector(True), name="a")
|
||||
.add_detector(MockDetector(False), name="b")
|
||||
)
|
||||
assert len(classifier._detectors) == snapshot(2)
|
||||
|
||||
def test_detector_calls_recorded(self):
|
||||
detector = MockDetector(True)
|
||||
classifier = HybridRefusalClassifier()
|
||||
classifier.add_detector(detector, name="mock")
|
||||
classifier.classify("test input")
|
||||
assert detector.calls == snapshot(["test input"])
|
||||
|
||||
def test_is_refusal_simple_interface(self):
|
||||
classifier = HybridRefusalClassifier()
|
||||
classifier.add_detector(MockDetector(True), name="mock")
|
||||
assert classifier.is_refusal("test") is True
|
||||
|
||||
def test_is_refusal_with_confidence(self):
|
||||
classifier = HybridRefusalClassifier()
|
||||
classifier.add_detector(MockDetector(True), name="mock")
|
||||
is_ref, conf = classifier.is_refusal_with_confidence("test")
|
||||
assert is_ref is True
|
||||
assert conf == snapshot(1.0)
|
||||
|
||||
def test_default_detector_name(self):
|
||||
classifier = HybridRefusalClassifier()
|
||||
classifier.add_detector(MockDetector(True))
|
||||
result = classifier.classify("test")
|
||||
assert result.method_results[0].method == snapshot("MockDetector")
|
||||
|
||||
|
||||
# Table-driven tests for create_hybrid_classifier factory
|
||||
factory_cases = [
|
||||
# (kwargs, expected_detector_count)
|
||||
({}, 0),
|
||||
({"marker_detector": MockDetector(True)}, 1),
|
||||
({"ml_detector": MockDetector(True)}, 1),
|
||||
({"llm_detector": MockDetector(True)}, 1),
|
||||
({"marker_detector": MockDetector(True), "ml_detector": MockDetector(False)}, 2),
|
||||
({"marker_detector": MockDetector(True), "ml_detector": MockDetector(False), "llm_detector": MockDetector(True)}, 3),
|
||||
]
|
||||
|
||||
|
||||
class TestCreateHybridClassifier:
|
||||
|
||||
def test_detector_count_cases(self):
|
||||
for kwargs, expected_count in factory_cases:
|
||||
classifier = create_hybrid_classifier(**kwargs)
|
||||
assert len(classifier._detectors) == expected_count
|
||||
|
||||
def test_threshold_passed(self):
|
||||
classifier = create_hybrid_classifier(threshold=0.7)
|
||||
assert classifier.threshold == snapshot(0.7)
|
||||
|
||||
def test_default_weights(self):
|
||||
classifier = create_hybrid_classifier(
|
||||
marker_detector=MockDetector(True),
|
||||
ml_detector=MockDetector(True),
|
||||
llm_detector=MockDetector(True),
|
||||
)
|
||||
weights = {d.name: d.weight for d in classifier._detectors}
|
||||
assert weights == snapshot({"marker": 1.0, "ml": 1.5, "llm": 2.0})
|
||||
|
||||
def test_custom_weights(self):
|
||||
classifier = create_hybrid_classifier(
|
||||
marker_detector=MockDetector(True),
|
||||
ml_detector=MockDetector(True),
|
||||
llm_detector=MockDetector(True),
|
||||
marker_weight=0.5,
|
||||
ml_weight=1.0,
|
||||
llm_weight=3.0,
|
||||
)
|
||||
weights = {d.name: d.weight for d in classifier._detectors}
|
||||
assert weights == snapshot({"marker": 0.5, "ml": 1.0, "llm": 3.0})
|
||||
|
||||
|
||||
class TestConfidenceScoring:
|
||||
"""Test confidence scoring edge cases."""
|
||||
|
||||
def test_confidence_high_agreement_refusal(self):
|
||||
classifier = HybridRefusalClassifier()
|
||||
classifier.add_detector(MockDetector(True), weight=1.0, name="a")
|
||||
classifier.add_detector(MockDetector(True), weight=1.0, name="b")
|
||||
classifier.add_detector(MockDetector(True), weight=1.0, name="c")
|
||||
result = classifier.classify("test")
|
||||
assert result.confidence == snapshot(1.0)
|
||||
assert result.is_refusal is True
|
||||
|
||||
def test_confidence_high_agreement_non_refusal(self):
|
||||
classifier = HybridRefusalClassifier()
|
||||
classifier.add_detector(MockDetector(False), weight=1.0, name="a")
|
||||
classifier.add_detector(MockDetector(False), weight=1.0, name="b")
|
||||
classifier.add_detector(MockDetector(False), weight=1.0, name="c")
|
||||
result = classifier.classify("test")
|
||||
assert result.confidence == snapshot(1.0)
|
||||
assert result.is_refusal is False
|
||||
|
||||
def test_confidence_moderate_majority_refusal(self):
|
||||
classifier = HybridRefusalClassifier()
|
||||
classifier.add_detector(MockDetector(True), weight=1.0, name="a")
|
||||
classifier.add_detector(MockDetector(True), weight=1.0, name="b")
|
||||
classifier.add_detector(MockDetector(False), weight=1.0, name="c")
|
||||
result = classifier.classify("test")
|
||||
# 2/3 = 0.666 confidence for refusal
|
||||
assert round(result.confidence, 2) == snapshot(0.67)
|
||||
assert result.is_refusal is True
|
||||
|
||||
def test_confidence_moderate_majority_non_refusal(self):
|
||||
classifier = HybridRefusalClassifier()
|
||||
classifier.add_detector(MockDetector(False), weight=1.0, name="a")
|
||||
classifier.add_detector(MockDetector(False), weight=1.0, name="b")
|
||||
classifier.add_detector(MockDetector(True), weight=1.0, name="c")
|
||||
result = classifier.classify("test")
|
||||
# 2/3 = 0.666 confidence for non-refusal
|
||||
assert round(result.confidence, 2) == snapshot(0.67)
|
||||
assert result.is_refusal is False
|
||||
Reference in New Issue
Block a user