mirror of
https://github.com/0xsrb/AASRT.git
synced 2026-04-23 21:25:59 +02:00
Initial commit: AASRT v1.0.0 - AI Agent Security Reconnaissance Tool
This commit is contained in:
@@ -0,0 +1,18 @@
|
||||
"""
|
||||
AASRT Test Suite
|
||||
|
||||
This package contains all tests for the AI Agent Security Reconnaissance Tool.
|
||||
|
||||
Test Categories:
|
||||
- Unit tests: Test individual components in isolation
|
||||
- Integration tests: Test component interactions
|
||||
- End-to-end tests: Test complete workflows
|
||||
|
||||
Running Tests:
|
||||
pytest # Run all tests
|
||||
pytest tests/unit/ # Run unit tests only
|
||||
pytest tests/integration/ # Run integration tests only
|
||||
pytest -v --cov=src # Run with coverage
|
||||
pytest -m "not slow" # Skip slow tests
|
||||
"""
|
||||
|
||||
@@ -0,0 +1,181 @@
|
||||
"""
|
||||
Pytest Configuration and Shared Fixtures
|
||||
|
||||
This module provides shared fixtures and configuration for all tests.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Generator, List
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
# Add src to path for imports
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Environment Fixtures
|
||||
# =============================================================================
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def test_env():
|
||||
"""Set up test environment variables."""
|
||||
original_env = os.environ.copy()
|
||||
|
||||
os.environ.update({
|
||||
'SHODAN_API_KEY': 'test_api_key_12345',
|
||||
'AASRT_ENVIRONMENT': 'testing',
|
||||
'AASRT_LOG_LEVEL': 'DEBUG',
|
||||
'AASRT_DEBUG': 'true',
|
||||
})
|
||||
|
||||
yield os.environ
|
||||
|
||||
# Restore original environment
|
||||
os.environ.clear()
|
||||
os.environ.update(original_env)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_dir():
|
||||
"""Create a temporary directory for test files."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
yield Path(tmpdir)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_db(temp_dir):
|
||||
"""Create a temporary database path."""
|
||||
return temp_dir / "test_scanner.db"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Mock Data Fixtures
|
||||
# =============================================================================
|
||||
|
||||
@pytest.fixture
|
||||
def sample_shodan_result() -> Dict[str, Any]:
|
||||
"""Sample Shodan API response."""
|
||||
return {
|
||||
'ip_str': '192.0.2.1',
|
||||
'port': 8080,
|
||||
'transport': 'tcp',
|
||||
'hostnames': ['test.example.com'],
|
||||
'org': 'Test Organization',
|
||||
'asn': 'AS12345',
|
||||
'isp': 'Test ISP',
|
||||
'data': 'HTTP/1.1 200 OK\r\nServer: nginx\r\n\r\nClawdBot Dashboard',
|
||||
'location': {
|
||||
'country_code': 'US',
|
||||
'country_name': 'United States',
|
||||
'city': 'Test City',
|
||||
'latitude': 37.7749,
|
||||
'longitude': -122.4194
|
||||
},
|
||||
'http': {
|
||||
'status': 200,
|
||||
'title': 'ClawdBot Dashboard',
|
||||
'server': 'nginx/1.18.0',
|
||||
'html': '<html><body>ClawdBot Dashboard</body></html>'
|
||||
},
|
||||
'vulns': ['CVE-2021-44228'],
|
||||
'timestamp': '2024-01-15T10:30:00.000000'
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_search_results(sample_shodan_result) -> List[Dict[str, Any]]:
|
||||
"""Multiple sample Shodan results."""
|
||||
results = [sample_shodan_result]
|
||||
|
||||
# Add more varied results
|
||||
results.append({
|
||||
**sample_shodan_result,
|
||||
'ip_str': '192.0.2.2',
|
||||
'port': 3000,
|
||||
'http': {
|
||||
'status': 200,
|
||||
'title': 'AutoGPT Interface',
|
||||
'server': 'Python/3.11'
|
||||
}
|
||||
})
|
||||
|
||||
results.append({
|
||||
**sample_shodan_result,
|
||||
'ip_str': '192.0.2.3',
|
||||
'port': 443,
|
||||
'http': {
|
||||
'status': 401,
|
||||
'title': 'Login Required'
|
||||
}
|
||||
})
|
||||
|
||||
return results
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_vulnerability() -> Dict[str, Any]:
|
||||
"""Sample vulnerability data."""
|
||||
return {
|
||||
'check_name': 'exposed_dashboard',
|
||||
'severity': 'HIGH',
|
||||
'cvss_score': 7.5,
|
||||
'description': 'Dashboard accessible without authentication',
|
||||
'evidence': {'http_title': 'ClawdBot Dashboard'},
|
||||
'remediation': 'Implement authentication',
|
||||
'cwe_id': 'CWE-306'
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Mock Service Fixtures
|
||||
# =============================================================================
|
||||
|
||||
@pytest.fixture
|
||||
def mock_shodan_client():
|
||||
"""Mock Shodan API client."""
|
||||
with patch('shodan.Shodan') as mock:
|
||||
client = MagicMock()
|
||||
client.info.return_value = {
|
||||
'plan': 'dev',
|
||||
'query_credits': 100,
|
||||
'scan_credits': 50
|
||||
}
|
||||
mock.return_value = client
|
||||
yield client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config(temp_dir, temp_db):
|
||||
"""Mock configuration object."""
|
||||
config = MagicMock()
|
||||
config.get_shodan_key.return_value = 'test_api_key'
|
||||
config.get.side_effect = lambda *args, default=None: {
|
||||
('database', 'type'): 'sqlite',
|
||||
('database', 'sqlite', 'path'): str(temp_db),
|
||||
('logging', 'level'): 'DEBUG',
|
||||
('reporting', 'output_dir'): str(temp_dir / 'reports'),
|
||||
('vulnerability_checks',): {'passive_only': True},
|
||||
}.get(args, default)
|
||||
return config
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Database Fixtures
|
||||
# =============================================================================
|
||||
|
||||
@pytest.fixture
|
||||
def test_database(mock_config, temp_db):
|
||||
"""Create a test database instance."""
|
||||
from src.storage.database import Database
|
||||
|
||||
# Patch config to use temp database
|
||||
with patch('src.storage.database.Config', return_value=mock_config):
|
||||
db = Database(mock_config)
|
||||
yield db
|
||||
db.close()
|
||||
|
||||
@@ -0,0 +1,7 @@
|
||||
"""
|
||||
Integration Tests for AASRT
|
||||
|
||||
Integration tests verify that components work together correctly.
|
||||
These tests may use real (test) databases but mock external APIs.
|
||||
"""
|
||||
|
||||
@@ -0,0 +1,207 @@
|
||||
"""
|
||||
Integration Tests for Database Operations
|
||||
|
||||
Tests database operations with real SQLite database.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch, MagicMock
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
|
||||
class TestDatabaseIntegration:
|
||||
"""Integration tests for database with real SQLite."""
|
||||
|
||||
@pytest.fixture
|
||||
def real_db(self):
|
||||
"""Create a real SQLite database for testing."""
|
||||
from src.storage.database import Database
|
||||
from src.utils.config import Config
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
db_path = Path(tmpdir) / "test_scanner.db"
|
||||
|
||||
mock_config = MagicMock()
|
||||
mock_config.get.side_effect = lambda *args, **kwargs: {
|
||||
('database', 'type'): 'sqlite',
|
||||
('database', 'sqlite', 'path'): str(db_path),
|
||||
}.get(args, kwargs.get('default'))
|
||||
mock_config.get_shodan_key.return_value = 'test_key'
|
||||
|
||||
with patch('src.storage.database.Config', return_value=mock_config):
|
||||
db = Database(mock_config)
|
||||
yield db
|
||||
db.close()
|
||||
|
||||
def test_scan_lifecycle(self, real_db):
|
||||
"""Test complete scan lifecycle: create, update, retrieve."""
|
||||
# Create scan
|
||||
scan_id = 'integration-test-scan-001'
|
||||
real_db.save_scan({
|
||||
'scan_id': scan_id,
|
||||
'query': 'http.title:"ClawdBot"',
|
||||
'engine': 'shodan',
|
||||
'started_at': datetime.utcnow(),
|
||||
'status': 'running',
|
||||
'total_results': 0
|
||||
})
|
||||
|
||||
# Update scan status
|
||||
real_db.update_scan(scan_id, {
|
||||
'status': 'completed',
|
||||
'total_results': 25,
|
||||
'completed_at': datetime.utcnow()
|
||||
})
|
||||
|
||||
# Retrieve scan
|
||||
scan = real_db.get_scan(scan_id)
|
||||
assert scan is not None
|
||||
|
||||
def test_findings_association(self, real_db):
|
||||
"""Test findings are properly associated with scans."""
|
||||
scan_id = 'integration-test-scan-002'
|
||||
|
||||
# Create scan
|
||||
real_db.save_scan({
|
||||
'scan_id': scan_id,
|
||||
'query': 'test query',
|
||||
'engine': 'shodan',
|
||||
'started_at': datetime.utcnow(),
|
||||
'status': 'completed'
|
||||
})
|
||||
|
||||
# Save multiple findings
|
||||
for i in range(5):
|
||||
real_db.save_finding({
|
||||
'scan_id': scan_id,
|
||||
'ip': f'192.0.2.{i+1}',
|
||||
'port': 8080 + i,
|
||||
'risk_score': 50 + i * 10,
|
||||
'vulnerabilities': ['test_vuln']
|
||||
})
|
||||
|
||||
# Retrieve findings
|
||||
findings = real_db.get_findings_by_scan(scan_id)
|
||||
assert len(findings) == 5
|
||||
|
||||
def test_scan_statistics(self, real_db):
|
||||
"""Test scan statistics calculation."""
|
||||
# Create multiple scans with different statuses
|
||||
for i in range(10):
|
||||
real_db.save_scan({
|
||||
'scan_id': f'stats-test-{i:03d}',
|
||||
'query': f'test query {i}',
|
||||
'engine': 'shodan',
|
||||
'started_at': datetime.utcnow() - timedelta(days=i),
|
||||
'status': 'completed' if i % 2 == 0 else 'failed',
|
||||
'total_results': i * 10
|
||||
})
|
||||
|
||||
# Get statistics
|
||||
if hasattr(real_db, 'get_scan_statistics'):
|
||||
stats = real_db.get_scan_statistics()
|
||||
assert 'total_scans' in stats or stats is not None
|
||||
|
||||
def test_concurrent_operations(self, real_db):
|
||||
"""Test concurrent database operations."""
|
||||
import threading
|
||||
|
||||
errors = []
|
||||
|
||||
def save_scan(scan_num):
|
||||
try:
|
||||
real_db.save_scan({
|
||||
'scan_id': f'concurrent-test-{scan_num:03d}',
|
||||
'query': f'test {scan_num}',
|
||||
'engine': 'shodan',
|
||||
'started_at': datetime.utcnow(),
|
||||
'status': 'completed'
|
||||
})
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
|
||||
threads = [threading.Thread(target=save_scan, args=(i,)) for i in range(5)]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
# Should complete without deadlocks
|
||||
assert len(errors) == 0
|
||||
|
||||
def test_data_persistence(self):
|
||||
"""Test that data persists across database connections."""
|
||||
from src.storage.database import Database
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
db_path = Path(tmpdir) / "persistence_test.db"
|
||||
scan_id = 'persistence-test-001'
|
||||
|
||||
mock_config = MagicMock()
|
||||
mock_config.get.side_effect = lambda *args, **kwargs: {
|
||||
('database', 'type'): 'sqlite',
|
||||
('database', 'sqlite', 'path'): str(db_path),
|
||||
}.get(args, kwargs.get('default'))
|
||||
|
||||
# First connection - create data
|
||||
with patch('src.storage.database.Config', return_value=mock_config):
|
||||
db1 = Database(mock_config)
|
||||
db1.save_scan({
|
||||
'scan_id': scan_id,
|
||||
'query': 'test',
|
||||
'engine': 'shodan',
|
||||
'started_at': datetime.utcnow(),
|
||||
'status': 'completed'
|
||||
})
|
||||
db1.close()
|
||||
|
||||
# Second connection - verify data exists
|
||||
with patch('src.storage.database.Config', return_value=mock_config):
|
||||
db2 = Database(mock_config)
|
||||
scan = db2.get_scan(scan_id)
|
||||
db2.close()
|
||||
|
||||
assert scan is not None
|
||||
|
||||
|
||||
class TestDatabaseCleanup:
|
||||
"""Tests for database cleanup and maintenance."""
|
||||
|
||||
@pytest.fixture
|
||||
def real_db(self):
|
||||
"""Create a real SQLite database for testing."""
|
||||
from src.storage.database import Database
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
db_path = Path(tmpdir) / "cleanup_test.db"
|
||||
|
||||
mock_config = MagicMock()
|
||||
mock_config.get.side_effect = lambda *args, **kwargs: {
|
||||
('database', 'type'): 'sqlite',
|
||||
('database', 'sqlite', 'path'): str(db_path),
|
||||
}.get(args, kwargs.get('default'))
|
||||
|
||||
with patch('src.storage.database.Config', return_value=mock_config):
|
||||
db = Database(mock_config)
|
||||
yield db
|
||||
db.close()
|
||||
|
||||
def test_delete_old_scans(self, real_db):
|
||||
"""Test deleting old scan records."""
|
||||
# Create old scans
|
||||
for i in range(5):
|
||||
real_db.save_scan({
|
||||
'scan_id': f'old-scan-{i:03d}',
|
||||
'query': f'test {i}',
|
||||
'engine': 'shodan',
|
||||
'started_at': datetime.utcnow() - timedelta(days=365),
|
||||
'status': 'completed'
|
||||
})
|
||||
|
||||
# If cleanup method exists, test it
|
||||
if hasattr(real_db, 'cleanup_old_scans'):
|
||||
deleted = real_db.cleanup_old_scans(days=30)
|
||||
assert deleted >= 0
|
||||
|
||||
@@ -0,0 +1,198 @@
|
||||
"""
|
||||
Integration Tests for Scan Workflow
|
||||
|
||||
Tests the complete scan workflow from query to report generation.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class TestEndToEndScan:
|
||||
"""Integration tests for complete scan workflow."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_shodan_response(self, sample_search_results):
|
||||
"""Mock Shodan API response."""
|
||||
return {
|
||||
'matches': sample_search_results,
|
||||
'total': len(sample_search_results)
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def temp_workspace(self):
|
||||
"""Create temporary workspace for test files."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
workspace = Path(tmpdir)
|
||||
(workspace / 'reports').mkdir()
|
||||
(workspace / 'data').mkdir()
|
||||
(workspace / 'logs').mkdir()
|
||||
yield workspace
|
||||
|
||||
def test_scan_template_workflow(self, mock_shodan_response, temp_workspace):
|
||||
"""Test scanning using a template."""
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.info.return_value = {'plan': 'dev', 'query_credits': 100}
|
||||
mock_client.search.return_value = mock_shodan_response
|
||||
|
||||
with patch('shodan.Shodan', return_value=mock_client):
|
||||
with patch.dict('os.environ', {
|
||||
'SHODAN_API_KEY': 'test_key_12345',
|
||||
'AASRT_REPORTS_DIR': str(temp_workspace / 'reports'),
|
||||
'AASRT_DATA_DIR': str(temp_workspace / 'data'),
|
||||
}):
|
||||
# Import after patching
|
||||
from src.core.query_manager import QueryManager
|
||||
from src.utils.config import Config
|
||||
|
||||
config = Config()
|
||||
qm = QueryManager(config)
|
||||
|
||||
# Check templates are available
|
||||
templates = qm.get_available_templates()
|
||||
assert len(templates) > 0
|
||||
|
||||
def test_custom_query_workflow(self, mock_shodan_response, temp_workspace):
|
||||
"""Test scanning with a custom query."""
|
||||
mock_client = MagicMock()
|
||||
mock_client.info.return_value = {'plan': 'dev', 'query_credits': 100}
|
||||
mock_client.search.return_value = mock_shodan_response
|
||||
|
||||
with patch('shodan.Shodan', return_value=mock_client):
|
||||
with patch.dict('os.environ', {
|
||||
'SHODAN_API_KEY': 'test_key_12345',
|
||||
}):
|
||||
from src.engines.shodan_engine import ShodanEngine
|
||||
from src.utils.config import Config
|
||||
|
||||
config = Config()
|
||||
engine = ShodanEngine(config=config)
|
||||
engine._client = mock_client
|
||||
|
||||
results = engine.search('http.title:"Test"')
|
||||
assert len(results) > 0
|
||||
|
||||
|
||||
class TestVulnerabilityAssessmentIntegration:
|
||||
"""Integration tests for vulnerability assessment pipeline."""
|
||||
|
||||
def test_assess_search_results(self, sample_search_results):
|
||||
"""Test vulnerability assessment on search results."""
|
||||
from src.core.vulnerability_assessor import VulnerabilityAssessor
|
||||
from src.engines.base import SearchResult
|
||||
|
||||
assessor = VulnerabilityAssessor()
|
||||
|
||||
# Convert sample data to SearchResult
|
||||
result = SearchResult(
|
||||
ip=sample_search_results[0]['ip_str'],
|
||||
port=sample_search_results[0]['port'],
|
||||
protocol='tcp',
|
||||
banner=sample_search_results[0].get('data', ''),
|
||||
metadata=sample_search_results[0]
|
||||
)
|
||||
|
||||
vulns = assessor.assess(result)
|
||||
# Should return a list (may be empty if no vulns detected)
|
||||
assert isinstance(vulns, list)
|
||||
|
||||
def test_risk_scoring_integration(self, sample_search_results):
|
||||
"""Test risk scoring on assessed results."""
|
||||
from src.core.risk_scorer import RiskScorer
|
||||
from src.core.vulnerability_assessor import VulnerabilityAssessor
|
||||
from src.engines.base import SearchResult
|
||||
|
||||
assessor = VulnerabilityAssessor()
|
||||
scorer = RiskScorer()
|
||||
|
||||
result = SearchResult(
|
||||
ip=sample_search_results[0]['ip_str'],
|
||||
port=sample_search_results[0]['port'],
|
||||
protocol='tcp',
|
||||
banner=sample_search_results[0].get('data', ''),
|
||||
metadata=sample_search_results[0]
|
||||
)
|
||||
|
||||
vulns = assessor.assess(result)
|
||||
score = scorer.score(result)
|
||||
|
||||
assert 0 <= score <= 100
|
||||
|
||||
|
||||
class TestReportGenerationIntegration:
|
||||
"""Integration tests for report generation."""
|
||||
|
||||
@pytest.fixture
|
||||
def temp_reports_dir(self):
|
||||
"""Create temporary reports directory."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
yield Path(tmpdir)
|
||||
|
||||
def test_json_report_generation(self, temp_reports_dir, sample_search_results):
|
||||
"""Test JSON report generation."""
|
||||
from src.reporting import JSONReporter, ScanReport
|
||||
from src.engines.base import SearchResult
|
||||
|
||||
# Create scan report data
|
||||
results = [
|
||||
SearchResult(
|
||||
ip=r['ip_str'],
|
||||
port=r['port'],
|
||||
protocol='tcp',
|
||||
banner=r.get('data', ''),
|
||||
metadata=r
|
||||
) for r in sample_search_results
|
||||
]
|
||||
|
||||
report = ScanReport(
|
||||
scan_id='test-scan-001',
|
||||
query='test query',
|
||||
engine='shodan',
|
||||
started_at=datetime.utcnow(),
|
||||
completed_at=datetime.utcnow(),
|
||||
results=results,
|
||||
total_results=len(results)
|
||||
)
|
||||
|
||||
reporter = JSONReporter(output_dir=str(temp_reports_dir))
|
||||
output_path = reporter.generate(report)
|
||||
|
||||
assert Path(output_path).exists()
|
||||
assert output_path.endswith('.json')
|
||||
|
||||
def test_csv_report_generation(self, temp_reports_dir, sample_search_results):
|
||||
"""Test CSV report generation."""
|
||||
from src.reporting import CSVReporter, ScanReport
|
||||
from src.engines.base import SearchResult
|
||||
|
||||
results = [
|
||||
SearchResult(
|
||||
ip=r['ip_str'],
|
||||
port=r['port'],
|
||||
protocol='tcp',
|
||||
banner=r.get('data', ''),
|
||||
metadata=r
|
||||
) for r in sample_search_results
|
||||
]
|
||||
|
||||
report = ScanReport(
|
||||
scan_id='test-scan-002',
|
||||
query='test query',
|
||||
engine='shodan',
|
||||
started_at=datetime.utcnow(),
|
||||
completed_at=datetime.utcnow(),
|
||||
results=results,
|
||||
total_results=len(results)
|
||||
)
|
||||
|
||||
reporter = CSVReporter(output_dir=str(temp_reports_dir))
|
||||
output_path = reporter.generate(report)
|
||||
|
||||
assert Path(output_path).exists()
|
||||
assert output_path.endswith('.csv')
|
||||
|
||||
@@ -0,0 +1,7 @@
|
||||
"""
|
||||
Unit Tests for AASRT
|
||||
|
||||
Unit tests test individual components in isolation using mocks
|
||||
for external dependencies.
|
||||
"""
|
||||
|
||||
@@ -0,0 +1,165 @@
|
||||
"""
|
||||
Unit Tests for Config Module
|
||||
|
||||
Tests for src/utils/config.py
|
||||
"""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestConfigLoading:
|
||||
"""Tests for configuration loading."""
|
||||
|
||||
def test_load_from_yaml(self, temp_dir):
|
||||
"""Test loading configuration from YAML file."""
|
||||
from src.utils.config import Config
|
||||
|
||||
config_content = """
|
||||
shodan:
|
||||
enabled: true
|
||||
rate_limit: 1
|
||||
max_results: 100
|
||||
|
||||
logging:
|
||||
level: DEBUG
|
||||
"""
|
||||
config_path = temp_dir / "config.yaml"
|
||||
config_path.write_text(config_content)
|
||||
|
||||
with patch.dict(os.environ, {'AASRT_CONFIG_PATH': str(config_path)}):
|
||||
config = Config()
|
||||
assert config.get('shodan', 'enabled') is True
|
||||
assert config.get('shodan', 'rate_limit') == 1
|
||||
|
||||
def test_environment_variable_override(self, temp_dir, monkeypatch):
|
||||
"""Test environment variables override config file."""
|
||||
from src.utils.config import Config
|
||||
|
||||
config_content = """
|
||||
shodan:
|
||||
enabled: true
|
||||
rate_limit: 1
|
||||
"""
|
||||
config_path = temp_dir / "config.yaml"
|
||||
config_path.write_text(config_content)
|
||||
|
||||
# Reset the Config singleton so we get a fresh instance
|
||||
Config._instance = None
|
||||
Config._config = {}
|
||||
|
||||
# Use monkeypatch for proper isolation - clear and set
|
||||
monkeypatch.setenv('AASRT_CONFIG_PATH', str(config_path))
|
||||
monkeypatch.setenv('SHODAN_API_KEY', 'env_api_key_12345')
|
||||
|
||||
config = Config()
|
||||
assert config.get_shodan_key() == 'env_api_key_12345'
|
||||
|
||||
# Clean up - reset singleton for other tests
|
||||
Config._instance = None
|
||||
Config._config = {}
|
||||
|
||||
def test_default_values(self):
|
||||
"""Test default configuration values are used when not specified."""
|
||||
from src.utils.config import Config
|
||||
|
||||
with patch.dict(os.environ, {'SHODAN_API_KEY': 'test_key'}):
|
||||
config = Config()
|
||||
# Check default logging level if not set
|
||||
log_level = config.get('logging', 'level', default='INFO')
|
||||
assert log_level in ['DEBUG', 'INFO', 'WARNING', 'ERROR']
|
||||
|
||||
|
||||
class TestConfigValidation:
|
||||
"""Tests for configuration validation."""
|
||||
|
||||
def test_validate_shodan_key_format(self):
|
||||
"""Test Shodan API key format validation."""
|
||||
from src.utils.config import Config
|
||||
|
||||
# Valid key format (typically alphanumeric)
|
||||
with patch.dict(os.environ, {'SHODAN_API_KEY': 'AbCdEf123456789012345678'}):
|
||||
config = Config()
|
||||
key = config.get_shodan_key()
|
||||
assert key is not None
|
||||
|
||||
def test_missing_required_config(self):
|
||||
"""Test handling of missing required configuration."""
|
||||
from src.utils.config import Config
|
||||
|
||||
# Clear all Shodan-related env vars
|
||||
env_copy = {k: v for k, v in os.environ.items() if 'SHODAN' not in k}
|
||||
with patch.dict(os.environ, env_copy, clear=True):
|
||||
config = Config()
|
||||
# Should return None or raise exception for missing key
|
||||
key = config.get_shodan_key()
|
||||
# Depending on implementation, key could be None or empty
|
||||
|
||||
|
||||
class TestConfigHealthCheck:
|
||||
"""Tests for configuration health check."""
|
||||
|
||||
def test_health_check_returns_dict(self):
|
||||
"""Test health_check returns a dictionary."""
|
||||
from src.utils.config import Config
|
||||
|
||||
with patch.dict(os.environ, {'SHODAN_API_KEY': 'test_key'}):
|
||||
config = Config()
|
||||
health = config.health_check()
|
||||
assert isinstance(health, dict)
|
||||
assert 'status' in health or 'healthy' in health
|
||||
|
||||
def test_health_check_includes_key_info(self):
|
||||
"""Test health_check includes API key status."""
|
||||
from src.utils.config import Config
|
||||
|
||||
with patch.dict(os.environ, {'SHODAN_API_KEY': 'test_key'}):
|
||||
config = Config()
|
||||
health = config.health_check()
|
||||
# Should indicate whether key is configured
|
||||
assert health is not None
|
||||
|
||||
|
||||
class TestConfigGet:
|
||||
"""Tests for the get() method."""
|
||||
|
||||
def test_nested_key_access(self, temp_dir):
|
||||
"""Test accessing nested configuration values."""
|
||||
from src.utils.config import Config
|
||||
|
||||
config_content = """
|
||||
database:
|
||||
sqlite:
|
||||
path: ./data/scanner.db
|
||||
pool_size: 5
|
||||
"""
|
||||
config_path = temp_dir / "config.yaml"
|
||||
config_path.write_text(config_content)
|
||||
|
||||
with patch.dict(os.environ, {'AASRT_CONFIG_PATH': str(config_path)}):
|
||||
config = Config()
|
||||
path = config.get('database', 'sqlite', 'path')
|
||||
assert path is not None
|
||||
|
||||
def test_default_for_missing_key(self):
|
||||
"""Test default value is returned for missing keys."""
|
||||
from src.utils.config import Config
|
||||
|
||||
with patch.dict(os.environ, {'SHODAN_API_KEY': 'test_key'}):
|
||||
config = Config()
|
||||
value = config.get('nonexistent', 'key', default='default_value')
|
||||
assert value == 'default_value'
|
||||
|
||||
def test_none_for_missing_key_no_default(self):
|
||||
"""Test None is returned for missing keys without default."""
|
||||
from src.utils.config import Config
|
||||
|
||||
with patch.dict(os.environ, {'SHODAN_API_KEY': 'test_key'}):
|
||||
config = Config()
|
||||
value = config.get('nonexistent', 'key')
|
||||
assert value is None
|
||||
|
||||
@@ -0,0 +1,204 @@
|
||||
"""
|
||||
Unit Tests for Database Module
|
||||
|
||||
Tests for src/storage/database.py
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class TestDatabaseInit:
|
||||
"""Tests for Database initialization."""
|
||||
|
||||
def test_init_creates_tables(self, temp_db, mock_config):
|
||||
"""Test database initialization creates tables."""
|
||||
from src.storage.database import Database
|
||||
|
||||
mock_config.get.side_effect = lambda *args, **kwargs: {
|
||||
('database', 'type'): 'sqlite',
|
||||
('database', 'sqlite', 'path'): str(temp_db),
|
||||
}.get(args, kwargs.get('default'))
|
||||
|
||||
db = Database(mock_config)
|
||||
assert db is not None
|
||||
db.close()
|
||||
|
||||
def test_init_sqlite_with_temp_path(self, temp_db, mock_config):
|
||||
"""Test SQLite database with temp path."""
|
||||
from src.storage.database import Database
|
||||
|
||||
mock_config.get.side_effect = lambda *args, **kwargs: {
|
||||
('database', 'type'): 'sqlite',
|
||||
('database', 'sqlite', 'path'): str(temp_db),
|
||||
}.get(args, kwargs.get('default'))
|
||||
|
||||
db = Database(mock_config)
|
||||
assert db is not None
|
||||
assert db._db_type == 'sqlite'
|
||||
db.close()
|
||||
|
||||
|
||||
class TestDatabaseOperations:
|
||||
"""Tests for database CRUD operations."""
|
||||
|
||||
@pytest.fixture
|
||||
def db(self, temp_db, mock_config):
|
||||
"""Create a test database instance."""
|
||||
from src.storage.database import Database
|
||||
|
||||
mock_config.get.side_effect = lambda *args, **kwargs: {
|
||||
('database', 'type'): 'sqlite',
|
||||
('database', 'sqlite', 'path'): str(temp_db),
|
||||
}.get(args, kwargs.get('default'))
|
||||
|
||||
db = Database(mock_config)
|
||||
yield db
|
||||
db.close()
|
||||
|
||||
def test_create_scan(self, db):
|
||||
"""Test creating a scan record."""
|
||||
scan = db.create_scan(
|
||||
engines=['shodan'],
|
||||
query='http.title:"ClawdBot"'
|
||||
)
|
||||
assert scan is not None
|
||||
assert scan.scan_id is not None
|
||||
|
||||
def test_get_scan_by_id(self, db):
|
||||
"""Test retrieving a scan by ID."""
|
||||
# Create a scan first
|
||||
scan = db.create_scan(
|
||||
engines=['shodan'],
|
||||
query='test query'
|
||||
)
|
||||
|
||||
retrieved = db.get_scan(scan.scan_id)
|
||||
assert retrieved is not None
|
||||
assert retrieved.scan_id == scan.scan_id
|
||||
|
||||
def test_get_recent_scans(self, db):
|
||||
"""Test retrieving recent scans."""
|
||||
# Create a few scans
|
||||
for i in range(3):
|
||||
db.create_scan(
|
||||
engines=['shodan'],
|
||||
query=f'test query {i}'
|
||||
)
|
||||
|
||||
scans = db.get_recent_scans(limit=10)
|
||||
assert len(scans) >= 3
|
||||
|
||||
def test_add_findings(self, db):
|
||||
"""Test adding findings to a scan."""
|
||||
from src.engines.base import SearchResult
|
||||
|
||||
# First create a scan
|
||||
scan = db.create_scan(
|
||||
engines=['shodan'],
|
||||
query='test'
|
||||
)
|
||||
|
||||
# Create some search results
|
||||
results = [
|
||||
SearchResult(
|
||||
ip='192.0.2.1',
|
||||
port=8080,
|
||||
banner='ClawdBot Dashboard',
|
||||
vulnerabilities=['exposed_dashboard']
|
||||
)
|
||||
]
|
||||
|
||||
count = db.add_findings(scan.scan_id, results)
|
||||
assert count >= 1
|
||||
|
||||
def test_update_scan(self, db):
|
||||
"""Test updating a scan."""
|
||||
# Create a scan
|
||||
scan = db.create_scan(
|
||||
engines=['shodan'],
|
||||
query='test'
|
||||
)
|
||||
|
||||
# Update it
|
||||
updated = db.update_scan(
|
||||
scan.scan_id,
|
||||
status='completed',
|
||||
total_results=5
|
||||
)
|
||||
|
||||
assert updated is not None
|
||||
assert updated.status == 'completed'
|
||||
|
||||
|
||||
class TestDatabaseHealthCheck:
|
||||
"""Tests for database health check."""
|
||||
|
||||
@pytest.fixture
|
||||
def db(self, temp_db, mock_config):
|
||||
"""Create a test database instance."""
|
||||
from src.storage.database import Database
|
||||
|
||||
mock_config.get.side_effect = lambda *args, **kwargs: {
|
||||
('database', 'type'): 'sqlite',
|
||||
('database', 'sqlite', 'path'): str(temp_db),
|
||||
}.get(args, kwargs.get('default'))
|
||||
|
||||
with patch('src.storage.database.Config', return_value=mock_config):
|
||||
db = Database(mock_config)
|
||||
yield db
|
||||
db.close()
|
||||
|
||||
def test_health_check_returns_dict(self, db):
|
||||
"""Test health_check returns a dictionary."""
|
||||
health = db.health_check()
|
||||
assert isinstance(health, dict)
|
||||
|
||||
def test_health_check_includes_status(self, db):
|
||||
"""Test health_check includes status."""
|
||||
health = db.health_check()
|
||||
assert 'status' in health or 'healthy' in health
|
||||
|
||||
def test_health_check_includes_latency(self, db):
|
||||
"""Test health_check includes latency measurement."""
|
||||
health = db.health_check()
|
||||
# Should have some form of latency/response time
|
||||
has_latency = 'latency' in health or 'latency_ms' in health or 'response_time' in health
|
||||
assert has_latency or health.get('status') == 'healthy'
|
||||
|
||||
|
||||
class TestDatabaseSessionScope:
|
||||
"""Tests for session_scope context manager."""
|
||||
|
||||
@pytest.fixture
|
||||
def db(self, temp_db, mock_config):
|
||||
"""Create a test database instance."""
|
||||
from src.storage.database import Database
|
||||
|
||||
mock_config.get.side_effect = lambda *args, **kwargs: {
|
||||
('database', 'type'): 'sqlite',
|
||||
('database', 'sqlite', 'path'): str(temp_db),
|
||||
}.get(args, kwargs.get('default'))
|
||||
|
||||
with patch('src.storage.database.Config', return_value=mock_config):
|
||||
db = Database(mock_config)
|
||||
yield db
|
||||
db.close()
|
||||
|
||||
def test_session_scope_commits(self, db):
|
||||
"""Test session_scope commits on success."""
|
||||
with db.session_scope() as session:
|
||||
# Perform some operation
|
||||
pass
|
||||
# Should complete without error
|
||||
|
||||
def test_session_scope_rollback_on_error(self, db):
|
||||
"""Test session_scope rolls back on error."""
|
||||
try:
|
||||
with db.session_scope() as session:
|
||||
raise ValueError("Test error")
|
||||
except ValueError:
|
||||
pass # Expected
|
||||
# Session should have been rolled back
|
||||
|
||||
@@ -0,0 +1,125 @@
|
||||
"""
|
||||
Unit Tests for Risk Scorer Module
|
||||
|
||||
Tests for src/core/risk_scorer.py
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
|
||||
class TestRiskScorer:
|
||||
"""Tests for RiskScorer class."""
|
||||
|
||||
@pytest.fixture
|
||||
def risk_scorer(self):
|
||||
"""Create a RiskScorer instance."""
|
||||
from src.core.risk_scorer import RiskScorer
|
||||
return RiskScorer()
|
||||
|
||||
@pytest.fixture
|
||||
def sample_vulnerabilities(self):
|
||||
"""Create sample Vulnerability objects."""
|
||||
from src.core.vulnerability_assessor import Vulnerability
|
||||
return [
|
||||
Vulnerability(
|
||||
check_name='exposed_dashboard',
|
||||
severity='HIGH',
|
||||
cvss_score=7.5,
|
||||
description='Dashboard exposed without authentication'
|
||||
)
|
||||
]
|
||||
|
||||
@pytest.fixture
|
||||
def sample_result(self, sample_shodan_result):
|
||||
"""Create a SearchResult with vulnerabilities."""
|
||||
from src.engines.base import SearchResult
|
||||
|
||||
result = SearchResult(
|
||||
ip=sample_shodan_result['ip_str'],
|
||||
port=sample_shodan_result['port'],
|
||||
banner=sample_shodan_result['data'],
|
||||
metadata=sample_shodan_result
|
||||
)
|
||||
return result
|
||||
|
||||
def test_calculate_score_returns_dict(self, risk_scorer, sample_vulnerabilities):
|
||||
"""Test that calculate_score returns a dictionary."""
|
||||
result = risk_scorer.calculate_score(sample_vulnerabilities)
|
||||
assert isinstance(result, dict)
|
||||
assert 'overall_score' in result
|
||||
|
||||
def test_calculate_score_range_valid(self, risk_scorer, sample_vulnerabilities):
|
||||
"""Test that score is within valid range (0-10)."""
|
||||
result = risk_scorer.calculate_score(sample_vulnerabilities)
|
||||
assert 0 <= result['overall_score'] <= 10
|
||||
|
||||
def test_high_risk_vulnerabilities_increase_score(self, risk_scorer):
|
||||
"""Test that high-risk vulnerabilities increase score."""
|
||||
from src.core.vulnerability_assessor import Vulnerability
|
||||
|
||||
# High severity vulnerabilities
|
||||
high_vulns = [
|
||||
Vulnerability(check_name='api_key_exposure', severity='CRITICAL', cvss_score=9.0, description='API key exposed'),
|
||||
Vulnerability(check_name='no_authentication', severity='CRITICAL', cvss_score=9.5, description='No auth')
|
||||
]
|
||||
|
||||
# Low severity vulnerabilities
|
||||
low_vulns = [
|
||||
Vulnerability(check_name='version_exposed', severity='LOW', cvss_score=2.0, description='Version info')
|
||||
]
|
||||
|
||||
high_score = risk_scorer.calculate_score(high_vulns)['overall_score']
|
||||
low_score = risk_scorer.calculate_score(low_vulns)['overall_score']
|
||||
|
||||
assert high_score > low_score
|
||||
|
||||
def test_empty_vulnerabilities_zero_score(self, risk_scorer):
|
||||
"""Test that no vulnerabilities result in zero score."""
|
||||
result = risk_scorer.calculate_score([])
|
||||
assert result['overall_score'] == 0
|
||||
|
||||
def test_score_result_updates_search_result(self, risk_scorer, sample_result, sample_vulnerabilities):
|
||||
"""Test that score_result updates the SearchResult."""
|
||||
scored_result = risk_scorer.score_result(sample_result, sample_vulnerabilities)
|
||||
assert scored_result.risk_score >= 0
|
||||
assert 'risk_assessment' in scored_result.metadata
|
||||
|
||||
def test_context_multipliers_applied(self, risk_scorer, sample_vulnerabilities):
|
||||
"""Test that context multipliers affect the score."""
|
||||
# Score with no context
|
||||
base_score = risk_scorer.calculate_score(sample_vulnerabilities)['overall_score']
|
||||
|
||||
# Score with context multiplier
|
||||
context = {'public_internet': True, 'no_waf': True, 'ai_agent': True}
|
||||
context_score = risk_scorer.calculate_score(sample_vulnerabilities, context)['overall_score']
|
||||
|
||||
# Context should increase or maintain score
|
||||
assert context_score >= base_score
|
||||
|
||||
def test_severity_breakdown_included(self, risk_scorer, sample_vulnerabilities):
|
||||
"""Test that severity breakdown is included in results."""
|
||||
result = risk_scorer.calculate_score(sample_vulnerabilities)
|
||||
assert 'severity_breakdown' in result
|
||||
assert isinstance(result['severity_breakdown'], dict)
|
||||
|
||||
|
||||
class TestRiskCategories:
|
||||
"""Tests for risk categorization."""
|
||||
|
||||
@pytest.fixture
|
||||
def risk_scorer(self):
|
||||
"""Create a RiskScorer instance."""
|
||||
from src.core.risk_scorer import RiskScorer
|
||||
return RiskScorer()
|
||||
|
||||
def test_get_risk_level(self, risk_scorer):
|
||||
"""Test risk level categorization."""
|
||||
# Test if there's a method to get risk level string
|
||||
if hasattr(risk_scorer, 'get_risk_level'):
|
||||
assert risk_scorer.get_risk_level(95) in ['CRITICAL', 'HIGH', 'critical', 'high']
|
||||
assert risk_scorer.get_risk_level(75) in ['HIGH', 'MEDIUM', 'high', 'medium']
|
||||
assert risk_scorer.get_risk_level(50) in ['MEDIUM', 'medium']
|
||||
assert risk_scorer.get_risk_level(25) in ['LOW', 'low']
|
||||
assert risk_scorer.get_risk_level(10) in ['INFO', 'LOW', 'info', 'low']
|
||||
|
||||
@@ -0,0 +1,179 @@
|
||||
"""
|
||||
Unit Tests for Shodan Engine Module
|
||||
|
||||
Tests for src/engines/shodan_engine.py
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
|
||||
class TestShodanEngineInit:
|
||||
"""Tests for ShodanEngine initialization."""
|
||||
|
||||
def test_init_with_valid_key(self):
|
||||
"""Test initialization with valid API key."""
|
||||
from src.engines.shodan_engine import ShodanEngine
|
||||
|
||||
with patch('shodan.Shodan') as mock_shodan:
|
||||
mock_client = MagicMock()
|
||||
mock_shodan.return_value = mock_client
|
||||
|
||||
engine = ShodanEngine(api_key='test_api_key_12345')
|
||||
assert engine is not None
|
||||
assert engine.name == 'shodan'
|
||||
|
||||
def test_init_without_key_raises_error(self):
|
||||
"""Test initialization without API key raises error."""
|
||||
from src.engines.shodan_engine import ShodanEngine
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
ShodanEngine(api_key='')
|
||||
|
||||
with pytest.raises((ValueError, TypeError)):
|
||||
ShodanEngine(api_key=None)
|
||||
|
||||
|
||||
class TestShodanEngineSearch:
|
||||
"""Tests for ShodanEngine search functionality."""
|
||||
|
||||
@pytest.fixture
|
||||
def engine(self, mock_shodan_client):
|
||||
"""Create a ShodanEngine instance with mocked client."""
|
||||
from src.engines.shodan_engine import ShodanEngine
|
||||
|
||||
with patch('shodan.Shodan', return_value=mock_shodan_client):
|
||||
engine = ShodanEngine(api_key='test_api_key_12345')
|
||||
engine._client = mock_shodan_client
|
||||
return engine
|
||||
|
||||
def test_search_returns_results(self, engine, mock_shodan_client, sample_shodan_result):
|
||||
"""Test search returns results."""
|
||||
mock_shodan_client.search.return_value = {
|
||||
'matches': [sample_shodan_result],
|
||||
'total': 1
|
||||
}
|
||||
|
||||
results = engine.search('http.title:"ClawdBot"')
|
||||
assert isinstance(results, list)
|
||||
|
||||
def test_search_empty_query_raises_error(self, engine):
|
||||
"""Test empty query raises error."""
|
||||
with pytest.raises((ValueError, Exception)):
|
||||
engine.search('')
|
||||
|
||||
def test_search_handles_api_error(self, engine, mock_shodan_client):
|
||||
"""Test search handles API errors gracefully."""
|
||||
import shodan
|
||||
|
||||
mock_shodan_client.search.side_effect = shodan.APIError('API Error')
|
||||
|
||||
from src.utils.exceptions import APIException
|
||||
with pytest.raises((APIException, Exception)):
|
||||
engine.search('test query')
|
||||
|
||||
def test_search_with_max_results(self, engine, mock_shodan_client, sample_shodan_result):
|
||||
"""Test search respects max_results limit."""
|
||||
mock_shodan_client.search.return_value = {
|
||||
'matches': [sample_shodan_result],
|
||||
'total': 1
|
||||
}
|
||||
|
||||
results = engine.search('test', max_results=1)
|
||||
assert len(results) <= 1
|
||||
|
||||
|
||||
class TestShodanEngineCredentials:
|
||||
"""Tests for credential validation."""
|
||||
|
||||
@pytest.fixture
|
||||
def engine(self, mock_shodan_client):
|
||||
"""Create a ShodanEngine instance."""
|
||||
from src.engines.shodan_engine import ShodanEngine
|
||||
|
||||
with patch('shodan.Shodan', return_value=mock_shodan_client):
|
||||
engine = ShodanEngine(api_key='test_api_key_12345')
|
||||
engine._client = mock_shodan_client
|
||||
return engine
|
||||
|
||||
def test_validate_credentials_success(self, engine, mock_shodan_client):
|
||||
"""Test successful credential validation."""
|
||||
mock_shodan_client.info.return_value = {'plan': 'dev', 'query_credits': 100}
|
||||
|
||||
result = engine.validate_credentials()
|
||||
assert result is True
|
||||
|
||||
def test_validate_credentials_invalid_key(self, engine, mock_shodan_client):
|
||||
"""Test invalid API key handling."""
|
||||
import shodan
|
||||
from src.utils.exceptions import AuthenticationException
|
||||
|
||||
mock_shodan_client.info.side_effect = shodan.APIError('Invalid API key')
|
||||
|
||||
with pytest.raises((AuthenticationException, Exception)):
|
||||
engine.validate_credentials()
|
||||
|
||||
|
||||
class TestShodanEngineQuota:
|
||||
"""Tests for quota information."""
|
||||
|
||||
@pytest.fixture
|
||||
def engine(self, mock_shodan_client):
|
||||
"""Create a ShodanEngine instance."""
|
||||
from src.engines.shodan_engine import ShodanEngine
|
||||
|
||||
with patch('shodan.Shodan', return_value=mock_shodan_client):
|
||||
engine = ShodanEngine(api_key='test_api_key_12345')
|
||||
engine._client = mock_shodan_client
|
||||
return engine
|
||||
|
||||
def test_get_quota_info(self, engine, mock_shodan_client):
|
||||
"""Test getting quota information."""
|
||||
mock_shodan_client.info.return_value = {
|
||||
'plan': 'dev',
|
||||
'query_credits': 100,
|
||||
'scan_credits': 50
|
||||
}
|
||||
|
||||
quota = engine.get_quota_info()
|
||||
assert isinstance(quota, dict)
|
||||
|
||||
def test_quota_info_handles_error(self, engine, mock_shodan_client):
|
||||
"""Test quota info handles API errors."""
|
||||
import shodan
|
||||
from src.utils.exceptions import APIException
|
||||
|
||||
mock_shodan_client.info.side_effect = shodan.APIError('API Error')
|
||||
|
||||
# May either raise or return error info depending on implementation
|
||||
try:
|
||||
quota = engine.get_quota_info()
|
||||
assert quota is not None
|
||||
except (APIException, Exception):
|
||||
pass # Acceptable if it raises
|
||||
|
||||
|
||||
class TestShodanEngineRetry:
|
||||
"""Tests for retry logic."""
|
||||
|
||||
def test_retry_on_transient_error(self, mock_shodan_client):
|
||||
"""Test retry logic on transient errors."""
|
||||
from src.engines.shodan_engine import ShodanEngine
|
||||
import shodan
|
||||
|
||||
with patch('shodan.Shodan', return_value=mock_shodan_client):
|
||||
engine = ShodanEngine(api_key='test_api_key_12345')
|
||||
engine._client = mock_shodan_client
|
||||
|
||||
# First call fails, second succeeds
|
||||
mock_shodan_client.search.side_effect = [
|
||||
ConnectionError("Network error"),
|
||||
{'matches': [], 'total': 0}
|
||||
]
|
||||
|
||||
# Depending on implementation, this may retry or raise
|
||||
try:
|
||||
results = engine.search('test')
|
||||
assert isinstance(results, list)
|
||||
except Exception:
|
||||
pass # Expected if retries exhausted
|
||||
@@ -0,0 +1,180 @@
|
||||
"""
|
||||
Unit Tests for Validators Module
|
||||
|
||||
Tests for src/utils/validators.py
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from src.utils.validators import (
|
||||
validate_ip,
|
||||
validate_domain,
|
||||
validate_query,
|
||||
validate_file_path,
|
||||
validate_template_name,
|
||||
is_safe_string,
|
||||
sanitize_output,
|
||||
)
|
||||
from src.utils.exceptions import ValidationException
|
||||
|
||||
|
||||
class TestValidateIP:
|
||||
"""Tests for IP address validation."""
|
||||
|
||||
def test_valid_ipv4(self):
|
||||
"""Test valid IPv4 addresses."""
|
||||
assert validate_ip("192.168.1.1") is True
|
||||
assert validate_ip("10.0.0.1") is True
|
||||
assert validate_ip("172.16.0.1") is True
|
||||
assert validate_ip("8.8.8.8") is True
|
||||
|
||||
def test_invalid_ipv4_raises_exception(self):
|
||||
"""Test invalid IPv4 addresses raise ValidationException."""
|
||||
with pytest.raises(ValidationException):
|
||||
validate_ip("256.1.1.1")
|
||||
with pytest.raises(ValidationException):
|
||||
validate_ip("192.168.1")
|
||||
with pytest.raises(ValidationException):
|
||||
validate_ip("not.an.ip.address")
|
||||
|
||||
def test_empty_and_none_raises_exception(self):
|
||||
"""Test empty and None values raise ValidationException."""
|
||||
with pytest.raises(ValidationException):
|
||||
validate_ip("")
|
||||
with pytest.raises(ValidationException):
|
||||
validate_ip(None)
|
||||
|
||||
def test_ipv6_addresses(self):
|
||||
"""Test IPv6 address handling."""
|
||||
# IPv6 addresses should be valid
|
||||
assert validate_ip("::1") is True
|
||||
assert validate_ip("2001:db8::1") is True
|
||||
|
||||
|
||||
class TestValidateDomain:
|
||||
"""Tests for domain validation."""
|
||||
|
||||
def test_valid_domains(self):
|
||||
"""Test valid domain names."""
|
||||
assert validate_domain("example.com") is True
|
||||
assert validate_domain("sub.example.com") is True
|
||||
assert validate_domain("test-site.example.org") is True
|
||||
|
||||
def test_invalid_domains_raises_exception(self):
|
||||
"""Test invalid domain names raise ValidationException."""
|
||||
with pytest.raises(ValidationException):
|
||||
validate_domain("-invalid.com")
|
||||
with pytest.raises(ValidationException):
|
||||
validate_domain("invalid-.com")
|
||||
|
||||
def test_localhost_raises_exception(self):
|
||||
"""Test localhost raises ValidationException (not a valid domain format)."""
|
||||
with pytest.raises(ValidationException):
|
||||
validate_domain("localhost")
|
||||
|
||||
|
||||
class TestValidateQuery:
|
||||
"""Tests for Shodan query validation."""
|
||||
|
||||
def test_valid_queries(self):
|
||||
"""Test valid Shodan queries."""
|
||||
assert validate_query('http.title:"ClawdBot"', engine='shodan') is True
|
||||
assert validate_query("port:8080", engine='shodan') is True
|
||||
assert validate_query("product:nginx", engine='shodan') is True
|
||||
|
||||
def test_empty_query_raises_exception(self):
|
||||
"""Test empty queries raise ValidationException."""
|
||||
with pytest.raises(ValidationException):
|
||||
validate_query("", engine='shodan')
|
||||
with pytest.raises(ValidationException):
|
||||
validate_query(" ", engine='shodan')
|
||||
|
||||
def test_sql_injection_patterns_allowed(self):
|
||||
"""Test SQL-like patterns are allowed (Shodan doesn't execute SQL)."""
|
||||
# Shodan queries can contain SQL-like syntax without causing issues
|
||||
result = validate_query("'; DROP TABLE users; --", engine='shodan')
|
||||
assert result is True # No script tags or null bytes
|
||||
|
||||
|
||||
class TestValidateFilePath:
|
||||
"""Tests for file path validation."""
|
||||
|
||||
def test_valid_paths(self):
|
||||
"""Test valid file paths return sanitized path."""
|
||||
result = validate_file_path("reports/scan.json")
|
||||
assert result is not None
|
||||
assert "scan.json" in result
|
||||
|
||||
def test_directory_traversal_raises_exception(self):
|
||||
"""Test directory traversal raises ValidationException."""
|
||||
with pytest.raises(ValidationException):
|
||||
validate_file_path("../../../etc/passwd")
|
||||
with pytest.raises(ValidationException):
|
||||
validate_file_path("..\\..\\windows\\system32")
|
||||
|
||||
def test_null_bytes_raises_exception(self):
|
||||
"""Test null byte injection raises ValidationException."""
|
||||
with pytest.raises(ValidationException):
|
||||
validate_file_path("file.txt\x00.exe")
|
||||
|
||||
|
||||
class TestValidateTemplateName:
|
||||
"""Tests for template name validation."""
|
||||
|
||||
def test_valid_templates(self):
|
||||
"""Test valid template names."""
|
||||
assert validate_template_name("clawdbot_instances") is True
|
||||
assert validate_template_name("autogpt_instances") is True
|
||||
|
||||
def test_invalid_template_raises_exception(self):
|
||||
"""Test invalid template names raise ValidationException."""
|
||||
with pytest.raises(ValidationException):
|
||||
validate_template_name("nonexistent_template")
|
||||
|
||||
def test_empty_template_raises_exception(self):
|
||||
"""Test empty template names raise ValidationException."""
|
||||
with pytest.raises(ValidationException):
|
||||
validate_template_name("")
|
||||
with pytest.raises(ValidationException):
|
||||
validate_template_name(None)
|
||||
|
||||
|
||||
class TestIsSafeString:
|
||||
"""Tests for safe string detection."""
|
||||
|
||||
def test_safe_strings(self):
|
||||
"""Test safe strings pass validation."""
|
||||
assert is_safe_string("hello world") is True
|
||||
assert is_safe_string("ClawdBot Dashboard") is True
|
||||
|
||||
def test_script_tags_detected(self):
|
||||
"""Test script tags are detected as unsafe."""
|
||||
assert is_safe_string("<script>alert('xss')</script>") is False
|
||||
|
||||
def test_sql_patterns_allowed(self):
|
||||
"""Test SQL-like patterns are allowed (is_safe_string checks XSS, not SQL)."""
|
||||
# Note: is_safe_string focuses on XSS patterns, not SQL injection
|
||||
result = is_safe_string("'; DROP TABLE users; --")
|
||||
# This may or may not be detected depending on implementation
|
||||
assert isinstance(result, bool)
|
||||
|
||||
|
||||
class TestSanitizeOutput:
|
||||
"""Tests for output sanitization."""
|
||||
|
||||
def test_password_redaction(self):
|
||||
"""Test passwords are redacted."""
|
||||
output = sanitize_output("password=mysecretpassword")
|
||||
assert "mysecretpassword" not in output
|
||||
|
||||
def test_normal_text_unchanged(self):
|
||||
"""Test normal text is not modified."""
|
||||
text = "This is normal text without secrets"
|
||||
assert sanitize_output(text) == text
|
||||
|
||||
def test_api_key_pattern_redaction(self):
|
||||
"""Test API key patterns are redacted."""
|
||||
# Test with patterns that match the redaction rules
|
||||
output = sanitize_output("api_key=12345678901234567890")
|
||||
# Depending on implementation, may or may not be redacted
|
||||
assert output is not None
|
||||
|
||||
Reference in New Issue
Block a user