Initial commit: AASRT v1.0.0 - AI Agent Security Reconnaissance Tool

This commit is contained in:
swethab
2026-02-10 10:53:31 -05:00
commit a714a3399b
61 changed files with 14858 additions and 0 deletions
+18
View File
@@ -0,0 +1,18 @@
"""
AASRT Test Suite
This package contains all tests for the AI Agent Security Reconnaissance Tool.
Test Categories:
- Unit tests: Test individual components in isolation
- Integration tests: Test component interactions
- End-to-end tests: Test complete workflows
Running Tests:
pytest # Run all tests
pytest tests/unit/ # Run unit tests only
pytest tests/integration/ # Run integration tests only
pytest -v --cov=src # Run with coverage
pytest -m "not slow" # Skip slow tests
"""
+181
View File
@@ -0,0 +1,181 @@
"""
Pytest Configuration and Shared Fixtures
This module provides shared fixtures and configuration for all tests.
"""
import os
import sys
import tempfile
from pathlib import Path
from typing import Any, Dict, Generator, List
from unittest.mock import MagicMock, patch
import pytest
# Add src to path for imports
sys.path.insert(0, str(Path(__file__).parent.parent))
# =============================================================================
# Environment Fixtures
# =============================================================================
@pytest.fixture(scope="session")
def test_env():
"""Set up test environment variables."""
original_env = os.environ.copy()
os.environ.update({
'SHODAN_API_KEY': 'test_api_key_12345',
'AASRT_ENVIRONMENT': 'testing',
'AASRT_LOG_LEVEL': 'DEBUG',
'AASRT_DEBUG': 'true',
})
yield os.environ
# Restore original environment
os.environ.clear()
os.environ.update(original_env)
@pytest.fixture
def temp_dir():
"""Create a temporary directory for test files."""
with tempfile.TemporaryDirectory() as tmpdir:
yield Path(tmpdir)
@pytest.fixture
def temp_db(temp_dir):
"""Create a temporary database path."""
return temp_dir / "test_scanner.db"
# =============================================================================
# Mock Data Fixtures
# =============================================================================
@pytest.fixture
def sample_shodan_result() -> Dict[str, Any]:
"""Sample Shodan API response."""
return {
'ip_str': '192.0.2.1',
'port': 8080,
'transport': 'tcp',
'hostnames': ['test.example.com'],
'org': 'Test Organization',
'asn': 'AS12345',
'isp': 'Test ISP',
'data': 'HTTP/1.1 200 OK\r\nServer: nginx\r\n\r\nClawdBot Dashboard',
'location': {
'country_code': 'US',
'country_name': 'United States',
'city': 'Test City',
'latitude': 37.7749,
'longitude': -122.4194
},
'http': {
'status': 200,
'title': 'ClawdBot Dashboard',
'server': 'nginx/1.18.0',
'html': '<html><body>ClawdBot Dashboard</body></html>'
},
'vulns': ['CVE-2021-44228'],
'timestamp': '2024-01-15T10:30:00.000000'
}
@pytest.fixture
def sample_search_results(sample_shodan_result) -> List[Dict[str, Any]]:
"""Multiple sample Shodan results."""
results = [sample_shodan_result]
# Add more varied results
results.append({
**sample_shodan_result,
'ip_str': '192.0.2.2',
'port': 3000,
'http': {
'status': 200,
'title': 'AutoGPT Interface',
'server': 'Python/3.11'
}
})
results.append({
**sample_shodan_result,
'ip_str': '192.0.2.3',
'port': 443,
'http': {
'status': 401,
'title': 'Login Required'
}
})
return results
@pytest.fixture
def sample_vulnerability() -> Dict[str, Any]:
"""Sample vulnerability data."""
return {
'check_name': 'exposed_dashboard',
'severity': 'HIGH',
'cvss_score': 7.5,
'description': 'Dashboard accessible without authentication',
'evidence': {'http_title': 'ClawdBot Dashboard'},
'remediation': 'Implement authentication',
'cwe_id': 'CWE-306'
}
# =============================================================================
# Mock Service Fixtures
# =============================================================================
@pytest.fixture
def mock_shodan_client():
"""Mock Shodan API client."""
with patch('shodan.Shodan') as mock:
client = MagicMock()
client.info.return_value = {
'plan': 'dev',
'query_credits': 100,
'scan_credits': 50
}
mock.return_value = client
yield client
@pytest.fixture
def mock_config(temp_dir, temp_db):
"""Mock configuration object."""
config = MagicMock()
config.get_shodan_key.return_value = 'test_api_key'
config.get.side_effect = lambda *args, default=None: {
('database', 'type'): 'sqlite',
('database', 'sqlite', 'path'): str(temp_db),
('logging', 'level'): 'DEBUG',
('reporting', 'output_dir'): str(temp_dir / 'reports'),
('vulnerability_checks',): {'passive_only': True},
}.get(args, default)
return config
# =============================================================================
# Database Fixtures
# =============================================================================
@pytest.fixture
def test_database(mock_config, temp_db):
"""Create a test database instance."""
from src.storage.database import Database
# Patch config to use temp database
with patch('src.storage.database.Config', return_value=mock_config):
db = Database(mock_config)
yield db
db.close()
+7
View File
@@ -0,0 +1,7 @@
"""
Integration Tests for AASRT
Integration tests verify that components work together correctly.
These tests may use real (test) databases but mock external APIs.
"""
@@ -0,0 +1,207 @@
"""
Integration Tests for Database Operations
Tests database operations with real SQLite database.
"""
import pytest
import tempfile
from pathlib import Path
from unittest.mock import patch, MagicMock
from datetime import datetime, timedelta
class TestDatabaseIntegration:
"""Integration tests for database with real SQLite."""
@pytest.fixture
def real_db(self):
"""Create a real SQLite database for testing."""
from src.storage.database import Database
from src.utils.config import Config
with tempfile.TemporaryDirectory() as tmpdir:
db_path = Path(tmpdir) / "test_scanner.db"
mock_config = MagicMock()
mock_config.get.side_effect = lambda *args, **kwargs: {
('database', 'type'): 'sqlite',
('database', 'sqlite', 'path'): str(db_path),
}.get(args, kwargs.get('default'))
mock_config.get_shodan_key.return_value = 'test_key'
with patch('src.storage.database.Config', return_value=mock_config):
db = Database(mock_config)
yield db
db.close()
def test_scan_lifecycle(self, real_db):
"""Test complete scan lifecycle: create, update, retrieve."""
# Create scan
scan_id = 'integration-test-scan-001'
real_db.save_scan({
'scan_id': scan_id,
'query': 'http.title:"ClawdBot"',
'engine': 'shodan',
'started_at': datetime.utcnow(),
'status': 'running',
'total_results': 0
})
# Update scan status
real_db.update_scan(scan_id, {
'status': 'completed',
'total_results': 25,
'completed_at': datetime.utcnow()
})
# Retrieve scan
scan = real_db.get_scan(scan_id)
assert scan is not None
def test_findings_association(self, real_db):
"""Test findings are properly associated with scans."""
scan_id = 'integration-test-scan-002'
# Create scan
real_db.save_scan({
'scan_id': scan_id,
'query': 'test query',
'engine': 'shodan',
'started_at': datetime.utcnow(),
'status': 'completed'
})
# Save multiple findings
for i in range(5):
real_db.save_finding({
'scan_id': scan_id,
'ip': f'192.0.2.{i+1}',
'port': 8080 + i,
'risk_score': 50 + i * 10,
'vulnerabilities': ['test_vuln']
})
# Retrieve findings
findings = real_db.get_findings_by_scan(scan_id)
assert len(findings) == 5
def test_scan_statistics(self, real_db):
"""Test scan statistics calculation."""
# Create multiple scans with different statuses
for i in range(10):
real_db.save_scan({
'scan_id': f'stats-test-{i:03d}',
'query': f'test query {i}',
'engine': 'shodan',
'started_at': datetime.utcnow() - timedelta(days=i),
'status': 'completed' if i % 2 == 0 else 'failed',
'total_results': i * 10
})
# Get statistics
if hasattr(real_db, 'get_scan_statistics'):
stats = real_db.get_scan_statistics()
assert 'total_scans' in stats or stats is not None
def test_concurrent_operations(self, real_db):
"""Test concurrent database operations."""
import threading
errors = []
def save_scan(scan_num):
try:
real_db.save_scan({
'scan_id': f'concurrent-test-{scan_num:03d}',
'query': f'test {scan_num}',
'engine': 'shodan',
'started_at': datetime.utcnow(),
'status': 'completed'
})
except Exception as e:
errors.append(e)
threads = [threading.Thread(target=save_scan, args=(i,)) for i in range(5)]
for t in threads:
t.start()
for t in threads:
t.join()
# Should complete without deadlocks
assert len(errors) == 0
def test_data_persistence(self):
"""Test that data persists across database connections."""
from src.storage.database import Database
with tempfile.TemporaryDirectory() as tmpdir:
db_path = Path(tmpdir) / "persistence_test.db"
scan_id = 'persistence-test-001'
mock_config = MagicMock()
mock_config.get.side_effect = lambda *args, **kwargs: {
('database', 'type'): 'sqlite',
('database', 'sqlite', 'path'): str(db_path),
}.get(args, kwargs.get('default'))
# First connection - create data
with patch('src.storage.database.Config', return_value=mock_config):
db1 = Database(mock_config)
db1.save_scan({
'scan_id': scan_id,
'query': 'test',
'engine': 'shodan',
'started_at': datetime.utcnow(),
'status': 'completed'
})
db1.close()
# Second connection - verify data exists
with patch('src.storage.database.Config', return_value=mock_config):
db2 = Database(mock_config)
scan = db2.get_scan(scan_id)
db2.close()
assert scan is not None
class TestDatabaseCleanup:
"""Tests for database cleanup and maintenance."""
@pytest.fixture
def real_db(self):
"""Create a real SQLite database for testing."""
from src.storage.database import Database
with tempfile.TemporaryDirectory() as tmpdir:
db_path = Path(tmpdir) / "cleanup_test.db"
mock_config = MagicMock()
mock_config.get.side_effect = lambda *args, **kwargs: {
('database', 'type'): 'sqlite',
('database', 'sqlite', 'path'): str(db_path),
}.get(args, kwargs.get('default'))
with patch('src.storage.database.Config', return_value=mock_config):
db = Database(mock_config)
yield db
db.close()
def test_delete_old_scans(self, real_db):
"""Test deleting old scan records."""
# Create old scans
for i in range(5):
real_db.save_scan({
'scan_id': f'old-scan-{i:03d}',
'query': f'test {i}',
'engine': 'shodan',
'started_at': datetime.utcnow() - timedelta(days=365),
'status': 'completed'
})
# If cleanup method exists, test it
if hasattr(real_db, 'cleanup_old_scans'):
deleted = real_db.cleanup_old_scans(days=30)
assert deleted >= 0
+198
View File
@@ -0,0 +1,198 @@
"""
Integration Tests for Scan Workflow
Tests the complete scan workflow from query to report generation.
"""
import pytest
import tempfile
from pathlib import Path
from unittest.mock import MagicMock, patch
from datetime import datetime
class TestEndToEndScan:
"""Integration tests for complete scan workflow."""
@pytest.fixture
def mock_shodan_response(self, sample_search_results):
"""Mock Shodan API response."""
return {
'matches': sample_search_results,
'total': len(sample_search_results)
}
@pytest.fixture
def temp_workspace(self):
"""Create temporary workspace for test files."""
with tempfile.TemporaryDirectory() as tmpdir:
workspace = Path(tmpdir)
(workspace / 'reports').mkdir()
(workspace / 'data').mkdir()
(workspace / 'logs').mkdir()
yield workspace
def test_scan_template_workflow(self, mock_shodan_response, temp_workspace):
"""Test scanning using a template."""
from unittest.mock import patch, MagicMock
mock_client = MagicMock()
mock_client.info.return_value = {'plan': 'dev', 'query_credits': 100}
mock_client.search.return_value = mock_shodan_response
with patch('shodan.Shodan', return_value=mock_client):
with patch.dict('os.environ', {
'SHODAN_API_KEY': 'test_key_12345',
'AASRT_REPORTS_DIR': str(temp_workspace / 'reports'),
'AASRT_DATA_DIR': str(temp_workspace / 'data'),
}):
# Import after patching
from src.core.query_manager import QueryManager
from src.utils.config import Config
config = Config()
qm = QueryManager(config)
# Check templates are available
templates = qm.get_available_templates()
assert len(templates) > 0
def test_custom_query_workflow(self, mock_shodan_response, temp_workspace):
"""Test scanning with a custom query."""
mock_client = MagicMock()
mock_client.info.return_value = {'plan': 'dev', 'query_credits': 100}
mock_client.search.return_value = mock_shodan_response
with patch('shodan.Shodan', return_value=mock_client):
with patch.dict('os.environ', {
'SHODAN_API_KEY': 'test_key_12345',
}):
from src.engines.shodan_engine import ShodanEngine
from src.utils.config import Config
config = Config()
engine = ShodanEngine(config=config)
engine._client = mock_client
results = engine.search('http.title:"Test"')
assert len(results) > 0
class TestVulnerabilityAssessmentIntegration:
"""Integration tests for vulnerability assessment pipeline."""
def test_assess_search_results(self, sample_search_results):
"""Test vulnerability assessment on search results."""
from src.core.vulnerability_assessor import VulnerabilityAssessor
from src.engines.base import SearchResult
assessor = VulnerabilityAssessor()
# Convert sample data to SearchResult
result = SearchResult(
ip=sample_search_results[0]['ip_str'],
port=sample_search_results[0]['port'],
protocol='tcp',
banner=sample_search_results[0].get('data', ''),
metadata=sample_search_results[0]
)
vulns = assessor.assess(result)
# Should return a list (may be empty if no vulns detected)
assert isinstance(vulns, list)
def test_risk_scoring_integration(self, sample_search_results):
"""Test risk scoring on assessed results."""
from src.core.risk_scorer import RiskScorer
from src.core.vulnerability_assessor import VulnerabilityAssessor
from src.engines.base import SearchResult
assessor = VulnerabilityAssessor()
scorer = RiskScorer()
result = SearchResult(
ip=sample_search_results[0]['ip_str'],
port=sample_search_results[0]['port'],
protocol='tcp',
banner=sample_search_results[0].get('data', ''),
metadata=sample_search_results[0]
)
vulns = assessor.assess(result)
score = scorer.score(result)
assert 0 <= score <= 100
class TestReportGenerationIntegration:
"""Integration tests for report generation."""
@pytest.fixture
def temp_reports_dir(self):
"""Create temporary reports directory."""
with tempfile.TemporaryDirectory() as tmpdir:
yield Path(tmpdir)
def test_json_report_generation(self, temp_reports_dir, sample_search_results):
"""Test JSON report generation."""
from src.reporting import JSONReporter, ScanReport
from src.engines.base import SearchResult
# Create scan report data
results = [
SearchResult(
ip=r['ip_str'],
port=r['port'],
protocol='tcp',
banner=r.get('data', ''),
metadata=r
) for r in sample_search_results
]
report = ScanReport(
scan_id='test-scan-001',
query='test query',
engine='shodan',
started_at=datetime.utcnow(),
completed_at=datetime.utcnow(),
results=results,
total_results=len(results)
)
reporter = JSONReporter(output_dir=str(temp_reports_dir))
output_path = reporter.generate(report)
assert Path(output_path).exists()
assert output_path.endswith('.json')
def test_csv_report_generation(self, temp_reports_dir, sample_search_results):
"""Test CSV report generation."""
from src.reporting import CSVReporter, ScanReport
from src.engines.base import SearchResult
results = [
SearchResult(
ip=r['ip_str'],
port=r['port'],
protocol='tcp',
banner=r.get('data', ''),
metadata=r
) for r in sample_search_results
]
report = ScanReport(
scan_id='test-scan-002',
query='test query',
engine='shodan',
started_at=datetime.utcnow(),
completed_at=datetime.utcnow(),
results=results,
total_results=len(results)
)
reporter = CSVReporter(output_dir=str(temp_reports_dir))
output_path = reporter.generate(report)
assert Path(output_path).exists()
assert output_path.endswith('.csv')
+7
View File
@@ -0,0 +1,7 @@
"""
Unit Tests for AASRT
Unit tests test individual components in isolation using mocks
for external dependencies.
"""
+165
View File
@@ -0,0 +1,165 @@
"""
Unit Tests for Config Module
Tests for src/utils/config.py
"""
import os
import tempfile
from pathlib import Path
from unittest.mock import patch, MagicMock
import pytest
class TestConfigLoading:
"""Tests for configuration loading."""
def test_load_from_yaml(self, temp_dir):
"""Test loading configuration from YAML file."""
from src.utils.config import Config
config_content = """
shodan:
enabled: true
rate_limit: 1
max_results: 100
logging:
level: DEBUG
"""
config_path = temp_dir / "config.yaml"
config_path.write_text(config_content)
with patch.dict(os.environ, {'AASRT_CONFIG_PATH': str(config_path)}):
config = Config()
assert config.get('shodan', 'enabled') is True
assert config.get('shodan', 'rate_limit') == 1
def test_environment_variable_override(self, temp_dir, monkeypatch):
"""Test environment variables override config file."""
from src.utils.config import Config
config_content = """
shodan:
enabled: true
rate_limit: 1
"""
config_path = temp_dir / "config.yaml"
config_path.write_text(config_content)
# Reset the Config singleton so we get a fresh instance
Config._instance = None
Config._config = {}
# Use monkeypatch for proper isolation - clear and set
monkeypatch.setenv('AASRT_CONFIG_PATH', str(config_path))
monkeypatch.setenv('SHODAN_API_KEY', 'env_api_key_12345')
config = Config()
assert config.get_shodan_key() == 'env_api_key_12345'
# Clean up - reset singleton for other tests
Config._instance = None
Config._config = {}
def test_default_values(self):
"""Test default configuration values are used when not specified."""
from src.utils.config import Config
with patch.dict(os.environ, {'SHODAN_API_KEY': 'test_key'}):
config = Config()
# Check default logging level if not set
log_level = config.get('logging', 'level', default='INFO')
assert log_level in ['DEBUG', 'INFO', 'WARNING', 'ERROR']
class TestConfigValidation:
"""Tests for configuration validation."""
def test_validate_shodan_key_format(self):
"""Test Shodan API key format validation."""
from src.utils.config import Config
# Valid key format (typically alphanumeric)
with patch.dict(os.environ, {'SHODAN_API_KEY': 'AbCdEf123456789012345678'}):
config = Config()
key = config.get_shodan_key()
assert key is not None
def test_missing_required_config(self):
"""Test handling of missing required configuration."""
from src.utils.config import Config
# Clear all Shodan-related env vars
env_copy = {k: v for k, v in os.environ.items() if 'SHODAN' not in k}
with patch.dict(os.environ, env_copy, clear=True):
config = Config()
# Should return None or raise exception for missing key
key = config.get_shodan_key()
# Depending on implementation, key could be None or empty
class TestConfigHealthCheck:
"""Tests for configuration health check."""
def test_health_check_returns_dict(self):
"""Test health_check returns a dictionary."""
from src.utils.config import Config
with patch.dict(os.environ, {'SHODAN_API_KEY': 'test_key'}):
config = Config()
health = config.health_check()
assert isinstance(health, dict)
assert 'status' in health or 'healthy' in health
def test_health_check_includes_key_info(self):
"""Test health_check includes API key status."""
from src.utils.config import Config
with patch.dict(os.environ, {'SHODAN_API_KEY': 'test_key'}):
config = Config()
health = config.health_check()
# Should indicate whether key is configured
assert health is not None
class TestConfigGet:
"""Tests for the get() method."""
def test_nested_key_access(self, temp_dir):
"""Test accessing nested configuration values."""
from src.utils.config import Config
config_content = """
database:
sqlite:
path: ./data/scanner.db
pool_size: 5
"""
config_path = temp_dir / "config.yaml"
config_path.write_text(config_content)
with patch.dict(os.environ, {'AASRT_CONFIG_PATH': str(config_path)}):
config = Config()
path = config.get('database', 'sqlite', 'path')
assert path is not None
def test_default_for_missing_key(self):
"""Test default value is returned for missing keys."""
from src.utils.config import Config
with patch.dict(os.environ, {'SHODAN_API_KEY': 'test_key'}):
config = Config()
value = config.get('nonexistent', 'key', default='default_value')
assert value == 'default_value'
def test_none_for_missing_key_no_default(self):
"""Test None is returned for missing keys without default."""
from src.utils.config import Config
with patch.dict(os.environ, {'SHODAN_API_KEY': 'test_key'}):
config = Config()
value = config.get('nonexistent', 'key')
assert value is None
+204
View File
@@ -0,0 +1,204 @@
"""
Unit Tests for Database Module
Tests for src/storage/database.py
"""
import pytest
from unittest.mock import MagicMock, patch
from datetime import datetime
class TestDatabaseInit:
"""Tests for Database initialization."""
def test_init_creates_tables(self, temp_db, mock_config):
"""Test database initialization creates tables."""
from src.storage.database import Database
mock_config.get.side_effect = lambda *args, **kwargs: {
('database', 'type'): 'sqlite',
('database', 'sqlite', 'path'): str(temp_db),
}.get(args, kwargs.get('default'))
db = Database(mock_config)
assert db is not None
db.close()
def test_init_sqlite_with_temp_path(self, temp_db, mock_config):
"""Test SQLite database with temp path."""
from src.storage.database import Database
mock_config.get.side_effect = lambda *args, **kwargs: {
('database', 'type'): 'sqlite',
('database', 'sqlite', 'path'): str(temp_db),
}.get(args, kwargs.get('default'))
db = Database(mock_config)
assert db is not None
assert db._db_type == 'sqlite'
db.close()
class TestDatabaseOperations:
"""Tests for database CRUD operations."""
@pytest.fixture
def db(self, temp_db, mock_config):
"""Create a test database instance."""
from src.storage.database import Database
mock_config.get.side_effect = lambda *args, **kwargs: {
('database', 'type'): 'sqlite',
('database', 'sqlite', 'path'): str(temp_db),
}.get(args, kwargs.get('default'))
db = Database(mock_config)
yield db
db.close()
def test_create_scan(self, db):
"""Test creating a scan record."""
scan = db.create_scan(
engines=['shodan'],
query='http.title:"ClawdBot"'
)
assert scan is not None
assert scan.scan_id is not None
def test_get_scan_by_id(self, db):
"""Test retrieving a scan by ID."""
# Create a scan first
scan = db.create_scan(
engines=['shodan'],
query='test query'
)
retrieved = db.get_scan(scan.scan_id)
assert retrieved is not None
assert retrieved.scan_id == scan.scan_id
def test_get_recent_scans(self, db):
"""Test retrieving recent scans."""
# Create a few scans
for i in range(3):
db.create_scan(
engines=['shodan'],
query=f'test query {i}'
)
scans = db.get_recent_scans(limit=10)
assert len(scans) >= 3
def test_add_findings(self, db):
"""Test adding findings to a scan."""
from src.engines.base import SearchResult
# First create a scan
scan = db.create_scan(
engines=['shodan'],
query='test'
)
# Create some search results
results = [
SearchResult(
ip='192.0.2.1',
port=8080,
banner='ClawdBot Dashboard',
vulnerabilities=['exposed_dashboard']
)
]
count = db.add_findings(scan.scan_id, results)
assert count >= 1
def test_update_scan(self, db):
"""Test updating a scan."""
# Create a scan
scan = db.create_scan(
engines=['shodan'],
query='test'
)
# Update it
updated = db.update_scan(
scan.scan_id,
status='completed',
total_results=5
)
assert updated is not None
assert updated.status == 'completed'
class TestDatabaseHealthCheck:
"""Tests for database health check."""
@pytest.fixture
def db(self, temp_db, mock_config):
"""Create a test database instance."""
from src.storage.database import Database
mock_config.get.side_effect = lambda *args, **kwargs: {
('database', 'type'): 'sqlite',
('database', 'sqlite', 'path'): str(temp_db),
}.get(args, kwargs.get('default'))
with patch('src.storage.database.Config', return_value=mock_config):
db = Database(mock_config)
yield db
db.close()
def test_health_check_returns_dict(self, db):
"""Test health_check returns a dictionary."""
health = db.health_check()
assert isinstance(health, dict)
def test_health_check_includes_status(self, db):
"""Test health_check includes status."""
health = db.health_check()
assert 'status' in health or 'healthy' in health
def test_health_check_includes_latency(self, db):
"""Test health_check includes latency measurement."""
health = db.health_check()
# Should have some form of latency/response time
has_latency = 'latency' in health or 'latency_ms' in health or 'response_time' in health
assert has_latency or health.get('status') == 'healthy'
class TestDatabaseSessionScope:
"""Tests for session_scope context manager."""
@pytest.fixture
def db(self, temp_db, mock_config):
"""Create a test database instance."""
from src.storage.database import Database
mock_config.get.side_effect = lambda *args, **kwargs: {
('database', 'type'): 'sqlite',
('database', 'sqlite', 'path'): str(temp_db),
}.get(args, kwargs.get('default'))
with patch('src.storage.database.Config', return_value=mock_config):
db = Database(mock_config)
yield db
db.close()
def test_session_scope_commits(self, db):
"""Test session_scope commits on success."""
with db.session_scope() as session:
# Perform some operation
pass
# Should complete without error
def test_session_scope_rollback_on_error(self, db):
"""Test session_scope rolls back on error."""
try:
with db.session_scope() as session:
raise ValueError("Test error")
except ValueError:
pass # Expected
# Session should have been rolled back
+125
View File
@@ -0,0 +1,125 @@
"""
Unit Tests for Risk Scorer Module
Tests for src/core/risk_scorer.py
"""
import pytest
from unittest.mock import MagicMock
class TestRiskScorer:
"""Tests for RiskScorer class."""
@pytest.fixture
def risk_scorer(self):
"""Create a RiskScorer instance."""
from src.core.risk_scorer import RiskScorer
return RiskScorer()
@pytest.fixture
def sample_vulnerabilities(self):
"""Create sample Vulnerability objects."""
from src.core.vulnerability_assessor import Vulnerability
return [
Vulnerability(
check_name='exposed_dashboard',
severity='HIGH',
cvss_score=7.5,
description='Dashboard exposed without authentication'
)
]
@pytest.fixture
def sample_result(self, sample_shodan_result):
"""Create a SearchResult with vulnerabilities."""
from src.engines.base import SearchResult
result = SearchResult(
ip=sample_shodan_result['ip_str'],
port=sample_shodan_result['port'],
banner=sample_shodan_result['data'],
metadata=sample_shodan_result
)
return result
def test_calculate_score_returns_dict(self, risk_scorer, sample_vulnerabilities):
"""Test that calculate_score returns a dictionary."""
result = risk_scorer.calculate_score(sample_vulnerabilities)
assert isinstance(result, dict)
assert 'overall_score' in result
def test_calculate_score_range_valid(self, risk_scorer, sample_vulnerabilities):
"""Test that score is within valid range (0-10)."""
result = risk_scorer.calculate_score(sample_vulnerabilities)
assert 0 <= result['overall_score'] <= 10
def test_high_risk_vulnerabilities_increase_score(self, risk_scorer):
"""Test that high-risk vulnerabilities increase score."""
from src.core.vulnerability_assessor import Vulnerability
# High severity vulnerabilities
high_vulns = [
Vulnerability(check_name='api_key_exposure', severity='CRITICAL', cvss_score=9.0, description='API key exposed'),
Vulnerability(check_name='no_authentication', severity='CRITICAL', cvss_score=9.5, description='No auth')
]
# Low severity vulnerabilities
low_vulns = [
Vulnerability(check_name='version_exposed', severity='LOW', cvss_score=2.0, description='Version info')
]
high_score = risk_scorer.calculate_score(high_vulns)['overall_score']
low_score = risk_scorer.calculate_score(low_vulns)['overall_score']
assert high_score > low_score
def test_empty_vulnerabilities_zero_score(self, risk_scorer):
"""Test that no vulnerabilities result in zero score."""
result = risk_scorer.calculate_score([])
assert result['overall_score'] == 0
def test_score_result_updates_search_result(self, risk_scorer, sample_result, sample_vulnerabilities):
"""Test that score_result updates the SearchResult."""
scored_result = risk_scorer.score_result(sample_result, sample_vulnerabilities)
assert scored_result.risk_score >= 0
assert 'risk_assessment' in scored_result.metadata
def test_context_multipliers_applied(self, risk_scorer, sample_vulnerabilities):
"""Test that context multipliers affect the score."""
# Score with no context
base_score = risk_scorer.calculate_score(sample_vulnerabilities)['overall_score']
# Score with context multiplier
context = {'public_internet': True, 'no_waf': True, 'ai_agent': True}
context_score = risk_scorer.calculate_score(sample_vulnerabilities, context)['overall_score']
# Context should increase or maintain score
assert context_score >= base_score
def test_severity_breakdown_included(self, risk_scorer, sample_vulnerabilities):
"""Test that severity breakdown is included in results."""
result = risk_scorer.calculate_score(sample_vulnerabilities)
assert 'severity_breakdown' in result
assert isinstance(result['severity_breakdown'], dict)
class TestRiskCategories:
"""Tests for risk categorization."""
@pytest.fixture
def risk_scorer(self):
"""Create a RiskScorer instance."""
from src.core.risk_scorer import RiskScorer
return RiskScorer()
def test_get_risk_level(self, risk_scorer):
"""Test risk level categorization."""
# Test if there's a method to get risk level string
if hasattr(risk_scorer, 'get_risk_level'):
assert risk_scorer.get_risk_level(95) in ['CRITICAL', 'HIGH', 'critical', 'high']
assert risk_scorer.get_risk_level(75) in ['HIGH', 'MEDIUM', 'high', 'medium']
assert risk_scorer.get_risk_level(50) in ['MEDIUM', 'medium']
assert risk_scorer.get_risk_level(25) in ['LOW', 'low']
assert risk_scorer.get_risk_level(10) in ['INFO', 'LOW', 'info', 'low']
+179
View File
@@ -0,0 +1,179 @@
"""
Unit Tests for Shodan Engine Module
Tests for src/engines/shodan_engine.py
"""
import pytest
from unittest.mock import MagicMock, patch
class TestShodanEngineInit:
"""Tests for ShodanEngine initialization."""
def test_init_with_valid_key(self):
"""Test initialization with valid API key."""
from src.engines.shodan_engine import ShodanEngine
with patch('shodan.Shodan') as mock_shodan:
mock_client = MagicMock()
mock_shodan.return_value = mock_client
engine = ShodanEngine(api_key='test_api_key_12345')
assert engine is not None
assert engine.name == 'shodan'
def test_init_without_key_raises_error(self):
"""Test initialization without API key raises error."""
from src.engines.shodan_engine import ShodanEngine
with pytest.raises(ValueError):
ShodanEngine(api_key='')
with pytest.raises((ValueError, TypeError)):
ShodanEngine(api_key=None)
class TestShodanEngineSearch:
"""Tests for ShodanEngine search functionality."""
@pytest.fixture
def engine(self, mock_shodan_client):
"""Create a ShodanEngine instance with mocked client."""
from src.engines.shodan_engine import ShodanEngine
with patch('shodan.Shodan', return_value=mock_shodan_client):
engine = ShodanEngine(api_key='test_api_key_12345')
engine._client = mock_shodan_client
return engine
def test_search_returns_results(self, engine, mock_shodan_client, sample_shodan_result):
"""Test search returns results."""
mock_shodan_client.search.return_value = {
'matches': [sample_shodan_result],
'total': 1
}
results = engine.search('http.title:"ClawdBot"')
assert isinstance(results, list)
def test_search_empty_query_raises_error(self, engine):
"""Test empty query raises error."""
with pytest.raises((ValueError, Exception)):
engine.search('')
def test_search_handles_api_error(self, engine, mock_shodan_client):
"""Test search handles API errors gracefully."""
import shodan
mock_shodan_client.search.side_effect = shodan.APIError('API Error')
from src.utils.exceptions import APIException
with pytest.raises((APIException, Exception)):
engine.search('test query')
def test_search_with_max_results(self, engine, mock_shodan_client, sample_shodan_result):
"""Test search respects max_results limit."""
mock_shodan_client.search.return_value = {
'matches': [sample_shodan_result],
'total': 1
}
results = engine.search('test', max_results=1)
assert len(results) <= 1
class TestShodanEngineCredentials:
"""Tests for credential validation."""
@pytest.fixture
def engine(self, mock_shodan_client):
"""Create a ShodanEngine instance."""
from src.engines.shodan_engine import ShodanEngine
with patch('shodan.Shodan', return_value=mock_shodan_client):
engine = ShodanEngine(api_key='test_api_key_12345')
engine._client = mock_shodan_client
return engine
def test_validate_credentials_success(self, engine, mock_shodan_client):
"""Test successful credential validation."""
mock_shodan_client.info.return_value = {'plan': 'dev', 'query_credits': 100}
result = engine.validate_credentials()
assert result is True
def test_validate_credentials_invalid_key(self, engine, mock_shodan_client):
"""Test invalid API key handling."""
import shodan
from src.utils.exceptions import AuthenticationException
mock_shodan_client.info.side_effect = shodan.APIError('Invalid API key')
with pytest.raises((AuthenticationException, Exception)):
engine.validate_credentials()
class TestShodanEngineQuota:
"""Tests for quota information."""
@pytest.fixture
def engine(self, mock_shodan_client):
"""Create a ShodanEngine instance."""
from src.engines.shodan_engine import ShodanEngine
with patch('shodan.Shodan', return_value=mock_shodan_client):
engine = ShodanEngine(api_key='test_api_key_12345')
engine._client = mock_shodan_client
return engine
def test_get_quota_info(self, engine, mock_shodan_client):
"""Test getting quota information."""
mock_shodan_client.info.return_value = {
'plan': 'dev',
'query_credits': 100,
'scan_credits': 50
}
quota = engine.get_quota_info()
assert isinstance(quota, dict)
def test_quota_info_handles_error(self, engine, mock_shodan_client):
"""Test quota info handles API errors."""
import shodan
from src.utils.exceptions import APIException
mock_shodan_client.info.side_effect = shodan.APIError('API Error')
# May either raise or return error info depending on implementation
try:
quota = engine.get_quota_info()
assert quota is not None
except (APIException, Exception):
pass # Acceptable if it raises
class TestShodanEngineRetry:
"""Tests for retry logic."""
def test_retry_on_transient_error(self, mock_shodan_client):
"""Test retry logic on transient errors."""
from src.engines.shodan_engine import ShodanEngine
import shodan
with patch('shodan.Shodan', return_value=mock_shodan_client):
engine = ShodanEngine(api_key='test_api_key_12345')
engine._client = mock_shodan_client
# First call fails, second succeeds
mock_shodan_client.search.side_effect = [
ConnectionError("Network error"),
{'matches': [], 'total': 0}
]
# Depending on implementation, this may retry or raise
try:
results = engine.search('test')
assert isinstance(results, list)
except Exception:
pass # Expected if retries exhausted
+180
View File
@@ -0,0 +1,180 @@
"""
Unit Tests for Validators Module
Tests for src/utils/validators.py
"""
import pytest
from src.utils.validators import (
validate_ip,
validate_domain,
validate_query,
validate_file_path,
validate_template_name,
is_safe_string,
sanitize_output,
)
from src.utils.exceptions import ValidationException
class TestValidateIP:
"""Tests for IP address validation."""
def test_valid_ipv4(self):
"""Test valid IPv4 addresses."""
assert validate_ip("192.168.1.1") is True
assert validate_ip("10.0.0.1") is True
assert validate_ip("172.16.0.1") is True
assert validate_ip("8.8.8.8") is True
def test_invalid_ipv4_raises_exception(self):
"""Test invalid IPv4 addresses raise ValidationException."""
with pytest.raises(ValidationException):
validate_ip("256.1.1.1")
with pytest.raises(ValidationException):
validate_ip("192.168.1")
with pytest.raises(ValidationException):
validate_ip("not.an.ip.address")
def test_empty_and_none_raises_exception(self):
"""Test empty and None values raise ValidationException."""
with pytest.raises(ValidationException):
validate_ip("")
with pytest.raises(ValidationException):
validate_ip(None)
def test_ipv6_addresses(self):
"""Test IPv6 address handling."""
# IPv6 addresses should be valid
assert validate_ip("::1") is True
assert validate_ip("2001:db8::1") is True
class TestValidateDomain:
"""Tests for domain validation."""
def test_valid_domains(self):
"""Test valid domain names."""
assert validate_domain("example.com") is True
assert validate_domain("sub.example.com") is True
assert validate_domain("test-site.example.org") is True
def test_invalid_domains_raises_exception(self):
"""Test invalid domain names raise ValidationException."""
with pytest.raises(ValidationException):
validate_domain("-invalid.com")
with pytest.raises(ValidationException):
validate_domain("invalid-.com")
def test_localhost_raises_exception(self):
"""Test localhost raises ValidationException (not a valid domain format)."""
with pytest.raises(ValidationException):
validate_domain("localhost")
class TestValidateQuery:
"""Tests for Shodan query validation."""
def test_valid_queries(self):
"""Test valid Shodan queries."""
assert validate_query('http.title:"ClawdBot"', engine='shodan') is True
assert validate_query("port:8080", engine='shodan') is True
assert validate_query("product:nginx", engine='shodan') is True
def test_empty_query_raises_exception(self):
"""Test empty queries raise ValidationException."""
with pytest.raises(ValidationException):
validate_query("", engine='shodan')
with pytest.raises(ValidationException):
validate_query(" ", engine='shodan')
def test_sql_injection_patterns_allowed(self):
"""Test SQL-like patterns are allowed (Shodan doesn't execute SQL)."""
# Shodan queries can contain SQL-like syntax without causing issues
result = validate_query("'; DROP TABLE users; --", engine='shodan')
assert result is True # No script tags or null bytes
class TestValidateFilePath:
"""Tests for file path validation."""
def test_valid_paths(self):
"""Test valid file paths return sanitized path."""
result = validate_file_path("reports/scan.json")
assert result is not None
assert "scan.json" in result
def test_directory_traversal_raises_exception(self):
"""Test directory traversal raises ValidationException."""
with pytest.raises(ValidationException):
validate_file_path("../../../etc/passwd")
with pytest.raises(ValidationException):
validate_file_path("..\\..\\windows\\system32")
def test_null_bytes_raises_exception(self):
"""Test null byte injection raises ValidationException."""
with pytest.raises(ValidationException):
validate_file_path("file.txt\x00.exe")
class TestValidateTemplateName:
"""Tests for template name validation."""
def test_valid_templates(self):
"""Test valid template names."""
assert validate_template_name("clawdbot_instances") is True
assert validate_template_name("autogpt_instances") is True
def test_invalid_template_raises_exception(self):
"""Test invalid template names raise ValidationException."""
with pytest.raises(ValidationException):
validate_template_name("nonexistent_template")
def test_empty_template_raises_exception(self):
"""Test empty template names raise ValidationException."""
with pytest.raises(ValidationException):
validate_template_name("")
with pytest.raises(ValidationException):
validate_template_name(None)
class TestIsSafeString:
"""Tests for safe string detection."""
def test_safe_strings(self):
"""Test safe strings pass validation."""
assert is_safe_string("hello world") is True
assert is_safe_string("ClawdBot Dashboard") is True
def test_script_tags_detected(self):
"""Test script tags are detected as unsafe."""
assert is_safe_string("<script>alert('xss')</script>") is False
def test_sql_patterns_allowed(self):
"""Test SQL-like patterns are allowed (is_safe_string checks XSS, not SQL)."""
# Note: is_safe_string focuses on XSS patterns, not SQL injection
result = is_safe_string("'; DROP TABLE users; --")
# This may or may not be detected depending on implementation
assert isinstance(result, bool)
class TestSanitizeOutput:
"""Tests for output sanitization."""
def test_password_redaction(self):
"""Test passwords are redacted."""
output = sanitize_output("password=mysecretpassword")
assert "mysecretpassword" not in output
def test_normal_text_unchanged(self):
"""Test normal text is not modified."""
text = "This is normal text without secrets"
assert sanitize_output(text) == text
def test_api_key_pattern_redaction(self):
"""Test API key patterns are redacted."""
# Test with patterns that match the redaction rules
output = sanitize_output("api_key=12345678901234567890")
# Depending on implementation, may or may not be redacted
assert output is not None