This commit is contained in:
Alexander Myasoedov
2026-01-28 21:04:29 +02:00
parent 8d42a84a9d
commit bc7fdd7cfa
19 changed files with 491 additions and 77 deletions
+8 -4
View File
@@ -56,16 +56,20 @@ class TestRulesToDataset:
class TestLoadRulesAsDataset:
def test_basic_load(self):
with tempfile.TemporaryDirectory() as tmpdir:
(Path(tmpdir) / "rule1.yaml").write_text("""
(Path(tmpdir) / "rule1.yaml").write_text(
"""
name: test1
type: jailbreak
prompt: Jailbreak prompt
""")
(Path(tmpdir) / "rule2.yaml").write_text("""
"""
)
(Path(tmpdir) / "rule2.yaml").write_text(
"""
name: test2
type: harmful
prompt: Harmful prompt
""")
"""
)
dataset = load_rules_as_dataset(tmpdir)
assert len(dataset.prompts) == 2
+14 -12
View File
@@ -77,15 +77,15 @@ severity: high
def test_load_rule_from_file(self):
loader = RuleLoader()
with tempfile.NamedTemporaryFile(
mode="w", suffix=".yaml", delete=False
) as f:
f.write("""
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
f.write(
"""
name: file_test
type: harmful
severity: medium
prompt: Test prompt from file
""")
"""
)
f.flush()
rule = loader.load_rule_from_file(f.name)
@@ -96,9 +96,7 @@ prompt: Test prompt from file
def test_load_rule_from_file_wrong_extension(self):
loader = RuleLoader()
with tempfile.NamedTemporaryFile(
mode="w", suffix=".txt", delete=False
) as f:
with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False) as f:
f.write("name: test\nprompt: test")
f.flush()
rule = loader.load_rule_from_file(f.name)
@@ -110,16 +108,20 @@ prompt: Test prompt from file
with tempfile.TemporaryDirectory() as tmpdir:
rule1_path = Path(tmpdir) / "rule1.yaml"
rule2_path = Path(tmpdir) / "rule2.yml"
rule1_path.write_text("""
rule1_path.write_text(
"""
name: rule1
type: jailbreak
prompt: First rule
""")
rule2_path.write_text("""
"""
)
rule2_path.write_text(
"""
name: rule2
type: harmful
prompt: Second rule
""")
"""
)
loader = RuleLoader()
rules = loader.load_rules_from_directory(tmpdir)
+6 -1
View File
@@ -87,7 +87,12 @@ class TestAttackRule:
rule = AttackRule(name="test", type="jailbreak", prompt="Test")
result = rule.to_dict()
assert result == snapshot(
{"name": "test", "type": "jailbreak", "prompt": "Test", "severity": "medium"}
{
"name": "test",
"type": "jailbreak",
"prompt": "Test",
"severity": "medium",
}
)
def test_render_prompt_no_variables(self):
+27 -19
View File
@@ -116,11 +116,13 @@ class TestFuzzChain:
result = await chain.run(input="initial")
assert result == "final result"
assert llm.prompts == snapshot([
"First: initial",
"Second: step1 result",
"Third: step2 result",
])
assert llm.prompts == snapshot(
[
"First: initial",
"Second: step1 result",
"Third: step2 result",
]
)
@pytest.mark.asyncio
async def test_chain_with_custom_variables(self):
@@ -131,10 +133,12 @@ class TestFuzzChain:
result = await chain.run(topic="security", input="test prompt")
assert result == "evaluated"
assert llm.prompts == snapshot([
"Analyze security: test prompt",
"Evaluate: analyzed",
])
assert llm.prompts == snapshot(
[
"Analyze security: test prompt",
"Evaluate: analyzed",
]
)
def test_pipe_chain_to_node(self):
llm = MockLLMProvider()
@@ -158,11 +162,13 @@ class TestFuzzChain:
def test_len(self):
llm = MockLLMProvider()
chain = FuzzChain([
FuzzNode(llm, "A"),
FuzzNode(llm, "B"),
FuzzNode(llm, "C"),
])
chain = FuzzChain(
[
FuzzNode(llm, "A"),
FuzzNode(llm, "B"),
FuzzNode(llm, "C"),
]
)
assert len(chain) == 3
def test_repr(self):
@@ -185,11 +191,13 @@ class TestPipeOperatorChaining:
result = await chain.run(input="start")
assert result == "c"
assert llm.prompts == snapshot([
"Step1: start",
"Step2: a",
"Step3: b",
])
assert llm.prompts == snapshot(
[
"Step1: start",
"Step2: a",
"Step3: b",
]
)
@pytest.mark.asyncio
async def test_chain_with_different_providers(self):
@@ -5,7 +5,11 @@ from unittest.mock import MagicMock, AsyncMock, patch
from inline_snapshot import snapshot
from agentic_security.llm_providers.anthropic_provider import AnthropicProvider
from agentic_security.llm_providers.base import LLMMessage, LLMProviderError, LLMRateLimitError
from agentic_security.llm_providers.base import (
LLMMessage,
LLMProviderError,
LLMRateLimitError,
)
class TestAnthropicProviderInit:
@@ -209,13 +213,19 @@ class TestAnthropicProviderErrors:
def test_handle_rate_limit_error(self, provider):
import anthropic
with pytest.raises(LLMRateLimitError):
provider._handle_error(anthropic.RateLimitError("rate limited", response=MagicMock(), body={}))
provider._handle_error(
anthropic.RateLimitError("rate limited", response=MagicMock(), body={})
)
def test_handle_api_error(self, provider):
import anthropic
with pytest.raises(LLMProviderError):
provider._handle_error(anthropic.APIError("api error", request=MagicMock(), body={}))
provider._handle_error(
anthropic.APIError("api error", request=MagicMock(), body={})
)
def test_handle_generic_error(self, provider):
with pytest.raises(LLMProviderError):
+7 -1
View File
@@ -9,7 +9,11 @@ from agentic_security.llm_providers.factory import (
list_providers,
register_provider,
)
from agentic_security.llm_providers.base import BaseLLMProvider, LLMProviderError, LLMResponse
from agentic_security.llm_providers.base import (
BaseLLMProvider,
LLMProviderError,
LLMResponse,
)
class TestListProviders:
@@ -26,11 +30,13 @@ class TestListProviders:
class TestGetProviderClass:
def test_get_openai(self):
from agentic_security.llm_providers.openai_provider import OpenAIProvider
cls = get_provider_class("openai")
assert cls is OpenAIProvider
def test_get_anthropic(self):
from agentic_security.llm_providers.anthropic_provider import AnthropicProvider
cls = get_provider_class("anthropic")
assert cls is AnthropicProvider
@@ -5,7 +5,11 @@ from unittest.mock import MagicMock, AsyncMock, patch
from inline_snapshot import snapshot
from agentic_security.llm_providers.openai_provider import OpenAIProvider
from agentic_security.llm_providers.base import LLMMessage, LLMProviderError, LLMRateLimitError
from agentic_security.llm_providers.base import (
LLMMessage,
LLMProviderError,
LLMRateLimitError,
)
class TestOpenAIProviderInit:
@@ -111,7 +115,9 @@ class TestOpenAIProviderSync:
mock_response.usage = None
with patch.object(provider, "_get_client") as mock_client:
mock_client.return_value.chat.completions.create.return_value = mock_response
mock_client.return_value.chat.completions.create.return_value = (
mock_response
)
result = provider.sync_generate("Hello")
assert result.content == snapshot("Response")
@@ -126,7 +132,9 @@ class TestOpenAIProviderSync:
messages = [LLMMessage(role="user", content="Hi")]
with patch.object(provider, "_get_client") as mock_client:
mock_client.return_value.chat.completions.create.return_value = mock_response
mock_client.return_value.chat.completions.create.return_value = (
mock_response
)
result = provider.sync_chat(messages)
assert result.content == snapshot("Chat response")
@@ -196,8 +204,11 @@ class TestOpenAIProviderErrors:
def test_handle_rate_limit_error(self, provider):
import openai
with pytest.raises(LLMRateLimitError):
provider._handle_error(openai.RateLimitError("rate limited", response=MagicMock(), body={}))
provider._handle_error(
openai.RateLimitError("rate limited", response=MagicMock(), body={})
)
def test_handle_generic_error(self, provider):
with pytest.raises(LLMProviderError):
@@ -45,7 +45,9 @@ class TestDetectionResult:
def test_weighted_score_cases(self):
for is_refusal, weight, expected in detection_result_cases:
result = DetectionResult(method="test", is_refusal=is_refusal, weight=weight)
result = DetectionResult(
method="test", is_refusal=is_refusal, weight=weight
)
assert result.weighted_score == expected
def test_default_weight(self):
@@ -234,7 +236,14 @@ factory_cases = [
({"ml_detector": MockDetector(True)}, 1),
({"llm_detector": MockDetector(True)}, 1),
({"marker_detector": MockDetector(True), "ml_detector": MockDetector(False)}, 2),
({"marker_detector": MockDetector(True), "ml_detector": MockDetector(False), "llm_detector": MockDetector(True)}, 3),
(
{
"marker_detector": MockDetector(True),
"ml_detector": MockDetector(False),
"llm_detector": MockDetector(True),
},
3,
),
]
+153
View File
@@ -0,0 +1,153 @@
"""Unit tests for security module."""
import pytest
from agentic_security.core.security import (
SecurityValidator,
SecretManager,
RateLimiter,
sanitize_log_output,
)
class TestSecurityValidator:
def test_validate_url_valid(self):
assert SecurityValidator.validate_url("https://example.com/path")
assert SecurityValidator.validate_url("http://api.example.com")
def test_validate_url_invalid_scheme(self):
assert not SecurityValidator.validate_url("ftp://example.com")
assert not SecurityValidator.validate_url("file:///etc/passwd")
def test_validate_url_localhost(self):
assert not SecurityValidator.validate_url("http://localhost/api")
assert not SecurityValidator.validate_url("http://127.0.0.1/api")
assert not SecurityValidator.validate_url("http://0.0.0.0/api")
def test_validate_url_private_ip(self):
assert not SecurityValidator.validate_url("http://10.0.0.1")
assert not SecurityValidator.validate_url("http://192.168.1.1")
assert not SecurityValidator.validate_url("http://169.254.1.1")
def test_validate_url_allowed_hosts(self):
allowed = ["api.example.com"]
assert SecurityValidator.validate_url("https://api.example.com", allowed)
assert not SecurityValidator.validate_url("https://evil.com", allowed)
def test_validate_url_too_long(self):
long_url = "https://example.com/" + "a" * 3000
assert not SecurityValidator.validate_url(long_url)
def test_sanitize_filename(self):
assert SecurityValidator.sanitize_filename("test.csv") == "test.csv"
assert SecurityValidator.sanitize_filename("../../../etc/passwd") == "passwd"
assert SecurityValidator.sanitize_filename("test/file.txt") == "file.txt"
assert (
SecurityValidator.sanitize_filename("file with spaces.txt")
== "file with spaces.txt"
)
def test_sanitize_filename_invalid(self):
with pytest.raises(ValueError):
SecurityValidator.sanitize_filename(".")
with pytest.raises(ValueError):
SecurityValidator.sanitize_filename("..")
with pytest.raises(ValueError):
SecurityValidator.sanitize_filename("")
def test_validate_file_size(self):
assert SecurityValidator.validate_file_size(1024)
assert SecurityValidator.validate_file_size(1024 * 1024)
assert not SecurityValidator.validate_file_size(0)
assert not SecurityValidator.validate_file_size(-1)
assert not SecurityValidator.validate_file_size(20 * 1024 * 1024)
def test_validate_csv_content(self):
assert SecurityValidator.validate_csv_content("col1,col2\nval1,val2")
assert not SecurityValidator.validate_csv_content("")
assert not SecurityValidator.validate_csv_content("x" * (11 * 1024 * 1024))
class TestSecretManager:
def test_hash_and_verify_secret(self):
secret = "my-secret-key"
hashed = SecretManager.hash_secret(secret)
assert SecretManager.verify_secret(secret, hashed)
assert not SecretManager.verify_secret("wrong-secret", hashed)
def test_hash_secret_with_salt(self):
secret = "my-secret"
hashed1 = SecretManager.hash_secret(secret, "salt1")
hashed2 = SecretManager.hash_secret(secret, "salt2")
assert hashed1 != hashed2
def test_verify_secret_invalid_format(self):
assert not SecretManager.verify_secret("secret", "invalid-hash")
class TestRateLimiter:
def test_rate_limiter_allows_requests(self):
limiter = RateLimiter(max_requests=3, window_seconds=60)
assert limiter.is_allowed("user1")
assert limiter.is_allowed("user1")
assert limiter.is_allowed("user1")
def test_rate_limiter_blocks_excess(self):
limiter = RateLimiter(max_requests=2, window_seconds=60)
assert limiter.is_allowed("user1")
assert limiter.is_allowed("user1")
assert not limiter.is_allowed("user1")
def test_rate_limiter_separate_keys(self):
limiter = RateLimiter(max_requests=2, window_seconds=60)
assert limiter.is_allowed("user1")
assert limiter.is_allowed("user1")
assert limiter.is_allowed("user2")
assert limiter.is_allowed("user2")
def test_rate_limiter_reset(self):
limiter = RateLimiter(max_requests=1, window_seconds=60)
assert limiter.is_allowed("user1")
assert not limiter.is_allowed("user1")
limiter.reset("user1")
assert limiter.is_allowed("user1")
class TestSanitizeLogOutput:
def test_sanitize_api_key(self):
data = 'api_key="sk-1234567890"'
result = sanitize_log_output(data)
assert "sk-1234567890" not in result
assert "***" in result
def test_sanitize_token(self):
data = "token: abc123xyz"
result = sanitize_log_output(data)
assert "abc123xyz" not in result
def test_sanitize_password(self):
data = {"password": "secret123"}
result = sanitize_log_output(data)
assert "secret123" not in result
def test_sanitize_bearer_token(self):
data = "Authorization: Bearer eyJhbGc..."
result = sanitize_log_output(data)
assert "eyJhbGc" not in result
assert "Bearer ***" in result
def test_preserves_non_sensitive(self):
data = "user_id=123 name=John"
result = sanitize_log_output(data)
assert "user_id=123" in result
assert "name=John" in result