#!/usr/bin/env python3 """ LLM Manager - Unified interface for multiple LLM providers Supports: Claude, GPT, Gemini, Ollama, and custom models """ import os import json import subprocess from typing import Dict, List, Optional, Any import logging import requests from pathlib import Path # Added for Path import re # Added for regex operations logger = logging.getLogger(__name__) class LLMManager: """Manage multiple LLM providers""" def __init__(self, config: Dict): """Initialize LLM manager""" self.config = config.get('llm', {}) self.default_profile_name = self.config.get('default_profile', 'gemini_pro_default') self.profiles = self.config.get('profiles', {}) self.active_profile = self.profiles.get(self.default_profile_name, {}) # Load active profile settings self.provider = self.active_profile.get('provider', 'gemini').lower() self.model = self.active_profile.get('model', 'gemini-pro') self.api_key = self._get_api_key(self.active_profile.get('api_key', '')) self.temperature = self.active_profile.get('temperature', 0.7) self.max_tokens = self.active_profile.get('max_tokens', 4096) # New LLM parameters self.input_token_limit = self.active_profile.get('input_token_limit', 4096) self.output_token_limit = self.active_profile.get('output_token_limit', 4096) self.cache_enabled = self.active_profile.get('cache_enabled', False) self.search_context_level = self.active_profile.get('search_context_level', 'medium') # low, medium, high self.pdf_support_enabled = self.active_profile.get('pdf_support_enabled', False) self.guardrails_enabled = self.active_profile.get('guardrails_enabled', False) self.hallucination_mitigation_strategy = self.active_profile.get('hallucination_mitigation_strategy', None) # New prompt loading self.json_prompts_file_path = Path("prompts/library.json") self.md_prompts_dir_path = Path("prompts/md_library") self.prompts = self._load_all_prompts() # New method to load both logger.info(f"Initialized LLM Manager - Provider: {self.provider}, Model: {self.model}, Profile: {self.default_profile_name}") def _get_api_key(self, api_key_config: str) -> str: """Helper to get API key from config or environment variable""" if api_key_config.startswith('${') and api_key_config.endswith('}'): env_var = api_key_config[2:-1] return os.getenv(env_var, '') return api_key_config def _load_all_prompts(self) -> Dict: """Load prompts from both JSON library and Markdown library files.""" all_prompts = { "json_prompts": {}, "md_prompts": {} } # Load from JSON library if self.json_prompts_file_path.exists(): try: with open(self.json_prompts_file_path, 'r') as f: all_prompts["json_prompts"] = json.load(f) logger.info(f"Loaded prompts from JSON library: {self.json_prompts_file_path}") except Exception as e: logger.error(f"Error loading prompts from {self.json_prompts_file_path}: {e}") else: logger.warning(f"JSON prompts file not found at {self.json_prompts_file_path}. Some AI functionalities might be limited.") # Load from Markdown library if self.md_prompts_dir_path.is_dir(): for md_file in self.md_prompts_dir_path.glob("*.md"): try: content = md_file.read_text() prompt_name = md_file.stem # Use filename as prompt name user_prompt_match = re.search(r"## User Prompt\n(.*?)(?=\n## System Prompt|\Z)", content, re.DOTALL) system_prompt_match = re.search(r"## System Prompt\n(.*?)(?=\n## User Prompt|\Z)", content, re.DOTALL) user_prompt = user_prompt_match.group(1).strip() if user_prompt_match else "" system_prompt = system_prompt_match.group(1).strip() if system_prompt_match else "" if user_prompt or system_prompt: all_prompts["md_prompts"][prompt_name] = { "user_prompt": user_prompt, "system_prompt": system_prompt } else: logger.warning(f"No valid User or System Prompt found in {md_file.name}. Skipping.") except Exception as e: logger.error(f"Error loading prompt from {md_file.name}: {e}") logger.info(f"Loaded {len(all_prompts['md_prompts'])} prompts from Markdown library.") else: logger.warning(f"Markdown prompts directory not found at {self.md_prompts_dir_path}. Some AI functionalities might be limited.") return all_prompts def get_prompt(self, library_type: str, category: str, name: str, default: str = "") -> str: """Retrieve a specific prompt by library type, category, and name. `library_type` can be "json_prompts" or "md_prompts". `category` can be a JSON top-level key (e.g., 'exploitation') or an MD filename (e.g., 'red_team_agent'). `name` can be a JSON sub-key (e.g., 'ai_exploit_planning_user') or 'user_prompt'/'system_prompt' for MD. """ return self.prompts.get(library_type, {}).get(category, {}).get(name, default) def generate(self, prompt: str, system_prompt: Optional[str] = None) -> str: """Generate response from LLM and apply hallucination mitigation if configured.""" raw_response = "" try: if self.provider == 'claude': raw_response = self._generate_claude(prompt, system_prompt) elif self.provider == 'gpt': raw_response = self._generate_gpt(prompt, system_prompt) elif self.provider == 'gemini': raw_response = self._generate_gemini(prompt, system_prompt) elif self.provider == 'ollama': raw_response = self._generate_ollama(prompt, system_prompt) elif self.provider == 'gemini-cli': raw_response = self._generate_gemini_cli(prompt, system_prompt) else: raise ValueError(f"Unsupported provider: {self.provider}") except Exception as e: logger.error(f"Error generating raw response: {e}") return f"Error: {str(e)}" if self.guardrails_enabled: raw_response = self._apply_guardrails(raw_response) # Apply guardrails here if self.hallucination_mitigation_strategy and self.hallucination_mitigation_strategy in ["grounding", "self_reflection", "consistency_check"]: logger.debug(f"Applying hallucination mitigation strategy: {self.hallucination_mitigation_strategy}") return self._mitigate_hallucination(raw_response, prompt, system_prompt) return raw_response def _apply_guardrails(self, response: str) -> str: """Applies basic guardrails to the LLM response.""" if not self.guardrails_enabled: return response logger.debug("Applying guardrails...") # Example: Simple keyword filtering harmful_keywords = ["malicious_exploit_command", "destroy_system", "wipe_data", "unauthorized_access"] # Placeholder keywords for keyword in harmful_keywords: if keyword in response.lower(): logger.warning(f"Guardrail triggered: Found potentially harmful keyword '{keyword}'. Response will be sanitized or flagged.") # A more robust solution would involve redaction, re-prompting, or flagging for human review. # For this example, we'll replace the keyword. response = response.replace(keyword, "[REDACTED_HARMFUL_CONTENT]") response = response.replace(keyword.upper(), "[REDACTED_HARMFUL_CONTENT]") # Example: Length check (if response is excessively long and not expected) # Using output_token_limit for a more accurate comparison if len(response.split()) > self.output_token_limit * 1.5: # Roughly estimate tokens by word count logger.warning("Guardrail triggered: Response is excessively long. Truncating or flagging.") response = " ".join(response.split()[:int(self.output_token_limit * 1.5)]) + "\n[RESPONSE TRUNCATED BY GUARDRAIL]" # Ethical check (can be another LLM call, but for simplicity, a fixed instruction) # This is more about ensuring the tone and content align with ethical hacking principles. # This is a very simplistic example. A real ethical check would be more nuanced. # For now, just a log or a general check for explicit unethical instructions. if any(bad_phrase in response.lower() for bad_phrase in ["perform illegal activity", "bypass security illegally"]): logger.warning("Guardrail triggered: Response contains potentially unethical instructions. Flagging for review.") response = "[UNETHICAL CONTENT FLAGGED FOR REVIEW]\n" + response return response def _mitigate_hallucination(self, raw_response: str, original_prompt: str, original_system_prompt: Optional[str]) -> str: """Applies configured hallucination mitigation strategy.""" strategy = self.hallucination_mitigation_strategy # Temporarily disable mitigation to prevent infinite recursion when calling self.generate internally original_mitigation_state = self.hallucination_mitigation_strategy self.hallucination_mitigation_strategy = None try: if strategy == "grounding": verification_prompt = ( f"Review the following response:\n\n---\n{raw_response}\n---\n\n" f"Based *only* on the context provided in the original prompt (user: '{original_prompt}', system: '{original_system_prompt or "None"}'), " f"is this response factual and directly supported by the context? If not, correct it to be factual. " f"If the response is completely unsourced or makes claims beyond the context, state 'UNSOURCED'." ) logger.debug("Applying grounding strategy: Re-prompting for factual verification.") return self.generate(verification_prompt, "You are a fact-checker whose sole purpose is to verify LLM output against provided context.") elif strategy == "self_reflection": reflection_prompt = ( f"Critically review the following response for accuracy, logical consistency, and adherence to the original prompt's instructions:\n\n" f"Original Prompt (User): {original_prompt}\n" f"Original Prompt (System): {original_system_prompt or "None"}\n\n" f"Generated Response: {raw_response}\n\n" f"Identify any potential hallucinations, inconsistencies, or areas where the response might have deviated from facts or instructions. " f"If you find issues, provide a corrected and more reliable version of the response. If the response is good, state 'ACCURATE'." ) logger.debug("Applying self-reflection strategy: Re-prompting for self-critique.") return self.generate(reflection_prompt, "You are an AI assistant designed to critically evaluate and improve other AI-generated content.") elif strategy == "consistency_check": logger.debug("Applying consistency check strategy: Generating multiple responses for comparison.") responses = [] for i in range(3): # Generate 3 responses for consistency check logger.debug(f"Generating response {i+1} for consistency check.") res = self.generate(original_prompt, original_system_prompt) responses.append(res) if len(set(responses)) == 1: return responses[0] else: logger.warning("Consistency check found varying responses. Attempting to synthesize a consistent answer.") synthesis_prompt = ( f"Synthesize a single, consistent, and factual response from the following AI-generated options. " f"Prioritize factual accuracy and avoid information present in only one response if contradictory. " f"If there's significant disagreement, state the core disagreement.\n\n" f"Options:\n" + "\n---\n".join(responses) ) return self.generate(synthesis_prompt, "You are a highly analytical AI assistant tasked with synthesizing consistent information from multiple sources.") return raw_response # Fallback if strategy not recognized or implemented finally: self.hallucination_mitigation_strategy = original_mitigation_state # Restore original state def _generate_claude(self, prompt: str, system_prompt: Optional[str] = None) -> str: """Generate using Claude API""" import anthropic client = anthropic.Anthropic(api_key=self.api_key) messages = [{"role": "user", "content": prompt}] response = client.messages.create( model=self.model, max_tokens=self.max_tokens, temperature=self.temperature, system=system_prompt or "", messages=messages ) return response.content[0].text def _generate_gpt(self, prompt: str, system_prompt: Optional[str] = None) -> str: """Generate using OpenAI GPT API""" import openai client = openai.OpenAI(api_key=self.api_key) messages = [] if system_prompt: messages.append({"role": "system", "content": system_prompt}) messages.append({"role": "user", "content": prompt}) response = client.chat.completions.create( model=self.model, messages=messages, temperature=self.temperature, max_tokens=self.max_tokens ) return response.choices[0].message.content def _generate_gemini(self, prompt: str, system_prompt: Optional[str] = None) -> str: """Generate using Google Gemini API""" import google.generativeai as genai genai.configure(api_key=self.api_key) model = genai.GenerativeModel(self.model) full_prompt = prompt if system_prompt: full_prompt = f"{system_prompt}\n\n{prompt}" response = model.generate_content( full_prompt, generation_config={ 'temperature': self.temperature, 'max_output_tokens': self.max_tokens, } ) return response.text def _generate_gemini_cli(self, prompt: str, system_prompt: Optional[str] = None) -> str: """Generate using Gemini CLI""" try: full_prompt = prompt if system_prompt: full_prompt = f"{system_prompt}\n\n{prompt}" # Use gemini CLI tool cmd = ['gemini', 'chat', '-m', self.model] result = subprocess.run( cmd, input=full_prompt.encode(), capture_output=True, timeout=120 ) if result.returncode == 0: return result.stdout.decode().strip() else: error = result.stderr.decode().strip() logger.error(f"Gemini CLI error: {error}") return f"Error: {error}" except subprocess.TimeoutExpired: logger.error("Gemini CLI timeout") return "Error: Request timeout" except Exception as e: logger.error(f"Gemini CLI error: {e}") return f"Error: {str(e)}" def _generate_ollama(self, prompt: str, system_prompt: Optional[str] = None) -> str: """Generate using Ollama local models""" try: url = "http://localhost:11434/api/generate" data = { "model": self.model, "prompt": prompt, "stream": False, "options": { "temperature": self.temperature, "num_predict": self.max_tokens } } if system_prompt: data["system"] = system_prompt response = requests.post(url, json=data, timeout=120) response.raise_for_status() return response.json()["response"] except Exception as e: logger.error(f"Ollama error: {e}") return f"Error: {str(e)}" def analyze_vulnerability(self, vulnerability_data: Dict) -> Dict: """Analyze vulnerability and suggest exploits""" # This prompt will be fetched from library.json later prompt = self.get_prompt("json_prompts", "exploitation", "analyze_vulnerability_user", default=f""" Analyze the following vulnerability data and provide exploitation recommendations: Vulnerability: {json.dumps(vulnerability_data, indent=2)} Provide: 1. Severity assessment (Critical/High/Medium/Low) 2. Exploitation difficulty 3. Potential impact 4. Recommended exploit techniques 5. Detection evasion strategies 6. Post-exploitation actions Response in JSON format. """) system_prompt = self.get_prompt("json_prompts", "exploitation", "analyze_vulnerability_system", default="""You are an expert penetration tester and security researcher. Analyze vulnerabilities and provide detailed, actionable exploitation strategies. Consider OWASP, CWE, and MITRE ATT&CK frameworks. Always include ethical considerations and legal boundaries.""") response = self.generate(prompt, system_prompt) try: return json.loads(response) except: return {"raw_response": response} def generate_payload(self, target_info: Dict, vulnerability_type: str) -> str: """Generate exploit payload""" # This prompt will be fetched from library.json later prompt = self.get_prompt("json_prompts", "exploitation", "generate_payload_user", default=f""" Generate an exploit payload for the following scenario: Target Information: {json.dumps(target_info, indent=2)} Vulnerability Type: {vulnerability_type} Requirements: 1. Generate a working payload 2. Include obfuscation techniques 3. Add error handling 4. Ensure minimal detection footprint 5. Include cleanup procedures Provide the payload code with detailed comments. """) system_prompt = self.get_prompt("json_prompts", "exploitation", "generate_payload_system", default="""You are an expert exploit developer. Generate sophisticated, tested payloads that are effective yet responsible. Always include safety mechanisms and ethical guidelines.""") return self.generate(prompt, system_prompt) def suggest_privilege_escalation(self, system_info: Dict) -> List[str]: """Suggest privilege escalation techniques""" # This prompt will be fetched from library.json later prompt = self.get_prompt("json_prompts", "privesc", "suggest_privilege_escalation_user", default=f""" Based on the following system information, suggest privilege escalation techniques: System Info: {json.dumps(system_info, indent=2)} Provide: 1. Top 5 privilege escalation vectors 2. Required tools and commands 3. Detection likelihood 4. Success probability 5. Alternative approaches Response in JSON format with prioritized list. """) system_prompt = self.get_prompt("json_prompts", "privesc", "suggest_privilege_escalation_system", default="""You are a privilege escalation specialist. Analyze system configurations and suggest effective escalation paths. Consider Windows, Linux, and Active Directory environments.""") response = self.generate(prompt, system_prompt) try: result = json.loads(response) return result.get('techniques', []) except: return [] def analyze_network_topology(self, scan_results: Dict) -> Dict: """Analyze network topology and suggest attack paths""" # This prompt will be fetched from library.json later prompt = self.get_prompt("json_prompts", "network_recon", "analyze_network_topology_user", default=f""" Analyze the network topology and suggest attack paths: Scan Results: {json.dumps(scan_results, indent=2)} Provide: 1. Network architecture overview 2. Critical assets identification 3. Attack surface analysis 4. Recommended attack paths (prioritized) 5. Lateral movement opportunities 6. Persistence locations Response in JSON format. """) system_prompt = self.get_prompt("json_prompts", "network_recon", "analyze_network_topology_system", default="""You are a network penetration testing expert. Analyze network structures and identify optimal attack vectors. Consider defense-in-depth and detection mechanisms.""") response = self.generate(prompt, system_prompt) try: return json.loads(response) except: return {"raw_response": response} def analyze_web_vulnerability(self, vulnerability_type: str, vulnerability_data: Dict) -> Dict: """Analyze a specific web vulnerability using the appropriate prompt from library.json""" user_prompt_name = f"{vulnerability_type.lower()}_user" system_prompt_name = f"{vulnerability_type.lower()}_system" # Dynamically fetch user prompt, passing vulnerability_data user_prompt_template = self.get_prompt("json_prompts", "vulnerability_testing", user_prompt_name) if not user_prompt_template: logger.warning(f"No user prompt found for vulnerability type: {vulnerability_type}") return {"error": f"No user prompt template for {vulnerability_type}"} # Replace placeholder in the user prompt template if vulnerability_type.lower() == "ssrf": prompt = user_prompt_template.format(http_data_json=json.dumps(vulnerability_data, indent=2)) elif vulnerability_type.lower() == "sql_injection": prompt = user_prompt_template.format(input_data_json=json.dumps(vulnerability_data, indent=2)) elif vulnerability_type.lower() == "xss": prompt = user_prompt_template.format(xss_data_json=json.dumps(vulnerability_data, indent=2)) elif vulnerability_type.lower() == "lfi": prompt = user_prompt_template.format(lfi_data_json=json.dumps(vulnerability_data, indent=2)) elif vulnerability_type.lower() == "broken_object": prompt = user_prompt_template.format(api_data_json=json.dumps(vulnerability_data, indent=2)) elif vulnerability_type.lower() == "broken_auth": prompt = user_prompt_template.format(auth_data_json=json.dumps(vulnerability_data, indent=2)) else: logger.warning(f"Unsupported vulnerability type for analysis: {vulnerability_type}") return {"error": f"Unsupported vulnerability type: {vulnerability_type}"} system_prompt = self.get_prompt("json_prompts", "vulnerability_testing", system_prompt_name) if not system_prompt: logger.warning(f"No system prompt found for vulnerability type: {vulnerability_type}") # Use a generic system prompt if a specific one isn't found system_prompt = "You are an expert web security tester. Analyze the provided data for vulnerabilities and offer exploitation steps and remediation." response = self.generate(prompt, system_prompt) try: return json.loads(response) except json.JSONDecodeError: logger.error(f"Failed to decode JSON response for {vulnerability_type} analysis: {response}") return {"raw_response": response} except Exception as e: logger.error(f"Error during {vulnerability_type} analysis: {e}") return {"error": str(e), "raw_response": response}