Files
NeuroSploit/backend/db/database.py
2026-01-23 15:46:05 -03:00

162 lines
6.3 KiB
Python

"""
NeuroSploit v3 - Database Configuration
"""
import logging
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
from sqlalchemy.orm import DeclarativeBase
from sqlalchemy import text
from backend.config import settings
logger = logging.getLogger(__name__)
class Base(DeclarativeBase):
"""Base class for all models"""
pass
# Create async engine
engine = create_async_engine(
settings.DATABASE_URL,
echo=settings.DEBUG,
future=True
)
# Create async session factory
async_session_maker = async_sessionmaker(
engine,
class_=AsyncSession,
expire_on_commit=False
)
# Alias for background tasks
async_session_factory = async_session_maker
async def get_db() -> AsyncSession:
"""Dependency to get database session"""
async with async_session_maker() as session:
try:
yield session
await session.commit()
except Exception:
await session.rollback()
raise
finally:
await session.close()
async def _run_migrations(conn):
"""Run schema migrations to add missing columns"""
try:
# Check and add duration column to scans table
result = await conn.execute(text("PRAGMA table_info(scans)"))
columns = [row[1] for row in result.fetchall()]
if "duration" not in columns:
logger.info("Adding 'duration' column to scans table...")
await conn.execute(text("ALTER TABLE scans ADD COLUMN duration INTEGER"))
# Check and add columns to reports table
result = await conn.execute(text("PRAGMA table_info(reports)"))
columns = [row[1] for row in result.fetchall()]
if columns: # Table exists
if "auto_generated" not in columns:
logger.info("Adding 'auto_generated' column to reports table...")
await conn.execute(text("ALTER TABLE reports ADD COLUMN auto_generated BOOLEAN DEFAULT 0"))
if "is_partial" not in columns:
logger.info("Adding 'is_partial' column to reports table...")
await conn.execute(text("ALTER TABLE reports ADD COLUMN is_partial BOOLEAN DEFAULT 0"))
# Check and add columns to vulnerabilities table
result = await conn.execute(text("PRAGMA table_info(vulnerabilities)"))
columns = [row[1] for row in result.fetchall()]
if columns: # Table exists
if "test_id" not in columns:
logger.info("Adding 'test_id' column to vulnerabilities table...")
await conn.execute(text("ALTER TABLE vulnerabilities ADD COLUMN test_id VARCHAR(36)"))
if "poc_parameter" not in columns:
logger.info("Adding 'poc_parameter' column to vulnerabilities table...")
await conn.execute(text("ALTER TABLE vulnerabilities ADD COLUMN poc_parameter VARCHAR(500)"))
if "poc_evidence" not in columns:
logger.info("Adding 'poc_evidence' column to vulnerabilities table...")
await conn.execute(text("ALTER TABLE vulnerabilities ADD COLUMN poc_evidence TEXT"))
# Check if agent_tasks table exists
result = await conn.execute(
text("SELECT name FROM sqlite_master WHERE type='table' AND name='agent_tasks'")
)
if not result.fetchone():
logger.info("Creating 'agent_tasks' table...")
await conn.execute(text("""
CREATE TABLE agent_tasks (
id VARCHAR(36) PRIMARY KEY,
scan_id VARCHAR(36) NOT NULL,
task_type VARCHAR(50) NOT NULL,
task_name VARCHAR(255) NOT NULL,
description TEXT,
tool_name VARCHAR(100),
tool_category VARCHAR(100),
status VARCHAR(20) DEFAULT 'pending',
started_at DATETIME,
completed_at DATETIME,
duration_ms INTEGER,
items_processed INTEGER DEFAULT 0,
items_found INTEGER DEFAULT 0,
result_summary TEXT,
error_message TEXT,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (scan_id) REFERENCES scans(id) ON DELETE CASCADE
)
"""))
await conn.execute(text("CREATE INDEX IF NOT EXISTS idx_agent_tasks_scan_id ON agent_tasks(scan_id)"))
await conn.execute(text("CREATE INDEX IF NOT EXISTS idx_agent_tasks_status ON agent_tasks(status)"))
# Check if vulnerability_tests table exists
result = await conn.execute(
text("SELECT name FROM sqlite_master WHERE type='table' AND name='vulnerability_tests'")
)
if not result.fetchone():
logger.info("Creating 'vulnerability_tests' table...")
await conn.execute(text("""
CREATE TABLE vulnerability_tests (
id VARCHAR(36) PRIMARY KEY,
scan_id VARCHAR(36) NOT NULL,
endpoint_id VARCHAR(36),
vulnerability_type VARCHAR(100) NOT NULL,
payload TEXT,
request_data JSON DEFAULT '{}',
response_data JSON DEFAULT '{}',
is_vulnerable BOOLEAN DEFAULT 0,
confidence FLOAT,
evidence TEXT,
tested_at DATETIME DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (scan_id) REFERENCES scans(id) ON DELETE CASCADE,
FOREIGN KEY (endpoint_id) REFERENCES endpoints(id) ON DELETE SET NULL
)
"""))
await conn.execute(text("CREATE INDEX IF NOT EXISTS idx_vulnerability_tests_scan_id ON vulnerability_tests(scan_id)"))
logger.info("Database migrations completed")
except Exception as e:
logger.warning(f"Migration check failed (may be normal on first run): {e}")
async def init_db():
"""Initialize database tables and run migrations"""
async with engine.begin() as conn:
# Create all tables from models
await conn.run_sync(Base.metadata.create_all)
# Run migrations to add any missing columns
await _run_migrations(conn)
async def close_db():
"""Close database connection"""
await engine.dispose()