From bc7fdd7cfa509109f1c10565ecbc92a880cc4522 Mon Sep 17 00:00:00 2001 From: Alexander Myasoedov Date: Wed, 28 Jan 2026 21:04:29 +0200 Subject: [PATCH] fix(pc): --- agentic_security/attack_rules/dataset.py | 12 +- agentic_security/attack_rules/loader.py | 15 +- agentic_security/attack_rules/models.py | 18 +- agentic_security/core/security.py | 179 ++++++++++++++++++ agentic_security/fuzz_chain/chain.py | 3 +- .../llm_providers/anthropic_provider.py | 3 + agentic_security/llm_providers/base.py | 2 + agentic_security/llm_providers/factory.py | 1 + .../llm_providers/openai_provider.py | 7 +- .../refusal_classifier/hybrid_classifier.py | 28 +-- tests/unit/attack_rules/test_dataset.py | 12 +- tests/unit/attack_rules/test_loader.py | 26 +-- tests/unit/attack_rules/test_models.py | 7 +- tests/unit/fuzz_chain/test_chain.py | 46 +++-- .../llm_providers/test_anthropic_provider.py | 16 +- tests/unit/llm_providers/test_factory.py | 8 +- .../llm_providers/test_openai_provider.py | 19 +- .../test_hybrid_classifier.py | 13 +- tests/unit/test_security.py | 153 +++++++++++++++ 19 files changed, 491 insertions(+), 77 deletions(-) create mode 100644 agentic_security/core/security.py create mode 100644 tests/unit/test_security.py diff --git a/agentic_security/attack_rules/dataset.py b/agentic_security/attack_rules/dataset.py index 7f7bbf7..2c9e575 100644 --- a/agentic_security/attack_rules/dataset.py +++ b/agentic_security/attack_rules/dataset.py @@ -83,7 +83,9 @@ class YAMLRulesDatasetLoader: severity_enums = None if self.severities: - severity_enums = [AttackRuleSeverity.from_string(s) for s in self.severities] + severity_enums = [ + AttackRuleSeverity.from_string(s) for s in self.severities + ] filtered = self._loader.filter_rules( rules, types=self.types, severities=severity_enums @@ -113,10 +115,14 @@ class YAMLRulesDatasetLoader: severity_enums = None if self.severities: - severity_enums = [AttackRuleSeverity.from_string(s) for s in self.severities] + severity_enums = [ + AttackRuleSeverity.from_string(s) for s in self.severities + ] filtered = self._loader.filter_rules( all_rules, types=self.types, severities=severity_enums ) - return rules_to_dataset(filtered, name="YAML Rules (merged)", variables=variables) + return rules_to_dataset( + filtered, name="YAML Rules (merged)", variables=variables + ) diff --git a/agentic_security/attack_rules/loader.py b/agentic_security/attack_rules/loader.py index dde5765..513ba29 100644 --- a/agentic_security/attack_rules/loader.py +++ b/agentic_security/attack_rules/loader.py @@ -1,4 +1,3 @@ -import os from pathlib import Path import yaml @@ -81,9 +80,7 @@ class RuleLoader: return None def load_rules_from_directory( - self, - directory: str | Path | None = None, - recursive: bool = True + self, directory: str | Path | None = None, recursive: bool = True ) -> list[AttackRule]: directory = Path(directory) if directory else self.rules_dir if not directory or not directory.exists(): @@ -91,7 +88,7 @@ class RuleLoader: return [] rules = [] - pattern = "**/*.yaml" if recursive else "*.yaml" + # pattern = "**/*.yaml" if recursive else "*.yaml" for ext in [".yaml", ".yml"]: glob_pattern = f"**/*{ext}" if recursive else f"*{ext}" @@ -105,9 +102,7 @@ class RuleLoader: return rules def load_multiple_directories( - self, - directories: list[str | Path], - recursive: bool = True + self, directories: list[str | Path], recursive: bool = True ) -> list[AttackRule]: all_rules = [] for directory in directories: @@ -133,6 +128,7 @@ class RuleLoader: if name_pattern: import re + pattern = re.compile(name_pattern, re.IGNORECASE) result = [r for r in result if pattern.search(r.name)] @@ -154,8 +150,7 @@ class RuleLoader: def load_rules_from_directory( - directory: str | Path, - recursive: bool = True + directory: str | Path, recursive: bool = True ) -> list[AttackRule]: loader = RuleLoader() return loader.load_rules_from_directory(directory, recursive) diff --git a/agentic_security/attack_rules/models.py b/agentic_security/attack_rules/models.py index 7533d05..9c949c3 100644 --- a/agentic_security/attack_rules/models.py +++ b/agentic_security/attack_rules/models.py @@ -38,10 +38,20 @@ class AttackRule: pass_conditions=data.get("pass_conditions", []), fail_conditions=data.get("fail_conditions", []), source=data.get("source"), - metadata={k: v for k, v in data.items() if k not in { - "name", "type", "prompt", "severity", - "pass_conditions", "fail_conditions", "source" - }}, + metadata={ + k: v + for k, v in data.items() + if k + not in { + "name", + "type", + "prompt", + "severity", + "pass_conditions", + "fail_conditions", + "source", + } + }, ) def to_dict(self) -> dict[str, Any]: diff --git a/agentic_security/core/security.py b/agentic_security/core/security.py new file mode 100644 index 0000000..8a784a3 --- /dev/null +++ b/agentic_security/core/security.py @@ -0,0 +1,179 @@ +"""Security utilities and validation for agentic_security.""" + +from functools import wraps +from collections.abc import Callable +from urllib.parse import urlparse +import hashlib +import hmac +import os +import re + + +class SecurityValidator: + """Input validation and sanitization.""" + + ALLOWED_URL_SCHEMES = {"http", "https"} + MAX_URL_LENGTH = 2048 + MAX_FILE_SIZE = 10 * 1024 * 1024 # 10MB + + @staticmethod + def validate_url(url: str, allowed_hosts: list[str] | None = None) -> bool: + """Validate URL for SSRF prevention.""" + if len(url) > SecurityValidator.MAX_URL_LENGTH: + return False + + try: + parsed = urlparse(url) + + if parsed.scheme not in SecurityValidator.ALLOWED_URL_SCHEMES: + return False + + if not parsed.netloc: + return False + + if parsed.netloc in ["localhost", "127.0.0.1", "0.0.0.0"]: + return False + + if parsed.netloc.startswith("169.254."): + return False + + if parsed.netloc.startswith("10.") or parsed.netloc.startswith("192.168."): + return False + + if allowed_hosts and parsed.netloc not in allowed_hosts: + return False + + return True + except Exception: + return False + + @staticmethod + def sanitize_filename(filename: str) -> str: + """Sanitize filename to prevent path traversal.""" + filename = os.path.basename(filename) + filename = re.sub(r"[^\w\s.-]", "", filename) + filename = filename.strip() + + if not filename or filename in [".", ".."]: + raise ValueError("Invalid filename") + + return filename + + @staticmethod + def validate_file_size(size: int) -> bool: + """Validate file size.""" + return 0 < size <= SecurityValidator.MAX_FILE_SIZE + + @staticmethod + def validate_csv_content(content: str) -> bool: + """Basic CSV validation.""" + if not content or len(content) > SecurityValidator.MAX_FILE_SIZE: + return False + + lines = content.split("\n", 2) + if not lines: + return False + + return True + + +class SecretManager: + """Secure secret handling.""" + + @staticmethod + def get_secret(key: str, default: str | None = None) -> str | None: + """Get secret from environment.""" + value = os.getenv(key, default) + if value and value.startswith("$"): + env_key = value[1:] + value = os.getenv(env_key, default) + return value + + @staticmethod + def hash_secret(secret: str, salt: str | None = None) -> str: + """Hash a secret value.""" + if salt is None: + salt = os.urandom(32).hex() + + hashed = hashlib.pbkdf2_hmac("sha256", secret.encode(), salt.encode(), 100000) + return f"{salt}${hashed.hex()}" + + @staticmethod + def verify_secret(secret: str, hashed: str) -> bool: + """Verify a secret against its hash.""" + try: + salt, expected = hashed.split("$", 1) + actual = hashlib.pbkdf2_hmac( + "sha256", secret.encode(), salt.encode(), 100000 + ) + return hmac.compare_digest(actual.hex(), expected) + except Exception: + return False + + +class RateLimiter: + """Simple in-memory rate limiter.""" + + def __init__(self, max_requests: int, window_seconds: int): + self.max_requests = max_requests + self.window_seconds = window_seconds + self._requests: dict[str, list[float]] = {} + + def is_allowed(self, key: str) -> bool: + """Check if request is allowed.""" + import time + + now = time.time() + + if key not in self._requests: + self._requests[key] = [] + + self._requests[key] = [ + ts for ts in self._requests[key] if now - ts < self.window_seconds + ] + + if len(self._requests[key]) >= self.max_requests: + return False + + self._requests[key].append(now) + return True + + def reset(self, key: str): + """Reset rate limit for key.""" + self._requests.pop(key, None) + + +def require_auth(func: Callable) -> Callable: + """Decorator to require authentication.""" + + @wraps(func) + async def wrapper(*args, **kwargs): + # TODO: Implement actual auth check + # For now, check if API key is present + api_key = kwargs.get("api_key") or os.getenv("API_KEY") + if not api_key: + from fastapi import HTTPException + + raise HTTPException(status_code=401, detail="Authentication required") + return await func(*args, **kwargs) + + return wrapper + + +def sanitize_log_output(data: str | dict) -> str: + """Remove sensitive data from logs.""" + if isinstance(data, dict): + data = str(data) + + patterns = [ + (r'(api[_-]?key["\s:=]+)["\']?[\w-]+', r"\1***"), + (r'(token["\s:=]+)["\']?[\w-]+', r"\1***"), + (r'(password["\s:=]+)["\']?[\w-]+', r"\1***"), + (r'(secret["\s:=]+)["\']?[\w-]+', r"\1***"), + (r"Bearer\s+[\w-]+", "Bearer ***"), + ] + + for pattern, replacement in patterns: + data = re.sub(pattern, replacement, data, flags=re.IGNORECASE) + + return data diff --git a/agentic_security/fuzz_chain/chain.py b/agentic_security/fuzz_chain/chain.py index caadd81..1d3bcd0 100644 --- a/agentic_security/fuzz_chain/chain.py +++ b/agentic_security/fuzz_chain/chain.py @@ -8,8 +8,7 @@ logger = logging.getLogger(__name__) class FuzzRunnable(Protocol): """Protocol for objects that can be run in a fuzzing chain.""" - async def run(self, **kwargs: Any) -> str: - ... + async def run(self, **kwargs: Any) -> str: ... class FuzzNode: diff --git a/agentic_security/llm_providers/anthropic_provider.py b/agentic_security/llm_providers/anthropic_provider.py index a5b64b0..c63b0d5 100644 --- a/agentic_security/llm_providers/anthropic_provider.py +++ b/agentic_security/llm_providers/anthropic_provider.py @@ -36,6 +36,7 @@ class AnthropicProvider(BaseLLMProvider): def _get_client(self) -> Any: if self._client is None: import anthropic + kwargs: dict[str, Any] = {"api_key": self.api_key} if self.base_url: kwargs["base_url"] = self.base_url @@ -45,6 +46,7 @@ class AnthropicProvider(BaseLLMProvider): def _get_async_client(self) -> Any: if self._async_client is None: import anthropic + kwargs: dict[str, Any] = {"api_key": self.api_key} if self.base_url: kwargs["base_url"] = self.base_url @@ -95,6 +97,7 @@ class AnthropicProvider(BaseLLMProvider): def _handle_error(self, e: Exception) -> None: import anthropic + if isinstance(e, anthropic.RateLimitError): raise LLMRateLimitError(str(e)) from e if isinstance(e, anthropic.APIError): diff --git a/agentic_security/llm_providers/base.py b/agentic_security/llm_providers/base.py index 1992078..99b9be7 100644 --- a/agentic_security/llm_providers/base.py +++ b/agentic_security/llm_providers/base.py @@ -20,6 +20,7 @@ class LLMRateLimitError(LLMProviderError): @dataclass class LLMMessage: """A message in a chat conversation.""" + role: str # "system", "user", or "assistant" content: str @@ -27,6 +28,7 @@ class LLMMessage: @dataclass class LLMResponse: """Response from an LLM provider.""" + content: str model: str | None = None finish_reason: str | None = None diff --git a/agentic_security/llm_providers/factory.py b/agentic_security/llm_providers/factory.py index 2e4edb3..4736bef 100644 --- a/agentic_security/llm_providers/factory.py +++ b/agentic_security/llm_providers/factory.py @@ -14,6 +14,7 @@ def _ensure_registered() -> None: return from agentic_security.llm_providers.openai_provider import OpenAIProvider from agentic_security.llm_providers.anthropic_provider import AnthropicProvider + _PROVIDERS["openai"] = OpenAIProvider _PROVIDERS["anthropic"] = AnthropicProvider diff --git a/agentic_security/llm_providers/openai_provider.py b/agentic_security/llm_providers/openai_provider.py index 6d0add4..ed64c93 100644 --- a/agentic_security/llm_providers/openai_provider.py +++ b/agentic_security/llm_providers/openai_provider.py @@ -36,13 +36,17 @@ class OpenAIProvider(BaseLLMProvider): def _get_client(self) -> Any: if self._client is None: import openai + self._client = openai.OpenAI(api_key=self.api_key, base_url=self.base_url) return self._client def _get_async_client(self) -> Any: if self._async_client is None: import openai - self._async_client = openai.AsyncOpenAI(api_key=self.api_key, base_url=self.base_url) + + self._async_client = openai.AsyncOpenAI( + api_key=self.api_key, base_url=self.base_url + ) return self._async_client @classmethod @@ -79,6 +83,7 @@ class OpenAIProvider(BaseLLMProvider): def _handle_error(self, e: Exception) -> None: import openai + if isinstance(e, openai.RateLimitError): raise LLMRateLimitError(str(e)) from e raise LLMProviderError(str(e)) from e diff --git a/agentic_security/refusal_classifier/hybrid_classifier.py b/agentic_security/refusal_classifier/hybrid_classifier.py index e1cd3f4..a9639ef 100644 --- a/agentic_security/refusal_classifier/hybrid_classifier.py +++ b/agentic_security/refusal_classifier/hybrid_classifier.py @@ -96,11 +96,13 @@ class HybridRefusalClassifier: self for method chaining """ detector_name = name or detector.__class__.__name__ - self._detectors.append(DetectorConfig( - detector=detector, - weight=weight, - name=detector_name, - )) + self._detectors.append( + DetectorConfig( + detector=detector, + weight=weight, + name=detector_name, + ) + ) return self def classify(self, response: str) -> HybridResult: @@ -117,11 +119,13 @@ class HybridRefusalClassifier: is_refusal = config.detector.is_refusal(response) except Exception: continue # Skip failed detectors - results.append(DetectionResult( - method=config.name, - is_refusal=is_refusal, - weight=config.weight, - )) + results.append( + DetectionResult( + method=config.name, + is_refusal=is_refusal, + weight=config.weight, + ) + ) if not results: return HybridResult(is_refusal=False, confidence=0.0) @@ -134,7 +138,9 @@ class HybridRefusalClassifier: # Check unanimous requirement if self.require_unanimous: - all_agree = all(r.is_refusal for r in results) or all(not r.is_refusal for r in results) + all_agree = all(r.is_refusal for r in results) or all( + not r.is_refusal for r in results + ) if not all_agree: # Disagreement - return uncertain result return HybridResult( diff --git a/tests/unit/attack_rules/test_dataset.py b/tests/unit/attack_rules/test_dataset.py index e215fca..e6cc0e9 100644 --- a/tests/unit/attack_rules/test_dataset.py +++ b/tests/unit/attack_rules/test_dataset.py @@ -56,16 +56,20 @@ class TestRulesToDataset: class TestLoadRulesAsDataset: def test_basic_load(self): with tempfile.TemporaryDirectory() as tmpdir: - (Path(tmpdir) / "rule1.yaml").write_text(""" + (Path(tmpdir) / "rule1.yaml").write_text( + """ name: test1 type: jailbreak prompt: Jailbreak prompt -""") - (Path(tmpdir) / "rule2.yaml").write_text(""" +""" + ) + (Path(tmpdir) / "rule2.yaml").write_text( + """ name: test2 type: harmful prompt: Harmful prompt -""") +""" + ) dataset = load_rules_as_dataset(tmpdir) assert len(dataset.prompts) == 2 diff --git a/tests/unit/attack_rules/test_loader.py b/tests/unit/attack_rules/test_loader.py index e2d4ba9..1d0a5a8 100644 --- a/tests/unit/attack_rules/test_loader.py +++ b/tests/unit/attack_rules/test_loader.py @@ -77,15 +77,15 @@ severity: high def test_load_rule_from_file(self): loader = RuleLoader() - with tempfile.NamedTemporaryFile( - mode="w", suffix=".yaml", delete=False - ) as f: - f.write(""" + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + f.write( + """ name: file_test type: harmful severity: medium prompt: Test prompt from file -""") +""" + ) f.flush() rule = loader.load_rule_from_file(f.name) @@ -96,9 +96,7 @@ prompt: Test prompt from file def test_load_rule_from_file_wrong_extension(self): loader = RuleLoader() - with tempfile.NamedTemporaryFile( - mode="w", suffix=".txt", delete=False - ) as f: + with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False) as f: f.write("name: test\nprompt: test") f.flush() rule = loader.load_rule_from_file(f.name) @@ -110,16 +108,20 @@ prompt: Test prompt from file with tempfile.TemporaryDirectory() as tmpdir: rule1_path = Path(tmpdir) / "rule1.yaml" rule2_path = Path(tmpdir) / "rule2.yml" - rule1_path.write_text(""" + rule1_path.write_text( + """ name: rule1 type: jailbreak prompt: First rule -""") - rule2_path.write_text(""" +""" + ) + rule2_path.write_text( + """ name: rule2 type: harmful prompt: Second rule -""") +""" + ) loader = RuleLoader() rules = loader.load_rules_from_directory(tmpdir) diff --git a/tests/unit/attack_rules/test_models.py b/tests/unit/attack_rules/test_models.py index d5ac6d7..ff9ea47 100644 --- a/tests/unit/attack_rules/test_models.py +++ b/tests/unit/attack_rules/test_models.py @@ -87,7 +87,12 @@ class TestAttackRule: rule = AttackRule(name="test", type="jailbreak", prompt="Test") result = rule.to_dict() assert result == snapshot( - {"name": "test", "type": "jailbreak", "prompt": "Test", "severity": "medium"} + { + "name": "test", + "type": "jailbreak", + "prompt": "Test", + "severity": "medium", + } ) def test_render_prompt_no_variables(self): diff --git a/tests/unit/fuzz_chain/test_chain.py b/tests/unit/fuzz_chain/test_chain.py index 3cbe457..4a9d03b 100644 --- a/tests/unit/fuzz_chain/test_chain.py +++ b/tests/unit/fuzz_chain/test_chain.py @@ -116,11 +116,13 @@ class TestFuzzChain: result = await chain.run(input="initial") assert result == "final result" - assert llm.prompts == snapshot([ - "First: initial", - "Second: step1 result", - "Third: step2 result", - ]) + assert llm.prompts == snapshot( + [ + "First: initial", + "Second: step1 result", + "Third: step2 result", + ] + ) @pytest.mark.asyncio async def test_chain_with_custom_variables(self): @@ -131,10 +133,12 @@ class TestFuzzChain: result = await chain.run(topic="security", input="test prompt") assert result == "evaluated" - assert llm.prompts == snapshot([ - "Analyze security: test prompt", - "Evaluate: analyzed", - ]) + assert llm.prompts == snapshot( + [ + "Analyze security: test prompt", + "Evaluate: analyzed", + ] + ) def test_pipe_chain_to_node(self): llm = MockLLMProvider() @@ -158,11 +162,13 @@ class TestFuzzChain: def test_len(self): llm = MockLLMProvider() - chain = FuzzChain([ - FuzzNode(llm, "A"), - FuzzNode(llm, "B"), - FuzzNode(llm, "C"), - ]) + chain = FuzzChain( + [ + FuzzNode(llm, "A"), + FuzzNode(llm, "B"), + FuzzNode(llm, "C"), + ] + ) assert len(chain) == 3 def test_repr(self): @@ -185,11 +191,13 @@ class TestPipeOperatorChaining: result = await chain.run(input="start") assert result == "c" - assert llm.prompts == snapshot([ - "Step1: start", - "Step2: a", - "Step3: b", - ]) + assert llm.prompts == snapshot( + [ + "Step1: start", + "Step2: a", + "Step3: b", + ] + ) @pytest.mark.asyncio async def test_chain_with_different_providers(self): diff --git a/tests/unit/llm_providers/test_anthropic_provider.py b/tests/unit/llm_providers/test_anthropic_provider.py index c6f6725..ac990f9 100644 --- a/tests/unit/llm_providers/test_anthropic_provider.py +++ b/tests/unit/llm_providers/test_anthropic_provider.py @@ -5,7 +5,11 @@ from unittest.mock import MagicMock, AsyncMock, patch from inline_snapshot import snapshot from agentic_security.llm_providers.anthropic_provider import AnthropicProvider -from agentic_security.llm_providers.base import LLMMessage, LLMProviderError, LLMRateLimitError +from agentic_security.llm_providers.base import ( + LLMMessage, + LLMProviderError, + LLMRateLimitError, +) class TestAnthropicProviderInit: @@ -209,13 +213,19 @@ class TestAnthropicProviderErrors: def test_handle_rate_limit_error(self, provider): import anthropic + with pytest.raises(LLMRateLimitError): - provider._handle_error(anthropic.RateLimitError("rate limited", response=MagicMock(), body={})) + provider._handle_error( + anthropic.RateLimitError("rate limited", response=MagicMock(), body={}) + ) def test_handle_api_error(self, provider): import anthropic + with pytest.raises(LLMProviderError): - provider._handle_error(anthropic.APIError("api error", request=MagicMock(), body={})) + provider._handle_error( + anthropic.APIError("api error", request=MagicMock(), body={}) + ) def test_handle_generic_error(self, provider): with pytest.raises(LLMProviderError): diff --git a/tests/unit/llm_providers/test_factory.py b/tests/unit/llm_providers/test_factory.py index dbcd633..ff00b9e 100644 --- a/tests/unit/llm_providers/test_factory.py +++ b/tests/unit/llm_providers/test_factory.py @@ -9,7 +9,11 @@ from agentic_security.llm_providers.factory import ( list_providers, register_provider, ) -from agentic_security.llm_providers.base import BaseLLMProvider, LLMProviderError, LLMResponse +from agentic_security.llm_providers.base import ( + BaseLLMProvider, + LLMProviderError, + LLMResponse, +) class TestListProviders: @@ -26,11 +30,13 @@ class TestListProviders: class TestGetProviderClass: def test_get_openai(self): from agentic_security.llm_providers.openai_provider import OpenAIProvider + cls = get_provider_class("openai") assert cls is OpenAIProvider def test_get_anthropic(self): from agentic_security.llm_providers.anthropic_provider import AnthropicProvider + cls = get_provider_class("anthropic") assert cls is AnthropicProvider diff --git a/tests/unit/llm_providers/test_openai_provider.py b/tests/unit/llm_providers/test_openai_provider.py index 85b90d0..7ca06e8 100644 --- a/tests/unit/llm_providers/test_openai_provider.py +++ b/tests/unit/llm_providers/test_openai_provider.py @@ -5,7 +5,11 @@ from unittest.mock import MagicMock, AsyncMock, patch from inline_snapshot import snapshot from agentic_security.llm_providers.openai_provider import OpenAIProvider -from agentic_security.llm_providers.base import LLMMessage, LLMProviderError, LLMRateLimitError +from agentic_security.llm_providers.base import ( + LLMMessage, + LLMProviderError, + LLMRateLimitError, +) class TestOpenAIProviderInit: @@ -111,7 +115,9 @@ class TestOpenAIProviderSync: mock_response.usage = None with patch.object(provider, "_get_client") as mock_client: - mock_client.return_value.chat.completions.create.return_value = mock_response + mock_client.return_value.chat.completions.create.return_value = ( + mock_response + ) result = provider.sync_generate("Hello") assert result.content == snapshot("Response") @@ -126,7 +132,9 @@ class TestOpenAIProviderSync: messages = [LLMMessage(role="user", content="Hi")] with patch.object(provider, "_get_client") as mock_client: - mock_client.return_value.chat.completions.create.return_value = mock_response + mock_client.return_value.chat.completions.create.return_value = ( + mock_response + ) result = provider.sync_chat(messages) assert result.content == snapshot("Chat response") @@ -196,8 +204,11 @@ class TestOpenAIProviderErrors: def test_handle_rate_limit_error(self, provider): import openai + with pytest.raises(LLMRateLimitError): - provider._handle_error(openai.RateLimitError("rate limited", response=MagicMock(), body={})) + provider._handle_error( + openai.RateLimitError("rate limited", response=MagicMock(), body={}) + ) def test_handle_generic_error(self, provider): with pytest.raises(LLMProviderError): diff --git a/tests/unit/refusal_classifier/test_hybrid_classifier.py b/tests/unit/refusal_classifier/test_hybrid_classifier.py index e835aaa..f17b56f 100644 --- a/tests/unit/refusal_classifier/test_hybrid_classifier.py +++ b/tests/unit/refusal_classifier/test_hybrid_classifier.py @@ -45,7 +45,9 @@ class TestDetectionResult: def test_weighted_score_cases(self): for is_refusal, weight, expected in detection_result_cases: - result = DetectionResult(method="test", is_refusal=is_refusal, weight=weight) + result = DetectionResult( + method="test", is_refusal=is_refusal, weight=weight + ) assert result.weighted_score == expected def test_default_weight(self): @@ -234,7 +236,14 @@ factory_cases = [ ({"ml_detector": MockDetector(True)}, 1), ({"llm_detector": MockDetector(True)}, 1), ({"marker_detector": MockDetector(True), "ml_detector": MockDetector(False)}, 2), - ({"marker_detector": MockDetector(True), "ml_detector": MockDetector(False), "llm_detector": MockDetector(True)}, 3), + ( + { + "marker_detector": MockDetector(True), + "ml_detector": MockDetector(False), + "llm_detector": MockDetector(True), + }, + 3, + ), ] diff --git a/tests/unit/test_security.py b/tests/unit/test_security.py new file mode 100644 index 0000000..83a597f --- /dev/null +++ b/tests/unit/test_security.py @@ -0,0 +1,153 @@ +"""Unit tests for security module.""" + +import pytest +from agentic_security.core.security import ( + SecurityValidator, + SecretManager, + RateLimiter, + sanitize_log_output, +) + + +class TestSecurityValidator: + + def test_validate_url_valid(self): + assert SecurityValidator.validate_url("https://example.com/path") + assert SecurityValidator.validate_url("http://api.example.com") + + def test_validate_url_invalid_scheme(self): + assert not SecurityValidator.validate_url("ftp://example.com") + assert not SecurityValidator.validate_url("file:///etc/passwd") + + def test_validate_url_localhost(self): + assert not SecurityValidator.validate_url("http://localhost/api") + assert not SecurityValidator.validate_url("http://127.0.0.1/api") + assert not SecurityValidator.validate_url("http://0.0.0.0/api") + + def test_validate_url_private_ip(self): + assert not SecurityValidator.validate_url("http://10.0.0.1") + assert not SecurityValidator.validate_url("http://192.168.1.1") + assert not SecurityValidator.validate_url("http://169.254.1.1") + + def test_validate_url_allowed_hosts(self): + allowed = ["api.example.com"] + assert SecurityValidator.validate_url("https://api.example.com", allowed) + assert not SecurityValidator.validate_url("https://evil.com", allowed) + + def test_validate_url_too_long(self): + long_url = "https://example.com/" + "a" * 3000 + assert not SecurityValidator.validate_url(long_url) + + def test_sanitize_filename(self): + assert SecurityValidator.sanitize_filename("test.csv") == "test.csv" + assert SecurityValidator.sanitize_filename("../../../etc/passwd") == "passwd" + assert SecurityValidator.sanitize_filename("test/file.txt") == "file.txt" + assert ( + SecurityValidator.sanitize_filename("file with spaces.txt") + == "file with spaces.txt" + ) + + def test_sanitize_filename_invalid(self): + with pytest.raises(ValueError): + SecurityValidator.sanitize_filename(".") + with pytest.raises(ValueError): + SecurityValidator.sanitize_filename("..") + with pytest.raises(ValueError): + SecurityValidator.sanitize_filename("") + + def test_validate_file_size(self): + assert SecurityValidator.validate_file_size(1024) + assert SecurityValidator.validate_file_size(1024 * 1024) + assert not SecurityValidator.validate_file_size(0) + assert not SecurityValidator.validate_file_size(-1) + assert not SecurityValidator.validate_file_size(20 * 1024 * 1024) + + def test_validate_csv_content(self): + assert SecurityValidator.validate_csv_content("col1,col2\nval1,val2") + assert not SecurityValidator.validate_csv_content("") + assert not SecurityValidator.validate_csv_content("x" * (11 * 1024 * 1024)) + + +class TestSecretManager: + + def test_hash_and_verify_secret(self): + secret = "my-secret-key" + hashed = SecretManager.hash_secret(secret) + + assert SecretManager.verify_secret(secret, hashed) + assert not SecretManager.verify_secret("wrong-secret", hashed) + + def test_hash_secret_with_salt(self): + secret = "my-secret" + hashed1 = SecretManager.hash_secret(secret, "salt1") + hashed2 = SecretManager.hash_secret(secret, "salt2") + + assert hashed1 != hashed2 + + def test_verify_secret_invalid_format(self): + assert not SecretManager.verify_secret("secret", "invalid-hash") + + +class TestRateLimiter: + + def test_rate_limiter_allows_requests(self): + limiter = RateLimiter(max_requests=3, window_seconds=60) + + assert limiter.is_allowed("user1") + assert limiter.is_allowed("user1") + assert limiter.is_allowed("user1") + + def test_rate_limiter_blocks_excess(self): + limiter = RateLimiter(max_requests=2, window_seconds=60) + + assert limiter.is_allowed("user1") + assert limiter.is_allowed("user1") + assert not limiter.is_allowed("user1") + + def test_rate_limiter_separate_keys(self): + limiter = RateLimiter(max_requests=2, window_seconds=60) + + assert limiter.is_allowed("user1") + assert limiter.is_allowed("user1") + assert limiter.is_allowed("user2") + assert limiter.is_allowed("user2") + + def test_rate_limiter_reset(self): + limiter = RateLimiter(max_requests=1, window_seconds=60) + + assert limiter.is_allowed("user1") + assert not limiter.is_allowed("user1") + + limiter.reset("user1") + assert limiter.is_allowed("user1") + + +class TestSanitizeLogOutput: + + def test_sanitize_api_key(self): + data = 'api_key="sk-1234567890"' + result = sanitize_log_output(data) + assert "sk-1234567890" not in result + assert "***" in result + + def test_sanitize_token(self): + data = "token: abc123xyz" + result = sanitize_log_output(data) + assert "abc123xyz" not in result + + def test_sanitize_password(self): + data = {"password": "secret123"} + result = sanitize_log_output(data) + assert "secret123" not in result + + def test_sanitize_bearer_token(self): + data = "Authorization: Bearer eyJhbGc..." + result = sanitize_log_output(data) + assert "eyJhbGc" not in result + assert "Bearer ***" in result + + def test_preserves_non_sensitive(self): + data = "user_id=123 name=John" + result = sanitize_log_output(data) + assert "user_id=123" in result + assert "name=John" in result