mirror of
https://github.com/0xsrb/AASRT.git
synced 2026-04-23 17:36:04 +02:00
Initial commit: AASRT v1.0.0 - AI Agent Security Reconnaissance Tool
This commit is contained in:
@@ -0,0 +1,26 @@
|
||||
"""Utility modules for AASRT."""
|
||||
|
||||
from .config import Config
|
||||
from .logger import setup_logger, get_logger
|
||||
from .exceptions import (
|
||||
AASRTException,
|
||||
APIException,
|
||||
RateLimitException,
|
||||
ConfigurationException,
|
||||
ValidationException
|
||||
)
|
||||
from .validators import validate_ip, validate_domain, validate_query
|
||||
|
||||
__all__ = [
|
||||
'Config',
|
||||
'setup_logger',
|
||||
'get_logger',
|
||||
'AASRTException',
|
||||
'APIException',
|
||||
'RateLimitException',
|
||||
'ConfigurationException',
|
||||
'ValidationException',
|
||||
'validate_ip',
|
||||
'validate_domain',
|
||||
'validate_query'
|
||||
]
|
||||
@@ -0,0 +1,513 @@
|
||||
"""
|
||||
Configuration management for AASRT.
|
||||
|
||||
This module provides a production-ready configuration management system with:
|
||||
- Singleton pattern for global configuration access
|
||||
- YAML file loading with deep merging
|
||||
- Environment variable overrides
|
||||
- Validation of required settings
|
||||
- Support for structured logging configuration
|
||||
- Health check capabilities
|
||||
|
||||
Configuration priority (highest to lowest):
|
||||
1. Environment variables
|
||||
2. YAML configuration file
|
||||
3. Default values
|
||||
|
||||
Example:
|
||||
>>> from src.utils.config import Config
|
||||
>>> config = Config()
|
||||
>>> shodan_key = config.get_shodan_key()
|
||||
>>> log_level = config.get('logging', 'level', default='INFO')
|
||||
|
||||
Environment Variables:
|
||||
SHODAN_API_KEY: Required Shodan API key
|
||||
AASRT_LOG_LEVEL: Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
|
||||
AASRT_ENVIRONMENT: Deployment environment (development, staging, production)
|
||||
AASRT_DEBUG: Enable debug mode (true/false)
|
||||
DB_TYPE: Database type (sqlite, postgresql)
|
||||
DB_HOST, DB_PORT, DB_NAME, DB_USER, DB_PASSWORD: PostgreSQL settings
|
||||
"""
|
||||
|
||||
import os
|
||||
import secrets
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
import yaml
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from .exceptions import ConfigurationException
|
||||
from .logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# =============================================================================
|
||||
# Validation Constants
|
||||
# =============================================================================
|
||||
|
||||
VALID_LOG_LEVELS: Set[str] = {"DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"}
|
||||
VALID_ENVIRONMENTS: Set[str] = {"development", "staging", "production"}
|
||||
VALID_DB_TYPES: Set[str] = {"sqlite", "postgresql", "mysql"}
|
||||
REQUIRED_SETTINGS: List[str] = [] # API key is optional until scan is run
|
||||
|
||||
|
||||
class Config:
|
||||
"""
|
||||
Configuration manager for AASRT with singleton pattern.
|
||||
|
||||
This class provides centralized configuration management with:
|
||||
- Thread-safe singleton access
|
||||
- YAML file configuration
|
||||
- Environment variable overrides
|
||||
- Validation of critical settings
|
||||
- Health check for configuration state
|
||||
|
||||
Attributes:
|
||||
_instance: Singleton instance.
|
||||
_config: Configuration dictionary.
|
||||
_initialized: Flag indicating initialization status.
|
||||
_config_path: Path to loaded configuration file.
|
||||
_environment: Current deployment environment.
|
||||
|
||||
Example:
|
||||
>>> config = Config()
|
||||
>>> api_key = config.get_shodan_key()
|
||||
>>> if not api_key:
|
||||
... print("Warning: Shodan API key not configured")
|
||||
"""
|
||||
|
||||
_instance: Optional['Config'] = None
|
||||
_config: Dict[str, Any] = {}
|
||||
|
||||
def __new__(cls, config_path: Optional[str] = None):
|
||||
"""
|
||||
Singleton pattern implementation.
|
||||
|
||||
Args:
|
||||
config_path: Optional path to YAML configuration file.
|
||||
|
||||
Returns:
|
||||
Singleton Config instance.
|
||||
"""
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance._initialized = False
|
||||
return cls._instance
|
||||
|
||||
def __init__(self, config_path: Optional[str] = None) -> None:
|
||||
"""
|
||||
Initialize configuration from multiple sources.
|
||||
|
||||
Configuration is loaded in order of priority:
|
||||
1. Default values
|
||||
2. YAML configuration file
|
||||
3. Environment variables (highest priority)
|
||||
|
||||
Args:
|
||||
config_path: Path to YAML configuration file.
|
||||
If not provided, searches common locations.
|
||||
|
||||
Raises:
|
||||
ConfigurationException: If YAML file is malformed.
|
||||
"""
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
# Load environment variables from .env file
|
||||
load_dotenv()
|
||||
|
||||
# Store metadata
|
||||
self._config_path: Optional[str] = None
|
||||
self._environment: str = os.getenv('AASRT_ENVIRONMENT', 'development')
|
||||
self._validation_errors: List[str] = []
|
||||
|
||||
# Default configuration
|
||||
self._config = self._get_defaults()
|
||||
|
||||
# Load from file if provided
|
||||
if config_path:
|
||||
self._load_from_file(config_path)
|
||||
else:
|
||||
# Try to find config file in common locations
|
||||
for path in ['config.yaml', 'config.yml', './config/config.yaml']:
|
||||
if os.path.exists(path):
|
||||
self._load_from_file(path)
|
||||
break
|
||||
|
||||
# Override with environment variables
|
||||
self._load_from_env()
|
||||
|
||||
# Validate configuration
|
||||
self._validate_config()
|
||||
|
||||
self._initialized = True
|
||||
logger.info(f"Configuration initialized (environment: {self._environment})")
|
||||
|
||||
def _get_defaults(self) -> Dict[str, Any]:
|
||||
"""Get default configuration values."""
|
||||
return {
|
||||
'shodan': {
|
||||
'enabled': True,
|
||||
'rate_limit': 1,
|
||||
'max_results': 100,
|
||||
'timeout': 30
|
||||
},
|
||||
'vulnerability_checks': {
|
||||
'enabled': True,
|
||||
'passive_only': True,
|
||||
'timeout_per_check': 10
|
||||
},
|
||||
'reporting': {
|
||||
'formats': ['json', 'csv'],
|
||||
'output_dir': './reports',
|
||||
'anonymize_by_default': False
|
||||
},
|
||||
'filtering': {
|
||||
'whitelist_ips': [],
|
||||
'whitelist_domains': [],
|
||||
'min_confidence_score': 70,
|
||||
'exclude_honeypots': True
|
||||
},
|
||||
'logging': {
|
||||
'level': 'INFO',
|
||||
'file': './logs/scanner.log',
|
||||
'max_size_mb': 100,
|
||||
'backup_count': 5
|
||||
},
|
||||
'database': {
|
||||
'type': 'sqlite',
|
||||
'sqlite': {
|
||||
'path': './data/scanner.db'
|
||||
}
|
||||
},
|
||||
'api_keys': {},
|
||||
'clawsec': {
|
||||
'enabled': True,
|
||||
'feed_url': 'https://clawsec.prompt.security/advisories/feed.json',
|
||||
'cache_ttl_seconds': 86400, # 24 hours
|
||||
'cache_file': './data/clawsec_cache.json',
|
||||
'offline_mode': False,
|
||||
'timeout': 30,
|
||||
'auto_refresh': True
|
||||
}
|
||||
}
|
||||
|
||||
def _load_from_file(self, path: str) -> None:
|
||||
"""
|
||||
Load configuration from YAML file.
|
||||
|
||||
Args:
|
||||
path: Path to YAML configuration file.
|
||||
|
||||
Raises:
|
||||
ConfigurationException: If YAML is malformed.
|
||||
"""
|
||||
try:
|
||||
with open(path, 'r') as f:
|
||||
file_config = yaml.safe_load(f)
|
||||
if file_config:
|
||||
self._deep_merge(self._config, file_config)
|
||||
self._config_path = path
|
||||
logger.info(f"Loaded configuration from {path}")
|
||||
except FileNotFoundError:
|
||||
logger.warning(f"Configuration file not found: {path}")
|
||||
except yaml.YAMLError as e:
|
||||
raise ConfigurationException(f"Invalid YAML in configuration file: {e}")
|
||||
|
||||
def _load_from_env(self) -> None:
|
||||
"""
|
||||
Load settings from environment variables.
|
||||
|
||||
Environment variables override file-based configuration.
|
||||
This method handles all supported environment variables.
|
||||
"""
|
||||
# Load Shodan API key
|
||||
shodan_key = os.getenv('SHODAN_API_KEY')
|
||||
if shodan_key:
|
||||
self._set_nested(('api_keys', 'shodan'), shodan_key)
|
||||
|
||||
# Load log level if set
|
||||
log_level = os.getenv('AASRT_LOG_LEVEL', '').upper()
|
||||
if log_level and log_level in VALID_LOG_LEVELS:
|
||||
self._set_nested(('logging', 'level'), log_level)
|
||||
elif log_level:
|
||||
logger.warning(f"Invalid log level '{log_level}', using default")
|
||||
|
||||
# Load environment setting
|
||||
env = os.getenv('AASRT_ENVIRONMENT', '').lower()
|
||||
if env and env in VALID_ENVIRONMENTS:
|
||||
self._environment = env
|
||||
|
||||
# Load debug flag
|
||||
debug = os.getenv('AASRT_DEBUG', '').lower()
|
||||
if debug in ('true', '1', 'yes'):
|
||||
self._set_nested(('logging', 'level'), 'DEBUG')
|
||||
|
||||
# Load database settings from environment
|
||||
db_type = os.getenv('DB_TYPE', '').lower()
|
||||
if db_type and db_type in VALID_DB_TYPES:
|
||||
self._set_nested(('database', 'type'), db_type)
|
||||
|
||||
# PostgreSQL settings from environment
|
||||
if os.getenv('DB_HOST'):
|
||||
self._set_nested(('database', 'postgresql', 'host'), os.getenv('DB_HOST'))
|
||||
if os.getenv('DB_PORT'):
|
||||
try:
|
||||
port = int(os.getenv('DB_PORT'))
|
||||
self._set_nested(('database', 'postgresql', 'port'), port)
|
||||
except ValueError:
|
||||
logger.warning("Invalid DB_PORT, using default")
|
||||
if os.getenv('DB_NAME'):
|
||||
self._set_nested(('database', 'postgresql', 'database'), os.getenv('DB_NAME'))
|
||||
if os.getenv('DB_USER'):
|
||||
self._set_nested(('database', 'postgresql', 'user'), os.getenv('DB_USER'))
|
||||
if os.getenv('DB_PASSWORD'):
|
||||
self._set_nested(('database', 'postgresql', 'password'), os.getenv('DB_PASSWORD'))
|
||||
if os.getenv('DB_SSL_MODE'):
|
||||
self._set_nested(('database', 'postgresql', 'ssl_mode'), os.getenv('DB_SSL_MODE'))
|
||||
|
||||
# Max results limit
|
||||
max_results = os.getenv('AASRT_MAX_RESULTS')
|
||||
if max_results:
|
||||
try:
|
||||
self._set_nested(('shodan', 'max_results'), int(max_results))
|
||||
except ValueError:
|
||||
logger.warning("Invalid AASRT_MAX_RESULTS, using default")
|
||||
|
||||
def _validate_config(self) -> None:
|
||||
"""
|
||||
Validate configuration settings.
|
||||
|
||||
Checks for valid values and logs warnings for potential issues.
|
||||
Does not raise exceptions to allow graceful degradation.
|
||||
"""
|
||||
self._validation_errors = []
|
||||
|
||||
# Validate log level
|
||||
log_level = self.get('logging', 'level', default='INFO')
|
||||
if log_level.upper() not in VALID_LOG_LEVELS:
|
||||
self._validation_errors.append(f"Invalid log level: {log_level}")
|
||||
|
||||
# Validate database type
|
||||
db_type = self.get('database', 'type', default='sqlite')
|
||||
if db_type.lower() not in VALID_DB_TYPES:
|
||||
self._validation_errors.append(f"Invalid database type: {db_type}")
|
||||
|
||||
# Validate max results is positive
|
||||
max_results = self.get('shodan', 'max_results', default=100)
|
||||
if not isinstance(max_results, int) or max_results < 1:
|
||||
self._validation_errors.append(f"Invalid max_results: {max_results}")
|
||||
|
||||
# Check for Shodan API key (warning, not error)
|
||||
if not self.get_shodan_key():
|
||||
logger.debug("Shodan API key not configured - scans will require it")
|
||||
|
||||
# Log validation errors
|
||||
for error in self._validation_errors:
|
||||
logger.warning(f"Configuration validation: {error}")
|
||||
|
||||
def _deep_merge(self, base: Dict, overlay: Dict) -> None:
|
||||
"""
|
||||
Deep merge overlay dictionary into base dictionary.
|
||||
|
||||
Args:
|
||||
base: Base dictionary to merge into (modified in place).
|
||||
overlay: Overlay dictionary to merge from.
|
||||
"""
|
||||
for key, value in overlay.items():
|
||||
if key in base and isinstance(base[key], dict) and isinstance(value, dict):
|
||||
self._deep_merge(base[key], value)
|
||||
else:
|
||||
base[key] = value
|
||||
|
||||
def _set_nested(self, path: tuple, value: Any) -> None:
|
||||
"""
|
||||
Set a nested configuration value by key path.
|
||||
|
||||
Args:
|
||||
path: Tuple of keys representing the path.
|
||||
value: Value to set at the path.
|
||||
"""
|
||||
current = self._config
|
||||
for key in path[:-1]:
|
||||
if key not in current:
|
||||
current[key] = {}
|
||||
current = current[key]
|
||||
current[path[-1]] = value
|
||||
|
||||
def get(self, *keys: str, default: Any = None) -> Any:
|
||||
"""
|
||||
Get a configuration value by nested keys.
|
||||
|
||||
Args:
|
||||
*keys: Nested keys to traverse (e.g., 'database', 'type').
|
||||
default: Default value if path not found.
|
||||
|
||||
Returns:
|
||||
Configuration value or default.
|
||||
|
||||
Example:
|
||||
>>> config.get('shodan', 'max_results', default=100)
|
||||
100
|
||||
"""
|
||||
current = self._config
|
||||
for key in keys:
|
||||
if isinstance(current, dict) and key in current:
|
||||
current = current[key]
|
||||
else:
|
||||
return default
|
||||
return current
|
||||
|
||||
def get_shodan_key(self) -> Optional[str]:
|
||||
"""
|
||||
Get Shodan API key.
|
||||
|
||||
Returns:
|
||||
Shodan API key string, or None if not configured.
|
||||
"""
|
||||
return self.get('api_keys', 'shodan')
|
||||
|
||||
def get_shodan_config(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get Shodan configuration dictionary.
|
||||
|
||||
Returns:
|
||||
Dictionary with Shodan settings (enabled, rate_limit, max_results, timeout).
|
||||
"""
|
||||
return self.get('shodan', default={})
|
||||
|
||||
def get_clawsec_config(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get ClawSec configuration dictionary.
|
||||
|
||||
Returns:
|
||||
Dictionary with ClawSec settings.
|
||||
"""
|
||||
return self.get('clawsec', default={})
|
||||
|
||||
def is_clawsec_enabled(self) -> bool:
|
||||
"""
|
||||
Check if ClawSec integration is enabled.
|
||||
|
||||
Returns:
|
||||
True if ClawSec vulnerability lookup is enabled.
|
||||
"""
|
||||
return self.get('clawsec', 'enabled', default=True)
|
||||
|
||||
def get_database_config(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get database configuration.
|
||||
|
||||
Returns:
|
||||
Dictionary with database settings.
|
||||
"""
|
||||
return self.get('database', default={})
|
||||
|
||||
def get_logging_config(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get logging configuration.
|
||||
|
||||
Returns:
|
||||
Dictionary with logging settings (level, file, max_size_mb, backup_count).
|
||||
"""
|
||||
return self.get('logging', default={})
|
||||
|
||||
@property
|
||||
def environment(self) -> str:
|
||||
"""
|
||||
Get current deployment environment.
|
||||
|
||||
Returns:
|
||||
Environment string (development, staging, production).
|
||||
"""
|
||||
return self._environment
|
||||
|
||||
@property
|
||||
def is_production(self) -> bool:
|
||||
"""
|
||||
Check if running in production environment.
|
||||
|
||||
Returns:
|
||||
True if environment is 'production'.
|
||||
"""
|
||||
return self._environment == 'production'
|
||||
|
||||
@property
|
||||
def is_debug(self) -> bool:
|
||||
"""
|
||||
Check if debug mode is enabled.
|
||||
|
||||
Returns:
|
||||
True if log level is DEBUG.
|
||||
"""
|
||||
return self.get('logging', 'level', default='INFO').upper() == 'DEBUG'
|
||||
|
||||
@property
|
||||
def all(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get all configuration as dictionary.
|
||||
|
||||
Returns:
|
||||
Copy of complete configuration dictionary.
|
||||
"""
|
||||
return self._config.copy()
|
||||
|
||||
def reload(self, config_path: Optional[str] = None) -> None:
|
||||
"""
|
||||
Reload configuration from file and environment.
|
||||
|
||||
Use this to refresh configuration without restarting the application.
|
||||
|
||||
Args:
|
||||
config_path: Optional path to configuration file.
|
||||
If None, uses previously loaded file path.
|
||||
"""
|
||||
logger.info("Reloading configuration...")
|
||||
self._initialized = False
|
||||
self._config = self._get_defaults()
|
||||
|
||||
# Use new path or fall back to previously loaded path
|
||||
path_to_load = config_path or self._config_path
|
||||
if path_to_load:
|
||||
self._load_from_file(path_to_load)
|
||||
|
||||
self._load_from_env()
|
||||
self._validate_config()
|
||||
self._initialized = True
|
||||
logger.info("Configuration reloaded successfully")
|
||||
|
||||
def health_check(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Perform a health check on configuration.
|
||||
|
||||
Returns:
|
||||
Dictionary with health status:
|
||||
- healthy: bool indicating if configuration is valid
|
||||
- environment: Current deployment environment
|
||||
- config_file: Path to loaded config file (if any)
|
||||
- validation_errors: List of validation errors
|
||||
- shodan_configured: Whether Shodan API key is set
|
||||
- clawsec_enabled: Whether ClawSec is enabled
|
||||
"""
|
||||
return {
|
||||
"healthy": len(self._validation_errors) == 0,
|
||||
"environment": self._environment,
|
||||
"config_file": self._config_path,
|
||||
"validation_errors": self._validation_errors.copy(),
|
||||
"shodan_configured": bool(self.get_shodan_key()),
|
||||
"clawsec_enabled": self.is_clawsec_enabled(),
|
||||
"log_level": self.get('logging', 'level', default='INFO'),
|
||||
"database_type": self.get('database', 'type', default='sqlite')
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def reset_instance() -> None:
|
||||
"""
|
||||
Reset the singleton instance (for testing).
|
||||
|
||||
Warning:
|
||||
This should only be used in tests. It will cause any
|
||||
existing references to the old instance to be stale.
|
||||
"""
|
||||
Config._instance = None
|
||||
@@ -0,0 +1,51 @@
|
||||
"""Custom exceptions for AASRT."""
|
||||
|
||||
|
||||
class AASRTException(Exception):
|
||||
"""Base exception for AASRT."""
|
||||
pass
|
||||
|
||||
|
||||
class APIException(AASRTException):
|
||||
"""Raised when API call fails."""
|
||||
|
||||
def __init__(self, message: str, engine: str = None, status_code: int = None):
|
||||
self.engine = engine
|
||||
self.status_code = status_code
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class RateLimitException(AASRTException):
|
||||
"""Raised when rate limit is exceeded."""
|
||||
|
||||
def __init__(self, message: str, engine: str = None, retry_after: int = None):
|
||||
self.engine = engine
|
||||
self.retry_after = retry_after
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ConfigurationException(AASRTException):
|
||||
"""Raised when configuration is invalid."""
|
||||
pass
|
||||
|
||||
|
||||
class ValidationException(AASRTException):
|
||||
"""Raised when input validation fails."""
|
||||
pass
|
||||
|
||||
|
||||
class AuthenticationException(AASRTException):
|
||||
"""Raised when authentication fails."""
|
||||
|
||||
def __init__(self, message: str, engine: str = None):
|
||||
self.engine = engine
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class TimeoutException(AASRTException):
|
||||
"""Raised when a request times out."""
|
||||
|
||||
def __init__(self, message: str, engine: str = None, timeout: int = None):
|
||||
self.engine = engine
|
||||
self.timeout = timeout
|
||||
super().__init__(message)
|
||||
@@ -0,0 +1,91 @@
|
||||
"""Logging setup for AASRT."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from logging.handlers import RotatingFileHandler
|
||||
from typing import Optional
|
||||
|
||||
|
||||
_loggers = {}
|
||||
|
||||
|
||||
def setup_logger(
|
||||
name: str = "aasrt",
|
||||
level: str = "INFO",
|
||||
log_file: Optional[str] = None,
|
||||
max_size_mb: int = 100,
|
||||
backup_count: int = 5
|
||||
) -> logging.Logger:
|
||||
"""
|
||||
Setup and configure a logger.
|
||||
|
||||
Args:
|
||||
name: Logger name
|
||||
level: Log level (DEBUG, INFO, WARNING, ERROR)
|
||||
log_file: Path to log file (optional)
|
||||
max_size_mb: Max log file size in MB
|
||||
backup_count: Number of backup files to keep
|
||||
|
||||
Returns:
|
||||
Configured logger instance
|
||||
"""
|
||||
if name in _loggers:
|
||||
return _loggers[name]
|
||||
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(getattr(logging, level.upper(), logging.INFO))
|
||||
|
||||
# Prevent duplicate handlers
|
||||
if logger.handlers:
|
||||
return logger
|
||||
|
||||
# Console handler with colors
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setLevel(logging.DEBUG)
|
||||
|
||||
# Format with colors for console
|
||||
console_format = logging.Formatter(
|
||||
'%(asctime)s | %(levelname)-8s | %(name)s | %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'
|
||||
)
|
||||
console_handler.setFormatter(console_format)
|
||||
logger.addHandler(console_handler)
|
||||
|
||||
# File handler if log_file specified
|
||||
if log_file:
|
||||
# Ensure directory exists
|
||||
log_dir = os.path.dirname(log_file)
|
||||
if log_dir and not os.path.exists(log_dir):
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
|
||||
file_handler = RotatingFileHandler(
|
||||
log_file,
|
||||
maxBytes=max_size_mb * 1024 * 1024,
|
||||
backupCount=backup_count
|
||||
)
|
||||
file_handler.setLevel(logging.DEBUG)
|
||||
|
||||
file_format = logging.Formatter(
|
||||
'%(asctime)s | %(levelname)-8s | %(name)s | %(filename)s:%(lineno)d | %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'
|
||||
)
|
||||
file_handler.setFormatter(file_format)
|
||||
logger.addHandler(file_handler)
|
||||
|
||||
_loggers[name] = logger
|
||||
return logger
|
||||
|
||||
|
||||
def get_logger(name: str = "aasrt") -> logging.Logger:
|
||||
"""
|
||||
Get an existing logger or create a new one.
|
||||
|
||||
Args:
|
||||
name: Logger name
|
||||
|
||||
Returns:
|
||||
Logger instance
|
||||
"""
|
||||
if name in _loggers:
|
||||
return _loggers[name]
|
||||
return setup_logger(name)
|
||||
@@ -0,0 +1,583 @@
|
||||
"""
|
||||
Input validation utilities for AASRT.
|
||||
|
||||
This module provides comprehensive input validation and sanitization functions
|
||||
for security-sensitive operations including:
|
||||
- IP address and domain validation
|
||||
- Port number and query string validation
|
||||
- File path sanitization (directory traversal prevention)
|
||||
- API key format validation
|
||||
- Template name whitelist validation
|
||||
- Configuration value validation
|
||||
|
||||
All validators raise ValidationException on invalid input with descriptive
|
||||
error messages for debugging.
|
||||
|
||||
Example:
|
||||
>>> from src.utils.validators import validate_ip, validate_file_path
|
||||
>>> validate_ip("192.168.1.1") # Returns True
|
||||
>>> validate_file_path("../../../etc/passwd") # Raises ValidationException
|
||||
"""
|
||||
|
||||
import re
|
||||
import os
|
||||
import ipaddress
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Set, Union
|
||||
|
||||
import validators
|
||||
|
||||
from .exceptions import ValidationException
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Constants
|
||||
# =============================================================================
|
||||
|
||||
# Valid log levels for configuration
|
||||
VALID_LOG_LEVELS: Set[str] = {"DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"}
|
||||
|
||||
# Valid environment names
|
||||
VALID_ENVIRONMENTS: Set[str] = {"development", "staging", "production"}
|
||||
|
||||
# Valid database types
|
||||
VALID_DB_TYPES: Set[str] = {"sqlite", "postgresql", "mysql"}
|
||||
|
||||
# Valid report formats
|
||||
VALID_REPORT_FORMATS: Set[str] = {"json", "csv", "html", "pdf"}
|
||||
|
||||
# Valid query template names (whitelist)
|
||||
VALID_TEMPLATES: Set[str] = {
|
||||
"clawdbot_instances",
|
||||
"autogpt_instances",
|
||||
"langchain_agents",
|
||||
"openai_agents",
|
||||
"anthropic_agents",
|
||||
"ai_agent_general",
|
||||
"agent_gpt",
|
||||
"babyagi_instances",
|
||||
"crewai_instances",
|
||||
"autogen_instances",
|
||||
"superagi_instances",
|
||||
"flowise_instances",
|
||||
"dify_instances",
|
||||
}
|
||||
|
||||
# Maximum limits for various inputs
|
||||
MAX_QUERY_LENGTH: int = 2000
|
||||
MAX_RESULTS_LIMIT: int = 10000
|
||||
MIN_RESULTS_LIMIT: int = 1
|
||||
MAX_PORT: int = 65535
|
||||
MIN_PORT: int = 1
|
||||
MAX_FILE_PATH_LENGTH: int = 4096
|
||||
MAX_API_KEY_LENGTH: int = 256
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# IP and Network Validators
|
||||
# =============================================================================
|
||||
|
||||
def validate_ip(ip: str) -> bool:
|
||||
"""
|
||||
Validate an IP address (IPv4 or IPv6).
|
||||
|
||||
Args:
|
||||
ip: IP address string to validate.
|
||||
|
||||
Returns:
|
||||
True if the IP address is valid.
|
||||
|
||||
Raises:
|
||||
ValidationException: If IP is None, empty, or invalid format.
|
||||
|
||||
Example:
|
||||
>>> validate_ip("192.168.1.1")
|
||||
True
|
||||
>>> validate_ip("2001:db8::1")
|
||||
True
|
||||
>>> validate_ip("invalid")
|
||||
ValidationException: Invalid IP address: invalid
|
||||
"""
|
||||
if ip is None:
|
||||
raise ValidationException("IP address cannot be None")
|
||||
|
||||
if not isinstance(ip, str):
|
||||
raise ValidationException(f"IP address must be a string, got {type(ip).__name__}")
|
||||
|
||||
ip = ip.strip()
|
||||
if not ip:
|
||||
raise ValidationException("IP address cannot be empty")
|
||||
|
||||
try:
|
||||
ipaddress.ip_address(ip)
|
||||
return True
|
||||
except ValueError:
|
||||
raise ValidationException(f"Invalid IP address: {ip}")
|
||||
|
||||
|
||||
def validate_domain(domain: str) -> bool:
|
||||
"""
|
||||
Validate a domain name.
|
||||
|
||||
Args:
|
||||
domain: Domain name string
|
||||
|
||||
Returns:
|
||||
True if valid
|
||||
|
||||
Raises:
|
||||
ValidationException: If domain is invalid
|
||||
"""
|
||||
if validators.domain(domain):
|
||||
return True
|
||||
raise ValidationException(f"Invalid domain: {domain}")
|
||||
|
||||
|
||||
def validate_query(query: str, engine: str) -> bool:
|
||||
"""
|
||||
Validate a search query for a specific engine.
|
||||
|
||||
Args:
|
||||
query: Search query string
|
||||
engine: Search engine name
|
||||
|
||||
Returns:
|
||||
True if valid
|
||||
|
||||
Raises:
|
||||
ValidationException: If query is invalid
|
||||
"""
|
||||
if not query or not query.strip():
|
||||
raise ValidationException("Query cannot be empty")
|
||||
|
||||
# Check for potentially dangerous characters
|
||||
dangerous_patterns = [
|
||||
r'[<>]', # Script injection attempts
|
||||
r'\x00', # Null bytes
|
||||
]
|
||||
|
||||
for pattern in dangerous_patterns:
|
||||
if re.search(pattern, query):
|
||||
raise ValidationException(f"Query contains invalid characters: {pattern}")
|
||||
|
||||
# Engine-specific validation
|
||||
if engine == "shodan":
|
||||
# Shodan queries should be reasonable length
|
||||
if len(query) > 1000:
|
||||
raise ValidationException("Shodan query too long (max 1000 chars)")
|
||||
|
||||
elif engine == "censys":
|
||||
# Censys queries should be reasonable length
|
||||
if len(query) > 2000:
|
||||
raise ValidationException("Censys query too long (max 2000 chars)")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def validate_port(port: int) -> bool:
|
||||
"""
|
||||
Validate a port number.
|
||||
|
||||
Args:
|
||||
port: Port number
|
||||
|
||||
Returns:
|
||||
True if valid
|
||||
|
||||
Raises:
|
||||
ValidationException: If port is invalid
|
||||
"""
|
||||
if not isinstance(port, int) or port < 1 or port > 65535:
|
||||
raise ValidationException(f"Invalid port number: {port}")
|
||||
return True
|
||||
|
||||
|
||||
def validate_api_key(api_key: str, engine: str) -> bool:
|
||||
"""
|
||||
Validate API key format for a specific engine.
|
||||
|
||||
Args:
|
||||
api_key: API key string
|
||||
engine: Search engine name
|
||||
|
||||
Returns:
|
||||
True if valid
|
||||
|
||||
Raises:
|
||||
ValidationException: If API key format is invalid
|
||||
"""
|
||||
if not api_key or not api_key.strip():
|
||||
raise ValidationException(f"API key for {engine} cannot be empty")
|
||||
|
||||
# Basic format validation (not checking actual validity)
|
||||
if engine == "shodan":
|
||||
# Shodan API keys are typically 32 characters
|
||||
if len(api_key) < 20:
|
||||
raise ValidationException("Shodan API key appears too short")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def sanitize_output(text: str) -> str:
|
||||
"""
|
||||
Sanitize text for safe output (remove potential secrets).
|
||||
|
||||
This function redacts sensitive patterns like API keys, passwords, and
|
||||
authentication tokens to prevent accidental exposure in logs or output.
|
||||
|
||||
Args:
|
||||
text: Text to sanitize.
|
||||
|
||||
Returns:
|
||||
Sanitized text with sensitive data replaced by REDACTED markers.
|
||||
|
||||
Example:
|
||||
>>> sanitize_output("key: sk-ant-abc123...")
|
||||
'key: sk-ant-***REDACTED***'
|
||||
"""
|
||||
if text is None:
|
||||
return ""
|
||||
|
||||
if not isinstance(text, str):
|
||||
text = str(text)
|
||||
|
||||
# Patterns for sensitive data (order matters - more specific first)
|
||||
patterns = [
|
||||
# Anthropic API keys
|
||||
(r'sk-ant-[a-zA-Z0-9-_]{20,}', 'sk-ant-***REDACTED***'),
|
||||
# OpenAI API keys
|
||||
(r'sk-[a-zA-Z0-9]{40,}', 'sk-***REDACTED***'),
|
||||
# AWS Access Key
|
||||
(r'AKIA[0-9A-Z]{16}', 'AKIA***REDACTED***'),
|
||||
# AWS Secret Key
|
||||
(r'(?i)aws_secret_access_key["\s:=]+["\']?[A-Za-z0-9/+=]{40}', 'aws_secret_access_key=***REDACTED***'),
|
||||
# GitHub tokens
|
||||
(r'ghp_[a-zA-Z0-9]{36}', 'ghp_***REDACTED***'),
|
||||
(r'gho_[a-zA-Z0-9]{36}', 'gho_***REDACTED***'),
|
||||
# Google API keys
|
||||
(r'AIza[0-9A-Za-z-_]{35}', 'AIza***REDACTED***'),
|
||||
# Stripe keys
|
||||
(r'sk_live_[a-zA-Z0-9]{24,}', 'sk_live_***REDACTED***'),
|
||||
(r'sk_test_[a-zA-Z0-9]{24,}', 'sk_test_***REDACTED***'),
|
||||
# Shodan API key (32 hex chars)
|
||||
(r'[a-fA-F0-9]{32}', '***REDACTED_KEY***'),
|
||||
# Generic password patterns
|
||||
(r'password["\s:=]+["\']?[\w@#$%^&*!?]+', 'password=***REDACTED***'),
|
||||
(r'passwd["\s:=]+["\']?[\w@#$%^&*!?]+', 'passwd=***REDACTED***'),
|
||||
(r'secret["\s:=]+["\']?[\w@#$%^&*!?]+', 'secret=***REDACTED***'),
|
||||
# Bearer tokens
|
||||
(r'Bearer\s+[a-zA-Z0-9._-]+', 'Bearer ***REDACTED***'),
|
||||
# Basic auth
|
||||
(r'Basic\s+[a-zA-Z0-9+/=]+', 'Basic ***REDACTED***'),
|
||||
]
|
||||
|
||||
result = text
|
||||
for pattern, replacement in patterns:
|
||||
result = re.sub(pattern, replacement, result, flags=re.IGNORECASE)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# File Path Validators
|
||||
# =============================================================================
|
||||
|
||||
def validate_file_path(
|
||||
path: str,
|
||||
must_exist: bool = False,
|
||||
allow_absolute: bool = True,
|
||||
base_dir: Optional[str] = None
|
||||
) -> str:
|
||||
"""
|
||||
Validate and sanitize a file path to prevent directory traversal attacks.
|
||||
|
||||
Args:
|
||||
path: File path to validate.
|
||||
must_exist: If True, the file must exist.
|
||||
allow_absolute: If True, allow absolute paths.
|
||||
base_dir: If provided, ensure path is within this directory.
|
||||
|
||||
Returns:
|
||||
Sanitized, normalized file path.
|
||||
|
||||
Raises:
|
||||
ValidationException: If path is invalid or potentially dangerous.
|
||||
|
||||
Example:
|
||||
>>> validate_file_path("reports/scan.json")
|
||||
'reports/scan.json'
|
||||
>>> validate_file_path("../../../etc/passwd")
|
||||
ValidationException: Path traversal detected
|
||||
"""
|
||||
if path is None:
|
||||
raise ValidationException("File path cannot be None")
|
||||
|
||||
if not isinstance(path, str):
|
||||
raise ValidationException(f"File path must be a string, got {type(path).__name__}")
|
||||
|
||||
path = path.strip()
|
||||
if not path:
|
||||
raise ValidationException("File path cannot be empty")
|
||||
|
||||
if len(path) > MAX_FILE_PATH_LENGTH:
|
||||
raise ValidationException(f"File path too long (max {MAX_FILE_PATH_LENGTH} chars)")
|
||||
|
||||
# Check for null bytes (security risk)
|
||||
if '\x00' in path:
|
||||
raise ValidationException("File path contains null bytes")
|
||||
|
||||
# Normalize the path
|
||||
try:
|
||||
normalized = os.path.normpath(path)
|
||||
except Exception as e:
|
||||
raise ValidationException(f"Invalid file path: {e}")
|
||||
|
||||
# Check for directory traversal
|
||||
if '..' in normalized.split(os.sep):
|
||||
raise ValidationException("Path traversal detected: '..' not allowed")
|
||||
|
||||
# Check absolute path restriction
|
||||
if not allow_absolute and os.path.isabs(normalized):
|
||||
raise ValidationException("Absolute paths not allowed")
|
||||
|
||||
# Check if within base directory
|
||||
if base_dir:
|
||||
base_dir = os.path.abspath(base_dir)
|
||||
full_path = os.path.abspath(os.path.join(base_dir, normalized))
|
||||
if not full_path.startswith(base_dir):
|
||||
raise ValidationException("Path escapes base directory")
|
||||
|
||||
# Check existence if required
|
||||
if must_exist and not os.path.exists(path):
|
||||
raise ValidationException(f"File does not exist: {path}")
|
||||
|
||||
return normalized
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Template and Configuration Validators
|
||||
# =============================================================================
|
||||
|
||||
def validate_template_name(template: str) -> bool:
|
||||
"""
|
||||
Validate a query template name against the whitelist.
|
||||
|
||||
Args:
|
||||
template: Template name to validate.
|
||||
|
||||
Returns:
|
||||
True if template is valid.
|
||||
|
||||
Raises:
|
||||
ValidationException: If template is not in the allowed list.
|
||||
|
||||
Example:
|
||||
>>> validate_template_name("clawdbot_instances")
|
||||
True
|
||||
>>> validate_template_name("malicious_query")
|
||||
ValidationException: Invalid template name
|
||||
"""
|
||||
if template is None:
|
||||
raise ValidationException("Template name cannot be None")
|
||||
|
||||
template = template.strip().lower()
|
||||
if not template:
|
||||
raise ValidationException("Template name cannot be empty")
|
||||
|
||||
if template not in VALID_TEMPLATES:
|
||||
valid_list = ", ".join(sorted(VALID_TEMPLATES))
|
||||
raise ValidationException(
|
||||
f"Invalid template name: '{template}'. Valid templates: {valid_list}"
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def validate_max_results(max_results: Union[int, str]) -> int:
|
||||
"""
|
||||
Validate and normalize max_results parameter.
|
||||
|
||||
Args:
|
||||
max_results: Maximum number of results (int or string).
|
||||
|
||||
Returns:
|
||||
Validated integer value.
|
||||
|
||||
Raises:
|
||||
ValidationException: If value is invalid or out of range.
|
||||
|
||||
Example:
|
||||
>>> validate_max_results(100)
|
||||
100
|
||||
>>> validate_max_results("50")
|
||||
50
|
||||
>>> validate_max_results(-1)
|
||||
ValidationException: max_results must be positive
|
||||
"""
|
||||
if max_results is None:
|
||||
raise ValidationException("max_results cannot be None")
|
||||
|
||||
# Convert string to int if needed
|
||||
if isinstance(max_results, str):
|
||||
try:
|
||||
max_results = int(max_results.strip())
|
||||
except ValueError:
|
||||
raise ValidationException(f"max_results must be a number, got: '{max_results}'")
|
||||
|
||||
if not isinstance(max_results, int):
|
||||
raise ValidationException(f"max_results must be an integer, got {type(max_results).__name__}")
|
||||
|
||||
if max_results < MIN_RESULTS_LIMIT:
|
||||
raise ValidationException(f"max_results must be at least {MIN_RESULTS_LIMIT}")
|
||||
|
||||
if max_results > MAX_RESULTS_LIMIT:
|
||||
raise ValidationException(f"max_results cannot exceed {MAX_RESULTS_LIMIT}")
|
||||
|
||||
return max_results
|
||||
|
||||
|
||||
def validate_log_level(level: str) -> str:
|
||||
"""
|
||||
Validate a log level string.
|
||||
|
||||
Args:
|
||||
level: Log level string.
|
||||
|
||||
Returns:
|
||||
Normalized uppercase log level.
|
||||
|
||||
Raises:
|
||||
ValidationException: If log level is invalid.
|
||||
"""
|
||||
if level is None:
|
||||
raise ValidationException("Log level cannot be None")
|
||||
|
||||
level = str(level).strip().upper()
|
||||
|
||||
if level not in VALID_LOG_LEVELS:
|
||||
valid_list = ", ".join(sorted(VALID_LOG_LEVELS))
|
||||
raise ValidationException(f"Invalid log level: '{level}'. Valid levels: {valid_list}")
|
||||
|
||||
return level
|
||||
|
||||
|
||||
def validate_environment(env: str) -> str:
|
||||
"""
|
||||
Validate an environment name.
|
||||
|
||||
Args:
|
||||
env: Environment name string.
|
||||
|
||||
Returns:
|
||||
Normalized lowercase environment name.
|
||||
|
||||
Raises:
|
||||
ValidationException: If environment is invalid.
|
||||
"""
|
||||
if env is None:
|
||||
raise ValidationException("Environment cannot be None")
|
||||
|
||||
env = str(env).strip().lower()
|
||||
|
||||
if env not in VALID_ENVIRONMENTS:
|
||||
valid_list = ", ".join(sorted(VALID_ENVIRONMENTS))
|
||||
raise ValidationException(f"Invalid environment: '{env}'. Valid environments: {valid_list}")
|
||||
|
||||
return env
|
||||
|
||||
|
||||
def validate_db_type(db_type: str) -> str:
|
||||
"""
|
||||
Validate a database type.
|
||||
|
||||
Args:
|
||||
db_type: Database type string.
|
||||
|
||||
Returns:
|
||||
Normalized lowercase database type.
|
||||
|
||||
Raises:
|
||||
ValidationException: If database type is invalid.
|
||||
"""
|
||||
if db_type is None:
|
||||
raise ValidationException("Database type cannot be None")
|
||||
|
||||
db_type = str(db_type).strip().lower()
|
||||
|
||||
if db_type not in VALID_DB_TYPES:
|
||||
valid_list = ", ".join(sorted(VALID_DB_TYPES))
|
||||
raise ValidationException(f"Invalid database type: '{db_type}'. Valid types: {valid_list}")
|
||||
|
||||
return db_type
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Batch Validation Helpers
|
||||
# =============================================================================
|
||||
|
||||
def validate_config_dict(config: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Validate a configuration dictionary.
|
||||
|
||||
Args:
|
||||
config: Configuration dictionary to validate.
|
||||
|
||||
Returns:
|
||||
Validated configuration dictionary.
|
||||
|
||||
Raises:
|
||||
ValidationException: If any configuration value is invalid.
|
||||
"""
|
||||
validated = {}
|
||||
|
||||
# Validate log level if present
|
||||
if 'logging' in config and 'level' in config['logging']:
|
||||
config['logging']['level'] = validate_log_level(config['logging']['level'])
|
||||
|
||||
# Validate database type if present
|
||||
if 'database' in config and 'type' in config['database']:
|
||||
config['database']['type'] = validate_db_type(config['database']['type'])
|
||||
|
||||
# Validate max_results if present
|
||||
if 'shodan' in config and 'max_results' in config['shodan']:
|
||||
config['shodan']['max_results'] = validate_max_results(config['shodan']['max_results'])
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def is_safe_string(text: str, max_length: int = 1000) -> bool:
|
||||
"""
|
||||
Check if a string is safe (no injection attempts).
|
||||
|
||||
Args:
|
||||
text: Text to check.
|
||||
max_length: Maximum allowed length.
|
||||
|
||||
Returns:
|
||||
True if string appears safe, False otherwise.
|
||||
"""
|
||||
if text is None:
|
||||
return False
|
||||
|
||||
if len(text) > max_length:
|
||||
return False
|
||||
|
||||
# Check for null bytes
|
||||
if '\x00' in text:
|
||||
return False
|
||||
|
||||
# Check for common injection patterns
|
||||
dangerous_patterns = [
|
||||
r'<script',
|
||||
r'javascript:',
|
||||
r'on\w+\s*=',
|
||||
r'\x00',
|
||||
r'<!--',
|
||||
r'--\s*>',
|
||||
]
|
||||
|
||||
for pattern in dangerous_patterns:
|
||||
if re.search(pattern, text, re.IGNORECASE):
|
||||
return False
|
||||
|
||||
return True
|
||||
Reference in New Issue
Block a user