Initial commit: AASRT v1.0.0 - AI Agent Security Reconnaissance Tool

This commit is contained in:
swethab
2026-02-10 10:53:31 -05:00
commit a714a3399b
61 changed files with 14858 additions and 0 deletions
+7
View File
@@ -0,0 +1,7 @@
"""
Unit Tests for AASRT
Unit tests test individual components in isolation using mocks
for external dependencies.
"""
+165
View File
@@ -0,0 +1,165 @@
"""
Unit Tests for Config Module
Tests for src/utils/config.py
"""
import os
import tempfile
from pathlib import Path
from unittest.mock import patch, MagicMock
import pytest
class TestConfigLoading:
"""Tests for configuration loading."""
def test_load_from_yaml(self, temp_dir):
"""Test loading configuration from YAML file."""
from src.utils.config import Config
config_content = """
shodan:
enabled: true
rate_limit: 1
max_results: 100
logging:
level: DEBUG
"""
config_path = temp_dir / "config.yaml"
config_path.write_text(config_content)
with patch.dict(os.environ, {'AASRT_CONFIG_PATH': str(config_path)}):
config = Config()
assert config.get('shodan', 'enabled') is True
assert config.get('shodan', 'rate_limit') == 1
def test_environment_variable_override(self, temp_dir, monkeypatch):
"""Test environment variables override config file."""
from src.utils.config import Config
config_content = """
shodan:
enabled: true
rate_limit: 1
"""
config_path = temp_dir / "config.yaml"
config_path.write_text(config_content)
# Reset the Config singleton so we get a fresh instance
Config._instance = None
Config._config = {}
# Use monkeypatch for proper isolation - clear and set
monkeypatch.setenv('AASRT_CONFIG_PATH', str(config_path))
monkeypatch.setenv('SHODAN_API_KEY', 'env_api_key_12345')
config = Config()
assert config.get_shodan_key() == 'env_api_key_12345'
# Clean up - reset singleton for other tests
Config._instance = None
Config._config = {}
def test_default_values(self):
"""Test default configuration values are used when not specified."""
from src.utils.config import Config
with patch.dict(os.environ, {'SHODAN_API_KEY': 'test_key'}):
config = Config()
# Check default logging level if not set
log_level = config.get('logging', 'level', default='INFO')
assert log_level in ['DEBUG', 'INFO', 'WARNING', 'ERROR']
class TestConfigValidation:
"""Tests for configuration validation."""
def test_validate_shodan_key_format(self):
"""Test Shodan API key format validation."""
from src.utils.config import Config
# Valid key format (typically alphanumeric)
with patch.dict(os.environ, {'SHODAN_API_KEY': 'AbCdEf123456789012345678'}):
config = Config()
key = config.get_shodan_key()
assert key is not None
def test_missing_required_config(self):
"""Test handling of missing required configuration."""
from src.utils.config import Config
# Clear all Shodan-related env vars
env_copy = {k: v for k, v in os.environ.items() if 'SHODAN' not in k}
with patch.dict(os.environ, env_copy, clear=True):
config = Config()
# Should return None or raise exception for missing key
key = config.get_shodan_key()
# Depending on implementation, key could be None or empty
class TestConfigHealthCheck:
"""Tests for configuration health check."""
def test_health_check_returns_dict(self):
"""Test health_check returns a dictionary."""
from src.utils.config import Config
with patch.dict(os.environ, {'SHODAN_API_KEY': 'test_key'}):
config = Config()
health = config.health_check()
assert isinstance(health, dict)
assert 'status' in health or 'healthy' in health
def test_health_check_includes_key_info(self):
"""Test health_check includes API key status."""
from src.utils.config import Config
with patch.dict(os.environ, {'SHODAN_API_KEY': 'test_key'}):
config = Config()
health = config.health_check()
# Should indicate whether key is configured
assert health is not None
class TestConfigGet:
"""Tests for the get() method."""
def test_nested_key_access(self, temp_dir):
"""Test accessing nested configuration values."""
from src.utils.config import Config
config_content = """
database:
sqlite:
path: ./data/scanner.db
pool_size: 5
"""
config_path = temp_dir / "config.yaml"
config_path.write_text(config_content)
with patch.dict(os.environ, {'AASRT_CONFIG_PATH': str(config_path)}):
config = Config()
path = config.get('database', 'sqlite', 'path')
assert path is not None
def test_default_for_missing_key(self):
"""Test default value is returned for missing keys."""
from src.utils.config import Config
with patch.dict(os.environ, {'SHODAN_API_KEY': 'test_key'}):
config = Config()
value = config.get('nonexistent', 'key', default='default_value')
assert value == 'default_value'
def test_none_for_missing_key_no_default(self):
"""Test None is returned for missing keys without default."""
from src.utils.config import Config
with patch.dict(os.environ, {'SHODAN_API_KEY': 'test_key'}):
config = Config()
value = config.get('nonexistent', 'key')
assert value is None
+204
View File
@@ -0,0 +1,204 @@
"""
Unit Tests for Database Module
Tests for src/storage/database.py
"""
import pytest
from unittest.mock import MagicMock, patch
from datetime import datetime
class TestDatabaseInit:
"""Tests for Database initialization."""
def test_init_creates_tables(self, temp_db, mock_config):
"""Test database initialization creates tables."""
from src.storage.database import Database
mock_config.get.side_effect = lambda *args, **kwargs: {
('database', 'type'): 'sqlite',
('database', 'sqlite', 'path'): str(temp_db),
}.get(args, kwargs.get('default'))
db = Database(mock_config)
assert db is not None
db.close()
def test_init_sqlite_with_temp_path(self, temp_db, mock_config):
"""Test SQLite database with temp path."""
from src.storage.database import Database
mock_config.get.side_effect = lambda *args, **kwargs: {
('database', 'type'): 'sqlite',
('database', 'sqlite', 'path'): str(temp_db),
}.get(args, kwargs.get('default'))
db = Database(mock_config)
assert db is not None
assert db._db_type == 'sqlite'
db.close()
class TestDatabaseOperations:
"""Tests for database CRUD operations."""
@pytest.fixture
def db(self, temp_db, mock_config):
"""Create a test database instance."""
from src.storage.database import Database
mock_config.get.side_effect = lambda *args, **kwargs: {
('database', 'type'): 'sqlite',
('database', 'sqlite', 'path'): str(temp_db),
}.get(args, kwargs.get('default'))
db = Database(mock_config)
yield db
db.close()
def test_create_scan(self, db):
"""Test creating a scan record."""
scan = db.create_scan(
engines=['shodan'],
query='http.title:"ClawdBot"'
)
assert scan is not None
assert scan.scan_id is not None
def test_get_scan_by_id(self, db):
"""Test retrieving a scan by ID."""
# Create a scan first
scan = db.create_scan(
engines=['shodan'],
query='test query'
)
retrieved = db.get_scan(scan.scan_id)
assert retrieved is not None
assert retrieved.scan_id == scan.scan_id
def test_get_recent_scans(self, db):
"""Test retrieving recent scans."""
# Create a few scans
for i in range(3):
db.create_scan(
engines=['shodan'],
query=f'test query {i}'
)
scans = db.get_recent_scans(limit=10)
assert len(scans) >= 3
def test_add_findings(self, db):
"""Test adding findings to a scan."""
from src.engines.base import SearchResult
# First create a scan
scan = db.create_scan(
engines=['shodan'],
query='test'
)
# Create some search results
results = [
SearchResult(
ip='192.0.2.1',
port=8080,
banner='ClawdBot Dashboard',
vulnerabilities=['exposed_dashboard']
)
]
count = db.add_findings(scan.scan_id, results)
assert count >= 1
def test_update_scan(self, db):
"""Test updating a scan."""
# Create a scan
scan = db.create_scan(
engines=['shodan'],
query='test'
)
# Update it
updated = db.update_scan(
scan.scan_id,
status='completed',
total_results=5
)
assert updated is not None
assert updated.status == 'completed'
class TestDatabaseHealthCheck:
"""Tests for database health check."""
@pytest.fixture
def db(self, temp_db, mock_config):
"""Create a test database instance."""
from src.storage.database import Database
mock_config.get.side_effect = lambda *args, **kwargs: {
('database', 'type'): 'sqlite',
('database', 'sqlite', 'path'): str(temp_db),
}.get(args, kwargs.get('default'))
with patch('src.storage.database.Config', return_value=mock_config):
db = Database(mock_config)
yield db
db.close()
def test_health_check_returns_dict(self, db):
"""Test health_check returns a dictionary."""
health = db.health_check()
assert isinstance(health, dict)
def test_health_check_includes_status(self, db):
"""Test health_check includes status."""
health = db.health_check()
assert 'status' in health or 'healthy' in health
def test_health_check_includes_latency(self, db):
"""Test health_check includes latency measurement."""
health = db.health_check()
# Should have some form of latency/response time
has_latency = 'latency' in health or 'latency_ms' in health or 'response_time' in health
assert has_latency or health.get('status') == 'healthy'
class TestDatabaseSessionScope:
"""Tests for session_scope context manager."""
@pytest.fixture
def db(self, temp_db, mock_config):
"""Create a test database instance."""
from src.storage.database import Database
mock_config.get.side_effect = lambda *args, **kwargs: {
('database', 'type'): 'sqlite',
('database', 'sqlite', 'path'): str(temp_db),
}.get(args, kwargs.get('default'))
with patch('src.storage.database.Config', return_value=mock_config):
db = Database(mock_config)
yield db
db.close()
def test_session_scope_commits(self, db):
"""Test session_scope commits on success."""
with db.session_scope() as session:
# Perform some operation
pass
# Should complete without error
def test_session_scope_rollback_on_error(self, db):
"""Test session_scope rolls back on error."""
try:
with db.session_scope() as session:
raise ValueError("Test error")
except ValueError:
pass # Expected
# Session should have been rolled back
+125
View File
@@ -0,0 +1,125 @@
"""
Unit Tests for Risk Scorer Module
Tests for src/core/risk_scorer.py
"""
import pytest
from unittest.mock import MagicMock
class TestRiskScorer:
"""Tests for RiskScorer class."""
@pytest.fixture
def risk_scorer(self):
"""Create a RiskScorer instance."""
from src.core.risk_scorer import RiskScorer
return RiskScorer()
@pytest.fixture
def sample_vulnerabilities(self):
"""Create sample Vulnerability objects."""
from src.core.vulnerability_assessor import Vulnerability
return [
Vulnerability(
check_name='exposed_dashboard',
severity='HIGH',
cvss_score=7.5,
description='Dashboard exposed without authentication'
)
]
@pytest.fixture
def sample_result(self, sample_shodan_result):
"""Create a SearchResult with vulnerabilities."""
from src.engines.base import SearchResult
result = SearchResult(
ip=sample_shodan_result['ip_str'],
port=sample_shodan_result['port'],
banner=sample_shodan_result['data'],
metadata=sample_shodan_result
)
return result
def test_calculate_score_returns_dict(self, risk_scorer, sample_vulnerabilities):
"""Test that calculate_score returns a dictionary."""
result = risk_scorer.calculate_score(sample_vulnerabilities)
assert isinstance(result, dict)
assert 'overall_score' in result
def test_calculate_score_range_valid(self, risk_scorer, sample_vulnerabilities):
"""Test that score is within valid range (0-10)."""
result = risk_scorer.calculate_score(sample_vulnerabilities)
assert 0 <= result['overall_score'] <= 10
def test_high_risk_vulnerabilities_increase_score(self, risk_scorer):
"""Test that high-risk vulnerabilities increase score."""
from src.core.vulnerability_assessor import Vulnerability
# High severity vulnerabilities
high_vulns = [
Vulnerability(check_name='api_key_exposure', severity='CRITICAL', cvss_score=9.0, description='API key exposed'),
Vulnerability(check_name='no_authentication', severity='CRITICAL', cvss_score=9.5, description='No auth')
]
# Low severity vulnerabilities
low_vulns = [
Vulnerability(check_name='version_exposed', severity='LOW', cvss_score=2.0, description='Version info')
]
high_score = risk_scorer.calculate_score(high_vulns)['overall_score']
low_score = risk_scorer.calculate_score(low_vulns)['overall_score']
assert high_score > low_score
def test_empty_vulnerabilities_zero_score(self, risk_scorer):
"""Test that no vulnerabilities result in zero score."""
result = risk_scorer.calculate_score([])
assert result['overall_score'] == 0
def test_score_result_updates_search_result(self, risk_scorer, sample_result, sample_vulnerabilities):
"""Test that score_result updates the SearchResult."""
scored_result = risk_scorer.score_result(sample_result, sample_vulnerabilities)
assert scored_result.risk_score >= 0
assert 'risk_assessment' in scored_result.metadata
def test_context_multipliers_applied(self, risk_scorer, sample_vulnerabilities):
"""Test that context multipliers affect the score."""
# Score with no context
base_score = risk_scorer.calculate_score(sample_vulnerabilities)['overall_score']
# Score with context multiplier
context = {'public_internet': True, 'no_waf': True, 'ai_agent': True}
context_score = risk_scorer.calculate_score(sample_vulnerabilities, context)['overall_score']
# Context should increase or maintain score
assert context_score >= base_score
def test_severity_breakdown_included(self, risk_scorer, sample_vulnerabilities):
"""Test that severity breakdown is included in results."""
result = risk_scorer.calculate_score(sample_vulnerabilities)
assert 'severity_breakdown' in result
assert isinstance(result['severity_breakdown'], dict)
class TestRiskCategories:
"""Tests for risk categorization."""
@pytest.fixture
def risk_scorer(self):
"""Create a RiskScorer instance."""
from src.core.risk_scorer import RiskScorer
return RiskScorer()
def test_get_risk_level(self, risk_scorer):
"""Test risk level categorization."""
# Test if there's a method to get risk level string
if hasattr(risk_scorer, 'get_risk_level'):
assert risk_scorer.get_risk_level(95) in ['CRITICAL', 'HIGH', 'critical', 'high']
assert risk_scorer.get_risk_level(75) in ['HIGH', 'MEDIUM', 'high', 'medium']
assert risk_scorer.get_risk_level(50) in ['MEDIUM', 'medium']
assert risk_scorer.get_risk_level(25) in ['LOW', 'low']
assert risk_scorer.get_risk_level(10) in ['INFO', 'LOW', 'info', 'low']
+179
View File
@@ -0,0 +1,179 @@
"""
Unit Tests for Shodan Engine Module
Tests for src/engines/shodan_engine.py
"""
import pytest
from unittest.mock import MagicMock, patch
class TestShodanEngineInit:
"""Tests for ShodanEngine initialization."""
def test_init_with_valid_key(self):
"""Test initialization with valid API key."""
from src.engines.shodan_engine import ShodanEngine
with patch('shodan.Shodan') as mock_shodan:
mock_client = MagicMock()
mock_shodan.return_value = mock_client
engine = ShodanEngine(api_key='test_api_key_12345')
assert engine is not None
assert engine.name == 'shodan'
def test_init_without_key_raises_error(self):
"""Test initialization without API key raises error."""
from src.engines.shodan_engine import ShodanEngine
with pytest.raises(ValueError):
ShodanEngine(api_key='')
with pytest.raises((ValueError, TypeError)):
ShodanEngine(api_key=None)
class TestShodanEngineSearch:
"""Tests for ShodanEngine search functionality."""
@pytest.fixture
def engine(self, mock_shodan_client):
"""Create a ShodanEngine instance with mocked client."""
from src.engines.shodan_engine import ShodanEngine
with patch('shodan.Shodan', return_value=mock_shodan_client):
engine = ShodanEngine(api_key='test_api_key_12345')
engine._client = mock_shodan_client
return engine
def test_search_returns_results(self, engine, mock_shodan_client, sample_shodan_result):
"""Test search returns results."""
mock_shodan_client.search.return_value = {
'matches': [sample_shodan_result],
'total': 1
}
results = engine.search('http.title:"ClawdBot"')
assert isinstance(results, list)
def test_search_empty_query_raises_error(self, engine):
"""Test empty query raises error."""
with pytest.raises((ValueError, Exception)):
engine.search('')
def test_search_handles_api_error(self, engine, mock_shodan_client):
"""Test search handles API errors gracefully."""
import shodan
mock_shodan_client.search.side_effect = shodan.APIError('API Error')
from src.utils.exceptions import APIException
with pytest.raises((APIException, Exception)):
engine.search('test query')
def test_search_with_max_results(self, engine, mock_shodan_client, sample_shodan_result):
"""Test search respects max_results limit."""
mock_shodan_client.search.return_value = {
'matches': [sample_shodan_result],
'total': 1
}
results = engine.search('test', max_results=1)
assert len(results) <= 1
class TestShodanEngineCredentials:
"""Tests for credential validation."""
@pytest.fixture
def engine(self, mock_shodan_client):
"""Create a ShodanEngine instance."""
from src.engines.shodan_engine import ShodanEngine
with patch('shodan.Shodan', return_value=mock_shodan_client):
engine = ShodanEngine(api_key='test_api_key_12345')
engine._client = mock_shodan_client
return engine
def test_validate_credentials_success(self, engine, mock_shodan_client):
"""Test successful credential validation."""
mock_shodan_client.info.return_value = {'plan': 'dev', 'query_credits': 100}
result = engine.validate_credentials()
assert result is True
def test_validate_credentials_invalid_key(self, engine, mock_shodan_client):
"""Test invalid API key handling."""
import shodan
from src.utils.exceptions import AuthenticationException
mock_shodan_client.info.side_effect = shodan.APIError('Invalid API key')
with pytest.raises((AuthenticationException, Exception)):
engine.validate_credentials()
class TestShodanEngineQuota:
"""Tests for quota information."""
@pytest.fixture
def engine(self, mock_shodan_client):
"""Create a ShodanEngine instance."""
from src.engines.shodan_engine import ShodanEngine
with patch('shodan.Shodan', return_value=mock_shodan_client):
engine = ShodanEngine(api_key='test_api_key_12345')
engine._client = mock_shodan_client
return engine
def test_get_quota_info(self, engine, mock_shodan_client):
"""Test getting quota information."""
mock_shodan_client.info.return_value = {
'plan': 'dev',
'query_credits': 100,
'scan_credits': 50
}
quota = engine.get_quota_info()
assert isinstance(quota, dict)
def test_quota_info_handles_error(self, engine, mock_shodan_client):
"""Test quota info handles API errors."""
import shodan
from src.utils.exceptions import APIException
mock_shodan_client.info.side_effect = shodan.APIError('API Error')
# May either raise or return error info depending on implementation
try:
quota = engine.get_quota_info()
assert quota is not None
except (APIException, Exception):
pass # Acceptable if it raises
class TestShodanEngineRetry:
"""Tests for retry logic."""
def test_retry_on_transient_error(self, mock_shodan_client):
"""Test retry logic on transient errors."""
from src.engines.shodan_engine import ShodanEngine
import shodan
with patch('shodan.Shodan', return_value=mock_shodan_client):
engine = ShodanEngine(api_key='test_api_key_12345')
engine._client = mock_shodan_client
# First call fails, second succeeds
mock_shodan_client.search.side_effect = [
ConnectionError("Network error"),
{'matches': [], 'total': 0}
]
# Depending on implementation, this may retry or raise
try:
results = engine.search('test')
assert isinstance(results, list)
except Exception:
pass # Expected if retries exhausted
+180
View File
@@ -0,0 +1,180 @@
"""
Unit Tests for Validators Module
Tests for src/utils/validators.py
"""
import pytest
from src.utils.validators import (
validate_ip,
validate_domain,
validate_query,
validate_file_path,
validate_template_name,
is_safe_string,
sanitize_output,
)
from src.utils.exceptions import ValidationException
class TestValidateIP:
"""Tests for IP address validation."""
def test_valid_ipv4(self):
"""Test valid IPv4 addresses."""
assert validate_ip("192.168.1.1") is True
assert validate_ip("10.0.0.1") is True
assert validate_ip("172.16.0.1") is True
assert validate_ip("8.8.8.8") is True
def test_invalid_ipv4_raises_exception(self):
"""Test invalid IPv4 addresses raise ValidationException."""
with pytest.raises(ValidationException):
validate_ip("256.1.1.1")
with pytest.raises(ValidationException):
validate_ip("192.168.1")
with pytest.raises(ValidationException):
validate_ip("not.an.ip.address")
def test_empty_and_none_raises_exception(self):
"""Test empty and None values raise ValidationException."""
with pytest.raises(ValidationException):
validate_ip("")
with pytest.raises(ValidationException):
validate_ip(None)
def test_ipv6_addresses(self):
"""Test IPv6 address handling."""
# IPv6 addresses should be valid
assert validate_ip("::1") is True
assert validate_ip("2001:db8::1") is True
class TestValidateDomain:
"""Tests for domain validation."""
def test_valid_domains(self):
"""Test valid domain names."""
assert validate_domain("example.com") is True
assert validate_domain("sub.example.com") is True
assert validate_domain("test-site.example.org") is True
def test_invalid_domains_raises_exception(self):
"""Test invalid domain names raise ValidationException."""
with pytest.raises(ValidationException):
validate_domain("-invalid.com")
with pytest.raises(ValidationException):
validate_domain("invalid-.com")
def test_localhost_raises_exception(self):
"""Test localhost raises ValidationException (not a valid domain format)."""
with pytest.raises(ValidationException):
validate_domain("localhost")
class TestValidateQuery:
"""Tests for Shodan query validation."""
def test_valid_queries(self):
"""Test valid Shodan queries."""
assert validate_query('http.title:"ClawdBot"', engine='shodan') is True
assert validate_query("port:8080", engine='shodan') is True
assert validate_query("product:nginx", engine='shodan') is True
def test_empty_query_raises_exception(self):
"""Test empty queries raise ValidationException."""
with pytest.raises(ValidationException):
validate_query("", engine='shodan')
with pytest.raises(ValidationException):
validate_query(" ", engine='shodan')
def test_sql_injection_patterns_allowed(self):
"""Test SQL-like patterns are allowed (Shodan doesn't execute SQL)."""
# Shodan queries can contain SQL-like syntax without causing issues
result = validate_query("'; DROP TABLE users; --", engine='shodan')
assert result is True # No script tags or null bytes
class TestValidateFilePath:
"""Tests for file path validation."""
def test_valid_paths(self):
"""Test valid file paths return sanitized path."""
result = validate_file_path("reports/scan.json")
assert result is not None
assert "scan.json" in result
def test_directory_traversal_raises_exception(self):
"""Test directory traversal raises ValidationException."""
with pytest.raises(ValidationException):
validate_file_path("../../../etc/passwd")
with pytest.raises(ValidationException):
validate_file_path("..\\..\\windows\\system32")
def test_null_bytes_raises_exception(self):
"""Test null byte injection raises ValidationException."""
with pytest.raises(ValidationException):
validate_file_path("file.txt\x00.exe")
class TestValidateTemplateName:
"""Tests for template name validation."""
def test_valid_templates(self):
"""Test valid template names."""
assert validate_template_name("clawdbot_instances") is True
assert validate_template_name("autogpt_instances") is True
def test_invalid_template_raises_exception(self):
"""Test invalid template names raise ValidationException."""
with pytest.raises(ValidationException):
validate_template_name("nonexistent_template")
def test_empty_template_raises_exception(self):
"""Test empty template names raise ValidationException."""
with pytest.raises(ValidationException):
validate_template_name("")
with pytest.raises(ValidationException):
validate_template_name(None)
class TestIsSafeString:
"""Tests for safe string detection."""
def test_safe_strings(self):
"""Test safe strings pass validation."""
assert is_safe_string("hello world") is True
assert is_safe_string("ClawdBot Dashboard") is True
def test_script_tags_detected(self):
"""Test script tags are detected as unsafe."""
assert is_safe_string("<script>alert('xss')</script>") is False
def test_sql_patterns_allowed(self):
"""Test SQL-like patterns are allowed (is_safe_string checks XSS, not SQL)."""
# Note: is_safe_string focuses on XSS patterns, not SQL injection
result = is_safe_string("'; DROP TABLE users; --")
# This may or may not be detected depending on implementation
assert isinstance(result, bool)
class TestSanitizeOutput:
"""Tests for output sanitization."""
def test_password_redaction(self):
"""Test passwords are redacted."""
output = sanitize_output("password=mysecretpassword")
assert "mysecretpassword" not in output
def test_normal_text_unchanged(self):
"""Test normal text is not modified."""
text = "This is normal text without secrets"
assert sanitize_output(text) == text
def test_api_key_pattern_redaction(self):
"""Test API key patterns are redacted."""
# Test with patterns that match the redaction rules
output = sanitize_output("api_key=12345678901234567890")
# Depending on implementation, may or may not be redacted
assert output is not None