Initial commit: AASRT v1.0.0 - AI Agent Security Reconnaissance Tool

This commit is contained in:
swethab
2026-02-10 10:53:31 -05:00
commit a714a3399b
61 changed files with 14858 additions and 0 deletions
+7
View File
@@ -0,0 +1,7 @@
"""
Integration Tests for AASRT
Integration tests verify that components work together correctly.
These tests may use real (test) databases but mock external APIs.
"""
@@ -0,0 +1,207 @@
"""
Integration Tests for Database Operations
Tests database operations with real SQLite database.
"""
import pytest
import tempfile
from pathlib import Path
from unittest.mock import patch, MagicMock
from datetime import datetime, timedelta
class TestDatabaseIntegration:
"""Integration tests for database with real SQLite."""
@pytest.fixture
def real_db(self):
"""Create a real SQLite database for testing."""
from src.storage.database import Database
from src.utils.config import Config
with tempfile.TemporaryDirectory() as tmpdir:
db_path = Path(tmpdir) / "test_scanner.db"
mock_config = MagicMock()
mock_config.get.side_effect = lambda *args, **kwargs: {
('database', 'type'): 'sqlite',
('database', 'sqlite', 'path'): str(db_path),
}.get(args, kwargs.get('default'))
mock_config.get_shodan_key.return_value = 'test_key'
with patch('src.storage.database.Config', return_value=mock_config):
db = Database(mock_config)
yield db
db.close()
def test_scan_lifecycle(self, real_db):
"""Test complete scan lifecycle: create, update, retrieve."""
# Create scan
scan_id = 'integration-test-scan-001'
real_db.save_scan({
'scan_id': scan_id,
'query': 'http.title:"ClawdBot"',
'engine': 'shodan',
'started_at': datetime.utcnow(),
'status': 'running',
'total_results': 0
})
# Update scan status
real_db.update_scan(scan_id, {
'status': 'completed',
'total_results': 25,
'completed_at': datetime.utcnow()
})
# Retrieve scan
scan = real_db.get_scan(scan_id)
assert scan is not None
def test_findings_association(self, real_db):
"""Test findings are properly associated with scans."""
scan_id = 'integration-test-scan-002'
# Create scan
real_db.save_scan({
'scan_id': scan_id,
'query': 'test query',
'engine': 'shodan',
'started_at': datetime.utcnow(),
'status': 'completed'
})
# Save multiple findings
for i in range(5):
real_db.save_finding({
'scan_id': scan_id,
'ip': f'192.0.2.{i+1}',
'port': 8080 + i,
'risk_score': 50 + i * 10,
'vulnerabilities': ['test_vuln']
})
# Retrieve findings
findings = real_db.get_findings_by_scan(scan_id)
assert len(findings) == 5
def test_scan_statistics(self, real_db):
"""Test scan statistics calculation."""
# Create multiple scans with different statuses
for i in range(10):
real_db.save_scan({
'scan_id': f'stats-test-{i:03d}',
'query': f'test query {i}',
'engine': 'shodan',
'started_at': datetime.utcnow() - timedelta(days=i),
'status': 'completed' if i % 2 == 0 else 'failed',
'total_results': i * 10
})
# Get statistics
if hasattr(real_db, 'get_scan_statistics'):
stats = real_db.get_scan_statistics()
assert 'total_scans' in stats or stats is not None
def test_concurrent_operations(self, real_db):
"""Test concurrent database operations."""
import threading
errors = []
def save_scan(scan_num):
try:
real_db.save_scan({
'scan_id': f'concurrent-test-{scan_num:03d}',
'query': f'test {scan_num}',
'engine': 'shodan',
'started_at': datetime.utcnow(),
'status': 'completed'
})
except Exception as e:
errors.append(e)
threads = [threading.Thread(target=save_scan, args=(i,)) for i in range(5)]
for t in threads:
t.start()
for t in threads:
t.join()
# Should complete without deadlocks
assert len(errors) == 0
def test_data_persistence(self):
"""Test that data persists across database connections."""
from src.storage.database import Database
with tempfile.TemporaryDirectory() as tmpdir:
db_path = Path(tmpdir) / "persistence_test.db"
scan_id = 'persistence-test-001'
mock_config = MagicMock()
mock_config.get.side_effect = lambda *args, **kwargs: {
('database', 'type'): 'sqlite',
('database', 'sqlite', 'path'): str(db_path),
}.get(args, kwargs.get('default'))
# First connection - create data
with patch('src.storage.database.Config', return_value=mock_config):
db1 = Database(mock_config)
db1.save_scan({
'scan_id': scan_id,
'query': 'test',
'engine': 'shodan',
'started_at': datetime.utcnow(),
'status': 'completed'
})
db1.close()
# Second connection - verify data exists
with patch('src.storage.database.Config', return_value=mock_config):
db2 = Database(mock_config)
scan = db2.get_scan(scan_id)
db2.close()
assert scan is not None
class TestDatabaseCleanup:
"""Tests for database cleanup and maintenance."""
@pytest.fixture
def real_db(self):
"""Create a real SQLite database for testing."""
from src.storage.database import Database
with tempfile.TemporaryDirectory() as tmpdir:
db_path = Path(tmpdir) / "cleanup_test.db"
mock_config = MagicMock()
mock_config.get.side_effect = lambda *args, **kwargs: {
('database', 'type'): 'sqlite',
('database', 'sqlite', 'path'): str(db_path),
}.get(args, kwargs.get('default'))
with patch('src.storage.database.Config', return_value=mock_config):
db = Database(mock_config)
yield db
db.close()
def test_delete_old_scans(self, real_db):
"""Test deleting old scan records."""
# Create old scans
for i in range(5):
real_db.save_scan({
'scan_id': f'old-scan-{i:03d}',
'query': f'test {i}',
'engine': 'shodan',
'started_at': datetime.utcnow() - timedelta(days=365),
'status': 'completed'
})
# If cleanup method exists, test it
if hasattr(real_db, 'cleanup_old_scans'):
deleted = real_db.cleanup_old_scans(days=30)
assert deleted >= 0
+198
View File
@@ -0,0 +1,198 @@
"""
Integration Tests for Scan Workflow
Tests the complete scan workflow from query to report generation.
"""
import pytest
import tempfile
from pathlib import Path
from unittest.mock import MagicMock, patch
from datetime import datetime
class TestEndToEndScan:
"""Integration tests for complete scan workflow."""
@pytest.fixture
def mock_shodan_response(self, sample_search_results):
"""Mock Shodan API response."""
return {
'matches': sample_search_results,
'total': len(sample_search_results)
}
@pytest.fixture
def temp_workspace(self):
"""Create temporary workspace for test files."""
with tempfile.TemporaryDirectory() as tmpdir:
workspace = Path(tmpdir)
(workspace / 'reports').mkdir()
(workspace / 'data').mkdir()
(workspace / 'logs').mkdir()
yield workspace
def test_scan_template_workflow(self, mock_shodan_response, temp_workspace):
"""Test scanning using a template."""
from unittest.mock import patch, MagicMock
mock_client = MagicMock()
mock_client.info.return_value = {'plan': 'dev', 'query_credits': 100}
mock_client.search.return_value = mock_shodan_response
with patch('shodan.Shodan', return_value=mock_client):
with patch.dict('os.environ', {
'SHODAN_API_KEY': 'test_key_12345',
'AASRT_REPORTS_DIR': str(temp_workspace / 'reports'),
'AASRT_DATA_DIR': str(temp_workspace / 'data'),
}):
# Import after patching
from src.core.query_manager import QueryManager
from src.utils.config import Config
config = Config()
qm = QueryManager(config)
# Check templates are available
templates = qm.get_available_templates()
assert len(templates) > 0
def test_custom_query_workflow(self, mock_shodan_response, temp_workspace):
"""Test scanning with a custom query."""
mock_client = MagicMock()
mock_client.info.return_value = {'plan': 'dev', 'query_credits': 100}
mock_client.search.return_value = mock_shodan_response
with patch('shodan.Shodan', return_value=mock_client):
with patch.dict('os.environ', {
'SHODAN_API_KEY': 'test_key_12345',
}):
from src.engines.shodan_engine import ShodanEngine
from src.utils.config import Config
config = Config()
engine = ShodanEngine(config=config)
engine._client = mock_client
results = engine.search('http.title:"Test"')
assert len(results) > 0
class TestVulnerabilityAssessmentIntegration:
"""Integration tests for vulnerability assessment pipeline."""
def test_assess_search_results(self, sample_search_results):
"""Test vulnerability assessment on search results."""
from src.core.vulnerability_assessor import VulnerabilityAssessor
from src.engines.base import SearchResult
assessor = VulnerabilityAssessor()
# Convert sample data to SearchResult
result = SearchResult(
ip=sample_search_results[0]['ip_str'],
port=sample_search_results[0]['port'],
protocol='tcp',
banner=sample_search_results[0].get('data', ''),
metadata=sample_search_results[0]
)
vulns = assessor.assess(result)
# Should return a list (may be empty if no vulns detected)
assert isinstance(vulns, list)
def test_risk_scoring_integration(self, sample_search_results):
"""Test risk scoring on assessed results."""
from src.core.risk_scorer import RiskScorer
from src.core.vulnerability_assessor import VulnerabilityAssessor
from src.engines.base import SearchResult
assessor = VulnerabilityAssessor()
scorer = RiskScorer()
result = SearchResult(
ip=sample_search_results[0]['ip_str'],
port=sample_search_results[0]['port'],
protocol='tcp',
banner=sample_search_results[0].get('data', ''),
metadata=sample_search_results[0]
)
vulns = assessor.assess(result)
score = scorer.score(result)
assert 0 <= score <= 100
class TestReportGenerationIntegration:
"""Integration tests for report generation."""
@pytest.fixture
def temp_reports_dir(self):
"""Create temporary reports directory."""
with tempfile.TemporaryDirectory() as tmpdir:
yield Path(tmpdir)
def test_json_report_generation(self, temp_reports_dir, sample_search_results):
"""Test JSON report generation."""
from src.reporting import JSONReporter, ScanReport
from src.engines.base import SearchResult
# Create scan report data
results = [
SearchResult(
ip=r['ip_str'],
port=r['port'],
protocol='tcp',
banner=r.get('data', ''),
metadata=r
) for r in sample_search_results
]
report = ScanReport(
scan_id='test-scan-001',
query='test query',
engine='shodan',
started_at=datetime.utcnow(),
completed_at=datetime.utcnow(),
results=results,
total_results=len(results)
)
reporter = JSONReporter(output_dir=str(temp_reports_dir))
output_path = reporter.generate(report)
assert Path(output_path).exists()
assert output_path.endswith('.json')
def test_csv_report_generation(self, temp_reports_dir, sample_search_results):
"""Test CSV report generation."""
from src.reporting import CSVReporter, ScanReport
from src.engines.base import SearchResult
results = [
SearchResult(
ip=r['ip_str'],
port=r['port'],
protocol='tcp',
banner=r.get('data', ''),
metadata=r
) for r in sample_search_results
]
report = ScanReport(
scan_id='test-scan-002',
query='test query',
engine='shodan',
started_at=datetime.utcnow(),
completed_at=datetime.utcnow(),
results=results,
total_results=len(results)
)
reporter = CSVReporter(output_dir=str(temp_reports_dir))
output_path = reporter.generate(report)
assert Path(output_path).exists()
assert output_path.endswith('.csv')