mirror of
https://github.com/msoedov/agentic_security.git
synced 2026-06-28 16:19:58 +02:00
fix(pc):
This commit is contained in:
@@ -56,16 +56,20 @@ class TestRulesToDataset:
|
||||
class TestLoadRulesAsDataset:
|
||||
def test_basic_load(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
(Path(tmpdir) / "rule1.yaml").write_text("""
|
||||
(Path(tmpdir) / "rule1.yaml").write_text(
|
||||
"""
|
||||
name: test1
|
||||
type: jailbreak
|
||||
prompt: Jailbreak prompt
|
||||
""")
|
||||
(Path(tmpdir) / "rule2.yaml").write_text("""
|
||||
"""
|
||||
)
|
||||
(Path(tmpdir) / "rule2.yaml").write_text(
|
||||
"""
|
||||
name: test2
|
||||
type: harmful
|
||||
prompt: Harmful prompt
|
||||
""")
|
||||
"""
|
||||
)
|
||||
dataset = load_rules_as_dataset(tmpdir)
|
||||
assert len(dataset.prompts) == 2
|
||||
|
||||
|
||||
@@ -77,15 +77,15 @@ severity: high
|
||||
|
||||
def test_load_rule_from_file(self):
|
||||
loader = RuleLoader()
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w", suffix=".yaml", delete=False
|
||||
) as f:
|
||||
f.write("""
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
|
||||
f.write(
|
||||
"""
|
||||
name: file_test
|
||||
type: harmful
|
||||
severity: medium
|
||||
prompt: Test prompt from file
|
||||
""")
|
||||
"""
|
||||
)
|
||||
f.flush()
|
||||
rule = loader.load_rule_from_file(f.name)
|
||||
|
||||
@@ -96,9 +96,7 @@ prompt: Test prompt from file
|
||||
|
||||
def test_load_rule_from_file_wrong_extension(self):
|
||||
loader = RuleLoader()
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w", suffix=".txt", delete=False
|
||||
) as f:
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False) as f:
|
||||
f.write("name: test\nprompt: test")
|
||||
f.flush()
|
||||
rule = loader.load_rule_from_file(f.name)
|
||||
@@ -110,16 +108,20 @@ prompt: Test prompt from file
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
rule1_path = Path(tmpdir) / "rule1.yaml"
|
||||
rule2_path = Path(tmpdir) / "rule2.yml"
|
||||
rule1_path.write_text("""
|
||||
rule1_path.write_text(
|
||||
"""
|
||||
name: rule1
|
||||
type: jailbreak
|
||||
prompt: First rule
|
||||
""")
|
||||
rule2_path.write_text("""
|
||||
"""
|
||||
)
|
||||
rule2_path.write_text(
|
||||
"""
|
||||
name: rule2
|
||||
type: harmful
|
||||
prompt: Second rule
|
||||
""")
|
||||
"""
|
||||
)
|
||||
loader = RuleLoader()
|
||||
rules = loader.load_rules_from_directory(tmpdir)
|
||||
|
||||
|
||||
@@ -87,7 +87,12 @@ class TestAttackRule:
|
||||
rule = AttackRule(name="test", type="jailbreak", prompt="Test")
|
||||
result = rule.to_dict()
|
||||
assert result == snapshot(
|
||||
{"name": "test", "type": "jailbreak", "prompt": "Test", "severity": "medium"}
|
||||
{
|
||||
"name": "test",
|
||||
"type": "jailbreak",
|
||||
"prompt": "Test",
|
||||
"severity": "medium",
|
||||
}
|
||||
)
|
||||
|
||||
def test_render_prompt_no_variables(self):
|
||||
|
||||
@@ -116,11 +116,13 @@ class TestFuzzChain:
|
||||
|
||||
result = await chain.run(input="initial")
|
||||
assert result == "final result"
|
||||
assert llm.prompts == snapshot([
|
||||
"First: initial",
|
||||
"Second: step1 result",
|
||||
"Third: step2 result",
|
||||
])
|
||||
assert llm.prompts == snapshot(
|
||||
[
|
||||
"First: initial",
|
||||
"Second: step1 result",
|
||||
"Third: step2 result",
|
||||
]
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chain_with_custom_variables(self):
|
||||
@@ -131,10 +133,12 @@ class TestFuzzChain:
|
||||
|
||||
result = await chain.run(topic="security", input="test prompt")
|
||||
assert result == "evaluated"
|
||||
assert llm.prompts == snapshot([
|
||||
"Analyze security: test prompt",
|
||||
"Evaluate: analyzed",
|
||||
])
|
||||
assert llm.prompts == snapshot(
|
||||
[
|
||||
"Analyze security: test prompt",
|
||||
"Evaluate: analyzed",
|
||||
]
|
||||
)
|
||||
|
||||
def test_pipe_chain_to_node(self):
|
||||
llm = MockLLMProvider()
|
||||
@@ -158,11 +162,13 @@ class TestFuzzChain:
|
||||
|
||||
def test_len(self):
|
||||
llm = MockLLMProvider()
|
||||
chain = FuzzChain([
|
||||
FuzzNode(llm, "A"),
|
||||
FuzzNode(llm, "B"),
|
||||
FuzzNode(llm, "C"),
|
||||
])
|
||||
chain = FuzzChain(
|
||||
[
|
||||
FuzzNode(llm, "A"),
|
||||
FuzzNode(llm, "B"),
|
||||
FuzzNode(llm, "C"),
|
||||
]
|
||||
)
|
||||
assert len(chain) == 3
|
||||
|
||||
def test_repr(self):
|
||||
@@ -185,11 +191,13 @@ class TestPipeOperatorChaining:
|
||||
result = await chain.run(input="start")
|
||||
|
||||
assert result == "c"
|
||||
assert llm.prompts == snapshot([
|
||||
"Step1: start",
|
||||
"Step2: a",
|
||||
"Step3: b",
|
||||
])
|
||||
assert llm.prompts == snapshot(
|
||||
[
|
||||
"Step1: start",
|
||||
"Step2: a",
|
||||
"Step3: b",
|
||||
]
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chain_with_different_providers(self):
|
||||
|
||||
@@ -5,7 +5,11 @@ from unittest.mock import MagicMock, AsyncMock, patch
|
||||
from inline_snapshot import snapshot
|
||||
|
||||
from agentic_security.llm_providers.anthropic_provider import AnthropicProvider
|
||||
from agentic_security.llm_providers.base import LLMMessage, LLMProviderError, LLMRateLimitError
|
||||
from agentic_security.llm_providers.base import (
|
||||
LLMMessage,
|
||||
LLMProviderError,
|
||||
LLMRateLimitError,
|
||||
)
|
||||
|
||||
|
||||
class TestAnthropicProviderInit:
|
||||
@@ -209,13 +213,19 @@ class TestAnthropicProviderErrors:
|
||||
|
||||
def test_handle_rate_limit_error(self, provider):
|
||||
import anthropic
|
||||
|
||||
with pytest.raises(LLMRateLimitError):
|
||||
provider._handle_error(anthropic.RateLimitError("rate limited", response=MagicMock(), body={}))
|
||||
provider._handle_error(
|
||||
anthropic.RateLimitError("rate limited", response=MagicMock(), body={})
|
||||
)
|
||||
|
||||
def test_handle_api_error(self, provider):
|
||||
import anthropic
|
||||
|
||||
with pytest.raises(LLMProviderError):
|
||||
provider._handle_error(anthropic.APIError("api error", request=MagicMock(), body={}))
|
||||
provider._handle_error(
|
||||
anthropic.APIError("api error", request=MagicMock(), body={})
|
||||
)
|
||||
|
||||
def test_handle_generic_error(self, provider):
|
||||
with pytest.raises(LLMProviderError):
|
||||
|
||||
@@ -9,7 +9,11 @@ from agentic_security.llm_providers.factory import (
|
||||
list_providers,
|
||||
register_provider,
|
||||
)
|
||||
from agentic_security.llm_providers.base import BaseLLMProvider, LLMProviderError, LLMResponse
|
||||
from agentic_security.llm_providers.base import (
|
||||
BaseLLMProvider,
|
||||
LLMProviderError,
|
||||
LLMResponse,
|
||||
)
|
||||
|
||||
|
||||
class TestListProviders:
|
||||
@@ -26,11 +30,13 @@ class TestListProviders:
|
||||
class TestGetProviderClass:
|
||||
def test_get_openai(self):
|
||||
from agentic_security.llm_providers.openai_provider import OpenAIProvider
|
||||
|
||||
cls = get_provider_class("openai")
|
||||
assert cls is OpenAIProvider
|
||||
|
||||
def test_get_anthropic(self):
|
||||
from agentic_security.llm_providers.anthropic_provider import AnthropicProvider
|
||||
|
||||
cls = get_provider_class("anthropic")
|
||||
assert cls is AnthropicProvider
|
||||
|
||||
|
||||
@@ -5,7 +5,11 @@ from unittest.mock import MagicMock, AsyncMock, patch
|
||||
from inline_snapshot import snapshot
|
||||
|
||||
from agentic_security.llm_providers.openai_provider import OpenAIProvider
|
||||
from agentic_security.llm_providers.base import LLMMessage, LLMProviderError, LLMRateLimitError
|
||||
from agentic_security.llm_providers.base import (
|
||||
LLMMessage,
|
||||
LLMProviderError,
|
||||
LLMRateLimitError,
|
||||
)
|
||||
|
||||
|
||||
class TestOpenAIProviderInit:
|
||||
@@ -111,7 +115,9 @@ class TestOpenAIProviderSync:
|
||||
mock_response.usage = None
|
||||
|
||||
with patch.object(provider, "_get_client") as mock_client:
|
||||
mock_client.return_value.chat.completions.create.return_value = mock_response
|
||||
mock_client.return_value.chat.completions.create.return_value = (
|
||||
mock_response
|
||||
)
|
||||
result = provider.sync_generate("Hello")
|
||||
assert result.content == snapshot("Response")
|
||||
|
||||
@@ -126,7 +132,9 @@ class TestOpenAIProviderSync:
|
||||
messages = [LLMMessage(role="user", content="Hi")]
|
||||
|
||||
with patch.object(provider, "_get_client") as mock_client:
|
||||
mock_client.return_value.chat.completions.create.return_value = mock_response
|
||||
mock_client.return_value.chat.completions.create.return_value = (
|
||||
mock_response
|
||||
)
|
||||
result = provider.sync_chat(messages)
|
||||
assert result.content == snapshot("Chat response")
|
||||
|
||||
@@ -196,8 +204,11 @@ class TestOpenAIProviderErrors:
|
||||
|
||||
def test_handle_rate_limit_error(self, provider):
|
||||
import openai
|
||||
|
||||
with pytest.raises(LLMRateLimitError):
|
||||
provider._handle_error(openai.RateLimitError("rate limited", response=MagicMock(), body={}))
|
||||
provider._handle_error(
|
||||
openai.RateLimitError("rate limited", response=MagicMock(), body={})
|
||||
)
|
||||
|
||||
def test_handle_generic_error(self, provider):
|
||||
with pytest.raises(LLMProviderError):
|
||||
|
||||
@@ -45,7 +45,9 @@ class TestDetectionResult:
|
||||
|
||||
def test_weighted_score_cases(self):
|
||||
for is_refusal, weight, expected in detection_result_cases:
|
||||
result = DetectionResult(method="test", is_refusal=is_refusal, weight=weight)
|
||||
result = DetectionResult(
|
||||
method="test", is_refusal=is_refusal, weight=weight
|
||||
)
|
||||
assert result.weighted_score == expected
|
||||
|
||||
def test_default_weight(self):
|
||||
@@ -234,7 +236,14 @@ factory_cases = [
|
||||
({"ml_detector": MockDetector(True)}, 1),
|
||||
({"llm_detector": MockDetector(True)}, 1),
|
||||
({"marker_detector": MockDetector(True), "ml_detector": MockDetector(False)}, 2),
|
||||
({"marker_detector": MockDetector(True), "ml_detector": MockDetector(False), "llm_detector": MockDetector(True)}, 3),
|
||||
(
|
||||
{
|
||||
"marker_detector": MockDetector(True),
|
||||
"ml_detector": MockDetector(False),
|
||||
"llm_detector": MockDetector(True),
|
||||
},
|
||||
3,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,153 @@
|
||||
"""Unit tests for security module."""
|
||||
|
||||
import pytest
|
||||
from agentic_security.core.security import (
|
||||
SecurityValidator,
|
||||
SecretManager,
|
||||
RateLimiter,
|
||||
sanitize_log_output,
|
||||
)
|
||||
|
||||
|
||||
class TestSecurityValidator:
|
||||
|
||||
def test_validate_url_valid(self):
|
||||
assert SecurityValidator.validate_url("https://example.com/path")
|
||||
assert SecurityValidator.validate_url("http://api.example.com")
|
||||
|
||||
def test_validate_url_invalid_scheme(self):
|
||||
assert not SecurityValidator.validate_url("ftp://example.com")
|
||||
assert not SecurityValidator.validate_url("file:///etc/passwd")
|
||||
|
||||
def test_validate_url_localhost(self):
|
||||
assert not SecurityValidator.validate_url("http://localhost/api")
|
||||
assert not SecurityValidator.validate_url("http://127.0.0.1/api")
|
||||
assert not SecurityValidator.validate_url("http://0.0.0.0/api")
|
||||
|
||||
def test_validate_url_private_ip(self):
|
||||
assert not SecurityValidator.validate_url("http://10.0.0.1")
|
||||
assert not SecurityValidator.validate_url("http://192.168.1.1")
|
||||
assert not SecurityValidator.validate_url("http://169.254.1.1")
|
||||
|
||||
def test_validate_url_allowed_hosts(self):
|
||||
allowed = ["api.example.com"]
|
||||
assert SecurityValidator.validate_url("https://api.example.com", allowed)
|
||||
assert not SecurityValidator.validate_url("https://evil.com", allowed)
|
||||
|
||||
def test_validate_url_too_long(self):
|
||||
long_url = "https://example.com/" + "a" * 3000
|
||||
assert not SecurityValidator.validate_url(long_url)
|
||||
|
||||
def test_sanitize_filename(self):
|
||||
assert SecurityValidator.sanitize_filename("test.csv") == "test.csv"
|
||||
assert SecurityValidator.sanitize_filename("../../../etc/passwd") == "passwd"
|
||||
assert SecurityValidator.sanitize_filename("test/file.txt") == "file.txt"
|
||||
assert (
|
||||
SecurityValidator.sanitize_filename("file with spaces.txt")
|
||||
== "file with spaces.txt"
|
||||
)
|
||||
|
||||
def test_sanitize_filename_invalid(self):
|
||||
with pytest.raises(ValueError):
|
||||
SecurityValidator.sanitize_filename(".")
|
||||
with pytest.raises(ValueError):
|
||||
SecurityValidator.sanitize_filename("..")
|
||||
with pytest.raises(ValueError):
|
||||
SecurityValidator.sanitize_filename("")
|
||||
|
||||
def test_validate_file_size(self):
|
||||
assert SecurityValidator.validate_file_size(1024)
|
||||
assert SecurityValidator.validate_file_size(1024 * 1024)
|
||||
assert not SecurityValidator.validate_file_size(0)
|
||||
assert not SecurityValidator.validate_file_size(-1)
|
||||
assert not SecurityValidator.validate_file_size(20 * 1024 * 1024)
|
||||
|
||||
def test_validate_csv_content(self):
|
||||
assert SecurityValidator.validate_csv_content("col1,col2\nval1,val2")
|
||||
assert not SecurityValidator.validate_csv_content("")
|
||||
assert not SecurityValidator.validate_csv_content("x" * (11 * 1024 * 1024))
|
||||
|
||||
|
||||
class TestSecretManager:
|
||||
|
||||
def test_hash_and_verify_secret(self):
|
||||
secret = "my-secret-key"
|
||||
hashed = SecretManager.hash_secret(secret)
|
||||
|
||||
assert SecretManager.verify_secret(secret, hashed)
|
||||
assert not SecretManager.verify_secret("wrong-secret", hashed)
|
||||
|
||||
def test_hash_secret_with_salt(self):
|
||||
secret = "my-secret"
|
||||
hashed1 = SecretManager.hash_secret(secret, "salt1")
|
||||
hashed2 = SecretManager.hash_secret(secret, "salt2")
|
||||
|
||||
assert hashed1 != hashed2
|
||||
|
||||
def test_verify_secret_invalid_format(self):
|
||||
assert not SecretManager.verify_secret("secret", "invalid-hash")
|
||||
|
||||
|
||||
class TestRateLimiter:
|
||||
|
||||
def test_rate_limiter_allows_requests(self):
|
||||
limiter = RateLimiter(max_requests=3, window_seconds=60)
|
||||
|
||||
assert limiter.is_allowed("user1")
|
||||
assert limiter.is_allowed("user1")
|
||||
assert limiter.is_allowed("user1")
|
||||
|
||||
def test_rate_limiter_blocks_excess(self):
|
||||
limiter = RateLimiter(max_requests=2, window_seconds=60)
|
||||
|
||||
assert limiter.is_allowed("user1")
|
||||
assert limiter.is_allowed("user1")
|
||||
assert not limiter.is_allowed("user1")
|
||||
|
||||
def test_rate_limiter_separate_keys(self):
|
||||
limiter = RateLimiter(max_requests=2, window_seconds=60)
|
||||
|
||||
assert limiter.is_allowed("user1")
|
||||
assert limiter.is_allowed("user1")
|
||||
assert limiter.is_allowed("user2")
|
||||
assert limiter.is_allowed("user2")
|
||||
|
||||
def test_rate_limiter_reset(self):
|
||||
limiter = RateLimiter(max_requests=1, window_seconds=60)
|
||||
|
||||
assert limiter.is_allowed("user1")
|
||||
assert not limiter.is_allowed("user1")
|
||||
|
||||
limiter.reset("user1")
|
||||
assert limiter.is_allowed("user1")
|
||||
|
||||
|
||||
class TestSanitizeLogOutput:
|
||||
|
||||
def test_sanitize_api_key(self):
|
||||
data = 'api_key="sk-1234567890"'
|
||||
result = sanitize_log_output(data)
|
||||
assert "sk-1234567890" not in result
|
||||
assert "***" in result
|
||||
|
||||
def test_sanitize_token(self):
|
||||
data = "token: abc123xyz"
|
||||
result = sanitize_log_output(data)
|
||||
assert "abc123xyz" not in result
|
||||
|
||||
def test_sanitize_password(self):
|
||||
data = {"password": "secret123"}
|
||||
result = sanitize_log_output(data)
|
||||
assert "secret123" not in result
|
||||
|
||||
def test_sanitize_bearer_token(self):
|
||||
data = "Authorization: Bearer eyJhbGc..."
|
||||
result = sanitize_log_output(data)
|
||||
assert "eyJhbGc" not in result
|
||||
assert "Bearer ***" in result
|
||||
|
||||
def test_preserves_non_sensitive(self):
|
||||
data = "user_id=123 name=John"
|
||||
result = sanitize_log_output(data)
|
||||
assert "user_id=123" in result
|
||||
assert "name=John" in result
|
||||
Reference in New Issue
Block a user