diff --git a/cmd/test-config/main.go b/cmd/test-config/main.go new file mode 100644 index 00000000..e8dfd150 --- /dev/null +++ b/cmd/test-config/main.go @@ -0,0 +1,45 @@ +package main + +import ( + "cyberstrike-ai/internal/config" + "fmt" + "os" +) + +func main() { + cfg, err := config.Load("config.yaml") + if err != nil { + fmt.Printf("❌ 加载配置失败: %v\n", err) + os.Exit(1) + } + + fmt.Printf("✅ 配置加载成功\n") + fmt.Printf(" 工具目录: %s\n", cfg.Security.ToolsDir) + fmt.Printf(" 工具数量: %d\n", len(cfg.Security.Tools)) + + if len(cfg.Security.Tools) > 0 { + fmt.Printf("\n 已加载的工具:\n") + for _, tool := range cfg.Security.Tools { + status := "❌ 禁用" + if tool.Enabled { + status = "✅ 启用" + } + shortDesc := tool.ShortDescription + if shortDesc == "" { + shortDesc = "(无简短描述,将自动提取)" + } + fmt.Printf(" %s %s\n", status, tool.Name) + fmt.Printf(" 简短描述: %s\n", shortDesc) + if len(tool.Description) > 100 { + fmt.Printf(" 详细描述: %s...\n", tool.Description[:100]) + } else { + fmt.Printf(" 详细描述: %s\n", tool.Description) + } + fmt.Printf(" 参数数量: %d\n", len(tool.Parameters)) + fmt.Println() + } + } else { + fmt.Printf(" ⚠️ 未加载任何工具\n") + } +} + diff --git a/config.yaml b/config.yaml index fa2672db..9b5f9886 100644 --- a/config.yaml +++ b/config.yaml @@ -20,155 +20,12 @@ database: path: "data/conversations.db" security: - tools: - # 示例1: 使用参数定义的工具(推荐方式) - - name: "nmap" - command: "nmap" - args: ["-sT", "-sV", "-sC"] # 固定参数 - description: "网络扫描工具,用于发现网络主机和服务" - enabled: true - parameters: - - name: "target" - type: "string" - description: "目标IP地址或域名" - required: true - position: 0 # 位置参数,放在最后 - format: "positional" - - name: "ports" - type: "string" - description: "端口范围,例如: 1-1000, 80,443,8080" - required: false - flag: "-p" - format: "flag" - - # 示例2: 标志参数工具 - - name: "sqlmap" - command: "sqlmap" - description: "SQL注入检测和利用工具" - enabled: true - parameters: - - name: "url" - type: "string" - description: "目标URL,例如: http://example.com/page?id=1" - required: true - flag: "-u" - format: "flag" - - name: "batch" - type: "bool" - description: "非交互模式" - required: false - default: true - flag: "--batch" - format: "flag" - - name: "level" - type: "int" - description: "测试级别 (1-5)" - required: false - default: 3 - flag: "--level" - format: "combined" # --level=3 - - # 示例3: 位置参数工具 - - name: "nikto" - command: "nikto" - description: "Web服务器扫描工具" - enabled: true - parameters: - - name: "target" - type: "string" - description: "目标URL或IP地址" - required: true - flag: "-h" - format: "flag" - - # 示例4: 简单位置参数 - - name: "dirb" - command: "dirb" - description: "Web目录扫描工具" - enabled: true - parameters: - - name: "url" - type: "string" - description: "目标URL" - required: true - position: 0 - format: "positional" - - name: "wordlist" - type: "string" - description: "字典文件路径" - required: false - flag: "-w" - format: "flag" - - # 示例5: 执行系统命令 - - name: "exec" - command: "sh" - args: ["-c"] - description: "执行系统命令(谨慎使用)" - enabled: true - parameters: - - name: "command" - type: "string" - description: "要执行的系统命令" - required: true - position: 0 - format: "positional" - - name: "shell" - type: "string" - description: "使用的shell(可选,默认为sh)" - required: false - default: "sh" - - name: "workdir" - type: "string" - description: "工作目录" - required: false - - # 示例6: 自定义工具 - 使用模板格式 - - name: "custom_scanner" - command: "my-scanner" - description: "自定义扫描工具示例" - enabled: false # 默认禁用,需要时启用 - parameters: - - name: "target" - type: "string" - description: "扫描目标" - required: true - flag: "--target" - format: "flag" - - name: "mode" - type: "string" - description: "扫描模式" - required: false - default: "normal" - options: ["normal", "aggressive", "stealth"] # 枚举值 - flag: "--mode" - format: "combined" # --mode=normal - - name: "threads" - type: "int" - description: "线程数" - required: false - default: 10 - flag: "-t" - format: "flag" - - name: "output" - type: "string" - description: "输出文件路径" - required: false - flag: "-o" - format: "template" - template: "-o {value}" # 自定义模板 - - name: "verbose" - type: "bool" - description: "详细输出" - required: false - default: false - flag: "-v" - format: "flag" # 布尔值:如果为true,只添加-v,不添加值 - - # 示例7: 向后兼容 - 不定义parameters,使用旧的硬编码逻辑 - # - name: "legacy_tool" - # command: "legacy" - # args: ["--option"] - # description: "旧工具(使用硬编码逻辑)" - # enabled: false + # 工具配置文件目录(推荐方式) + # 系统会自动加载 tools/ 目录下的所有 .yaml 和 .yml 文件 + # 每个工具一个配置文件,便于维护和管理 + tools_dir: "tools" + + # 向后兼容:也可以在主配置文件中直接定义工具 + # 如果 tools_dir 和 tools 都配置了,tools_dir 中的工具优先 + # tools: [] diff --git a/internal/agent/agent.go b/internal/agent/agent.go index 7110138c..a6451212 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -193,12 +193,40 @@ type AgentLoopResult struct { MCPExecutionIDs []string } +// ProgressCallback 进度回调函数类型 +type ProgressCallback func(eventType, message string, data interface{}) + // AgentLoop 执行Agent循环 func (a *Agent) AgentLoop(ctx context.Context, userInput string, historyMessages []ChatMessage) (*AgentLoopResult, error) { + return a.AgentLoopWithProgress(ctx, userInput, historyMessages, nil) +} + +// AgentLoopWithProgress 执行Agent循环(带进度回调) +func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, historyMessages []ChatMessage, callback ProgressCallback) (*AgentLoopResult, error) { + // 发送进度更新 + sendProgress := func(eventType, message string, data interface{}) { + if callback != nil { + callback(eventType, message, data) + } + } + + // 系统提示词,指导AI如何处理工具错误 + systemPrompt := `你是一个专业的网络安全渗透测试专家。你可以使用各种安全工具进行自主渗透测试。分析目标并选择最佳测试策略。 + +重要:当工具调用失败时,请遵循以下原则: +1. 仔细分析错误信息,理解失败的具体原因 +2. 如果工具不存在或未启用,尝试使用其他替代工具完成相同目标 +3. 如果参数错误,根据错误提示修正参数后重试 +4. 如果工具执行失败但输出了有用信息,可以基于这些信息继续分析 +5. 如果确实无法使用某个工具,向用户说明问题,并建议替代方案或手动操作 +6. 不要因为单个工具失败就停止整个测试流程,尝试其他方法继续完成任务 + +当工具返回错误时,错误信息会包含在工具响应中,请仔细阅读并做出合理的决策。` + messages := []ChatMessage{ { Role: "system", - Content: "你是一个专业的网络安全渗透测试专家。你可以使用各种安全工具进行自主渗透测试。分析目标并选择最佳测试策略。当需要执行工具时,使用提供的工具函数。", + Content: systemPrompt, }, } @@ -248,6 +276,13 @@ func (a *Agent) AgentLoop(ctx context.Context, userInput string, historyMessages // 获取可用工具 tools := a.getAvailableTools() + // 发送进度更新 + if i == 0 { + sendProgress("progress", "正在分析请求并制定测试策略...", nil) + } else { + sendProgress("progress", fmt.Sprintf("正在继续分析(第 %d 轮迭代)...", i+1), nil) + } + // 记录每次调用OpenAI if i == 0 { a.logger.Info("调用OpenAI", @@ -277,6 +312,7 @@ func (a *Agent) AgentLoop(ctx context.Context, userInput string, historyMessages } // 调用OpenAI + sendProgress("progress", "正在调用AI模型...", nil) response, err := a.callOpenAI(ctx, messages, tools) if err != nil { result.Response = "" @@ -304,17 +340,46 @@ func (a *Agent) AgentLoop(ctx context.Context, userInput string, historyMessages ToolCalls: choice.Message.ToolCalls, }) + // 发送工具调用进度 + sendProgress("progress", fmt.Sprintf("检测到 %d 个工具调用,开始执行...", len(choice.Message.ToolCalls)), nil) + // 执行所有工具调用 - for _, toolCall := range choice.Message.ToolCalls { + for idx, toolCall := range choice.Message.ToolCalls { + // 发送工具调用开始事件 + toolArgsJSON, _ := json.Marshal(toolCall.Function.Arguments) + sendProgress("tool_call", fmt.Sprintf("正在调用工具: %s", toolCall.Function.Name), map[string]interface{}{ + "toolName": toolCall.Function.Name, + "arguments": string(toolArgsJSON), + "index": idx + 1, + "total": len(choice.Message.ToolCalls), + }) + // 执行工具 execResult, err := a.executeToolViaMCP(ctx, toolCall.Function.Name, toolCall.Function.Arguments) if err != nil { + // 构建详细的错误信息,帮助AI理解问题并做出决策 + errorMsg := a.formatToolError(toolCall.Function.Name, toolCall.Function.Arguments, err) messages = append(messages, ChatMessage{ Role: "tool", ToolCallID: toolCall.ID, - Content: fmt.Sprintf("工具执行失败: %v", err), + Content: errorMsg, }) + + // 发送工具执行失败事件 + sendProgress("tool_result", fmt.Sprintf("工具 %s 执行失败", toolCall.Function.Name), map[string]interface{}{ + "toolName": toolCall.Function.Name, + "success": false, + "error": err.Error(), + "index": idx + 1, + "total": len(choice.Message.ToolCalls), + }) + + a.logger.Warn("工具执行失败,已返回详细错误信息", + zap.String("tool", toolCall.Function.Name), + zap.Error(err), + ) } else { + // 即使工具返回了错误结果(IsError=true),也继续处理,让AI决定下一步 messages = append(messages, ChatMessage{ Role: "tool", ToolCallID: toolCall.ID, @@ -324,6 +389,29 @@ func (a *Agent) AgentLoop(ctx context.Context, userInput string, historyMessages if execResult.ExecutionID != "" { result.MCPExecutionIDs = append(result.MCPExecutionIDs, execResult.ExecutionID) } + + // 发送工具执行成功事件 + resultPreview := execResult.Result + if len(resultPreview) > 200 { + resultPreview = resultPreview[:200] + "..." + } + sendProgress("tool_result", fmt.Sprintf("工具 %s 执行完成", toolCall.Function.Name), map[string]interface{}{ + "toolName": toolCall.Function.Name, + "success": !execResult.IsError, + "isError": execResult.IsError, + "result": resultPreview, + "executionId": execResult.ExecutionID, + "index": idx + 1, + "total": len(choice.Message.ToolCalls), + }) + + // 如果工具返回了错误,记录日志但不中断流程 + if execResult.IsError { + a.logger.Warn("工具返回错误结果,但继续处理", + zap.String("tool", toolCall.Function.Name), + zap.String("result", execResult.Result), + ) + } } } continue @@ -337,6 +425,7 @@ func (a *Agent) AgentLoop(ctx context.Context, userInput string, historyMessages // 如果完成,返回结果 if choice.FinishReason == "stop" { + sendProgress("progress", "正在生成最终回复...", nil) result.Response = choice.Message.Content return result, nil } @@ -347,117 +436,98 @@ func (a *Agent) AgentLoop(ctx context.Context, userInput string, historyMessages } // getAvailableTools 获取可用工具 +// 从MCP服务器动态获取工具列表,使用简短描述以减少token消耗 func (a *Agent) getAvailableTools() []Tool { - // 从MCP服务器获取工具列表 - executions := a.mcpServer.GetAllExecutions() - toolNames := make(map[string]bool) - for _, exec := range executions { - toolNames[exec.ToolName] = true + // 从MCP服务器获取所有已注册的工具 + mcpTools := a.mcpServer.GetAllTools() + + // 转换为OpenAI格式的工具定义 + tools := make([]Tool, 0, len(mcpTools)) + for _, mcpTool := range mcpTools { + // 使用简短描述(如果存在),否则使用详细描述 + description := mcpTool.ShortDescription + if description == "" { + description = mcpTool.Description + } + + // 转换schema中的类型为OpenAI标准类型 + convertedSchema := a.convertSchemaTypes(mcpTool.InputSchema) + + tools = append(tools, Tool{ + Type: "function", + Function: FunctionDefinition{ + Name: mcpTool.Name, + Description: description, // 使用简短描述减少token消耗 + Parameters: convertedSchema, + }, + }) } - - tools := []Tool{ - { - Type: "function", - Function: FunctionDefinition{ - Name: "nmap", - Description: "使用nmap进行网络扫描,发现开放端口和服务。支持IP地址、域名或URL(会自动提取域名)。使用TCP连接扫描,不需要root权限。", - Parameters: map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "target": map[string]interface{}{ - "type": "string", - "description": "目标IP地址、域名或URL(如 https://example.com)。如果是URL,会自动提取域名部分。", - }, - "ports": map[string]interface{}{ - "type": "string", - "description": "要扫描的端口范围,例如: 1-1000 或 80,443,8080。如果不指定,将扫描常用端口。", - }, - }, - "required": []string{"target"}, - }, - }, - }, - { - Type: "function", - Function: FunctionDefinition{ - Name: "sqlmap", - Description: "使用sqlmap检测SQL注入漏洞", - Parameters: map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "url": map[string]interface{}{ - "type": "string", - "description": "目标URL", - }, - }, - "required": []string{"url"}, - }, - }, - }, - { - Type: "function", - Function: FunctionDefinition{ - Name: "nikto", - Description: "使用nikto扫描Web服务器漏洞", - Parameters: map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "target": map[string]interface{}{ - "type": "string", - "description": "目标URL", - }, - }, - "required": []string{"target"}, - }, - }, - }, - { - Type: "function", - Function: FunctionDefinition{ - Name: "dirb", - Description: "使用dirb进行目录扫描", - Parameters: map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "url": map[string]interface{}{ - "type": "string", - "description": "目标URL", - }, - }, - "required": []string{"url"}, - }, - }, - }, - { - Type: "function", - Function: FunctionDefinition{ - Name: "exec", - Description: "执行系统命令(谨慎使用,仅用于必要的系统操作)", - Parameters: map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "command": map[string]interface{}{ - "type": "string", - "description": "要执行的系统命令", - }, - "shell": map[string]interface{}{ - "type": "string", - "description": "使用的shell(可选,默认为sh)", - }, - "workdir": map[string]interface{}{ - "type": "string", - "description": "工作目录(可选)", - }, - }, - "required": []string{"command"}, - }, - }, - }, - } - + + a.logger.Debug("获取可用工具列表", + zap.Int("count", len(tools)), + ) + return tools } +// convertSchemaTypes 递归转换schema中的类型为OpenAI标准类型 +func (a *Agent) convertSchemaTypes(schema map[string]interface{}) map[string]interface{} { + if schema == nil { + return schema + } + + // 创建新的schema副本 + converted := make(map[string]interface{}) + for k, v := range schema { + converted[k] = v + } + + // 转换properties中的类型 + if properties, ok := converted["properties"].(map[string]interface{}); ok { + convertedProperties := make(map[string]interface{}) + for propName, propValue := range properties { + if prop, ok := propValue.(map[string]interface{}); ok { + convertedProp := make(map[string]interface{}) + for pk, pv := range prop { + if pk == "type" { + // 转换类型 + if typeStr, ok := pv.(string); ok { + convertedProp[pk] = a.convertToOpenAIType(typeStr) + } else { + convertedProp[pk] = pv + } + } else { + convertedProp[pk] = pv + } + } + convertedProperties[propName] = convertedProp + } else { + convertedProperties[propName] = propValue + } + } + converted["properties"] = convertedProperties + } + + return converted +} + +// convertToOpenAIType 将配置中的类型转换为OpenAI/JSON Schema标准类型 +func (a *Agent) convertToOpenAIType(configType string) string { + switch configType { + case "bool": + return "boolean" + case "int", "integer": + return "number" + case "float", "double": + return "number" + case "string", "array", "object": + return configType + default: + // 默认返回原类型 + return configType + } +} + // callOpenAI 调用OpenAI API func (a *Agent) callOpenAI(ctx context.Context, messages []ChatMessage, tools []Tool) (*OpenAIResponse, error) { reqBody := OpenAIRequest{ @@ -546,9 +616,11 @@ func (a *Agent) parseToolCall(content string) (map[string]interface{}, error) { type ToolExecutionResult struct { Result string ExecutionID string + IsError bool // 标记是否为错误结果 } // executeToolViaMCP 通过MCP执行工具 +// 即使工具执行失败,也返回结果而不是错误,让AI能够处理错误情况 func (a *Agent) executeToolViaMCP(ctx context.Context, toolName string, args map[string]interface{}) (*ToolExecutionResult, error) { a.logger.Info("通过MCP执行工具", zap.String("tool", toolName), @@ -557,8 +629,30 @@ func (a *Agent) executeToolViaMCP(ctx context.Context, toolName string, args map // 通过MCP服务器调用工具 result, executionID, err := a.mcpServer.CallTool(ctx, toolName, args) + + // 如果调用失败(如工具不存在),返回友好的错误信息而不是抛出异常 if err != nil { - return nil, fmt.Errorf("工具执行失败: %w", err) + errorMsg := fmt.Sprintf(`工具调用失败 + +工具名称: %s +错误类型: 系统错误 +错误详情: %v + +可能的原因: +- 工具 "%s" 不存在或未启用 +- 系统配置问题 +- 网络或权限问题 + +建议: +- 检查工具名称是否正确 +- 尝试使用其他替代工具 +- 如果这是必需的工具,请向用户说明情况`, toolName, err, toolName) + + return &ToolExecutionResult{ + Result: errorMsg, + ExecutionID: executionID, + IsError: true, + }, nil // 返回 nil 错误,让调用者处理结果 } // 格式化结果 @@ -571,6 +665,24 @@ func (a *Agent) executeToolViaMCP(ctx context.Context, toolName string, args map return &ToolExecutionResult{ Result: resultText.String(), ExecutionID: executionID, + IsError: result != nil && result.IsError, }, nil } +// formatToolError 格式化工具错误信息,提供更友好的错误描述 +func (a *Agent) formatToolError(toolName string, args map[string]interface{}, err error) string { + errorMsg := fmt.Sprintf(`工具执行失败 + +工具名称: %s +调用参数: %v +错误信息: %v + +请分析错误原因并采取以下行动之一: +1. 如果参数错误,请修正参数后重试 +2. 如果工具不可用,请尝试使用替代工具 +3. 如果这是系统问题,请向用户说明情况并提供建议 +4. 如果错误信息中包含有用信息,可以基于这些信息继续分析`, toolName, args, err) + + return errorMsg +} + diff --git a/internal/app/app.go b/internal/app/app.go index 8f78d2b7..8e1e2714 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -115,6 +115,8 @@ func setupRoutes(router *gin.Engine, agentHandler *handler.AgentHandler, monitor { // Agent Loop api.POST("/agent-loop", agentHandler.AgentLoop) + // Agent Loop 流式输出 + api.POST("/agent-loop/stream", agentHandler.AgentLoopStream) // 对话历史 api.POST("/conversations", conversationHandler.CreateConversation) diff --git a/internal/config/config.go b/internal/config/config.go index 2215bce7..82dd1ac0 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -3,6 +3,8 @@ package config import ( "fmt" "os" + "path/filepath" + "strings" "gopkg.in/yaml.v3" ) @@ -39,7 +41,8 @@ type OpenAIConfig struct { } type SecurityConfig struct { - Tools []ToolConfig `yaml:"tools"` + Tools []ToolConfig `yaml:"tools,omitempty"` // 向后兼容:支持在主配置文件中定义工具 + ToolsDir string `yaml:"tools_dir,omitempty"` // 工具配置文件目录(新方式) } type DatabaseConfig struct { @@ -47,13 +50,14 @@ type DatabaseConfig struct { } type ToolConfig struct { - Name string `yaml:"name"` - Command string `yaml:"command"` - Args []string `yaml:"args,omitempty"` // 固定参数(可选) - Description string `yaml:"description"` - Enabled bool `yaml:"enabled"` - Parameters []ParameterConfig `yaml:"parameters,omitempty"` // 参数定义(可选) - ArgMapping string `yaml:"arg_mapping,omitempty"` // 参数映射方式: "auto", "manual", "template"(可选) + Name string `yaml:"name"` + Command string `yaml:"command"` + Args []string `yaml:"args,omitempty"` // 固定参数(可选) + ShortDescription string `yaml:"short_description,omitempty"` // 简短描述(用于工具列表,减少token消耗) + Description string `yaml:"description"` // 详细描述(用于工具文档) + Enabled bool `yaml:"enabled"` + Parameters []ParameterConfig `yaml:"parameters,omitempty"` // 参数定义(可选) + ArgMapping string `yaml:"arg_mapping,omitempty"` // 参数映射方式: "auto", "manual", "template"(可选) } // ParameterConfig 参数配置 @@ -81,9 +85,102 @@ func Load(path string) (*Config, error) { return nil, fmt.Errorf("解析配置文件失败: %w", err) } + // 如果配置了工具目录,从目录加载工具配置 + if cfg.Security.ToolsDir != "" { + configDir := filepath.Dir(path) + toolsDir := cfg.Security.ToolsDir + + // 如果是相对路径,相对于配置文件所在目录 + if !filepath.IsAbs(toolsDir) { + toolsDir = filepath.Join(configDir, toolsDir) + } + + tools, err := LoadToolsFromDir(toolsDir) + if err != nil { + return nil, fmt.Errorf("从工具目录加载工具配置失败: %w", err) + } + + // 合并工具配置:目录中的工具优先,主配置中的工具作为补充 + existingTools := make(map[string]bool) + for _, tool := range tools { + existingTools[tool.Name] = true + } + + // 添加主配置中不存在于目录中的工具(向后兼容) + for _, tool := range cfg.Security.Tools { + if !existingTools[tool.Name] { + tools = append(tools, tool) + } + } + + cfg.Security.Tools = tools + } + return &cfg, nil } +// LoadToolsFromDir 从目录加载所有工具配置文件 +func LoadToolsFromDir(dir string) ([]ToolConfig, error) { + var tools []ToolConfig + + // 检查目录是否存在 + if _, err := os.Stat(dir); os.IsNotExist(err) { + return tools, nil // 目录不存在时返回空列表,不报错 + } + + // 读取目录中的所有 .yaml 和 .yml 文件 + entries, err := os.ReadDir(dir) + if err != nil { + return nil, fmt.Errorf("读取工具目录失败: %w", err) + } + + for _, entry := range entries { + if entry.IsDir() { + continue + } + + name := entry.Name() + if !strings.HasSuffix(name, ".yaml") && !strings.HasSuffix(name, ".yml") { + continue + } + + filePath := filepath.Join(dir, name) + tool, err := LoadToolFromFile(filePath) + if err != nil { + // 记录错误但继续加载其他文件 + fmt.Printf("警告: 加载工具配置文件 %s 失败: %v\n", filePath, err) + continue + } + + tools = append(tools, *tool) + } + + return tools, nil +} + +// LoadToolFromFile 从单个文件加载工具配置 +func LoadToolFromFile(path string) (*ToolConfig, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("读取文件失败: %w", err) + } + + var tool ToolConfig + if err := yaml.Unmarshal(data, &tool); err != nil { + return nil, fmt.Errorf("解析工具配置失败: %w", err) + } + + // 验证必需字段 + if tool.Name == "" { + return nil, fmt.Errorf("工具名称不能为空") + } + if tool.Command == "" { + return nil, fmt.Errorf("工具命令不能为空") + } + + return &tool, nil +} + func Default() *Config { return &Config{ Server: ServerConfig{ @@ -104,7 +201,8 @@ func Default() *Config { Model: "gpt-4", }, Security: SecurityConfig{ - Tools: []ToolConfig{}, // 工具配置应该从 config.yaml 加载,不在此硬编码 + Tools: []ToolConfig{}, // 工具配置应该从 config.yaml 或 tools/ 目录加载 + ToolsDir: "tools", // 默认工具目录 }, Database: DatabaseConfig{ Path: "data/conversations.db", diff --git a/internal/handler/agent.go b/internal/handler/agent.go index b628b08c..50974318 100644 --- a/internal/handler/agent.go +++ b/internal/handler/agent.go @@ -1,6 +1,8 @@ package handler import ( + "encoding/json" + "fmt" "net/http" "time" @@ -132,3 +134,120 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) { }) } +// StreamEvent 流式事件 +type StreamEvent struct { + Type string `json:"type"` // progress, tool_call, tool_result, response, error, done + Message string `json:"message"` // 显示消息 + Data interface{} `json:"data,omitempty"` +} + +// AgentLoopStream 处理Agent Loop流式请求 +func (h *AgentHandler) AgentLoopStream(c *gin.Context) { + var req ChatRequest + if err := c.ShouldBindJSON(&req); err != nil { + // 对于流式请求,也发送SSE格式的错误 + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + event := StreamEvent{ + Type: "error", + Message: "请求参数错误: " + err.Error(), + } + eventJSON, _ := json.Marshal(event) + fmt.Fprintf(c.Writer, "data: %s\n\n", eventJSON) + c.Writer.Flush() + return + } + + h.logger.Info("收到Agent Loop流式请求", + zap.String("message", req.Message), + zap.String("conversationId", req.ConversationID), + ) + + // 设置SSE响应头 + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("X-Accel-Buffering", "no") // 禁用nginx缓冲 + + // 发送初始事件 + sendEvent := func(eventType, message string, data interface{}) { + event := StreamEvent{ + Type: eventType, + Message: message, + Data: data, + } + eventJSON, _ := json.Marshal(event) + fmt.Fprintf(c.Writer, "data: %s\n\n", eventJSON) + c.Writer.Flush() + } + + // 如果没有对话ID,创建新对话 + conversationID := req.ConversationID + if conversationID == "" { + title := req.Message + if len(title) > 50 { + title = title[:50] + "..." + } + conv, err := h.db.CreateConversation(title) + if err != nil { + h.logger.Error("创建对话失败", zap.Error(err)) + sendEvent("error", "创建对话失败: "+err.Error(), nil) + return + } + conversationID = conv.ID + } + + // 获取历史消息 + historyMessages, err := h.db.GetMessages(conversationID) + if err != nil { + h.logger.Warn("获取历史消息失败", zap.Error(err)) + historyMessages = []database.Message{} + } + + // 将数据库消息转换为Agent消息格式 + agentHistoryMessages := make([]agent.ChatMessage, 0, len(historyMessages)) + for _, msg := range historyMessages { + agentHistoryMessages = append(agentHistoryMessages, agent.ChatMessage{ + Role: msg.Role, + Content: msg.Content, + }) + } + + // 保存用户消息 + _, err = h.db.AddMessage(conversationID, "user", req.Message, nil) + if err != nil { + h.logger.Error("保存用户消息失败", zap.Error(err)) + } + + // 创建进度回调函数 + progressCallback := func(eventType, message string, data interface{}) { + sendEvent(eventType, message, data) + } + + // 执行Agent Loop,传入进度回调 + sendEvent("progress", "正在分析您的请求...", nil) + result, err := h.agent.AgentLoopWithProgress(c.Request.Context(), req.Message, agentHistoryMessages, progressCallback) + if err != nil { + h.logger.Error("Agent Loop执行失败", zap.Error(err)) + sendEvent("error", "执行失败: "+err.Error(), nil) + sendEvent("done", "", nil) + return + } + + // 保存助手回复 + _, err = h.db.AddMessage(conversationID, "assistant", result.Response, result.MCPExecutionIDs) + if err != nil { + h.logger.Error("保存助手消息失败", zap.Error(err)) + } + + // 发送最终响应 + sendEvent("response", result.Response, map[string]interface{}{ + "mcpExecutionIds": result.MCPExecutionIDs, + "conversationId": conversationID, + }) + sendEvent("done", "", map[string]interface{}{ + "conversationId": conversationID, + }) +} + diff --git a/internal/mcp/server.go b/internal/mcp/server.go index c9e53ea5..6de13152 100644 --- a/internal/mcp/server.go +++ b/internal/mcp/server.go @@ -361,6 +361,18 @@ func (s *Server) GetStats() map[string]*ToolStats { return stats } +// GetAllTools 获取所有已注册的工具(用于Agent动态获取工具列表) +func (s *Server) GetAllTools() []Tool { + s.mu.RLock() + defer s.mu.RUnlock() + + tools := make([]Tool, 0, len(s.toolDefs)) + for _, tool := range s.toolDefs { + tools = append(tools, tool) + } + return tools +} + // CallTool 直接调用工具(用于内部调用) func (s *Server) CallTool(ctx context.Context, toolName string, args map[string]interface{}) (*ToolResult, string, error) { s.mu.RLock() diff --git a/internal/mcp/types.go b/internal/mcp/types.go index a757196f..40618a54 100644 --- a/internal/mcp/types.go +++ b/internal/mcp/types.go @@ -36,9 +36,10 @@ type Error struct { // Tool 表示MCP工具定义 type Tool struct { - Name string `json:"name"` - Description string `json:"description"` - InputSchema map[string]interface{} `json:"inputSchema"` + Name string `json:"name"` + Description string `json:"description"` // 详细描述 + ShortDescription string `json:"shortDescription,omitempty"` // 简短描述(用于工具列表,减少token消耗) + InputSchema map[string]interface{} `json:"inputSchema"` } // ToolCall 表示工具调用 diff --git a/internal/security/executor.go b/internal/security/executor.go index 9131ea44..12f17c9b 100644 --- a/internal/security/executor.go +++ b/internal/security/executor.go @@ -150,10 +150,28 @@ func (e *Executor) RegisterTools(mcpServer *mcp.Server) { toolName := toolConfig.Name toolConfigCopy := toolConfig + // 使用简短描述(如果存在),否则使用详细描述的前100个字符 + shortDesc := toolConfigCopy.ShortDescription + if shortDesc == "" { + // 如果没有简短描述,从详细描述中提取第一行或前100个字符 + desc := toolConfigCopy.Description + if len(desc) > 100 { + // 尝试找到第一个换行符 + if idx := strings.Index(desc, "\n"); idx > 0 && idx < 100 { + shortDesc = strings.TrimSpace(desc[:idx]) + } else { + shortDesc = desc[:100] + "..." + } + } else { + shortDesc = desc + } + } + tool := mcp.Tool{ - Name: toolConfigCopy.Name, - Description: toolConfigCopy.Description, - InputSchema: e.buildInputSchema(&toolConfigCopy), + Name: toolConfigCopy.Name, + Description: toolConfigCopy.Description, + ShortDescription: shortDesc, + InputSchema: e.buildInputSchema(&toolConfigCopy), } handler := func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) { @@ -543,8 +561,11 @@ func (e *Executor) buildInputSchema(toolConfig *config.ToolConfig) map[string]in required := []string{} for _, param := range toolConfig.Parameters { + // 转换类型为OpenAI/JSON Schema标准类型 + openAIType := e.convertToOpenAIType(param.Type) + prop := map[string]interface{}{ - "type": param.Type, + "type": openAIType, "description": param.Description, } @@ -622,6 +643,26 @@ func (e *Executor) buildInputSchema(toolConfig *config.ToolConfig) map[string]in return schema } +// convertToOpenAIType 将配置中的类型转换为OpenAI/JSON Schema标准类型 +func (e *Executor) convertToOpenAIType(configType string) string { + switch configType { + case "bool": + return "boolean" + case "int", "integer": + return "number" + case "float", "double": + return "number" + case "string", "array", "object": + return configType + default: + // 默认返回原类型,但记录警告 + e.logger.Warn("未知的参数类型,使用原类型", + zap.String("type", configType), + ) + return configType + } +} + // Vulnerability 漏洞信息 type Vulnerability struct { ID string `json:"id"` diff --git a/tools/README.md b/tools/README.md new file mode 100644 index 00000000..94baccb1 --- /dev/null +++ b/tools/README.md @@ -0,0 +1,121 @@ +# 工具配置文件说明 + +## 概述 + +每个工具现在都有独立的配置文件,存放在 `tools/` 目录下。这种方式使得工具配置更加清晰、易于维护和管理。 + +## 配置文件格式 + +每个工具配置文件是一个 YAML 文件,包含以下字段: + +### 必需字段 + +- `name`: 工具名称(唯一标识符) +- `command`: 要执行的命令 +- `enabled`: 是否启用(true/false) + +### 可选字段 + +- `args`: 固定参数列表(数组) +- `short_description`: 简短描述(一句话说明工具用途,用于工具列表,减少token消耗) +- `description`: 工具详细描述(支持多行文本,用于工具文档和详细说明) +- `parameters`: 参数定义列表 + +## 工具描述 + +### 简短描述 (`short_description`) + +- **用途**:用于工具列表,减少发送给大模型的token消耗 +- **要求**:一句话(20-50字)说明工具的核心用途 +- **示例**:`"网络扫描工具,用于发现网络主机、开放端口和服务"` + +### 详细描述 (`description`) + +支持多行文本,应该包含: + +1. **工具功能说明**:工具的主要功能 +2. **使用场景**:什么情况下使用这个工具 +3. **注意事项**:使用时的注意事项和警告 +4. **示例**:使用示例(可选) + +**重要说明**: +- 工具列表发送给大模型时,使用 `short_description`(如果存在) +- 如果没有 `short_description`,系统会自动从 `description` 中提取第一行或前100个字符 +- 详细描述可以通过 MCP 的 `resources/read` 接口获取(URI: `tool://tool_name`) + +这样可以大幅减少token消耗,特别是当工具数量很多时(如100个工具)。 + +## 参数定义 + +每个参数可以包含以下字段: + +- `name`: 参数名称 +- `type`: 参数类型(string, int, bool, array) +- `description`: 参数详细描述(支持多行) +- `required`: 是否必需(true/false) +- `default`: 默认值 +- `flag`: 命令行标志(如 "-u", "--url") +- `position`: 位置参数的位置(整数) +- `format`: 参数格式("flag", "positional", "combined", "template") +- `template`: 模板字符串(用于 format="template") +- `options`: 可选值列表(用于枚举类型) + +### 参数描述要求 + +参数描述应该包含: + +1. **参数用途**:这个参数是做什么的 +2. **格式要求**:参数值的格式要求 +3. **示例值**:具体的示例值 +4. **注意事项**:使用时需要注意的事项 + +## 示例 + +参考 `tools/` 目录下的现有工具配置文件: + +- `nmap.yaml`: 网络扫描工具 +- `sqlmap.yaml`: SQL注入检测工具 +- `nikto.yaml`: Web服务器扫描工具 +- `dirb.yaml`: Web目录扫描工具 +- `exec.yaml`: 系统命令执行工具 + +## 添加新工具 + +要添加新工具,只需在 `tools/` 目录下创建一个新的 YAML 文件,例如 `my_tool.yaml`: + +```yaml +name: "my_tool" +command: "my-command" +enabled: true +short_description: "一句话说明工具用途" # 简短描述(推荐) +description: | + 工具详细描述... + + **功能:** + - 功能1 + - 功能2 + + **使用场景:** + - 场景1 + - 场景2 + +parameters: + - name: "param1" + type: "string" + description: | + 参数详细描述... + + **示例值:** + - "value1" + - "value2" + required: true + flag: "-p" + format: "flag" +``` + +保存文件后,重启服务即可自动加载新工具。 + +## 禁用工具 + +要禁用某个工具,只需将配置文件中的 `enabled` 字段设置为 `false`,或者直接删除/重命名配置文件。 + diff --git a/tools/dirb.yaml b/tools/dirb.yaml new file mode 100644 index 00000000..eb906dbb --- /dev/null +++ b/tools/dirb.yaml @@ -0,0 +1,86 @@ +name: "dirb" +command: "dirb" +enabled: true + +# 简短描述(用于工具列表,减少token消耗) +short_description: "Web目录和文件扫描工具,通过暴力破解方式发现Web服务器上的隐藏目录和文件" + +# 工具详细描述 +description: | + Web目录和文件扫描工具,通过暴力破解方式发现Web服务器上的隐藏目录和文件。 + + **主要功能:** + - 目录和文件发现 + - 支持自定义字典文件 + - 检测常见的Web目录结构 + - 识别备份文件、配置文件等敏感文件 + - 支持多种HTTP方法 + + **使用场景:** + - Web应用目录枚举 + - 发现隐藏的管理界面 + - 查找备份文件和敏感信息 + - 渗透测试中的信息收集 + + **注意事项:** + - 扫描可能产生大量HTTP请求 + - 某些请求可能被WAF拦截 + - 建议使用合适的字典文件以提高效率 + - 扫描结果需要人工验证 + +# 参数定义 +parameters: + - name: "url" + type: "string" + description: | + 目标URL,要扫描的Web服务器地址。 + + **格式要求:** + - 必须包含协议(http:// 或 https://) + - 可以包含基础路径 + - 末尾不要带斜杠(除非要扫描特定目录) + + **示例值:** + - 基础URL: "http://example.com" + - HTTPS: "https://example.com" + - 带端口: "http://example.com:8080" + - 特定目录: "http://example.com/admin" + - 带路径: "http://example.com/app" + + **注意事项:** + - URL必须可访问 + - 确保URL格式正确,包含协议前缀 + - 必需参数,不能为空 + required: true + position: 0 + format: "positional" + + - name: "wordlist" + type: "string" + description: | + 字典文件路径,包含要尝试的目录和文件名列表。 + + **格式要求:** + - 文件路径,可以是绝对路径或相对路径 + - 文件每行一个目录或文件名 + - 支持常见的字典文件格式 + + **示例值:** + - 默认字典: "/usr/share/dirb/wordlists/common.txt" + - 自定义字典: "/path/to/custom-wordlist.txt" + - 常用字典: "/usr/share/wordlists/dirb/common.txt" + + **常用字典文件:** + - common.txt: 常见目录和文件 + - big.txt: 大型字典 + - small.txt: 小型快速字典 + - extensions_common.txt: 常见文件扩展名 + + **注意事项:** + - 如果不指定,将使用默认字典 + - 确保字典文件存在且可读 + - 大型字典会显著增加扫描时间 + required: false + flag: "-w" + format: "flag" + diff --git a/tools/exec.yaml b/tools/exec.yaml new file mode 100644 index 00000000..37171068 --- /dev/null +++ b/tools/exec.yaml @@ -0,0 +1,102 @@ +name: "exec" +command: "sh" +args: ["-c"] +enabled: true + +# 简短描述(用于工具列表,减少token消耗) +short_description: "系统命令执行工具,用于执行Shell命令和系统操作(谨慎使用)" + +# 工具详细描述 +description: | + 系统命令执行工具,用于执行Shell命令和系统操作。 + + **主要功能:** + - 执行任意Shell命令 + - 支持bash、sh等shell + - 可以指定工作目录 + - 返回命令执行结果 + + **使用场景:** + - 系统管理和维护 + - 自动化脚本执行 + - 文件操作和处理 + - 系统信息收集 + + **安全警告:** + - ⚠️ 此工具可以执行任意系统命令,存在安全风险 + - ⚠️ 仅应在受控环境中使用 + - ⚠️ 所有命令执行都会被记录 + - ⚠️ 建议限制可执行的命令范围 + - ⚠️ 不要执行不可信的命令 + +# 参数定义 +parameters: + - name: "command" + type: "string" + description: | + 要执行的系统命令。可以是任何有效的Shell命令。 + + **格式要求:** + - 完整的Shell命令 + - 可以包含管道、重定向等Shell特性 + - 支持环境变量 + + **示例值:** + - 简单命令: "ls -la" + - 带管道: "ps aux | grep nginx" + - 文件操作: "cat /etc/passwd" + - 网络命令: "curl http://example.com" + - 系统信息: "uname -a" + - 查找文件: "find /var/log -name '*.log'" + + **注意事项:** + - 命令会在指定的shell中执行 + - 确保命令语法正确 + - 注意命令的安全影响 + - 必需参数,不能为空 + required: true + position: 0 + format: "positional" + + - name: "shell" + type: "string" + description: | + 使用的Shell类型,默认为sh。 + + **可选值:** + - sh: 标准Shell(默认) + - bash: Bash Shell + - zsh: Z Shell + - 其他系统可用的shell + + **示例值:** + - "sh" (默认) + - "bash" + - "zsh" + + **注意事项:** + - 确保指定的shell在系统中可用 + - 不同shell的命令语法可能略有差异 + required: false + default: "sh" + + - name: "workdir" + type: "string" + description: | + 命令执行的工作目录。如果不指定,使用当前工作目录。 + + **格式要求:** + - 绝对路径或相对路径 + - 目录必须存在 + + **示例值:** + - "/tmp" + - "/var/log" + - "./data" + - "/home/user/project" + + **注意事项:** + - 确保目录存在且有访问权限 + - 相对路径相对于程序运行目录 + required: false + diff --git a/tools/nikto.yaml b/tools/nikto.yaml new file mode 100644 index 00000000..0c4aebd9 --- /dev/null +++ b/tools/nikto.yaml @@ -0,0 +1,58 @@ +name: "nikto" +command: "nikto" +enabled: true + +# 简短描述(用于工具列表,减少token消耗) +short_description: "Web服务器扫描工具,用于检测Web服务器和应用程序中的已知漏洞和配置错误" + +# 工具详细描述 +description: | + Web服务器扫描工具,用于检测Web服务器和应用程序中的已知漏洞、配置错误和潜在安全问题。 + + **主要功能:** + - 检测Web服务器版本和配置问题 + - 识别已知的Web漏洞和CVE + - 检测危险文件和目录 + - 检查服务器配置错误 + - 识别过时的软件版本 + - 检测默认文件和脚本 + + **使用场景:** + - Web应用安全评估 + - 服务器配置审计 + - 漏洞扫描和发现 + - 渗透测试前期信息收集 + + **注意事项:** + - 扫描可能产生大量日志,注意日志管理 + - 某些扫描可能触发WAF或IDS告警 + - 建议在授权范围内使用 + - 扫描结果需要人工验证 + +# 参数定义 +parameters: + - name: "target" + type: "string" + description: | + 目标URL或IP地址。可以是完整的URL或IP地址。 + + **格式要求:** + - 可以包含协议(http:// 或 https://) + - 可以只提供IP地址或域名 + - 如果只提供IP,默认使用http协议 + + **示例值:** + - 完整URL: "http://example.com" + - HTTPS: "https://example.com" + - IP地址: "192.168.1.1" + - 带端口: "http://example.com:8080" + - 带路径: "http://example.com/admin" + + **注意事项:** + - 如果只提供IP,工具会使用http协议 + - 建议提供完整URL以确保正确扫描 + - 必需参数,不能为空 + required: true + flag: "-h" + format: "flag" + diff --git a/tools/nmap.yaml b/tools/nmap.yaml new file mode 100644 index 00000000..1ac79372 --- /dev/null +++ b/tools/nmap.yaml @@ -0,0 +1,75 @@ +name: "nmap" +command: "nmap" +args: ["-sT", "-sV", "-sC"] # 固定参数:TCP连接扫描、版本检测、默认脚本 +enabled: true + +# 简短描述(用于工具列表,减少token消耗)- 一句话说明工具用途 +short_description: "网络扫描工具,用于发现网络主机、开放端口和服务" + +# 工具详细描述 - 帮助大模型理解工具用途和使用场景 +description: | + 网络映射和端口扫描工具,用于发现网络中的主机、服务和开放端口。 + + **主要功能:** + - 主机发现:检测网络中的活动主机 + - 端口扫描:识别目标主机上开放的端口 + - 服务识别:检测运行在端口上的服务类型和版本 + - 操作系统检测:识别目标主机的操作系统类型 + - 漏洞检测:使用NSE脚本检测常见漏洞 + + **使用场景:** + - 网络资产发现和枚举 + - 安全评估和渗透测试 + - 网络故障排查 + - 端口和服务审计 + + **注意事项:** + - 使用 -sT (TCP连接扫描) 而不是 -sS (SYN扫描),因为 -sS 需要root权限 + - 扫描速度取决于网络延迟和目标响应 + - 某些扫描可能被防火墙或IDS检测到 + - 请确保有权限扫描目标网络 + +# 参数定义 +parameters: + - name: "target" + type: "string" + description: | + 目标IP地址或域名。可以是单个IP、IP范围、CIDR格式或域名。 + + **示例值:** + - 单个IP: "192.168.1.1" + - IP范围: "192.168.1.1-100" + - CIDR: "192.168.1.0/24" + - 域名: "example.com" + - URL: "https://example.com" (会自动提取域名部分) + + **注意事项:** + - 如果提供URL,会自动提取域名部分 + - 确保目标地址格式正确 + - 必需参数,不能为空 + required: true + position: 0 # 位置参数,放在命令最后 + format: "positional" + + - name: "ports" + type: "string" + description: | + 要扫描的端口范围。可以是单个端口、端口范围、逗号分隔的端口列表,或特殊值。 + + **示例值:** + - 单个端口: "80" + - 端口范围: "1-1000" + - 多个端口: "80,443,8080,8443" + - 组合: "80,443,8000-9000" + - 常用端口: "1-1024" + - 所有端口: "1-65535" + - 快速扫描: "80,443,22,21,25,53,110,143,993,995" + + **注意事项:** + - 如果不指定,将扫描默认的1000个常用端口 + - 扫描所有端口(1-65535)会非常耗时 + - 建议先扫描常用端口,再根据结果决定是否扫描全部端口 + required: false + flag: "-p" + format: "flag" + diff --git a/tools/sqlmap.yaml b/tools/sqlmap.yaml new file mode 100644 index 00000000..f60c5ff5 --- /dev/null +++ b/tools/sqlmap.yaml @@ -0,0 +1,100 @@ +name: "sqlmap" +command: "sqlmap" +enabled: true + +# 简短描述(用于工具列表,减少token消耗) +short_description: "自动化SQL注入检测和利用工具,用于发现和利用SQL注入漏洞" + +# 工具详细描述 +description: | + 自动化SQL注入检测和利用工具,用于发现和利用SQL注入漏洞。 + + **主要功能:** + - 自动检测SQL注入漏洞 + - 支持多种数据库类型(MySQL, PostgreSQL, Oracle, MSSQL等) + - 自动提取数据库信息(表、列、数据) + - 支持多种注入技术(布尔盲注、时间盲注、联合查询等) + - 支持文件上传/下载、命令执行等高级功能 + + **使用场景:** + - Web应用安全测试 + - SQL注入漏洞检测 + - 数据库信息收集 + - 渗透测试和漏洞验证 + + **注意事项:** + - 仅用于授权的安全测试 + - 某些操作可能对目标系统造成影响 + - 建议在测试环境中先验证 + - 使用 --batch 参数避免交互式提示 + +# 参数定义 +parameters: + - name: "url" + type: "string" + description: | + 目标URL,包含可能存在SQL注入的参数。 + + **格式要求:** + - 完整的URL,包含协议(http:// 或 https://) + - 必须包含查询参数,参数值用 * 标记注入点 + - 或者直接提供完整的URL,sqlmap会自动检测参数 + + **示例值:** + - 标记注入点: "http://example.com/page?id=1*" + - 完整URL: "http://example.com/page?id=1" + - POST数据: "http://example.com/login" (需要配合data参数) + - Cookie注入: "http://example.com/page" (需要配合cookie参数) + + **注意事项:** + - URL必须可访问 + - 确保URL格式正确,包含协议前缀 + - 如果使用POST请求,需要配合data参数 + - 必需参数,不能为空 + required: true + flag: "-u" + format: "flag" + + - name: "batch" + type: "bool" + description: | + 非交互模式,自动选择默认选项,不需要用户输入。 + + **使用场景:** + - 自动化测试脚本 + - 批量扫描 + - 避免交互式提示 + + **注意事项:** + - 建议始终设置为true,避免工具等待用户输入 + - 默认值为true + required: false + default: true + flag: "--batch" + format: "flag" + + - name: "level" + type: "int" + description: | + 测试级别,范围1-5,级别越高测试越全面但耗时越长。 + + **级别说明:** + - Level 1: 基本测试,快速但可能遗漏漏洞 + - Level 2: 默认级别,平衡速度和覆盖率 + - Level 3: 扩展测试,更全面的检测 + - Level 4: 深度测试,包含更多payload + - Level 5: 最全面测试,包含所有已知技术 + + **建议:** + - 快速扫描使用1-2 + - 常规测试使用3(默认) + - 深度测试使用4-5 + + **注意事项:** + - 级别越高,请求数量越多,可能被WAF拦截 + - 默认值为3 + required: false + default: 3 + flag: "--level" + format: "combined" # --level=3 + diff --git a/web/static/css/style.css b/web/static/css/style.css index ad78f71d..f462ea5e 100644 --- a/web/static/css/style.css +++ b/web/static/css/style.css @@ -304,7 +304,131 @@ header { word-break: break-word; line-height: 1.6; box-shadow: var(--shadow-sm); - white-space: pre-wrap; +} + +/* Markdown 样式 */ +.message-bubble p { + margin: 0.5em 0; +} + +.message-bubble p:first-child { + margin-top: 0; +} + +.message-bubble p:last-child { + margin-bottom: 0; +} + +.message-bubble strong, +.message-bubble b { + font-weight: 600; + color: inherit; +} + +.message-bubble em, +.message-bubble i { + font-style: italic; +} + +.message-bubble code { + background: rgba(0, 0, 0, 0.05); + padding: 2px 6px; + border-radius: 3px; + font-family: 'Monaco', 'Menlo', 'Ubuntu Mono', monospace; + font-size: 0.9em; +} + +.message.user .message-bubble code { + background: rgba(255, 255, 255, 0.2); +} + +.message-bubble pre { + background: rgba(0, 0, 0, 0.05); + padding: 12px; + border-radius: 6px; + overflow-x: auto; + margin: 0.5em 0; + font-family: 'Monaco', 'Menlo', 'Ubuntu Mono', monospace; + font-size: 0.875em; + line-height: 1.5; +} + +.message.user .message-bubble pre { + background: rgba(255, 255, 255, 0.15); +} + +.message-bubble pre code { + background: none; + padding: 0; +} + +.message-bubble ul, +.message-bubble ol { + margin: 0.5em 0; + padding-left: 1.5em; +} + +.message-bubble li { + margin: 0.25em 0; +} + +.message-bubble blockquote { + border-left: 3px solid var(--border-color); + padding-left: 1em; + margin: 0.5em 0; + color: var(--text-secondary); +} + +.message-bubble h1, +.message-bubble h2, +.message-bubble h3, +.message-bubble h4, +.message-bubble h5, +.message-bubble h6 { + margin: 0.8em 0 0.4em 0; + font-weight: 600; + line-height: 1.3; +} + +.message-bubble h1:first-child, +.message-bubble h2:first-child, +.message-bubble h3:first-child, +.message-bubble h4:first-child, +.message-bubble h5:first-child, +.message-bubble h6:first-child { + margin-top: 0; +} + +.message-bubble h1 { + font-size: 1.5em; +} + +.message-bubble h2 { + font-size: 1.3em; +} + +.message-bubble h3 { + font-size: 1.1em; +} + +.message-bubble hr { + border: none; + border-top: 1px solid var(--border-color); + margin: 1em 0; +} + +.message-bubble a { + color: var(--accent-color); + text-decoration: none; +} + +.message-bubble a:hover { + text-decoration: underline; +} + +.message.user .message-bubble a { + color: rgba(255, 255, 255, 0.9); + text-decoration: underline; } .message.user .message-bubble { diff --git a/web/static/js/app.js b/web/static/js/app.js index fbb16ba2..50bbc9ba 100644 --- a/web/static/js/app.js +++ b/web/static/js/app.js @@ -15,11 +15,15 @@ async function sendMessage() { addMessage('user', message); input.value = ''; - // 显示加载状态 - const loadingId = addMessage('system', '正在处理中...'); + // 创建进度消息容器 + const progressId = addMessage('system', '正在处理中...'); + const progressElement = document.getElementById(progressId); + const progressBubble = progressElement.querySelector('.message-bubble'); + let assistantMessageId = null; + let mcpExecutionIds = []; try { - const response = await fetch('/api/agent-loop', { + const response = await fetch('/api/agent-loop/stream', { method: 'POST', headers: { 'Content-Type': 'application/json', @@ -30,30 +34,124 @@ async function sendMessage() { }), }); - const data = await response.json(); + if (!response.ok) { + throw new Error('请求失败: ' + response.status); + } - // 移除加载消息 - removeMessage(loadingId); + const reader = response.body.getReader(); + const decoder = new TextDecoder(); + let buffer = ''; - if (response.ok) { - // 更新当前对话ID - if (data.conversationId) { - currentConversationId = data.conversationId; + while (true) { + const { done, value } = await reader.read(); + if (done) break; + + buffer += decoder.decode(value, { stream: true }); + const lines = buffer.split('\n'); + buffer = lines.pop(); // 保留最后一个不完整的行 + + for (const line of lines) { + if (line.startsWith('data: ')) { + try { + const eventData = JSON.parse(line.slice(6)); + handleStreamEvent(eventData, progressElement, progressBubble, progressId, + () => assistantMessageId, (id) => { assistantMessageId = id; }, + () => mcpExecutionIds, (ids) => { mcpExecutionIds = ids; }); + } catch (e) { + console.error('解析事件数据失败:', e, line); + } + } + } + } + + // 处理剩余的buffer + if (buffer.trim()) { + const lines = buffer.split('\n'); + for (const line of lines) { + if (line.startsWith('data: ')) { + try { + const eventData = JSON.parse(line.slice(6)); + handleStreamEvent(eventData, progressElement, progressBubble, progressId, + () => assistantMessageId, (id) => { assistantMessageId = id; }, + () => mcpExecutionIds, (ids) => { mcpExecutionIds = ids; }); + } catch (e) { + console.error('解析事件数据失败:', e, line); + } + } + } + } + + } catch (error) { + removeMessage(progressId); + addMessage('system', '错误: ' + error.message); + } +} + +// 处理流式事件 +function handleStreamEvent(event, progressElement, progressBubble, progressId, + getAssistantId, setAssistantId, getMcpIds, setMcpIds) { + switch (event.type) { + case 'progress': + // 更新进度消息 + progressBubble.textContent = event.message; + break; + + case 'tool_call': + // 显示工具调用信息 + const toolInfo = event.data || {}; + const toolName = toolInfo.toolName || '未知工具'; + const index = toolInfo.index || 0; + const total = toolInfo.total || 0; + progressBubble.innerHTML = `🔧 正在调用工具: ${escapeHtml(toolName)} (${index}/${total})`; + break; + + case 'tool_result': + // 显示工具执行结果 + const resultInfo = event.data || {}; + const resultToolName = resultInfo.toolName || '未知工具'; + const success = resultInfo.success !== false; + const statusIcon = success ? '✅' : '❌'; + progressBubble.innerHTML = `${statusIcon} 工具 ${escapeHtml(resultToolName)} 执行${success ? '完成' : '失败'}`; + break; + + case 'response': + // 移除进度消息,显示最终回复 + removeMessage(progressId); + const responseData = event.data || {}; + const mcpIds = responseData.mcpExecutionIds || []; + setMcpIds(mcpIds); + + // 更新对话ID + if (responseData.conversationId) { + currentConversationId = responseData.conversationId; updateActiveConversation(); } - // 如果有MCP执行ID,显示所有调用 - const mcpIds = data.mcpExecutionIds || []; - addMessage('assistant', data.response, mcpIds); + // 添加助手回复 + const assistantId = addMessage('assistant', event.message, mcpIds); + setAssistantId(assistantId); // 刷新对话列表 loadConversations(); - } else { - addMessage('system', '错误: ' + (data.error || '未知错误')); - } - } catch (error) { - removeMessage(loadingId); - addMessage('system', '错误: ' + error.message); + break; + + case 'error': + // 显示错误 + removeMessage(progressId); + addMessage('system', '错误: ' + event.message); + break; + + case 'done': + // 完成,确保进度消息已移除 + if (progressElement && progressElement.parentNode) { + removeMessage(progressId); + } + // 更新对话ID + if (event.data && event.data.conversationId) { + currentConversationId = event.data.conversationId; + updateActiveConversation(); + } + break; } } @@ -88,8 +186,28 @@ function addMessage(role, content, mcpExecutionIds = null) { // 创建消息气泡 const bubble = document.createElement('div'); bubble.className = 'message-bubble'; - // 处理换行和格式化 - const formattedContent = content.replace(/\n/g, '
'); + + // 解析 Markdown 格式 + let formattedContent; + if (typeof marked !== 'undefined') { + // 使用 marked.js 解析 Markdown + try { + // 配置 marked 选项 + marked.setOptions({ + breaks: true, // 支持换行 + gfm: true, // 支持 GitHub Flavored Markdown + }); + formattedContent = marked.parse(content); + } catch (e) { + console.error('Markdown 解析失败:', e); + // 降级处理:转义 HTML 并保留换行 + formattedContent = escapeHtml(content).replace(/\n/g, '
'); + } + } else { + // 如果没有 marked.js,使用简单处理 + formattedContent = escapeHtml(content).replace(/\n/g, '
'); + } + bubble.innerHTML = formattedContent; contentWrapper.appendChild(bubble); diff --git a/web/templates/index.html b/web/templates/index.html index bc563f32..bd40fc09 100644 --- a/web/templates/index.html +++ b/web/templates/index.html @@ -86,6 +86,8 @@ + +