mirror of
https://github.com/0xsrb/AASRT.git
synced 2026-04-24 00:15:58 +02:00
126 lines
4.9 KiB
Python
126 lines
4.9 KiB
Python
"""
|
|
Unit Tests for Risk Scorer Module
|
|
|
|
Tests for src/core/risk_scorer.py
|
|
"""
|
|
|
|
import pytest
|
|
from unittest.mock import MagicMock
|
|
|
|
|
|
class TestRiskScorer:
|
|
"""Tests for RiskScorer class."""
|
|
|
|
@pytest.fixture
|
|
def risk_scorer(self):
|
|
"""Create a RiskScorer instance."""
|
|
from src.core.risk_scorer import RiskScorer
|
|
return RiskScorer()
|
|
|
|
@pytest.fixture
|
|
def sample_vulnerabilities(self):
|
|
"""Create sample Vulnerability objects."""
|
|
from src.core.vulnerability_assessor import Vulnerability
|
|
return [
|
|
Vulnerability(
|
|
check_name='exposed_dashboard',
|
|
severity='HIGH',
|
|
cvss_score=7.5,
|
|
description='Dashboard exposed without authentication'
|
|
)
|
|
]
|
|
|
|
@pytest.fixture
|
|
def sample_result(self, sample_shodan_result):
|
|
"""Create a SearchResult with vulnerabilities."""
|
|
from src.engines.base import SearchResult
|
|
|
|
result = SearchResult(
|
|
ip=sample_shodan_result['ip_str'],
|
|
port=sample_shodan_result['port'],
|
|
banner=sample_shodan_result['data'],
|
|
metadata=sample_shodan_result
|
|
)
|
|
return result
|
|
|
|
def test_calculate_score_returns_dict(self, risk_scorer, sample_vulnerabilities):
|
|
"""Test that calculate_score returns a dictionary."""
|
|
result = risk_scorer.calculate_score(sample_vulnerabilities)
|
|
assert isinstance(result, dict)
|
|
assert 'overall_score' in result
|
|
|
|
def test_calculate_score_range_valid(self, risk_scorer, sample_vulnerabilities):
|
|
"""Test that score is within valid range (0-10)."""
|
|
result = risk_scorer.calculate_score(sample_vulnerabilities)
|
|
assert 0 <= result['overall_score'] <= 10
|
|
|
|
def test_high_risk_vulnerabilities_increase_score(self, risk_scorer):
|
|
"""Test that high-risk vulnerabilities increase score."""
|
|
from src.core.vulnerability_assessor import Vulnerability
|
|
|
|
# High severity vulnerabilities
|
|
high_vulns = [
|
|
Vulnerability(check_name='api_key_exposure', severity='CRITICAL', cvss_score=9.0, description='API key exposed'),
|
|
Vulnerability(check_name='no_authentication', severity='CRITICAL', cvss_score=9.5, description='No auth')
|
|
]
|
|
|
|
# Low severity vulnerabilities
|
|
low_vulns = [
|
|
Vulnerability(check_name='version_exposed', severity='LOW', cvss_score=2.0, description='Version info')
|
|
]
|
|
|
|
high_score = risk_scorer.calculate_score(high_vulns)['overall_score']
|
|
low_score = risk_scorer.calculate_score(low_vulns)['overall_score']
|
|
|
|
assert high_score > low_score
|
|
|
|
def test_empty_vulnerabilities_zero_score(self, risk_scorer):
|
|
"""Test that no vulnerabilities result in zero score."""
|
|
result = risk_scorer.calculate_score([])
|
|
assert result['overall_score'] == 0
|
|
|
|
def test_score_result_updates_search_result(self, risk_scorer, sample_result, sample_vulnerabilities):
|
|
"""Test that score_result updates the SearchResult."""
|
|
scored_result = risk_scorer.score_result(sample_result, sample_vulnerabilities)
|
|
assert scored_result.risk_score >= 0
|
|
assert 'risk_assessment' in scored_result.metadata
|
|
|
|
def test_context_multipliers_applied(self, risk_scorer, sample_vulnerabilities):
|
|
"""Test that context multipliers affect the score."""
|
|
# Score with no context
|
|
base_score = risk_scorer.calculate_score(sample_vulnerabilities)['overall_score']
|
|
|
|
# Score with context multiplier
|
|
context = {'public_internet': True, 'no_waf': True, 'ai_agent': True}
|
|
context_score = risk_scorer.calculate_score(sample_vulnerabilities, context)['overall_score']
|
|
|
|
# Context should increase or maintain score
|
|
assert context_score >= base_score
|
|
|
|
def test_severity_breakdown_included(self, risk_scorer, sample_vulnerabilities):
|
|
"""Test that severity breakdown is included in results."""
|
|
result = risk_scorer.calculate_score(sample_vulnerabilities)
|
|
assert 'severity_breakdown' in result
|
|
assert isinstance(result['severity_breakdown'], dict)
|
|
|
|
|
|
class TestRiskCategories:
|
|
"""Tests for risk categorization."""
|
|
|
|
@pytest.fixture
|
|
def risk_scorer(self):
|
|
"""Create a RiskScorer instance."""
|
|
from src.core.risk_scorer import RiskScorer
|
|
return RiskScorer()
|
|
|
|
def test_get_risk_level(self, risk_scorer):
|
|
"""Test risk level categorization."""
|
|
# Test if there's a method to get risk level string
|
|
if hasattr(risk_scorer, 'get_risk_level'):
|
|
assert risk_scorer.get_risk_level(95) in ['CRITICAL', 'HIGH', 'critical', 'high']
|
|
assert risk_scorer.get_risk_level(75) in ['HIGH', 'MEDIUM', 'high', 'medium']
|
|
assert risk_scorer.get_risk_level(50) in ['MEDIUM', 'medium']
|
|
assert risk_scorer.get_risk_level(25) in ['LOW', 'low']
|
|
assert risk_scorer.get_risk_level(10) in ['INFO', 'LOW', 'info', 'low']
|
|
|