Files
NeuroSploit/core/llm_manager.py
2025-12-19 13:26:15 -03:00

514 lines
24 KiB
Python

#!/usr/bin/env python3
"""
LLM Manager - Unified interface for multiple LLM providers
Supports: Claude, GPT, Gemini, Ollama, and custom models
"""
import os
import json
import subprocess
from typing import Dict, List, Optional, Any
import logging
import requests
from pathlib import Path # Added for Path
import re # Added for regex operations
logger = logging.getLogger(__name__)
class LLMManager:
"""Manage multiple LLM providers"""
def __init__(self, config: Dict):
"""Initialize LLM manager"""
self.config = config.get('llm', {})
self.default_profile_name = self.config.get('default_profile', 'gemini_pro_default')
self.profiles = self.config.get('profiles', {})
self.active_profile = self.profiles.get(self.default_profile_name, {})
# Load active profile settings
self.provider = self.active_profile.get('provider', 'gemini').lower()
self.model = self.active_profile.get('model', 'gemini-pro')
self.api_key = self._get_api_key(self.active_profile.get('api_key', ''))
self.temperature = self.active_profile.get('temperature', 0.7)
self.max_tokens = self.active_profile.get('max_tokens', 4096)
# New LLM parameters
self.input_token_limit = self.active_profile.get('input_token_limit', 4096)
self.output_token_limit = self.active_profile.get('output_token_limit', 4096)
self.cache_enabled = self.active_profile.get('cache_enabled', False)
self.search_context_level = self.active_profile.get('search_context_level', 'medium') # low, medium, high
self.pdf_support_enabled = self.active_profile.get('pdf_support_enabled', False)
self.guardrails_enabled = self.active_profile.get('guardrails_enabled', False)
self.hallucination_mitigation_strategy = self.active_profile.get('hallucination_mitigation_strategy', None)
# New prompt loading
self.json_prompts_file_path = Path("prompts/library.json")
self.md_prompts_dir_path = Path("prompts/md_library")
self.prompts = self._load_all_prompts() # New method to load both
logger.info(f"Initialized LLM Manager - Provider: {self.provider}, Model: {self.model}, Profile: {self.default_profile_name}")
def _get_api_key(self, api_key_config: str) -> str:
"""Helper to get API key from config or environment variable"""
if api_key_config.startswith('${') and api_key_config.endswith('}'):
env_var = api_key_config[2:-1]
return os.getenv(env_var, '')
return api_key_config
def _load_all_prompts(self) -> Dict:
"""Load prompts from both JSON library and Markdown library files."""
all_prompts = {
"json_prompts": {},
"md_prompts": {}
}
# Load from JSON library
if self.json_prompts_file_path.exists():
try:
with open(self.json_prompts_file_path, 'r') as f:
all_prompts["json_prompts"] = json.load(f)
logger.info(f"Loaded prompts from JSON library: {self.json_prompts_file_path}")
except Exception as e:
logger.error(f"Error loading prompts from {self.json_prompts_file_path}: {e}")
else:
logger.warning(f"JSON prompts file not found at {self.json_prompts_file_path}. Some AI functionalities might be limited.")
# Load from Markdown library
if self.md_prompts_dir_path.is_dir():
for md_file in self.md_prompts_dir_path.glob("*.md"):
try:
content = md_file.read_text()
prompt_name = md_file.stem # Use filename as prompt name
user_prompt_match = re.search(r"## User Prompt\n(.*?)(?=\n## System Prompt|\Z)", content, re.DOTALL)
system_prompt_match = re.search(r"## System Prompt\n(.*?)(?=\n## User Prompt|\Z)", content, re.DOTALL)
user_prompt = user_prompt_match.group(1).strip() if user_prompt_match else ""
system_prompt = system_prompt_match.group(1).strip() if system_prompt_match else ""
if user_prompt or system_prompt:
all_prompts["md_prompts"][prompt_name] = {
"user_prompt": user_prompt,
"system_prompt": system_prompt
}
else:
logger.warning(f"No valid User or System Prompt found in {md_file.name}. Skipping.")
except Exception as e:
logger.error(f"Error loading prompt from {md_file.name}: {e}")
logger.info(f"Loaded {len(all_prompts['md_prompts'])} prompts from Markdown library.")
else:
logger.warning(f"Markdown prompts directory not found at {self.md_prompts_dir_path}. Some AI functionalities might be limited.")
return all_prompts
def get_prompt(self, library_type: str, category: str, name: str, default: str = "") -> str:
"""Retrieve a specific prompt by library type, category, and name.
`library_type` can be "json_prompts" or "md_prompts".
`category` can be a JSON top-level key (e.g., 'exploitation') or an MD filename (e.g., 'red_team_agent').
`name` can be a JSON sub-key (e.g., 'ai_exploit_planning_user') or 'user_prompt'/'system_prompt' for MD.
"""
return self.prompts.get(library_type, {}).get(category, {}).get(name, default)
def generate(self, prompt: str, system_prompt: Optional[str] = None) -> str:
"""Generate response from LLM and apply hallucination mitigation if configured."""
raw_response = ""
try:
if self.provider == 'claude':
raw_response = self._generate_claude(prompt, system_prompt)
elif self.provider == 'gpt':
raw_response = self._generate_gpt(prompt, system_prompt)
elif self.provider == 'gemini':
raw_response = self._generate_gemini(prompt, system_prompt)
elif self.provider == 'ollama':
raw_response = self._generate_ollama(prompt, system_prompt)
elif self.provider == 'gemini-cli':
raw_response = self._generate_gemini_cli(prompt, system_prompt)
else:
raise ValueError(f"Unsupported provider: {self.provider}")
except Exception as e:
logger.error(f"Error generating raw response: {e}")
return f"Error: {str(e)}"
if self.guardrails_enabled:
raw_response = self._apply_guardrails(raw_response) # Apply guardrails here
if self.hallucination_mitigation_strategy and self.hallucination_mitigation_strategy in ["grounding", "self_reflection", "consistency_check"]:
logger.debug(f"Applying hallucination mitigation strategy: {self.hallucination_mitigation_strategy}")
return self._mitigate_hallucination(raw_response, prompt, system_prompt)
return raw_response
def _apply_guardrails(self, response: str) -> str:
"""Applies basic guardrails to the LLM response."""
if not self.guardrails_enabled:
return response
logger.debug("Applying guardrails...")
# Example: Simple keyword filtering
harmful_keywords = ["malicious_exploit_command", "destroy_system", "wipe_data", "unauthorized_access"] # Placeholder keywords
for keyword in harmful_keywords:
if keyword in response.lower():
logger.warning(f"Guardrail triggered: Found potentially harmful keyword '{keyword}'. Response will be sanitized or flagged.")
# A more robust solution would involve redaction, re-prompting, or flagging for human review.
# For this example, we'll replace the keyword.
response = response.replace(keyword, "[REDACTED_HARMFUL_CONTENT]")
response = response.replace(keyword.upper(), "[REDACTED_HARMFUL_CONTENT]")
# Example: Length check (if response is excessively long and not expected)
# Using output_token_limit for a more accurate comparison
if len(response.split()) > self.output_token_limit * 1.5: # Roughly estimate tokens by word count
logger.warning("Guardrail triggered: Response is excessively long. Truncating or flagging.")
response = " ".join(response.split()[:int(self.output_token_limit * 1.5)]) + "\n[RESPONSE TRUNCATED BY GUARDRAIL]"
# Ethical check (can be another LLM call, but for simplicity, a fixed instruction)
# This is more about ensuring the tone and content align with ethical hacking principles.
# This is a very simplistic example. A real ethical check would be more nuanced.
# For now, just a log or a general check for explicit unethical instructions.
if any(bad_phrase in response.lower() for bad_phrase in ["perform illegal activity", "bypass security illegally"]):
logger.warning("Guardrail triggered: Response contains potentially unethical instructions. Flagging for review.")
response = "[UNETHICAL CONTENT FLAGGED FOR REVIEW]\n" + response
return response
def _mitigate_hallucination(self, raw_response: str, original_prompt: str, original_system_prompt: Optional[str]) -> str:
"""Applies configured hallucination mitigation strategy."""
strategy = self.hallucination_mitigation_strategy
# Temporarily disable mitigation to prevent infinite recursion when calling self.generate internally
original_mitigation_state = self.hallucination_mitigation_strategy
self.hallucination_mitigation_strategy = None
try:
if strategy == "grounding":
verification_prompt = f"""Review the following response:
---
{raw_response}
---
Based *only* on the context provided in the original prompt (user: '{original_prompt}', system: '{original_system_prompt or "None"}'), is this response factual and directly supported by the context? If not, correct it to be factual. If the response is completely unsourced or makes claims beyond the context, state 'UNSOURCED'."""
logger.debug("Applying grounding strategy: Re-prompting for factual verification.")
return self.generate(verification_prompt, "You are a fact-checker whose sole purpose is to verify LLM output against provided context.")
elif strategy == "self_reflection":
reflection_prompt = f"""Critically review the following response for accuracy, logical consistency, and adherence to the original prompt's instructions:
Original Prompt (User): {original_prompt}
Original Prompt (System): {original_system_prompt or "None"}
Generated Response: {raw_response}
Identify any potential hallucinations, inconsistencies, or areas where the response might have deviated from facts or instructions. If you find issues, provide a corrected and more reliable version of the response. If the response is good, state 'ACCURATE'."""
logger.debug("Applying self-reflection strategy: Re-prompting for self-critique.")
return self.generate(reflection_prompt, "You are an AI assistant designed to critically evaluate and improve other AI-generated content.")
elif strategy == "consistency_check":
logger.debug("Applying consistency check strategy: Generating multiple responses for comparison.")
responses = []
for i in range(3): # Generate 3 responses for consistency check
logger.debug(f"Generating response {i+1} for consistency check.")
res = self.generate(original_prompt, original_system_prompt)
responses.append(res)
if len(set(responses)) == 1:
return responses[0]
else:
logger.warning("Consistency check found varying responses. Attempting to synthesize a consistent answer.")
synthesis_prompt = (
f"Synthesize a single, consistent, and factual response from the following AI-generated options. "
f"Prioritize factual accuracy and avoid information present in only one response if contradictory. "
f"If there's significant disagreement, state the core disagreement.\n\n"
f"Options:\n" + "\n---\n".join(responses)
)
return self.generate(synthesis_prompt, "You are a highly analytical AI assistant tasked with synthesizing consistent information from multiple sources.")
return raw_response # Fallback if strategy not recognized or implemented
finally:
self.hallucination_mitigation_strategy = original_mitigation_state # Restore original state
def _generate_claude(self, prompt: str, system_prompt: Optional[str] = None) -> str:
"""Generate using Claude API"""
import anthropic
client = anthropic.Anthropic(api_key=self.api_key)
messages = [{"role": "user", "content": prompt}]
response = client.messages.create(
model=self.model,
max_tokens=self.max_tokens,
temperature=self.temperature,
system=system_prompt or "",
messages=messages
)
return response.content[0].text
def _generate_gpt(self, prompt: str, system_prompt: Optional[str] = None) -> str:
"""Generate using OpenAI GPT API"""
import openai
client = openai.OpenAI(api_key=self.api_key)
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": prompt})
response = client.chat.completions.create(
model=self.model,
messages=messages,
temperature=self.temperature,
max_tokens=self.max_tokens
)
return response.choices[0].message.content
def _generate_gemini(self, prompt: str, system_prompt: Optional[str] = None) -> str:
"""Generate using Google Gemini API"""
import google.generativeai as genai
genai.configure(api_key=self.api_key)
model = genai.GenerativeModel(self.model)
full_prompt = prompt
if system_prompt:
full_prompt = f"{system_prompt}\n\n{prompt}"
response = model.generate_content(
full_prompt,
generation_config={
'temperature': self.temperature,
'max_output_tokens': self.max_tokens,
}
)
return response.text
def _generate_gemini_cli(self, prompt: str, system_prompt: Optional[str] = None) -> str:
"""Generate using Gemini CLI"""
try:
full_prompt = prompt
if system_prompt:
full_prompt = f"{system_prompt}\n\n{prompt}"
# Use gemini CLI tool
cmd = ['gemini', 'chat', '-m', self.model]
result = subprocess.run(
cmd,
input=full_prompt.encode(),
capture_output=True,
timeout=120
)
if result.returncode == 0:
return result.stdout.decode().strip()
else:
error = result.stderr.decode().strip()
logger.error(f"Gemini CLI error: {error}")
return f"Error: {error}"
except subprocess.TimeoutExpired:
logger.error("Gemini CLI timeout")
return "Error: Request timeout"
except Exception as e:
logger.error(f"Gemini CLI error: {e}")
return f"Error: {str(e)}"
def _generate_ollama(self, prompt: str, system_prompt: Optional[str] = None) -> str:
"""Generate using Ollama local models"""
try:
url = "http://localhost:11434/api/generate"
data = {
"model": self.model,
"prompt": prompt,
"stream": False,
"options": {
"temperature": self.temperature,
"num_predict": self.max_tokens
}
}
if system_prompt:
data["system"] = system_prompt
response = requests.post(url, json=data, timeout=120)
response.raise_for_status()
return response.json()["response"]
except Exception as e:
logger.error(f"Ollama error: {e}")
return f"Error: {str(e)}"
def analyze_vulnerability(self, vulnerability_data: Dict) -> Dict:
"""Analyze vulnerability and suggest exploits"""
# This prompt will be fetched from library.json later
prompt = self.get_prompt("json_prompts", "exploitation", "analyze_vulnerability_user", default=f"""
Analyze the following vulnerability data and provide exploitation recommendations:
Vulnerability: {json.dumps(vulnerability_data, indent=2)}
Provide:
1. Severity assessment (Critical/High/Medium/Low)
2. Exploitation difficulty
3. Potential impact
4. Recommended exploit techniques
5. Detection evasion strategies
6. Post-exploitation actions
Response in JSON format.
""")
system_prompt = self.get_prompt("json_prompts", "exploitation", "analyze_vulnerability_system", default="""You are an expert penetration tester and security researcher.
Analyze vulnerabilities and provide detailed, actionable exploitation strategies.
Consider OWASP, CWE, and MITRE ATT&CK frameworks.
Always include ethical considerations and legal boundaries.""")
response = self.generate(prompt, system_prompt)
try:
return json.loads(response)
except:
return {"raw_response": response}
def generate_payload(self, target_info: Dict, vulnerability_type: str) -> str:
"""Generate exploit payload"""
# This prompt will be fetched from library.json later
prompt = self.get_prompt("json_prompts", "exploitation", "generate_payload_user", default=f"""
Generate an exploit payload for the following scenario:
Target Information:
{json.dumps(target_info, indent=2)}
Vulnerability Type: {vulnerability_type}
Requirements:
1. Generate a working payload
2. Include obfuscation techniques
3. Add error handling
4. Ensure minimal detection footprint
5. Include cleanup procedures
Provide the payload code with detailed comments.
""")
system_prompt = self.get_prompt("json_prompts", "exploitation", "generate_payload_system", default="""You are an expert exploit developer.
Generate sophisticated, tested payloads that are effective yet responsible.
Always include safety mechanisms and ethical guidelines.""")
return self.generate(prompt, system_prompt)
def suggest_privilege_escalation(self, system_info: Dict) -> List[str]:
"""Suggest privilege escalation techniques"""
# This prompt will be fetched from library.json later
prompt = self.get_prompt("json_prompts", "privesc", "suggest_privilege_escalation_user", default=f"""
Based on the following system information, suggest privilege escalation techniques:
System Info:
{json.dumps(system_info, indent=2)}
Provide:
1. Top 5 privilege escalation vectors
2. Required tools and commands
3. Detection likelihood
4. Success probability
5. Alternative approaches
Response in JSON format with prioritized list.
""")
system_prompt = self.get_prompt("json_prompts", "privesc", "suggest_privilege_escalation_system", default="""You are a privilege escalation specialist.
Analyze system configurations and suggest effective escalation paths.
Consider Windows, Linux, and Active Directory environments.""")
response = self.generate(prompt, system_prompt)
try:
result = json.loads(response)
return result.get('techniques', [])
except:
return []
def analyze_network_topology(self, scan_results: Dict) -> Dict:
"""Analyze network topology and suggest attack paths"""
# This prompt will be fetched from library.json later
prompt = self.get_prompt("json_prompts", "network_recon", "analyze_network_topology_user", default=f"""
Analyze the network topology and suggest attack paths:
Scan Results:
{json.dumps(scan_results, indent=2)}
Provide:
1. Network architecture overview
2. Critical assets identification
3. Attack surface analysis
4. Recommended attack paths (prioritized)
5. Lateral movement opportunities
6. Persistence locations
Response in JSON format.
""")
system_prompt = self.get_prompt("json_prompts", "network_recon", "analyze_network_topology_system", default="""You are a network penetration testing expert.
Analyze network structures and identify optimal attack vectors.
Consider defense-in-depth and detection mechanisms.""")
response = self.generate(prompt, system_prompt)
try:
return json.loads(response)
except:
return {"raw_response": response}
def analyze_web_vulnerability(self, vulnerability_type: str, vulnerability_data: Dict) -> Dict:
"""Analyze a specific web vulnerability using the appropriate prompt from library.json"""
user_prompt_name = f"{vulnerability_type.lower()}_user"
system_prompt_name = f"{vulnerability_type.lower()}_system"
# Dynamically fetch user prompt, passing vulnerability_data
user_prompt_template = self.get_prompt("json_prompts", "vulnerability_testing", user_prompt_name)
if not user_prompt_template:
logger.warning(f"No user prompt found for vulnerability type: {vulnerability_type}")
return {"error": f"No user prompt template for {vulnerability_type}"}
# Replace placeholder in the user prompt template
if vulnerability_type.lower() == "ssrf":
prompt = user_prompt_template.format(http_data_json=json.dumps(vulnerability_data, indent=2))
elif vulnerability_type.lower() == "sql_injection":
prompt = user_prompt_template.format(input_data_json=json.dumps(vulnerability_data, indent=2))
elif vulnerability_type.lower() == "xss":
prompt = user_prompt_template.format(xss_data_json=json.dumps(vulnerability_data, indent=2))
elif vulnerability_type.lower() == "lfi":
prompt = user_prompt_template.format(lfi_data_json=json.dumps(vulnerability_data, indent=2))
elif vulnerability_type.lower() == "broken_object":
prompt = user_prompt_template.format(api_data_json=json.dumps(vulnerability_data, indent=2))
elif vulnerability_type.lower() == "broken_auth":
prompt = user_prompt_template.format(auth_data_json=json.dumps(vulnerability_data, indent=2))
else:
logger.warning(f"Unsupported vulnerability type for analysis: {vulnerability_type}")
return {"error": f"Unsupported vulnerability type: {vulnerability_type}"}
system_prompt = self.get_prompt("json_prompts", "vulnerability_testing", system_prompt_name)
if not system_prompt:
logger.warning(f"No system prompt found for vulnerability type: {vulnerability_type}")
# Use a generic system prompt if a specific one isn't found
system_prompt = "You are an expert web security tester. Analyze the provided data for vulnerabilities and offer exploitation steps and remediation."
response = self.generate(prompt, system_prompt)
try:
return json.loads(response)
except json.JSONDecodeError:
logger.error(f"Failed to decode JSON response for {vulnerability_type} analysis: {response}")
return {"raw_response": response}
except Exception as e:
logger.error(f"Error during {vulnerability_type} analysis: {e}")
return {"error": str(e), "raw_response": response}