NeuroSploit v3.2 - Autonomous AI Penetration Testing Platform

116 modules | 100 vuln types | 18 API routes | 18 frontend pages

Major features:
- VulnEngine: 100 vuln types, 526+ payloads, 12 testers, anti-hallucination prompts
- Autonomous Agent: 3-stream auto pentest, multi-session (5 concurrent), pause/resume/stop
- CLI Agent: Claude Code / Gemini CLI / Codex CLI inside Kali containers
- Validation Pipeline: negative controls, proof of execution, confidence scoring, judge
- AI Reasoning: ReACT engine, token budget, endpoint classifier, CVE hunter, deep recon
- Multi-Agent: 5 specialists + orchestrator + researcher AI + vuln type agents
- RAG System: BM25/TF-IDF/ChromaDB vectorstore, few-shot, reasoning templates
- Smart Router: 20 providers (8 CLI OAuth + 12 API), tier failover, token refresh
- Kali Sandbox: container-per-scan, 56 tools, VPN support, on-demand install
- Full IA Testing: methodology-driven comprehensive pentest sessions
- Notifications: Discord, Telegram, WhatsApp/Twilio multi-channel alerts
- Frontend: React/TypeScript with 18 pages, real-time WebSocket updates
This commit is contained in:
CyberSecurityUP
2026-02-22 17:58:12 -03:00
commit e0935793c5
271 changed files with 132462 additions and 0 deletions
+84
View File
@@ -0,0 +1,84 @@
"""
RAG (Retrieval-Augmented Generation) system for NeuroSploitv2.
Enhances AI reasoning by providing relevant context from multiple knowledge
sources without modifying the underlying model. This is a "reasoning amplifier"
that teaches the AI HOW to think about vulnerabilities through:
1. Semantic retrieval from 9000+ bug bounty reports
2. Few-shot examples showing successful exploitation reasoning
3. Chain-of-Thought reasoning templates per vulnerability type
4. Cross-scan reasoning memory (learning from past successes/failures)
Usage:
from backend.core.rag import RAGEngine, FewShotSelector, ReasoningMemory
from backend.core.rag.reasoning_templates import format_reasoning_prompt
# Initialize
rag = RAGEngine(data_dir="data")
rag.index_all() # One-time indexing
# Get testing context
context = rag.get_testing_context("xss", technology="PHP")
# Get few-shot examples
few_shot = FewShotSelector(rag_engine=rag)
examples = few_shot.get_testing_examples("sqli", technology="MySQL")
# Get reasoning framework
reasoning = format_reasoning_prompt("ssrf")
# Record success for future learning
memory = ReasoningMemory()
memory.record_success(trace)
Backends (auto-selected, best available):
- ChromaDB + sentence-transformers: Semantic embeddings (best quality)
- TF-IDF (scikit-learn): Statistical similarity (good quality)
- BM25 (zero deps): Keyword ranking (works out of box)
"""
from .engine import RAGEngine, RAGContext
from .few_shot import FewShotSelector, FewShotExample
from .reasoning_memory import ReasoningMemory, ReasoningTrace, FailureRecord
from .reasoning_templates import (
get_reasoning_template,
format_reasoning_prompt,
get_available_types,
REASONING_TEMPLATES
)
from .vectorstore import (
BaseVectorStore,
BM25VectorStore,
RetrievedChunk,
Document,
create_vectorstore
)
__all__ = [
# Core engine
"RAGEngine",
"RAGContext",
# Few-shot selection
"FewShotSelector",
"FewShotExample",
# Reasoning memory
"ReasoningMemory",
"ReasoningTrace",
"FailureRecord",
# Reasoning templates
"get_reasoning_template",
"format_reasoning_prompt",
"get_available_types",
"REASONING_TEMPLATES",
# Vector store
"BaseVectorStore",
"BM25VectorStore",
"RetrievedChunk",
"Document",
"create_vectorstore",
]
+877
View File
@@ -0,0 +1,877 @@
"""
RAG Engine - Retrieval-Augmented Generation for enhanced AI reasoning.
Indexes all knowledge sources (bug bounty reports, vuln KB, custom docs,
reasoning traces) and provides semantic retrieval for context-enriched
LLM prompts. Does NOT modify the model - only augments input context.
Collections:
- bug_bounty_patterns: 9131 real-world vulnerability reports
- vuln_methodologies: 100 vulnerability type methodologies
- custom_knowledge: User-uploaded research documents
- reasoning_traces: Successful reasoning chains from past scans
- attack_patterns: Extracted attack patterns and techniques
"""
import json
import logging
import re
import time
from pathlib import Path
from typing import List, Dict, Optional, Any, Tuple
from dataclasses import dataclass, field
from .vectorstore import (
BaseVectorStore, Document, RetrievedChunk,
create_vectorstore
)
logger = logging.getLogger(__name__)
# Collection names
COL_BUG_BOUNTY = "bug_bounty_patterns"
COL_VULN_METHODS = "vuln_methodologies"
COL_CUSTOM = "custom_knowledge"
COL_REASONING = "reasoning_traces"
COL_ATTACK = "attack_patterns"
# Defaults
DEFAULT_TOP_K = 5
MAX_CONTEXT_CHARS = 4000
INDEX_BATCH_SIZE = 200
@dataclass
class RAGContext:
"""Assembled RAG context for a specific query."""
query: str
chunks: List[RetrievedChunk] = field(default_factory=list)
total_score: float = 0.0
sources_used: List[str] = field(default_factory=list)
token_estimate: int = 0
def to_prompt_text(self, max_chars: int = MAX_CONTEXT_CHARS) -> str:
"""Format retrieved context for injection into LLM prompt."""
if not self.chunks:
return ""
sections = []
current_len = 0
for chunk in self.chunks:
source_label = chunk.metadata.get("source_type", chunk.source)
vuln_type = chunk.metadata.get("vuln_type", "")
score_pct = int(chunk.score * 100) if chunk.score <= 1.0 else int(chunk.score)
header = f"[{source_label}]"
if vuln_type:
header += f" ({vuln_type})"
header += f" [relevance: {score_pct}%]"
text = chunk.text.strip()
section = f"{header}\n{text}\n"
if current_len + len(section) > max_chars:
remaining = max_chars - current_len - len(header) - 20
if remaining > 100:
section = f"{header}\n{text[:remaining]}...\n"
else:
break
sections.append(section)
current_len += len(section)
if not sections:
return ""
result = "=== RETRIEVED KNOWLEDGE (RAG) ===\n"
result += "Use this knowledge to inform your analysis. Adapt techniques to the target.\n\n"
result += "\n---\n".join(sections)
result += "\n=== END RETRIEVED KNOWLEDGE ===\n"
self.token_estimate = len(result) // 4 # rough token estimate
return result
class RAGEngine:
"""
Main RAG orchestrator. Indexes knowledge sources and provides
semantic retrieval for context-enriched AI reasoning.
"""
def __init__(self, data_dir: str = "data", backend: str = "auto",
persist_dir: str = None):
self.data_dir = Path(data_dir)
self.persist_dir = persist_dir or str(self.data_dir / "vectorstore")
self.store: BaseVectorStore = create_vectorstore(
self.persist_dir, backend=backend
)
self._indexed = False
self._index_stats: Dict[str, int] = {}
logger.info(f"RAG Engine initialized with '{self.store.backend_name}' backend")
@property
def backend_name(self) -> str:
return self.store.backend_name
@property
def is_indexed(self) -> bool:
return self._indexed
def get_stats(self) -> Dict:
"""Return indexing statistics."""
stats = {
"backend": self.store.backend_name,
"indexed": self._indexed,
"collections": {}
}
for col_name in [COL_BUG_BOUNTY, COL_VULN_METHODS, COL_CUSTOM,
COL_REASONING, COL_ATTACK]:
count = self.store.collection_count(col_name)
if count > 0:
stats["collections"][col_name] = count
return stats
# ── Indexing ────────────────────────────────────────────────
def index_all(self, force: bool = False) -> Dict[str, int]:
"""
Index all available knowledge sources.
Returns dict of collection_name -> documents_indexed.
"""
stats = {}
# Only re-index if forced or collections are empty
if not force and self._all_collections_populated():
logger.info("RAG: All collections already populated, skipping index")
self._indexed = True
return stats
start = time.time()
stats[COL_BUG_BOUNTY] = self._index_bug_bounty()
stats[COL_VULN_METHODS] = self._index_vuln_knowledge_base()
stats[COL_CUSTOM] = self._index_custom_knowledge()
stats[COL_ATTACK] = self._index_attack_patterns()
elapsed = time.time() - start
total = sum(stats.values())
self._indexed = True
self._index_stats = stats
logger.info(f"RAG: Indexed {total} documents across {len(stats)} collections in {elapsed:.1f}s")
return stats
def _all_collections_populated(self) -> bool:
"""Check if main collections already have data."""
return (self.store.collection_exists(COL_BUG_BOUNTY) and
self.store.collection_exists(COL_VULN_METHODS))
def _index_bug_bounty(self) -> int:
"""Index the bug bounty finetuning dataset."""
dataset_path = Path("models/bug-bounty/bugbounty_finetuning_dataset.json")
if not dataset_path.exists():
logger.warning(f"RAG: Bug bounty dataset not found at {dataset_path}")
return 0
if self.store.collection_exists(COL_BUG_BOUNTY):
existing = self.store.collection_count(COL_BUG_BOUNTY)
if existing > 1000:
logger.info(f"RAG: Bug bounty already indexed ({existing} docs)")
return 0
try:
with open(dataset_path, 'r', encoding='utf-8') as f:
entries = json.load(f)
except Exception as e:
logger.error(f"RAG: Failed to load bug bounty dataset: {e}")
return 0
if not isinstance(entries, list):
return 0
documents = []
for i, entry in enumerate(entries):
instruction = entry.get("instruction", "")
output = entry.get("output", "")
if not output or len(output) < 50:
continue
# Extract vulnerability types from content
vuln_types = self._detect_vuln_types(instruction + " " + output)
# Extract technologies
technologies = self._detect_technologies(output)
# Chunk 1: Full methodology (primary chunk)
methodology = self._extract_section(output, [
"passos para reproduzir", "steps to reproduce",
"methodology", "exploitation", "proof of concept",
"como reproduzir", "reprodução"
])
if methodology and len(methodology) > 100:
documents.append(Document(
text=methodology[:4000],
metadata={
"source_type": "bug_bounty",
"vuln_type": vuln_types[0] if vuln_types else "unknown",
"vuln_types": ",".join(vuln_types[:5]),
"technologies": ",".join(technologies[:5]),
"chunk_type": "methodology",
"entry_index": i
},
doc_id=f"bb_method_{i}"
))
# Chunk 2: Summary + Impact (secondary chunk)
summary = self._extract_section(output, [
"resumo", "summary", "descrição", "description",
"overview"
])
impact = self._extract_section(output, [
"impacto", "impact", "severity", "risco"
])
summary_text = f"{instruction}\n\n{summary or output[:500]}"
if impact:
summary_text += f"\n\nImpact: {impact}"
documents.append(Document(
text=summary_text[:3000],
metadata={
"source_type": "bug_bounty",
"vuln_type": vuln_types[0] if vuln_types else "unknown",
"vuln_types": ",".join(vuln_types[:5]),
"technologies": ",".join(technologies[:5]),
"chunk_type": "summary",
"entry_index": i
},
doc_id=f"bb_summary_{i}"
))
# Chunk 3: Payloads & PoC code (if present)
payloads = self._extract_code_blocks(output)
if payloads:
payload_text = f"Vulnerability: {vuln_types[0] if vuln_types else 'unknown'}\n"
payload_text += f"Technologies: {', '.join(technologies[:3])}\n\n"
payload_text += "Payloads/PoC:\n" + "\n\n".join(payloads[:10])
documents.append(Document(
text=payload_text[:3000],
metadata={
"source_type": "bug_bounty",
"vuln_type": vuln_types[0] if vuln_types else "unknown",
"vuln_types": ",".join(vuln_types[:5]),
"technologies": ",".join(technologies[:5]),
"chunk_type": "payload",
"entry_index": i
},
doc_id=f"bb_payload_{i}"
))
# Index in batches
total_added = 0
for start in range(0, len(documents), INDEX_BATCH_SIZE):
batch = documents[start:start + INDEX_BATCH_SIZE]
added = self.store.add(COL_BUG_BOUNTY, batch)
total_added += added
logger.info(f"RAG: Indexed {total_added} bug bounty chunks from {len(entries)} entries")
return total_added
def _index_vuln_knowledge_base(self) -> int:
"""Index the 100-type vulnerability knowledge base."""
kb_path = self.data_dir / "vuln_knowledge_base.json"
if not kb_path.exists():
return 0
if self.store.collection_exists(COL_VULN_METHODS):
existing = self.store.collection_count(COL_VULN_METHODS)
if existing >= 90:
return 0
try:
with open(kb_path, 'r', encoding='utf-8') as f:
kb = json.load(f)
except Exception as e:
logger.error(f"RAG: Failed to load vuln KB: {e}")
return 0
vuln_types = kb.get("vulnerability_types", {})
if not vuln_types:
return 0
documents = []
for vuln_type, info in vuln_types.items():
text = f"Vulnerability: {info.get('title', vuln_type)}\n"
text += f"Type: {vuln_type}\n"
text += f"CWE: {info.get('cwe_id', 'N/A')}\n"
text += f"Severity: {info.get('severity', 'N/A')}\n\n"
text += f"Description: {info.get('description', '')}\n\n"
text += f"Impact: {info.get('impact', '')}\n\n"
text += f"Remediation: {info.get('remediation', '')}\n"
fp_markers = info.get("false_positive_markers", [])
if fp_markers:
text += f"\nFalse Positive Indicators: {', '.join(fp_markers)}\n"
documents.append(Document(
text=text,
metadata={
"source_type": "vuln_kb",
"vuln_type": vuln_type,
"severity": info.get("severity", "medium"),
"cwe_id": info.get("cwe_id", ""),
"chunk_type": "methodology"
},
doc_id=f"vkb_{vuln_type}"
))
# Index XBOW insights if available
xbow = kb.get("xbow_insights", {})
if xbow:
for category, insights in xbow.items():
if isinstance(insights, str):
text = f"XBOW Benchmark Insight - {category}:\n{insights}"
elif isinstance(insights, dict):
text = f"XBOW Benchmark Insight - {category}:\n{json.dumps(insights, indent=2)}"
elif isinstance(insights, list):
text = f"XBOW Benchmark Insight - {category}:\n" + "\n".join(str(i) for i in insights)
else:
continue
documents.append(Document(
text=text[:3000],
metadata={
"source_type": "vuln_kb",
"vuln_type": category,
"chunk_type": "insight"
},
doc_id=f"xbow_{category}"
))
added = self.store.add(COL_VULN_METHODS, documents)
logger.info(f"RAG: Indexed {added} vuln KB entries")
return added
def _index_custom_knowledge(self) -> int:
"""Index user-uploaded custom knowledge documents."""
index_path = self.data_dir / "custom-knowledge" / "index.json"
if not index_path.exists():
return 0
try:
with open(index_path, 'r', encoding='utf-8') as f:
index = json.load(f)
except Exception:
return 0
documents = []
for doc_entry in index.get("documents", []):
for entry in doc_entry.get("knowledge_entries", []):
vuln_type = entry.get("vuln_type", "unknown")
text = f"Custom Knowledge - {vuln_type}\n"
text += f"Source: {doc_entry.get('filename', 'unknown')}\n\n"
if entry.get("methodology"):
text += f"Methodology: {entry['methodology']}\n\n"
if entry.get("key_insights"):
if isinstance(entry["key_insights"], list):
text += "Key Insights:\n" + "\n".join(f"- {i}" for i in entry["key_insights"]) + "\n\n"
else:
text += f"Key Insights: {entry['key_insights']}\n\n"
if entry.get("payloads"):
payloads = entry["payloads"][:10]
text += "Payloads:\n" + "\n".join(f" {p}" for p in payloads) + "\n\n"
if entry.get("bypass_techniques"):
techniques = entry["bypass_techniques"][:10]
text += "Bypass Techniques:\n" + "\n".join(f"- {t}" for t in techniques) + "\n"
documents.append(Document(
text=text[:4000],
metadata={
"source_type": "custom",
"vuln_type": vuln_type,
"filename": doc_entry.get("filename", ""),
"chunk_type": "methodology"
},
doc_id=f"custom_{doc_entry.get('id', '')}_{vuln_type}"
))
if not documents:
return 0
added = self.store.add(COL_CUSTOM, documents)
logger.info(f"RAG: Indexed {added} custom knowledge entries")
return added
def _index_attack_patterns(self) -> int:
"""Index extracted attack patterns from execution history."""
hist_path = self.data_dir / "execution_history.json"
if not hist_path.exists():
return 0
try:
with open(hist_path, 'r', encoding='utf-8') as f:
history = json.load(f)
except Exception:
return 0
attacks = history.get("attacks", [])
if not attacks:
return 0
# Group successful attacks by vuln_type + tech
successes: Dict[str, List[Dict]] = {}
for attack in attacks:
if not attack.get("success"):
continue
key = f"{attack.get('vuln_type', '')}_{attack.get('tech', '')}"
if key not in successes:
successes[key] = []
successes[key].append(attack)
documents = []
for key, attack_list in successes.items():
vuln_type = attack_list[0].get("vuln_type", "unknown")
tech = attack_list[0].get("tech", "unknown")
text = f"Successful Attack Pattern: {vuln_type} on {tech}\n"
text += f"Success count: {len(attack_list)}\n\n"
for atk in attack_list[:5]:
evidence = atk.get("evidence_preview", "")
domain = atk.get("target_domain", "")
text += f"- Target: {domain}, Evidence: {evidence}\n"
documents.append(Document(
text=text[:2000],
metadata={
"source_type": "attack_pattern",
"vuln_type": vuln_type,
"technology": tech,
"success_count": len(attack_list),
"chunk_type": "pattern"
},
doc_id=f"atk_{key}"
))
if not documents:
return 0
added = self.store.add(COL_ATTACK, documents)
logger.info(f"RAG: Indexed {added} attack patterns")
return added
def index_reasoning_trace(self, trace: Dict) -> bool:
"""
Index a successful reasoning trace for future retrieval.
Called when a finding is confirmed.
trace = {
"vuln_type": str,
"technology": str,
"endpoint": str,
"reasoning_chain": List[str],
"payload_used": str,
"evidence": str,
"confidence": float,
"timestamp": float
}
"""
vuln_type = trace.get("vuln_type", "unknown")
tech = trace.get("technology", "unknown")
text = f"Confirmed Reasoning Trace - {vuln_type}\n"
text += f"Technology: {tech}\n"
text += f"Endpoint: {trace.get('endpoint', '')}\n"
text += f"Confidence: {trace.get('confidence', 0):.0%}\n\n"
chain = trace.get("reasoning_chain", [])
if chain:
text += "Reasoning Chain:\n"
for i, step in enumerate(chain, 1):
text += f" {i}. {step}\n"
text += "\n"
if trace.get("payload_used"):
text += f"Payload Used: {trace['payload_used']}\n"
if trace.get("evidence"):
text += f"Evidence: {trace['evidence'][:500]}\n"
doc = Document(
text=text[:3000],
metadata={
"source_type": "reasoning_trace",
"vuln_type": vuln_type,
"technology": tech,
"confidence": trace.get("confidence", 0),
"chunk_type": "reasoning",
"timestamp": trace.get("timestamp", time.time())
},
doc_id=f"trace_{vuln_type}_{int(time.time())}"
)
try:
self.store.add(COL_REASONING, [doc])
return True
except Exception as e:
logger.warning(f"RAG: Failed to index reasoning trace: {e}")
return False
# ── Querying ────────────────────────────────────────────────
def query(self, query_text: str, collections: List[str] = None,
top_k: int = DEFAULT_TOP_K,
vuln_type: str = None,
technology: str = None,
chunk_type: str = None) -> RAGContext:
"""
Query across collections for relevant knowledge.
Args:
query_text: The search query
collections: Which collections to search (default: all)
top_k: Number of results per collection
vuln_type: Filter by vulnerability type
technology: Filter by technology
chunk_type: Filter by chunk type (methodology, payload, summary, etc.)
Returns:
RAGContext with ranked, deduplicated results
"""
if not collections:
collections = [COL_BUG_BOUNTY, COL_VULN_METHODS, COL_CUSTOM,
COL_REASONING, COL_ATTACK]
# Build metadata filter
meta_filter = {}
if vuln_type:
meta_filter["vuln_type"] = vuln_type
if chunk_type:
meta_filter["chunk_type"] = chunk_type
all_chunks: List[RetrievedChunk] = []
sources_used = []
for col_name in collections:
if not self.store.collection_exists(col_name):
continue
chunks = self.store.query(
collection=col_name,
query_text=query_text,
top_k=top_k,
metadata_filter=meta_filter if meta_filter else None
)
if chunks:
all_chunks.extend(chunks)
sources_used.append(col_name)
# Also search with technology-enhanced query if provided
if technology and technology not in query_text.lower():
enhanced_query = f"{query_text} {technology}"
for col_name in collections:
if not self.store.collection_exists(col_name):
continue
chunks = self.store.query(
collection=col_name,
query_text=enhanced_query,
top_k=max(2, top_k // 2),
metadata_filter=meta_filter if meta_filter else None
)
if chunks:
all_chunks.extend(chunks)
# Deduplicate by chunk_id
seen = set()
unique_chunks = []
for chunk in all_chunks:
if chunk.chunk_id not in seen:
seen.add(chunk.chunk_id)
unique_chunks.append(chunk)
# Sort by relevance score
unique_chunks.sort(key=lambda c: c.score, reverse=True)
# Limit total results
max_results = top_k * 2
unique_chunks = unique_chunks[:max_results]
total_score = sum(c.score for c in unique_chunks)
return RAGContext(
query=query_text,
chunks=unique_chunks,
total_score=total_score,
sources_used=sources_used
)
def get_testing_context(self, vuln_type: str, target_url: str = "",
technology: str = "", endpoint: str = "",
parameter: str = "",
max_chars: int = MAX_CONTEXT_CHARS) -> str:
"""
Get optimized RAG context for vulnerability testing.
Combines methodology, real examples, and attack patterns.
"""
# Build a rich query
query_parts = [vuln_type.replace("_", " ")]
if technology:
query_parts.append(technology)
if endpoint:
query_parts.append(f"endpoint {endpoint}")
if parameter:
query_parts.append(f"parameter {parameter}")
query = " ".join(query_parts)
# Query with vuln_type preference
context = self.query(
query_text=query,
vuln_type=vuln_type,
technology=technology,
top_k=5
)
# Also get broader results without vuln_type filter
broad_context = self.query(
query_text=query,
technology=technology,
top_k=3
)
# Merge, preferring vuln-specific results
seen = {c.chunk_id for c in context.chunks}
for chunk in broad_context.chunks:
if chunk.chunk_id not in seen:
context.chunks.append(chunk)
seen.add(chunk.chunk_id)
# Re-sort and limit
context.chunks.sort(key=lambda c: c.score, reverse=True)
context.chunks = context.chunks[:8]
return context.to_prompt_text(max_chars=max_chars)
def get_verification_context(self, vuln_type: str, evidence: str,
technology: str = "",
max_chars: int = 2000) -> str:
"""
Get RAG context for finding verification/judgment.
Focuses on confirmed examples and false positive patterns.
"""
query = f"{vuln_type.replace('_', ' ')} verification proof confirmed {evidence[:200]}"
if technology:
query += f" {technology}"
# Get confirmed reasoning traces
trace_ctx = self.query(
query_text=query,
collections=[COL_REASONING],
vuln_type=vuln_type,
top_k=3
)
# Get methodology for verification criteria
method_ctx = self.query(
query_text=f"{vuln_type} false positive verification criteria proof",
collections=[COL_VULN_METHODS, COL_BUG_BOUNTY],
vuln_type=vuln_type,
chunk_type="methodology",
top_k=3
)
# Combine
all_chunks = trace_ctx.chunks + method_ctx.chunks
all_chunks.sort(key=lambda c: c.score, reverse=True)
combined = RAGContext(
query=query,
chunks=all_chunks[:6],
total_score=sum(c.score for c in all_chunks[:6]),
sources_used=list(set(trace_ctx.sources_used + method_ctx.sources_used))
)
return combined.to_prompt_text(max_chars=max_chars)
def get_strategy_context(self, technologies: List[str],
endpoints: List[str] = None,
max_chars: int = 3000) -> str:
"""
Get RAG context for attack strategy planning.
Focuses on tech-specific patterns and successful attack history.
"""
query_parts = ["penetration testing attack strategy"]
query_parts.extend(technologies[:3])
if endpoints:
query_parts.extend(endpoints[:3])
query = " ".join(query_parts)
# Get attack patterns
attack_ctx = self.query(
query_text=query,
collections=[COL_ATTACK, COL_BUG_BOUNTY],
top_k=5
)
# Get methodology per technology
for tech in technologies[:2]:
tech_ctx = self.query(
query_text=f"{tech} common vulnerabilities exploitation",
collections=[COL_BUG_BOUNTY, COL_VULN_METHODS],
technology=tech,
top_k=3
)
attack_ctx.chunks.extend(tech_ctx.chunks)
# Deduplicate and sort
seen = set()
unique = []
for c in attack_ctx.chunks:
if c.chunk_id not in seen:
seen.add(c.chunk_id)
unique.append(c)
unique.sort(key=lambda c: c.score, reverse=True)
combined = RAGContext(
query=query,
chunks=unique[:10],
total_score=sum(c.score for c in unique[:10]),
sources_used=attack_ctx.sources_used
)
return combined.to_prompt_text(max_chars=max_chars)
# ── Helpers ─────────────────────────────────────────────────
def _detect_vuln_types(self, text: str) -> List[str]:
"""Detect vulnerability types mentioned in text."""
text_lower = text.lower()
VULN_KEYWORDS = {
"xss": ["xss", "cross-site scripting", "cross site scripting", "script injection", "reflected xss", "stored xss"],
"sqli": ["sql injection", "sqli", "sql injeção", "union select", "sqlmap"],
"ssrf": ["ssrf", "server-side request forgery", "server side request"],
"idor": ["idor", "insecure direct object", "referência direta"],
"rce": ["rce", "remote code execution", "command injection", "execução remota", "os command"],
"lfi": ["lfi", "local file inclusion", "path traversal", "directory traversal", "inclusão de arquivo"],
"ssti": ["ssti", "server-side template injection", "template injection", "jinja", "twig"],
"xxe": ["xxe", "xml external entity", "xml injection"],
"csrf": ["csrf", "cross-site request forgery", "request forgery"],
"open_redirect": ["open redirect", "redirecionamento aberto", "redirect"],
"auth_bypass": ["authentication bypass", "auth bypass", "bypass autenticação"],
"race_condition": ["race condition", "condição de corrida", "toctou"],
"deserialization": ["deserialization", "deserialização", "unserialize", "pickle"],
"upload": ["file upload", "upload", "unrestricted upload"],
"cors": ["cors", "cross-origin"],
"prototype_pollution": ["prototype pollution", "poluição de protótipo"],
"request_smuggling": ["request smuggling", "http smuggling", "cl.te", "te.cl"],
"graphql": ["graphql", "introspection"],
"jwt": ["jwt", "json web token"],
"nosql": ["nosql injection", "mongodb injection", "nosql"],
"crlf": ["crlf injection", "header injection", "injeção de cabeçalho"],
"subdomain_takeover": ["subdomain takeover", "tomada de subdomínio"],
"information_disclosure": ["information disclosure", "divulgação de informação", "sensitive data"],
"bola": ["bola", "broken object level"],
"bfla": ["bfla", "broken function level"],
"privilege_escalation": ["privilege escalation", "escalação de privilégio"],
}
detected = []
for vuln_type, keywords in VULN_KEYWORDS.items():
for kw in keywords:
if kw in text_lower:
detected.append(vuln_type)
break
return detected if detected else ["unknown"]
def _detect_technologies(self, text: str) -> List[str]:
"""Detect technologies mentioned in text."""
text_lower = text.lower()
TECH_KEYWORDS = {
"php": ["php", "laravel", "wordpress", "drupal", "symfony", "codeigniter"],
"python": ["python", "django", "flask", "fastapi", "tornado"],
"java": ["java", "spring", "struts", "tomcat", "jboss", "wildfly"],
"node": ["node.js", "nodejs", "express", "next.js", "nuxt"],
"ruby": ["ruby", "rails", "sinatra"],
"dotnet": [".net", "asp.net", "c#", "iis"],
"go": ["golang", " go ", "gin", "echo"],
"nginx": ["nginx"],
"apache": ["apache", "httpd"],
"react": ["react", "reactjs"],
"angular": ["angular"],
"vue": ["vue.js", "vuejs"],
"graphql": ["graphql"],
"docker": ["docker", "kubernetes", "k8s"],
"aws": ["aws", "amazon", "s3", "lambda", "ec2"],
"azure": ["azure", "microsoft cloud"],
"mysql": ["mysql", "mariadb"],
"postgres": ["postgresql", "postgres"],
"mongodb": ["mongodb", "mongo"],
"redis": ["redis"],
}
detected = []
for tech, keywords in TECH_KEYWORDS.items():
for kw in keywords:
if kw in text_lower:
detected.append(tech)
break
return detected
def _extract_section(self, text: str, markers: List[str],
max_chars: int = 2000) -> Optional[str]:
"""Extract a section from text based on header markers."""
text_lower = text.lower()
for marker in markers:
idx = text_lower.find(marker)
if idx != -1:
# Find section start (after the marker line)
newline_after = text.find("\n", idx)
if newline_after == -1:
continue
section_start = newline_after + 1
# Find section end (next ## header or end)
next_header = re.search(r'\n#{1,3}\s', text[section_start:])
if next_header:
section_end = section_start + next_header.start()
else:
section_end = min(section_start + max_chars, len(text))
section = text[section_start:section_end].strip()
if len(section) > 50:
return section[:max_chars]
return None
def _extract_code_blocks(self, text: str) -> List[str]:
"""Extract code blocks and payloads from text."""
blocks = []
# Fenced code blocks
for match in re.finditer(r'```[\w]*\n(.*?)```', text, re.DOTALL):
code = match.group(1).strip()
if len(code) > 20:
blocks.append(code[:500])
# Inline code with attack indicators
for match in re.finditer(r'`([^`]{10,500})`', text):
code = match.group(1)
attack_indicators = ['<script', 'alert(', 'SELECT', 'UNION',
'../', 'curl ', 'wget ', '{{', '${',
'eval(', 'exec(', 'system(']
if any(ind in code for ind in attack_indicators):
blocks.append(code)
return blocks[:20]
+644
View File
@@ -0,0 +1,644 @@
"""
Few-Shot Example Selector for RAG-enhanced reasoning.
Selects the most relevant real-world bug bounty examples and formats
them as few-shot reasoning demonstrations for the LLM. This teaches
the model HOW to reason about vulnerabilities by showing worked examples.
The key insight: instead of just giving the AI information, we show it
examples of successful reasoning chains, so it learns the PATTERN of
good pentesting analysis.
"""
import json
import logging
import re
from pathlib import Path
from typing import List, Dict, Optional, Tuple
from dataclasses import dataclass, field
logger = logging.getLogger(__name__)
@dataclass
class FewShotExample:
"""A formatted few-shot example for prompt injection."""
vuln_type: str
technology: str
scenario: str
reasoning_chain: List[str]
outcome: str
payload: str = ""
proof: str = ""
score: float = 0.0
def format(self, include_chain: bool = True) -> str:
"""Format as a prompt-ready example."""
text = f"--- Example: {self.vuln_type.upper()} in {self.technology} ---\n"
text += f"Scenario: {self.scenario}\n"
if include_chain and self.reasoning_chain:
text += "Reasoning:\n"
for i, step in enumerate(self.reasoning_chain, 1):
text += f" {i}. {step}\n"
if self.payload:
text += f"Payload: {self.payload}\n"
text += f"Outcome: {self.outcome}\n"
if self.proof:
text += f"Proof: {self.proof}\n"
return text
class FewShotSelector:
"""
Selects and formats few-shot examples for LLM prompts.
Provides three types of examples:
1. Testing examples: How to test for a specific vuln type
2. Verification examples: How to verify if a finding is real
3. Strategy examples: How to plan an attack approach
"""
def __init__(self, rag_engine=None, dataset_path: str = None):
"""
Args:
rag_engine: RAGEngine instance for semantic retrieval
dataset_path: Path to bug bounty dataset (fallback if no RAG engine)
"""
self.rag_engine = rag_engine
self.dataset_path = dataset_path or "models/bug-bounty/bugbounty_finetuning_dataset.json"
self._example_cache: Dict[str, List[FewShotExample]] = {}
self._curated_examples = self._build_curated_examples()
def get_testing_examples(self, vuln_type: str, technology: str = "",
max_examples: int = 3) -> str:
"""
Get few-shot examples showing how to test for a vulnerability type.
Demonstrates the reasoning chain: observe → hypothesize → test → verify.
"""
cache_key = f"test_{vuln_type}_{technology}"
if cache_key in self._example_cache:
examples = self._example_cache[cache_key][:max_examples]
return self._format_examples(examples, "TESTING EXAMPLES")
examples = []
# 1. Try curated examples first (highest quality)
curated = self._get_curated_for_type(vuln_type)
examples.extend(curated)
# 2. Get RAG-retrieved examples
if self.rag_engine:
rag_examples = self._retrieve_rag_examples(
vuln_type, technology, "testing", max_examples
)
examples.extend(rag_examples)
# 3. Deduplicate and rank
examples = self._rank_examples(examples, vuln_type, technology)[:max_examples]
self._example_cache[cache_key] = examples
return self._format_examples(examples, "TESTING EXAMPLES")
def get_verification_examples(self, vuln_type: str, evidence: str = "",
max_examples: int = 2) -> str:
"""
Get few-shot examples showing how to verify/judge a finding.
Includes both TRUE POSITIVE and FALSE POSITIVE examples.
"""
examples = []
# Get curated verification examples
curated_tp = self._get_curated_verification(vuln_type, is_tp=True)
curated_fp = self._get_curated_verification(vuln_type, is_tp=False)
examples.extend(curated_tp[:1])
examples.extend(curated_fp[:1])
# RAG-retrieved
if self.rag_engine and len(examples) < max_examples:
rag_examples = self._retrieve_rag_examples(
vuln_type, "", "verification", max_examples - len(examples)
)
examples.extend(rag_examples)
return self._format_examples(examples[:max_examples], "VERIFICATION EXAMPLES")
def get_strategy_examples(self, technologies: List[str],
max_examples: int = 2) -> str:
"""
Get few-shot examples showing attack strategy planning.
"""
examples = []
for tech in technologies[:2]:
tech_examples = self._get_curated_strategy(tech)
examples.extend(tech_examples)
if self.rag_engine and len(examples) < max_examples:
query = f"penetration testing strategy {' '.join(technologies[:3])}"
rag_examples = self._retrieve_rag_examples(
"strategy", " ".join(technologies[:3]), "strategy",
max_examples - len(examples)
)
examples.extend(rag_examples)
return self._format_examples(examples[:max_examples], "STRATEGY EXAMPLES")
def _retrieve_rag_examples(self, vuln_type: str, technology: str,
context: str, max_examples: int) -> List[FewShotExample]:
"""Retrieve and convert RAG chunks into few-shot examples."""
if not self.rag_engine:
return []
query = f"{vuln_type.replace('_', ' ')} {technology} {context}"
rag_ctx = self.rag_engine.query(
query_text=query,
vuln_type=vuln_type if vuln_type != "strategy" else None,
technology=technology if technology else None,
top_k=max_examples * 2
)
examples = []
for chunk in rag_ctx.chunks[:max_examples]:
example = self._chunk_to_example(chunk, vuln_type)
if example:
examples.append(example)
return examples
def _chunk_to_example(self, chunk, vuln_type: str) -> Optional[FewShotExample]:
"""Convert a retrieved chunk into a few-shot example."""
text = chunk.text
meta = chunk.metadata
# Extract reasoning chain from the text
chain = self._extract_reasoning_from_text(text)
# Extract payload
payload = ""
payload_match = re.search(r'(?:payload|exploit|poc)[:\s]*[`"]?([^\n`"]{10,200})', text, re.I)
if payload_match:
payload = payload_match.group(1).strip()
# Extract outcome
outcome = "See methodology above for complete exploitation details."
if "confirmed" in text.lower() or "success" in text.lower():
outcome = "Vulnerability confirmed with proof of exploitation."
elif "impacto" in text.lower() or "impact" in text.lower():
impact_match = re.search(r'(?:impacto|impact)[:\s]*(.{20,200})', text, re.I)
if impact_match:
outcome = impact_match.group(1).strip()
technology = meta.get("technology", meta.get("technologies", "unknown"))
if isinstance(technology, str) and "," in technology:
technology = technology.split(",")[0]
scenario = text[:200].replace("\n", " ").strip()
return FewShotExample(
vuln_type=meta.get("vuln_type", vuln_type),
technology=str(technology),
scenario=scenario,
reasoning_chain=chain,
outcome=outcome,
payload=payload,
score=chunk.score
)
def _extract_reasoning_from_text(self, text: str) -> List[str]:
"""Extract reasoning steps from a bug bounty report."""
steps = []
# Try numbered steps
numbered = re.findall(r'(?:^|\n)\s*(\d+)\.\s+(.{10,200})', text)
if len(numbered) >= 2:
for num, step in numbered[:6]:
steps.append(step.strip())
return steps
# Try bullet points
bullets = re.findall(r'(?:^|\n)\s*[-*]\s+(.{10,200})', text)
if len(bullets) >= 2:
for bullet in bullets[:6]:
steps.append(bullet.strip())
return steps
# Try section-based extraction
sections = re.findall(r'###?\s+(.+?)(?:\n|$)', text)
for section in sections[:6]:
steps.append(section.strip())
if not steps:
# Fall back to sentence extraction
sentences = re.split(r'[.!]\s+', text[:800])
for sent in sentences[:4]:
if len(sent.strip()) > 20:
steps.append(sent.strip())
return steps
def _rank_examples(self, examples: List[FewShotExample],
vuln_type: str, technology: str) -> List[FewShotExample]:
"""Rank examples by relevance to the target vuln type and technology."""
for example in examples:
score = example.score
# Boost exact vuln type match
if example.vuln_type == vuln_type:
score += 2.0
# Boost technology match
if technology and technology.lower() in example.technology.lower():
score += 1.5
# Boost examples with reasoning chains
if example.reasoning_chain and len(example.reasoning_chain) >= 3:
score += 1.0
# Boost examples with payloads
if example.payload:
score += 0.5
# Boost examples with proof
if example.proof:
score += 0.5
example.score = score
examples.sort(key=lambda e: e.score, reverse=True)
# Deduplicate by scenario similarity
seen_starts = set()
unique = []
for ex in examples:
start = ex.scenario[:50].lower()
if start not in seen_starts:
seen_starts.add(start)
unique.append(ex)
return unique
def _format_examples(self, examples: List[FewShotExample],
header: str) -> str:
"""Format examples into a prompt-ready string."""
if not examples:
return ""
text = f"\n=== {header} (Learn from these real-world cases) ===\n"
text += "Study these examples to understand the REASONING PATTERN, then apply similar logic.\n\n"
for i, example in enumerate(examples, 1):
text += f"[Example {i}]\n"
text += example.format(include_chain=True)
text += "\n"
text += f"=== END {header} ===\n"
return text
def _get_curated_for_type(self, vuln_type: str) -> List[FewShotExample]:
"""Get curated examples for a vulnerability type."""
vtype = vuln_type.lower().replace("-", "_")
examples = []
if vtype in self._curated_examples:
for ex_data in self._curated_examples[vtype].get("testing", []):
examples.append(FewShotExample(**ex_data, score=10.0))
# Also check parent types (e.g., xss_reflected -> xss)
base_type = vtype.split("_")[0]
if base_type != vtype and base_type in self._curated_examples:
for ex_data in self._curated_examples[base_type].get("testing", []):
examples.append(FewShotExample(**ex_data, score=8.0))
return examples
def _get_curated_verification(self, vuln_type: str,
is_tp: bool) -> List[FewShotExample]:
"""Get curated verification examples (TP or FP)."""
vtype = vuln_type.lower().replace("-", "_")
key = "verification_tp" if is_tp else "verification_fp"
examples = []
if vtype in self._curated_examples:
for ex_data in self._curated_examples[vtype].get(key, []):
examples.append(FewShotExample(**ex_data, score=10.0))
base_type = vtype.split("_")[0]
if base_type != vtype and base_type in self._curated_examples:
for ex_data in self._curated_examples[base_type].get(key, []):
examples.append(FewShotExample(**ex_data, score=8.0))
return examples
def _get_curated_strategy(self, technology: str) -> List[FewShotExample]:
"""Get curated strategy examples for a technology."""
tech = technology.lower()
if tech in self._curated_examples.get("_strategies", {}):
data = self._curated_examples["_strategies"][tech]
return [FewShotExample(**data, score=10.0)]
return []
def _build_curated_examples(self) -> Dict:
"""
Build curated high-quality few-shot examples.
These are hand-crafted to demonstrate ideal reasoning patterns.
"""
return {
"xss": {
"testing": [
{
"vuln_type": "xss_reflected",
"technology": "PHP",
"scenario": "Search parameter reflected in HTML body without encoding",
"reasoning_chain": [
"OBSERVE: Parameter 'q' is reflected verbatim in <div class='results'>",
"IDENTIFY CONTEXT: Reflection is inside HTML body (not attribute, not JS)",
"TEST FILTERS: Sent <b>test</b> - HTML tags rendered, no encoding",
"ESCALATE: Injected <script>alert(document.domain)</script>",
"VERIFY: Script executed in browser, alert showed domain name",
"PROVE: DOM inspection confirms injected <script> tag is live"
],
"outcome": "Confirmed: Reflected XSS via unencoded HTML body injection",
"payload": "<script>alert(document.domain)</script>",
"proof": "Playwright confirmed script execution, DOM shows injected tag"
}
],
"verification_tp": [
{
"vuln_type": "xss_reflected",
"technology": "generic",
"scenario": "Verifying XSS finding is a TRUE POSITIVE",
"reasoning_chain": [
"CHECK 1: Payload appears in response body unencoded? YES",
"CHECK 2: Payload is in executable context (inside HTML, not comment)? YES",
"CHECK 3: No CSP header blocking inline scripts? CORRECT, no CSP",
"CHECK 4: Browser actually executes the script? YES (Playwright confirms)",
"VERDICT: All 4 checks pass → TRUE POSITIVE"
],
"outcome": "CONFIRMED: True positive - all verification checks passed",
"proof": "Browser execution confirmed via Playwright"
}
],
"verification_fp": [
{
"vuln_type": "xss_reflected",
"technology": "generic",
"scenario": "Verifying XSS finding that is a FALSE POSITIVE",
"reasoning_chain": [
"CHECK 1: Payload in response? YES, but only inside HTML comment <!-- -->",
"CHECK 2: Executable context? NO - HTML comments are not executed",
"CHECK 3: Even if we break out of comment, CSP blocks inline scripts",
"VERDICT: Payload present but NOT executable → FALSE POSITIVE"
],
"outcome": "REJECTED: False positive - payload in non-executable context",
"proof": "No browser execution possible due to HTML comment context + CSP"
}
]
},
"sqli": {
"testing": [
{
"vuln_type": "sqli",
"technology": "PHP/MySQL",
"scenario": "Login form with username and password fields",
"reasoning_chain": [
"OBSERVE: Login form sends POST with username & password params",
"PROBE: Single quote in username returns MySQL error: 'syntax error near '''",
"IDENTIFY: Error-based SQL injection, MySQL backend confirmed",
"TEST UNION: ' UNION SELECT 1,2,3-- - reveals 3 columns",
"EXTRACT: ' UNION SELECT user(),database(),version()-- - shows root@localhost",
"PROVE: Extracted real DB info (database name, MySQL version, current user)"
],
"outcome": "Confirmed: UNION-based SQL injection with full data extraction",
"payload": "' UNION SELECT user(),database(),version()-- -",
"proof": "Database name, version and user extracted from response"
}
],
"verification_tp": [
{
"vuln_type": "sqli",
"technology": "generic",
"scenario": "Verifying SQL injection is TRUE POSITIVE",
"reasoning_chain": [
"CHECK 1: Database error message in response? YES (MySQL syntax error)",
"CHECK 2: Error contains our injected syntax? YES (shows the quote)",
"CHECK 3: Can we extract data? YES (UNION SELECT returns DB version)",
"CHECK 4: Is data extraction real? YES (version string matches known MySQL format)",
"VERDICT: Data extraction proven → TRUE POSITIVE"
],
"outcome": "CONFIRMED: True positive - actual data extraction achieved"
}
],
"verification_fp": [
{
"vuln_type": "sqli",
"technology": "generic",
"scenario": "WAF error page mimics SQL error",
"reasoning_chain": [
"CHECK 1: Error message in response? YES, but it's a generic WAF block page",
"CHECK 2: Same error for ANY special character? YES - WAF blocks all",
"CHECK 3: Can we extract data? NO - all payloads return same WAF page",
"VERDICT: WAF blocking, not SQL processing → FALSE POSITIVE"
],
"outcome": "REJECTED: False positive - WAF error page, not database error"
}
]
},
"ssrf": {
"testing": [
{
"vuln_type": "ssrf",
"technology": "Python/Flask",
"scenario": "URL parameter used for fetching external content",
"reasoning_chain": [
"OBSERVE: Parameter 'url' fetches and displays content from URLs",
"PROBE: Sent url=http://127.0.0.1:80 - got response from localhost",
"TEST INTERNAL: url=http://169.254.169.254/latest/meta-data/ - got AWS metadata!",
"EXTRACT: Retrieved IAM role name and temporary credentials",
"PROVE: AWS metadata content (ami-id, instance-type) confirms internal access"
],
"outcome": "Confirmed: SSRF to AWS metadata endpoint with credential extraction",
"payload": "http://169.254.169.254/latest/meta-data/iam/security-credentials/",
"proof": "AWS IAM credentials retrieved from metadata endpoint"
}
],
"verification_fp": [
{
"vuln_type": "ssrf",
"technology": "generic",
"scenario": "Status code difference is NOT proof of SSRF",
"reasoning_chain": [
"CHECK 1: Different status code with internal URL? YES (403→200)",
"CHECK 2: But is the CONTENT from internal service? NO - same login page",
"CHECK 3: Negative control (random URL) also returns 200? YES",
"VERDICT: Status code change is application behavior, NOT SSRF → FALSE POSITIVE"
],
"outcome": "REJECTED: Status code diff without internal content is NOT SSRF"
}
]
},
"idor": {
"testing": [
{
"vuln_type": "idor",
"technology": "REST API",
"scenario": "User profile API endpoint with numeric ID",
"reasoning_chain": [
"OBSERVE: GET /api/users/42 returns user profile (my ID is 42)",
"TEST: GET /api/users/1 with my auth token - got different user's data!",
"COMPARE DATA: Response contains name='Admin', email='admin@target.com'",
"VERIFY: This is NOT my data - different name, email, role",
"PROVE: Can access ANY user's profile by changing ID parameter"
],
"outcome": "Confirmed: IDOR - can access other users' profiles via ID enumeration",
"proof": "Different user's PII (name, email) retrieved with attacker's token"
}
],
"verification_fp": [
{
"vuln_type": "idor",
"technology": "generic",
"scenario": "Same response for different IDs is NOT IDOR",
"reasoning_chain": [
"CHECK 1: Different ID returns 200? YES",
"CHECK 2: But compare the DATA content - is it actually DIFFERENT user data? NO",
"CHECK 3: Both IDs return the SAME profile (my own data)",
"VERDICT: Server ignores the ID parameter, always returns current user → FALSE POSITIVE"
],
"outcome": "REJECTED: Same data returned regardless of ID - no object-level access violation"
}
]
},
"rce": {
"testing": [
{
"vuln_type": "rce",
"technology": "Node.js",
"scenario": "Template rendering endpoint with user-controlled input",
"reasoning_chain": [
"OBSERVE: Parameter 'name' rendered in template, endpoint uses eval-like function",
"PROBE: Sent {{7*7}} - response shows 49 (template injection confirmed)",
"ESCALATE: {{require('child_process').execSync('id')}} - returns uid=0(root)!",
"EXTRACT: Read /etc/passwd via command execution",
"PROVE: OS command output (uid, file contents) confirms RCE"
],
"outcome": "Confirmed: RCE via SSTI in Node.js (template to command execution chain)",
"payload": "{{require('child_process').execSync('id')}}",
"proof": "Command output 'uid=0(root)' in HTTP response"
}
]
},
"ssti": {
"testing": [
{
"vuln_type": "ssti",
"technology": "Python/Jinja2",
"scenario": "Name field rendered via Jinja2 template engine",
"reasoning_chain": [
"OBSERVE: Input reflected in response via template rendering",
"PROBE: {{7*7}} returns '49' - arithmetic evaluated = SSTI confirmed",
"IDENTIFY ENGINE: {{config}} returns Flask config = Jinja2 confirmed",
"ESCALATE: Use MRO chain to access subprocess module",
"EXECUTE: {{''.__class__.__mro__[1].__subclasses__()}} - list Python classes",
"PROVE: Achieved code execution via Popen subclass"
],
"outcome": "Confirmed: SSTI in Jinja2 with RCE via Python class chain",
"payload": "{{config.__class__.__init__.__globals__['os'].popen('id').read()}}",
"proof": "OS command output returned in template render"
}
]
},
"lfi": {
"testing": [
{
"vuln_type": "lfi",
"technology": "PHP",
"scenario": "File include parameter loading page templates",
"reasoning_chain": [
"OBSERVE: Parameter 'page=about' loads about.php template",
"TEST: page=../../../etc/passwd - returned 'root:x:0:0:root' content!",
"VERIFY: Content matches /etc/passwd format (user:x:uid:gid:...)",
"ESCALATE: Read application config via page=../config/database.php",
"PROVE: Extracted database credentials from config file"
],
"outcome": "Confirmed: LFI with path traversal, read sensitive system and app files",
"payload": "../../../etc/passwd",
"proof": "/etc/passwd content with valid user entries in response"
}
]
},
"auth_bypass": {
"testing": [
{
"vuln_type": "auth_bypass",
"technology": "REST API",
"scenario": "Admin panel behind authentication check",
"reasoning_chain": [
"OBSERVE: /admin returns 302 redirect to /login",
"TEST: Send request without following redirect - check if body has admin content",
"PROBE: Try /admin with modified headers (X-Forwarded-For: 127.0.0.1)",
"DISCOVER: /admin/ (trailing slash) bypasses auth check! Returns admin panel",
"VERIFY: Compare authenticated vs unauthenticated response - SAME admin content",
"PROVE: Full admin functionality accessible without any credentials"
],
"outcome": "Confirmed: Auth bypass via trailing slash normalization bug",
"proof": "Admin panel content accessible without authentication"
}
]
},
"_strategies": {
"php": {
"vuln_type": "strategy",
"technology": "PHP",
"scenario": "Planning attack strategy for PHP application",
"reasoning_chain": [
"PHP apps are prone to: SQL injection (especially with raw queries), LFI/RFI (include/require), XSS (echo without htmlspecialchars), file upload bypass, deserialization (unserialize)",
"Priority: Test SQL injection on login/search forms, check for LFI in page/template parameters, test file upload functionality for webshell",
"PHP-specific: Check for type juggling (== vs ===), test PHP wrapper protocols (php://input, php://filter), check for exposed phpinfo()",
"Framework detection: Look for Laravel (.env exposure, debug mode), WordPress (wp-admin, xmlrpc.php), CodeIgniter (CI paths)"
],
"outcome": "Focus on SQLi > LFI > XSS > Upload > Deserialization for PHP targets"
},
"node": {
"vuln_type": "strategy",
"technology": "Node.js",
"scenario": "Planning attack strategy for Node.js application",
"reasoning_chain": [
"Node.js apps are prone to: Prototype pollution, SSTI (pug/ejs/handlebars), NoSQL injection (MongoDB), SSRF, insecure deserialization (node-serialize), path traversal",
"Priority: Test prototype pollution via JSON body (__proto__), check for SSTI in template params, test NoSQL injection on MongoDB endpoints",
"Node-specific: Check for npm package vulns, test for event loop blocking (ReDoS), look for Express middleware bypasses",
"API focus: GraphQL introspection, JWT implementation flaws, WebSocket injection"
],
"outcome": "Focus on Prototype Pollution > SSTI > NoSQL > SSRF for Node.js targets"
},
"python": {
"vuln_type": "strategy",
"technology": "Python",
"scenario": "Planning attack strategy for Python application",
"reasoning_chain": [
"Python apps are prone to: SSTI (Jinja2/Mako), SQL injection (raw queries, ORM bypass), SSRF, pickle deserialization, command injection (os.system/subprocess)",
"Priority: Test SSTI with {{7*7}} on all input fields, check for pickle endpoints, test SSRF on URL parameters",
"Python-specific: Django debug mode, Flask debug/secret key exposure, YAML deserialization (yaml.load), eval/exec injection",
"Framework: Django admin exposure, Flask /console (Werkzeug debugger), FastAPI /docs endpoint"
],
"outcome": "Focus on SSTI > SQLi > SSRF > Deserialization for Python targets"
},
"java": {
"vuln_type": "strategy",
"technology": "Java",
"scenario": "Planning attack strategy for Java application",
"reasoning_chain": [
"Java apps are prone to: Deserialization (ObjectInputStream), XXE (SAX/DOM parsers), SSTI (Velocity/Freemarker), Expression Language injection, Log4Shell",
"Priority: Test deserialization on all serialized object endpoints, check XXE on XML parsing endpoints, test EL injection",
"Java-specific: Check for Java serialization magic bytes (aced0005), test Log4j via ${jndi:ldap://} in headers, Struts OGNL injection",
"Framework: Spring Boot Actuator endpoints (/env, /heapdump), Tomcat manager exposure, JBoss/WildFly admin"
],
"outcome": "Focus on Deserialization > XXE > Log4Shell > SSTI for Java targets"
}
}
}
+399
View File
@@ -0,0 +1,399 @@
"""
Reasoning Memory - Cross-scan storage of successful reasoning traces.
This is "pseudo-fine-tuning": instead of modifying model weights, we
accumulate successful reasoning chains and inject them as context
into future prompts. Over time, the system learns from its own
successful analyses.
Stores:
- Confirmed finding reasoning chains
- Failed hypothesis patterns (what didn't work and why)
- Technology-specific successful strategies
- Payload effectiveness per context
"""
import json
import logging
import time
import hashlib
from pathlib import Path
from typing import List, Dict, Optional, Tuple
from dataclasses import dataclass, field, asdict
logger = logging.getLogger(__name__)
MEMORY_FILE = "data/reasoning_memory.json"
MAX_TRACES = 500
MAX_FAILURES = 200
MAX_STRATEGIES = 100
@dataclass
class ReasoningTrace:
"""A successful reasoning chain from a confirmed finding."""
vuln_type: str
technology: str
endpoint_pattern: str # Normalized pattern (IDs removed)
parameter: str
reasoning_steps: List[str]
payload_used: str
evidence_summary: str
confidence: float
timestamp: float = 0.0
scan_target: str = ""
trace_id: str = ""
def __post_init__(self):
if not self.timestamp:
self.timestamp = time.time()
if not self.trace_id:
key = f"{self.vuln_type}_{self.endpoint_pattern}_{self.payload_used}"
self.trace_id = hashlib.md5(key.encode()).hexdigest()[:10]
@dataclass
class FailureRecord:
"""A record of what didn't work and why."""
vuln_type: str
technology: str
endpoint_pattern: str
attempted_payloads: List[str]
failure_reason: str # "waf_blocked", "encoded", "no_reflection", "same_behavior", etc.
timestamp: float = 0.0
def __post_init__(self):
if not self.timestamp:
self.timestamp = time.time()
@dataclass
class StrategyRecord:
"""A successful technology-specific strategy."""
technology: str
vuln_types_found: List[str]
priority_order: List[str]
key_insights: List[str]
scan_count: int = 1
success_rate: float = 0.0
timestamp: float = 0.0
def __post_init__(self):
if not self.timestamp:
self.timestamp = time.time()
class ReasoningMemory:
"""
Persistent storage and retrieval of reasoning experience.
Learns from successful attacks and failed attempts to provide
increasingly relevant context over time.
"""
def __init__(self, memory_path: str = None):
self.memory_path = Path(memory_path or MEMORY_FILE)
self.memory_path.parent.mkdir(parents=True, exist_ok=True)
self._traces: List[Dict] = []
self._failures: List[Dict] = []
self._strategies: Dict[str, Dict] = {} # tech -> StrategyRecord dict
self._dirty = False
self._load()
def _load(self):
"""Load persisted memory."""
if not self.memory_path.exists():
return
try:
with open(self.memory_path, 'r', encoding='utf-8') as f:
data = json.load(f)
self._traces = data.get("traces", [])
self._failures = data.get("failures", [])
self._strategies = data.get("strategies", {})
logger.info(
f"ReasoningMemory: Loaded {len(self._traces)} traces, "
f"{len(self._failures)} failures, {len(self._strategies)} strategies"
)
except Exception as e:
logger.warning(f"ReasoningMemory: Failed to load: {e}")
def _save(self):
"""Persist memory to disk."""
if not self._dirty:
return
# Enforce size limits
self._traces = self._traces[-MAX_TRACES:]
self._failures = self._failures[-MAX_FAILURES:]
try:
data = {
"traces": self._traces,
"failures": self._failures,
"strategies": self._strategies,
"last_updated": time.time(),
"stats": {
"total_traces": len(self._traces),
"total_failures": len(self._failures),
"technologies": list(self._strategies.keys())
}
}
with open(self.memory_path, 'w', encoding='utf-8') as f:
json.dump(data, f, indent=2, default=str)
self._dirty = False
except Exception as e:
logger.warning(f"ReasoningMemory: Failed to save: {e}")
# ── Recording ──────────────────────────────────────────────
def record_success(self, trace: ReasoningTrace):
"""Record a successful reasoning chain from a confirmed finding."""
trace_dict = asdict(trace)
self._traces.append(trace_dict)
self._dirty = True
# Auto-save periodically
if len(self._traces) % 10 == 0:
self._save()
logger.debug(
f"ReasoningMemory: Recorded success - {trace.vuln_type} "
f"on {trace.technology} ({trace.endpoint_pattern})"
)
def record_failure(self, failure: FailureRecord):
"""Record a failed attack attempt for future avoidance."""
self._failures.append(asdict(failure))
self._dirty = True
if len(self._failures) % 20 == 0:
self._save()
def record_strategy(self, technology: str, vuln_types_found: List[str],
priority_order: List[str], insights: List[str]):
"""Record a successful scanning strategy for a technology."""
tech_key = technology.lower()
if tech_key in self._strategies:
# Update existing
existing = self._strategies[tech_key]
existing["scan_count"] = existing.get("scan_count", 0) + 1
existing["vuln_types_found"] = list(set(
existing.get("vuln_types_found", []) + vuln_types_found
))
# Merge insights
existing_insights = set(existing.get("key_insights", []))
for insight in insights:
existing_insights.add(insight)
existing["key_insights"] = list(existing_insights)[:20]
existing["timestamp"] = time.time()
# Recalculate priority based on accumulated experience
if priority_order:
existing["priority_order"] = priority_order
else:
self._strategies[tech_key] = asdict(StrategyRecord(
technology=technology,
vuln_types_found=vuln_types_found,
priority_order=priority_order,
key_insights=insights
))
self._dirty = True
self._save()
# ── Retrieval ──────────────────────────────────────────────
def get_relevant_traces(self, vuln_type: str, technology: str = "",
max_traces: int = 3) -> List[Dict]:
"""
Retrieve relevant successful reasoning traces.
Prioritizes exact vuln_type match, then technology match.
"""
candidates = []
for trace in self._traces:
score = 0.0
# Vuln type match (primary)
if trace.get("vuln_type") == vuln_type:
score += 5.0
elif vuln_type.split("_")[0] in trace.get("vuln_type", ""):
score += 2.0
else:
continue # Skip irrelevant types
# Technology match (secondary)
if technology and technology.lower() in trace.get("technology", "").lower():
score += 3.0
# Recency boost
age_days = (time.time() - trace.get("timestamp", 0)) / 86400
if age_days < 7:
score += 1.0
elif age_days < 30:
score += 0.5
# High confidence boost
if trace.get("confidence", 0) >= 0.9:
score += 1.0
candidates.append((score, trace))
candidates.sort(key=lambda x: x[0], reverse=True)
return [trace for _, trace in candidates[:max_traces]]
def get_failure_patterns(self, vuln_type: str, technology: str = "",
max_patterns: int = 3) -> List[Dict]:
"""
Retrieve relevant failure patterns to avoid.
"""
candidates = []
for failure in self._failures:
if failure.get("vuln_type") != vuln_type:
continue
score = 1.0
if technology and technology.lower() in failure.get("technology", "").lower():
score += 2.0
candidates.append((score, failure))
candidates.sort(key=lambda x: x[0], reverse=True)
return [f for _, f in candidates[:max_patterns]]
def get_strategy_for_tech(self, technology: str) -> Optional[Dict]:
"""Get accumulated strategy knowledge for a technology."""
tech_key = technology.lower()
return self._strategies.get(tech_key)
def get_context_for_testing(self, vuln_type: str, technology: str = "",
max_chars: int = 1500) -> str:
"""
Format reasoning memory into a prompt-ready context string.
Includes successful traces and failure avoidance.
"""
sections = []
# Successful reasoning traces
traces = self.get_relevant_traces(vuln_type, technology, max_traces=2)
if traces:
section = "## Successful Past Reasoning (from confirmed findings)\n"
for i, trace in enumerate(traces, 1):
section += f"\n### Past Success #{i}: {trace.get('vuln_type')} on {trace.get('technology', 'unknown')}\n"
steps = trace.get("reasoning_steps", [])
if steps:
for step in steps[:4]:
section += f" - {step}\n"
payload = trace.get("payload_used", "")
if payload:
section += f" Effective payload: {payload}\n"
evidence = trace.get("evidence_summary", "")
if evidence:
section += f" Evidence: {evidence[:200]}\n"
sections.append(section)
# Failure avoidance
failures = self.get_failure_patterns(vuln_type, technology, max_patterns=2)
if failures:
section = "## Failed Approaches to AVOID\n"
for failure in failures:
reason = failure.get("failure_reason", "unknown")
endpoint = failure.get("endpoint_pattern", "")
payloads = failure.get("attempted_payloads", [])[:3]
section += f" - {reason} on {endpoint}: payloads {payloads} did NOT work\n"
sections.append(section)
# Technology strategy
if technology:
strategy = self.get_strategy_for_tech(technology)
if strategy:
section = f"## Learned Strategy for {technology}\n"
priority = strategy.get("priority_order", [])
if priority:
section += f" Priority order: {', '.join(priority[:8])}\n"
insights = strategy.get("key_insights", [])
for insight in insights[:3]:
section += f" - {insight}\n"
found = strategy.get("vuln_types_found", [])
if found:
section += f" Previously found: {', '.join(found[:5])}\n"
sections.append(section)
if not sections:
return ""
result = "\n=== REASONING MEMORY (Learned from past scans) ===\n"
result += "Apply these lessons learned. Avoid previously failed approaches.\n\n"
current_len = len(result)
for section in sections:
if current_len + len(section) > max_chars:
remaining = max_chars - current_len - 20
if remaining > 100:
result += section[:remaining] + "...\n"
break
result += section
current_len += len(section)
result += "\n=== END REASONING MEMORY ===\n"
return result
def get_strategy_context(self, technologies: List[str],
max_chars: int = 1000) -> str:
"""
Format accumulated strategy knowledge for attack planning.
"""
sections = []
for tech in technologies[:3]:
strategy = self.get_strategy_for_tech(tech)
if strategy:
section = f"### {tech} (tested {strategy.get('scan_count', 0)} times)\n"
priority = strategy.get("priority_order", [])
if priority:
section += f" Recommended priority: {', '.join(priority[:6])}\n"
found = strategy.get("vuln_types_found", [])
if found:
section += f" Previously successful: {', '.join(found[:5])}\n"
insights = strategy.get("key_insights", [])
for insight in insights[:2]:
section += f" Insight: {insight}\n"
sections.append(section)
if not sections:
return ""
result = "\n=== ACCUMULATED STRATEGY KNOWLEDGE ===\n"
current_len = len(result)
for section in sections:
if current_len + len(section) > max_chars:
break
result += section
current_len += len(section)
result += "=== END STRATEGY KNOWLEDGE ===\n"
return result
def get_stats(self) -> Dict:
"""Return memory statistics."""
vuln_type_counts = {}
for trace in self._traces:
vt = trace.get("vuln_type", "unknown")
vuln_type_counts[vt] = vuln_type_counts.get(vt, 0) + 1
return {
"total_traces": len(self._traces),
"total_failures": len(self._failures),
"technologies_known": list(self._strategies.keys()),
"vuln_type_distribution": vuln_type_counts,
"memory_file": str(self.memory_path),
"file_exists": self.memory_path.exists()
}
def flush(self):
"""Force save to disk."""
self._dirty = True
self._save()
File diff suppressed because it is too large Load Diff
+643
View File
@@ -0,0 +1,643 @@
"""
Multi-backend vector store for RAG knowledge retrieval.
Backends (in priority order):
1. ChromaDB + sentence-transformers (semantic embeddings, persistent)
2. TF-IDF via scikit-learn (statistical similarity)
3. BM25 (zero dependencies, keyword-based ranking)
All backends provide the same interface: add(), query(), delete_collection().
"""
import json
import math
import hashlib
import logging
import time
from abc import ABC, abstractmethod
from collections import Counter
from dataclasses import dataclass, field, asdict
from pathlib import Path
from typing import List, Dict, Optional, Any, Tuple
logger = logging.getLogger(__name__)
# Optional dependencies
try:
import chromadb
from chromadb.config import Settings as ChromaSettings
HAS_CHROMADB = True
except ImportError:
HAS_CHROMADB = False
try:
from sentence_transformers import SentenceTransformer
HAS_SENTENCE_TRANSFORMERS = True
except ImportError:
HAS_SENTENCE_TRANSFORMERS = False
try:
import numpy as np
HAS_NUMPY = True
except ImportError:
HAS_NUMPY = False
try:
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
HAS_SKLEARN = True
except ImportError:
HAS_SKLEARN = False
@dataclass
class RetrievedChunk:
"""A retrieved knowledge chunk with relevance score."""
text: str
score: float
metadata: Dict[str, Any] = field(default_factory=dict)
chunk_id: str = ""
source: str = ""
def to_dict(self) -> Dict:
return asdict(self)
@dataclass
class Document:
"""A document to be indexed in the vector store."""
text: str
metadata: Dict[str, Any] = field(default_factory=dict)
doc_id: str = ""
def __post_init__(self):
if not self.doc_id:
self.doc_id = hashlib.md5(self.text[:500].encode()).hexdigest()[:12]
class BaseVectorStore(ABC):
"""Abstract vector store interface."""
@abstractmethod
def add(self, collection: str, documents: List[Document]) -> int:
"""Add documents to a collection. Returns count added."""
pass
@abstractmethod
def query(self, collection: str, query_text: str, top_k: int = 5,
metadata_filter: Optional[Dict] = None) -> List[RetrievedChunk]:
"""Query a collection for relevant documents."""
pass
@abstractmethod
def collection_exists(self, collection: str) -> bool:
"""Check if a collection has been indexed."""
pass
@abstractmethod
def delete_collection(self, collection: str) -> None:
"""Delete a collection and all its documents."""
pass
@abstractmethod
def collection_count(self, collection: str) -> int:
"""Return number of documents in a collection."""
pass
@property
@abstractmethod
def backend_name(self) -> str:
pass
class BM25VectorStore(BaseVectorStore):
"""
BM25 (Best Matching 25) keyword-based ranking.
Zero external dependencies - works with pure Python.
Good for exact keyword matching and term-frequency scoring.
"""
def __init__(self, persist_dir: str, k1: float = 1.5, b: float = 0.75):
self.persist_dir = Path(persist_dir)
self.persist_dir.mkdir(parents=True, exist_ok=True)
self.k1 = k1
self.b = b
self._collections: Dict[str, Dict] = {}
self._load_persisted()
@property
def backend_name(self) -> str:
return "bm25"
def _tokenize(self, text: str) -> List[str]:
"""Simple whitespace + punctuation tokenizer."""
import re
text = text.lower()
tokens = re.findall(r'\b[a-z0-9_]{2,}\b', text)
return tokens
def _load_persisted(self):
"""Load persisted collections from disk."""
index_file = self.persist_dir / "bm25_index.json"
if index_file.exists():
try:
with open(index_file, 'r') as f:
data = json.load(f)
self._collections = data.get("collections", {})
logger.info(f"BM25: Loaded {len(self._collections)} collections from disk")
except Exception as e:
logger.warning(f"BM25: Failed to load index: {e}")
self._collections = {}
def _persist(self):
"""Persist collections to disk."""
index_file = self.persist_dir / "bm25_index.json"
try:
with open(index_file, 'w') as f:
json.dump({"collections": self._collections, "timestamp": time.time()}, f)
except Exception as e:
logger.warning(f"BM25: Failed to persist index: {e}")
def add(self, collection: str, documents: List[Document]) -> int:
if not documents:
return 0
if collection not in self._collections:
self._collections[collection] = {
"documents": [],
"doc_freqs": [],
"df": {},
"doc_lengths": [],
"avgdl": 0,
"N": 0
}
col = self._collections[collection]
added = 0
existing_ids = {d.get("doc_id", "") for d in col["documents"]}
for doc in documents:
if doc.doc_id in existing_ids:
continue
tokens = self._tokenize(doc.text)
token_freq = dict(Counter(tokens))
unique_tokens = set(tokens)
col["documents"].append({
"doc_id": doc.doc_id,
"text": doc.text[:5000], # Cap storage
"metadata": doc.metadata
})
col["doc_freqs"].append(token_freq)
col["doc_lengths"].append(len(tokens))
for token in unique_tokens:
col["df"][token] = col["df"].get(token, 0) + 1
added += 1
col["N"] = len(col["documents"])
col["avgdl"] = sum(col["doc_lengths"]) / max(col["N"], 1)
if added > 0:
self._persist()
return added
def query(self, collection: str, query_text: str, top_k: int = 5,
metadata_filter: Optional[Dict] = None) -> List[RetrievedChunk]:
if collection not in self._collections:
return []
col = self._collections[collection]
if col["N"] == 0:
return []
query_tokens = self._tokenize(query_text)
if not query_tokens:
return []
scores = []
N = col["N"]
avgdl = col["avgdl"]
for i in range(N):
# Metadata filter
if metadata_filter:
doc_meta = col["documents"][i].get("metadata", {})
skip = False
for key, val in metadata_filter.items():
if isinstance(val, list):
if doc_meta.get(key) not in val:
skip = True
break
elif doc_meta.get(key) != val:
skip = True
break
if skip:
scores.append(0.0)
continue
doc_freq = col["doc_freqs"][i]
doc_len = col["doc_lengths"][i]
score = 0.0
for token in query_tokens:
if token not in doc_freq:
continue
tf = doc_freq[token]
df = col["df"].get(token, 0)
# BM25 IDF
idf = math.log((N - df + 0.5) / (df + 0.5) + 1.0)
# BM25 TF normalization
tf_norm = (tf * (self.k1 + 1)) / (
tf + self.k1 * (1.0 - self.b + self.b * doc_len / avgdl)
)
score += idf * tf_norm
scores.append(score)
# Get top-k
indexed_scores = [(i, s) for i, s in enumerate(scores) if s > 0]
indexed_scores.sort(key=lambda x: x[1], reverse=True)
results = []
for i, score in indexed_scores[:top_k]:
doc = col["documents"][i]
results.append(RetrievedChunk(
text=doc["text"],
score=score,
metadata=doc.get("metadata", {}),
chunk_id=doc.get("doc_id", f"doc_{i}"),
source=collection
))
return results
def collection_exists(self, collection: str) -> bool:
return collection in self._collections and self._collections[collection]["N"] > 0
def delete_collection(self, collection: str) -> None:
if collection in self._collections:
del self._collections[collection]
self._persist()
def collection_count(self, collection: str) -> int:
if collection not in self._collections:
return 0
return self._collections[collection]["N"]
class TFIDFVectorStore(BaseVectorStore):
"""
TF-IDF based vector store using scikit-learn.
Better than BM25 for capturing document-level similarity.
Requires: scikit-learn, numpy
"""
def __init__(self, persist_dir: str):
if not HAS_SKLEARN or not HAS_NUMPY:
raise ImportError("TF-IDF backend requires scikit-learn and numpy")
self.persist_dir = Path(persist_dir)
self.persist_dir.mkdir(parents=True, exist_ok=True)
self._collections: Dict[str, Dict] = {}
@property
def backend_name(self) -> str:
return "tfidf"
def add(self, collection: str, documents: List[Document]) -> int:
if not documents:
return 0
if collection not in self._collections:
self._collections[collection] = {
"documents": [],
"texts": [],
"vectorizer": None,
"matrix": None
}
col = self._collections[collection]
existing_ids = {d.get("doc_id", "") for d in col["documents"]}
added = 0
for doc in documents:
if doc.doc_id in existing_ids:
continue
col["documents"].append({
"doc_id": doc.doc_id,
"text": doc.text[:5000],
"metadata": doc.metadata
})
col["texts"].append(doc.text[:5000])
added += 1
if added > 0:
# Rebuild TF-IDF matrix
vectorizer = TfidfVectorizer(
max_features=10000,
stop_words='english',
ngram_range=(1, 2),
min_df=1,
max_df=0.95
)
col["matrix"] = vectorizer.fit_transform(col["texts"])
col["vectorizer"] = vectorizer
return added
def query(self, collection: str, query_text: str, top_k: int = 5,
metadata_filter: Optional[Dict] = None) -> List[RetrievedChunk]:
if collection not in self._collections:
return []
col = self._collections[collection]
if col["vectorizer"] is None or col["matrix"] is None:
return []
query_vec = col["vectorizer"].transform([query_text])
similarities = cosine_similarity(query_vec, col["matrix"]).flatten()
# Apply metadata filter
if metadata_filter:
for i, doc in enumerate(col["documents"]):
meta = doc.get("metadata", {})
for key, val in metadata_filter.items():
if isinstance(val, list):
if meta.get(key) not in val:
similarities[i] = 0.0
elif meta.get(key) != val:
similarities[i] = 0.0
top_indices = np.argsort(similarities)[::-1][:top_k]
results = []
for i in top_indices:
if similarities[i] <= 0:
continue
doc = col["documents"][i]
results.append(RetrievedChunk(
text=doc["text"],
score=float(similarities[i]),
metadata=doc.get("metadata", {}),
chunk_id=doc.get("doc_id", f"doc_{i}"),
source=collection
))
return results
def collection_exists(self, collection: str) -> bool:
return (collection in self._collections and
len(self._collections[collection]["documents"]) > 0)
def delete_collection(self, collection: str) -> None:
if collection in self._collections:
del self._collections[collection]
def collection_count(self, collection: str) -> int:
if collection not in self._collections:
return 0
return len(self._collections[collection]["documents"])
class ChromaVectorStore(BaseVectorStore):
"""
ChromaDB + sentence-transformers for true semantic embeddings.
Best quality: understands meaning, not just keywords.
Requires: chromadb, sentence-transformers
"""
DEFAULT_MODEL = "all-MiniLM-L6-v2" # Fast, 384-dim, good quality
def __init__(self, persist_dir: str, model_name: str = None):
if not HAS_CHROMADB:
raise ImportError("ChromaDB backend requires: pip install chromadb")
self.persist_dir = Path(persist_dir)
self.persist_dir.mkdir(parents=True, exist_ok=True)
self.client = chromadb.PersistentClient(
path=str(self.persist_dir / "chromadb")
)
# Embedding model
self._embed_model = None
self._model_name = model_name or self.DEFAULT_MODEL
if HAS_SENTENCE_TRANSFORMERS:
try:
self._embed_model = SentenceTransformer(self._model_name)
logger.info(f"ChromaDB: Loaded embedding model '{self._model_name}'")
except Exception as e:
logger.warning(f"ChromaDB: Failed to load model: {e}")
@property
def backend_name(self) -> str:
return "chromadb"
def _get_collection(self, name: str):
"""Get or create a ChromaDB collection."""
if self._embed_model:
return self.client.get_or_create_collection(
name=name,
metadata={"hnsw:space": "cosine"}
)
else:
return self.client.get_or_create_collection(name=name)
def _embed(self, texts: List[str]) -> Optional[List[List[float]]]:
"""Generate embeddings using sentence-transformers."""
if not self._embed_model:
return None
try:
embeddings = self._embed_model.encode(texts, show_progress_bar=False)
return embeddings.tolist()
except Exception as e:
logger.warning(f"ChromaDB: Embedding failed: {e}")
return None
def add(self, collection: str, documents: List[Document]) -> int:
if not documents:
return 0
col = self._get_collection(collection)
# Filter already-indexed docs
existing = set()
try:
result = col.get()
if result and result.get("ids"):
existing = set(result["ids"])
except Exception:
pass
new_docs = [d for d in documents if d.doc_id not in existing]
if not new_docs:
return 0
# Batch add (ChromaDB limit: 41666 per batch)
batch_size = 500
added = 0
for start in range(0, len(new_docs), batch_size):
batch = new_docs[start:start + batch_size]
ids = [d.doc_id for d in batch]
texts = [d.text[:5000] for d in batch]
metadatas = []
for d in batch:
# ChromaDB metadata must be str/int/float/bool
meta = {}
for k, v in d.metadata.items():
if isinstance(v, (str, int, float, bool)):
meta[k] = v
elif isinstance(v, list):
meta[k] = ",".join(str(x) for x in v)
else:
meta[k] = str(v)
metadatas.append(meta)
embeddings = self._embed(texts)
try:
if embeddings:
col.add(
ids=ids,
documents=texts,
metadatas=metadatas,
embeddings=embeddings
)
else:
col.add(
ids=ids,
documents=texts,
metadatas=metadatas
)
added += len(batch)
except Exception as e:
logger.warning(f"ChromaDB: Failed to add batch: {e}")
return added
def query(self, collection: str, query_text: str, top_k: int = 5,
metadata_filter: Optional[Dict] = None) -> List[RetrievedChunk]:
try:
col = self._get_collection(collection)
except Exception:
return []
if col.count() == 0:
return []
# Build ChromaDB where clause
where = None
if metadata_filter:
conditions = []
for key, val in metadata_filter.items():
if isinstance(val, list):
conditions.append({key: {"$in": val}})
else:
conditions.append({key: {"$eq": val}})
if len(conditions) == 1:
where = conditions[0]
elif len(conditions) > 1:
where = {"$and": conditions}
# Query with embeddings if available
query_embedding = self._embed([query_text])
try:
if query_embedding:
results = col.query(
query_embeddings=query_embedding,
n_results=min(top_k, col.count()),
where=where
)
else:
results = col.query(
query_texts=[query_text],
n_results=min(top_k, col.count()),
where=where
)
except Exception as e:
logger.warning(f"ChromaDB query failed: {e}")
return []
chunks = []
if results and results.get("documents"):
docs = results["documents"][0]
ids = results["ids"][0] if results.get("ids") else [""] * len(docs)
distances = results["distances"][0] if results.get("distances") else [0.0] * len(docs)
metadatas = results["metadatas"][0] if results.get("metadatas") else [{}] * len(docs)
for text, doc_id, distance, meta in zip(docs, ids, distances, metadatas):
# ChromaDB returns distance (lower = better), convert to similarity score
score = max(0.0, 1.0 - distance)
chunks.append(RetrievedChunk(
text=text,
score=score,
metadata=meta or {},
chunk_id=doc_id,
source=collection
))
return chunks
def collection_exists(self, collection: str) -> bool:
try:
col = self.client.get_collection(collection)
return col.count() > 0
except Exception:
return False
def delete_collection(self, collection: str) -> None:
try:
self.client.delete_collection(collection)
except Exception:
pass
def collection_count(self, collection: str) -> int:
try:
col = self.client.get_collection(collection)
return col.count()
except Exception:
return 0
def create_vectorstore(persist_dir: str, backend: str = "auto") -> BaseVectorStore:
"""
Factory function to create the best available vector store.
Args:
persist_dir: Directory for persistent storage
backend: "auto" (best available), "chromadb", "tfidf", or "bm25"
Returns:
Configured vector store instance
"""
if backend == "chromadb" or (backend == "auto" and HAS_CHROMADB):
try:
store = ChromaVectorStore(persist_dir)
logger.info(f"RAG: Using ChromaDB backend (semantic embeddings)")
return store
except Exception as e:
logger.warning(f"RAG: ChromaDB init failed: {e}, falling back")
if backend == "tfidf" or (backend == "auto" and HAS_SKLEARN):
try:
store = TFIDFVectorStore(persist_dir)
logger.info(f"RAG: Using TF-IDF backend (statistical similarity)")
return store
except Exception as e:
logger.warning(f"RAG: TF-IDF init failed: {e}, falling back")
store = BM25VectorStore(persist_dir)
logger.info(f"RAG: Using BM25 backend (keyword ranking)")
return store