This commit is contained in:
Alexander Myasoedov
2026-01-28 21:04:29 +02:00
parent 8d42a84a9d
commit bc7fdd7cfa
19 changed files with 491 additions and 77 deletions
+9 -3
View File
@@ -83,7 +83,9 @@ class YAMLRulesDatasetLoader:
severity_enums = None
if self.severities:
severity_enums = [AttackRuleSeverity.from_string(s) for s in self.severities]
severity_enums = [
AttackRuleSeverity.from_string(s) for s in self.severities
]
filtered = self._loader.filter_rules(
rules, types=self.types, severities=severity_enums
@@ -113,10 +115,14 @@ class YAMLRulesDatasetLoader:
severity_enums = None
if self.severities:
severity_enums = [AttackRuleSeverity.from_string(s) for s in self.severities]
severity_enums = [
AttackRuleSeverity.from_string(s) for s in self.severities
]
filtered = self._loader.filter_rules(
all_rules, types=self.types, severities=severity_enums
)
return rules_to_dataset(filtered, name="YAML Rules (merged)", variables=variables)
return rules_to_dataset(
filtered, name="YAML Rules (merged)", variables=variables
)
+5 -10
View File
@@ -1,4 +1,3 @@
import os
from pathlib import Path
import yaml
@@ -81,9 +80,7 @@ class RuleLoader:
return None
def load_rules_from_directory(
self,
directory: str | Path | None = None,
recursive: bool = True
self, directory: str | Path | None = None, recursive: bool = True
) -> list[AttackRule]:
directory = Path(directory) if directory else self.rules_dir
if not directory or not directory.exists():
@@ -91,7 +88,7 @@ class RuleLoader:
return []
rules = []
pattern = "**/*.yaml" if recursive else "*.yaml"
# pattern = "**/*.yaml" if recursive else "*.yaml"
for ext in [".yaml", ".yml"]:
glob_pattern = f"**/*{ext}" if recursive else f"*{ext}"
@@ -105,9 +102,7 @@ class RuleLoader:
return rules
def load_multiple_directories(
self,
directories: list[str | Path],
recursive: bool = True
self, directories: list[str | Path], recursive: bool = True
) -> list[AttackRule]:
all_rules = []
for directory in directories:
@@ -133,6 +128,7 @@ class RuleLoader:
if name_pattern:
import re
pattern = re.compile(name_pattern, re.IGNORECASE)
result = [r for r in result if pattern.search(r.name)]
@@ -154,8 +150,7 @@ class RuleLoader:
def load_rules_from_directory(
directory: str | Path,
recursive: bool = True
directory: str | Path, recursive: bool = True
) -> list[AttackRule]:
loader = RuleLoader()
return loader.load_rules_from_directory(directory, recursive)
+14 -4
View File
@@ -38,10 +38,20 @@ class AttackRule:
pass_conditions=data.get("pass_conditions", []),
fail_conditions=data.get("fail_conditions", []),
source=data.get("source"),
metadata={k: v for k, v in data.items() if k not in {
"name", "type", "prompt", "severity",
"pass_conditions", "fail_conditions", "source"
}},
metadata={
k: v
for k, v in data.items()
if k
not in {
"name",
"type",
"prompt",
"severity",
"pass_conditions",
"fail_conditions",
"source",
}
},
)
def to_dict(self) -> dict[str, Any]:
+179
View File
@@ -0,0 +1,179 @@
"""Security utilities and validation for agentic_security."""
from functools import wraps
from collections.abc import Callable
from urllib.parse import urlparse
import hashlib
import hmac
import os
import re
class SecurityValidator:
"""Input validation and sanitization."""
ALLOWED_URL_SCHEMES = {"http", "https"}
MAX_URL_LENGTH = 2048
MAX_FILE_SIZE = 10 * 1024 * 1024 # 10MB
@staticmethod
def validate_url(url: str, allowed_hosts: list[str] | None = None) -> bool:
"""Validate URL for SSRF prevention."""
if len(url) > SecurityValidator.MAX_URL_LENGTH:
return False
try:
parsed = urlparse(url)
if parsed.scheme not in SecurityValidator.ALLOWED_URL_SCHEMES:
return False
if not parsed.netloc:
return False
if parsed.netloc in ["localhost", "127.0.0.1", "0.0.0.0"]:
return False
if parsed.netloc.startswith("169.254."):
return False
if parsed.netloc.startswith("10.") or parsed.netloc.startswith("192.168."):
return False
if allowed_hosts and parsed.netloc not in allowed_hosts:
return False
return True
except Exception:
return False
@staticmethod
def sanitize_filename(filename: str) -> str:
"""Sanitize filename to prevent path traversal."""
filename = os.path.basename(filename)
filename = re.sub(r"[^\w\s.-]", "", filename)
filename = filename.strip()
if not filename or filename in [".", ".."]:
raise ValueError("Invalid filename")
return filename
@staticmethod
def validate_file_size(size: int) -> bool:
"""Validate file size."""
return 0 < size <= SecurityValidator.MAX_FILE_SIZE
@staticmethod
def validate_csv_content(content: str) -> bool:
"""Basic CSV validation."""
if not content or len(content) > SecurityValidator.MAX_FILE_SIZE:
return False
lines = content.split("\n", 2)
if not lines:
return False
return True
class SecretManager:
"""Secure secret handling."""
@staticmethod
def get_secret(key: str, default: str | None = None) -> str | None:
"""Get secret from environment."""
value = os.getenv(key, default)
if value and value.startswith("$"):
env_key = value[1:]
value = os.getenv(env_key, default)
return value
@staticmethod
def hash_secret(secret: str, salt: str | None = None) -> str:
"""Hash a secret value."""
if salt is None:
salt = os.urandom(32).hex()
hashed = hashlib.pbkdf2_hmac("sha256", secret.encode(), salt.encode(), 100000)
return f"{salt}${hashed.hex()}"
@staticmethod
def verify_secret(secret: str, hashed: str) -> bool:
"""Verify a secret against its hash."""
try:
salt, expected = hashed.split("$", 1)
actual = hashlib.pbkdf2_hmac(
"sha256", secret.encode(), salt.encode(), 100000
)
return hmac.compare_digest(actual.hex(), expected)
except Exception:
return False
class RateLimiter:
"""Simple in-memory rate limiter."""
def __init__(self, max_requests: int, window_seconds: int):
self.max_requests = max_requests
self.window_seconds = window_seconds
self._requests: dict[str, list[float]] = {}
def is_allowed(self, key: str) -> bool:
"""Check if request is allowed."""
import time
now = time.time()
if key not in self._requests:
self._requests[key] = []
self._requests[key] = [
ts for ts in self._requests[key] if now - ts < self.window_seconds
]
if len(self._requests[key]) >= self.max_requests:
return False
self._requests[key].append(now)
return True
def reset(self, key: str):
"""Reset rate limit for key."""
self._requests.pop(key, None)
def require_auth(func: Callable) -> Callable:
"""Decorator to require authentication."""
@wraps(func)
async def wrapper(*args, **kwargs):
# TODO: Implement actual auth check
# For now, check if API key is present
api_key = kwargs.get("api_key") or os.getenv("API_KEY")
if not api_key:
from fastapi import HTTPException
raise HTTPException(status_code=401, detail="Authentication required")
return await func(*args, **kwargs)
return wrapper
def sanitize_log_output(data: str | dict) -> str:
"""Remove sensitive data from logs."""
if isinstance(data, dict):
data = str(data)
patterns = [
(r'(api[_-]?key["\s:=]+)["\']?[\w-]+', r"\1***"),
(r'(token["\s:=]+)["\']?[\w-]+', r"\1***"),
(r'(password["\s:=]+)["\']?[\w-]+', r"\1***"),
(r'(secret["\s:=]+)["\']?[\w-]+', r"\1***"),
(r"Bearer\s+[\w-]+", "Bearer ***"),
]
for pattern, replacement in patterns:
data = re.sub(pattern, replacement, data, flags=re.IGNORECASE)
return data
+1 -2
View File
@@ -8,8 +8,7 @@ logger = logging.getLogger(__name__)
class FuzzRunnable(Protocol):
"""Protocol for objects that can be run in a fuzzing chain."""
async def run(self, **kwargs: Any) -> str:
...
async def run(self, **kwargs: Any) -> str: ...
class FuzzNode:
@@ -36,6 +36,7 @@ class AnthropicProvider(BaseLLMProvider):
def _get_client(self) -> Any:
if self._client is None:
import anthropic
kwargs: dict[str, Any] = {"api_key": self.api_key}
if self.base_url:
kwargs["base_url"] = self.base_url
@@ -45,6 +46,7 @@ class AnthropicProvider(BaseLLMProvider):
def _get_async_client(self) -> Any:
if self._async_client is None:
import anthropic
kwargs: dict[str, Any] = {"api_key": self.api_key}
if self.base_url:
kwargs["base_url"] = self.base_url
@@ -95,6 +97,7 @@ class AnthropicProvider(BaseLLMProvider):
def _handle_error(self, e: Exception) -> None:
import anthropic
if isinstance(e, anthropic.RateLimitError):
raise LLMRateLimitError(str(e)) from e
if isinstance(e, anthropic.APIError):
+2
View File
@@ -20,6 +20,7 @@ class LLMRateLimitError(LLMProviderError):
@dataclass
class LLMMessage:
"""A message in a chat conversation."""
role: str # "system", "user", or "assistant"
content: str
@@ -27,6 +28,7 @@ class LLMMessage:
@dataclass
class LLMResponse:
"""Response from an LLM provider."""
content: str
model: str | None = None
finish_reason: str | None = None
@@ -14,6 +14,7 @@ def _ensure_registered() -> None:
return
from agentic_security.llm_providers.openai_provider import OpenAIProvider
from agentic_security.llm_providers.anthropic_provider import AnthropicProvider
_PROVIDERS["openai"] = OpenAIProvider
_PROVIDERS["anthropic"] = AnthropicProvider
@@ -36,13 +36,17 @@ class OpenAIProvider(BaseLLMProvider):
def _get_client(self) -> Any:
if self._client is None:
import openai
self._client = openai.OpenAI(api_key=self.api_key, base_url=self.base_url)
return self._client
def _get_async_client(self) -> Any:
if self._async_client is None:
import openai
self._async_client = openai.AsyncOpenAI(api_key=self.api_key, base_url=self.base_url)
self._async_client = openai.AsyncOpenAI(
api_key=self.api_key, base_url=self.base_url
)
return self._async_client
@classmethod
@@ -79,6 +83,7 @@ class OpenAIProvider(BaseLLMProvider):
def _handle_error(self, e: Exception) -> None:
import openai
if isinstance(e, openai.RateLimitError):
raise LLMRateLimitError(str(e)) from e
raise LLMProviderError(str(e)) from e
@@ -96,11 +96,13 @@ class HybridRefusalClassifier:
self for method chaining
"""
detector_name = name or detector.__class__.__name__
self._detectors.append(DetectorConfig(
detector=detector,
weight=weight,
name=detector_name,
))
self._detectors.append(
DetectorConfig(
detector=detector,
weight=weight,
name=detector_name,
)
)
return self
def classify(self, response: str) -> HybridResult:
@@ -117,11 +119,13 @@ class HybridRefusalClassifier:
is_refusal = config.detector.is_refusal(response)
except Exception:
continue # Skip failed detectors
results.append(DetectionResult(
method=config.name,
is_refusal=is_refusal,
weight=config.weight,
))
results.append(
DetectionResult(
method=config.name,
is_refusal=is_refusal,
weight=config.weight,
)
)
if not results:
return HybridResult(is_refusal=False, confidence=0.0)
@@ -134,7 +138,9 @@ class HybridRefusalClassifier:
# Check unanimous requirement
if self.require_unanimous:
all_agree = all(r.is_refusal for r in results) or all(not r.is_refusal for r in results)
all_agree = all(r.is_refusal for r in results) or all(
not r.is_refusal for r in results
)
if not all_agree:
# Disagreement - return uncertain result
return HybridResult(
+8 -4
View File
@@ -56,16 +56,20 @@ class TestRulesToDataset:
class TestLoadRulesAsDataset:
def test_basic_load(self):
with tempfile.TemporaryDirectory() as tmpdir:
(Path(tmpdir) / "rule1.yaml").write_text("""
(Path(tmpdir) / "rule1.yaml").write_text(
"""
name: test1
type: jailbreak
prompt: Jailbreak prompt
""")
(Path(tmpdir) / "rule2.yaml").write_text("""
"""
)
(Path(tmpdir) / "rule2.yaml").write_text(
"""
name: test2
type: harmful
prompt: Harmful prompt
""")
"""
)
dataset = load_rules_as_dataset(tmpdir)
assert len(dataset.prompts) == 2
+14 -12
View File
@@ -77,15 +77,15 @@ severity: high
def test_load_rule_from_file(self):
loader = RuleLoader()
with tempfile.NamedTemporaryFile(
mode="w", suffix=".yaml", delete=False
) as f:
f.write("""
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
f.write(
"""
name: file_test
type: harmful
severity: medium
prompt: Test prompt from file
""")
"""
)
f.flush()
rule = loader.load_rule_from_file(f.name)
@@ -96,9 +96,7 @@ prompt: Test prompt from file
def test_load_rule_from_file_wrong_extension(self):
loader = RuleLoader()
with tempfile.NamedTemporaryFile(
mode="w", suffix=".txt", delete=False
) as f:
with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False) as f:
f.write("name: test\nprompt: test")
f.flush()
rule = loader.load_rule_from_file(f.name)
@@ -110,16 +108,20 @@ prompt: Test prompt from file
with tempfile.TemporaryDirectory() as tmpdir:
rule1_path = Path(tmpdir) / "rule1.yaml"
rule2_path = Path(tmpdir) / "rule2.yml"
rule1_path.write_text("""
rule1_path.write_text(
"""
name: rule1
type: jailbreak
prompt: First rule
""")
rule2_path.write_text("""
"""
)
rule2_path.write_text(
"""
name: rule2
type: harmful
prompt: Second rule
""")
"""
)
loader = RuleLoader()
rules = loader.load_rules_from_directory(tmpdir)
+6 -1
View File
@@ -87,7 +87,12 @@ class TestAttackRule:
rule = AttackRule(name="test", type="jailbreak", prompt="Test")
result = rule.to_dict()
assert result == snapshot(
{"name": "test", "type": "jailbreak", "prompt": "Test", "severity": "medium"}
{
"name": "test",
"type": "jailbreak",
"prompt": "Test",
"severity": "medium",
}
)
def test_render_prompt_no_variables(self):
+27 -19
View File
@@ -116,11 +116,13 @@ class TestFuzzChain:
result = await chain.run(input="initial")
assert result == "final result"
assert llm.prompts == snapshot([
"First: initial",
"Second: step1 result",
"Third: step2 result",
])
assert llm.prompts == snapshot(
[
"First: initial",
"Second: step1 result",
"Third: step2 result",
]
)
@pytest.mark.asyncio
async def test_chain_with_custom_variables(self):
@@ -131,10 +133,12 @@ class TestFuzzChain:
result = await chain.run(topic="security", input="test prompt")
assert result == "evaluated"
assert llm.prompts == snapshot([
"Analyze security: test prompt",
"Evaluate: analyzed",
])
assert llm.prompts == snapshot(
[
"Analyze security: test prompt",
"Evaluate: analyzed",
]
)
def test_pipe_chain_to_node(self):
llm = MockLLMProvider()
@@ -158,11 +162,13 @@ class TestFuzzChain:
def test_len(self):
llm = MockLLMProvider()
chain = FuzzChain([
FuzzNode(llm, "A"),
FuzzNode(llm, "B"),
FuzzNode(llm, "C"),
])
chain = FuzzChain(
[
FuzzNode(llm, "A"),
FuzzNode(llm, "B"),
FuzzNode(llm, "C"),
]
)
assert len(chain) == 3
def test_repr(self):
@@ -185,11 +191,13 @@ class TestPipeOperatorChaining:
result = await chain.run(input="start")
assert result == "c"
assert llm.prompts == snapshot([
"Step1: start",
"Step2: a",
"Step3: b",
])
assert llm.prompts == snapshot(
[
"Step1: start",
"Step2: a",
"Step3: b",
]
)
@pytest.mark.asyncio
async def test_chain_with_different_providers(self):
@@ -5,7 +5,11 @@ from unittest.mock import MagicMock, AsyncMock, patch
from inline_snapshot import snapshot
from agentic_security.llm_providers.anthropic_provider import AnthropicProvider
from agentic_security.llm_providers.base import LLMMessage, LLMProviderError, LLMRateLimitError
from agentic_security.llm_providers.base import (
LLMMessage,
LLMProviderError,
LLMRateLimitError,
)
class TestAnthropicProviderInit:
@@ -209,13 +213,19 @@ class TestAnthropicProviderErrors:
def test_handle_rate_limit_error(self, provider):
import anthropic
with pytest.raises(LLMRateLimitError):
provider._handle_error(anthropic.RateLimitError("rate limited", response=MagicMock(), body={}))
provider._handle_error(
anthropic.RateLimitError("rate limited", response=MagicMock(), body={})
)
def test_handle_api_error(self, provider):
import anthropic
with pytest.raises(LLMProviderError):
provider._handle_error(anthropic.APIError("api error", request=MagicMock(), body={}))
provider._handle_error(
anthropic.APIError("api error", request=MagicMock(), body={})
)
def test_handle_generic_error(self, provider):
with pytest.raises(LLMProviderError):
+7 -1
View File
@@ -9,7 +9,11 @@ from agentic_security.llm_providers.factory import (
list_providers,
register_provider,
)
from agentic_security.llm_providers.base import BaseLLMProvider, LLMProviderError, LLMResponse
from agentic_security.llm_providers.base import (
BaseLLMProvider,
LLMProviderError,
LLMResponse,
)
class TestListProviders:
@@ -26,11 +30,13 @@ class TestListProviders:
class TestGetProviderClass:
def test_get_openai(self):
from agentic_security.llm_providers.openai_provider import OpenAIProvider
cls = get_provider_class("openai")
assert cls is OpenAIProvider
def test_get_anthropic(self):
from agentic_security.llm_providers.anthropic_provider import AnthropicProvider
cls = get_provider_class("anthropic")
assert cls is AnthropicProvider
@@ -5,7 +5,11 @@ from unittest.mock import MagicMock, AsyncMock, patch
from inline_snapshot import snapshot
from agentic_security.llm_providers.openai_provider import OpenAIProvider
from agentic_security.llm_providers.base import LLMMessage, LLMProviderError, LLMRateLimitError
from agentic_security.llm_providers.base import (
LLMMessage,
LLMProviderError,
LLMRateLimitError,
)
class TestOpenAIProviderInit:
@@ -111,7 +115,9 @@ class TestOpenAIProviderSync:
mock_response.usage = None
with patch.object(provider, "_get_client") as mock_client:
mock_client.return_value.chat.completions.create.return_value = mock_response
mock_client.return_value.chat.completions.create.return_value = (
mock_response
)
result = provider.sync_generate("Hello")
assert result.content == snapshot("Response")
@@ -126,7 +132,9 @@ class TestOpenAIProviderSync:
messages = [LLMMessage(role="user", content="Hi")]
with patch.object(provider, "_get_client") as mock_client:
mock_client.return_value.chat.completions.create.return_value = mock_response
mock_client.return_value.chat.completions.create.return_value = (
mock_response
)
result = provider.sync_chat(messages)
assert result.content == snapshot("Chat response")
@@ -196,8 +204,11 @@ class TestOpenAIProviderErrors:
def test_handle_rate_limit_error(self, provider):
import openai
with pytest.raises(LLMRateLimitError):
provider._handle_error(openai.RateLimitError("rate limited", response=MagicMock(), body={}))
provider._handle_error(
openai.RateLimitError("rate limited", response=MagicMock(), body={})
)
def test_handle_generic_error(self, provider):
with pytest.raises(LLMProviderError):
@@ -45,7 +45,9 @@ class TestDetectionResult:
def test_weighted_score_cases(self):
for is_refusal, weight, expected in detection_result_cases:
result = DetectionResult(method="test", is_refusal=is_refusal, weight=weight)
result = DetectionResult(
method="test", is_refusal=is_refusal, weight=weight
)
assert result.weighted_score == expected
def test_default_weight(self):
@@ -234,7 +236,14 @@ factory_cases = [
({"ml_detector": MockDetector(True)}, 1),
({"llm_detector": MockDetector(True)}, 1),
({"marker_detector": MockDetector(True), "ml_detector": MockDetector(False)}, 2),
({"marker_detector": MockDetector(True), "ml_detector": MockDetector(False), "llm_detector": MockDetector(True)}, 3),
(
{
"marker_detector": MockDetector(True),
"ml_detector": MockDetector(False),
"llm_detector": MockDetector(True),
},
3,
),
]
+153
View File
@@ -0,0 +1,153 @@
"""Unit tests for security module."""
import pytest
from agentic_security.core.security import (
SecurityValidator,
SecretManager,
RateLimiter,
sanitize_log_output,
)
class TestSecurityValidator:
def test_validate_url_valid(self):
assert SecurityValidator.validate_url("https://example.com/path")
assert SecurityValidator.validate_url("http://api.example.com")
def test_validate_url_invalid_scheme(self):
assert not SecurityValidator.validate_url("ftp://example.com")
assert not SecurityValidator.validate_url("file:///etc/passwd")
def test_validate_url_localhost(self):
assert not SecurityValidator.validate_url("http://localhost/api")
assert not SecurityValidator.validate_url("http://127.0.0.1/api")
assert not SecurityValidator.validate_url("http://0.0.0.0/api")
def test_validate_url_private_ip(self):
assert not SecurityValidator.validate_url("http://10.0.0.1")
assert not SecurityValidator.validate_url("http://192.168.1.1")
assert not SecurityValidator.validate_url("http://169.254.1.1")
def test_validate_url_allowed_hosts(self):
allowed = ["api.example.com"]
assert SecurityValidator.validate_url("https://api.example.com", allowed)
assert not SecurityValidator.validate_url("https://evil.com", allowed)
def test_validate_url_too_long(self):
long_url = "https://example.com/" + "a" * 3000
assert not SecurityValidator.validate_url(long_url)
def test_sanitize_filename(self):
assert SecurityValidator.sanitize_filename("test.csv") == "test.csv"
assert SecurityValidator.sanitize_filename("../../../etc/passwd") == "passwd"
assert SecurityValidator.sanitize_filename("test/file.txt") == "file.txt"
assert (
SecurityValidator.sanitize_filename("file with spaces.txt")
== "file with spaces.txt"
)
def test_sanitize_filename_invalid(self):
with pytest.raises(ValueError):
SecurityValidator.sanitize_filename(".")
with pytest.raises(ValueError):
SecurityValidator.sanitize_filename("..")
with pytest.raises(ValueError):
SecurityValidator.sanitize_filename("")
def test_validate_file_size(self):
assert SecurityValidator.validate_file_size(1024)
assert SecurityValidator.validate_file_size(1024 * 1024)
assert not SecurityValidator.validate_file_size(0)
assert not SecurityValidator.validate_file_size(-1)
assert not SecurityValidator.validate_file_size(20 * 1024 * 1024)
def test_validate_csv_content(self):
assert SecurityValidator.validate_csv_content("col1,col2\nval1,val2")
assert not SecurityValidator.validate_csv_content("")
assert not SecurityValidator.validate_csv_content("x" * (11 * 1024 * 1024))
class TestSecretManager:
def test_hash_and_verify_secret(self):
secret = "my-secret-key"
hashed = SecretManager.hash_secret(secret)
assert SecretManager.verify_secret(secret, hashed)
assert not SecretManager.verify_secret("wrong-secret", hashed)
def test_hash_secret_with_salt(self):
secret = "my-secret"
hashed1 = SecretManager.hash_secret(secret, "salt1")
hashed2 = SecretManager.hash_secret(secret, "salt2")
assert hashed1 != hashed2
def test_verify_secret_invalid_format(self):
assert not SecretManager.verify_secret("secret", "invalid-hash")
class TestRateLimiter:
def test_rate_limiter_allows_requests(self):
limiter = RateLimiter(max_requests=3, window_seconds=60)
assert limiter.is_allowed("user1")
assert limiter.is_allowed("user1")
assert limiter.is_allowed("user1")
def test_rate_limiter_blocks_excess(self):
limiter = RateLimiter(max_requests=2, window_seconds=60)
assert limiter.is_allowed("user1")
assert limiter.is_allowed("user1")
assert not limiter.is_allowed("user1")
def test_rate_limiter_separate_keys(self):
limiter = RateLimiter(max_requests=2, window_seconds=60)
assert limiter.is_allowed("user1")
assert limiter.is_allowed("user1")
assert limiter.is_allowed("user2")
assert limiter.is_allowed("user2")
def test_rate_limiter_reset(self):
limiter = RateLimiter(max_requests=1, window_seconds=60)
assert limiter.is_allowed("user1")
assert not limiter.is_allowed("user1")
limiter.reset("user1")
assert limiter.is_allowed("user1")
class TestSanitizeLogOutput:
def test_sanitize_api_key(self):
data = 'api_key="sk-1234567890"'
result = sanitize_log_output(data)
assert "sk-1234567890" not in result
assert "***" in result
def test_sanitize_token(self):
data = "token: abc123xyz"
result = sanitize_log_output(data)
assert "abc123xyz" not in result
def test_sanitize_password(self):
data = {"password": "secret123"}
result = sanitize_log_output(data)
assert "secret123" not in result
def test_sanitize_bearer_token(self):
data = "Authorization: Bearer eyJhbGc..."
result = sanitize_log_output(data)
assert "eyJhbGc" not in result
assert "Bearer ***" in result
def test_preserves_non_sensitive(self):
data = "user_id=123 name=John"
result = sanitize_log_output(data)
assert "user_id=123" in result
assert "name=John" in result