#!/usr/bin/env python3 """ LLM Manager - Unified interface for multiple LLM providers Supports: Claude, GPT, Gemini, Ollama, and custom models """ import os import json import subprocess from typing import Dict, List, Optional, Any import logging import requests from pathlib import Path # Added for Path import re # Added for regex operations logger = logging.getLogger(__name__) class LLMManager: """Manage multiple LLM providers""" def __init__(self, config: Dict): """Initialize LLM manager""" self.config = config.get('llm', {}) self.default_profile_name = self.config.get('default_profile', 'gemini_pro_default') self.profiles = self.config.get('profiles', {}) self.active_profile = self.profiles.get(self.default_profile_name, {}) # Load active profile settings self.provider = self.active_profile.get('provider', 'gemini').lower() self.model = self.active_profile.get('model', 'gemini-pro') self.api_key = self._get_api_key(self.active_profile.get('api_key', '')) self.temperature = self.active_profile.get('temperature', 0.7) self.max_tokens = self.active_profile.get('max_tokens', 4096) # New LLM parameters self.input_token_limit = self.active_profile.get('input_token_limit', 4096) self.output_token_limit = self.active_profile.get('output_token_limit', 4096) self.cache_enabled = self.active_profile.get('cache_enabled', False) self.search_context_level = self.active_profile.get('search_context_level', 'medium') # low, medium, high self.pdf_support_enabled = self.active_profile.get('pdf_support_enabled', False) self.guardrails_enabled = self.active_profile.get('guardrails_enabled', False) self.hallucination_mitigation_strategy = self.active_profile.get('hallucination_mitigation_strategy', None) # New prompt loading self.json_prompts_file_path = Path("prompts/library.json") self.md_prompts_dir_path = Path("prompts/md_library") self.prompts = self._load_all_prompts() # New method to load both logger.info(f"Initialized LLM Manager - Provider: {self.provider}, Model: {self.model}, Profile: {self.default_profile_name}") def _get_api_key(self, api_key_config: str) -> str: """Helper to get API key from config or environment variable""" if api_key_config.startswith('${') and api_key_config.endswith('}'): env_var = api_key_config[2:-1] return os.getenv(env_var, '') return api_key_config def _load_all_prompts(self) -> Dict: """Load prompts from both JSON library and Markdown library files.""" all_prompts = { "json_prompts": {}, "md_prompts": {} } # Load from JSON library if self.json_prompts_file_path.exists(): try: with open(self.json_prompts_file_path, 'r') as f: all_prompts["json_prompts"] = json.load(f) logger.info(f"Loaded prompts from JSON library: {self.json_prompts_file_path}") except Exception as e: logger.error(f"Error loading prompts from {self.json_prompts_file_path}: {e}") else: logger.warning(f"JSON prompts file not found at {self.json_prompts_file_path}. Some AI functionalities might be limited.") # Load from Markdown library if self.md_prompts_dir_path.is_dir(): for md_file in self.md_prompts_dir_path.glob("*.md"): try: content = md_file.read_text() prompt_name = md_file.stem # Use filename as prompt name user_prompt_match = re.search(r"## User Prompt\n(.*?)(?=\n## System Prompt|\Z)", content, re.DOTALL) system_prompt_match = re.search(r"## System Prompt\n(.*?)(?=\n## User Prompt|\Z)", content, re.DOTALL) user_prompt = user_prompt_match.group(1).strip() if user_prompt_match else "" system_prompt = system_prompt_match.group(1).strip() if system_prompt_match else "" if user_prompt or system_prompt: all_prompts["md_prompts"][prompt_name] = { "user_prompt": user_prompt, "system_prompt": system_prompt } else: logger.warning(f"No valid User or System Prompt found in {md_file.name}. Skipping.") except Exception as e: logger.error(f"Error loading prompt from {md_file.name}: {e}") logger.info(f"Loaded {len(all_prompts['md_prompts'])} prompts from Markdown library.") else: logger.warning(f"Markdown prompts directory not found at {self.md_prompts_dir_path}. Some AI functionalities might be limited.") return all_prompts def get_prompt(self, library_type: str, category: str, name: str, default: str = "") -> str: """Retrieve a specific prompt by library type, category, and name. `library_type` can be "json_prompts" or "md_prompts". `category` can be a JSON top-level key (e.g., 'exploitation') or an MD filename (e.g., 'red_team_agent'). `name` can be a JSON sub-key (e.g., 'ai_exploit_planning_user') or 'user_prompt'/'system_prompt' for MD. """ return self.prompts.get(library_type, {}).get(category, {}).get(name, default) def generate(self, prompt: str, system_prompt: Optional[str] = None) -> str: """Generate response from LLM and apply hallucination mitigation if configured.""" raw_response = "" try: if self.provider == 'claude': raw_response = self._generate_claude(prompt, system_prompt) elif self.provider == 'gpt': raw_response = self._generate_gpt(prompt, system_prompt) elif self.provider == 'gemini': raw_response = self._generate_gemini(prompt, system_prompt) elif self.provider == 'ollama': raw_response = self._generate_ollama(prompt, system_prompt) elif self.provider == 'gemini-cli': raw_response = self._generate_gemini_cli(prompt, system_prompt) elif self.provider == 'lmstudio': raw_response = self._generate_lmstudio(prompt, system_prompt) else: raise ValueError(f"Unsupported provider: {self.provider}") except Exception as e: logger.error(f"Error generating raw response: {e}") return f"Error: {str(e)}" if self.guardrails_enabled: raw_response = self._apply_guardrails(raw_response) # Apply guardrails here if self.hallucination_mitigation_strategy and self.hallucination_mitigation_strategy in ["grounding", "self_reflection", "consistency_check"]: logger.debug(f"Applying hallucination mitigation strategy: {self.hallucination_mitigation_strategy}") return self._mitigate_hallucination(raw_response, prompt, system_prompt) return raw_response def _apply_guardrails(self, response: str) -> str: """Applies basic guardrails to the LLM response.""" if not self.guardrails_enabled: return response logger.debug("Applying guardrails...") # Example: Simple keyword filtering harmful_keywords = ["malicious_exploit_command", "destroy_system", "wipe_data", "unauthorized_access"] # Placeholder keywords for keyword in harmful_keywords: if keyword in response.lower(): logger.warning(f"Guardrail triggered: Found potentially harmful keyword '{keyword}'. Response will be sanitized or flagged.") # A more robust solution would involve redaction, re-prompting, or flagging for human review. # For this example, we'll replace the keyword. response = response.replace(keyword, "[REDACTED_HARMFUL_CONTENT]") response = response.replace(keyword.upper(), "[REDACTED_HARMFUL_CONTENT]") # Example: Length check (if response is excessively long and not expected) # Using output_token_limit for a more accurate comparison if len(response.split()) > self.output_token_limit * 1.5: # Roughly estimate tokens by word count logger.warning("Guardrail triggered: Response is excessively long. Truncating or flagging.") response = " ".join(response.split()[:int(self.output_token_limit * 1.5)]) + "\n[RESPONSE TRUNCATED BY GUARDRAIL]" # Ethical check (can be another LLM call, but for simplicity, a fixed instruction) # This is more about ensuring the tone and content align with ethical hacking principles. # This is a very simplistic example. A real ethical check would be more nuanced. # For now, just a log or a general check for explicit unethical instructions. if any(bad_phrase in response.lower() for bad_phrase in ["perform illegal activity", "bypass security illegally"]): logger.warning("Guardrail triggered: Response contains potentially unethical instructions. Flagging for review.") response = "[UNETHICAL CONTENT FLAGGED FOR REVIEW]\n" + response return response def _mitigate_hallucination(self, raw_response: str, original_prompt: str, original_system_prompt: Optional[str]) -> str: """Applies configured hallucination mitigation strategy.""" strategy = self.hallucination_mitigation_strategy # Temporarily disable mitigation to prevent infinite recursion when calling self.generate internally original_mitigation_state = self.hallucination_mitigation_strategy self.hallucination_mitigation_strategy = None try: if strategy == "grounding": verification_prompt = f"""Review the following response: --- {raw_response} --- Based *only* on the context provided in the original prompt (user: '{original_prompt}', system: '{original_system_prompt or "None"}'), is this response factual and directly supported by the context? If not, correct it to be factual. If the response is completely unsourced or makes claims beyond the context, state 'UNSOURCED'.""" logger.debug("Applying grounding strategy: Re-prompting for factual verification.") return self.generate(verification_prompt, "You are a fact-checker whose sole purpose is to verify LLM output against provided context.") elif strategy == "self_reflection": reflection_prompt = f"""Critically review the following response for accuracy, logical consistency, and adherence to the original prompt's instructions: Original Prompt (User): {original_prompt} Original Prompt (System): {original_system_prompt or "None"} Generated Response: {raw_response} Identify any potential hallucinations, inconsistencies, or areas where the response might have deviated from facts or instructions. If you find issues, provide a corrected and more reliable version of the response. If the response is good, state 'ACCURATE'.""" logger.debug("Applying self-reflection strategy: Re-prompting for self-critique.") return self.generate(reflection_prompt, "You are an AI assistant designed to critically evaluate and improve other AI-generated content.") elif strategy == "consistency_check": logger.debug("Applying consistency check strategy: Generating multiple responses for comparison.") responses = [] for i in range(3): # Generate 3 responses for consistency check logger.debug(f"Generating response {i+1} for consistency check.") res = self.generate(original_prompt, original_system_prompt) responses.append(res) if len(set(responses)) == 1: return responses[0] else: logger.warning("Consistency check found varying responses. Attempting to synthesize a consistent answer.") synthesis_prompt = ( f"Synthesize a single, consistent, and factual response from the following AI-generated options. " f"Prioritize factual accuracy and avoid information present in only one response if contradictory. " f"If there's significant disagreement, state the core disagreement.\n\n" f"Options:\n" + "\n---\n".join(responses) ) return self.generate(synthesis_prompt, "You are a highly analytical AI assistant tasked with synthesizing consistent information from multiple sources.") return raw_response # Fallback if strategy not recognized or implemented finally: self.hallucination_mitigation_strategy = original_mitigation_state # Restore original state def _generate_claude(self, prompt: str, system_prompt: Optional[str] = None) -> str: """Generate using Claude API""" import anthropic client = anthropic.Anthropic(api_key=self.api_key) messages = [{"role": "user", "content": prompt}] response = client.messages.create( model=self.model, max_tokens=self.max_tokens, temperature=self.temperature, system=system_prompt or "", messages=messages ) return response.content[0].text def _generate_gpt(self, prompt: str, system_prompt: Optional[str] = None) -> str: """Generate using OpenAI GPT API""" import openai client = openai.OpenAI(api_key=self.api_key) messages = [] if system_prompt: messages.append({"role": "system", "content": system_prompt}) messages.append({"role": "user", "content": prompt}) response = client.chat.completions.create( model=self.model, messages=messages, temperature=self.temperature, max_tokens=self.max_tokens ) return response.choices[0].message.content def _generate_gemini(self, prompt: str, system_prompt: Optional[str] = None) -> str: """Generate using Google Gemini API""" import google.generativeai as genai genai.configure(api_key=self.api_key) model = genai.GenerativeModel(self.model) full_prompt = prompt if system_prompt: full_prompt = f"{system_prompt}\n\n{prompt}" response = model.generate_content( full_prompt, generation_config={ 'temperature': self.temperature, 'max_output_tokens': self.max_tokens, } ) return response.text def _generate_gemini_cli(self, prompt: str, system_prompt: Optional[str] = None) -> str: """Generate using Gemini CLI""" try: full_prompt = prompt if system_prompt: full_prompt = f"{system_prompt}\n\n{prompt}" # Use gemini CLI tool cmd = ['gemini', 'chat', '-m', self.model] result = subprocess.run( cmd, input=full_prompt.encode(), capture_output=True, timeout=120 ) if result.returncode == 0: return result.stdout.decode().strip() else: error = result.stderr.decode().strip() logger.error(f"Gemini CLI error: {error}") return f"Error: {error}" except subprocess.TimeoutExpired: logger.error("Gemini CLI timeout") return "Error: Request timeout" except Exception as e: logger.error(f"Gemini CLI error: {e}") return f"Error: {str(e)}" def _generate_ollama(self, prompt: str, system_prompt: Optional[str] = None) -> str: """Generate using Ollama local models""" try: url = "http://localhost:11434/api/generate" data = { "model": self.model, "prompt": prompt, "stream": False, "options": { "temperature": self.temperature, "num_predict": self.max_tokens } } if system_prompt: data["system"] = system_prompt response = requests.post(url, json=data, timeout=120) response.raise_for_status() return response.json()["response"] except Exception as e: logger.error(f"Ollama error: {e}") return f"Error: {str(e)}" def _generate_lmstudio(self, prompt: str, system_prompt: Optional[str] = None) -> str: """ Generate using LM Studio local server. LM Studio provides an OpenAI-compatible API at http://localhost:1234/v1 """ try: # LM Studio uses OpenAI-compatible API url = "http://localhost:1234/v1/chat/completions" messages = [] if system_prompt: messages.append({"role": "system", "content": system_prompt}) messages.append({"role": "user", "content": prompt}) data = { "model": self.model, # LM Studio auto-detects loaded model "messages": messages, "temperature": self.temperature, "max_tokens": self.max_tokens, "stream": False } logger.debug(f"Sending request to LM Studio at {url}") response = requests.post(url, json=data, timeout=120) response.raise_for_status() result = response.json() return result["choices"][0]["message"]["content"] except requests.exceptions.ConnectionError: logger.error("LM Studio connection error. Ensure LM Studio server is running on http://localhost:1234") return "Error: Cannot connect to LM Studio. Please ensure LM Studio server is running on port 1234." except requests.exceptions.Timeout: logger.error("LM Studio request timeout") return "Error: LM Studio request timeout after 120 seconds" except KeyError as e: logger.error(f"LM Studio response format error: {e}") return f"Error: Unexpected response format from LM Studio: {str(e)}" except Exception as e: logger.error(f"LM Studio error: {e}") return f"Error: {str(e)}" def analyze_vulnerability(self, vulnerability_data: Dict) -> Dict: """Analyze vulnerability and suggest exploits""" # This prompt will be fetched from library.json later prompt = self.get_prompt("json_prompts", "exploitation", "analyze_vulnerability_user", default=f""" Analyze the following vulnerability data and provide exploitation recommendations: Vulnerability: {json.dumps(vulnerability_data, indent=2)} Provide: 1. Severity assessment (Critical/High/Medium/Low) 2. Exploitation difficulty 3. Potential impact 4. Recommended exploit techniques 5. Detection evasion strategies 6. Post-exploitation actions Response in JSON format. """) system_prompt = self.get_prompt("json_prompts", "exploitation", "analyze_vulnerability_system", default="""You are an expert penetration tester and security researcher. Analyze vulnerabilities and provide detailed, actionable exploitation strategies. Consider OWASP, CWE, and MITRE ATT&CK frameworks. Always include ethical considerations and legal boundaries.""") response = self.generate(prompt, system_prompt) try: return json.loads(response) except: return {"raw_response": response} def generate_payload(self, target_info: Dict, vulnerability_type: str) -> str: """Generate exploit payload""" # This prompt will be fetched from library.json later prompt = self.get_prompt("json_prompts", "exploitation", "generate_payload_user", default=f""" Generate an exploit payload for the following scenario: Target Information: {json.dumps(target_info, indent=2)} Vulnerability Type: {vulnerability_type} Requirements: 1. Generate a working payload 2. Include obfuscation techniques 3. Add error handling 4. Ensure minimal detection footprint 5. Include cleanup procedures Provide the payload code with detailed comments. """) system_prompt = self.get_prompt("json_prompts", "exploitation", "generate_payload_system", default="""You are an expert exploit developer. Generate sophisticated, tested payloads that are effective yet responsible. Always include safety mechanisms and ethical guidelines.""") return self.generate(prompt, system_prompt) def suggest_privilege_escalation(self, system_info: Dict) -> List[str]: """Suggest privilege escalation techniques""" # This prompt will be fetched from library.json later prompt = self.get_prompt("json_prompts", "privesc", "suggest_privilege_escalation_user", default=f""" Based on the following system information, suggest privilege escalation techniques: System Info: {json.dumps(system_info, indent=2)} Provide: 1. Top 5 privilege escalation vectors 2. Required tools and commands 3. Detection likelihood 4. Success probability 5. Alternative approaches Response in JSON format with prioritized list. """) system_prompt = self.get_prompt("json_prompts", "privesc", "suggest_privilege_escalation_system", default="""You are a privilege escalation specialist. Analyze system configurations and suggest effective escalation paths. Consider Windows, Linux, and Active Directory environments.""") response = self.generate(prompt, system_prompt) try: result = json.loads(response) return result.get('techniques', []) except: return [] def analyze_network_topology(self, scan_results: Dict) -> Dict: """Analyze network topology and suggest attack paths""" # This prompt will be fetched from library.json later prompt = self.get_prompt("json_prompts", "network_recon", "analyze_network_topology_user", default=f""" Analyze the network topology and suggest attack paths: Scan Results: {json.dumps(scan_results, indent=2)} Provide: 1. Network architecture overview 2. Critical assets identification 3. Attack surface analysis 4. Recommended attack paths (prioritized) 5. Lateral movement opportunities 6. Persistence locations Response in JSON format. """) system_prompt = self.get_prompt("json_prompts", "network_recon", "analyze_network_topology_system", default="""You are a network penetration testing expert. Analyze network structures and identify optimal attack vectors. Consider defense-in-depth and detection mechanisms.""") response = self.generate(prompt, system_prompt) try: return json.loads(response) except: return {"raw_response": response} def analyze_web_vulnerability(self, vulnerability_type: str, vulnerability_data: Dict) -> Dict: """Analyze a specific web vulnerability using the appropriate prompt from library.json""" user_prompt_name = f"{vulnerability_type.lower()}_user" system_prompt_name = f"{vulnerability_type.lower()}_system" # Dynamically fetch user prompt, passing vulnerability_data user_prompt_template = self.get_prompt("json_prompts", "vulnerability_testing", user_prompt_name) if not user_prompt_template: logger.warning(f"No user prompt found for vulnerability type: {vulnerability_type}") return {"error": f"No user prompt template for {vulnerability_type}"} # Replace placeholder in the user prompt template if vulnerability_type.lower() == "ssrf": prompt = user_prompt_template.format(http_data_json=json.dumps(vulnerability_data, indent=2)) elif vulnerability_type.lower() == "sql_injection": prompt = user_prompt_template.format(input_data_json=json.dumps(vulnerability_data, indent=2)) elif vulnerability_type.lower() == "xss": prompt = user_prompt_template.format(xss_data_json=json.dumps(vulnerability_data, indent=2)) elif vulnerability_type.lower() == "lfi": prompt = user_prompt_template.format(lfi_data_json=json.dumps(vulnerability_data, indent=2)) elif vulnerability_type.lower() == "broken_object": prompt = user_prompt_template.format(api_data_json=json.dumps(vulnerability_data, indent=2)) elif vulnerability_type.lower() == "broken_auth": prompt = user_prompt_template.format(auth_data_json=json.dumps(vulnerability_data, indent=2)) else: logger.warning(f"Unsupported vulnerability type for analysis: {vulnerability_type}") return {"error": f"Unsupported vulnerability type: {vulnerability_type}"} system_prompt = self.get_prompt("json_prompts", "vulnerability_testing", system_prompt_name) if not system_prompt: logger.warning(f"No system prompt found for vulnerability type: {vulnerability_type}") # Use a generic system prompt if a specific one isn't found system_prompt = "You are an expert web security tester. Analyze the provided data for vulnerabilities and offer exploitation steps and remediation." response = self.generate(prompt, system_prompt) try: return json.loads(response) except json.JSONDecodeError: logger.error(f"Failed to decode JSON response for {vulnerability_type} analysis: {response}") return {"raw_response": response} except Exception as e: logger.error(f"Error during {vulnerability_type} analysis: {e}") return {"error": str(e), "raw_response": response}