Files
NeuroSploit/core/llm_manager.py
2025-12-18 18:18:29 -03:00

652 lines
26 KiB
Python

#!/usr/bin/env python3
"""
LLM Manager - Unified interface for multiple LLM providers
Supports: Claude, GPT, Gemini, Ollama, and custom models
"""
import os
import json
import subprocess
from typing import Dict, List, Optional, Any
import logging
import requests
from pathlib import Path # Added for Path
import re # Added for regex operations
logger = logging.getLogger(__name__)
class LLMManager:
"""Manage multiple LLM providers"""
def __init__(self, config: Dict):
"""Initialize LLM manager"""
self.config = config.get('llm', {})
self.default_profile_name = self.config.get('default_profile', 'gemini_pro_default')
self.profiles = self.config.get('profiles', {})
self.active_profile = self.profiles.get(self.default_profile_name, {})
# Load active profile settings
self.provider = self.active_profile.get('provider', 'gemini').lower()
self.model = self.active_profile.get('model', 'gemini-pro')
self.api_key = self._get_api_key(self.active_profile.get('api_key', ''))
self.temperature = self.active_profile.get('temperature', 0.7)
self.max_tokens = self.active_profile.get('max_tokens', 4096)
# New LLM parameters
self.input_token_limit = self.active_profile.get('input_token_limit', 4096)
self.output_token_limit = self.active_profile.get('output_token_limit', 4096)
self.cache_enabled = self.active_profile.get('cache_enabled', False)
self.search_context_level = self.active_profile.get('search_context_level', 'medium') # low, medium, high
self.pdf_support_enabled = self.active_profile.get('pdf_support_enabled', False)
self.guardrails_enabled = self.active_profile.get('guardrails_enabled', False)
self.hallucination_mitigation_strategy = self.active_profile.get('hallucination_mitigation_strategy', None)
# New prompt loading
self.json_prompts_file_path = Path("prompts/library.json")
self.md_prompts_dir_path = Path("prompts/md_library")
self.prompts = self._load_all_prompts() # New method to load both
logger.info(f"Initialized LLM Manager - Provider: {self.provider}, Model: {self.model}, Profile: {self.default_profile_name}")
def _get_api_key(self, api_key_config: str) -> str:
"""Helper to get API key from config or environment variable"""
if api_key_config.startswith('${') and api_key_config.endswith('}'):
env_var = api_key_config[2:-1]
return os.getenv(env_var, '')
return api_key_config
def _load_all_prompts(self) -> Dict:
"""Load prompts from both JSON library and Markdown library files."""
all_prompts = {
"json_prompts": {},
"md_prompts": {}
}
# Load from JSON library
if self.json_prompts_file_path.exists():
try:
with open(self.json_prompts_file_path, 'r') as f:
all_prompts["json_prompts"] = json.load(f)
logger.info(f"Loaded prompts from JSON library: {self.json_prompts_file_path}")
except Exception as e:
logger.error(f"Error loading prompts from {self.json_prompts_file_path}: {e}")
else:
logger.warning(f"JSON prompts file not found at {self.json_prompts_file_path}. Some AI functionalities might be limited.")
# Load from Markdown library
if self.md_prompts_dir_path.is_dir():
for md_file in self.md_prompts_dir_path.glob("*.md"):
try:
content = md_file.read_text()
prompt_name = md_file.stem # Use filename as prompt name
user_prompt_match = re.search(r"## User Prompt\n(.*?)(?=\n## System Prompt|\Z)", content, re.DOTALL)
system_prompt_match = re.search(r"## System Prompt\n(.*?)(?=\n## User Prompt|\Z)", content, re.DOTALL)
user_prompt = user_prompt_match.group(1).strip() if user_prompt_match else ""
system_prompt = system_prompt_match.group(1).strip() if system_prompt_match else ""
if user_prompt or system_prompt:
all_prompts["md_prompts"][prompt_name] = {
"user_prompt": user_prompt,
"system_prompt": system_prompt
}
else:
logger.warning(f"No valid User or System Prompt found in {md_file.name}. Skipping.")
except Exception as e:
logger.error(f"Error loading prompt from {md_file.name}: {e}")
logger.info(f"Loaded {len(all_prompts['md_prompts'])} prompts from Markdown library.")
else:
logger.warning(f"Markdown prompts directory not found at {self.md_prompts_dir_path}. Some AI functionalities might be limited.")
return all_prompts
def get_prompt(self, library_type: str, category: str, name: str, default: str = "") -> str:
"""Retrieve a specific prompt by library type, category, and name.
`library_type` can be "json_prompts" or "md_prompts".
`category` can be a JSON top-level key (e.g., 'exploitation') or an MD filename (e.g., 'red_team_agent').
`name` can be a JSON sub-key (e.g., 'ai_exploit_planning_user') or 'user_prompt'/'system_prompt' for MD.
"""
return self.prompts.get(library_type, {}).get(category, {}).get(name, default)
def generate(self, prompt: str, system_prompt: Optional[str] = None) -> str:
"""Generate response from LLM and apply hallucination mitigation if configured."""
raw_response = ""
try:
if self.provider == 'claude':
raw_response = self._generate_claude(prompt, system_prompt)
elif self.provider == 'gpt':
raw_response = self._generate_gpt(prompt, system_prompt)
elif self.provider == 'gemini':
raw_response = self._generate_gemini(prompt, system_prompt)
elif self.provider == 'ollama':
raw_response = self._generate_ollama(prompt, system_prompt)
elif self.provider == 'gemini-cli':
raw_response = self._generate_gemini_cli(prompt, system_prompt)
else:
raise ValueError(f"Unsupported provider: {self.provider}")
except Exception as e:
logger.error(f"Error generating raw response: {e}")
return f"Error: {str(e)}"
if self.guardrails_enabled:
raw_response = self._apply_guardrails(raw_response) # Apply guardrails here
if self.hallucination_mitigation_strategy and self.hallucination_mitigation_strategy in ["grounding", "self_reflection", "consistency_check"]:
logger.debug(f"Applying hallucination mitigation strategy: {self.hallucination_mitigation_strategy}")
return self._mitigate_hallucination(raw_response, prompt, system_prompt)
return raw_response
def _apply_guardrails(self, response: str) -> str:
"""Applies basic guardrails to the LLM response."""
if not self.guardrails_enabled:
return response
logger.debug("Applying guardrails...")
# Example: Simple keyword filtering
harmful_keywords = ["malicious_exploit_command", "destroy_system", "wipe_data", "unauthorized_access"] # Placeholder keywords
for keyword in harmful_keywords:
if keyword in response.lower():
logger.warning(f"Guardrail triggered: Found potentially harmful keyword '{keyword}'. Response will be sanitized or flagged.")
# A more robust solution would involve redaction, re-prompting, or flagging for human review.
# For this example, we'll replace the keyword.
response = response.replace(keyword, "[REDACTED_HARMFUL_CONTENT]")
response = response.replace(keyword.upper(), "[REDACTED_HARMFUL_CONTENT]")
# Example: Length check (if response is excessively long and not expected)
# Using output_token_limit for a more accurate comparison
if len(response.split()) > self.output_token_limit * 1.5: # Roughly estimate tokens by word count
logger.warning("Guardrail triggered: Response is excessively long. Truncating or flagging.")
response = " ".join(response.split()[:int(self.output_token_limit * 1.5)]) + "\n[RESPONSE TRUNCATED BY GUARDRAIL]"
# Ethical check (can be another LLM call, but for simplicity, a fixed instruction)
# This is more about ensuring the tone and content align with ethical hacking principles.
# This is a very simplistic example. A real ethical check would be more nuanced.
# For now, just a log or a general check for explicit unethical instructions.
if any(bad_phrase in response.lower() for bad_phrase in ["perform illegal activity", "bypass security illegally"]):
logger.warning("Guardrail triggered: Response contains potentially unethical instructions. Flagging for review.")
response = "[UNETHICAL CONTENT FLAGGED FOR REVIEW]\n" + response
return response
def _mitigate_hallucination(self, raw_response: str, original_prompt: str, original_system_prompt: Optional[str]) -> str:
"""Applies configured hallucination mitigation strategy."""
strategy = self.hallucination_mitigation_strategy
# Temporarily disable mitigation to prevent infinite recursion when calling self.generate internally
original_mitigation_state = self.hallucination_mitigation_strategy
self.hallucination_mitigation_strategy = None
try:
if strategy == "grounding":
verification_prompt = (
f"Review the following response:\n\n---\n{raw_response}\n---\n\n"
f"Based *only* on the context provided in the original prompt (user: '{original_prompt}', system: '{original_system_prompt or "None"}'), "
f"is this response factual and directly supported by the context? If not, correct it to be factual. "
f"If the response is completely unsourced or makes claims beyond the context, state 'UNSOURCED'."
)
logger.debug("Applying grounding strategy: Re-prompting for factual verification.")
return self.generate(verification_prompt, "You are a fact-checker whose sole purpose is to verify LLM output against provided context.")
elif strategy == "self_reflection":
reflection_prompt = (
f"Critically review the following response for accuracy, logical consistency, and adherence to the original prompt's instructions:\n\n"
f"Original Prompt (User): {original_prompt}\n"
f"Original Prompt (System): {original_system_prompt or "None"}\n\n"
f"Generated Response: {raw_response}\n\n"
f"Identify any potential hallucinations, inconsistencies, or areas where the response might have deviated from facts or instructions. "
f"If you find issues, provide a corrected and more reliable version of the response. If the response is good, state 'ACCURATE'."
)
logger.debug("Applying self-reflection strategy: Re-prompting for self-critique.")
return self.generate(reflection_prompt, "You are an AI assistant designed to critically evaluate and improve other AI-generated content.")
elif strategy == "consistency_check":
logger.debug("Applying consistency check strategy: Generating multiple responses for comparison.")
responses = []
for i in range(3): # Generate 3 responses for consistency check
logger.debug(f"Generating response {i+1} for consistency check.")
res = self.generate(original_prompt, original_system_prompt)
responses.append(res)
if len(set(responses)) == 1:
return responses[0]
else:
logger.warning("Consistency check found varying responses. Attempting to synthesize a consistent answer.")
synthesis_prompt = (
f"Synthesize a single, consistent, and factual response from the following AI-generated options. "
f"Prioritize factual accuracy and avoid information present in only one response if contradictory. "
f"If there's significant disagreement, state the core disagreement.\n\n"
f"Options:\n" + "\n---\n".join(responses)
)
return self.generate(synthesis_prompt, "You are a highly analytical AI assistant tasked with synthesizing consistent information from multiple sources.")
return raw_response # Fallback if strategy not recognized or implemented
finally:
self.hallucination_mitigation_strategy = original_mitigation_state # Restore original state
def _generate_claude(self, prompt: str, system_prompt: Optional[str] = None) -> str:
"""Generate using Claude API"""
import anthropic
client = anthropic.Anthropic(api_key=self.api_key)
messages = [{"role": "user", "content": prompt}]
response = client.messages.create(
model=self.model,
max_tokens=self.max_tokens,
temperature=self.temperature,
system=system_prompt or "",
messages=messages
)
return response.content[0].text
def _generate_gpt(self, prompt: str, system_prompt: Optional[str] = None) -> str:
"""Generate using OpenAI GPT API"""
import openai
client = openai.OpenAI(api_key=self.api_key)
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": prompt})
response = client.chat.completions.create(
model=self.model,
messages=messages,
temperature=self.temperature,
max_tokens=self.max_tokens
)
return response.choices[0].message.content
def _generate_gemini(self, prompt: str, system_prompt: Optional[str] = None) -> str:
"""Generate using Google Gemini API"""
import google.generativeai as genai
genai.configure(api_key=self.api_key)
model = genai.GenerativeModel(self.model)
full_prompt = prompt
if system_prompt:
full_prompt = f"{system_prompt}\n\n{prompt}"
response = model.generate_content(
full_prompt,
generation_config={
'temperature': self.temperature,
'max_output_tokens': self.max_tokens,
}
)
return response.text
def _generate_gemini_cli(self, prompt: str, system_prompt: Optional[str] = None) -> str:
"""Generate using Gemini CLI"""
try:
full_prompt = prompt
if system_prompt:
full_prompt = f"{system_prompt}\n\n{prompt}"
# Use gemini CLI tool
cmd = ['gemini', 'chat', '-m', self.model]
result = subprocess.run(
cmd,
input=full_prompt.encode(),
capture_output=True,
timeout=120
)
if result.returncode == 0:
return result.stdout.decode().strip()
else:
error = result.stderr.decode().strip()
logger.error(f"Gemini CLI error: {error}")
return f"Error: {error}"
except subprocess.TimeoutExpired:
logger.error("Gemini CLI timeout")
return "Error: Request timeout"
except Exception as e:
logger.error(f"Gemini CLI error: {e}")
return f"Error: {str(e)}"
def _generate_ollama(self, prompt: str, system_prompt: Optional[str] = None) -> str:
"""Generate using Ollama local models"""
try:
url = "http://localhost:11434/api/generate"
data = {
"model": self.model,
"prompt": prompt,
"stream": False,
"options": {
"temperature": self.temperature,
"num_predict": self.max_tokens
}
}
if system_prompt:
data["system"] = system_prompt
response = requests.post(url, json=data, timeout=120)
response.raise_for_status()
return response.json()["response"]
except Exception as e:
logger.error(f"Ollama error: {e}")
return f"Error: {str(e)}"
def analyze_vulnerability(self, vulnerability_data: Dict) -> Dict:
"""Analyze vulnerability and suggest exploits"""
# This prompt will be fetched from library.json later
prompt = self.get_prompt("json_prompts", "exploitation", "analyze_vulnerability_user", default=f"""
Analyze the following vulnerability data and provide exploitation recommendations:
Vulnerability: {json.dumps(vulnerability_data, indent=2)}
Provide:
1. Severity assessment (Critical/High/Medium/Low)
2. Exploitation difficulty
3. Potential impact
4. Recommended exploit techniques
5. Detection evasion strategies
6. Post-exploitation actions
Response in JSON format.
""")
system_prompt = self.get_prompt("json_prompts", "exploitation", "analyze_vulnerability_system", default="""You are an expert penetration tester and security researcher.
Analyze vulnerabilities and provide detailed, actionable exploitation strategies.
Consider OWASP, CWE, and MITRE ATT&CK frameworks.
Always include ethical considerations and legal boundaries.""")
response = self.generate(prompt, system_prompt)
try:
return json.loads(response)
except:
return {"raw_response": response}
def generate_payload(self, target_info: Dict, vulnerability_type: str) -> str:
"""Generate exploit payload"""
# This prompt will be fetched from library.json later
prompt = self.get_prompt("json_prompts", "exploitation", "generate_payload_user", default=f"""
Generate an exploit payload for the following scenario:
Target Information:
{json.dumps(target_info, indent=2)}
Vulnerability Type: {vulnerability_type}
Requirements:
1. Generate a working payload
2. Include obfuscation techniques
3. Add error handling
4. Ensure minimal detection footprint
5. Include cleanup procedures
Provide the payload code with detailed comments.
""")
system_prompt = self.get_prompt("json_prompts", "exploitation", "generate_payload_system", default="""You are an expert exploit developer.
Generate sophisticated, tested payloads that are effective yet responsible.
Always include safety mechanisms and ethical guidelines.""")
return self.generate(prompt, system_prompt)
def suggest_privilege_escalation(self, system_info: Dict) -> List[str]:
"""Suggest privilege escalation techniques"""
# This prompt will be fetched from library.json later
prompt = self.get_prompt("json_prompts", "privesc", "suggest_privilege_escalation_user", default=f"""
Based on the following system information, suggest privilege escalation techniques:
System Info:
{json.dumps(system_info, indent=2)}
Provide:
1. Top 5 privilege escalation vectors
2. Required tools and commands
3. Detection likelihood
4. Success probability
5. Alternative approaches
Response in JSON format with prioritized list.
""")
system_prompt = self.get_prompt("json_prompts", "privesc", "suggest_privilege_escalation_system", default="""You are a privilege escalation specialist.
Analyze system configurations and suggest effective escalation paths.
Consider Windows, Linux, and Active Directory environments.""")
response = self.generate(prompt, system_prompt)
try:
result = json.loads(response)
return result.get('techniques', [])
except:
return []
def analyze_network_topology(self, scan_results: Dict) -> Dict:
"""Analyze network topology and suggest attack paths"""
# This prompt will be fetched from library.json later
prompt = self.get_prompt("json_prompts", "network_recon", "analyze_network_topology_user", default=f"""
Analyze the network topology and suggest attack paths:
Scan Results:
{json.dumps(scan_results, indent=2)}
Provide:
1. Network architecture overview
2. Critical assets identification
3. Attack surface analysis
4. Recommended attack paths (prioritized)
5. Lateral movement opportunities
6. Persistence locations
Response in JSON format.
""")
system_prompt = self.get_prompt("json_prompts", "network_recon", "analyze_network_topology_system", default="""You are a network penetration testing expert.
Analyze network structures and identify optimal attack vectors.
Consider defense-in-depth and detection mechanisms.""")
response = self.generate(prompt, system_prompt)
try:
return json.loads(response)
except:
return {"raw_response": response}
def analyze_web_vulnerability(self, vulnerability_type: str, vulnerability_data: Dict) -> Dict:
"""Analyze a specific web vulnerability using the appropriate prompt from library.json"""
user_prompt_name = f"{vulnerability_type.lower()}_user"
system_prompt_name = f"{vulnerability_type.lower()}_system"
# Dynamically fetch user prompt, passing vulnerability_data
user_prompt_template = self.get_prompt("json_prompts", "vulnerability_testing", user_prompt_name)
if not user_prompt_template:
logger.warning(f"No user prompt found for vulnerability type: {vulnerability_type}")
return {"error": f"No user prompt template for {vulnerability_type}"}
# Replace placeholder in the user prompt template
if vulnerability_type.lower() == "ssrf":
prompt = user_prompt_template.format(http_data_json=json.dumps(vulnerability_data, indent=2))
elif vulnerability_type.lower() == "sql_injection":
prompt = user_prompt_template.format(input_data_json=json.dumps(vulnerability_data, indent=2))
elif vulnerability_type.lower() == "xss":
prompt = user_prompt_template.format(xss_data_json=json.dumps(vulnerability_data, indent=2))
elif vulnerability_type.lower() == "lfi":
prompt = user_prompt_template.format(lfi_data_json=json.dumps(vulnerability_data, indent=2))
elif vulnerability_type.lower() == "broken_object":
prompt = user_prompt_template.format(api_data_json=json.dumps(vulnerability_data, indent=2))
elif vulnerability_type.lower() == "broken_auth":
prompt = user_prompt_template.format(auth_data_json=json.dumps(vulnerability_data, indent=2))
else:
logger.warning(f"Unsupported vulnerability type for analysis: {vulnerability_type}")
return {"error": f"Unsupported vulnerability type: {vulnerability_type}"}
system_prompt = self.get_prompt("json_prompts", "vulnerability_testing", system_prompt_name)
if not system_prompt:
logger.warning(f"No system prompt found for vulnerability type: {vulnerability_type}")
# Use a generic system prompt if a specific one isn't found
system_prompt = "You are an expert web security tester. Analyze the provided data for vulnerabilities and offer exploitation steps and remediation."
response = self.generate(prompt, system_prompt)
try:
return json.loads(response)
except json.JSONDecodeError:
logger.error(f"Failed to decode JSON response for {vulnerability_type} analysis: {response}")
return {"raw_response": response}
except Exception as e:
logger.error(f"Error during {vulnerability_type} analysis: {e}")
return {"error": str(e), "raw_response": response}