Files
NeuroSploit/core/knowledge_augmentor.py
2026-02-11 10:50:37 -03:00

200 lines
8.6 KiB
Python

#!/usr/bin/env python3
"""
Knowledge Augmentor - Adversarial pattern recognition from bug bounty data.
Loads the bug bounty finetuning dataset and provides retrieval-based
context enrichment for agent prompts. This is for PATTERN RECOGNITION
and adversarial intuition -- NOT for replaying exploits.
The augmentor:
- Builds a keyword index by vulnerability type
- Retrieves relevant patterns matching current testing context
- Injects formatted reference material into agent prompts
- Explicitly instructs the model to adapt, not copy
"""
import json
import logging
from typing import Dict, List, Optional
from pathlib import Path
logger = logging.getLogger(__name__)
class KnowledgeAugmentor:
"""Retrieval-based knowledge augmentation from bug bounty dataset."""
# Vulnerability type keyword mappings
VULN_KEYWORDS = {
'xss': ['xss', 'cross-site scripting', 'reflected xss', 'stored xss', 'dom xss',
'script injection', 'html injection'],
'sqli': ['sql injection', 'sqli', 'union select', 'blind sql', 'error-based sql',
'time-based sql', 'second-order sql'],
'ssrf': ['ssrf', 'server-side request forgery', 'internal service'],
'idor': ['idor', 'insecure direct object', 'broken object level',
'bola', 'horizontal privilege'],
'rce': ['rce', 'remote code execution', 'command injection', 'os command',
'code execution', 'shell injection'],
'lfi': ['lfi', 'local file inclusion', 'path traversal', 'directory traversal',
'file read', 'file disclosure'],
'auth_bypass': ['authentication bypass', 'broken authentication', 'auth bypass',
'session fixation', 'jwt', 'token manipulation'],
'csrf': ['csrf', 'cross-site request forgery', 'state-changing'],
'open_redirect': ['open redirect', 'url redirect', 'redirect vulnerability'],
'xxe': ['xxe', 'xml external entity', 'xml injection'],
'ssti': ['ssti', 'server-side template injection', 'template injection'],
'race_condition': ['race condition', 'toctou', 'concurrency'],
'graphql': ['graphql', 'introspection', 'batching attack'],
'api': ['api', 'rest api', 'broken api', 'api key', 'rate limiting'],
'deserialization': ['deserialization', 'insecure deserialization', 'pickle',
'object injection'],
'upload': ['file upload', 'unrestricted upload', 'web shell', 'upload bypass'],
'cors': ['cors', 'cross-origin', 'origin validation'],
'subdomain_takeover': ['subdomain takeover', 'dangling dns', 'cname'],
'information_disclosure': ['information disclosure', 'sensitive data', 'data exposure',
'directory listing', 'source code disclosure'],
}
def __init__(self, dataset_path: str = "models/bug-bounty/bugbounty_finetuning_dataset.json",
max_patterns: int = 3):
self.dataset_path = Path(dataset_path)
self.max_patterns = max_patterns
self.entries: List[Dict] = []
self.index: Dict[str, List[int]] = {} # vuln_type -> list of entry indices
self._loaded = False
def _ensure_loaded(self):
"""Lazy load and index the dataset on first use."""
if self._loaded:
return
if not self.dataset_path.exists():
logger.warning(f"Bug bounty dataset not found: {self.dataset_path}")
self._loaded = True
return
try:
with open(self.dataset_path, 'r', encoding='utf-8') as f:
self.entries = json.load(f)
logger.info(f"Loaded {len(self.entries)} entries from bug bounty dataset")
self._build_index()
except Exception as e:
logger.error(f"Failed to load bug bounty dataset: {e}")
self._loaded = True
def _build_index(self):
"""Build keyword index over the dataset entries."""
for i, entry in enumerate(self.entries):
text = (
entry.get('instruction', '') + ' ' +
entry.get('input', '') + ' ' +
entry.get('output', '')
).lower()
for vuln_type, keywords in self.VULN_KEYWORDS.items():
for kw in keywords:
if kw in text:
self.index.setdefault(vuln_type, []).append(i)
break # One match per vuln_type per entry
indexed_types = {k: len(v) for k, v in self.index.items()}
logger.info(f"Knowledge index built: {indexed_types}")
def get_relevant_patterns(self, vulnerability_type: str,
technologies: Optional[List[str]] = None,
max_entries: Optional[int] = None) -> str:
"""Retrieve relevant bug bounty patterns for context enrichment.
Args:
vulnerability_type: Type of vulnerability being tested (e.g., 'xss', 'sqli')
technologies: Optional list of detected technologies for relevance boosting
max_entries: Override default max patterns count
Returns:
Formatted string for injection into LLM prompts as cognitive augmentation.
Returns empty string if no relevant patterns found.
"""
self._ensure_loaded()
limit = max_entries or self.max_patterns
vuln_key = vulnerability_type.lower().replace(' ', '_').replace('-', '_')
# Try exact match first, then partial
candidates = self.index.get(vuln_key, [])
if not candidates:
# Try partial matching
for key, indices in self.index.items():
if vuln_key in key or key in vuln_key:
candidates = indices
break
if not candidates:
return ""
# Deduplicate
candidates = list(dict.fromkeys(candidates))
# Score by technology relevance if technologies provided
if technologies:
scored = []
for idx in candidates:
entry = self.entries[idx]
text = (entry.get('output', '') + ' ' + entry.get('instruction', '')).lower()
tech_score = sum(1 for t in technologies if t.lower() in text)
scored.append((tech_score, idx))
scored.sort(key=lambda x: x[0], reverse=True)
candidates = [idx for _, idx in scored]
selected = candidates[:limit]
# Build augmentation context
augmentation = (
"\n\n=== ADVERSARIAL PATTERN CONTEXT (Bug Bounty Knowledge) ===\n"
"These are REFERENCE PATTERNS for understanding attack vectors and methodology.\n"
"ADAPT the approach to the current target. Do NOT replay exact exploits.\n"
"Use these as cognitive anchors for creative hypothesis generation.\n\n"
)
for i, idx in enumerate(selected, 1):
entry = self.entries[idx]
instruction = entry.get('instruction', '')[:300]
output = entry.get('output', '')
# Extract methodology-relevant sections, truncate for context budget
methodology = self._extract_methodology(output, max_chars=1500)
augmentation += f"--- Pattern {i} ---\n"
augmentation += f"Context: {instruction}\n"
augmentation += f"Methodology:\n{methodology}\n\n"
augmentation += "=== END ADVERSARIAL PATTERN CONTEXT ===\n"
return augmentation
def _extract_methodology(self, text: str, max_chars: int = 1500) -> str:
"""Extract the most methodology-relevant portion of a writeup."""
# Look for methodology/steps/approach sections
markers = ['### steps', '### methodology', '### approach', '### exploitation',
'## steps', '## methodology', '## approach', '## exploitation',
'steps to reproduce', 'reproduction steps', 'proof of concept']
text_lower = text.lower()
for marker in markers:
idx = text_lower.find(marker)
if idx != -1:
return text[idx:idx + max_chars]
# Fall back to first max_chars of the output
return text[:max_chars]
def get_available_types(self) -> List[str]:
"""Return list of vulnerability types that have indexed entries."""
self._ensure_loaded()
return sorted(self.index.keys())
def get_entry_count(self, vulnerability_type: str) -> int:
"""Return count of indexed entries for a vulnerability type."""
self._ensure_loaded()
vuln_key = vulnerability_type.lower().replace(' ', '_').replace('-', '_')
return len(self.index.get(vuln_key, []))