mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-04-22 02:36:40 +02:00
656 lines
17 KiB
Go
656 lines
17 KiB
Go
package security
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"os/exec"
|
|
"strings"
|
|
"time"
|
|
|
|
"cyberstrike-ai/internal/config"
|
|
"cyberstrike-ai/internal/mcp"
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
// Executor 安全工具执行器
|
|
type Executor struct {
|
|
config *config.SecurityConfig
|
|
mcpServer *mcp.Server
|
|
logger *zap.Logger
|
|
}
|
|
|
|
// NewExecutor 创建新的执行器
|
|
func NewExecutor(cfg *config.SecurityConfig, mcpServer *mcp.Server, logger *zap.Logger) *Executor {
|
|
return &Executor{
|
|
config: cfg,
|
|
mcpServer: mcpServer,
|
|
logger: logger,
|
|
}
|
|
}
|
|
|
|
// ExecuteTool 执行安全工具
|
|
func (e *Executor) ExecuteTool(ctx context.Context, toolName string, args map[string]interface{}) (*mcp.ToolResult, error) {
|
|
e.logger.Info("ExecuteTool被调用",
|
|
zap.String("toolName", toolName),
|
|
zap.Any("args", args),
|
|
)
|
|
|
|
// 特殊处理:exec工具直接执行系统命令
|
|
if toolName == "exec" {
|
|
e.logger.Info("执行exec工具")
|
|
return e.executeSystemCommand(ctx, args)
|
|
}
|
|
|
|
// 查找工具配置
|
|
var toolConfig *config.ToolConfig
|
|
for i := range e.config.Tools {
|
|
if e.config.Tools[i].Name == toolName && e.config.Tools[i].Enabled {
|
|
toolConfig = &e.config.Tools[i]
|
|
break
|
|
}
|
|
}
|
|
|
|
if toolConfig == nil {
|
|
e.logger.Error("工具未找到或未启用",
|
|
zap.String("toolName", toolName),
|
|
zap.Int("totalTools", len(e.config.Tools)),
|
|
)
|
|
return nil, fmt.Errorf("工具 %s 未找到或未启用", toolName)
|
|
}
|
|
|
|
e.logger.Info("找到工具配置",
|
|
zap.String("toolName", toolName),
|
|
zap.String("command", toolConfig.Command),
|
|
zap.Strings("args", toolConfig.Args),
|
|
)
|
|
|
|
// 构建命令 - 根据工具类型使用不同的参数格式
|
|
cmdArgs := e.buildCommandArgs(toolName, toolConfig, args)
|
|
|
|
e.logger.Info("构建命令参数完成",
|
|
zap.String("toolName", toolName),
|
|
zap.Strings("cmdArgs", cmdArgs),
|
|
zap.Int("argsCount", len(cmdArgs)),
|
|
)
|
|
|
|
// 验证命令参数
|
|
if len(cmdArgs) == 0 {
|
|
e.logger.Warn("命令参数为空",
|
|
zap.String("toolName", toolName),
|
|
zap.Any("inputArgs", args),
|
|
)
|
|
return &mcp.ToolResult{
|
|
Content: []mcp.Content{
|
|
{
|
|
Type: "text",
|
|
Text: fmt.Sprintf("错误: 工具 %s 缺少必需的参数。接收到的参数: %v", toolName, args),
|
|
},
|
|
},
|
|
IsError: true,
|
|
}, nil
|
|
}
|
|
|
|
// 执行命令
|
|
cmd := exec.CommandContext(ctx, toolConfig.Command, cmdArgs...)
|
|
|
|
e.logger.Info("执行安全工具",
|
|
zap.String("tool", toolName),
|
|
zap.Strings("args", cmdArgs),
|
|
)
|
|
|
|
output, err := cmd.CombinedOutput()
|
|
if err != nil {
|
|
e.logger.Error("工具执行失败",
|
|
zap.String("tool", toolName),
|
|
zap.Error(err),
|
|
zap.String("output", string(output)),
|
|
)
|
|
return &mcp.ToolResult{
|
|
Content: []mcp.Content{
|
|
{
|
|
Type: "text",
|
|
Text: fmt.Sprintf("工具执行失败: %v\n输出: %s", err, string(output)),
|
|
},
|
|
},
|
|
IsError: true,
|
|
}, nil
|
|
}
|
|
|
|
e.logger.Info("工具执行成功",
|
|
zap.String("tool", toolName),
|
|
zap.String("output", string(output)),
|
|
)
|
|
|
|
return &mcp.ToolResult{
|
|
Content: []mcp.Content{
|
|
{
|
|
Type: "text",
|
|
Text: string(output),
|
|
},
|
|
},
|
|
IsError: false,
|
|
}, nil
|
|
}
|
|
|
|
// RegisterTools 注册工具到MCP服务器
|
|
func (e *Executor) RegisterTools(mcpServer *mcp.Server) {
|
|
e.logger.Info("开始注册工具",
|
|
zap.Int("totalTools", len(e.config.Tools)),
|
|
)
|
|
|
|
for i, toolConfig := range e.config.Tools {
|
|
if !toolConfig.Enabled {
|
|
e.logger.Debug("跳过未启用的工具",
|
|
zap.String("tool", toolConfig.Name),
|
|
)
|
|
continue
|
|
}
|
|
|
|
// 创建工具配置的副本,避免闭包问题
|
|
toolName := toolConfig.Name
|
|
toolConfigCopy := toolConfig
|
|
|
|
// 使用简短描述(如果存在),否则使用详细描述的前100个字符
|
|
shortDesc := toolConfigCopy.ShortDescription
|
|
if shortDesc == "" {
|
|
// 如果没有简短描述,从详细描述中提取第一行或前100个字符
|
|
desc := toolConfigCopy.Description
|
|
if len(desc) > 100 {
|
|
// 尝试找到第一个换行符
|
|
if idx := strings.Index(desc, "\n"); idx > 0 && idx < 100 {
|
|
shortDesc = strings.TrimSpace(desc[:idx])
|
|
} else {
|
|
shortDesc = desc[:100] + "..."
|
|
}
|
|
} else {
|
|
shortDesc = desc
|
|
}
|
|
}
|
|
|
|
tool := mcp.Tool{
|
|
Name: toolConfigCopy.Name,
|
|
Description: toolConfigCopy.Description,
|
|
ShortDescription: shortDesc,
|
|
InputSchema: e.buildInputSchema(&toolConfigCopy),
|
|
}
|
|
|
|
handler := func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
|
|
e.logger.Info("工具handler被调用",
|
|
zap.String("toolName", toolName),
|
|
zap.Any("args", args),
|
|
)
|
|
return e.ExecuteTool(ctx, toolName, args)
|
|
}
|
|
|
|
mcpServer.RegisterTool(tool, handler)
|
|
e.logger.Info("注册安全工具成功",
|
|
zap.String("tool", toolConfigCopy.Name),
|
|
zap.String("command", toolConfigCopy.Command),
|
|
zap.Int("index", i),
|
|
)
|
|
}
|
|
|
|
e.logger.Info("工具注册完成",
|
|
zap.Int("registeredCount", len(e.config.Tools)),
|
|
)
|
|
}
|
|
|
|
// buildCommandArgs 构建命令参数
|
|
func (e *Executor) buildCommandArgs(toolName string, toolConfig *config.ToolConfig, args map[string]interface{}) []string {
|
|
cmdArgs := make([]string, 0)
|
|
|
|
// 如果配置中定义了参数映射,使用配置中的映射规则
|
|
if len(toolConfig.Parameters) > 0 {
|
|
// 先添加固定参数
|
|
cmdArgs = append(cmdArgs, toolConfig.Args...)
|
|
|
|
// 按位置参数排序
|
|
positionalParams := make([]config.ParameterConfig, 0)
|
|
flagParams := make([]config.ParameterConfig, 0)
|
|
|
|
for _, param := range toolConfig.Parameters {
|
|
if param.Position != nil {
|
|
positionalParams = append(positionalParams, param)
|
|
} else {
|
|
flagParams = append(flagParams, param)
|
|
}
|
|
}
|
|
|
|
// 对位置参数按位置排序
|
|
for i := 0; i < len(positionalParams); i++ {
|
|
for _, param := range positionalParams {
|
|
if param.Position != nil && *param.Position == i {
|
|
value := e.getParamValue(args, param)
|
|
if value == nil {
|
|
if param.Required {
|
|
// 必需参数缺失,返回空数组让上层处理错误
|
|
e.logger.Warn("缺少必需的位置参数",
|
|
zap.String("tool", toolName),
|
|
zap.String("param", param.Name),
|
|
zap.Int("position", *param.Position),
|
|
)
|
|
return []string{}
|
|
}
|
|
break
|
|
}
|
|
cmdArgs = append(cmdArgs, e.formatParamValue(param, value))
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
// 处理标志参数
|
|
for _, param := range flagParams {
|
|
value := e.getParamValue(args, param)
|
|
if value == nil {
|
|
if param.Required {
|
|
// 必需参数缺失,返回空数组让上层处理错误
|
|
e.logger.Warn("缺少必需的标志参数",
|
|
zap.String("tool", toolName),
|
|
zap.String("param", param.Name),
|
|
)
|
|
return []string{}
|
|
}
|
|
continue
|
|
}
|
|
|
|
// 布尔值特殊处理:如果为 false,跳过;如果为 true,只添加标志
|
|
if param.Type == "bool" {
|
|
if boolVal, ok := value.(bool); ok {
|
|
if !boolVal {
|
|
continue // false 时不添加任何参数
|
|
}
|
|
// true 时只添加标志,不添加值
|
|
if param.Flag != "" {
|
|
cmdArgs = append(cmdArgs, param.Flag)
|
|
}
|
|
continue
|
|
}
|
|
}
|
|
|
|
format := param.Format
|
|
if format == "" {
|
|
format = "flag" // 默认格式
|
|
}
|
|
|
|
switch format {
|
|
case "flag":
|
|
// --flag value 或 -f value
|
|
if param.Flag != "" {
|
|
cmdArgs = append(cmdArgs, param.Flag)
|
|
}
|
|
formattedValue := e.formatParamValue(param, value)
|
|
if formattedValue != "" {
|
|
cmdArgs = append(cmdArgs, formattedValue)
|
|
}
|
|
case "combined":
|
|
// --flag=value 或 -f=value
|
|
if param.Flag != "" {
|
|
cmdArgs = append(cmdArgs, fmt.Sprintf("%s=%s", param.Flag, e.formatParamValue(param, value)))
|
|
} else {
|
|
cmdArgs = append(cmdArgs, e.formatParamValue(param, value))
|
|
}
|
|
case "template":
|
|
// 使用模板字符串
|
|
if param.Template != "" {
|
|
template := param.Template
|
|
template = strings.ReplaceAll(template, "{flag}", param.Flag)
|
|
template = strings.ReplaceAll(template, "{value}", e.formatParamValue(param, value))
|
|
template = strings.ReplaceAll(template, "{name}", param.Name)
|
|
cmdArgs = append(cmdArgs, strings.Fields(template)...)
|
|
} else {
|
|
// 如果没有模板,使用默认格式
|
|
if param.Flag != "" {
|
|
cmdArgs = append(cmdArgs, param.Flag)
|
|
}
|
|
cmdArgs = append(cmdArgs, e.formatParamValue(param, value))
|
|
}
|
|
case "positional":
|
|
// 位置参数(已在上面处理)
|
|
cmdArgs = append(cmdArgs, e.formatParamValue(param, value))
|
|
default:
|
|
// 默认:直接添加值
|
|
cmdArgs = append(cmdArgs, e.formatParamValue(param, value))
|
|
}
|
|
}
|
|
|
|
return cmdArgs
|
|
}
|
|
|
|
// 如果没有定义参数配置,使用固定参数和通用处理
|
|
// 添加固定参数
|
|
cmdArgs = append(cmdArgs, toolConfig.Args...)
|
|
|
|
// 通用处理:将参数转换为命令行参数
|
|
for key, value := range args {
|
|
if key == "_tool_name" {
|
|
continue
|
|
}
|
|
// 使用 --key value 格式
|
|
cmdArgs = append(cmdArgs, fmt.Sprintf("--%s", key))
|
|
if strValue, ok := value.(string); ok {
|
|
cmdArgs = append(cmdArgs, strValue)
|
|
} else {
|
|
cmdArgs = append(cmdArgs, fmt.Sprintf("%v", value))
|
|
}
|
|
}
|
|
|
|
return cmdArgs
|
|
}
|
|
|
|
// getParamValue 获取参数值,支持默认值
|
|
func (e *Executor) getParamValue(args map[string]interface{}, param config.ParameterConfig) interface{} {
|
|
// 从参数中获取值
|
|
if value, ok := args[param.Name]; ok && value != nil {
|
|
return value
|
|
}
|
|
|
|
// 如果参数是必需的但没有提供,返回 nil(让上层处理错误)
|
|
if param.Required {
|
|
return nil
|
|
}
|
|
|
|
// 返回默认值
|
|
return param.Default
|
|
}
|
|
|
|
// formatParamValue 格式化参数值
|
|
func (e *Executor) formatParamValue(param config.ParameterConfig, value interface{}) string {
|
|
switch param.Type {
|
|
case "bool":
|
|
// 布尔值应该在上层处理,这里不应该被调用
|
|
if boolVal, ok := value.(bool); ok {
|
|
return fmt.Sprintf("%v", boolVal)
|
|
}
|
|
return "false"
|
|
case "array":
|
|
// 数组:转换为逗号分隔的字符串
|
|
if arr, ok := value.([]interface{}); ok {
|
|
strs := make([]string, 0, len(arr))
|
|
for _, item := range arr {
|
|
strs = append(strs, fmt.Sprintf("%v", item))
|
|
}
|
|
return strings.Join(strs, ",")
|
|
}
|
|
return fmt.Sprintf("%v", value)
|
|
default:
|
|
return fmt.Sprintf("%v", value)
|
|
}
|
|
}
|
|
|
|
// executeSystemCommand 执行系统命令
|
|
func (e *Executor) executeSystemCommand(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
|
|
// 获取命令
|
|
command, ok := args["command"].(string)
|
|
if !ok {
|
|
return &mcp.ToolResult{
|
|
Content: []mcp.Content{
|
|
{
|
|
Type: "text",
|
|
Text: "错误: 缺少command参数",
|
|
},
|
|
},
|
|
IsError: true,
|
|
}, nil
|
|
}
|
|
|
|
if command == "" {
|
|
return &mcp.ToolResult{
|
|
Content: []mcp.Content{
|
|
{
|
|
Type: "text",
|
|
Text: "错误: command参数不能为空",
|
|
},
|
|
},
|
|
IsError: true,
|
|
}, nil
|
|
}
|
|
|
|
// 安全检查:记录执行的命令
|
|
e.logger.Warn("执行系统命令",
|
|
zap.String("command", command),
|
|
)
|
|
|
|
// 获取shell类型(可选,默认为sh)
|
|
shell := "sh"
|
|
if s, ok := args["shell"].(string); ok && s != "" {
|
|
shell = s
|
|
}
|
|
|
|
// 获取工作目录(可选)
|
|
workDir := ""
|
|
if wd, ok := args["workdir"].(string); ok && wd != "" {
|
|
workDir = wd
|
|
}
|
|
|
|
// 构建命令
|
|
var cmd *exec.Cmd
|
|
if workDir != "" {
|
|
cmd = exec.CommandContext(ctx, shell, "-c", command)
|
|
cmd.Dir = workDir
|
|
} else {
|
|
cmd = exec.CommandContext(ctx, shell, "-c", command)
|
|
}
|
|
|
|
// 执行命令
|
|
e.logger.Info("执行系统命令",
|
|
zap.String("command", command),
|
|
zap.String("shell", shell),
|
|
zap.String("workdir", workDir),
|
|
)
|
|
|
|
output, err := cmd.CombinedOutput()
|
|
if err != nil {
|
|
e.logger.Error("系统命令执行失败",
|
|
zap.String("command", command),
|
|
zap.Error(err),
|
|
zap.String("output", string(output)),
|
|
)
|
|
return &mcp.ToolResult{
|
|
Content: []mcp.Content{
|
|
{
|
|
Type: "text",
|
|
Text: fmt.Sprintf("命令执行失败: %v\n输出: %s", err, string(output)),
|
|
},
|
|
},
|
|
IsError: true,
|
|
}, nil
|
|
}
|
|
|
|
e.logger.Info("系统命令执行成功",
|
|
zap.String("command", command),
|
|
zap.String("output_length", fmt.Sprintf("%d", len(output))),
|
|
)
|
|
|
|
return &mcp.ToolResult{
|
|
Content: []mcp.Content{
|
|
{
|
|
Type: "text",
|
|
Text: string(output),
|
|
},
|
|
},
|
|
IsError: false,
|
|
}, nil
|
|
}
|
|
|
|
// buildInputSchema 构建输入模式
|
|
func (e *Executor) buildInputSchema(toolConfig *config.ToolConfig) map[string]interface{} {
|
|
schema := map[string]interface{}{
|
|
"type": "object",
|
|
"properties": map[string]interface{}{},
|
|
"required": []string{},
|
|
}
|
|
|
|
// 如果配置中定义了参数,优先使用配置中的参数定义
|
|
if len(toolConfig.Parameters) > 0 {
|
|
properties := make(map[string]interface{})
|
|
required := []string{}
|
|
|
|
for _, param := range toolConfig.Parameters {
|
|
// 转换类型为OpenAI/JSON Schema标准类型
|
|
openAIType := e.convertToOpenAIType(param.Type)
|
|
|
|
prop := map[string]interface{}{
|
|
"type": openAIType,
|
|
"description": param.Description,
|
|
}
|
|
|
|
// 添加默认值
|
|
if param.Default != nil {
|
|
prop["default"] = param.Default
|
|
}
|
|
|
|
// 添加枚举选项
|
|
if len(param.Options) > 0 {
|
|
prop["enum"] = param.Options
|
|
}
|
|
|
|
properties[param.Name] = prop
|
|
|
|
// 添加到必需参数列表
|
|
if param.Required {
|
|
required = append(required, param.Name)
|
|
}
|
|
}
|
|
|
|
schema["properties"] = properties
|
|
schema["required"] = required
|
|
return schema
|
|
}
|
|
|
|
// 如果没有定义参数配置,返回空schema
|
|
// 这种情况下工具可能只使用固定参数(args字段)
|
|
// 或者需要通过YAML配置文件定义参数
|
|
e.logger.Warn("工具未定义参数配置,返回空schema",
|
|
zap.String("tool", toolConfig.Name),
|
|
)
|
|
return schema
|
|
}
|
|
|
|
// convertToOpenAIType 将配置中的类型转换为OpenAI/JSON Schema标准类型
|
|
func (e *Executor) convertToOpenAIType(configType string) string {
|
|
switch configType {
|
|
case "bool":
|
|
return "boolean"
|
|
case "int", "integer":
|
|
return "number"
|
|
case "float", "double":
|
|
return "number"
|
|
case "string", "array", "object":
|
|
return configType
|
|
default:
|
|
// 默认返回原类型,但记录警告
|
|
e.logger.Warn("未知的参数类型,使用原类型",
|
|
zap.String("type", configType),
|
|
)
|
|
return configType
|
|
}
|
|
}
|
|
|
|
// Vulnerability 漏洞信息
|
|
type Vulnerability struct {
|
|
ID string `json:"id"`
|
|
Type string `json:"type"`
|
|
Severity string `json:"severity"` // low, medium, high, critical
|
|
Title string `json:"title"`
|
|
Description string `json:"description"`
|
|
Target string `json:"target"`
|
|
FoundAt time.Time `json:"foundAt"`
|
|
Details string `json:"details"`
|
|
}
|
|
|
|
// AnalyzeResults 分析工具执行结果,提取漏洞信息
|
|
func (e *Executor) AnalyzeResults(toolName string, result *mcp.ToolResult) []Vulnerability {
|
|
vulnerabilities := []Vulnerability{}
|
|
|
|
if result.IsError {
|
|
return vulnerabilities
|
|
}
|
|
|
|
// 分析输出内容
|
|
for _, content := range result.Content {
|
|
if content.Type == "text" {
|
|
vulns := e.parseToolOutput(toolName, content.Text)
|
|
vulnerabilities = append(vulnerabilities, vulns...)
|
|
}
|
|
}
|
|
|
|
return vulnerabilities
|
|
}
|
|
|
|
// parseToolOutput 解析工具输出
|
|
func (e *Executor) parseToolOutput(toolName, output string) []Vulnerability {
|
|
vulnerabilities := []Vulnerability{}
|
|
|
|
// 简单的漏洞检测逻辑
|
|
outputLower := strings.ToLower(output)
|
|
|
|
// SQL注入检测
|
|
if strings.Contains(outputLower, "sql injection") || strings.Contains(outputLower, "sqli") {
|
|
vulnerabilities = append(vulnerabilities, Vulnerability{
|
|
ID: fmt.Sprintf("sql-%d", time.Now().Unix()),
|
|
Type: "SQL Injection",
|
|
Severity: "high",
|
|
Title: "SQL注入漏洞",
|
|
Description: "检测到潜在的SQL注入漏洞",
|
|
FoundAt: time.Now(),
|
|
Details: output,
|
|
})
|
|
}
|
|
|
|
// XSS检测
|
|
if strings.Contains(outputLower, "xss") || strings.Contains(outputLower, "cross-site scripting") {
|
|
vulnerabilities = append(vulnerabilities, Vulnerability{
|
|
ID: fmt.Sprintf("xss-%d", time.Now().Unix()),
|
|
Type: "XSS",
|
|
Severity: "medium",
|
|
Title: "跨站脚本攻击漏洞",
|
|
Description: "检测到潜在的XSS漏洞",
|
|
FoundAt: time.Now(),
|
|
Details: output,
|
|
})
|
|
}
|
|
|
|
// 开放端口检测
|
|
if toolName == "nmap" {
|
|
lines := strings.Split(output, "\n")
|
|
for _, line := range lines {
|
|
if strings.Contains(line, "open") && strings.Contains(line, "port") {
|
|
vulnerabilities = append(vulnerabilities, Vulnerability{
|
|
ID: fmt.Sprintf("port-%d", time.Now().Unix()),
|
|
Type: "Open Port",
|
|
Severity: "low",
|
|
Title: "开放端口",
|
|
Description: fmt.Sprintf("发现开放端口: %s", line),
|
|
FoundAt: time.Now(),
|
|
Details: line,
|
|
})
|
|
}
|
|
}
|
|
}
|
|
|
|
return vulnerabilities
|
|
}
|
|
|
|
// GetVulnerabilityReport 生成漏洞报告
|
|
func (e *Executor) GetVulnerabilityReport(vulnerabilities []Vulnerability) map[string]interface{} {
|
|
severityCount := map[string]int{
|
|
"critical": 0,
|
|
"high": 0,
|
|
"medium": 0,
|
|
"low": 0,
|
|
}
|
|
|
|
for _, vuln := range vulnerabilities {
|
|
severityCount[vuln.Severity]++
|
|
}
|
|
|
|
return map[string]interface{}{
|
|
"total": len(vulnerabilities),
|
|
"severityCount": severityCount,
|
|
"vulnerabilities": vulnerabilities,
|
|
"generatedAt": time.Now(),
|
|
}
|
|
}
|
|
|