Files
CyberStrikeAI/internal/security/executor.go
2025-11-09 13:30:08 +08:00

676 lines
17 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package security
import (
"context"
"fmt"
"os/exec"
"strings"
"time"
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/mcp"
"go.uber.org/zap"
)
// Executor 安全工具执行器
type Executor struct {
config *config.SecurityConfig
mcpServer *mcp.Server
logger *zap.Logger
}
// NewExecutor 创建新的执行器
func NewExecutor(cfg *config.SecurityConfig, mcpServer *mcp.Server, logger *zap.Logger) *Executor {
return &Executor{
config: cfg,
mcpServer: mcpServer,
logger: logger,
}
}
// ExecuteTool 执行安全工具
func (e *Executor) ExecuteTool(ctx context.Context, toolName string, args map[string]interface{}) (*mcp.ToolResult, error) {
e.logger.Info("ExecuteTool被调用",
zap.String("toolName", toolName),
zap.Any("args", args),
)
// 特殊处理exec工具直接执行系统命令
if toolName == "exec" {
e.logger.Info("执行exec工具")
return e.executeSystemCommand(ctx, args)
}
// 查找工具配置
var toolConfig *config.ToolConfig
for i := range e.config.Tools {
if e.config.Tools[i].Name == toolName && e.config.Tools[i].Enabled {
toolConfig = &e.config.Tools[i]
break
}
}
if toolConfig == nil {
e.logger.Error("工具未找到或未启用",
zap.String("toolName", toolName),
zap.Int("totalTools", len(e.config.Tools)),
)
return nil, fmt.Errorf("工具 %s 未找到或未启用", toolName)
}
e.logger.Info("找到工具配置",
zap.String("toolName", toolName),
zap.String("command", toolConfig.Command),
zap.Strings("args", toolConfig.Args),
)
// 构建命令 - 根据工具类型使用不同的参数格式
cmdArgs := e.buildCommandArgs(toolName, toolConfig, args)
e.logger.Info("构建命令参数完成",
zap.String("toolName", toolName),
zap.Strings("cmdArgs", cmdArgs),
zap.Int("argsCount", len(cmdArgs)),
)
// 验证命令参数
if len(cmdArgs) == 0 {
e.logger.Warn("命令参数为空",
zap.String("toolName", toolName),
zap.Any("inputArgs", args),
)
return &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: fmt.Sprintf("错误: 工具 %s 缺少必需的参数。接收到的参数: %v", toolName, args),
},
},
IsError: true,
}, nil
}
// 执行命令
cmd := exec.CommandContext(ctx, toolConfig.Command, cmdArgs...)
e.logger.Info("执行安全工具",
zap.String("tool", toolName),
zap.Strings("args", cmdArgs),
)
output, err := cmd.CombinedOutput()
if err != nil {
e.logger.Error("工具执行失败",
zap.String("tool", toolName),
zap.Error(err),
zap.String("output", string(output)),
)
return &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: fmt.Sprintf("工具执行失败: %v\n输出: %s", err, string(output)),
},
},
IsError: true,
}, nil
}
e.logger.Info("工具执行成功",
zap.String("tool", toolName),
zap.String("output", string(output)),
)
return &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: string(output),
},
},
IsError: false,
}, nil
}
// RegisterTools 注册工具到MCP服务器
func (e *Executor) RegisterTools(mcpServer *mcp.Server) {
e.logger.Info("开始注册工具",
zap.Int("totalTools", len(e.config.Tools)),
)
for i, toolConfig := range e.config.Tools {
if !toolConfig.Enabled {
e.logger.Debug("跳过未启用的工具",
zap.String("tool", toolConfig.Name),
)
continue
}
// 创建工具配置的副本,避免闭包问题
toolName := toolConfig.Name
toolConfigCopy := toolConfig
// 使用简短描述如果存在否则使用详细描述的前100个字符
shortDesc := toolConfigCopy.ShortDescription
if shortDesc == "" {
// 如果没有简短描述从详细描述中提取第一行或前100个字符
desc := toolConfigCopy.Description
if len(desc) > 100 {
// 尝试找到第一个换行符
if idx := strings.Index(desc, "\n"); idx > 0 && idx < 100 {
shortDesc = strings.TrimSpace(desc[:idx])
} else {
shortDesc = desc[:100] + "..."
}
} else {
shortDesc = desc
}
}
tool := mcp.Tool{
Name: toolConfigCopy.Name,
Description: toolConfigCopy.Description,
ShortDescription: shortDesc,
InputSchema: e.buildInputSchema(&toolConfigCopy),
}
handler := func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
e.logger.Info("工具handler被调用",
zap.String("toolName", toolName),
zap.Any("args", args),
)
return e.ExecuteTool(ctx, toolName, args)
}
mcpServer.RegisterTool(tool, handler)
e.logger.Info("注册安全工具成功",
zap.String("tool", toolConfigCopy.Name),
zap.String("command", toolConfigCopy.Command),
zap.Int("index", i),
)
}
e.logger.Info("工具注册完成",
zap.Int("registeredCount", len(e.config.Tools)),
)
}
// buildCommandArgs 构建命令参数
func (e *Executor) buildCommandArgs(toolName string, toolConfig *config.ToolConfig, args map[string]interface{}) []string {
cmdArgs := make([]string, 0)
// 如果配置中定义了参数映射,使用配置中的映射规则
if len(toolConfig.Parameters) > 0 {
// 先添加固定参数
cmdArgs = append(cmdArgs, toolConfig.Args...)
// 按位置参数排序
positionalParams := make([]config.ParameterConfig, 0)
flagParams := make([]config.ParameterConfig, 0)
for _, param := range toolConfig.Parameters {
if param.Position != nil {
positionalParams = append(positionalParams, param)
} else {
flagParams = append(flagParams, param)
}
}
// 对位置参数按位置排序
for i := 0; i < len(positionalParams); i++ {
for _, param := range positionalParams {
if param.Position != nil && *param.Position == i {
value := e.getParamValue(args, param)
if value == nil {
if param.Required {
// 必需参数缺失,返回空数组让上层处理错误
e.logger.Warn("缺少必需的位置参数",
zap.String("tool", toolName),
zap.String("param", param.Name),
zap.Int("position", *param.Position),
)
return []string{}
}
break
}
cmdArgs = append(cmdArgs, e.formatParamValue(param, value))
break
}
}
}
// 处理标志参数
for _, param := range flagParams {
value := e.getParamValue(args, param)
if value == nil {
if param.Required {
// 必需参数缺失,返回空数组让上层处理错误
e.logger.Warn("缺少必需的标志参数",
zap.String("tool", toolName),
zap.String("param", param.Name),
)
return []string{}
}
continue
}
// 布尔值特殊处理:如果为 false跳过如果为 true只添加标志
if param.Type == "bool" {
var boolVal bool
var ok bool
// 尝试多种类型转换
if boolVal, ok = value.(bool); ok {
// 已经是布尔值
} else if numVal, ok := value.(float64); ok {
// JSON 数字类型float64
boolVal = numVal != 0
ok = true
} else if numVal, ok := value.(int); ok {
// int 类型
boolVal = numVal != 0
ok = true
} else if strVal, ok := value.(string); ok {
// 字符串类型
boolVal = strVal == "true" || strVal == "1" || strVal == "yes"
ok = true
}
if ok {
if !boolVal {
continue // false 时不添加任何参数
}
// true 时只添加标志,不添加值
if param.Flag != "" {
cmdArgs = append(cmdArgs, param.Flag)
}
continue
}
}
format := param.Format
if format == "" {
format = "flag" // 默认格式
}
switch format {
case "flag":
// --flag value 或 -f value
if param.Flag != "" {
cmdArgs = append(cmdArgs, param.Flag)
}
formattedValue := e.formatParamValue(param, value)
if formattedValue != "" {
cmdArgs = append(cmdArgs, formattedValue)
}
case "combined":
// --flag=value 或 -f=value
if param.Flag != "" {
cmdArgs = append(cmdArgs, fmt.Sprintf("%s=%s", param.Flag, e.formatParamValue(param, value)))
} else {
cmdArgs = append(cmdArgs, e.formatParamValue(param, value))
}
case "template":
// 使用模板字符串
if param.Template != "" {
template := param.Template
template = strings.ReplaceAll(template, "{flag}", param.Flag)
template = strings.ReplaceAll(template, "{value}", e.formatParamValue(param, value))
template = strings.ReplaceAll(template, "{name}", param.Name)
cmdArgs = append(cmdArgs, strings.Fields(template)...)
} else {
// 如果没有模板,使用默认格式
if param.Flag != "" {
cmdArgs = append(cmdArgs, param.Flag)
}
cmdArgs = append(cmdArgs, e.formatParamValue(param, value))
}
case "positional":
// 位置参数(已在上面处理)
cmdArgs = append(cmdArgs, e.formatParamValue(param, value))
default:
// 默认:直接添加值
cmdArgs = append(cmdArgs, e.formatParamValue(param, value))
}
}
return cmdArgs
}
// 如果没有定义参数配置,使用固定参数和通用处理
// 添加固定参数
cmdArgs = append(cmdArgs, toolConfig.Args...)
// 通用处理:将参数转换为命令行参数
for key, value := range args {
if key == "_tool_name" {
continue
}
// 使用 --key value 格式
cmdArgs = append(cmdArgs, fmt.Sprintf("--%s", key))
if strValue, ok := value.(string); ok {
cmdArgs = append(cmdArgs, strValue)
} else {
cmdArgs = append(cmdArgs, fmt.Sprintf("%v", value))
}
}
return cmdArgs
}
// getParamValue 获取参数值,支持默认值
func (e *Executor) getParamValue(args map[string]interface{}, param config.ParameterConfig) interface{} {
// 从参数中获取值
if value, ok := args[param.Name]; ok && value != nil {
return value
}
// 如果参数是必需的但没有提供,返回 nil让上层处理错误
if param.Required {
return nil
}
// 返回默认值
return param.Default
}
// formatParamValue 格式化参数值
func (e *Executor) formatParamValue(param config.ParameterConfig, value interface{}) string {
switch param.Type {
case "bool":
// 布尔值应该在上层处理,这里不应该被调用
if boolVal, ok := value.(bool); ok {
return fmt.Sprintf("%v", boolVal)
}
return "false"
case "array":
// 数组:转换为逗号分隔的字符串
if arr, ok := value.([]interface{}); ok {
strs := make([]string, 0, len(arr))
for _, item := range arr {
strs = append(strs, fmt.Sprintf("%v", item))
}
return strings.Join(strs, ",")
}
return fmt.Sprintf("%v", value)
default:
return fmt.Sprintf("%v", value)
}
}
// executeSystemCommand 执行系统命令
func (e *Executor) executeSystemCommand(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
// 获取命令
command, ok := args["command"].(string)
if !ok {
return &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: "错误: 缺少command参数",
},
},
IsError: true,
}, nil
}
if command == "" {
return &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: "错误: command参数不能为空",
},
},
IsError: true,
}, nil
}
// 安全检查:记录执行的命令
e.logger.Warn("执行系统命令",
zap.String("command", command),
)
// 获取shell类型可选默认为sh
shell := "sh"
if s, ok := args["shell"].(string); ok && s != "" {
shell = s
}
// 获取工作目录(可选)
workDir := ""
if wd, ok := args["workdir"].(string); ok && wd != "" {
workDir = wd
}
// 构建命令
var cmd *exec.Cmd
if workDir != "" {
cmd = exec.CommandContext(ctx, shell, "-c", command)
cmd.Dir = workDir
} else {
cmd = exec.CommandContext(ctx, shell, "-c", command)
}
// 执行命令
e.logger.Info("执行系统命令",
zap.String("command", command),
zap.String("shell", shell),
zap.String("workdir", workDir),
)
output, err := cmd.CombinedOutput()
if err != nil {
e.logger.Error("系统命令执行失败",
zap.String("command", command),
zap.Error(err),
zap.String("output", string(output)),
)
return &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: fmt.Sprintf("命令执行失败: %v\n输出: %s", err, string(output)),
},
},
IsError: true,
}, nil
}
e.logger.Info("系统命令执行成功",
zap.String("command", command),
zap.String("output_length", fmt.Sprintf("%d", len(output))),
)
return &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: string(output),
},
},
IsError: false,
}, nil
}
// buildInputSchema 构建输入模式
func (e *Executor) buildInputSchema(toolConfig *config.ToolConfig) map[string]interface{} {
schema := map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{},
"required": []string{},
}
// 如果配置中定义了参数,优先使用配置中的参数定义
if len(toolConfig.Parameters) > 0 {
properties := make(map[string]interface{})
required := []string{}
for _, param := range toolConfig.Parameters {
// 转换类型为OpenAI/JSON Schema标准类型
openAIType := e.convertToOpenAIType(param.Type)
prop := map[string]interface{}{
"type": openAIType,
"description": param.Description,
}
// 添加默认值
if param.Default != nil {
prop["default"] = param.Default
}
// 添加枚举选项
if len(param.Options) > 0 {
prop["enum"] = param.Options
}
properties[param.Name] = prop
// 添加到必需参数列表
if param.Required {
required = append(required, param.Name)
}
}
schema["properties"] = properties
schema["required"] = required
return schema
}
// 如果没有定义参数配置返回空schema
// 这种情况下工具可能只使用固定参数args字段
// 或者需要通过YAML配置文件定义参数
e.logger.Warn("工具未定义参数配置返回空schema",
zap.String("tool", toolConfig.Name),
)
return schema
}
// convertToOpenAIType 将配置中的类型转换为OpenAI/JSON Schema标准类型
func (e *Executor) convertToOpenAIType(configType string) string {
switch configType {
case "bool":
return "boolean"
case "int", "integer":
return "number"
case "float", "double":
return "number"
case "string", "array", "object":
return configType
default:
// 默认返回原类型,但记录警告
e.logger.Warn("未知的参数类型,使用原类型",
zap.String("type", configType),
)
return configType
}
}
// Vulnerability 漏洞信息
type Vulnerability struct {
ID string `json:"id"`
Type string `json:"type"`
Severity string `json:"severity"` // low, medium, high, critical
Title string `json:"title"`
Description string `json:"description"`
Target string `json:"target"`
FoundAt time.Time `json:"foundAt"`
Details string `json:"details"`
}
// AnalyzeResults 分析工具执行结果,提取漏洞信息
func (e *Executor) AnalyzeResults(toolName string, result *mcp.ToolResult) []Vulnerability {
vulnerabilities := []Vulnerability{}
if result.IsError {
return vulnerabilities
}
// 分析输出内容
for _, content := range result.Content {
if content.Type == "text" {
vulns := e.parseToolOutput(toolName, content.Text)
vulnerabilities = append(vulnerabilities, vulns...)
}
}
return vulnerabilities
}
// parseToolOutput 解析工具输出
func (e *Executor) parseToolOutput(toolName, output string) []Vulnerability {
vulnerabilities := []Vulnerability{}
// 简单的漏洞检测逻辑
outputLower := strings.ToLower(output)
// SQL注入检测
if strings.Contains(outputLower, "sql injection") || strings.Contains(outputLower, "sqli") {
vulnerabilities = append(vulnerabilities, Vulnerability{
ID: fmt.Sprintf("sql-%d", time.Now().Unix()),
Type: "SQL Injection",
Severity: "high",
Title: "SQL注入漏洞",
Description: "检测到潜在的SQL注入漏洞",
FoundAt: time.Now(),
Details: output,
})
}
// XSS检测
if strings.Contains(outputLower, "xss") || strings.Contains(outputLower, "cross-site scripting") {
vulnerabilities = append(vulnerabilities, Vulnerability{
ID: fmt.Sprintf("xss-%d", time.Now().Unix()),
Type: "XSS",
Severity: "medium",
Title: "跨站脚本攻击漏洞",
Description: "检测到潜在的XSS漏洞",
FoundAt: time.Now(),
Details: output,
})
}
// 开放端口检测
if toolName == "nmap" {
lines := strings.Split(output, "\n")
for _, line := range lines {
if strings.Contains(line, "open") && strings.Contains(line, "port") {
vulnerabilities = append(vulnerabilities, Vulnerability{
ID: fmt.Sprintf("port-%d", time.Now().Unix()),
Type: "Open Port",
Severity: "low",
Title: "开放端口",
Description: fmt.Sprintf("发现开放端口: %s", line),
FoundAt: time.Now(),
Details: line,
})
}
}
}
return vulnerabilities
}
// GetVulnerabilityReport 生成漏洞报告
func (e *Executor) GetVulnerabilityReport(vulnerabilities []Vulnerability) map[string]interface{} {
severityCount := map[string]int{
"critical": 0,
"high": 0,
"medium": 0,
"low": 0,
}
for _, vuln := range vulnerabilities {
severityCount[vuln.Severity]++
}
return map[string]interface{}{
"total": len(vulnerabilities),
"severityCount": severityCount,
"vulnerabilities": vulnerabilities,
"generatedAt": time.Now(),
}
}