mirror of
https://github.com/msoedov/agentic_security.git
synced 2026-06-23 21:59:57 +02:00
fix(pc):
This commit is contained in:
@@ -83,7 +83,9 @@ class YAMLRulesDatasetLoader:
|
||||
|
||||
severity_enums = None
|
||||
if self.severities:
|
||||
severity_enums = [AttackRuleSeverity.from_string(s) for s in self.severities]
|
||||
severity_enums = [
|
||||
AttackRuleSeverity.from_string(s) for s in self.severities
|
||||
]
|
||||
|
||||
filtered = self._loader.filter_rules(
|
||||
rules, types=self.types, severities=severity_enums
|
||||
@@ -113,10 +115,14 @@ class YAMLRulesDatasetLoader:
|
||||
|
||||
severity_enums = None
|
||||
if self.severities:
|
||||
severity_enums = [AttackRuleSeverity.from_string(s) for s in self.severities]
|
||||
severity_enums = [
|
||||
AttackRuleSeverity.from_string(s) for s in self.severities
|
||||
]
|
||||
|
||||
filtered = self._loader.filter_rules(
|
||||
all_rules, types=self.types, severities=severity_enums
|
||||
)
|
||||
|
||||
return rules_to_dataset(filtered, name="YAML Rules (merged)", variables=variables)
|
||||
return rules_to_dataset(
|
||||
filtered, name="YAML Rules (merged)", variables=variables
|
||||
)
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
@@ -81,9 +80,7 @@ class RuleLoader:
|
||||
return None
|
||||
|
||||
def load_rules_from_directory(
|
||||
self,
|
||||
directory: str | Path | None = None,
|
||||
recursive: bool = True
|
||||
self, directory: str | Path | None = None, recursive: bool = True
|
||||
) -> list[AttackRule]:
|
||||
directory = Path(directory) if directory else self.rules_dir
|
||||
if not directory or not directory.exists():
|
||||
@@ -91,7 +88,7 @@ class RuleLoader:
|
||||
return []
|
||||
|
||||
rules = []
|
||||
pattern = "**/*.yaml" if recursive else "*.yaml"
|
||||
# pattern = "**/*.yaml" if recursive else "*.yaml"
|
||||
|
||||
for ext in [".yaml", ".yml"]:
|
||||
glob_pattern = f"**/*{ext}" if recursive else f"*{ext}"
|
||||
@@ -105,9 +102,7 @@ class RuleLoader:
|
||||
return rules
|
||||
|
||||
def load_multiple_directories(
|
||||
self,
|
||||
directories: list[str | Path],
|
||||
recursive: bool = True
|
||||
self, directories: list[str | Path], recursive: bool = True
|
||||
) -> list[AttackRule]:
|
||||
all_rules = []
|
||||
for directory in directories:
|
||||
@@ -133,6 +128,7 @@ class RuleLoader:
|
||||
|
||||
if name_pattern:
|
||||
import re
|
||||
|
||||
pattern = re.compile(name_pattern, re.IGNORECASE)
|
||||
result = [r for r in result if pattern.search(r.name)]
|
||||
|
||||
@@ -154,8 +150,7 @@ class RuleLoader:
|
||||
|
||||
|
||||
def load_rules_from_directory(
|
||||
directory: str | Path,
|
||||
recursive: bool = True
|
||||
directory: str | Path, recursive: bool = True
|
||||
) -> list[AttackRule]:
|
||||
loader = RuleLoader()
|
||||
return loader.load_rules_from_directory(directory, recursive)
|
||||
|
||||
@@ -38,10 +38,20 @@ class AttackRule:
|
||||
pass_conditions=data.get("pass_conditions", []),
|
||||
fail_conditions=data.get("fail_conditions", []),
|
||||
source=data.get("source"),
|
||||
metadata={k: v for k, v in data.items() if k not in {
|
||||
"name", "type", "prompt", "severity",
|
||||
"pass_conditions", "fail_conditions", "source"
|
||||
}},
|
||||
metadata={
|
||||
k: v
|
||||
for k, v in data.items()
|
||||
if k
|
||||
not in {
|
||||
"name",
|
||||
"type",
|
||||
"prompt",
|
||||
"severity",
|
||||
"pass_conditions",
|
||||
"fail_conditions",
|
||||
"source",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
|
||||
@@ -0,0 +1,179 @@
|
||||
"""Security utilities and validation for agentic_security."""
|
||||
|
||||
from functools import wraps
|
||||
from collections.abc import Callable
|
||||
from urllib.parse import urlparse
|
||||
import hashlib
|
||||
import hmac
|
||||
import os
|
||||
import re
|
||||
|
||||
|
||||
class SecurityValidator:
|
||||
"""Input validation and sanitization."""
|
||||
|
||||
ALLOWED_URL_SCHEMES = {"http", "https"}
|
||||
MAX_URL_LENGTH = 2048
|
||||
MAX_FILE_SIZE = 10 * 1024 * 1024 # 10MB
|
||||
|
||||
@staticmethod
|
||||
def validate_url(url: str, allowed_hosts: list[str] | None = None) -> bool:
|
||||
"""Validate URL for SSRF prevention."""
|
||||
if len(url) > SecurityValidator.MAX_URL_LENGTH:
|
||||
return False
|
||||
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
|
||||
if parsed.scheme not in SecurityValidator.ALLOWED_URL_SCHEMES:
|
||||
return False
|
||||
|
||||
if not parsed.netloc:
|
||||
return False
|
||||
|
||||
if parsed.netloc in ["localhost", "127.0.0.1", "0.0.0.0"]:
|
||||
return False
|
||||
|
||||
if parsed.netloc.startswith("169.254."):
|
||||
return False
|
||||
|
||||
if parsed.netloc.startswith("10.") or parsed.netloc.startswith("192.168."):
|
||||
return False
|
||||
|
||||
if allowed_hosts and parsed.netloc not in allowed_hosts:
|
||||
return False
|
||||
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def sanitize_filename(filename: str) -> str:
|
||||
"""Sanitize filename to prevent path traversal."""
|
||||
filename = os.path.basename(filename)
|
||||
filename = re.sub(r"[^\w\s.-]", "", filename)
|
||||
filename = filename.strip()
|
||||
|
||||
if not filename or filename in [".", ".."]:
|
||||
raise ValueError("Invalid filename")
|
||||
|
||||
return filename
|
||||
|
||||
@staticmethod
|
||||
def validate_file_size(size: int) -> bool:
|
||||
"""Validate file size."""
|
||||
return 0 < size <= SecurityValidator.MAX_FILE_SIZE
|
||||
|
||||
@staticmethod
|
||||
def validate_csv_content(content: str) -> bool:
|
||||
"""Basic CSV validation."""
|
||||
if not content or len(content) > SecurityValidator.MAX_FILE_SIZE:
|
||||
return False
|
||||
|
||||
lines = content.split("\n", 2)
|
||||
if not lines:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
class SecretManager:
|
||||
"""Secure secret handling."""
|
||||
|
||||
@staticmethod
|
||||
def get_secret(key: str, default: str | None = None) -> str | None:
|
||||
"""Get secret from environment."""
|
||||
value = os.getenv(key, default)
|
||||
if value and value.startswith("$"):
|
||||
env_key = value[1:]
|
||||
value = os.getenv(env_key, default)
|
||||
return value
|
||||
|
||||
@staticmethod
|
||||
def hash_secret(secret: str, salt: str | None = None) -> str:
|
||||
"""Hash a secret value."""
|
||||
if salt is None:
|
||||
salt = os.urandom(32).hex()
|
||||
|
||||
hashed = hashlib.pbkdf2_hmac("sha256", secret.encode(), salt.encode(), 100000)
|
||||
return f"{salt}${hashed.hex()}"
|
||||
|
||||
@staticmethod
|
||||
def verify_secret(secret: str, hashed: str) -> bool:
|
||||
"""Verify a secret against its hash."""
|
||||
try:
|
||||
salt, expected = hashed.split("$", 1)
|
||||
actual = hashlib.pbkdf2_hmac(
|
||||
"sha256", secret.encode(), salt.encode(), 100000
|
||||
)
|
||||
return hmac.compare_digest(actual.hex(), expected)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""Simple in-memory rate limiter."""
|
||||
|
||||
def __init__(self, max_requests: int, window_seconds: int):
|
||||
self.max_requests = max_requests
|
||||
self.window_seconds = window_seconds
|
||||
self._requests: dict[str, list[float]] = {}
|
||||
|
||||
def is_allowed(self, key: str) -> bool:
|
||||
"""Check if request is allowed."""
|
||||
import time
|
||||
|
||||
now = time.time()
|
||||
|
||||
if key not in self._requests:
|
||||
self._requests[key] = []
|
||||
|
||||
self._requests[key] = [
|
||||
ts for ts in self._requests[key] if now - ts < self.window_seconds
|
||||
]
|
||||
|
||||
if len(self._requests[key]) >= self.max_requests:
|
||||
return False
|
||||
|
||||
self._requests[key].append(now)
|
||||
return True
|
||||
|
||||
def reset(self, key: str):
|
||||
"""Reset rate limit for key."""
|
||||
self._requests.pop(key, None)
|
||||
|
||||
|
||||
def require_auth(func: Callable) -> Callable:
|
||||
"""Decorator to require authentication."""
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
# TODO: Implement actual auth check
|
||||
# For now, check if API key is present
|
||||
api_key = kwargs.get("api_key") or os.getenv("API_KEY")
|
||||
if not api_key:
|
||||
from fastapi import HTTPException
|
||||
|
||||
raise HTTPException(status_code=401, detail="Authentication required")
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def sanitize_log_output(data: str | dict) -> str:
|
||||
"""Remove sensitive data from logs."""
|
||||
if isinstance(data, dict):
|
||||
data = str(data)
|
||||
|
||||
patterns = [
|
||||
(r'(api[_-]?key["\s:=]+)["\']?[\w-]+', r"\1***"),
|
||||
(r'(token["\s:=]+)["\']?[\w-]+', r"\1***"),
|
||||
(r'(password["\s:=]+)["\']?[\w-]+', r"\1***"),
|
||||
(r'(secret["\s:=]+)["\']?[\w-]+', r"\1***"),
|
||||
(r"Bearer\s+[\w-]+", "Bearer ***"),
|
||||
]
|
||||
|
||||
for pattern, replacement in patterns:
|
||||
data = re.sub(pattern, replacement, data, flags=re.IGNORECASE)
|
||||
|
||||
return data
|
||||
@@ -8,8 +8,7 @@ logger = logging.getLogger(__name__)
|
||||
class FuzzRunnable(Protocol):
|
||||
"""Protocol for objects that can be run in a fuzzing chain."""
|
||||
|
||||
async def run(self, **kwargs: Any) -> str:
|
||||
...
|
||||
async def run(self, **kwargs: Any) -> str: ...
|
||||
|
||||
|
||||
class FuzzNode:
|
||||
|
||||
@@ -36,6 +36,7 @@ class AnthropicProvider(BaseLLMProvider):
|
||||
def _get_client(self) -> Any:
|
||||
if self._client is None:
|
||||
import anthropic
|
||||
|
||||
kwargs: dict[str, Any] = {"api_key": self.api_key}
|
||||
if self.base_url:
|
||||
kwargs["base_url"] = self.base_url
|
||||
@@ -45,6 +46,7 @@ class AnthropicProvider(BaseLLMProvider):
|
||||
def _get_async_client(self) -> Any:
|
||||
if self._async_client is None:
|
||||
import anthropic
|
||||
|
||||
kwargs: dict[str, Any] = {"api_key": self.api_key}
|
||||
if self.base_url:
|
||||
kwargs["base_url"] = self.base_url
|
||||
@@ -95,6 +97,7 @@ class AnthropicProvider(BaseLLMProvider):
|
||||
|
||||
def _handle_error(self, e: Exception) -> None:
|
||||
import anthropic
|
||||
|
||||
if isinstance(e, anthropic.RateLimitError):
|
||||
raise LLMRateLimitError(str(e)) from e
|
||||
if isinstance(e, anthropic.APIError):
|
||||
|
||||
@@ -20,6 +20,7 @@ class LLMRateLimitError(LLMProviderError):
|
||||
@dataclass
|
||||
class LLMMessage:
|
||||
"""A message in a chat conversation."""
|
||||
|
||||
role: str # "system", "user", or "assistant"
|
||||
content: str
|
||||
|
||||
@@ -27,6 +28,7 @@ class LLMMessage:
|
||||
@dataclass
|
||||
class LLMResponse:
|
||||
"""Response from an LLM provider."""
|
||||
|
||||
content: str
|
||||
model: str | None = None
|
||||
finish_reason: str | None = None
|
||||
|
||||
@@ -14,6 +14,7 @@ def _ensure_registered() -> None:
|
||||
return
|
||||
from agentic_security.llm_providers.openai_provider import OpenAIProvider
|
||||
from agentic_security.llm_providers.anthropic_provider import AnthropicProvider
|
||||
|
||||
_PROVIDERS["openai"] = OpenAIProvider
|
||||
_PROVIDERS["anthropic"] = AnthropicProvider
|
||||
|
||||
|
||||
@@ -36,13 +36,17 @@ class OpenAIProvider(BaseLLMProvider):
|
||||
def _get_client(self) -> Any:
|
||||
if self._client is None:
|
||||
import openai
|
||||
|
||||
self._client = openai.OpenAI(api_key=self.api_key, base_url=self.base_url)
|
||||
return self._client
|
||||
|
||||
def _get_async_client(self) -> Any:
|
||||
if self._async_client is None:
|
||||
import openai
|
||||
self._async_client = openai.AsyncOpenAI(api_key=self.api_key, base_url=self.base_url)
|
||||
|
||||
self._async_client = openai.AsyncOpenAI(
|
||||
api_key=self.api_key, base_url=self.base_url
|
||||
)
|
||||
return self._async_client
|
||||
|
||||
@classmethod
|
||||
@@ -79,6 +83,7 @@ class OpenAIProvider(BaseLLMProvider):
|
||||
|
||||
def _handle_error(self, e: Exception) -> None:
|
||||
import openai
|
||||
|
||||
if isinstance(e, openai.RateLimitError):
|
||||
raise LLMRateLimitError(str(e)) from e
|
||||
raise LLMProviderError(str(e)) from e
|
||||
|
||||
@@ -96,11 +96,13 @@ class HybridRefusalClassifier:
|
||||
self for method chaining
|
||||
"""
|
||||
detector_name = name or detector.__class__.__name__
|
||||
self._detectors.append(DetectorConfig(
|
||||
detector=detector,
|
||||
weight=weight,
|
||||
name=detector_name,
|
||||
))
|
||||
self._detectors.append(
|
||||
DetectorConfig(
|
||||
detector=detector,
|
||||
weight=weight,
|
||||
name=detector_name,
|
||||
)
|
||||
)
|
||||
return self
|
||||
|
||||
def classify(self, response: str) -> HybridResult:
|
||||
@@ -117,11 +119,13 @@ class HybridRefusalClassifier:
|
||||
is_refusal = config.detector.is_refusal(response)
|
||||
except Exception:
|
||||
continue # Skip failed detectors
|
||||
results.append(DetectionResult(
|
||||
method=config.name,
|
||||
is_refusal=is_refusal,
|
||||
weight=config.weight,
|
||||
))
|
||||
results.append(
|
||||
DetectionResult(
|
||||
method=config.name,
|
||||
is_refusal=is_refusal,
|
||||
weight=config.weight,
|
||||
)
|
||||
)
|
||||
|
||||
if not results:
|
||||
return HybridResult(is_refusal=False, confidence=0.0)
|
||||
@@ -134,7 +138,9 @@ class HybridRefusalClassifier:
|
||||
|
||||
# Check unanimous requirement
|
||||
if self.require_unanimous:
|
||||
all_agree = all(r.is_refusal for r in results) or all(not r.is_refusal for r in results)
|
||||
all_agree = all(r.is_refusal for r in results) or all(
|
||||
not r.is_refusal for r in results
|
||||
)
|
||||
if not all_agree:
|
||||
# Disagreement - return uncertain result
|
||||
return HybridResult(
|
||||
|
||||
@@ -56,16 +56,20 @@ class TestRulesToDataset:
|
||||
class TestLoadRulesAsDataset:
|
||||
def test_basic_load(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
(Path(tmpdir) / "rule1.yaml").write_text("""
|
||||
(Path(tmpdir) / "rule1.yaml").write_text(
|
||||
"""
|
||||
name: test1
|
||||
type: jailbreak
|
||||
prompt: Jailbreak prompt
|
||||
""")
|
||||
(Path(tmpdir) / "rule2.yaml").write_text("""
|
||||
"""
|
||||
)
|
||||
(Path(tmpdir) / "rule2.yaml").write_text(
|
||||
"""
|
||||
name: test2
|
||||
type: harmful
|
||||
prompt: Harmful prompt
|
||||
""")
|
||||
"""
|
||||
)
|
||||
dataset = load_rules_as_dataset(tmpdir)
|
||||
assert len(dataset.prompts) == 2
|
||||
|
||||
|
||||
@@ -77,15 +77,15 @@ severity: high
|
||||
|
||||
def test_load_rule_from_file(self):
|
||||
loader = RuleLoader()
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w", suffix=".yaml", delete=False
|
||||
) as f:
|
||||
f.write("""
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
|
||||
f.write(
|
||||
"""
|
||||
name: file_test
|
||||
type: harmful
|
||||
severity: medium
|
||||
prompt: Test prompt from file
|
||||
""")
|
||||
"""
|
||||
)
|
||||
f.flush()
|
||||
rule = loader.load_rule_from_file(f.name)
|
||||
|
||||
@@ -96,9 +96,7 @@ prompt: Test prompt from file
|
||||
|
||||
def test_load_rule_from_file_wrong_extension(self):
|
||||
loader = RuleLoader()
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w", suffix=".txt", delete=False
|
||||
) as f:
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False) as f:
|
||||
f.write("name: test\nprompt: test")
|
||||
f.flush()
|
||||
rule = loader.load_rule_from_file(f.name)
|
||||
@@ -110,16 +108,20 @@ prompt: Test prompt from file
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
rule1_path = Path(tmpdir) / "rule1.yaml"
|
||||
rule2_path = Path(tmpdir) / "rule2.yml"
|
||||
rule1_path.write_text("""
|
||||
rule1_path.write_text(
|
||||
"""
|
||||
name: rule1
|
||||
type: jailbreak
|
||||
prompt: First rule
|
||||
""")
|
||||
rule2_path.write_text("""
|
||||
"""
|
||||
)
|
||||
rule2_path.write_text(
|
||||
"""
|
||||
name: rule2
|
||||
type: harmful
|
||||
prompt: Second rule
|
||||
""")
|
||||
"""
|
||||
)
|
||||
loader = RuleLoader()
|
||||
rules = loader.load_rules_from_directory(tmpdir)
|
||||
|
||||
|
||||
@@ -87,7 +87,12 @@ class TestAttackRule:
|
||||
rule = AttackRule(name="test", type="jailbreak", prompt="Test")
|
||||
result = rule.to_dict()
|
||||
assert result == snapshot(
|
||||
{"name": "test", "type": "jailbreak", "prompt": "Test", "severity": "medium"}
|
||||
{
|
||||
"name": "test",
|
||||
"type": "jailbreak",
|
||||
"prompt": "Test",
|
||||
"severity": "medium",
|
||||
}
|
||||
)
|
||||
|
||||
def test_render_prompt_no_variables(self):
|
||||
|
||||
@@ -116,11 +116,13 @@ class TestFuzzChain:
|
||||
|
||||
result = await chain.run(input="initial")
|
||||
assert result == "final result"
|
||||
assert llm.prompts == snapshot([
|
||||
"First: initial",
|
||||
"Second: step1 result",
|
||||
"Third: step2 result",
|
||||
])
|
||||
assert llm.prompts == snapshot(
|
||||
[
|
||||
"First: initial",
|
||||
"Second: step1 result",
|
||||
"Third: step2 result",
|
||||
]
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chain_with_custom_variables(self):
|
||||
@@ -131,10 +133,12 @@ class TestFuzzChain:
|
||||
|
||||
result = await chain.run(topic="security", input="test prompt")
|
||||
assert result == "evaluated"
|
||||
assert llm.prompts == snapshot([
|
||||
"Analyze security: test prompt",
|
||||
"Evaluate: analyzed",
|
||||
])
|
||||
assert llm.prompts == snapshot(
|
||||
[
|
||||
"Analyze security: test prompt",
|
||||
"Evaluate: analyzed",
|
||||
]
|
||||
)
|
||||
|
||||
def test_pipe_chain_to_node(self):
|
||||
llm = MockLLMProvider()
|
||||
@@ -158,11 +162,13 @@ class TestFuzzChain:
|
||||
|
||||
def test_len(self):
|
||||
llm = MockLLMProvider()
|
||||
chain = FuzzChain([
|
||||
FuzzNode(llm, "A"),
|
||||
FuzzNode(llm, "B"),
|
||||
FuzzNode(llm, "C"),
|
||||
])
|
||||
chain = FuzzChain(
|
||||
[
|
||||
FuzzNode(llm, "A"),
|
||||
FuzzNode(llm, "B"),
|
||||
FuzzNode(llm, "C"),
|
||||
]
|
||||
)
|
||||
assert len(chain) == 3
|
||||
|
||||
def test_repr(self):
|
||||
@@ -185,11 +191,13 @@ class TestPipeOperatorChaining:
|
||||
result = await chain.run(input="start")
|
||||
|
||||
assert result == "c"
|
||||
assert llm.prompts == snapshot([
|
||||
"Step1: start",
|
||||
"Step2: a",
|
||||
"Step3: b",
|
||||
])
|
||||
assert llm.prompts == snapshot(
|
||||
[
|
||||
"Step1: start",
|
||||
"Step2: a",
|
||||
"Step3: b",
|
||||
]
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chain_with_different_providers(self):
|
||||
|
||||
@@ -5,7 +5,11 @@ from unittest.mock import MagicMock, AsyncMock, patch
|
||||
from inline_snapshot import snapshot
|
||||
|
||||
from agentic_security.llm_providers.anthropic_provider import AnthropicProvider
|
||||
from agentic_security.llm_providers.base import LLMMessage, LLMProviderError, LLMRateLimitError
|
||||
from agentic_security.llm_providers.base import (
|
||||
LLMMessage,
|
||||
LLMProviderError,
|
||||
LLMRateLimitError,
|
||||
)
|
||||
|
||||
|
||||
class TestAnthropicProviderInit:
|
||||
@@ -209,13 +213,19 @@ class TestAnthropicProviderErrors:
|
||||
|
||||
def test_handle_rate_limit_error(self, provider):
|
||||
import anthropic
|
||||
|
||||
with pytest.raises(LLMRateLimitError):
|
||||
provider._handle_error(anthropic.RateLimitError("rate limited", response=MagicMock(), body={}))
|
||||
provider._handle_error(
|
||||
anthropic.RateLimitError("rate limited", response=MagicMock(), body={})
|
||||
)
|
||||
|
||||
def test_handle_api_error(self, provider):
|
||||
import anthropic
|
||||
|
||||
with pytest.raises(LLMProviderError):
|
||||
provider._handle_error(anthropic.APIError("api error", request=MagicMock(), body={}))
|
||||
provider._handle_error(
|
||||
anthropic.APIError("api error", request=MagicMock(), body={})
|
||||
)
|
||||
|
||||
def test_handle_generic_error(self, provider):
|
||||
with pytest.raises(LLMProviderError):
|
||||
|
||||
@@ -9,7 +9,11 @@ from agentic_security.llm_providers.factory import (
|
||||
list_providers,
|
||||
register_provider,
|
||||
)
|
||||
from agentic_security.llm_providers.base import BaseLLMProvider, LLMProviderError, LLMResponse
|
||||
from agentic_security.llm_providers.base import (
|
||||
BaseLLMProvider,
|
||||
LLMProviderError,
|
||||
LLMResponse,
|
||||
)
|
||||
|
||||
|
||||
class TestListProviders:
|
||||
@@ -26,11 +30,13 @@ class TestListProviders:
|
||||
class TestGetProviderClass:
|
||||
def test_get_openai(self):
|
||||
from agentic_security.llm_providers.openai_provider import OpenAIProvider
|
||||
|
||||
cls = get_provider_class("openai")
|
||||
assert cls is OpenAIProvider
|
||||
|
||||
def test_get_anthropic(self):
|
||||
from agentic_security.llm_providers.anthropic_provider import AnthropicProvider
|
||||
|
||||
cls = get_provider_class("anthropic")
|
||||
assert cls is AnthropicProvider
|
||||
|
||||
|
||||
@@ -5,7 +5,11 @@ from unittest.mock import MagicMock, AsyncMock, patch
|
||||
from inline_snapshot import snapshot
|
||||
|
||||
from agentic_security.llm_providers.openai_provider import OpenAIProvider
|
||||
from agentic_security.llm_providers.base import LLMMessage, LLMProviderError, LLMRateLimitError
|
||||
from agentic_security.llm_providers.base import (
|
||||
LLMMessage,
|
||||
LLMProviderError,
|
||||
LLMRateLimitError,
|
||||
)
|
||||
|
||||
|
||||
class TestOpenAIProviderInit:
|
||||
@@ -111,7 +115,9 @@ class TestOpenAIProviderSync:
|
||||
mock_response.usage = None
|
||||
|
||||
with patch.object(provider, "_get_client") as mock_client:
|
||||
mock_client.return_value.chat.completions.create.return_value = mock_response
|
||||
mock_client.return_value.chat.completions.create.return_value = (
|
||||
mock_response
|
||||
)
|
||||
result = provider.sync_generate("Hello")
|
||||
assert result.content == snapshot("Response")
|
||||
|
||||
@@ -126,7 +132,9 @@ class TestOpenAIProviderSync:
|
||||
messages = [LLMMessage(role="user", content="Hi")]
|
||||
|
||||
with patch.object(provider, "_get_client") as mock_client:
|
||||
mock_client.return_value.chat.completions.create.return_value = mock_response
|
||||
mock_client.return_value.chat.completions.create.return_value = (
|
||||
mock_response
|
||||
)
|
||||
result = provider.sync_chat(messages)
|
||||
assert result.content == snapshot("Chat response")
|
||||
|
||||
@@ -196,8 +204,11 @@ class TestOpenAIProviderErrors:
|
||||
|
||||
def test_handle_rate_limit_error(self, provider):
|
||||
import openai
|
||||
|
||||
with pytest.raises(LLMRateLimitError):
|
||||
provider._handle_error(openai.RateLimitError("rate limited", response=MagicMock(), body={}))
|
||||
provider._handle_error(
|
||||
openai.RateLimitError("rate limited", response=MagicMock(), body={})
|
||||
)
|
||||
|
||||
def test_handle_generic_error(self, provider):
|
||||
with pytest.raises(LLMProviderError):
|
||||
|
||||
@@ -45,7 +45,9 @@ class TestDetectionResult:
|
||||
|
||||
def test_weighted_score_cases(self):
|
||||
for is_refusal, weight, expected in detection_result_cases:
|
||||
result = DetectionResult(method="test", is_refusal=is_refusal, weight=weight)
|
||||
result = DetectionResult(
|
||||
method="test", is_refusal=is_refusal, weight=weight
|
||||
)
|
||||
assert result.weighted_score == expected
|
||||
|
||||
def test_default_weight(self):
|
||||
@@ -234,7 +236,14 @@ factory_cases = [
|
||||
({"ml_detector": MockDetector(True)}, 1),
|
||||
({"llm_detector": MockDetector(True)}, 1),
|
||||
({"marker_detector": MockDetector(True), "ml_detector": MockDetector(False)}, 2),
|
||||
({"marker_detector": MockDetector(True), "ml_detector": MockDetector(False), "llm_detector": MockDetector(True)}, 3),
|
||||
(
|
||||
{
|
||||
"marker_detector": MockDetector(True),
|
||||
"ml_detector": MockDetector(False),
|
||||
"llm_detector": MockDetector(True),
|
||||
},
|
||||
3,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,153 @@
|
||||
"""Unit tests for security module."""
|
||||
|
||||
import pytest
|
||||
from agentic_security.core.security import (
|
||||
SecurityValidator,
|
||||
SecretManager,
|
||||
RateLimiter,
|
||||
sanitize_log_output,
|
||||
)
|
||||
|
||||
|
||||
class TestSecurityValidator:
|
||||
|
||||
def test_validate_url_valid(self):
|
||||
assert SecurityValidator.validate_url("https://example.com/path")
|
||||
assert SecurityValidator.validate_url("http://api.example.com")
|
||||
|
||||
def test_validate_url_invalid_scheme(self):
|
||||
assert not SecurityValidator.validate_url("ftp://example.com")
|
||||
assert not SecurityValidator.validate_url("file:///etc/passwd")
|
||||
|
||||
def test_validate_url_localhost(self):
|
||||
assert not SecurityValidator.validate_url("http://localhost/api")
|
||||
assert not SecurityValidator.validate_url("http://127.0.0.1/api")
|
||||
assert not SecurityValidator.validate_url("http://0.0.0.0/api")
|
||||
|
||||
def test_validate_url_private_ip(self):
|
||||
assert not SecurityValidator.validate_url("http://10.0.0.1")
|
||||
assert not SecurityValidator.validate_url("http://192.168.1.1")
|
||||
assert not SecurityValidator.validate_url("http://169.254.1.1")
|
||||
|
||||
def test_validate_url_allowed_hosts(self):
|
||||
allowed = ["api.example.com"]
|
||||
assert SecurityValidator.validate_url("https://api.example.com", allowed)
|
||||
assert not SecurityValidator.validate_url("https://evil.com", allowed)
|
||||
|
||||
def test_validate_url_too_long(self):
|
||||
long_url = "https://example.com/" + "a" * 3000
|
||||
assert not SecurityValidator.validate_url(long_url)
|
||||
|
||||
def test_sanitize_filename(self):
|
||||
assert SecurityValidator.sanitize_filename("test.csv") == "test.csv"
|
||||
assert SecurityValidator.sanitize_filename("../../../etc/passwd") == "passwd"
|
||||
assert SecurityValidator.sanitize_filename("test/file.txt") == "file.txt"
|
||||
assert (
|
||||
SecurityValidator.sanitize_filename("file with spaces.txt")
|
||||
== "file with spaces.txt"
|
||||
)
|
||||
|
||||
def test_sanitize_filename_invalid(self):
|
||||
with pytest.raises(ValueError):
|
||||
SecurityValidator.sanitize_filename(".")
|
||||
with pytest.raises(ValueError):
|
||||
SecurityValidator.sanitize_filename("..")
|
||||
with pytest.raises(ValueError):
|
||||
SecurityValidator.sanitize_filename("")
|
||||
|
||||
def test_validate_file_size(self):
|
||||
assert SecurityValidator.validate_file_size(1024)
|
||||
assert SecurityValidator.validate_file_size(1024 * 1024)
|
||||
assert not SecurityValidator.validate_file_size(0)
|
||||
assert not SecurityValidator.validate_file_size(-1)
|
||||
assert not SecurityValidator.validate_file_size(20 * 1024 * 1024)
|
||||
|
||||
def test_validate_csv_content(self):
|
||||
assert SecurityValidator.validate_csv_content("col1,col2\nval1,val2")
|
||||
assert not SecurityValidator.validate_csv_content("")
|
||||
assert not SecurityValidator.validate_csv_content("x" * (11 * 1024 * 1024))
|
||||
|
||||
|
||||
class TestSecretManager:
|
||||
|
||||
def test_hash_and_verify_secret(self):
|
||||
secret = "my-secret-key"
|
||||
hashed = SecretManager.hash_secret(secret)
|
||||
|
||||
assert SecretManager.verify_secret(secret, hashed)
|
||||
assert not SecretManager.verify_secret("wrong-secret", hashed)
|
||||
|
||||
def test_hash_secret_with_salt(self):
|
||||
secret = "my-secret"
|
||||
hashed1 = SecretManager.hash_secret(secret, "salt1")
|
||||
hashed2 = SecretManager.hash_secret(secret, "salt2")
|
||||
|
||||
assert hashed1 != hashed2
|
||||
|
||||
def test_verify_secret_invalid_format(self):
|
||||
assert not SecretManager.verify_secret("secret", "invalid-hash")
|
||||
|
||||
|
||||
class TestRateLimiter:
|
||||
|
||||
def test_rate_limiter_allows_requests(self):
|
||||
limiter = RateLimiter(max_requests=3, window_seconds=60)
|
||||
|
||||
assert limiter.is_allowed("user1")
|
||||
assert limiter.is_allowed("user1")
|
||||
assert limiter.is_allowed("user1")
|
||||
|
||||
def test_rate_limiter_blocks_excess(self):
|
||||
limiter = RateLimiter(max_requests=2, window_seconds=60)
|
||||
|
||||
assert limiter.is_allowed("user1")
|
||||
assert limiter.is_allowed("user1")
|
||||
assert not limiter.is_allowed("user1")
|
||||
|
||||
def test_rate_limiter_separate_keys(self):
|
||||
limiter = RateLimiter(max_requests=2, window_seconds=60)
|
||||
|
||||
assert limiter.is_allowed("user1")
|
||||
assert limiter.is_allowed("user1")
|
||||
assert limiter.is_allowed("user2")
|
||||
assert limiter.is_allowed("user2")
|
||||
|
||||
def test_rate_limiter_reset(self):
|
||||
limiter = RateLimiter(max_requests=1, window_seconds=60)
|
||||
|
||||
assert limiter.is_allowed("user1")
|
||||
assert not limiter.is_allowed("user1")
|
||||
|
||||
limiter.reset("user1")
|
||||
assert limiter.is_allowed("user1")
|
||||
|
||||
|
||||
class TestSanitizeLogOutput:
|
||||
|
||||
def test_sanitize_api_key(self):
|
||||
data = 'api_key="sk-1234567890"'
|
||||
result = sanitize_log_output(data)
|
||||
assert "sk-1234567890" not in result
|
||||
assert "***" in result
|
||||
|
||||
def test_sanitize_token(self):
|
||||
data = "token: abc123xyz"
|
||||
result = sanitize_log_output(data)
|
||||
assert "abc123xyz" not in result
|
||||
|
||||
def test_sanitize_password(self):
|
||||
data = {"password": "secret123"}
|
||||
result = sanitize_log_output(data)
|
||||
assert "secret123" not in result
|
||||
|
||||
def test_sanitize_bearer_token(self):
|
||||
data = "Authorization: Bearer eyJhbGc..."
|
||||
result = sanitize_log_output(data)
|
||||
assert "eyJhbGc" not in result
|
||||
assert "Bearer ***" in result
|
||||
|
||||
def test_preserves_non_sensitive(self):
|
||||
data = "user_id=123 name=John"
|
||||
result = sanitize_log_output(data)
|
||||
assert "user_id=123" in result
|
||||
assert "name=John" in result
|
||||
Reference in New Issue
Block a user