Add files via upload

This commit is contained in:
公明
2025-11-14 01:44:28 +08:00
committed by GitHub
parent 1b14070cee
commit f8dbfbb65f
5 changed files with 144 additions and 66 deletions
+60 -44
View File
@@ -9,23 +9,43 @@ import (
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/mcp"
"go.uber.org/zap"
)
// Executor 安全工具执行器
type Executor struct {
config *config.SecurityConfig
toolIndex map[string]*config.ToolConfig // 工具索引,用于 O(1) 查找
mcpServer *mcp.Server
logger *zap.Logger
}
// NewExecutor 创建新的执行器
func NewExecutor(cfg *config.SecurityConfig, mcpServer *mcp.Server, logger *zap.Logger) *Executor {
return &Executor{
executor := &Executor{
config: cfg,
toolIndex: make(map[string]*config.ToolConfig),
mcpServer: mcpServer,
logger: logger,
}
// 构建工具索引
executor.buildToolIndex()
return executor
}
// buildToolIndex 构建工具索引,将 O(n) 查找优化为 O(1)
func (e *Executor) buildToolIndex() {
e.toolIndex = make(map[string]*config.ToolConfig)
for i := range e.config.Tools {
if e.config.Tools[i].Enabled {
e.toolIndex[e.config.Tools[i].Name] = &e.config.Tools[i]
}
}
e.logger.Info("工具索引构建完成",
zap.Int("totalTools", len(e.config.Tools)),
zap.Int("enabledTools", len(e.toolIndex)),
)
}
// ExecuteTool 执行安全工具
@@ -34,30 +54,24 @@ func (e *Executor) ExecuteTool(ctx context.Context, toolName string, args map[st
zap.String("toolName", toolName),
zap.Any("args", args),
)
// 特殊处理:exec工具直接执行系统命令
if toolName == "exec" {
e.logger.Info("执行exec工具")
return e.executeSystemCommand(ctx, args)
}
// 查找工具配置
var toolConfig *config.ToolConfig
for i := range e.config.Tools {
if e.config.Tools[i].Name == toolName && e.config.Tools[i].Enabled {
toolConfig = &e.config.Tools[i]
break
}
}
if toolConfig == nil {
// 使用索引查找工具配置O(1) 查找)
toolConfig, exists := e.toolIndex[toolName]
if !exists {
e.logger.Error("工具未找到或未启用",
zap.String("toolName", toolName),
zap.Int("totalTools", len(e.config.Tools)),
zap.Int("enabledTools", len(e.toolIndex)),
)
return nil, fmt.Errorf("工具 %s 未找到或未启用", toolName)
}
e.logger.Info("找到工具配置",
zap.String("toolName", toolName),
zap.String("command", toolConfig.Command),
@@ -66,13 +80,13 @@ func (e *Executor) ExecuteTool(ctx context.Context, toolName string, args map[st
// 构建命令 - 根据工具类型使用不同的参数格式
cmdArgs := e.buildCommandArgs(toolName, toolConfig, args)
e.logger.Info("构建命令参数完成",
zap.String("toolName", toolName),
zap.Strings("cmdArgs", cmdArgs),
zap.Int("argsCount", len(cmdArgs)),
)
// 验证命令参数
if len(cmdArgs) == 0 {
e.logger.Warn("命令参数为空",
@@ -92,7 +106,7 @@ func (e *Executor) ExecuteTool(ctx context.Context, toolName string, args map[st
// 执行命令
cmd := exec.CommandContext(ctx, toolConfig.Command, cmdArgs...)
e.logger.Info("执行安全工具",
zap.String("tool", toolName),
zap.Strings("args", cmdArgs),
@@ -136,8 +150,12 @@ func (e *Executor) ExecuteTool(ctx context.Context, toolName string, args map[st
func (e *Executor) RegisterTools(mcpServer *mcp.Server) {
e.logger.Info("开始注册工具",
zap.Int("totalTools", len(e.config.Tools)),
zap.Int("enabledTools", len(e.toolIndex)),
)
// 重新构建索引(以防配置更新)
e.buildToolIndex()
for i, toolConfig := range e.config.Tools {
if !toolConfig.Enabled {
e.logger.Debug("跳过未启用的工具",
@@ -149,7 +167,7 @@ func (e *Executor) RegisterTools(mcpServer *mcp.Server) {
// 创建工具配置的副本,避免闭包问题
toolName := toolConfig.Name
toolConfigCopy := toolConfig
// 使用简短描述(如果存在),否则使用详细描述的前100个字符
shortDesc := toolConfigCopy.ShortDescription
if shortDesc == "" {
@@ -166,7 +184,7 @@ func (e *Executor) RegisterTools(mcpServer *mcp.Server) {
shortDesc = desc
}
}
tool := mcp.Tool{
Name: toolConfigCopy.Name,
Description: toolConfigCopy.Description,
@@ -189,7 +207,7 @@ func (e *Executor) RegisterTools(mcpServer *mcp.Server) {
zap.Int("index", i),
)
}
e.logger.Info("工具注册完成",
zap.Int("registeredCount", len(e.config.Tools)),
)
@@ -208,7 +226,7 @@ func (e *Executor) buildCommandArgs(toolName string, toolConfig *config.ToolConf
hasScanType = true
scanTypeValue = scanType
}
// 添加固定参数(如果指定了 scan_type,可能需要过滤掉默认的扫描类型参数)
if hasScanType && toolName == "nmap" {
// 对于 nmap,如果指定了 scan_type,跳过默认的 -sT -sV -sC
@@ -237,7 +255,7 @@ func (e *Executor) buildCommandArgs(toolName string, toolConfig *config.ToolConf
if param.Name == "additional_args" || param.Name == "scan_type" || param.Name == "action" {
continue
}
value := e.getParamValue(args, param)
if value == nil {
if param.Required {
@@ -255,7 +273,7 @@ func (e *Executor) buildCommandArgs(toolName string, toolConfig *config.ToolConf
if param.Type == "bool" {
var boolVal bool
var ok bool
// 尝试多种类型转换
if boolVal, ok = value.(bool); ok {
// 已经是布尔值
@@ -272,7 +290,7 @@ func (e *Executor) buildCommandArgs(toolName string, toolConfig *config.ToolConf
boolVal = strVal == "true" || strVal == "1" || strVal == "yes"
ok = true
}
if ok {
if !boolVal {
continue // false 时不添加任何参数
@@ -340,7 +358,7 @@ func (e *Executor) buildCommandArgs(toolName string, toolConfig *config.ToolConf
if param.Name == "additional_args" || param.Name == "scan_type" || param.Name == "action" {
continue
}
if param.Position != nil && *param.Position == i {
value := e.getParamValue(args, param)
if value == nil {
@@ -365,14 +383,14 @@ func (e *Executor) buildCommandArgs(toolName string, toolConfig *config.ToolConf
}
}
}
// 特殊处理:additional_args 参数(需要按空格分割成多个参数)
if additionalArgs, ok := args["additional_args"].(string); ok && additionalArgs != "" {
// 按空格分割,但保留引号内的内容
additionalArgsList := e.parseAdditionalArgs(additionalArgs)
cmdArgs = append(cmdArgs, additionalArgsList...)
}
// 特殊处理:scan_type 参数(需要按空格分割并插入到合适位置)
if hasScanType {
scanTypeArgs := e.parseAdditionalArgs(scanTypeValue)
@@ -403,7 +421,7 @@ func (e *Executor) buildCommandArgs(toolName string, toolConfig *config.ToolConf
// 如果没有定义参数配置,使用固定参数和通用处理
// 添加固定参数
cmdArgs = append(cmdArgs, toolConfig.Args...)
// 通用处理:将参数转换为命令行参数
for key, value := range args {
if key == "_tool_name" {
@@ -426,23 +444,23 @@ func (e *Executor) parseAdditionalArgs(argsStr string) []string {
if argsStr == "" {
return []string{}
}
result := make([]string, 0)
var current strings.Builder
inQuotes := false
var quoteChar rune
escapeNext := false
runes := []rune(argsStr)
for i := 0; i < len(runes); i++ {
r := runes[i]
if escapeNext {
current.WriteRune(r)
escapeNext = false
continue
}
if r == '\\' {
// 检查下一个字符是否是引号
if i+1 < len(runes) && (runes[i+1] == '"' || runes[i+1] == '\'') {
@@ -456,19 +474,19 @@ func (e *Executor) parseAdditionalArgs(argsStr string) []string {
}
continue
}
if !inQuotes && (r == '"' || r == '\'') {
inQuotes = true
quoteChar = r
continue
}
if inQuotes && r == quoteChar {
inQuotes = false
quoteChar = 0
continue
}
if !inQuotes && (r == ' ' || r == '\t' || r == '\n') {
if current.Len() > 0 {
result = append(result, current.String())
@@ -476,20 +494,20 @@ func (e *Executor) parseAdditionalArgs(argsStr string) []string {
}
continue
}
current.WriteRune(r)
}
// 处理最后一个参数(如果存在)
if current.Len() > 0 {
result = append(result, current.String())
}
// 如果解析结果为空,使用简单的空格分割作为降级方案
if len(result) == 0 {
result = strings.Fields(argsStr)
}
return result
}
@@ -638,9 +656,9 @@ func (e *Executor) executeSystemCommand(ctx context.Context, args map[string]int
// buildInputSchema 构建输入模式
func (e *Executor) buildInputSchema(toolConfig *config.ToolConfig) map[string]interface{} {
schema := map[string]interface{}{
"type": "object",
"type": "object",
"properties": map[string]interface{}{},
"required": []string{},
"required": []string{},
}
// 如果配置中定义了参数,优先使用配置中的参数定义
@@ -651,7 +669,7 @@ func (e *Executor) buildInputSchema(toolConfig *config.ToolConfig) map[string]in
for _, param := range toolConfig.Parameters {
// 转换类型为OpenAI/JSON Schema标准类型
openAIType := e.convertToOpenAIType(param.Type)
prop := map[string]interface{}{
"type": openAIType,
"description": param.Description,
@@ -750,5 +768,3 @@ func (e *Executor) GetVulnerabilityReport(vulnerabilities []Vulnerability) map[s
"generatedAt": time.Now(),
}
}