mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-07-05 03:58:16 +02:00
Add files via upload
This commit is contained in:
@@ -9,23 +9,43 @@ import (
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Executor 安全工具执行器
|
||||
type Executor struct {
|
||||
config *config.SecurityConfig
|
||||
toolIndex map[string]*config.ToolConfig // 工具索引,用于 O(1) 查找
|
||||
mcpServer *mcp.Server
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewExecutor 创建新的执行器
|
||||
func NewExecutor(cfg *config.SecurityConfig, mcpServer *mcp.Server, logger *zap.Logger) *Executor {
|
||||
return &Executor{
|
||||
executor := &Executor{
|
||||
config: cfg,
|
||||
toolIndex: make(map[string]*config.ToolConfig),
|
||||
mcpServer: mcpServer,
|
||||
logger: logger,
|
||||
}
|
||||
// 构建工具索引
|
||||
executor.buildToolIndex()
|
||||
return executor
|
||||
}
|
||||
|
||||
// buildToolIndex 构建工具索引,将 O(n) 查找优化为 O(1)
|
||||
func (e *Executor) buildToolIndex() {
|
||||
e.toolIndex = make(map[string]*config.ToolConfig)
|
||||
for i := range e.config.Tools {
|
||||
if e.config.Tools[i].Enabled {
|
||||
e.toolIndex[e.config.Tools[i].Name] = &e.config.Tools[i]
|
||||
}
|
||||
}
|
||||
e.logger.Info("工具索引构建完成",
|
||||
zap.Int("totalTools", len(e.config.Tools)),
|
||||
zap.Int("enabledTools", len(e.toolIndex)),
|
||||
)
|
||||
}
|
||||
|
||||
// ExecuteTool 执行安全工具
|
||||
@@ -34,30 +54,24 @@ func (e *Executor) ExecuteTool(ctx context.Context, toolName string, args map[st
|
||||
zap.String("toolName", toolName),
|
||||
zap.Any("args", args),
|
||||
)
|
||||
|
||||
|
||||
// 特殊处理:exec工具直接执行系统命令
|
||||
if toolName == "exec" {
|
||||
e.logger.Info("执行exec工具")
|
||||
return e.executeSystemCommand(ctx, args)
|
||||
}
|
||||
|
||||
// 查找工具配置
|
||||
var toolConfig *config.ToolConfig
|
||||
for i := range e.config.Tools {
|
||||
if e.config.Tools[i].Name == toolName && e.config.Tools[i].Enabled {
|
||||
toolConfig = &e.config.Tools[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if toolConfig == nil {
|
||||
// 使用索引查找工具配置(O(1) 查找)
|
||||
toolConfig, exists := e.toolIndex[toolName]
|
||||
if !exists {
|
||||
e.logger.Error("工具未找到或未启用",
|
||||
zap.String("toolName", toolName),
|
||||
zap.Int("totalTools", len(e.config.Tools)),
|
||||
zap.Int("enabledTools", len(e.toolIndex)),
|
||||
)
|
||||
return nil, fmt.Errorf("工具 %s 未找到或未启用", toolName)
|
||||
}
|
||||
|
||||
|
||||
e.logger.Info("找到工具配置",
|
||||
zap.String("toolName", toolName),
|
||||
zap.String("command", toolConfig.Command),
|
||||
@@ -66,13 +80,13 @@ func (e *Executor) ExecuteTool(ctx context.Context, toolName string, args map[st
|
||||
|
||||
// 构建命令 - 根据工具类型使用不同的参数格式
|
||||
cmdArgs := e.buildCommandArgs(toolName, toolConfig, args)
|
||||
|
||||
|
||||
e.logger.Info("构建命令参数完成",
|
||||
zap.String("toolName", toolName),
|
||||
zap.Strings("cmdArgs", cmdArgs),
|
||||
zap.Int("argsCount", len(cmdArgs)),
|
||||
)
|
||||
|
||||
|
||||
// 验证命令参数
|
||||
if len(cmdArgs) == 0 {
|
||||
e.logger.Warn("命令参数为空",
|
||||
@@ -92,7 +106,7 @@ func (e *Executor) ExecuteTool(ctx context.Context, toolName string, args map[st
|
||||
|
||||
// 执行命令
|
||||
cmd := exec.CommandContext(ctx, toolConfig.Command, cmdArgs...)
|
||||
|
||||
|
||||
e.logger.Info("执行安全工具",
|
||||
zap.String("tool", toolName),
|
||||
zap.Strings("args", cmdArgs),
|
||||
@@ -136,8 +150,12 @@ func (e *Executor) ExecuteTool(ctx context.Context, toolName string, args map[st
|
||||
func (e *Executor) RegisterTools(mcpServer *mcp.Server) {
|
||||
e.logger.Info("开始注册工具",
|
||||
zap.Int("totalTools", len(e.config.Tools)),
|
||||
zap.Int("enabledTools", len(e.toolIndex)),
|
||||
)
|
||||
|
||||
|
||||
// 重新构建索引(以防配置更新)
|
||||
e.buildToolIndex()
|
||||
|
||||
for i, toolConfig := range e.config.Tools {
|
||||
if !toolConfig.Enabled {
|
||||
e.logger.Debug("跳过未启用的工具",
|
||||
@@ -149,7 +167,7 @@ func (e *Executor) RegisterTools(mcpServer *mcp.Server) {
|
||||
// 创建工具配置的副本,避免闭包问题
|
||||
toolName := toolConfig.Name
|
||||
toolConfigCopy := toolConfig
|
||||
|
||||
|
||||
// 使用简短描述(如果存在),否则使用详细描述的前100个字符
|
||||
shortDesc := toolConfigCopy.ShortDescription
|
||||
if shortDesc == "" {
|
||||
@@ -166,7 +184,7 @@ func (e *Executor) RegisterTools(mcpServer *mcp.Server) {
|
||||
shortDesc = desc
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
tool := mcp.Tool{
|
||||
Name: toolConfigCopy.Name,
|
||||
Description: toolConfigCopy.Description,
|
||||
@@ -189,7 +207,7 @@ func (e *Executor) RegisterTools(mcpServer *mcp.Server) {
|
||||
zap.Int("index", i),
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
e.logger.Info("工具注册完成",
|
||||
zap.Int("registeredCount", len(e.config.Tools)),
|
||||
)
|
||||
@@ -208,7 +226,7 @@ func (e *Executor) buildCommandArgs(toolName string, toolConfig *config.ToolConf
|
||||
hasScanType = true
|
||||
scanTypeValue = scanType
|
||||
}
|
||||
|
||||
|
||||
// 添加固定参数(如果指定了 scan_type,可能需要过滤掉默认的扫描类型参数)
|
||||
if hasScanType && toolName == "nmap" {
|
||||
// 对于 nmap,如果指定了 scan_type,跳过默认的 -sT -sV -sC
|
||||
@@ -237,7 +255,7 @@ func (e *Executor) buildCommandArgs(toolName string, toolConfig *config.ToolConf
|
||||
if param.Name == "additional_args" || param.Name == "scan_type" || param.Name == "action" {
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
value := e.getParamValue(args, param)
|
||||
if value == nil {
|
||||
if param.Required {
|
||||
@@ -255,7 +273,7 @@ func (e *Executor) buildCommandArgs(toolName string, toolConfig *config.ToolConf
|
||||
if param.Type == "bool" {
|
||||
var boolVal bool
|
||||
var ok bool
|
||||
|
||||
|
||||
// 尝试多种类型转换
|
||||
if boolVal, ok = value.(bool); ok {
|
||||
// 已经是布尔值
|
||||
@@ -272,7 +290,7 @@ func (e *Executor) buildCommandArgs(toolName string, toolConfig *config.ToolConf
|
||||
boolVal = strVal == "true" || strVal == "1" || strVal == "yes"
|
||||
ok = true
|
||||
}
|
||||
|
||||
|
||||
if ok {
|
||||
if !boolVal {
|
||||
continue // false 时不添加任何参数
|
||||
@@ -340,7 +358,7 @@ func (e *Executor) buildCommandArgs(toolName string, toolConfig *config.ToolConf
|
||||
if param.Name == "additional_args" || param.Name == "scan_type" || param.Name == "action" {
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
if param.Position != nil && *param.Position == i {
|
||||
value := e.getParamValue(args, param)
|
||||
if value == nil {
|
||||
@@ -365,14 +383,14 @@ func (e *Executor) buildCommandArgs(toolName string, toolConfig *config.ToolConf
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// 特殊处理:additional_args 参数(需要按空格分割成多个参数)
|
||||
if additionalArgs, ok := args["additional_args"].(string); ok && additionalArgs != "" {
|
||||
// 按空格分割,但保留引号内的内容
|
||||
additionalArgsList := e.parseAdditionalArgs(additionalArgs)
|
||||
cmdArgs = append(cmdArgs, additionalArgsList...)
|
||||
}
|
||||
|
||||
|
||||
// 特殊处理:scan_type 参数(需要按空格分割并插入到合适位置)
|
||||
if hasScanType {
|
||||
scanTypeArgs := e.parseAdditionalArgs(scanTypeValue)
|
||||
@@ -403,7 +421,7 @@ func (e *Executor) buildCommandArgs(toolName string, toolConfig *config.ToolConf
|
||||
// 如果没有定义参数配置,使用固定参数和通用处理
|
||||
// 添加固定参数
|
||||
cmdArgs = append(cmdArgs, toolConfig.Args...)
|
||||
|
||||
|
||||
// 通用处理:将参数转换为命令行参数
|
||||
for key, value := range args {
|
||||
if key == "_tool_name" {
|
||||
@@ -426,23 +444,23 @@ func (e *Executor) parseAdditionalArgs(argsStr string) []string {
|
||||
if argsStr == "" {
|
||||
return []string{}
|
||||
}
|
||||
|
||||
|
||||
result := make([]string, 0)
|
||||
var current strings.Builder
|
||||
inQuotes := false
|
||||
var quoteChar rune
|
||||
escapeNext := false
|
||||
|
||||
|
||||
runes := []rune(argsStr)
|
||||
for i := 0; i < len(runes); i++ {
|
||||
r := runes[i]
|
||||
|
||||
|
||||
if escapeNext {
|
||||
current.WriteRune(r)
|
||||
escapeNext = false
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
if r == '\\' {
|
||||
// 检查下一个字符是否是引号
|
||||
if i+1 < len(runes) && (runes[i+1] == '"' || runes[i+1] == '\'') {
|
||||
@@ -456,19 +474,19 @@ func (e *Executor) parseAdditionalArgs(argsStr string) []string {
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
if !inQuotes && (r == '"' || r == '\'') {
|
||||
inQuotes = true
|
||||
quoteChar = r
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
if inQuotes && r == quoteChar {
|
||||
inQuotes = false
|
||||
quoteChar = 0
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
if !inQuotes && (r == ' ' || r == '\t' || r == '\n') {
|
||||
if current.Len() > 0 {
|
||||
result = append(result, current.String())
|
||||
@@ -476,20 +494,20 @@ func (e *Executor) parseAdditionalArgs(argsStr string) []string {
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
current.WriteRune(r)
|
||||
}
|
||||
|
||||
|
||||
// 处理最后一个参数(如果存在)
|
||||
if current.Len() > 0 {
|
||||
result = append(result, current.String())
|
||||
}
|
||||
|
||||
|
||||
// 如果解析结果为空,使用简单的空格分割作为降级方案
|
||||
if len(result) == 0 {
|
||||
result = strings.Fields(argsStr)
|
||||
}
|
||||
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
@@ -638,9 +656,9 @@ func (e *Executor) executeSystemCommand(ctx context.Context, args map[string]int
|
||||
// buildInputSchema 构建输入模式
|
||||
func (e *Executor) buildInputSchema(toolConfig *config.ToolConfig) map[string]interface{} {
|
||||
schema := map[string]interface{}{
|
||||
"type": "object",
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{},
|
||||
"required": []string{},
|
||||
"required": []string{},
|
||||
}
|
||||
|
||||
// 如果配置中定义了参数,优先使用配置中的参数定义
|
||||
@@ -651,7 +669,7 @@ func (e *Executor) buildInputSchema(toolConfig *config.ToolConfig) map[string]in
|
||||
for _, param := range toolConfig.Parameters {
|
||||
// 转换类型为OpenAI/JSON Schema标准类型
|
||||
openAIType := e.convertToOpenAIType(param.Type)
|
||||
|
||||
|
||||
prop := map[string]interface{}{
|
||||
"type": openAIType,
|
||||
"description": param.Description,
|
||||
@@ -750,5 +768,3 @@ func (e *Executor) GetVulnerabilityReport(vulnerabilities []Vulnerability) map[s
|
||||
"generatedAt": time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user