mirror of
https://github.com/CyberSecurityUP/NeuroSploit.git
synced 2026-02-12 14:02:45 +00:00
Merge pull request #6 from YatinChaubal/main
fix: handle missing placeholders in prompt template formatting
This commit is contained in:
@@ -3,7 +3,6 @@ import logging
|
|||||||
from typing import Dict, Any, List, Optional
|
from typing import Dict, Any, List, Optional
|
||||||
import re
|
import re
|
||||||
import subprocess
|
import subprocess
|
||||||
import shlex
|
|
||||||
|
|
||||||
from core.llm_manager import LLMManager
|
from core.llm_manager import LLMManager
|
||||||
|
|
||||||
@@ -36,35 +35,15 @@ class BaseAgent:
|
|||||||
logger.warning(f"No user prompt template found for agent {self.agent_name}.")
|
logger.warning(f"No user prompt template found for agent {self.agent_name}.")
|
||||||
return user_input # Fallback to raw user input
|
return user_input # Fallback to raw user input
|
||||||
|
|
||||||
# Create a dictionary with all the possible placeholders and default values
|
# Create a dictionary with all the possible placeholders
|
||||||
format_dict = {
|
format_dict = {
|
||||||
"user_input": user_input,
|
"user_input": user_input,
|
||||||
# For bug_bounty_hunter agent
|
"target_info_json": user_input, # for bug_bounty_hunter
|
||||||
"target_info_json": user_input,
|
"recon_data_json": json.dumps(additional_context or {}, indent=2), # for bug_bounty_hunter
|
||||||
"recon_data_json": json.dumps(additional_context or {}, indent=2),
|
"additional_context_json": json.dumps(additional_context or {}, indent=2),
|
||||||
# For red_team_agent
|
"mission_objectives_json": json.dumps(additional_context or {}, indent=2) # for red_team_agent
|
||||||
"mission_objectives_json": user_input,
|
|
||||||
"target_environment_json": json.dumps(additional_context or {}, indent=2),
|
|
||||||
# For pentest agent
|
|
||||||
"scope_json": user_input,
|
|
||||||
"initial_info_json": json.dumps(additional_context or {}, indent=2),
|
|
||||||
# For blue_team_agent
|
|
||||||
"logs_alerts_json": user_input,
|
|
||||||
"telemetry_json": json.dumps(additional_context or {}, indent=2),
|
|
||||||
# For exploit_expert agent
|
|
||||||
"vulnerability_details_json": user_input,
|
|
||||||
"target_info_json": json.dumps(additional_context or {}, indent=2),
|
|
||||||
# For cwe_expert agent
|
|
||||||
"code_vulnerability_json": user_input,
|
|
||||||
# For malware_analysis agent
|
|
||||||
"malware_sample_json": user_input,
|
|
||||||
# For replay_attack agent
|
|
||||||
"traffic_logs_json": user_input,
|
|
||||||
# Generic additional context
|
|
||||||
"additional_context_json": json.dumps(additional_context or {}, indent=2)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# Override with actual additional_context values if provided
|
|
||||||
if additional_context:
|
if additional_context:
|
||||||
for key, value in additional_context.items():
|
for key, value in additional_context.items():
|
||||||
if isinstance(value, (dict, list)):
|
if isinstance(value, (dict, list)):
|
||||||
@@ -72,14 +51,13 @@ class BaseAgent:
|
|||||||
else:
|
else:
|
||||||
format_dict[key] = value
|
format_dict[key] = value
|
||||||
|
|
||||||
# Use a safe way to format, only including keys that exist in the template
|
# Use a safe way to format, handling missing keys gracefully
|
||||||
try:
|
class SafeDict(dict):
|
||||||
formatted_prompt = user_prompt_template.format_map(format_dict)
|
def __missing__(self, key):
|
||||||
except KeyError as e:
|
return f"{{{key}}}" # Return the placeholder as-is for missing keys
|
||||||
logger.error(f"Missing key in format_dict: {e}")
|
|
||||||
# Fallback to user input if formatting fails
|
formatted_prompt = user_prompt_template.format_map(SafeDict(format_dict))
|
||||||
return user_input
|
|
||||||
|
|
||||||
return formatted_prompt
|
return formatted_prompt
|
||||||
|
|
||||||
def execute(self, user_input: str, campaign_data: Dict = None) -> Dict:
|
def execute(self, user_input: str, campaign_data: Dict = None) -> Dict:
|
||||||
@@ -121,134 +99,19 @@ class BaseAgent:
|
|||||||
return {"agent_name": self.agent_name, "input": user_input, "llm_response": llm_response_text}
|
return {"agent_name": self.agent_name, "input": user_input, "llm_response": llm_response_text}
|
||||||
|
|
||||||
def _parse_llm_response(self, response: str) -> (Optional[str], Optional[str]):
|
def _parse_llm_response(self, response: str) -> (Optional[str], Optional[str]):
|
||||||
"""
|
"""Parses the LLM response to find a tool to use."""
|
||||||
Parses the LLM response to find a tool to use.
|
match = re.search(r"\[TOOL\]\s*(\w+)\s*:\s*(.*)", response)
|
||||||
Supports both single tool format and multiple tool chain format.
|
|
||||||
"""
|
|
||||||
# Single tool format: [TOOL] toolname: args
|
|
||||||
match = re.search(r"\[TOOL\]\s*(\w+)\s*:\s*(.*?)(?:\n|$)", response, re.MULTILINE)
|
|
||||||
if match:
|
if match:
|
||||||
return match.group(1), match.group(2).strip()
|
return match.group(1), match.group(2)
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
def _parse_all_tools(self, response: str) -> List[tuple]:
|
|
||||||
"""
|
|
||||||
Parse multiple tool calls from LLM response for tool chaining.
|
|
||||||
Returns list of (tool_name, tool_args) tuples.
|
|
||||||
"""
|
|
||||||
tools = []
|
|
||||||
pattern = r"\[TOOL\]\s*(\w+)\s*:\s*(.*?)(?=\[TOOL\]|$)"
|
|
||||||
matches = re.finditer(pattern, response, re.MULTILINE | re.DOTALL)
|
|
||||||
|
|
||||||
for match in matches:
|
|
||||||
tool_name = match.group(1)
|
|
||||||
tool_args = match.group(2).strip()
|
|
||||||
tools.append((tool_name, tool_args))
|
|
||||||
|
|
||||||
logger.debug(f"Parsed {len(tools)} tool calls from LLM response")
|
|
||||||
return tools
|
|
||||||
|
|
||||||
def execute_tool_chain(self, tools: List[tuple]) -> List[Dict]:
|
|
||||||
"""
|
|
||||||
Execute multiple tools in sequence (tool chaining).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tools: List of (tool_name, tool_args) tuples
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[Dict]: Results from each tool execution
|
|
||||||
"""
|
|
||||||
results = []
|
|
||||||
|
|
||||||
for tool_name, tool_args in tools:
|
|
||||||
logger.info(f"Executing tool in chain: {tool_name}")
|
|
||||||
|
|
||||||
# Check if tool is allowed for this agent
|
|
||||||
if tool_name not in self.tools_allowed and self.tools_allowed:
|
|
||||||
logger.warning(f"Tool '{tool_name}' not allowed for agent {self.agent_name}")
|
|
||||||
results.append({
|
|
||||||
"tool": tool_name,
|
|
||||||
"status": "denied",
|
|
||||||
"output": f"Tool '{tool_name}' not in allowed tools list"
|
|
||||||
})
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Check if tool exists in config
|
|
||||||
if tool_name not in self.config.get('tools', {}):
|
|
||||||
logger.warning(f"Tool '{tool_name}' not found in configuration")
|
|
||||||
results.append({
|
|
||||||
"tool": tool_name,
|
|
||||||
"status": "not_found",
|
|
||||||
"output": f"Tool '{tool_name}' not configured"
|
|
||||||
})
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Execute the tool
|
|
||||||
tool_path = self.config['tools'][tool_name]
|
|
||||||
output = self._execute_tool(tool_path, tool_args)
|
|
||||||
|
|
||||||
results.append({
|
|
||||||
"tool": tool_name,
|
|
||||||
"args": tool_args,
|
|
||||||
"status": "executed",
|
|
||||||
"output": output
|
|
||||||
})
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|
||||||
def _execute_tool(self, tool_path: str, args: str) -> str:
|
def _execute_tool(self, tool_path: str, args: str) -> str:
|
||||||
"""
|
"""Executes a tool and returns the output."""
|
||||||
Executes a tool safely and returns the output.
|
|
||||||
Uses shlex for safe argument parsing and includes timeout protection.
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
# Sanitize and validate tool path
|
result = subprocess.run(f"{tool_path} {args}", shell=True, capture_output=True, text=True)
|
||||||
if not tool_path or '..' in tool_path:
|
return result.stdout + result.stderr
|
||||||
return f"[ERROR] Invalid tool path: {tool_path}"
|
|
||||||
|
|
||||||
# Parse arguments safely using shlex
|
|
||||||
try:
|
|
||||||
args_list = shlex.split(args) if args else []
|
|
||||||
except ValueError as e:
|
|
||||||
return f"[ERROR] Invalid arguments: {e}"
|
|
||||||
|
|
||||||
# Build command list (no shell=True for security)
|
|
||||||
cmd = [tool_path] + args_list
|
|
||||||
|
|
||||||
logger.info(f"Executing tool: {' '.join(cmd)}")
|
|
||||||
|
|
||||||
# Execute with timeout (60 seconds default)
|
|
||||||
result = subprocess.run(
|
|
||||||
cmd,
|
|
||||||
capture_output=True,
|
|
||||||
text=True,
|
|
||||||
timeout=60,
|
|
||||||
shell=False # Security: never use shell=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# Combine stdout and stderr
|
|
||||||
output = ""
|
|
||||||
if result.stdout:
|
|
||||||
output += f"[STDOUT]\n{result.stdout}\n"
|
|
||||||
if result.stderr:
|
|
||||||
output += f"[STDERR]\n{result.stderr}\n"
|
|
||||||
if result.returncode != 0:
|
|
||||||
output += f"[EXIT_CODE] {result.returncode}\n"
|
|
||||||
|
|
||||||
return output if output else "[NO_OUTPUT]"
|
|
||||||
|
|
||||||
except subprocess.TimeoutExpired:
|
|
||||||
logger.error(f"Tool execution timeout: {tool_path}")
|
|
||||||
return f"[ERROR] Tool execution timeout after 60 seconds"
|
|
||||||
except FileNotFoundError:
|
|
||||||
logger.error(f"Tool not found: {tool_path}")
|
|
||||||
return f"[ERROR] Tool not found at path: {tool_path}"
|
|
||||||
except PermissionError:
|
|
||||||
logger.error(f"Permission denied executing: {tool_path}")
|
|
||||||
return f"[ERROR] Permission denied for tool: {tool_path}"
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Unexpected error executing tool: {e}")
|
return f"Error executing tool: {e}"
|
||||||
return f"[ERROR] Unexpected error: {str(e)}"
|
|
||||||
|
|
||||||
def _ask_for_permission(self, message: str) -> bool:
|
def _ask_for_permission(self, message: str) -> bool:
|
||||||
"""Asks the user for permission."""
|
"""Asks the user for permission."""
|
||||||
|
|||||||
Reference in New Issue
Block a user