mirror of
https://github.com/CyberSecurityUP/NeuroSploit.git
synced 2026-05-22 14:36:50 +02:00
NeuroSploit v3.2 - Autonomous AI Penetration Testing Platform
116 modules | 100 vuln types | 18 API routes | 18 frontend pages Major features: - VulnEngine: 100 vuln types, 526+ payloads, 12 testers, anti-hallucination prompts - Autonomous Agent: 3-stream auto pentest, multi-session (5 concurrent), pause/resume/stop - CLI Agent: Claude Code / Gemini CLI / Codex CLI inside Kali containers - Validation Pipeline: negative controls, proof of execution, confidence scoring, judge - AI Reasoning: ReACT engine, token budget, endpoint classifier, CVE hunter, deep recon - Multi-Agent: 5 specialists + orchestrator + researcher AI + vuln type agents - RAG System: BM25/TF-IDF/ChromaDB vectorstore, few-shot, reasoning templates - Smart Router: 20 providers (8 CLI OAuth + 12 API), tier failover, token refresh - Kali Sandbox: container-per-scan, 56 tools, VPN support, on-demand install - Full IA Testing: methodology-driven comprehensive pentest sessions - Notifications: Discord, Telegram, WhatsApp/Twilio multi-channel alerts - Frontend: React/TypeScript with 18 pages, real-time WebSocket updates
This commit is contained in:
@@ -0,0 +1,84 @@
|
||||
"""
|
||||
RAG (Retrieval-Augmented Generation) system for NeuroSploitv2.
|
||||
|
||||
Enhances AI reasoning by providing relevant context from multiple knowledge
|
||||
sources without modifying the underlying model. This is a "reasoning amplifier"
|
||||
that teaches the AI HOW to think about vulnerabilities through:
|
||||
|
||||
1. Semantic retrieval from 9000+ bug bounty reports
|
||||
2. Few-shot examples showing successful exploitation reasoning
|
||||
3. Chain-of-Thought reasoning templates per vulnerability type
|
||||
4. Cross-scan reasoning memory (learning from past successes/failures)
|
||||
|
||||
Usage:
|
||||
from backend.core.rag import RAGEngine, FewShotSelector, ReasoningMemory
|
||||
from backend.core.rag.reasoning_templates import format_reasoning_prompt
|
||||
|
||||
# Initialize
|
||||
rag = RAGEngine(data_dir="data")
|
||||
rag.index_all() # One-time indexing
|
||||
|
||||
# Get testing context
|
||||
context = rag.get_testing_context("xss", technology="PHP")
|
||||
|
||||
# Get few-shot examples
|
||||
few_shot = FewShotSelector(rag_engine=rag)
|
||||
examples = few_shot.get_testing_examples("sqli", technology="MySQL")
|
||||
|
||||
# Get reasoning framework
|
||||
reasoning = format_reasoning_prompt("ssrf")
|
||||
|
||||
# Record success for future learning
|
||||
memory = ReasoningMemory()
|
||||
memory.record_success(trace)
|
||||
|
||||
Backends (auto-selected, best available):
|
||||
- ChromaDB + sentence-transformers: Semantic embeddings (best quality)
|
||||
- TF-IDF (scikit-learn): Statistical similarity (good quality)
|
||||
- BM25 (zero deps): Keyword ranking (works out of box)
|
||||
"""
|
||||
|
||||
from .engine import RAGEngine, RAGContext
|
||||
from .few_shot import FewShotSelector, FewShotExample
|
||||
from .reasoning_memory import ReasoningMemory, ReasoningTrace, FailureRecord
|
||||
from .reasoning_templates import (
|
||||
get_reasoning_template,
|
||||
format_reasoning_prompt,
|
||||
get_available_types,
|
||||
REASONING_TEMPLATES
|
||||
)
|
||||
from .vectorstore import (
|
||||
BaseVectorStore,
|
||||
BM25VectorStore,
|
||||
RetrievedChunk,
|
||||
Document,
|
||||
create_vectorstore
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Core engine
|
||||
"RAGEngine",
|
||||
"RAGContext",
|
||||
|
||||
# Few-shot selection
|
||||
"FewShotSelector",
|
||||
"FewShotExample",
|
||||
|
||||
# Reasoning memory
|
||||
"ReasoningMemory",
|
||||
"ReasoningTrace",
|
||||
"FailureRecord",
|
||||
|
||||
# Reasoning templates
|
||||
"get_reasoning_template",
|
||||
"format_reasoning_prompt",
|
||||
"get_available_types",
|
||||
"REASONING_TEMPLATES",
|
||||
|
||||
# Vector store
|
||||
"BaseVectorStore",
|
||||
"BM25VectorStore",
|
||||
"RetrievedChunk",
|
||||
"Document",
|
||||
"create_vectorstore",
|
||||
]
|
||||
@@ -0,0 +1,877 @@
|
||||
"""
|
||||
RAG Engine - Retrieval-Augmented Generation for enhanced AI reasoning.
|
||||
|
||||
Indexes all knowledge sources (bug bounty reports, vuln KB, custom docs,
|
||||
reasoning traces) and provides semantic retrieval for context-enriched
|
||||
LLM prompts. Does NOT modify the model - only augments input context.
|
||||
|
||||
Collections:
|
||||
- bug_bounty_patterns: 9131 real-world vulnerability reports
|
||||
- vuln_methodologies: 100 vulnerability type methodologies
|
||||
- custom_knowledge: User-uploaded research documents
|
||||
- reasoning_traces: Successful reasoning chains from past scans
|
||||
- attack_patterns: Extracted attack patterns and techniques
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Optional, Any, Tuple
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from .vectorstore import (
|
||||
BaseVectorStore, Document, RetrievedChunk,
|
||||
create_vectorstore
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Collection names
|
||||
COL_BUG_BOUNTY = "bug_bounty_patterns"
|
||||
COL_VULN_METHODS = "vuln_methodologies"
|
||||
COL_CUSTOM = "custom_knowledge"
|
||||
COL_REASONING = "reasoning_traces"
|
||||
COL_ATTACK = "attack_patterns"
|
||||
|
||||
# Defaults
|
||||
DEFAULT_TOP_K = 5
|
||||
MAX_CONTEXT_CHARS = 4000
|
||||
INDEX_BATCH_SIZE = 200
|
||||
|
||||
|
||||
@dataclass
|
||||
class RAGContext:
|
||||
"""Assembled RAG context for a specific query."""
|
||||
query: str
|
||||
chunks: List[RetrievedChunk] = field(default_factory=list)
|
||||
total_score: float = 0.0
|
||||
sources_used: List[str] = field(default_factory=list)
|
||||
token_estimate: int = 0
|
||||
|
||||
def to_prompt_text(self, max_chars: int = MAX_CONTEXT_CHARS) -> str:
|
||||
"""Format retrieved context for injection into LLM prompt."""
|
||||
if not self.chunks:
|
||||
return ""
|
||||
|
||||
sections = []
|
||||
current_len = 0
|
||||
|
||||
for chunk in self.chunks:
|
||||
source_label = chunk.metadata.get("source_type", chunk.source)
|
||||
vuln_type = chunk.metadata.get("vuln_type", "")
|
||||
score_pct = int(chunk.score * 100) if chunk.score <= 1.0 else int(chunk.score)
|
||||
|
||||
header = f"[{source_label}]"
|
||||
if vuln_type:
|
||||
header += f" ({vuln_type})"
|
||||
header += f" [relevance: {score_pct}%]"
|
||||
|
||||
text = chunk.text.strip()
|
||||
section = f"{header}\n{text}\n"
|
||||
|
||||
if current_len + len(section) > max_chars:
|
||||
remaining = max_chars - current_len - len(header) - 20
|
||||
if remaining > 100:
|
||||
section = f"{header}\n{text[:remaining]}...\n"
|
||||
else:
|
||||
break
|
||||
|
||||
sections.append(section)
|
||||
current_len += len(section)
|
||||
|
||||
if not sections:
|
||||
return ""
|
||||
|
||||
result = "=== RETRIEVED KNOWLEDGE (RAG) ===\n"
|
||||
result += "Use this knowledge to inform your analysis. Adapt techniques to the target.\n\n"
|
||||
result += "\n---\n".join(sections)
|
||||
result += "\n=== END RETRIEVED KNOWLEDGE ===\n"
|
||||
|
||||
self.token_estimate = len(result) // 4 # rough token estimate
|
||||
return result
|
||||
|
||||
|
||||
class RAGEngine:
|
||||
"""
|
||||
Main RAG orchestrator. Indexes knowledge sources and provides
|
||||
semantic retrieval for context-enriched AI reasoning.
|
||||
"""
|
||||
|
||||
def __init__(self, data_dir: str = "data", backend: str = "auto",
|
||||
persist_dir: str = None):
|
||||
self.data_dir = Path(data_dir)
|
||||
self.persist_dir = persist_dir or str(self.data_dir / "vectorstore")
|
||||
|
||||
self.store: BaseVectorStore = create_vectorstore(
|
||||
self.persist_dir, backend=backend
|
||||
)
|
||||
|
||||
self._indexed = False
|
||||
self._index_stats: Dict[str, int] = {}
|
||||
|
||||
logger.info(f"RAG Engine initialized with '{self.store.backend_name}' backend")
|
||||
|
||||
@property
|
||||
def backend_name(self) -> str:
|
||||
return self.store.backend_name
|
||||
|
||||
@property
|
||||
def is_indexed(self) -> bool:
|
||||
return self._indexed
|
||||
|
||||
def get_stats(self) -> Dict:
|
||||
"""Return indexing statistics."""
|
||||
stats = {
|
||||
"backend": self.store.backend_name,
|
||||
"indexed": self._indexed,
|
||||
"collections": {}
|
||||
}
|
||||
for col_name in [COL_BUG_BOUNTY, COL_VULN_METHODS, COL_CUSTOM,
|
||||
COL_REASONING, COL_ATTACK]:
|
||||
count = self.store.collection_count(col_name)
|
||||
if count > 0:
|
||||
stats["collections"][col_name] = count
|
||||
return stats
|
||||
|
||||
# ── Indexing ────────────────────────────────────────────────
|
||||
|
||||
def index_all(self, force: bool = False) -> Dict[str, int]:
|
||||
"""
|
||||
Index all available knowledge sources.
|
||||
Returns dict of collection_name -> documents_indexed.
|
||||
"""
|
||||
stats = {}
|
||||
|
||||
# Only re-index if forced or collections are empty
|
||||
if not force and self._all_collections_populated():
|
||||
logger.info("RAG: All collections already populated, skipping index")
|
||||
self._indexed = True
|
||||
return stats
|
||||
|
||||
start = time.time()
|
||||
|
||||
stats[COL_BUG_BOUNTY] = self._index_bug_bounty()
|
||||
stats[COL_VULN_METHODS] = self._index_vuln_knowledge_base()
|
||||
stats[COL_CUSTOM] = self._index_custom_knowledge()
|
||||
stats[COL_ATTACK] = self._index_attack_patterns()
|
||||
|
||||
elapsed = time.time() - start
|
||||
total = sum(stats.values())
|
||||
self._indexed = True
|
||||
self._index_stats = stats
|
||||
|
||||
logger.info(f"RAG: Indexed {total} documents across {len(stats)} collections in {elapsed:.1f}s")
|
||||
return stats
|
||||
|
||||
def _all_collections_populated(self) -> bool:
|
||||
"""Check if main collections already have data."""
|
||||
return (self.store.collection_exists(COL_BUG_BOUNTY) and
|
||||
self.store.collection_exists(COL_VULN_METHODS))
|
||||
|
||||
def _index_bug_bounty(self) -> int:
|
||||
"""Index the bug bounty finetuning dataset."""
|
||||
dataset_path = Path("models/bug-bounty/bugbounty_finetuning_dataset.json")
|
||||
if not dataset_path.exists():
|
||||
logger.warning(f"RAG: Bug bounty dataset not found at {dataset_path}")
|
||||
return 0
|
||||
|
||||
if self.store.collection_exists(COL_BUG_BOUNTY):
|
||||
existing = self.store.collection_count(COL_BUG_BOUNTY)
|
||||
if existing > 1000:
|
||||
logger.info(f"RAG: Bug bounty already indexed ({existing} docs)")
|
||||
return 0
|
||||
|
||||
try:
|
||||
with open(dataset_path, 'r', encoding='utf-8') as f:
|
||||
entries = json.load(f)
|
||||
except Exception as e:
|
||||
logger.error(f"RAG: Failed to load bug bounty dataset: {e}")
|
||||
return 0
|
||||
|
||||
if not isinstance(entries, list):
|
||||
return 0
|
||||
|
||||
documents = []
|
||||
for i, entry in enumerate(entries):
|
||||
instruction = entry.get("instruction", "")
|
||||
output = entry.get("output", "")
|
||||
|
||||
if not output or len(output) < 50:
|
||||
continue
|
||||
|
||||
# Extract vulnerability types from content
|
||||
vuln_types = self._detect_vuln_types(instruction + " " + output)
|
||||
|
||||
# Extract technologies
|
||||
technologies = self._detect_technologies(output)
|
||||
|
||||
# Chunk 1: Full methodology (primary chunk)
|
||||
methodology = self._extract_section(output, [
|
||||
"passos para reproduzir", "steps to reproduce",
|
||||
"methodology", "exploitation", "proof of concept",
|
||||
"como reproduzir", "reprodução"
|
||||
])
|
||||
|
||||
if methodology and len(methodology) > 100:
|
||||
documents.append(Document(
|
||||
text=methodology[:4000],
|
||||
metadata={
|
||||
"source_type": "bug_bounty",
|
||||
"vuln_type": vuln_types[0] if vuln_types else "unknown",
|
||||
"vuln_types": ",".join(vuln_types[:5]),
|
||||
"technologies": ",".join(technologies[:5]),
|
||||
"chunk_type": "methodology",
|
||||
"entry_index": i
|
||||
},
|
||||
doc_id=f"bb_method_{i}"
|
||||
))
|
||||
|
||||
# Chunk 2: Summary + Impact (secondary chunk)
|
||||
summary = self._extract_section(output, [
|
||||
"resumo", "summary", "descrição", "description",
|
||||
"overview"
|
||||
])
|
||||
impact = self._extract_section(output, [
|
||||
"impacto", "impact", "severity", "risco"
|
||||
])
|
||||
|
||||
summary_text = f"{instruction}\n\n{summary or output[:500]}"
|
||||
if impact:
|
||||
summary_text += f"\n\nImpact: {impact}"
|
||||
|
||||
documents.append(Document(
|
||||
text=summary_text[:3000],
|
||||
metadata={
|
||||
"source_type": "bug_bounty",
|
||||
"vuln_type": vuln_types[0] if vuln_types else "unknown",
|
||||
"vuln_types": ",".join(vuln_types[:5]),
|
||||
"technologies": ",".join(technologies[:5]),
|
||||
"chunk_type": "summary",
|
||||
"entry_index": i
|
||||
},
|
||||
doc_id=f"bb_summary_{i}"
|
||||
))
|
||||
|
||||
# Chunk 3: Payloads & PoC code (if present)
|
||||
payloads = self._extract_code_blocks(output)
|
||||
if payloads:
|
||||
payload_text = f"Vulnerability: {vuln_types[0] if vuln_types else 'unknown'}\n"
|
||||
payload_text += f"Technologies: {', '.join(technologies[:3])}\n\n"
|
||||
payload_text += "Payloads/PoC:\n" + "\n\n".join(payloads[:10])
|
||||
|
||||
documents.append(Document(
|
||||
text=payload_text[:3000],
|
||||
metadata={
|
||||
"source_type": "bug_bounty",
|
||||
"vuln_type": vuln_types[0] if vuln_types else "unknown",
|
||||
"vuln_types": ",".join(vuln_types[:5]),
|
||||
"technologies": ",".join(technologies[:5]),
|
||||
"chunk_type": "payload",
|
||||
"entry_index": i
|
||||
},
|
||||
doc_id=f"bb_payload_{i}"
|
||||
))
|
||||
|
||||
# Index in batches
|
||||
total_added = 0
|
||||
for start in range(0, len(documents), INDEX_BATCH_SIZE):
|
||||
batch = documents[start:start + INDEX_BATCH_SIZE]
|
||||
added = self.store.add(COL_BUG_BOUNTY, batch)
|
||||
total_added += added
|
||||
|
||||
logger.info(f"RAG: Indexed {total_added} bug bounty chunks from {len(entries)} entries")
|
||||
return total_added
|
||||
|
||||
def _index_vuln_knowledge_base(self) -> int:
|
||||
"""Index the 100-type vulnerability knowledge base."""
|
||||
kb_path = self.data_dir / "vuln_knowledge_base.json"
|
||||
if not kb_path.exists():
|
||||
return 0
|
||||
|
||||
if self.store.collection_exists(COL_VULN_METHODS):
|
||||
existing = self.store.collection_count(COL_VULN_METHODS)
|
||||
if existing >= 90:
|
||||
return 0
|
||||
|
||||
try:
|
||||
with open(kb_path, 'r', encoding='utf-8') as f:
|
||||
kb = json.load(f)
|
||||
except Exception as e:
|
||||
logger.error(f"RAG: Failed to load vuln KB: {e}")
|
||||
return 0
|
||||
|
||||
vuln_types = kb.get("vulnerability_types", {})
|
||||
if not vuln_types:
|
||||
return 0
|
||||
|
||||
documents = []
|
||||
for vuln_type, info in vuln_types.items():
|
||||
text = f"Vulnerability: {info.get('title', vuln_type)}\n"
|
||||
text += f"Type: {vuln_type}\n"
|
||||
text += f"CWE: {info.get('cwe_id', 'N/A')}\n"
|
||||
text += f"Severity: {info.get('severity', 'N/A')}\n\n"
|
||||
text += f"Description: {info.get('description', '')}\n\n"
|
||||
text += f"Impact: {info.get('impact', '')}\n\n"
|
||||
text += f"Remediation: {info.get('remediation', '')}\n"
|
||||
|
||||
fp_markers = info.get("false_positive_markers", [])
|
||||
if fp_markers:
|
||||
text += f"\nFalse Positive Indicators: {', '.join(fp_markers)}\n"
|
||||
|
||||
documents.append(Document(
|
||||
text=text,
|
||||
metadata={
|
||||
"source_type": "vuln_kb",
|
||||
"vuln_type": vuln_type,
|
||||
"severity": info.get("severity", "medium"),
|
||||
"cwe_id": info.get("cwe_id", ""),
|
||||
"chunk_type": "methodology"
|
||||
},
|
||||
doc_id=f"vkb_{vuln_type}"
|
||||
))
|
||||
|
||||
# Index XBOW insights if available
|
||||
xbow = kb.get("xbow_insights", {})
|
||||
if xbow:
|
||||
for category, insights in xbow.items():
|
||||
if isinstance(insights, str):
|
||||
text = f"XBOW Benchmark Insight - {category}:\n{insights}"
|
||||
elif isinstance(insights, dict):
|
||||
text = f"XBOW Benchmark Insight - {category}:\n{json.dumps(insights, indent=2)}"
|
||||
elif isinstance(insights, list):
|
||||
text = f"XBOW Benchmark Insight - {category}:\n" + "\n".join(str(i) for i in insights)
|
||||
else:
|
||||
continue
|
||||
|
||||
documents.append(Document(
|
||||
text=text[:3000],
|
||||
metadata={
|
||||
"source_type": "vuln_kb",
|
||||
"vuln_type": category,
|
||||
"chunk_type": "insight"
|
||||
},
|
||||
doc_id=f"xbow_{category}"
|
||||
))
|
||||
|
||||
added = self.store.add(COL_VULN_METHODS, documents)
|
||||
logger.info(f"RAG: Indexed {added} vuln KB entries")
|
||||
return added
|
||||
|
||||
def _index_custom_knowledge(self) -> int:
|
||||
"""Index user-uploaded custom knowledge documents."""
|
||||
index_path = self.data_dir / "custom-knowledge" / "index.json"
|
||||
if not index_path.exists():
|
||||
return 0
|
||||
|
||||
try:
|
||||
with open(index_path, 'r', encoding='utf-8') as f:
|
||||
index = json.load(f)
|
||||
except Exception:
|
||||
return 0
|
||||
|
||||
documents = []
|
||||
for doc_entry in index.get("documents", []):
|
||||
for entry in doc_entry.get("knowledge_entries", []):
|
||||
vuln_type = entry.get("vuln_type", "unknown")
|
||||
text = f"Custom Knowledge - {vuln_type}\n"
|
||||
text += f"Source: {doc_entry.get('filename', 'unknown')}\n\n"
|
||||
|
||||
if entry.get("methodology"):
|
||||
text += f"Methodology: {entry['methodology']}\n\n"
|
||||
if entry.get("key_insights"):
|
||||
if isinstance(entry["key_insights"], list):
|
||||
text += "Key Insights:\n" + "\n".join(f"- {i}" for i in entry["key_insights"]) + "\n\n"
|
||||
else:
|
||||
text += f"Key Insights: {entry['key_insights']}\n\n"
|
||||
if entry.get("payloads"):
|
||||
payloads = entry["payloads"][:10]
|
||||
text += "Payloads:\n" + "\n".join(f" {p}" for p in payloads) + "\n\n"
|
||||
if entry.get("bypass_techniques"):
|
||||
techniques = entry["bypass_techniques"][:10]
|
||||
text += "Bypass Techniques:\n" + "\n".join(f"- {t}" for t in techniques) + "\n"
|
||||
|
||||
documents.append(Document(
|
||||
text=text[:4000],
|
||||
metadata={
|
||||
"source_type": "custom",
|
||||
"vuln_type": vuln_type,
|
||||
"filename": doc_entry.get("filename", ""),
|
||||
"chunk_type": "methodology"
|
||||
},
|
||||
doc_id=f"custom_{doc_entry.get('id', '')}_{vuln_type}"
|
||||
))
|
||||
|
||||
if not documents:
|
||||
return 0
|
||||
|
||||
added = self.store.add(COL_CUSTOM, documents)
|
||||
logger.info(f"RAG: Indexed {added} custom knowledge entries")
|
||||
return added
|
||||
|
||||
def _index_attack_patterns(self) -> int:
|
||||
"""Index extracted attack patterns from execution history."""
|
||||
hist_path = self.data_dir / "execution_history.json"
|
||||
if not hist_path.exists():
|
||||
return 0
|
||||
|
||||
try:
|
||||
with open(hist_path, 'r', encoding='utf-8') as f:
|
||||
history = json.load(f)
|
||||
except Exception:
|
||||
return 0
|
||||
|
||||
attacks = history.get("attacks", [])
|
||||
if not attacks:
|
||||
return 0
|
||||
|
||||
# Group successful attacks by vuln_type + tech
|
||||
successes: Dict[str, List[Dict]] = {}
|
||||
for attack in attacks:
|
||||
if not attack.get("success"):
|
||||
continue
|
||||
key = f"{attack.get('vuln_type', '')}_{attack.get('tech', '')}"
|
||||
if key not in successes:
|
||||
successes[key] = []
|
||||
successes[key].append(attack)
|
||||
|
||||
documents = []
|
||||
for key, attack_list in successes.items():
|
||||
vuln_type = attack_list[0].get("vuln_type", "unknown")
|
||||
tech = attack_list[0].get("tech", "unknown")
|
||||
|
||||
text = f"Successful Attack Pattern: {vuln_type} on {tech}\n"
|
||||
text += f"Success count: {len(attack_list)}\n\n"
|
||||
|
||||
for atk in attack_list[:5]:
|
||||
evidence = atk.get("evidence_preview", "")
|
||||
domain = atk.get("target_domain", "")
|
||||
text += f"- Target: {domain}, Evidence: {evidence}\n"
|
||||
|
||||
documents.append(Document(
|
||||
text=text[:2000],
|
||||
metadata={
|
||||
"source_type": "attack_pattern",
|
||||
"vuln_type": vuln_type,
|
||||
"technology": tech,
|
||||
"success_count": len(attack_list),
|
||||
"chunk_type": "pattern"
|
||||
},
|
||||
doc_id=f"atk_{key}"
|
||||
))
|
||||
|
||||
if not documents:
|
||||
return 0
|
||||
|
||||
added = self.store.add(COL_ATTACK, documents)
|
||||
logger.info(f"RAG: Indexed {added} attack patterns")
|
||||
return added
|
||||
|
||||
def index_reasoning_trace(self, trace: Dict) -> bool:
|
||||
"""
|
||||
Index a successful reasoning trace for future retrieval.
|
||||
Called when a finding is confirmed.
|
||||
|
||||
trace = {
|
||||
"vuln_type": str,
|
||||
"technology": str,
|
||||
"endpoint": str,
|
||||
"reasoning_chain": List[str],
|
||||
"payload_used": str,
|
||||
"evidence": str,
|
||||
"confidence": float,
|
||||
"timestamp": float
|
||||
}
|
||||
"""
|
||||
vuln_type = trace.get("vuln_type", "unknown")
|
||||
tech = trace.get("technology", "unknown")
|
||||
|
||||
text = f"Confirmed Reasoning Trace - {vuln_type}\n"
|
||||
text += f"Technology: {tech}\n"
|
||||
text += f"Endpoint: {trace.get('endpoint', '')}\n"
|
||||
text += f"Confidence: {trace.get('confidence', 0):.0%}\n\n"
|
||||
|
||||
chain = trace.get("reasoning_chain", [])
|
||||
if chain:
|
||||
text += "Reasoning Chain:\n"
|
||||
for i, step in enumerate(chain, 1):
|
||||
text += f" {i}. {step}\n"
|
||||
text += "\n"
|
||||
|
||||
if trace.get("payload_used"):
|
||||
text += f"Payload Used: {trace['payload_used']}\n"
|
||||
if trace.get("evidence"):
|
||||
text += f"Evidence: {trace['evidence'][:500]}\n"
|
||||
|
||||
doc = Document(
|
||||
text=text[:3000],
|
||||
metadata={
|
||||
"source_type": "reasoning_trace",
|
||||
"vuln_type": vuln_type,
|
||||
"technology": tech,
|
||||
"confidence": trace.get("confidence", 0),
|
||||
"chunk_type": "reasoning",
|
||||
"timestamp": trace.get("timestamp", time.time())
|
||||
},
|
||||
doc_id=f"trace_{vuln_type}_{int(time.time())}"
|
||||
)
|
||||
|
||||
try:
|
||||
self.store.add(COL_REASONING, [doc])
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"RAG: Failed to index reasoning trace: {e}")
|
||||
return False
|
||||
|
||||
# ── Querying ────────────────────────────────────────────────
|
||||
|
||||
def query(self, query_text: str, collections: List[str] = None,
|
||||
top_k: int = DEFAULT_TOP_K,
|
||||
vuln_type: str = None,
|
||||
technology: str = None,
|
||||
chunk_type: str = None) -> RAGContext:
|
||||
"""
|
||||
Query across collections for relevant knowledge.
|
||||
|
||||
Args:
|
||||
query_text: The search query
|
||||
collections: Which collections to search (default: all)
|
||||
top_k: Number of results per collection
|
||||
vuln_type: Filter by vulnerability type
|
||||
technology: Filter by technology
|
||||
chunk_type: Filter by chunk type (methodology, payload, summary, etc.)
|
||||
|
||||
Returns:
|
||||
RAGContext with ranked, deduplicated results
|
||||
"""
|
||||
if not collections:
|
||||
collections = [COL_BUG_BOUNTY, COL_VULN_METHODS, COL_CUSTOM,
|
||||
COL_REASONING, COL_ATTACK]
|
||||
|
||||
# Build metadata filter
|
||||
meta_filter = {}
|
||||
if vuln_type:
|
||||
meta_filter["vuln_type"] = vuln_type
|
||||
if chunk_type:
|
||||
meta_filter["chunk_type"] = chunk_type
|
||||
|
||||
all_chunks: List[RetrievedChunk] = []
|
||||
sources_used = []
|
||||
|
||||
for col_name in collections:
|
||||
if not self.store.collection_exists(col_name):
|
||||
continue
|
||||
|
||||
chunks = self.store.query(
|
||||
collection=col_name,
|
||||
query_text=query_text,
|
||||
top_k=top_k,
|
||||
metadata_filter=meta_filter if meta_filter else None
|
||||
)
|
||||
|
||||
if chunks:
|
||||
all_chunks.extend(chunks)
|
||||
sources_used.append(col_name)
|
||||
|
||||
# Also search with technology-enhanced query if provided
|
||||
if technology and technology not in query_text.lower():
|
||||
enhanced_query = f"{query_text} {technology}"
|
||||
for col_name in collections:
|
||||
if not self.store.collection_exists(col_name):
|
||||
continue
|
||||
chunks = self.store.query(
|
||||
collection=col_name,
|
||||
query_text=enhanced_query,
|
||||
top_k=max(2, top_k // 2),
|
||||
metadata_filter=meta_filter if meta_filter else None
|
||||
)
|
||||
if chunks:
|
||||
all_chunks.extend(chunks)
|
||||
|
||||
# Deduplicate by chunk_id
|
||||
seen = set()
|
||||
unique_chunks = []
|
||||
for chunk in all_chunks:
|
||||
if chunk.chunk_id not in seen:
|
||||
seen.add(chunk.chunk_id)
|
||||
unique_chunks.append(chunk)
|
||||
|
||||
# Sort by relevance score
|
||||
unique_chunks.sort(key=lambda c: c.score, reverse=True)
|
||||
|
||||
# Limit total results
|
||||
max_results = top_k * 2
|
||||
unique_chunks = unique_chunks[:max_results]
|
||||
|
||||
total_score = sum(c.score for c in unique_chunks)
|
||||
|
||||
return RAGContext(
|
||||
query=query_text,
|
||||
chunks=unique_chunks,
|
||||
total_score=total_score,
|
||||
sources_used=sources_used
|
||||
)
|
||||
|
||||
def get_testing_context(self, vuln_type: str, target_url: str = "",
|
||||
technology: str = "", endpoint: str = "",
|
||||
parameter: str = "",
|
||||
max_chars: int = MAX_CONTEXT_CHARS) -> str:
|
||||
"""
|
||||
Get optimized RAG context for vulnerability testing.
|
||||
Combines methodology, real examples, and attack patterns.
|
||||
"""
|
||||
# Build a rich query
|
||||
query_parts = [vuln_type.replace("_", " ")]
|
||||
if technology:
|
||||
query_parts.append(technology)
|
||||
if endpoint:
|
||||
query_parts.append(f"endpoint {endpoint}")
|
||||
if parameter:
|
||||
query_parts.append(f"parameter {parameter}")
|
||||
|
||||
query = " ".join(query_parts)
|
||||
|
||||
# Query with vuln_type preference
|
||||
context = self.query(
|
||||
query_text=query,
|
||||
vuln_type=vuln_type,
|
||||
technology=technology,
|
||||
top_k=5
|
||||
)
|
||||
|
||||
# Also get broader results without vuln_type filter
|
||||
broad_context = self.query(
|
||||
query_text=query,
|
||||
technology=technology,
|
||||
top_k=3
|
||||
)
|
||||
|
||||
# Merge, preferring vuln-specific results
|
||||
seen = {c.chunk_id for c in context.chunks}
|
||||
for chunk in broad_context.chunks:
|
||||
if chunk.chunk_id not in seen:
|
||||
context.chunks.append(chunk)
|
||||
seen.add(chunk.chunk_id)
|
||||
|
||||
# Re-sort and limit
|
||||
context.chunks.sort(key=lambda c: c.score, reverse=True)
|
||||
context.chunks = context.chunks[:8]
|
||||
|
||||
return context.to_prompt_text(max_chars=max_chars)
|
||||
|
||||
def get_verification_context(self, vuln_type: str, evidence: str,
|
||||
technology: str = "",
|
||||
max_chars: int = 2000) -> str:
|
||||
"""
|
||||
Get RAG context for finding verification/judgment.
|
||||
Focuses on confirmed examples and false positive patterns.
|
||||
"""
|
||||
query = f"{vuln_type.replace('_', ' ')} verification proof confirmed {evidence[:200]}"
|
||||
if technology:
|
||||
query += f" {technology}"
|
||||
|
||||
# Get confirmed reasoning traces
|
||||
trace_ctx = self.query(
|
||||
query_text=query,
|
||||
collections=[COL_REASONING],
|
||||
vuln_type=vuln_type,
|
||||
top_k=3
|
||||
)
|
||||
|
||||
# Get methodology for verification criteria
|
||||
method_ctx = self.query(
|
||||
query_text=f"{vuln_type} false positive verification criteria proof",
|
||||
collections=[COL_VULN_METHODS, COL_BUG_BOUNTY],
|
||||
vuln_type=vuln_type,
|
||||
chunk_type="methodology",
|
||||
top_k=3
|
||||
)
|
||||
|
||||
# Combine
|
||||
all_chunks = trace_ctx.chunks + method_ctx.chunks
|
||||
all_chunks.sort(key=lambda c: c.score, reverse=True)
|
||||
|
||||
combined = RAGContext(
|
||||
query=query,
|
||||
chunks=all_chunks[:6],
|
||||
total_score=sum(c.score for c in all_chunks[:6]),
|
||||
sources_used=list(set(trace_ctx.sources_used + method_ctx.sources_used))
|
||||
)
|
||||
|
||||
return combined.to_prompt_text(max_chars=max_chars)
|
||||
|
||||
def get_strategy_context(self, technologies: List[str],
|
||||
endpoints: List[str] = None,
|
||||
max_chars: int = 3000) -> str:
|
||||
"""
|
||||
Get RAG context for attack strategy planning.
|
||||
Focuses on tech-specific patterns and successful attack history.
|
||||
"""
|
||||
query_parts = ["penetration testing attack strategy"]
|
||||
query_parts.extend(technologies[:3])
|
||||
if endpoints:
|
||||
query_parts.extend(endpoints[:3])
|
||||
|
||||
query = " ".join(query_parts)
|
||||
|
||||
# Get attack patterns
|
||||
attack_ctx = self.query(
|
||||
query_text=query,
|
||||
collections=[COL_ATTACK, COL_BUG_BOUNTY],
|
||||
top_k=5
|
||||
)
|
||||
|
||||
# Get methodology per technology
|
||||
for tech in technologies[:2]:
|
||||
tech_ctx = self.query(
|
||||
query_text=f"{tech} common vulnerabilities exploitation",
|
||||
collections=[COL_BUG_BOUNTY, COL_VULN_METHODS],
|
||||
technology=tech,
|
||||
top_k=3
|
||||
)
|
||||
attack_ctx.chunks.extend(tech_ctx.chunks)
|
||||
|
||||
# Deduplicate and sort
|
||||
seen = set()
|
||||
unique = []
|
||||
for c in attack_ctx.chunks:
|
||||
if c.chunk_id not in seen:
|
||||
seen.add(c.chunk_id)
|
||||
unique.append(c)
|
||||
unique.sort(key=lambda c: c.score, reverse=True)
|
||||
|
||||
combined = RAGContext(
|
||||
query=query,
|
||||
chunks=unique[:10],
|
||||
total_score=sum(c.score for c in unique[:10]),
|
||||
sources_used=attack_ctx.sources_used
|
||||
)
|
||||
|
||||
return combined.to_prompt_text(max_chars=max_chars)
|
||||
|
||||
# ── Helpers ─────────────────────────────────────────────────
|
||||
|
||||
def _detect_vuln_types(self, text: str) -> List[str]:
|
||||
"""Detect vulnerability types mentioned in text."""
|
||||
text_lower = text.lower()
|
||||
VULN_KEYWORDS = {
|
||||
"xss": ["xss", "cross-site scripting", "cross site scripting", "script injection", "reflected xss", "stored xss"],
|
||||
"sqli": ["sql injection", "sqli", "sql injeção", "union select", "sqlmap"],
|
||||
"ssrf": ["ssrf", "server-side request forgery", "server side request"],
|
||||
"idor": ["idor", "insecure direct object", "referência direta"],
|
||||
"rce": ["rce", "remote code execution", "command injection", "execução remota", "os command"],
|
||||
"lfi": ["lfi", "local file inclusion", "path traversal", "directory traversal", "inclusão de arquivo"],
|
||||
"ssti": ["ssti", "server-side template injection", "template injection", "jinja", "twig"],
|
||||
"xxe": ["xxe", "xml external entity", "xml injection"],
|
||||
"csrf": ["csrf", "cross-site request forgery", "request forgery"],
|
||||
"open_redirect": ["open redirect", "redirecionamento aberto", "redirect"],
|
||||
"auth_bypass": ["authentication bypass", "auth bypass", "bypass autenticação"],
|
||||
"race_condition": ["race condition", "condição de corrida", "toctou"],
|
||||
"deserialization": ["deserialization", "deserialização", "unserialize", "pickle"],
|
||||
"upload": ["file upload", "upload", "unrestricted upload"],
|
||||
"cors": ["cors", "cross-origin"],
|
||||
"prototype_pollution": ["prototype pollution", "poluição de protótipo"],
|
||||
"request_smuggling": ["request smuggling", "http smuggling", "cl.te", "te.cl"],
|
||||
"graphql": ["graphql", "introspection"],
|
||||
"jwt": ["jwt", "json web token"],
|
||||
"nosql": ["nosql injection", "mongodb injection", "nosql"],
|
||||
"crlf": ["crlf injection", "header injection", "injeção de cabeçalho"],
|
||||
"subdomain_takeover": ["subdomain takeover", "tomada de subdomínio"],
|
||||
"information_disclosure": ["information disclosure", "divulgação de informação", "sensitive data"],
|
||||
"bola": ["bola", "broken object level"],
|
||||
"bfla": ["bfla", "broken function level"],
|
||||
"privilege_escalation": ["privilege escalation", "escalação de privilégio"],
|
||||
}
|
||||
|
||||
detected = []
|
||||
for vuln_type, keywords in VULN_KEYWORDS.items():
|
||||
for kw in keywords:
|
||||
if kw in text_lower:
|
||||
detected.append(vuln_type)
|
||||
break
|
||||
|
||||
return detected if detected else ["unknown"]
|
||||
|
||||
def _detect_technologies(self, text: str) -> List[str]:
|
||||
"""Detect technologies mentioned in text."""
|
||||
text_lower = text.lower()
|
||||
TECH_KEYWORDS = {
|
||||
"php": ["php", "laravel", "wordpress", "drupal", "symfony", "codeigniter"],
|
||||
"python": ["python", "django", "flask", "fastapi", "tornado"],
|
||||
"java": ["java", "spring", "struts", "tomcat", "jboss", "wildfly"],
|
||||
"node": ["node.js", "nodejs", "express", "next.js", "nuxt"],
|
||||
"ruby": ["ruby", "rails", "sinatra"],
|
||||
"dotnet": [".net", "asp.net", "c#", "iis"],
|
||||
"go": ["golang", " go ", "gin", "echo"],
|
||||
"nginx": ["nginx"],
|
||||
"apache": ["apache", "httpd"],
|
||||
"react": ["react", "reactjs"],
|
||||
"angular": ["angular"],
|
||||
"vue": ["vue.js", "vuejs"],
|
||||
"graphql": ["graphql"],
|
||||
"docker": ["docker", "kubernetes", "k8s"],
|
||||
"aws": ["aws", "amazon", "s3", "lambda", "ec2"],
|
||||
"azure": ["azure", "microsoft cloud"],
|
||||
"mysql": ["mysql", "mariadb"],
|
||||
"postgres": ["postgresql", "postgres"],
|
||||
"mongodb": ["mongodb", "mongo"],
|
||||
"redis": ["redis"],
|
||||
}
|
||||
|
||||
detected = []
|
||||
for tech, keywords in TECH_KEYWORDS.items():
|
||||
for kw in keywords:
|
||||
if kw in text_lower:
|
||||
detected.append(tech)
|
||||
break
|
||||
|
||||
return detected
|
||||
|
||||
def _extract_section(self, text: str, markers: List[str],
|
||||
max_chars: int = 2000) -> Optional[str]:
|
||||
"""Extract a section from text based on header markers."""
|
||||
text_lower = text.lower()
|
||||
|
||||
for marker in markers:
|
||||
idx = text_lower.find(marker)
|
||||
if idx != -1:
|
||||
# Find section start (after the marker line)
|
||||
newline_after = text.find("\n", idx)
|
||||
if newline_after == -1:
|
||||
continue
|
||||
section_start = newline_after + 1
|
||||
|
||||
# Find section end (next ## header or end)
|
||||
next_header = re.search(r'\n#{1,3}\s', text[section_start:])
|
||||
if next_header:
|
||||
section_end = section_start + next_header.start()
|
||||
else:
|
||||
section_end = min(section_start + max_chars, len(text))
|
||||
|
||||
section = text[section_start:section_end].strip()
|
||||
if len(section) > 50:
|
||||
return section[:max_chars]
|
||||
|
||||
return None
|
||||
|
||||
def _extract_code_blocks(self, text: str) -> List[str]:
|
||||
"""Extract code blocks and payloads from text."""
|
||||
blocks = []
|
||||
|
||||
# Fenced code blocks
|
||||
for match in re.finditer(r'```[\w]*\n(.*?)```', text, re.DOTALL):
|
||||
code = match.group(1).strip()
|
||||
if len(code) > 20:
|
||||
blocks.append(code[:500])
|
||||
|
||||
# Inline code with attack indicators
|
||||
for match in re.finditer(r'`([^`]{10,500})`', text):
|
||||
code = match.group(1)
|
||||
attack_indicators = ['<script', 'alert(', 'SELECT', 'UNION',
|
||||
'../', 'curl ', 'wget ', '{{', '${',
|
||||
'eval(', 'exec(', 'system(']
|
||||
if any(ind in code for ind in attack_indicators):
|
||||
blocks.append(code)
|
||||
|
||||
return blocks[:20]
|
||||
@@ -0,0 +1,644 @@
|
||||
"""
|
||||
Few-Shot Example Selector for RAG-enhanced reasoning.
|
||||
|
||||
Selects the most relevant real-world bug bounty examples and formats
|
||||
them as few-shot reasoning demonstrations for the LLM. This teaches
|
||||
the model HOW to reason about vulnerabilities by showing worked examples.
|
||||
|
||||
The key insight: instead of just giving the AI information, we show it
|
||||
examples of successful reasoning chains, so it learns the PATTERN of
|
||||
good pentesting analysis.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Optional, Tuple
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FewShotExample:
|
||||
"""A formatted few-shot example for prompt injection."""
|
||||
vuln_type: str
|
||||
technology: str
|
||||
scenario: str
|
||||
reasoning_chain: List[str]
|
||||
outcome: str
|
||||
payload: str = ""
|
||||
proof: str = ""
|
||||
score: float = 0.0
|
||||
|
||||
def format(self, include_chain: bool = True) -> str:
|
||||
"""Format as a prompt-ready example."""
|
||||
text = f"--- Example: {self.vuln_type.upper()} in {self.technology} ---\n"
|
||||
text += f"Scenario: {self.scenario}\n"
|
||||
|
||||
if include_chain and self.reasoning_chain:
|
||||
text += "Reasoning:\n"
|
||||
for i, step in enumerate(self.reasoning_chain, 1):
|
||||
text += f" {i}. {step}\n"
|
||||
|
||||
if self.payload:
|
||||
text += f"Payload: {self.payload}\n"
|
||||
|
||||
text += f"Outcome: {self.outcome}\n"
|
||||
|
||||
if self.proof:
|
||||
text += f"Proof: {self.proof}\n"
|
||||
|
||||
return text
|
||||
|
||||
|
||||
class FewShotSelector:
|
||||
"""
|
||||
Selects and formats few-shot examples for LLM prompts.
|
||||
|
||||
Provides three types of examples:
|
||||
1. Testing examples: How to test for a specific vuln type
|
||||
2. Verification examples: How to verify if a finding is real
|
||||
3. Strategy examples: How to plan an attack approach
|
||||
"""
|
||||
|
||||
def __init__(self, rag_engine=None, dataset_path: str = None):
|
||||
"""
|
||||
Args:
|
||||
rag_engine: RAGEngine instance for semantic retrieval
|
||||
dataset_path: Path to bug bounty dataset (fallback if no RAG engine)
|
||||
"""
|
||||
self.rag_engine = rag_engine
|
||||
self.dataset_path = dataset_path or "models/bug-bounty/bugbounty_finetuning_dataset.json"
|
||||
self._example_cache: Dict[str, List[FewShotExample]] = {}
|
||||
self._curated_examples = self._build_curated_examples()
|
||||
|
||||
def get_testing_examples(self, vuln_type: str, technology: str = "",
|
||||
max_examples: int = 3) -> str:
|
||||
"""
|
||||
Get few-shot examples showing how to test for a vulnerability type.
|
||||
Demonstrates the reasoning chain: observe → hypothesize → test → verify.
|
||||
"""
|
||||
cache_key = f"test_{vuln_type}_{technology}"
|
||||
if cache_key in self._example_cache:
|
||||
examples = self._example_cache[cache_key][:max_examples]
|
||||
return self._format_examples(examples, "TESTING EXAMPLES")
|
||||
|
||||
examples = []
|
||||
|
||||
# 1. Try curated examples first (highest quality)
|
||||
curated = self._get_curated_for_type(vuln_type)
|
||||
examples.extend(curated)
|
||||
|
||||
# 2. Get RAG-retrieved examples
|
||||
if self.rag_engine:
|
||||
rag_examples = self._retrieve_rag_examples(
|
||||
vuln_type, technology, "testing", max_examples
|
||||
)
|
||||
examples.extend(rag_examples)
|
||||
|
||||
# 3. Deduplicate and rank
|
||||
examples = self._rank_examples(examples, vuln_type, technology)[:max_examples]
|
||||
|
||||
self._example_cache[cache_key] = examples
|
||||
return self._format_examples(examples, "TESTING EXAMPLES")
|
||||
|
||||
def get_verification_examples(self, vuln_type: str, evidence: str = "",
|
||||
max_examples: int = 2) -> str:
|
||||
"""
|
||||
Get few-shot examples showing how to verify/judge a finding.
|
||||
Includes both TRUE POSITIVE and FALSE POSITIVE examples.
|
||||
"""
|
||||
examples = []
|
||||
|
||||
# Get curated verification examples
|
||||
curated_tp = self._get_curated_verification(vuln_type, is_tp=True)
|
||||
curated_fp = self._get_curated_verification(vuln_type, is_tp=False)
|
||||
|
||||
examples.extend(curated_tp[:1])
|
||||
examples.extend(curated_fp[:1])
|
||||
|
||||
# RAG-retrieved
|
||||
if self.rag_engine and len(examples) < max_examples:
|
||||
rag_examples = self._retrieve_rag_examples(
|
||||
vuln_type, "", "verification", max_examples - len(examples)
|
||||
)
|
||||
examples.extend(rag_examples)
|
||||
|
||||
return self._format_examples(examples[:max_examples], "VERIFICATION EXAMPLES")
|
||||
|
||||
def get_strategy_examples(self, technologies: List[str],
|
||||
max_examples: int = 2) -> str:
|
||||
"""
|
||||
Get few-shot examples showing attack strategy planning.
|
||||
"""
|
||||
examples = []
|
||||
|
||||
for tech in technologies[:2]:
|
||||
tech_examples = self._get_curated_strategy(tech)
|
||||
examples.extend(tech_examples)
|
||||
|
||||
if self.rag_engine and len(examples) < max_examples:
|
||||
query = f"penetration testing strategy {' '.join(technologies[:3])}"
|
||||
rag_examples = self._retrieve_rag_examples(
|
||||
"strategy", " ".join(technologies[:3]), "strategy",
|
||||
max_examples - len(examples)
|
||||
)
|
||||
examples.extend(rag_examples)
|
||||
|
||||
return self._format_examples(examples[:max_examples], "STRATEGY EXAMPLES")
|
||||
|
||||
def _retrieve_rag_examples(self, vuln_type: str, technology: str,
|
||||
context: str, max_examples: int) -> List[FewShotExample]:
|
||||
"""Retrieve and convert RAG chunks into few-shot examples."""
|
||||
if not self.rag_engine:
|
||||
return []
|
||||
|
||||
query = f"{vuln_type.replace('_', ' ')} {technology} {context}"
|
||||
rag_ctx = self.rag_engine.query(
|
||||
query_text=query,
|
||||
vuln_type=vuln_type if vuln_type != "strategy" else None,
|
||||
technology=technology if technology else None,
|
||||
top_k=max_examples * 2
|
||||
)
|
||||
|
||||
examples = []
|
||||
for chunk in rag_ctx.chunks[:max_examples]:
|
||||
example = self._chunk_to_example(chunk, vuln_type)
|
||||
if example:
|
||||
examples.append(example)
|
||||
|
||||
return examples
|
||||
|
||||
def _chunk_to_example(self, chunk, vuln_type: str) -> Optional[FewShotExample]:
|
||||
"""Convert a retrieved chunk into a few-shot example."""
|
||||
text = chunk.text
|
||||
meta = chunk.metadata
|
||||
|
||||
# Extract reasoning chain from the text
|
||||
chain = self._extract_reasoning_from_text(text)
|
||||
|
||||
# Extract payload
|
||||
payload = ""
|
||||
payload_match = re.search(r'(?:payload|exploit|poc)[:\s]*[`"]?([^\n`"]{10,200})', text, re.I)
|
||||
if payload_match:
|
||||
payload = payload_match.group(1).strip()
|
||||
|
||||
# Extract outcome
|
||||
outcome = "See methodology above for complete exploitation details."
|
||||
if "confirmed" in text.lower() or "success" in text.lower():
|
||||
outcome = "Vulnerability confirmed with proof of exploitation."
|
||||
elif "impacto" in text.lower() or "impact" in text.lower():
|
||||
impact_match = re.search(r'(?:impacto|impact)[:\s]*(.{20,200})', text, re.I)
|
||||
if impact_match:
|
||||
outcome = impact_match.group(1).strip()
|
||||
|
||||
technology = meta.get("technology", meta.get("technologies", "unknown"))
|
||||
if isinstance(technology, str) and "," in technology:
|
||||
technology = technology.split(",")[0]
|
||||
|
||||
scenario = text[:200].replace("\n", " ").strip()
|
||||
|
||||
return FewShotExample(
|
||||
vuln_type=meta.get("vuln_type", vuln_type),
|
||||
technology=str(technology),
|
||||
scenario=scenario,
|
||||
reasoning_chain=chain,
|
||||
outcome=outcome,
|
||||
payload=payload,
|
||||
score=chunk.score
|
||||
)
|
||||
|
||||
def _extract_reasoning_from_text(self, text: str) -> List[str]:
|
||||
"""Extract reasoning steps from a bug bounty report."""
|
||||
steps = []
|
||||
|
||||
# Try numbered steps
|
||||
numbered = re.findall(r'(?:^|\n)\s*(\d+)\.\s+(.{10,200})', text)
|
||||
if len(numbered) >= 2:
|
||||
for num, step in numbered[:6]:
|
||||
steps.append(step.strip())
|
||||
return steps
|
||||
|
||||
# Try bullet points
|
||||
bullets = re.findall(r'(?:^|\n)\s*[-*]\s+(.{10,200})', text)
|
||||
if len(bullets) >= 2:
|
||||
for bullet in bullets[:6]:
|
||||
steps.append(bullet.strip())
|
||||
return steps
|
||||
|
||||
# Try section-based extraction
|
||||
sections = re.findall(r'###?\s+(.+?)(?:\n|$)', text)
|
||||
for section in sections[:6]:
|
||||
steps.append(section.strip())
|
||||
|
||||
if not steps:
|
||||
# Fall back to sentence extraction
|
||||
sentences = re.split(r'[.!]\s+', text[:800])
|
||||
for sent in sentences[:4]:
|
||||
if len(sent.strip()) > 20:
|
||||
steps.append(sent.strip())
|
||||
|
||||
return steps
|
||||
|
||||
def _rank_examples(self, examples: List[FewShotExample],
|
||||
vuln_type: str, technology: str) -> List[FewShotExample]:
|
||||
"""Rank examples by relevance to the target vuln type and technology."""
|
||||
for example in examples:
|
||||
score = example.score
|
||||
|
||||
# Boost exact vuln type match
|
||||
if example.vuln_type == vuln_type:
|
||||
score += 2.0
|
||||
|
||||
# Boost technology match
|
||||
if technology and technology.lower() in example.technology.lower():
|
||||
score += 1.5
|
||||
|
||||
# Boost examples with reasoning chains
|
||||
if example.reasoning_chain and len(example.reasoning_chain) >= 3:
|
||||
score += 1.0
|
||||
|
||||
# Boost examples with payloads
|
||||
if example.payload:
|
||||
score += 0.5
|
||||
|
||||
# Boost examples with proof
|
||||
if example.proof:
|
||||
score += 0.5
|
||||
|
||||
example.score = score
|
||||
|
||||
examples.sort(key=lambda e: e.score, reverse=True)
|
||||
|
||||
# Deduplicate by scenario similarity
|
||||
seen_starts = set()
|
||||
unique = []
|
||||
for ex in examples:
|
||||
start = ex.scenario[:50].lower()
|
||||
if start not in seen_starts:
|
||||
seen_starts.add(start)
|
||||
unique.append(ex)
|
||||
|
||||
return unique
|
||||
|
||||
def _format_examples(self, examples: List[FewShotExample],
|
||||
header: str) -> str:
|
||||
"""Format examples into a prompt-ready string."""
|
||||
if not examples:
|
||||
return ""
|
||||
|
||||
text = f"\n=== {header} (Learn from these real-world cases) ===\n"
|
||||
text += "Study these examples to understand the REASONING PATTERN, then apply similar logic.\n\n"
|
||||
|
||||
for i, example in enumerate(examples, 1):
|
||||
text += f"[Example {i}]\n"
|
||||
text += example.format(include_chain=True)
|
||||
text += "\n"
|
||||
|
||||
text += f"=== END {header} ===\n"
|
||||
return text
|
||||
|
||||
def _get_curated_for_type(self, vuln_type: str) -> List[FewShotExample]:
|
||||
"""Get curated examples for a vulnerability type."""
|
||||
vtype = vuln_type.lower().replace("-", "_")
|
||||
examples = []
|
||||
|
||||
if vtype in self._curated_examples:
|
||||
for ex_data in self._curated_examples[vtype].get("testing", []):
|
||||
examples.append(FewShotExample(**ex_data, score=10.0))
|
||||
|
||||
# Also check parent types (e.g., xss_reflected -> xss)
|
||||
base_type = vtype.split("_")[0]
|
||||
if base_type != vtype and base_type in self._curated_examples:
|
||||
for ex_data in self._curated_examples[base_type].get("testing", []):
|
||||
examples.append(FewShotExample(**ex_data, score=8.0))
|
||||
|
||||
return examples
|
||||
|
||||
def _get_curated_verification(self, vuln_type: str,
|
||||
is_tp: bool) -> List[FewShotExample]:
|
||||
"""Get curated verification examples (TP or FP)."""
|
||||
vtype = vuln_type.lower().replace("-", "_")
|
||||
key = "verification_tp" if is_tp else "verification_fp"
|
||||
examples = []
|
||||
|
||||
if vtype in self._curated_examples:
|
||||
for ex_data in self._curated_examples[vtype].get(key, []):
|
||||
examples.append(FewShotExample(**ex_data, score=10.0))
|
||||
|
||||
base_type = vtype.split("_")[0]
|
||||
if base_type != vtype and base_type in self._curated_examples:
|
||||
for ex_data in self._curated_examples[base_type].get(key, []):
|
||||
examples.append(FewShotExample(**ex_data, score=8.0))
|
||||
|
||||
return examples
|
||||
|
||||
def _get_curated_strategy(self, technology: str) -> List[FewShotExample]:
|
||||
"""Get curated strategy examples for a technology."""
|
||||
tech = technology.lower()
|
||||
if tech in self._curated_examples.get("_strategies", {}):
|
||||
data = self._curated_examples["_strategies"][tech]
|
||||
return [FewShotExample(**data, score=10.0)]
|
||||
return []
|
||||
|
||||
def _build_curated_examples(self) -> Dict:
|
||||
"""
|
||||
Build curated high-quality few-shot examples.
|
||||
These are hand-crafted to demonstrate ideal reasoning patterns.
|
||||
"""
|
||||
return {
|
||||
"xss": {
|
||||
"testing": [
|
||||
{
|
||||
"vuln_type": "xss_reflected",
|
||||
"technology": "PHP",
|
||||
"scenario": "Search parameter reflected in HTML body without encoding",
|
||||
"reasoning_chain": [
|
||||
"OBSERVE: Parameter 'q' is reflected verbatim in <div class='results'>",
|
||||
"IDENTIFY CONTEXT: Reflection is inside HTML body (not attribute, not JS)",
|
||||
"TEST FILTERS: Sent <b>test</b> - HTML tags rendered, no encoding",
|
||||
"ESCALATE: Injected <script>alert(document.domain)</script>",
|
||||
"VERIFY: Script executed in browser, alert showed domain name",
|
||||
"PROVE: DOM inspection confirms injected <script> tag is live"
|
||||
],
|
||||
"outcome": "Confirmed: Reflected XSS via unencoded HTML body injection",
|
||||
"payload": "<script>alert(document.domain)</script>",
|
||||
"proof": "Playwright confirmed script execution, DOM shows injected tag"
|
||||
}
|
||||
],
|
||||
"verification_tp": [
|
||||
{
|
||||
"vuln_type": "xss_reflected",
|
||||
"technology": "generic",
|
||||
"scenario": "Verifying XSS finding is a TRUE POSITIVE",
|
||||
"reasoning_chain": [
|
||||
"CHECK 1: Payload appears in response body unencoded? YES",
|
||||
"CHECK 2: Payload is in executable context (inside HTML, not comment)? YES",
|
||||
"CHECK 3: No CSP header blocking inline scripts? CORRECT, no CSP",
|
||||
"CHECK 4: Browser actually executes the script? YES (Playwright confirms)",
|
||||
"VERDICT: All 4 checks pass → TRUE POSITIVE"
|
||||
],
|
||||
"outcome": "CONFIRMED: True positive - all verification checks passed",
|
||||
"proof": "Browser execution confirmed via Playwright"
|
||||
}
|
||||
],
|
||||
"verification_fp": [
|
||||
{
|
||||
"vuln_type": "xss_reflected",
|
||||
"technology": "generic",
|
||||
"scenario": "Verifying XSS finding that is a FALSE POSITIVE",
|
||||
"reasoning_chain": [
|
||||
"CHECK 1: Payload in response? YES, but only inside HTML comment <!-- -->",
|
||||
"CHECK 2: Executable context? NO - HTML comments are not executed",
|
||||
"CHECK 3: Even if we break out of comment, CSP blocks inline scripts",
|
||||
"VERDICT: Payload present but NOT executable → FALSE POSITIVE"
|
||||
],
|
||||
"outcome": "REJECTED: False positive - payload in non-executable context",
|
||||
"proof": "No browser execution possible due to HTML comment context + CSP"
|
||||
}
|
||||
]
|
||||
},
|
||||
"sqli": {
|
||||
"testing": [
|
||||
{
|
||||
"vuln_type": "sqli",
|
||||
"technology": "PHP/MySQL",
|
||||
"scenario": "Login form with username and password fields",
|
||||
"reasoning_chain": [
|
||||
"OBSERVE: Login form sends POST with username & password params",
|
||||
"PROBE: Single quote in username returns MySQL error: 'syntax error near '''",
|
||||
"IDENTIFY: Error-based SQL injection, MySQL backend confirmed",
|
||||
"TEST UNION: ' UNION SELECT 1,2,3-- - reveals 3 columns",
|
||||
"EXTRACT: ' UNION SELECT user(),database(),version()-- - shows root@localhost",
|
||||
"PROVE: Extracted real DB info (database name, MySQL version, current user)"
|
||||
],
|
||||
"outcome": "Confirmed: UNION-based SQL injection with full data extraction",
|
||||
"payload": "' UNION SELECT user(),database(),version()-- -",
|
||||
"proof": "Database name, version and user extracted from response"
|
||||
}
|
||||
],
|
||||
"verification_tp": [
|
||||
{
|
||||
"vuln_type": "sqli",
|
||||
"technology": "generic",
|
||||
"scenario": "Verifying SQL injection is TRUE POSITIVE",
|
||||
"reasoning_chain": [
|
||||
"CHECK 1: Database error message in response? YES (MySQL syntax error)",
|
||||
"CHECK 2: Error contains our injected syntax? YES (shows the quote)",
|
||||
"CHECK 3: Can we extract data? YES (UNION SELECT returns DB version)",
|
||||
"CHECK 4: Is data extraction real? YES (version string matches known MySQL format)",
|
||||
"VERDICT: Data extraction proven → TRUE POSITIVE"
|
||||
],
|
||||
"outcome": "CONFIRMED: True positive - actual data extraction achieved"
|
||||
}
|
||||
],
|
||||
"verification_fp": [
|
||||
{
|
||||
"vuln_type": "sqli",
|
||||
"technology": "generic",
|
||||
"scenario": "WAF error page mimics SQL error",
|
||||
"reasoning_chain": [
|
||||
"CHECK 1: Error message in response? YES, but it's a generic WAF block page",
|
||||
"CHECK 2: Same error for ANY special character? YES - WAF blocks all",
|
||||
"CHECK 3: Can we extract data? NO - all payloads return same WAF page",
|
||||
"VERDICT: WAF blocking, not SQL processing → FALSE POSITIVE"
|
||||
],
|
||||
"outcome": "REJECTED: False positive - WAF error page, not database error"
|
||||
}
|
||||
]
|
||||
},
|
||||
"ssrf": {
|
||||
"testing": [
|
||||
{
|
||||
"vuln_type": "ssrf",
|
||||
"technology": "Python/Flask",
|
||||
"scenario": "URL parameter used for fetching external content",
|
||||
"reasoning_chain": [
|
||||
"OBSERVE: Parameter 'url' fetches and displays content from URLs",
|
||||
"PROBE: Sent url=http://127.0.0.1:80 - got response from localhost",
|
||||
"TEST INTERNAL: url=http://169.254.169.254/latest/meta-data/ - got AWS metadata!",
|
||||
"EXTRACT: Retrieved IAM role name and temporary credentials",
|
||||
"PROVE: AWS metadata content (ami-id, instance-type) confirms internal access"
|
||||
],
|
||||
"outcome": "Confirmed: SSRF to AWS metadata endpoint with credential extraction",
|
||||
"payload": "http://169.254.169.254/latest/meta-data/iam/security-credentials/",
|
||||
"proof": "AWS IAM credentials retrieved from metadata endpoint"
|
||||
}
|
||||
],
|
||||
"verification_fp": [
|
||||
{
|
||||
"vuln_type": "ssrf",
|
||||
"technology": "generic",
|
||||
"scenario": "Status code difference is NOT proof of SSRF",
|
||||
"reasoning_chain": [
|
||||
"CHECK 1: Different status code with internal URL? YES (403→200)",
|
||||
"CHECK 2: But is the CONTENT from internal service? NO - same login page",
|
||||
"CHECK 3: Negative control (random URL) also returns 200? YES",
|
||||
"VERDICT: Status code change is application behavior, NOT SSRF → FALSE POSITIVE"
|
||||
],
|
||||
"outcome": "REJECTED: Status code diff without internal content is NOT SSRF"
|
||||
}
|
||||
]
|
||||
},
|
||||
"idor": {
|
||||
"testing": [
|
||||
{
|
||||
"vuln_type": "idor",
|
||||
"technology": "REST API",
|
||||
"scenario": "User profile API endpoint with numeric ID",
|
||||
"reasoning_chain": [
|
||||
"OBSERVE: GET /api/users/42 returns user profile (my ID is 42)",
|
||||
"TEST: GET /api/users/1 with my auth token - got different user's data!",
|
||||
"COMPARE DATA: Response contains name='Admin', email='admin@target.com'",
|
||||
"VERIFY: This is NOT my data - different name, email, role",
|
||||
"PROVE: Can access ANY user's profile by changing ID parameter"
|
||||
],
|
||||
"outcome": "Confirmed: IDOR - can access other users' profiles via ID enumeration",
|
||||
"proof": "Different user's PII (name, email) retrieved with attacker's token"
|
||||
}
|
||||
],
|
||||
"verification_fp": [
|
||||
{
|
||||
"vuln_type": "idor",
|
||||
"technology": "generic",
|
||||
"scenario": "Same response for different IDs is NOT IDOR",
|
||||
"reasoning_chain": [
|
||||
"CHECK 1: Different ID returns 200? YES",
|
||||
"CHECK 2: But compare the DATA content - is it actually DIFFERENT user data? NO",
|
||||
"CHECK 3: Both IDs return the SAME profile (my own data)",
|
||||
"VERDICT: Server ignores the ID parameter, always returns current user → FALSE POSITIVE"
|
||||
],
|
||||
"outcome": "REJECTED: Same data returned regardless of ID - no object-level access violation"
|
||||
}
|
||||
]
|
||||
},
|
||||
"rce": {
|
||||
"testing": [
|
||||
{
|
||||
"vuln_type": "rce",
|
||||
"technology": "Node.js",
|
||||
"scenario": "Template rendering endpoint with user-controlled input",
|
||||
"reasoning_chain": [
|
||||
"OBSERVE: Parameter 'name' rendered in template, endpoint uses eval-like function",
|
||||
"PROBE: Sent {{7*7}} - response shows 49 (template injection confirmed)",
|
||||
"ESCALATE: {{require('child_process').execSync('id')}} - returns uid=0(root)!",
|
||||
"EXTRACT: Read /etc/passwd via command execution",
|
||||
"PROVE: OS command output (uid, file contents) confirms RCE"
|
||||
],
|
||||
"outcome": "Confirmed: RCE via SSTI in Node.js (template to command execution chain)",
|
||||
"payload": "{{require('child_process').execSync('id')}}",
|
||||
"proof": "Command output 'uid=0(root)' in HTTP response"
|
||||
}
|
||||
]
|
||||
},
|
||||
"ssti": {
|
||||
"testing": [
|
||||
{
|
||||
"vuln_type": "ssti",
|
||||
"technology": "Python/Jinja2",
|
||||
"scenario": "Name field rendered via Jinja2 template engine",
|
||||
"reasoning_chain": [
|
||||
"OBSERVE: Input reflected in response via template rendering",
|
||||
"PROBE: {{7*7}} returns '49' - arithmetic evaluated = SSTI confirmed",
|
||||
"IDENTIFY ENGINE: {{config}} returns Flask config = Jinja2 confirmed",
|
||||
"ESCALATE: Use MRO chain to access subprocess module",
|
||||
"EXECUTE: {{''.__class__.__mro__[1].__subclasses__()}} - list Python classes",
|
||||
"PROVE: Achieved code execution via Popen subclass"
|
||||
],
|
||||
"outcome": "Confirmed: SSTI in Jinja2 with RCE via Python class chain",
|
||||
"payload": "{{config.__class__.__init__.__globals__['os'].popen('id').read()}}",
|
||||
"proof": "OS command output returned in template render"
|
||||
}
|
||||
]
|
||||
},
|
||||
"lfi": {
|
||||
"testing": [
|
||||
{
|
||||
"vuln_type": "lfi",
|
||||
"technology": "PHP",
|
||||
"scenario": "File include parameter loading page templates",
|
||||
"reasoning_chain": [
|
||||
"OBSERVE: Parameter 'page=about' loads about.php template",
|
||||
"TEST: page=../../../etc/passwd - returned 'root:x:0:0:root' content!",
|
||||
"VERIFY: Content matches /etc/passwd format (user:x:uid:gid:...)",
|
||||
"ESCALATE: Read application config via page=../config/database.php",
|
||||
"PROVE: Extracted database credentials from config file"
|
||||
],
|
||||
"outcome": "Confirmed: LFI with path traversal, read sensitive system and app files",
|
||||
"payload": "../../../etc/passwd",
|
||||
"proof": "/etc/passwd content with valid user entries in response"
|
||||
}
|
||||
]
|
||||
},
|
||||
"auth_bypass": {
|
||||
"testing": [
|
||||
{
|
||||
"vuln_type": "auth_bypass",
|
||||
"technology": "REST API",
|
||||
"scenario": "Admin panel behind authentication check",
|
||||
"reasoning_chain": [
|
||||
"OBSERVE: /admin returns 302 redirect to /login",
|
||||
"TEST: Send request without following redirect - check if body has admin content",
|
||||
"PROBE: Try /admin with modified headers (X-Forwarded-For: 127.0.0.1)",
|
||||
"DISCOVER: /admin/ (trailing slash) bypasses auth check! Returns admin panel",
|
||||
"VERIFY: Compare authenticated vs unauthenticated response - SAME admin content",
|
||||
"PROVE: Full admin functionality accessible without any credentials"
|
||||
],
|
||||
"outcome": "Confirmed: Auth bypass via trailing slash normalization bug",
|
||||
"proof": "Admin panel content accessible without authentication"
|
||||
}
|
||||
]
|
||||
},
|
||||
"_strategies": {
|
||||
"php": {
|
||||
"vuln_type": "strategy",
|
||||
"technology": "PHP",
|
||||
"scenario": "Planning attack strategy for PHP application",
|
||||
"reasoning_chain": [
|
||||
"PHP apps are prone to: SQL injection (especially with raw queries), LFI/RFI (include/require), XSS (echo without htmlspecialchars), file upload bypass, deserialization (unserialize)",
|
||||
"Priority: Test SQL injection on login/search forms, check for LFI in page/template parameters, test file upload functionality for webshell",
|
||||
"PHP-specific: Check for type juggling (== vs ===), test PHP wrapper protocols (php://input, php://filter), check for exposed phpinfo()",
|
||||
"Framework detection: Look for Laravel (.env exposure, debug mode), WordPress (wp-admin, xmlrpc.php), CodeIgniter (CI paths)"
|
||||
],
|
||||
"outcome": "Focus on SQLi > LFI > XSS > Upload > Deserialization for PHP targets"
|
||||
},
|
||||
"node": {
|
||||
"vuln_type": "strategy",
|
||||
"technology": "Node.js",
|
||||
"scenario": "Planning attack strategy for Node.js application",
|
||||
"reasoning_chain": [
|
||||
"Node.js apps are prone to: Prototype pollution, SSTI (pug/ejs/handlebars), NoSQL injection (MongoDB), SSRF, insecure deserialization (node-serialize), path traversal",
|
||||
"Priority: Test prototype pollution via JSON body (__proto__), check for SSTI in template params, test NoSQL injection on MongoDB endpoints",
|
||||
"Node-specific: Check for npm package vulns, test for event loop blocking (ReDoS), look for Express middleware bypasses",
|
||||
"API focus: GraphQL introspection, JWT implementation flaws, WebSocket injection"
|
||||
],
|
||||
"outcome": "Focus on Prototype Pollution > SSTI > NoSQL > SSRF for Node.js targets"
|
||||
},
|
||||
"python": {
|
||||
"vuln_type": "strategy",
|
||||
"technology": "Python",
|
||||
"scenario": "Planning attack strategy for Python application",
|
||||
"reasoning_chain": [
|
||||
"Python apps are prone to: SSTI (Jinja2/Mako), SQL injection (raw queries, ORM bypass), SSRF, pickle deserialization, command injection (os.system/subprocess)",
|
||||
"Priority: Test SSTI with {{7*7}} on all input fields, check for pickle endpoints, test SSRF on URL parameters",
|
||||
"Python-specific: Django debug mode, Flask debug/secret key exposure, YAML deserialization (yaml.load), eval/exec injection",
|
||||
"Framework: Django admin exposure, Flask /console (Werkzeug debugger), FastAPI /docs endpoint"
|
||||
],
|
||||
"outcome": "Focus on SSTI > SQLi > SSRF > Deserialization for Python targets"
|
||||
},
|
||||
"java": {
|
||||
"vuln_type": "strategy",
|
||||
"technology": "Java",
|
||||
"scenario": "Planning attack strategy for Java application",
|
||||
"reasoning_chain": [
|
||||
"Java apps are prone to: Deserialization (ObjectInputStream), XXE (SAX/DOM parsers), SSTI (Velocity/Freemarker), Expression Language injection, Log4Shell",
|
||||
"Priority: Test deserialization on all serialized object endpoints, check XXE on XML parsing endpoints, test EL injection",
|
||||
"Java-specific: Check for Java serialization magic bytes (aced0005), test Log4j via ${jndi:ldap://} in headers, Struts OGNL injection",
|
||||
"Framework: Spring Boot Actuator endpoints (/env, /heapdump), Tomcat manager exposure, JBoss/WildFly admin"
|
||||
],
|
||||
"outcome": "Focus on Deserialization > XXE > Log4Shell > SSTI for Java targets"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,399 @@
|
||||
"""
|
||||
Reasoning Memory - Cross-scan storage of successful reasoning traces.
|
||||
|
||||
This is "pseudo-fine-tuning": instead of modifying model weights, we
|
||||
accumulate successful reasoning chains and inject them as context
|
||||
into future prompts. Over time, the system learns from its own
|
||||
successful analyses.
|
||||
|
||||
Stores:
|
||||
- Confirmed finding reasoning chains
|
||||
- Failed hypothesis patterns (what didn't work and why)
|
||||
- Technology-specific successful strategies
|
||||
- Payload effectiveness per context
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Optional, Tuple
|
||||
from dataclasses import dataclass, field, asdict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MEMORY_FILE = "data/reasoning_memory.json"
|
||||
MAX_TRACES = 500
|
||||
MAX_FAILURES = 200
|
||||
MAX_STRATEGIES = 100
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReasoningTrace:
|
||||
"""A successful reasoning chain from a confirmed finding."""
|
||||
vuln_type: str
|
||||
technology: str
|
||||
endpoint_pattern: str # Normalized pattern (IDs removed)
|
||||
parameter: str
|
||||
reasoning_steps: List[str]
|
||||
payload_used: str
|
||||
evidence_summary: str
|
||||
confidence: float
|
||||
timestamp: float = 0.0
|
||||
scan_target: str = ""
|
||||
trace_id: str = ""
|
||||
|
||||
def __post_init__(self):
|
||||
if not self.timestamp:
|
||||
self.timestamp = time.time()
|
||||
if not self.trace_id:
|
||||
key = f"{self.vuln_type}_{self.endpoint_pattern}_{self.payload_used}"
|
||||
self.trace_id = hashlib.md5(key.encode()).hexdigest()[:10]
|
||||
|
||||
|
||||
@dataclass
|
||||
class FailureRecord:
|
||||
"""A record of what didn't work and why."""
|
||||
vuln_type: str
|
||||
technology: str
|
||||
endpoint_pattern: str
|
||||
attempted_payloads: List[str]
|
||||
failure_reason: str # "waf_blocked", "encoded", "no_reflection", "same_behavior", etc.
|
||||
timestamp: float = 0.0
|
||||
|
||||
def __post_init__(self):
|
||||
if not self.timestamp:
|
||||
self.timestamp = time.time()
|
||||
|
||||
|
||||
@dataclass
|
||||
class StrategyRecord:
|
||||
"""A successful technology-specific strategy."""
|
||||
technology: str
|
||||
vuln_types_found: List[str]
|
||||
priority_order: List[str]
|
||||
key_insights: List[str]
|
||||
scan_count: int = 1
|
||||
success_rate: float = 0.0
|
||||
timestamp: float = 0.0
|
||||
|
||||
def __post_init__(self):
|
||||
if not self.timestamp:
|
||||
self.timestamp = time.time()
|
||||
|
||||
|
||||
class ReasoningMemory:
|
||||
"""
|
||||
Persistent storage and retrieval of reasoning experience.
|
||||
Learns from successful attacks and failed attempts to provide
|
||||
increasingly relevant context over time.
|
||||
"""
|
||||
|
||||
def __init__(self, memory_path: str = None):
|
||||
self.memory_path = Path(memory_path or MEMORY_FILE)
|
||||
self.memory_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self._traces: List[Dict] = []
|
||||
self._failures: List[Dict] = []
|
||||
self._strategies: Dict[str, Dict] = {} # tech -> StrategyRecord dict
|
||||
self._dirty = False
|
||||
|
||||
self._load()
|
||||
|
||||
def _load(self):
|
||||
"""Load persisted memory."""
|
||||
if not self.memory_path.exists():
|
||||
return
|
||||
|
||||
try:
|
||||
with open(self.memory_path, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
self._traces = data.get("traces", [])
|
||||
self._failures = data.get("failures", [])
|
||||
self._strategies = data.get("strategies", {})
|
||||
logger.info(
|
||||
f"ReasoningMemory: Loaded {len(self._traces)} traces, "
|
||||
f"{len(self._failures)} failures, {len(self._strategies)} strategies"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"ReasoningMemory: Failed to load: {e}")
|
||||
|
||||
def _save(self):
|
||||
"""Persist memory to disk."""
|
||||
if not self._dirty:
|
||||
return
|
||||
|
||||
# Enforce size limits
|
||||
self._traces = self._traces[-MAX_TRACES:]
|
||||
self._failures = self._failures[-MAX_FAILURES:]
|
||||
|
||||
try:
|
||||
data = {
|
||||
"traces": self._traces,
|
||||
"failures": self._failures,
|
||||
"strategies": self._strategies,
|
||||
"last_updated": time.time(),
|
||||
"stats": {
|
||||
"total_traces": len(self._traces),
|
||||
"total_failures": len(self._failures),
|
||||
"technologies": list(self._strategies.keys())
|
||||
}
|
||||
}
|
||||
with open(self.memory_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(data, f, indent=2, default=str)
|
||||
self._dirty = False
|
||||
except Exception as e:
|
||||
logger.warning(f"ReasoningMemory: Failed to save: {e}")
|
||||
|
||||
# ── Recording ──────────────────────────────────────────────
|
||||
|
||||
def record_success(self, trace: ReasoningTrace):
|
||||
"""Record a successful reasoning chain from a confirmed finding."""
|
||||
trace_dict = asdict(trace)
|
||||
self._traces.append(trace_dict)
|
||||
self._dirty = True
|
||||
|
||||
# Auto-save periodically
|
||||
if len(self._traces) % 10 == 0:
|
||||
self._save()
|
||||
|
||||
logger.debug(
|
||||
f"ReasoningMemory: Recorded success - {trace.vuln_type} "
|
||||
f"on {trace.technology} ({trace.endpoint_pattern})"
|
||||
)
|
||||
|
||||
def record_failure(self, failure: FailureRecord):
|
||||
"""Record a failed attack attempt for future avoidance."""
|
||||
self._failures.append(asdict(failure))
|
||||
self._dirty = True
|
||||
|
||||
if len(self._failures) % 20 == 0:
|
||||
self._save()
|
||||
|
||||
def record_strategy(self, technology: str, vuln_types_found: List[str],
|
||||
priority_order: List[str], insights: List[str]):
|
||||
"""Record a successful scanning strategy for a technology."""
|
||||
tech_key = technology.lower()
|
||||
|
||||
if tech_key in self._strategies:
|
||||
# Update existing
|
||||
existing = self._strategies[tech_key]
|
||||
existing["scan_count"] = existing.get("scan_count", 0) + 1
|
||||
existing["vuln_types_found"] = list(set(
|
||||
existing.get("vuln_types_found", []) + vuln_types_found
|
||||
))
|
||||
# Merge insights
|
||||
existing_insights = set(existing.get("key_insights", []))
|
||||
for insight in insights:
|
||||
existing_insights.add(insight)
|
||||
existing["key_insights"] = list(existing_insights)[:20]
|
||||
existing["timestamp"] = time.time()
|
||||
|
||||
# Recalculate priority based on accumulated experience
|
||||
if priority_order:
|
||||
existing["priority_order"] = priority_order
|
||||
else:
|
||||
self._strategies[tech_key] = asdict(StrategyRecord(
|
||||
technology=technology,
|
||||
vuln_types_found=vuln_types_found,
|
||||
priority_order=priority_order,
|
||||
key_insights=insights
|
||||
))
|
||||
|
||||
self._dirty = True
|
||||
self._save()
|
||||
|
||||
# ── Retrieval ──────────────────────────────────────────────
|
||||
|
||||
def get_relevant_traces(self, vuln_type: str, technology: str = "",
|
||||
max_traces: int = 3) -> List[Dict]:
|
||||
"""
|
||||
Retrieve relevant successful reasoning traces.
|
||||
Prioritizes exact vuln_type match, then technology match.
|
||||
"""
|
||||
candidates = []
|
||||
|
||||
for trace in self._traces:
|
||||
score = 0.0
|
||||
|
||||
# Vuln type match (primary)
|
||||
if trace.get("vuln_type") == vuln_type:
|
||||
score += 5.0
|
||||
elif vuln_type.split("_")[0] in trace.get("vuln_type", ""):
|
||||
score += 2.0
|
||||
else:
|
||||
continue # Skip irrelevant types
|
||||
|
||||
# Technology match (secondary)
|
||||
if technology and technology.lower() in trace.get("technology", "").lower():
|
||||
score += 3.0
|
||||
|
||||
# Recency boost
|
||||
age_days = (time.time() - trace.get("timestamp", 0)) / 86400
|
||||
if age_days < 7:
|
||||
score += 1.0
|
||||
elif age_days < 30:
|
||||
score += 0.5
|
||||
|
||||
# High confidence boost
|
||||
if trace.get("confidence", 0) >= 0.9:
|
||||
score += 1.0
|
||||
|
||||
candidates.append((score, trace))
|
||||
|
||||
candidates.sort(key=lambda x: x[0], reverse=True)
|
||||
return [trace for _, trace in candidates[:max_traces]]
|
||||
|
||||
def get_failure_patterns(self, vuln_type: str, technology: str = "",
|
||||
max_patterns: int = 3) -> List[Dict]:
|
||||
"""
|
||||
Retrieve relevant failure patterns to avoid.
|
||||
"""
|
||||
candidates = []
|
||||
|
||||
for failure in self._failures:
|
||||
if failure.get("vuln_type") != vuln_type:
|
||||
continue
|
||||
|
||||
score = 1.0
|
||||
if technology and technology.lower() in failure.get("technology", "").lower():
|
||||
score += 2.0
|
||||
|
||||
candidates.append((score, failure))
|
||||
|
||||
candidates.sort(key=lambda x: x[0], reverse=True)
|
||||
return [f for _, f in candidates[:max_patterns]]
|
||||
|
||||
def get_strategy_for_tech(self, technology: str) -> Optional[Dict]:
|
||||
"""Get accumulated strategy knowledge for a technology."""
|
||||
tech_key = technology.lower()
|
||||
return self._strategies.get(tech_key)
|
||||
|
||||
def get_context_for_testing(self, vuln_type: str, technology: str = "",
|
||||
max_chars: int = 1500) -> str:
|
||||
"""
|
||||
Format reasoning memory into a prompt-ready context string.
|
||||
Includes successful traces and failure avoidance.
|
||||
"""
|
||||
sections = []
|
||||
|
||||
# Successful reasoning traces
|
||||
traces = self.get_relevant_traces(vuln_type, technology, max_traces=2)
|
||||
if traces:
|
||||
section = "## Successful Past Reasoning (from confirmed findings)\n"
|
||||
for i, trace in enumerate(traces, 1):
|
||||
section += f"\n### Past Success #{i}: {trace.get('vuln_type')} on {trace.get('technology', 'unknown')}\n"
|
||||
steps = trace.get("reasoning_steps", [])
|
||||
if steps:
|
||||
for step in steps[:4]:
|
||||
section += f" - {step}\n"
|
||||
payload = trace.get("payload_used", "")
|
||||
if payload:
|
||||
section += f" Effective payload: {payload}\n"
|
||||
evidence = trace.get("evidence_summary", "")
|
||||
if evidence:
|
||||
section += f" Evidence: {evidence[:200]}\n"
|
||||
sections.append(section)
|
||||
|
||||
# Failure avoidance
|
||||
failures = self.get_failure_patterns(vuln_type, technology, max_patterns=2)
|
||||
if failures:
|
||||
section = "## Failed Approaches to AVOID\n"
|
||||
for failure in failures:
|
||||
reason = failure.get("failure_reason", "unknown")
|
||||
endpoint = failure.get("endpoint_pattern", "")
|
||||
payloads = failure.get("attempted_payloads", [])[:3]
|
||||
section += f" - {reason} on {endpoint}: payloads {payloads} did NOT work\n"
|
||||
sections.append(section)
|
||||
|
||||
# Technology strategy
|
||||
if technology:
|
||||
strategy = self.get_strategy_for_tech(technology)
|
||||
if strategy:
|
||||
section = f"## Learned Strategy for {technology}\n"
|
||||
priority = strategy.get("priority_order", [])
|
||||
if priority:
|
||||
section += f" Priority order: {', '.join(priority[:8])}\n"
|
||||
insights = strategy.get("key_insights", [])
|
||||
for insight in insights[:3]:
|
||||
section += f" - {insight}\n"
|
||||
found = strategy.get("vuln_types_found", [])
|
||||
if found:
|
||||
section += f" Previously found: {', '.join(found[:5])}\n"
|
||||
sections.append(section)
|
||||
|
||||
if not sections:
|
||||
return ""
|
||||
|
||||
result = "\n=== REASONING MEMORY (Learned from past scans) ===\n"
|
||||
result += "Apply these lessons learned. Avoid previously failed approaches.\n\n"
|
||||
|
||||
current_len = len(result)
|
||||
for section in sections:
|
||||
if current_len + len(section) > max_chars:
|
||||
remaining = max_chars - current_len - 20
|
||||
if remaining > 100:
|
||||
result += section[:remaining] + "...\n"
|
||||
break
|
||||
result += section
|
||||
current_len += len(section)
|
||||
|
||||
result += "\n=== END REASONING MEMORY ===\n"
|
||||
return result
|
||||
|
||||
def get_strategy_context(self, technologies: List[str],
|
||||
max_chars: int = 1000) -> str:
|
||||
"""
|
||||
Format accumulated strategy knowledge for attack planning.
|
||||
"""
|
||||
sections = []
|
||||
|
||||
for tech in technologies[:3]:
|
||||
strategy = self.get_strategy_for_tech(tech)
|
||||
if strategy:
|
||||
section = f"### {tech} (tested {strategy.get('scan_count', 0)} times)\n"
|
||||
priority = strategy.get("priority_order", [])
|
||||
if priority:
|
||||
section += f" Recommended priority: {', '.join(priority[:6])}\n"
|
||||
found = strategy.get("vuln_types_found", [])
|
||||
if found:
|
||||
section += f" Previously successful: {', '.join(found[:5])}\n"
|
||||
insights = strategy.get("key_insights", [])
|
||||
for insight in insights[:2]:
|
||||
section += f" Insight: {insight}\n"
|
||||
sections.append(section)
|
||||
|
||||
if not sections:
|
||||
return ""
|
||||
|
||||
result = "\n=== ACCUMULATED STRATEGY KNOWLEDGE ===\n"
|
||||
current_len = len(result)
|
||||
for section in sections:
|
||||
if current_len + len(section) > max_chars:
|
||||
break
|
||||
result += section
|
||||
current_len += len(section)
|
||||
result += "=== END STRATEGY KNOWLEDGE ===\n"
|
||||
return result
|
||||
|
||||
def get_stats(self) -> Dict:
|
||||
"""Return memory statistics."""
|
||||
vuln_type_counts = {}
|
||||
for trace in self._traces:
|
||||
vt = trace.get("vuln_type", "unknown")
|
||||
vuln_type_counts[vt] = vuln_type_counts.get(vt, 0) + 1
|
||||
|
||||
return {
|
||||
"total_traces": len(self._traces),
|
||||
"total_failures": len(self._failures),
|
||||
"technologies_known": list(self._strategies.keys()),
|
||||
"vuln_type_distribution": vuln_type_counts,
|
||||
"memory_file": str(self.memory_path),
|
||||
"file_exists": self.memory_path.exists()
|
||||
}
|
||||
|
||||
def flush(self):
|
||||
"""Force save to disk."""
|
||||
self._dirty = True
|
||||
self._save()
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,643 @@
|
||||
"""
|
||||
Multi-backend vector store for RAG knowledge retrieval.
|
||||
|
||||
Backends (in priority order):
|
||||
1. ChromaDB + sentence-transformers (semantic embeddings, persistent)
|
||||
2. TF-IDF via scikit-learn (statistical similarity)
|
||||
3. BM25 (zero dependencies, keyword-based ranking)
|
||||
|
||||
All backends provide the same interface: add(), query(), delete_collection().
|
||||
"""
|
||||
|
||||
import json
|
||||
import math
|
||||
import hashlib
|
||||
import logging
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import Counter
|
||||
from dataclasses import dataclass, field, asdict
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Optional, Any, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Optional dependencies
|
||||
try:
|
||||
import chromadb
|
||||
from chromadb.config import Settings as ChromaSettings
|
||||
HAS_CHROMADB = True
|
||||
except ImportError:
|
||||
HAS_CHROMADB = False
|
||||
|
||||
try:
|
||||
from sentence_transformers import SentenceTransformer
|
||||
HAS_SENTENCE_TRANSFORMERS = True
|
||||
except ImportError:
|
||||
HAS_SENTENCE_TRANSFORMERS = False
|
||||
|
||||
try:
|
||||
import numpy as np
|
||||
HAS_NUMPY = True
|
||||
except ImportError:
|
||||
HAS_NUMPY = False
|
||||
|
||||
try:
|
||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
HAS_SKLEARN = True
|
||||
except ImportError:
|
||||
HAS_SKLEARN = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetrievedChunk:
|
||||
"""A retrieved knowledge chunk with relevance score."""
|
||||
text: str
|
||||
score: float
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
chunk_id: str = ""
|
||||
source: str = ""
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
return asdict(self)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Document:
|
||||
"""A document to be indexed in the vector store."""
|
||||
text: str
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
doc_id: str = ""
|
||||
|
||||
def __post_init__(self):
|
||||
if not self.doc_id:
|
||||
self.doc_id = hashlib.md5(self.text[:500].encode()).hexdigest()[:12]
|
||||
|
||||
|
||||
class BaseVectorStore(ABC):
|
||||
"""Abstract vector store interface."""
|
||||
|
||||
@abstractmethod
|
||||
def add(self, collection: str, documents: List[Document]) -> int:
|
||||
"""Add documents to a collection. Returns count added."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def query(self, collection: str, query_text: str, top_k: int = 5,
|
||||
metadata_filter: Optional[Dict] = None) -> List[RetrievedChunk]:
|
||||
"""Query a collection for relevant documents."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def collection_exists(self, collection: str) -> bool:
|
||||
"""Check if a collection has been indexed."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete_collection(self, collection: str) -> None:
|
||||
"""Delete a collection and all its documents."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def collection_count(self, collection: str) -> int:
|
||||
"""Return number of documents in a collection."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def backend_name(self) -> str:
|
||||
pass
|
||||
|
||||
|
||||
class BM25VectorStore(BaseVectorStore):
|
||||
"""
|
||||
BM25 (Best Matching 25) keyword-based ranking.
|
||||
Zero external dependencies - works with pure Python.
|
||||
Good for exact keyword matching and term-frequency scoring.
|
||||
"""
|
||||
|
||||
def __init__(self, persist_dir: str, k1: float = 1.5, b: float = 0.75):
|
||||
self.persist_dir = Path(persist_dir)
|
||||
self.persist_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.k1 = k1
|
||||
self.b = b
|
||||
self._collections: Dict[str, Dict] = {}
|
||||
self._load_persisted()
|
||||
|
||||
@property
|
||||
def backend_name(self) -> str:
|
||||
return "bm25"
|
||||
|
||||
def _tokenize(self, text: str) -> List[str]:
|
||||
"""Simple whitespace + punctuation tokenizer."""
|
||||
import re
|
||||
text = text.lower()
|
||||
tokens = re.findall(r'\b[a-z0-9_]{2,}\b', text)
|
||||
return tokens
|
||||
|
||||
def _load_persisted(self):
|
||||
"""Load persisted collections from disk."""
|
||||
index_file = self.persist_dir / "bm25_index.json"
|
||||
if index_file.exists():
|
||||
try:
|
||||
with open(index_file, 'r') as f:
|
||||
data = json.load(f)
|
||||
self._collections = data.get("collections", {})
|
||||
logger.info(f"BM25: Loaded {len(self._collections)} collections from disk")
|
||||
except Exception as e:
|
||||
logger.warning(f"BM25: Failed to load index: {e}")
|
||||
self._collections = {}
|
||||
|
||||
def _persist(self):
|
||||
"""Persist collections to disk."""
|
||||
index_file = self.persist_dir / "bm25_index.json"
|
||||
try:
|
||||
with open(index_file, 'w') as f:
|
||||
json.dump({"collections": self._collections, "timestamp": time.time()}, f)
|
||||
except Exception as e:
|
||||
logger.warning(f"BM25: Failed to persist index: {e}")
|
||||
|
||||
def add(self, collection: str, documents: List[Document]) -> int:
|
||||
if not documents:
|
||||
return 0
|
||||
|
||||
if collection not in self._collections:
|
||||
self._collections[collection] = {
|
||||
"documents": [],
|
||||
"doc_freqs": [],
|
||||
"df": {},
|
||||
"doc_lengths": [],
|
||||
"avgdl": 0,
|
||||
"N": 0
|
||||
}
|
||||
|
||||
col = self._collections[collection]
|
||||
|
||||
added = 0
|
||||
existing_ids = {d.get("doc_id", "") for d in col["documents"]}
|
||||
|
||||
for doc in documents:
|
||||
if doc.doc_id in existing_ids:
|
||||
continue
|
||||
|
||||
tokens = self._tokenize(doc.text)
|
||||
token_freq = dict(Counter(tokens))
|
||||
unique_tokens = set(tokens)
|
||||
|
||||
col["documents"].append({
|
||||
"doc_id": doc.doc_id,
|
||||
"text": doc.text[:5000], # Cap storage
|
||||
"metadata": doc.metadata
|
||||
})
|
||||
col["doc_freqs"].append(token_freq)
|
||||
col["doc_lengths"].append(len(tokens))
|
||||
|
||||
for token in unique_tokens:
|
||||
col["df"][token] = col["df"].get(token, 0) + 1
|
||||
|
||||
added += 1
|
||||
|
||||
col["N"] = len(col["documents"])
|
||||
col["avgdl"] = sum(col["doc_lengths"]) / max(col["N"], 1)
|
||||
|
||||
if added > 0:
|
||||
self._persist()
|
||||
|
||||
return added
|
||||
|
||||
def query(self, collection: str, query_text: str, top_k: int = 5,
|
||||
metadata_filter: Optional[Dict] = None) -> List[RetrievedChunk]:
|
||||
if collection not in self._collections:
|
||||
return []
|
||||
|
||||
col = self._collections[collection]
|
||||
if col["N"] == 0:
|
||||
return []
|
||||
|
||||
query_tokens = self._tokenize(query_text)
|
||||
if not query_tokens:
|
||||
return []
|
||||
|
||||
scores = []
|
||||
N = col["N"]
|
||||
avgdl = col["avgdl"]
|
||||
|
||||
for i in range(N):
|
||||
# Metadata filter
|
||||
if metadata_filter:
|
||||
doc_meta = col["documents"][i].get("metadata", {})
|
||||
skip = False
|
||||
for key, val in metadata_filter.items():
|
||||
if isinstance(val, list):
|
||||
if doc_meta.get(key) not in val:
|
||||
skip = True
|
||||
break
|
||||
elif doc_meta.get(key) != val:
|
||||
skip = True
|
||||
break
|
||||
if skip:
|
||||
scores.append(0.0)
|
||||
continue
|
||||
|
||||
doc_freq = col["doc_freqs"][i]
|
||||
doc_len = col["doc_lengths"][i]
|
||||
score = 0.0
|
||||
|
||||
for token in query_tokens:
|
||||
if token not in doc_freq:
|
||||
continue
|
||||
|
||||
tf = doc_freq[token]
|
||||
df = col["df"].get(token, 0)
|
||||
|
||||
# BM25 IDF
|
||||
idf = math.log((N - df + 0.5) / (df + 0.5) + 1.0)
|
||||
|
||||
# BM25 TF normalization
|
||||
tf_norm = (tf * (self.k1 + 1)) / (
|
||||
tf + self.k1 * (1.0 - self.b + self.b * doc_len / avgdl)
|
||||
)
|
||||
|
||||
score += idf * tf_norm
|
||||
|
||||
scores.append(score)
|
||||
|
||||
# Get top-k
|
||||
indexed_scores = [(i, s) for i, s in enumerate(scores) if s > 0]
|
||||
indexed_scores.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
results = []
|
||||
for i, score in indexed_scores[:top_k]:
|
||||
doc = col["documents"][i]
|
||||
results.append(RetrievedChunk(
|
||||
text=doc["text"],
|
||||
score=score,
|
||||
metadata=doc.get("metadata", {}),
|
||||
chunk_id=doc.get("doc_id", f"doc_{i}"),
|
||||
source=collection
|
||||
))
|
||||
|
||||
return results
|
||||
|
||||
def collection_exists(self, collection: str) -> bool:
|
||||
return collection in self._collections and self._collections[collection]["N"] > 0
|
||||
|
||||
def delete_collection(self, collection: str) -> None:
|
||||
if collection in self._collections:
|
||||
del self._collections[collection]
|
||||
self._persist()
|
||||
|
||||
def collection_count(self, collection: str) -> int:
|
||||
if collection not in self._collections:
|
||||
return 0
|
||||
return self._collections[collection]["N"]
|
||||
|
||||
|
||||
class TFIDFVectorStore(BaseVectorStore):
|
||||
"""
|
||||
TF-IDF based vector store using scikit-learn.
|
||||
Better than BM25 for capturing document-level similarity.
|
||||
Requires: scikit-learn, numpy
|
||||
"""
|
||||
|
||||
def __init__(self, persist_dir: str):
|
||||
if not HAS_SKLEARN or not HAS_NUMPY:
|
||||
raise ImportError("TF-IDF backend requires scikit-learn and numpy")
|
||||
|
||||
self.persist_dir = Path(persist_dir)
|
||||
self.persist_dir.mkdir(parents=True, exist_ok=True)
|
||||
self._collections: Dict[str, Dict] = {}
|
||||
|
||||
@property
|
||||
def backend_name(self) -> str:
|
||||
return "tfidf"
|
||||
|
||||
def add(self, collection: str, documents: List[Document]) -> int:
|
||||
if not documents:
|
||||
return 0
|
||||
|
||||
if collection not in self._collections:
|
||||
self._collections[collection] = {
|
||||
"documents": [],
|
||||
"texts": [],
|
||||
"vectorizer": None,
|
||||
"matrix": None
|
||||
}
|
||||
|
||||
col = self._collections[collection]
|
||||
existing_ids = {d.get("doc_id", "") for d in col["documents"]}
|
||||
|
||||
added = 0
|
||||
for doc in documents:
|
||||
if doc.doc_id in existing_ids:
|
||||
continue
|
||||
col["documents"].append({
|
||||
"doc_id": doc.doc_id,
|
||||
"text": doc.text[:5000],
|
||||
"metadata": doc.metadata
|
||||
})
|
||||
col["texts"].append(doc.text[:5000])
|
||||
added += 1
|
||||
|
||||
if added > 0:
|
||||
# Rebuild TF-IDF matrix
|
||||
vectorizer = TfidfVectorizer(
|
||||
max_features=10000,
|
||||
stop_words='english',
|
||||
ngram_range=(1, 2),
|
||||
min_df=1,
|
||||
max_df=0.95
|
||||
)
|
||||
col["matrix"] = vectorizer.fit_transform(col["texts"])
|
||||
col["vectorizer"] = vectorizer
|
||||
|
||||
return added
|
||||
|
||||
def query(self, collection: str, query_text: str, top_k: int = 5,
|
||||
metadata_filter: Optional[Dict] = None) -> List[RetrievedChunk]:
|
||||
if collection not in self._collections:
|
||||
return []
|
||||
|
||||
col = self._collections[collection]
|
||||
if col["vectorizer"] is None or col["matrix"] is None:
|
||||
return []
|
||||
|
||||
query_vec = col["vectorizer"].transform([query_text])
|
||||
similarities = cosine_similarity(query_vec, col["matrix"]).flatten()
|
||||
|
||||
# Apply metadata filter
|
||||
if metadata_filter:
|
||||
for i, doc in enumerate(col["documents"]):
|
||||
meta = doc.get("metadata", {})
|
||||
for key, val in metadata_filter.items():
|
||||
if isinstance(val, list):
|
||||
if meta.get(key) not in val:
|
||||
similarities[i] = 0.0
|
||||
elif meta.get(key) != val:
|
||||
similarities[i] = 0.0
|
||||
|
||||
top_indices = np.argsort(similarities)[::-1][:top_k]
|
||||
|
||||
results = []
|
||||
for i in top_indices:
|
||||
if similarities[i] <= 0:
|
||||
continue
|
||||
doc = col["documents"][i]
|
||||
results.append(RetrievedChunk(
|
||||
text=doc["text"],
|
||||
score=float(similarities[i]),
|
||||
metadata=doc.get("metadata", {}),
|
||||
chunk_id=doc.get("doc_id", f"doc_{i}"),
|
||||
source=collection
|
||||
))
|
||||
|
||||
return results
|
||||
|
||||
def collection_exists(self, collection: str) -> bool:
|
||||
return (collection in self._collections and
|
||||
len(self._collections[collection]["documents"]) > 0)
|
||||
|
||||
def delete_collection(self, collection: str) -> None:
|
||||
if collection in self._collections:
|
||||
del self._collections[collection]
|
||||
|
||||
def collection_count(self, collection: str) -> int:
|
||||
if collection not in self._collections:
|
||||
return 0
|
||||
return len(self._collections[collection]["documents"])
|
||||
|
||||
|
||||
class ChromaVectorStore(BaseVectorStore):
|
||||
"""
|
||||
ChromaDB + sentence-transformers for true semantic embeddings.
|
||||
Best quality: understands meaning, not just keywords.
|
||||
Requires: chromadb, sentence-transformers
|
||||
"""
|
||||
|
||||
DEFAULT_MODEL = "all-MiniLM-L6-v2" # Fast, 384-dim, good quality
|
||||
|
||||
def __init__(self, persist_dir: str, model_name: str = None):
|
||||
if not HAS_CHROMADB:
|
||||
raise ImportError("ChromaDB backend requires: pip install chromadb")
|
||||
|
||||
self.persist_dir = Path(persist_dir)
|
||||
self.persist_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.client = chromadb.PersistentClient(
|
||||
path=str(self.persist_dir / "chromadb")
|
||||
)
|
||||
|
||||
# Embedding model
|
||||
self._embed_model = None
|
||||
self._model_name = model_name or self.DEFAULT_MODEL
|
||||
if HAS_SENTENCE_TRANSFORMERS:
|
||||
try:
|
||||
self._embed_model = SentenceTransformer(self._model_name)
|
||||
logger.info(f"ChromaDB: Loaded embedding model '{self._model_name}'")
|
||||
except Exception as e:
|
||||
logger.warning(f"ChromaDB: Failed to load model: {e}")
|
||||
|
||||
@property
|
||||
def backend_name(self) -> str:
|
||||
return "chromadb"
|
||||
|
||||
def _get_collection(self, name: str):
|
||||
"""Get or create a ChromaDB collection."""
|
||||
if self._embed_model:
|
||||
return self.client.get_or_create_collection(
|
||||
name=name,
|
||||
metadata={"hnsw:space": "cosine"}
|
||||
)
|
||||
else:
|
||||
return self.client.get_or_create_collection(name=name)
|
||||
|
||||
def _embed(self, texts: List[str]) -> Optional[List[List[float]]]:
|
||||
"""Generate embeddings using sentence-transformers."""
|
||||
if not self._embed_model:
|
||||
return None
|
||||
try:
|
||||
embeddings = self._embed_model.encode(texts, show_progress_bar=False)
|
||||
return embeddings.tolist()
|
||||
except Exception as e:
|
||||
logger.warning(f"ChromaDB: Embedding failed: {e}")
|
||||
return None
|
||||
|
||||
def add(self, collection: str, documents: List[Document]) -> int:
|
||||
if not documents:
|
||||
return 0
|
||||
|
||||
col = self._get_collection(collection)
|
||||
|
||||
# Filter already-indexed docs
|
||||
existing = set()
|
||||
try:
|
||||
result = col.get()
|
||||
if result and result.get("ids"):
|
||||
existing = set(result["ids"])
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
new_docs = [d for d in documents if d.doc_id not in existing]
|
||||
if not new_docs:
|
||||
return 0
|
||||
|
||||
# Batch add (ChromaDB limit: 41666 per batch)
|
||||
batch_size = 500
|
||||
added = 0
|
||||
|
||||
for start in range(0, len(new_docs), batch_size):
|
||||
batch = new_docs[start:start + batch_size]
|
||||
|
||||
ids = [d.doc_id for d in batch]
|
||||
texts = [d.text[:5000] for d in batch]
|
||||
metadatas = []
|
||||
for d in batch:
|
||||
# ChromaDB metadata must be str/int/float/bool
|
||||
meta = {}
|
||||
for k, v in d.metadata.items():
|
||||
if isinstance(v, (str, int, float, bool)):
|
||||
meta[k] = v
|
||||
elif isinstance(v, list):
|
||||
meta[k] = ",".join(str(x) for x in v)
|
||||
else:
|
||||
meta[k] = str(v)
|
||||
metadatas.append(meta)
|
||||
|
||||
embeddings = self._embed(texts)
|
||||
|
||||
try:
|
||||
if embeddings:
|
||||
col.add(
|
||||
ids=ids,
|
||||
documents=texts,
|
||||
metadatas=metadatas,
|
||||
embeddings=embeddings
|
||||
)
|
||||
else:
|
||||
col.add(
|
||||
ids=ids,
|
||||
documents=texts,
|
||||
metadatas=metadatas
|
||||
)
|
||||
added += len(batch)
|
||||
except Exception as e:
|
||||
logger.warning(f"ChromaDB: Failed to add batch: {e}")
|
||||
|
||||
return added
|
||||
|
||||
def query(self, collection: str, query_text: str, top_k: int = 5,
|
||||
metadata_filter: Optional[Dict] = None) -> List[RetrievedChunk]:
|
||||
try:
|
||||
col = self._get_collection(collection)
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
if col.count() == 0:
|
||||
return []
|
||||
|
||||
# Build ChromaDB where clause
|
||||
where = None
|
||||
if metadata_filter:
|
||||
conditions = []
|
||||
for key, val in metadata_filter.items():
|
||||
if isinstance(val, list):
|
||||
conditions.append({key: {"$in": val}})
|
||||
else:
|
||||
conditions.append({key: {"$eq": val}})
|
||||
if len(conditions) == 1:
|
||||
where = conditions[0]
|
||||
elif len(conditions) > 1:
|
||||
where = {"$and": conditions}
|
||||
|
||||
# Query with embeddings if available
|
||||
query_embedding = self._embed([query_text])
|
||||
|
||||
try:
|
||||
if query_embedding:
|
||||
results = col.query(
|
||||
query_embeddings=query_embedding,
|
||||
n_results=min(top_k, col.count()),
|
||||
where=where
|
||||
)
|
||||
else:
|
||||
results = col.query(
|
||||
query_texts=[query_text],
|
||||
n_results=min(top_k, col.count()),
|
||||
where=where
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"ChromaDB query failed: {e}")
|
||||
return []
|
||||
|
||||
chunks = []
|
||||
if results and results.get("documents"):
|
||||
docs = results["documents"][0]
|
||||
ids = results["ids"][0] if results.get("ids") else [""] * len(docs)
|
||||
distances = results["distances"][0] if results.get("distances") else [0.0] * len(docs)
|
||||
metadatas = results["metadatas"][0] if results.get("metadatas") else [{}] * len(docs)
|
||||
|
||||
for text, doc_id, distance, meta in zip(docs, ids, distances, metadatas):
|
||||
# ChromaDB returns distance (lower = better), convert to similarity score
|
||||
score = max(0.0, 1.0 - distance)
|
||||
chunks.append(RetrievedChunk(
|
||||
text=text,
|
||||
score=score,
|
||||
metadata=meta or {},
|
||||
chunk_id=doc_id,
|
||||
source=collection
|
||||
))
|
||||
|
||||
return chunks
|
||||
|
||||
def collection_exists(self, collection: str) -> bool:
|
||||
try:
|
||||
col = self.client.get_collection(collection)
|
||||
return col.count() > 0
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def delete_collection(self, collection: str) -> None:
|
||||
try:
|
||||
self.client.delete_collection(collection)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def collection_count(self, collection: str) -> int:
|
||||
try:
|
||||
col = self.client.get_collection(collection)
|
||||
return col.count()
|
||||
except Exception:
|
||||
return 0
|
||||
|
||||
|
||||
def create_vectorstore(persist_dir: str, backend: str = "auto") -> BaseVectorStore:
|
||||
"""
|
||||
Factory function to create the best available vector store.
|
||||
|
||||
Args:
|
||||
persist_dir: Directory for persistent storage
|
||||
backend: "auto" (best available), "chromadb", "tfidf", or "bm25"
|
||||
|
||||
Returns:
|
||||
Configured vector store instance
|
||||
"""
|
||||
if backend == "chromadb" or (backend == "auto" and HAS_CHROMADB):
|
||||
try:
|
||||
store = ChromaVectorStore(persist_dir)
|
||||
logger.info(f"RAG: Using ChromaDB backend (semantic embeddings)")
|
||||
return store
|
||||
except Exception as e:
|
||||
logger.warning(f"RAG: ChromaDB init failed: {e}, falling back")
|
||||
|
||||
if backend == "tfidf" or (backend == "auto" and HAS_SKLEARN):
|
||||
try:
|
||||
store = TFIDFVectorStore(persist_dir)
|
||||
logger.info(f"RAG: Using TF-IDF backend (statistical similarity)")
|
||||
return store
|
||||
except Exception as e:
|
||||
logger.warning(f"RAG: TF-IDF init failed: {e}, falling back")
|
||||
|
||||
store = BM25VectorStore(persist_dir)
|
||||
logger.info(f"RAG: Using BM25 backend (keyword ranking)")
|
||||
return store
|
||||
Reference in New Issue
Block a user