From 2de0bd4d315a761e13cf1b516ea7e91ee0d7cbdc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=85=AC=E6=98=8E?= <83812544+Ed1s0nZ@users.noreply.github.com> Date: Sun, 19 Apr 2026 01:25:30 +0800 Subject: [PATCH] Add files via upload --- internal/attackchain/builder.go | 933 +++++++++++++ internal/config/config.go | 877 ++++++++++++ internal/database/attackchain.go | 168 +++ internal/database/batch_task.go | 537 +++++++ internal/database/conversation.go | 758 ++++++++++ internal/database/conversation_turn_test.go | 39 + internal/database/database.go | 809 +++++++++++ internal/database/group.go | 449 ++++++ internal/database/monitor.go | 537 +++++++ internal/database/skill_stats.go | 142 ++ internal/database/vulnerability.go | 281 ++++ internal/database/webshell.go | 148 ++ internal/logger/logger.go | 68 + internal/mcp/builtin/constants.go | 105 ++ internal/mcp/client_sdk.go | 551 ++++++++ internal/mcp/external_manager.go | 1105 +++++++++++++++ internal/mcp/external_manager_test.go | 239 ++++ internal/mcp/server.go | 1237 +++++++++++++++++ internal/mcp/types.go | 295 ++++ internal/multiagent/eino_skills.go | 85 ++ internal/multiagent/eino_summarize.go | 140 ++ internal/multiagent/no_nested_task.go | 62 + internal/multiagent/runner.go | 1068 ++++++++++++++ internal/multiagent/tool_args_json_retry.go | 51 + .../multiagent/tool_args_json_retry_test.go | 17 + internal/multiagent/tool_error_middleware.go | 131 ++ .../multiagent/tool_error_middleware_test.go | 166 +++ internal/multiagent/tool_execution_retry.go | 76 + internal/skillpackage/content.go | 165 +++ internal/skillpackage/frontmatter.go | 114 ++ internal/skillpackage/io.go | 200 +++ internal/skillpackage/layout.go | 66 + internal/skillpackage/service.go | 155 +++ internal/skillpackage/types.go | 67 + internal/skillpackage/validate.go | 102 ++ internal/storage/result_storage.go | 297 ++++ internal/storage/result_storage_test.go | 453 ++++++ 37 files changed, 12693 insertions(+) create mode 100644 internal/attackchain/builder.go create mode 100644 internal/config/config.go create mode 100644 internal/database/attackchain.go create mode 100644 internal/database/batch_task.go create mode 100644 internal/database/conversation.go create mode 100644 internal/database/conversation_turn_test.go create mode 100644 internal/database/database.go create mode 100644 internal/database/group.go create mode 100644 internal/database/monitor.go create mode 100644 internal/database/skill_stats.go create mode 100644 internal/database/vulnerability.go create mode 100644 internal/database/webshell.go create mode 100644 internal/logger/logger.go create mode 100644 internal/mcp/builtin/constants.go create mode 100644 internal/mcp/client_sdk.go create mode 100644 internal/mcp/external_manager.go create mode 100644 internal/mcp/external_manager_test.go create mode 100644 internal/mcp/server.go create mode 100644 internal/mcp/types.go create mode 100644 internal/multiagent/eino_skills.go create mode 100644 internal/multiagent/eino_summarize.go create mode 100644 internal/multiagent/no_nested_task.go create mode 100644 internal/multiagent/runner.go create mode 100644 internal/multiagent/tool_args_json_retry.go create mode 100644 internal/multiagent/tool_args_json_retry_test.go create mode 100644 internal/multiagent/tool_error_middleware.go create mode 100644 internal/multiagent/tool_error_middleware_test.go create mode 100644 internal/multiagent/tool_execution_retry.go create mode 100644 internal/skillpackage/content.go create mode 100644 internal/skillpackage/frontmatter.go create mode 100644 internal/skillpackage/io.go create mode 100644 internal/skillpackage/layout.go create mode 100644 internal/skillpackage/service.go create mode 100644 internal/skillpackage/types.go create mode 100644 internal/skillpackage/validate.go create mode 100644 internal/storage/result_storage.go create mode 100644 internal/storage/result_storage_test.go diff --git a/internal/attackchain/builder.go b/internal/attackchain/builder.go new file mode 100644 index 00000000..de1a7d52 --- /dev/null +++ b/internal/attackchain/builder.go @@ -0,0 +1,933 @@ +package attackchain + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "net/http" + "strings" + "time" + + "cyberstrike-ai/internal/agent" + "cyberstrike-ai/internal/config" + "cyberstrike-ai/internal/database" + "cyberstrike-ai/internal/openai" + + "github.com/google/uuid" + "go.uber.org/zap" +) + +// Builder 攻击链构建器 +type Builder struct { + db *database.DB + logger *zap.Logger + openAIClient *openai.Client + openAIConfig *config.OpenAIConfig + tokenCounter agent.TokenCounter + maxTokens int // 最大tokens限制,默认100000 +} + +// Node 攻击链节点(使用database包的类型) +type Node = database.AttackChainNode + +// Edge 攻击链边(使用database包的类型) +type Edge = database.AttackChainEdge + +// Chain 完整的攻击链 +type Chain struct { + Nodes []Node `json:"nodes"` + Edges []Edge `json:"edges"` +} + +// NewBuilder 创建新的攻击链构建器 +func NewBuilder(db *database.DB, openAIConfig *config.OpenAIConfig, logger *zap.Logger) *Builder { + transport := &http.Transport{ + MaxIdleConns: 100, + MaxIdleConnsPerHost: 10, + IdleConnTimeout: 90 * time.Second, + } + httpClient := &http.Client{Timeout: 5 * time.Minute, Transport: transport} + + // 优先使用配置文件中的统一 Token 上限(config.yaml -> openai.max_total_tokens) + maxTokens := 0 + if openAIConfig != nil && openAIConfig.MaxTotalTokens > 0 { + maxTokens = openAIConfig.MaxTotalTokens + } else if openAIConfig != nil { + // 如果未显式配置 max_total_tokens,则根据模型设置一个合理的默认值 + model := strings.ToLower(openAIConfig.Model) + if strings.Contains(model, "gpt-4") { + maxTokens = 128000 // gpt-4通常支持128k + } else if strings.Contains(model, "gpt-3.5") { + maxTokens = 16000 // gpt-3.5-turbo通常支持16k + } else if strings.Contains(model, "deepseek") { + maxTokens = 131072 // deepseek-chat通常支持131k + } else { + maxTokens = 100000 // 兜底默认值 + } + } else { + // 没有 OpenAI 配置时使用兜底值,避免为 0 + maxTokens = 100000 + } + + return &Builder{ + db: db, + logger: logger, + openAIClient: openai.NewClient(openAIConfig, httpClient, logger), + openAIConfig: openAIConfig, + tokenCounter: agent.NewTikTokenCounter(), + maxTokens: maxTokens, + } +} + +// BuildChainFromConversation 从对话构建攻击链(简化版本:用户输入+最后一轮ReAct输入+大模型输出) +func (b *Builder) BuildChainFromConversation(ctx context.Context, conversationID string) (*Chain, error) { + b.logger.Info("开始构建攻击链(简化版本)", zap.String("conversationId", conversationID)) + + // 0. 首先检查是否有实际的工具执行记录 + messages, err := b.db.GetMessages(conversationID) + if err != nil { + return nil, fmt.Errorf("获取对话消息失败: %w", err) + } + + if len(messages) == 0 { + b.logger.Info("对话中没有数据", zap.String("conversationId", conversationID)) + return &Chain{Nodes: []Node{}, Edges: []Edge{}}, nil + } + + // 检查是否有实际的工具执行:assistant 的 mcp_execution_ids,或过程详情中的 tool_call/tool_result + //(多代理下若 MCP 未返回 execution_id,IDs 可能为空,但工具已通过 Eino 执行并写入 process_details) + hasToolExecutions := false + for i := len(messages) - 1; i >= 0; i-- { + if strings.EqualFold(messages[i].Role, "assistant") { + if len(messages[i].MCPExecutionIDs) > 0 { + hasToolExecutions = true + break + } + } + } + if !hasToolExecutions { + if pdOK, err := b.db.ConversationHasToolProcessDetails(conversationID); err != nil { + b.logger.Warn("查询过程详情判定工具执行失败", zap.Error(err)) + } else if pdOK { + hasToolExecutions = true + } + } + + // 检查任务是否被取消(通过检查最后一条assistant消息内容或process_details) + taskCancelled := false + for i := len(messages) - 1; i >= 0; i-- { + if strings.EqualFold(messages[i].Role, "assistant") { + content := strings.ToLower(messages[i].Content) + if strings.Contains(content, "取消") || strings.Contains(content, "cancelled") { + taskCancelled = true + } + break + } + } + + // 如果任务被取消且没有实际工具执行,返回空攻击链 + if taskCancelled && !hasToolExecutions { + b.logger.Info("任务已取消且没有实际工具执行,返回空攻击链", + zap.String("conversationId", conversationID), + zap.Bool("taskCancelled", taskCancelled), + zap.Bool("hasToolExecutions", hasToolExecutions)) + return &Chain{Nodes: []Node{}, Edges: []Edge{}}, nil + } + + // 如果没有实际工具执行,也返回空攻击链(避免AI编造) + if !hasToolExecutions { + b.logger.Info("没有实际工具执行记录,返回空攻击链", + zap.String("conversationId", conversationID)) + return &Chain{Nodes: []Node{}, Edges: []Edge{}}, nil + } + + // 1. 优先尝试从数据库获取保存的最后一轮ReAct输入和输出 + reactInputJSON, modelOutput, err := b.db.GetReActData(conversationID) + if err != nil { + b.logger.Warn("获取保存的ReAct数据失败,将使用消息历史构建", zap.Error(err)) + // 继续使用原来的逻辑 + reactInputJSON = "" + modelOutput = "" + } + + // var userInput string + var reactInputFinal string + var dataSource string // 记录数据来源 + + // 如果成功获取到保存的ReAct数据,直接使用 + if reactInputJSON != "" && modelOutput != "" { + // 计算 ReAct 输入的哈希值,用于追踪 + hash := sha256.Sum256([]byte(reactInputJSON)) + reactInputHash := hex.EncodeToString(hash[:])[:16] // 使用前16字符作为短标识 + + // 统计消息数量 + var messageCount int + var tempMessages []interface{} + if json.Unmarshal([]byte(reactInputJSON), &tempMessages) == nil { + messageCount = len(tempMessages) + } + + dataSource = "database_last_react_input" + b.logger.Info("使用保存的ReAct数据构建攻击链", + zap.String("conversationId", conversationID), + zap.String("dataSource", dataSource), + zap.Int("reactInputSize", len(reactInputJSON)), + zap.Int("messageCount", messageCount), + zap.String("reactInputHash", reactInputHash), + zap.Int("modelOutputSize", len(modelOutput))) + + // 从保存的ReAct输入(JSON格式)中提取用户输入 + // userInput = b.extractUserInputFromReActInput(reactInputJSON) + + // 将JSON格式的messages转换为可读格式 + reactInputFinal = b.formatReActInputFromJSON(reactInputJSON) + } else { + // 2. 如果没有保存的ReAct数据,从对话消息构建 + dataSource = "messages_table" + b.logger.Info("从消息历史构建ReAct数据", + zap.String("conversationId", conversationID), + zap.String("dataSource", dataSource), + zap.Int("messageCount", len(messages))) + + // 提取用户输入(最后一条user消息) + for i := len(messages) - 1; i >= 0; i-- { + if strings.EqualFold(messages[i].Role, "user") { + // userInput = messages[i].Content + break + } + } + + // 提取最后一轮ReAct的输入(历史消息+当前用户输入) + reactInputFinal = b.buildReActInput(messages) + + // 提取大模型最后的输出(最后一条assistant消息) + for i := len(messages) - 1; i >= 0; i-- { + if strings.EqualFold(messages[i].Role, "assistant") { + modelOutput = messages[i].Content + break + } + } + } + + // 多代理:保存的 last_react_input 可能仅为首轮用户消息,不含工具轨迹;补充最后一轮助手的过程详情(与单代理「最后一轮 ReAct」对齐) + hasMCPOnAssistant := false + var lastAssistantID string + for i := len(messages) - 1; i >= 0; i-- { + if strings.EqualFold(messages[i].Role, "assistant") { + lastAssistantID = messages[i].ID + if len(messages[i].MCPExecutionIDs) > 0 { + hasMCPOnAssistant = true + } + break + } + } + if lastAssistantID != "" { + pdHasTools, _ := b.db.ConversationHasToolProcessDetails(conversationID) + if pdHasTools && !(hasMCPOnAssistant && reactInputContainsToolTrace(reactInputJSON)) { + detailsMap, err := b.db.GetProcessDetailsByConversation(conversationID) + if err != nil { + b.logger.Warn("加载过程详情用于攻击链失败", zap.Error(err)) + } else if dets := detailsMap[lastAssistantID]; len(dets) > 0 { + extra := b.formatProcessDetailsForAttackChain(dets) + if strings.TrimSpace(extra) != "" { + reactInputFinal = reactInputFinal + "\n\n## 执行过程与工具记录(含多代理编排与子任务)\n\n" + extra + b.logger.Info("攻击链输入已补充过程详情", + zap.String("conversationId", conversationID), + zap.String("messageId", lastAssistantID), + zap.Int("detailEvents", len(dets))) + } + } + } + } + + // 3. 构建简化的prompt,一次性传递给大模型 + prompt := b.buildSimplePrompt(reactInputFinal, modelOutput) + // fmt.Println(prompt) + // 6. 调用AI生成攻击链(一次性,不做任何处理) + chainJSON, err := b.callAIForChainGeneration(ctx, prompt) + if err != nil { + return nil, fmt.Errorf("AI生成失败: %w", err) + } + + // 7. 解析JSON并生成节点/边ID(前端需要有效的ID) + chainData, err := b.parseChainJSON(chainJSON) + if err != nil { + // 如果解析失败,返回空链,让前端处理错误 + b.logger.Warn("解析攻击链JSON失败", zap.Error(err), zap.String("raw_json", chainJSON)) + return &Chain{ + Nodes: []Node{}, + Edges: []Edge{}, + }, nil + } + + b.logger.Info("攻击链构建完成", + zap.String("conversationId", conversationID), + zap.String("dataSource", dataSource), + zap.Int("nodes", len(chainData.Nodes)), + zap.Int("edges", len(chainData.Edges))) + + // 保存到数据库(供后续加载使用) + if err := b.saveChain(conversationID, chainData.Nodes, chainData.Edges); err != nil { + b.logger.Warn("保存攻击链到数据库失败", zap.Error(err)) + // 即使保存失败,也返回数据给前端 + } + + // 直接返回,不做任何处理和校验 + return chainData, nil +} + +// reactInputContainsToolTrace 判断保存的 ReAct JSON 是否包含可解析的工具调用轨迹(单代理完整保存时为 true)。 +func reactInputContainsToolTrace(reactInputJSON string) bool { + s := strings.TrimSpace(reactInputJSON) + if s == "" { + return false + } + return strings.Contains(s, "tool_calls") || + strings.Contains(s, "tool_call_id") || + strings.Contains(s, `"role":"tool"`) || + strings.Contains(s, `"role": "tool"`) +} + +// formatProcessDetailsForAttackChain 将最后一轮助手的过程详情格式化为攻击链分析的输入(覆盖多代理下 last_react_input 不完整的情况)。 +func (b *Builder) formatProcessDetailsForAttackChain(details []database.ProcessDetail) string { + if len(details) == 0 { + return "" + } + var sb strings.Builder + for _, d := range details { + // 目标:以主 agent(编排器)视角输出整轮迭代 + // - 保留:编排器工具调用/结果、对子代理的 task 调度、子代理最终回复(不含推理) + // - 丢弃:thinking/planning/progress 等噪声、子代理的工具细节与推理过程 + if d.EventType == "progress" || d.EventType == "thinking" || d.EventType == "planning" { + continue + } + + // 解析 data(JSON string),用于识别 einoRole / toolName 等 + var dataMap map[string]interface{} + if strings.TrimSpace(d.Data) != "" { + _ = json.Unmarshal([]byte(d.Data), &dataMap) + } + einoRole := "" + if v, ok := dataMap["einoRole"]; ok { + einoRole = strings.ToLower(strings.TrimSpace(fmt.Sprint(v))) + } + toolName := "" + if v, ok := dataMap["toolName"]; ok { + toolName = strings.TrimSpace(fmt.Sprint(v)) + } + + // 1) 编排器的工具调用/结果:保留(这是“主 agent 调了什么工具”) + if (d.EventType == "tool_call" || d.EventType == "tool_result" || d.EventType == "tool_calls_detected" || d.EventType == "iteration" || d.EventType == "eino_recovery") && einoRole == "orchestrator" { + sb.WriteString("[") + sb.WriteString(d.EventType) + sb.WriteString("] ") + sb.WriteString(strings.TrimSpace(d.Message)) + sb.WriteString("\n") + if strings.TrimSpace(d.Data) != "" { + sb.WriteString(d.Data) + sb.WriteString("\n") + } + sb.WriteString("\n") + continue + } + + // 2) 子代理调度:tool_call(toolName=="task") 代表编排器把子任务派发出去;保留(只需任务,不要子代理推理) + if d.EventType == "tool_call" && strings.EqualFold(toolName, "task") { + sb.WriteString("[dispatch_subagent_task] ") + sb.WriteString(strings.TrimSpace(d.Message)) + sb.WriteString("\n") + if strings.TrimSpace(d.Data) != "" { + sb.WriteString(d.Data) + sb.WriteString("\n") + } + sb.WriteString("\n") + continue + } + + // 3) 子代理最终回复:保留(只保留最终输出,不保留分析过程) + if d.EventType == "eino_agent_reply" && einoRole == "sub" { + sb.WriteString("[subagent_final_reply] ") + sb.WriteString(strings.TrimSpace(d.Message)) + sb.WriteString("\n") + // data 里含 einoAgent 等元信息,保留有助于追踪“哪个子代理说的” + if strings.TrimSpace(d.Data) != "" { + sb.WriteString(d.Data) + sb.WriteString("\n") + } + sb.WriteString("\n") + continue + } + + // 其他事件默认丢弃,避免把子代理工具细节/推理塞进 prompt,偏离“主 agent 一轮迭代”的视角。 + } + return strings.TrimSpace(sb.String()) +} + +// buildReActInput 构建最后一轮ReAct的输入(历史消息+当前用户输入) +func (b *Builder) buildReActInput(messages []database.Message) string { + var builder strings.Builder + for _, msg := range messages { + builder.WriteString(fmt.Sprintf("[%s]: %s\n\n", msg.Role, msg.Content)) + } + return builder.String() +} + +// extractUserInputFromReActInput 从保存的ReAct输入(JSON格式的messages数组)中提取最后一条用户输入 +// func (b *Builder) extractUserInputFromReActInput(reactInputJSON string) string { +// // reactInputJSON是JSON格式的ChatMessage数组,需要解析 +// var messages []map[string]interface{} +// if err := json.Unmarshal([]byte(reactInputJSON), &messages); err != nil { +// b.logger.Warn("解析ReAct输入JSON失败", zap.Error(err)) +// return "" +// } + +// // 从后往前查找最后一条user消息 +// for i := len(messages) - 1; i >= 0; i-- { +// if role, ok := messages[i]["role"].(string); ok && strings.EqualFold(role, "user") { +// if content, ok := messages[i]["content"].(string); ok { +// return content +// } +// } +// } + +// return "" +// } + +// formatReActInputFromJSON 将JSON格式的messages数组转换为可读的字符串格式 +func (b *Builder) formatReActInputFromJSON(reactInputJSON string) string { + var messages []map[string]interface{} + if err := json.Unmarshal([]byte(reactInputJSON), &messages); err != nil { + b.logger.Warn("解析ReAct输入JSON失败", zap.Error(err)) + return reactInputJSON // 如果解析失败,返回原始JSON + } + + var builder strings.Builder + for _, msg := range messages { + role, _ := msg["role"].(string) + content, _ := msg["content"].(string) + + // 处理assistant消息:提取tool_calls信息 + if role == "assistant" { + if toolCalls, ok := msg["tool_calls"].([]interface{}); ok && len(toolCalls) > 0 { + // 如果有文本内容,先显示 + if content != "" { + builder.WriteString(fmt.Sprintf("[%s]: %s\n", role, content)) + } + // 详细显示每个工具调用 + builder.WriteString(fmt.Sprintf("[%s] 工具调用 (%d个):\n", role, len(toolCalls))) + for i, toolCall := range toolCalls { + if tc, ok := toolCall.(map[string]interface{}); ok { + toolCallID, _ := tc["id"].(string) + if funcData, ok := tc["function"].(map[string]interface{}); ok { + toolName, _ := funcData["name"].(string) + arguments, _ := funcData["arguments"].(string) + builder.WriteString(fmt.Sprintf(" [工具调用 %d]\n", i+1)) + builder.WriteString(fmt.Sprintf(" ID: %s\n", toolCallID)) + builder.WriteString(fmt.Sprintf(" 工具名称: %s\n", toolName)) + builder.WriteString(fmt.Sprintf(" 参数: %s\n", arguments)) + } + } + } + builder.WriteString("\n") + continue + } + } + + // 处理tool消息:显示tool_call_id和完整内容 + if role == "tool" { + toolCallID, _ := msg["tool_call_id"].(string) + if toolCallID != "" { + builder.WriteString(fmt.Sprintf("[%s] (tool_call_id: %s):\n%s\n\n", role, toolCallID, content)) + } else { + builder.WriteString(fmt.Sprintf("[%s]: %s\n\n", role, content)) + } + continue + } + + // 其他消息类型(system, user等)正常显示 + builder.WriteString(fmt.Sprintf("[%s]: %s\n\n", role, content)) + } + + return builder.String() +} + +// buildSimplePrompt 构建简化的prompt +func (b *Builder) buildSimplePrompt(reactInput, modelOutput string) string { + return fmt.Sprintf(`你是专业的安全测试分析师和攻击链构建专家。你的任务是根据对话记录和工具执行结果,构建一个逻辑清晰、有教育意义的攻击链图,完整展现渗透测试的思维过程和执行路径。 + +## 核心目标 + +构建一个能够讲述完整攻击故事的攻击链让学习者能够: +1. 理解渗透测试的完整流程和思维逻辑(从目标识别到漏洞发现的每一步) +2. 学习如何从失败中获取线索并调整策略 +3. 掌握工具使用的实际效果和局限性 +4. 理解漏洞发现和利用的因果关系 + +**关键原则**:完整性优先。必须包含所有有意义的工具执行和关键步骤,不要为了控制节点数量而遗漏重要信息。 + +## 构建流程(按此顺序思考) + +### 第一步:理解上下文 +仔细分析ReAct输入中的工具调用序列和大模型输出,识别: +- 测试目标(IP、域名、URL等) +- 实际执行的工具和参数 +- 工具返回的关键信息(成功结果、错误信息、超时等) +- AI的分析和决策过程 + +### 第二步:提取关键节点 +从工具执行记录中提取有意义的节点,**确保不遗漏任何关键步骤**: +- **target节点**:每个独立的测试目标创建一个target节点 +- **action节点**:每个有意义的工具执行创建一个action节点(包括提供线索的失败、成功的信息收集、漏洞验证等) +- **vulnerability节点**:每个真实确认的漏洞创建一个vulnerability节点 +- **完整性检查**:对照ReAct输入中的工具调用序列,确保每个有意义的工具执行都被包含在攻击链中 + +### 第三步:构建逻辑关系(树状结构) +**重要:必须构建树状结构,而不是简单的线性链。** +按照因果关系连接节点,形成树状图(因为是单agent执行,所以可以不按照时间顺序): +- **分支结构**:一个节点可以有多个后续节点(例如:端口扫描发现多个端口后,可以同时进行多个不同的测试) +- **汇聚结构**:多个节点可以指向同一个节点(例如:多个不同的测试都发现了同一个漏洞) +- 识别哪些action是基于前面action的结果而执行的 +- 识别哪些vulnerability是由哪些action发现的 +- 识别失败节点如何为后续成功提供线索 +- **避免线性链**:不要将所有节点连成一条线,应该根据实际的并行测试和分支探索构建树状结构 + +### 第四步:优化和精简 +- **完整性检查**:确保所有有意义的工具执行都被包含,不要遗漏关键步骤 +- **合并规则**:只合并真正相似或重复的action节点(如多次相同工具的相似调用) +- **删除规则**:只删除完全无价值的失败节点(完全无输出、纯系统错误、重复的相同失败) +- **重要提醒**:宁可保留更多节点,也不要遗漏关键步骤。攻击链必须完整展现渗透测试过程 +- 确保攻击链逻辑连贯,能够讲述完整故事 + +## 节点类型详解 + +### target(目标节点) +- **用途**:标识测试目标 +- **创建规则**:每个独立目标(不同IP/域名)创建一个target节点 +- **多目标处理**:不同目标的节点不相互连接,各自形成独立的子图 +- **metadata.target**:精确记录目标标识(IP地址、域名、URL等) + +### action(行动节点) +- **用途**:记录工具执行和AI分析结果 +- **标签规则**: + * 15-25个汉字,动宾结构 + * 成功节点:描述执行结果(如"扫描端口发现80/443/8080"、"目录扫描发现/admin路径") + * 失败节点:描述失败原因(如"尝试SQL注入(被WAF拦截)"、"端口扫描超时(目标不可达)") +- **ai_analysis要求**: + * 成功节点:总结工具执行的关键发现,说明这些发现的意义 + * 失败节点:必须说明失败原因、获得的线索、这些线索如何指引后续行动 + * 不超过150字,要具体、有信息量 +- **findings要求**: + * 提取工具返回结果中的关键信息点 + * 每个finding应该是独立的、有价值的信息片段 + * 成功节点:列出关键发现(如["80端口开放", "443端口开放", "HTTP服务为Apache 2.4"]) + * 失败节点:列出失败线索(如["WAF拦截", "返回403", "检测到Cloudflare"]) +- **status标记**: + * 成功节点:不设置或设为"success" + * 提供线索的失败节点:必须设为"failed_insight" +- **risk_score**:始终为0(action节点不评估风险) + +### vulnerability(漏洞节点) +- **用途**:记录真实确认的安全漏洞 +- **创建规则**: + * 必须是真实确认的漏洞,不是所有发现都是漏洞 + * 需要明确的漏洞证据(如SQL注入返回数据库错误、XSS成功执行等) +- **risk_score规则**: + * critical(90-100):可导致系统完全沦陷(RCE、SQL注入导致数据泄露等) + * high(80-89):可导致敏感信息泄露或权限提升 + * medium(60-79):存在安全风险但影响有限 + * low(40-59):轻微安全问题 +- **metadata要求**: + * vulnerability_type:漏洞类型(SQL注入、XSS、RCE等) + * description:详细描述漏洞位置、原理、影响 + * severity:critical/high/medium/low + * location:精确的漏洞位置(URL、参数、文件路径等) + +## 节点过滤和合并规则 + +### 必须保留的失败节点 +以下失败情况必须创建节点,因为它们提供了有价值的线索: +- 工具返回明确的错误信息(权限错误、连接拒绝、认证失败等) +- 超时或连接失败(可能表明防火墙、网络隔离等) +- WAF/防火墙拦截(返回403、406等,表明存在防护机制) +- 工具未安装或配置错误(但执行了调用) +- 目标不可达(DNS解析失败、网络不通等) + +### 应该删除的失败节点 +以下情况不应创建节点: +- 完全无输出的工具调用 +- 纯系统错误(与目标无关,如本地环境问题) +- 重复的相同失败(多次相同错误只保留第一次) + +### 节点合并规则 +以下情况应合并节点: +- 同一工具的多次相似调用(如多次nmap扫描不同端口范围,合并为一个"端口扫描"节点) +- 同一目标的多个相似探测(如多个目录扫描工具,合并为一个"目录扫描"节点) + +### 节点数量控制 +- **完整性优先**:必须包含所有有意义的工具执行和关键步骤,不要为了控制数量而删除重要节点 +- **建议范围**:单目标通常8-15个节点,但如果实际执行步骤较多,可以适当增加(最多20个节点) +- **优先保留**:关键成功步骤、提供线索的失败、发现的漏洞、重要的信息收集步骤 +- **可以合并**:同一工具的多次相似调用(如多次nmap扫描不同端口范围,合并为一个"端口扫描"节点) +- **可以删除**:完全无输出的工具调用、纯系统错误、重复的相同失败(多次相同错误只保留第一次) +- **重要原则**:宁可节点稍多,也不要遗漏关键步骤。攻击链必须能够完整展现渗透测试的完整过程 + +## 边的类型和权重 + +### 边的类型 +- **leads_to**:表示"导致"或"引导到",用于action→action、target→action + * 例如:端口扫描 → 目录扫描(因为发现了80端口,所以进行目录扫描) +- **discovers**:表示"发现",**专门用于action→vulnerability** + * 例如:SQL注入测试 → SQL注入漏洞 + * **重要**:所有action→vulnerability的边都必须使用discovers类型,即使多个action都指向同一个vulnerability,也应该统一使用discovers +- **enables**:表示"使能"或"促成",**仅用于vulnerability→vulnerability、action→action(当后续行动依赖前面结果时)** + * 例如:信息泄露漏洞 → 权限提升漏洞(通过信息泄露获得的信息促成了权限提升) + * **重要**:enables不能用于action→vulnerability,action→vulnerability必须使用discovers + +### 边的权重 +- **权重1-2**:弱关联(如初步探测到进一步探测) +- **权重3-4**:中等关联(如发现端口到服务识别) +- **权重5-7**:强关联(如发现漏洞、关键信息泄露) +- **权重8-10**:极强关联(如漏洞利用成功、权限提升) + +### DAG结构要求(有向无环图) +**关键:必须确保生成的是真正的DAG(有向无环图),不能有任何循环。** + +- **节点编号规则**:节点id从"node_1"开始递增(node_1, node_2, node_3...) +- **边的方向规则**:所有边的source节点id必须严格小于target节点id(source < target),这是确保无环的关键 + * 例如:node_1 → node_2 ✓(正确) + * 例如:node_2 → node_1 ✗(错误,会形成环) + * 例如:node_3 → node_5 ✓(正确) +- **无环验证**:在输出JSON前,必须检查所有边,确保没有任何一条边的source >= target +- **无孤立节点**:确保每个节点至少有一条边连接(除了可能的根节点) +- **DAG结构特点**: + * 一个节点可以有多个后续节点(分支),例如:node_2(端口扫描)可以同时连接到node_3、node_4、node_5等多个节点 + * 多个节点可以汇聚到一个节点(汇聚),例如:node_3、node_4、node_5都指向node_6(漏洞节点) + * 避免将所有节点连成一条线,应该根据实际的并行测试和分支探索构建DAG结构 +- **拓扑排序验证**:如果按照节点id从小到大排序,所有边都应该从左指向右(从上指向下),这样就能保证无环 + +## 攻击链逻辑连贯性要求 + +构建的攻击链应该能够回答以下问题: +1. **起点**:测试从哪里开始?(target节点) +2. **探索过程**:如何逐步收集信息?(action节点序列) +3. **失败与调整**:遇到障碍时如何调整策略?(failed_insight节点) +4. **关键发现**:发现了哪些重要信息?(action的findings) +5. **漏洞确认**:如何确认漏洞存在?(action→vulnerability) +6. **攻击路径**:完整的攻击路径是什么?(从target到vulnerability的路径) + +## 最后一轮ReAct输入 + +%s + +## 大模型输出 + +%s + +## 输出格式 + +严格按照以下JSON格式输出,不要添加任何其他文字: + +**重要:示例展示的是树状结构,注意node_2(端口扫描)同时连接到多个后续节点(node_3、node_4),形成分支结构。** + +{ + "nodes": [ + { + "id": "node_1", + "type": "target", + "label": "测试目标: example.com", + "risk_score": 40, + "metadata": { + "target": "example.com" + } + }, + { + "id": "node_2", + "type": "action", + "label": "扫描端口发现80/443/8080", + "risk_score": 0, + "metadata": { + "tool_name": "nmap", + "tool_intent": "端口扫描", + "ai_analysis": "使用nmap对目标进行端口扫描,发现80、443、8080端口开放。80端口运行HTTP服务,443端口运行HTTPS服务,8080端口可能为管理后台。这些开放端口为后续Web应用测试提供了入口。", + "findings": ["80端口开放", "443端口开放", "8080端口开放", "HTTP服务为Apache 2.4"] + } + }, + { + "id": "node_3", + "type": "action", + "label": "目录扫描发现/admin后台", + "risk_score": 0, + "metadata": { + "tool_name": "dirsearch", + "tool_intent": "目录扫描", + "ai_analysis": "使用dirsearch对目标进行目录扫描,发现/admin目录存在且可访问。该目录可能为管理后台,是重要的测试目标。", + "findings": ["/admin目录存在", "返回200状态码", "疑似管理后台"] + } + }, + { + "id": "node_4", + "type": "action", + "label": "识别Web服务为Apache 2.4", + "risk_score": 0, + "metadata": { + "tool_name": "whatweb", + "tool_intent": "Web服务识别", + "ai_analysis": "识别出目标运行Apache 2.4服务器,这为后续的漏洞测试提供了重要信息。", + "findings": ["Apache 2.4", "PHP版本信息"] + } + }, + { + "id": "node_5", + "type": "action", + "label": "尝试SQL注入(被WAF拦截)", + "risk_score": 0, + "metadata": { + "tool_name": "sqlmap", + "tool_intent": "SQL注入检测", + "ai_analysis": "对/login.php进行SQL注入测试时被WAF拦截,返回403错误。错误信息显示检测到Cloudflare防护。这表明目标部署了WAF,需要调整测试策略。", + "findings": ["WAF拦截", "返回403", "检测到Cloudflare", "目标部署WAF"], + "status": "failed_insight" + } + }, + { + "id": "node_6", + "type": "vulnerability", + "label": "SQL注入漏洞", + "risk_score": 85, + "metadata": { + "vulnerability_type": "SQL注入", + "description": "在/admin/login.php的username参数发现SQL注入漏洞,可通过注入payload绕过登录验证,直接获取管理员权限。漏洞返回数据库错误信息,确认存在注入点。", + "severity": "high", + "location": "/admin/login.php?username=" + } + } + ], + "edges": [ + { + "source": "node_1", + "target": "node_2", + "type": "leads_to", + "weight": 3 + }, + { + "source": "node_2", + "target": "node_3", + "type": "leads_to", + "weight": 4 + }, + { + "source": "node_2", + "target": "node_4", + "type": "leads_to", + "weight": 3 + }, + { + "source": "node_3", + "target": "node_5", + "type": "leads_to", + "weight": 4 + }, + { + "source": "node_5", + "target": "node_6", + "type": "discovers", + "weight": 7 + } + ] +} + +## 重要提醒 + +1. **严禁杜撰**:只使用ReAct输入中实际执行的工具和实际返回的结果。如无实际数据,返回空的nodes和edges数组。 +2. **DAG结构必须**:必须构建真正的DAG(有向无环图),不能有任何循环。所有边的source节点id必须严格小于target节点id(source < target)。 +3. **拓扑顺序**:节点应该按照逻辑顺序编号,target节点通常是node_1,后续的action节点按执行顺序递增,vulnerability节点在最后。 +4. **完整性优先**:必须包含所有有意义的工具执行和关键步骤,不要为了控制节点数量而删除重要节点。攻击链必须能够完整展现从目标识别到漏洞发现的完整过程。 +5. **逻辑连贯**:确保攻击链能够讲述一个完整、连贯的渗透测试故事,包括所有关键步骤和决策点。 +6. **教育价值**:优先保留有教育意义的节点,帮助学习者理解渗透测试思维和完整流程。 +7. **准确性**:所有节点信息必须基于实际数据,不要推测或假设。 +8. **完整性检查**:确保每个节点都有必要的metadata字段,每条边都有正确的source和target,没有孤立节点,没有循环。 +9. **不要过度精简**:如果实际执行步骤较多,可以适当增加节点数量(最多20个),确保不遗漏关键步骤。 +10. **输出前验证**:在输出JSON前,必须验证所有边都满足source < target的条件,确保DAG结构正确。 + +现在开始分析并构建攻击链:`, reactInput, modelOutput) +} + +// saveChain 保存攻击链到数据库 +func (b *Builder) saveChain(conversationID string, nodes []Node, edges []Edge) error { + // 先删除旧的攻击链数据 + if err := b.db.DeleteAttackChain(conversationID); err != nil { + b.logger.Warn("删除旧攻击链失败", zap.Error(err)) + } + + for _, node := range nodes { + metadataJSON, _ := json.Marshal(node.Metadata) + if err := b.db.SaveAttackChainNode(conversationID, node.ID, node.Type, node.Label, "", string(metadataJSON), node.RiskScore); err != nil { + b.logger.Warn("保存攻击链节点失败", zap.String("nodeId", node.ID), zap.Error(err)) + } + } + + // 保存边 + for _, edge := range edges { + if err := b.db.SaveAttackChainEdge(conversationID, edge.ID, edge.Source, edge.Target, edge.Type, edge.Weight); err != nil { + b.logger.Warn("保存攻击链边失败", zap.String("edgeId", edge.ID), zap.Error(err)) + } + } + + return nil +} + +// LoadChainFromDatabase 从数据库加载攻击链 +func (b *Builder) LoadChainFromDatabase(conversationID string) (*Chain, error) { + nodes, err := b.db.LoadAttackChainNodes(conversationID) + if err != nil { + return nil, fmt.Errorf("加载攻击链节点失败: %w", err) + } + + edges, err := b.db.LoadAttackChainEdges(conversationID) + if err != nil { + return nil, fmt.Errorf("加载攻击链边失败: %w", err) + } + + return &Chain{ + Nodes: nodes, + Edges: edges, + }, nil +} + +// callAIForChainGeneration 调用AI生成攻击链 +func (b *Builder) callAIForChainGeneration(ctx context.Context, prompt string) (string, error) { + requestBody := map[string]interface{}{ + "model": b.openAIConfig.Model, + "messages": []map[string]interface{}{ + { + "role": "system", + "content": "你是一个专业的安全测试分析师,擅长构建攻击链图。请严格按照JSON格式返回攻击链数据。", + }, + { + "role": "user", + "content": prompt, + }, + }, + "temperature": 0.3, + "max_tokens": 8000, + } + + var apiResponse struct { + Choices []struct { + Message struct { + Content string `json:"content"` + } `json:"message"` + } `json:"choices"` + } + + if b.openAIClient == nil { + return "", fmt.Errorf("OpenAI客户端未初始化") + } + if err := b.openAIClient.ChatCompletion(ctx, requestBody, &apiResponse); err != nil { + var apiErr *openai.APIError + if errors.As(err, &apiErr) { + bodyStr := strings.ToLower(apiErr.Body) + if strings.Contains(bodyStr, "context") || strings.Contains(bodyStr, "length") || strings.Contains(bodyStr, "too long") { + return "", fmt.Errorf("context length exceeded") + } + } else if strings.Contains(strings.ToLower(err.Error()), "context") || strings.Contains(strings.ToLower(err.Error()), "length") { + return "", fmt.Errorf("context length exceeded") + } + return "", fmt.Errorf("请求失败: %w", err) + } + + if len(apiResponse.Choices) == 0 { + return "", fmt.Errorf("API未返回有效响应") + } + + content := strings.TrimSpace(apiResponse.Choices[0].Message.Content) + // 尝试提取JSON(可能包含markdown代码块) + content = strings.TrimPrefix(content, "```json") + content = strings.TrimPrefix(content, "```") + content = strings.TrimSuffix(content, "```") + content = strings.TrimSpace(content) + + return content, nil +} + +// ChainJSON 攻击链JSON结构 +type ChainJSON struct { + Nodes []struct { + ID string `json:"id"` + Type string `json:"type"` + Label string `json:"label"` + RiskScore int `json:"risk_score"` + Metadata map[string]interface{} `json:"metadata"` + } `json:"nodes"` + Edges []struct { + Source string `json:"source"` + Target string `json:"target"` + Type string `json:"type"` + Weight int `json:"weight"` + } `json:"edges"` +} + +// parseChainJSON 解析攻击链JSON +func (b *Builder) parseChainJSON(chainJSON string) (*Chain, error) { + var chainData ChainJSON + if err := json.Unmarshal([]byte(chainJSON), &chainData); err != nil { + return nil, fmt.Errorf("解析JSON失败: %w", err) + } + + // 创建节点ID映射(AI返回的ID -> 新的UUID) + nodeIDMap := make(map[string]string) + + // 转换为Chain结构 + nodes := make([]Node, 0, len(chainData.Nodes)) + for _, n := range chainData.Nodes { + // 生成新的UUID节点ID + newNodeID := fmt.Sprintf("node_%s", uuid.New().String()) + nodeIDMap[n.ID] = newNodeID + + node := Node{ + ID: newNodeID, + Type: n.Type, + Label: n.Label, + RiskScore: n.RiskScore, + Metadata: n.Metadata, + } + if node.Metadata == nil { + node.Metadata = make(map[string]interface{}) + } + nodes = append(nodes, node) + } + + // 转换边 + edges := make([]Edge, 0, len(chainData.Edges)) + for _, e := range chainData.Edges { + sourceID, ok := nodeIDMap[e.Source] + if !ok { + continue + } + targetID, ok := nodeIDMap[e.Target] + if !ok { + continue + } + + // 生成边的ID(前端需要) + edgeID := fmt.Sprintf("edge_%s", uuid.New().String()) + + edges = append(edges, Edge{ + ID: edgeID, + Source: sourceID, + Target: targetID, + Type: e.Type, + Weight: e.Weight, + }) + } + + return &Chain{ + Nodes: nodes, + Edges: edges, + }, nil +} + +// 以下所有方法已不再使用,已删除以简化代码 diff --git a/internal/config/config.go b/internal/config/config.go new file mode 100644 index 00000000..c7ad6147 --- /dev/null +++ b/internal/config/config.go @@ -0,0 +1,877 @@ +package config + +import ( + "crypto/rand" + "encoding/base64" + "encoding/hex" + "encoding/json" + "fmt" + "os" + "path/filepath" + "strconv" + "strings" + + "gopkg.in/yaml.v3" +) + +type Config struct { + Version string `yaml:"version,omitempty" json:"version,omitempty"` // 前端显示的版本号,如 v1.3.3 + Server ServerConfig `yaml:"server"` + Log LogConfig `yaml:"log"` + MCP MCPConfig `yaml:"mcp"` + OpenAI OpenAIConfig `yaml:"openai"` + FOFA FofaConfig `yaml:"fofa,omitempty" json:"fofa,omitempty"` + Agent AgentConfig `yaml:"agent"` + Security SecurityConfig `yaml:"security"` + Database DatabaseConfig `yaml:"database"` + Auth AuthConfig `yaml:"auth"` + ExternalMCP ExternalMCPConfig `yaml:"external_mcp,omitempty"` + Knowledge KnowledgeConfig `yaml:"knowledge,omitempty"` + Robots RobotsConfig `yaml:"robots,omitempty" json:"robots,omitempty"` // 企业微信/钉钉/飞书等机器人配置 + RolesDir string `yaml:"roles_dir,omitempty" json:"roles_dir,omitempty"` // 角色配置文件目录(新方式) + Roles map[string]RoleConfig `yaml:"roles,omitempty" json:"roles,omitempty"` // 向后兼容:支持在主配置文件中定义角色 + SkillsDir string `yaml:"skills_dir,omitempty" json:"skills_dir,omitempty"` // Skills配置文件目录 + AgentsDir string `yaml:"agents_dir,omitempty" json:"agents_dir,omitempty"` // 多代理子 Agent Markdown 定义目录(*.md,YAML front matter) + MultiAgent MultiAgentConfig `yaml:"multi_agent,omitempty" json:"multi_agent,omitempty"` +} + +// MultiAgentConfig 基于 CloudWeGo Eino DeepAgent 的多代理编排(与单 Agent /agent-loop 并存)。 +type MultiAgentConfig struct { + Enabled bool `yaml:"enabled" json:"enabled"` + DefaultMode string `yaml:"default_mode" json:"default_mode"` // single | multi,供前端默认展示 + RobotUseMultiAgent bool `yaml:"robot_use_multi_agent" json:"robot_use_multi_agent"` // 为 true 时钉钉/飞书/企微机器人走 Eino 多代理 + BatchUseMultiAgent bool `yaml:"batch_use_multi_agent" json:"batch_use_multi_agent"` // 为 true 时批量任务队列中每子任务走 Eino 多代理 + MaxIteration int `yaml:"max_iteration" json:"max_iteration"` // Deep 主代理最大推理轮次 + SubAgentMaxIterations int `yaml:"sub_agent_max_iterations" json:"sub_agent_max_iterations"` + WithoutGeneralSubAgent bool `yaml:"without_general_sub_agent" json:"without_general_sub_agent"` + WithoutWriteTodos bool `yaml:"without_write_todos" json:"without_write_todos"` + OrchestratorInstruction string `yaml:"orchestrator_instruction" json:"orchestrator_instruction"` + SubAgents []MultiAgentSubConfig `yaml:"sub_agents" json:"sub_agents"` + // EinoSkills configures CloudWeGo Eino ADK skill middleware + optional local filesystem/execute on DeepAgent. + EinoSkills MultiAgentEinoSkillsConfig `yaml:"eino_skills,omitempty" json:"eino_skills,omitempty"` +} + +// MultiAgentEinoSkillsConfig toggles Eino official skill progressive disclosure and host filesystem tools. +type MultiAgentEinoSkillsConfig struct { + // Disable skips skill middleware (and does not attach local FS tools for Deep). + Disable bool `yaml:"disable" json:"disable"` + // FilesystemTools registers read_file/glob/grep/write/edit/execute (eino-ext local backend). Nil/omitted = true. + FilesystemTools *bool `yaml:"filesystem_tools,omitempty" json:"filesystem_tools,omitempty"` + // SkillToolName overrides the default Eino tool name "skill". + SkillToolName string `yaml:"skill_tool_name,omitempty" json:"skill_tool_name,omitempty"` +} + +// EinoSkillFilesystemToolsEffective returns whether Deep/sub-agents should attach local filesystem + streaming shell. +func (c MultiAgentEinoSkillsConfig) EinoSkillFilesystemToolsEffective() bool { + if c.FilesystemTools != nil { + return *c.FilesystemTools + } + return true +} + +// MultiAgentSubConfig 子代理(Eino ChatModelAgent),由 DeepAgent 通过 task 工具调度。 +type MultiAgentSubConfig struct { + ID string `yaml:"id" json:"id"` + Name string `yaml:"name" json:"name"` + Description string `yaml:"description" json:"description"` + Instruction string `yaml:"instruction" json:"instruction"` + BindRole string `yaml:"bind_role,omitempty" json:"bind_role,omitempty"` // 可选:关联主配置 roles 中的角色名;未配 role_tools 时沿用该角色的 tools,并把 skills 写入指令提示 + RoleTools []string `yaml:"role_tools" json:"role_tools"` // 与单 Agent 角色工具相同 key;空表示全部工具(bind_role 可补全 tools) + MaxIterations int `yaml:"max_iterations" json:"max_iterations"` + Kind string `yaml:"kind,omitempty" json:"kind,omitempty"` // 仅 Markdown:kind=orchestrator 表示 Deep 主代理(与 orchestrator.md 二选一约定) +} + +// MultiAgentPublic 返回给前端的精简信息(不含子代理指令全文)。 +type MultiAgentPublic struct { + Enabled bool `json:"enabled"` + DefaultMode string `json:"default_mode"` + RobotUseMultiAgent bool `json:"robot_use_multi_agent"` + BatchUseMultiAgent bool `json:"batch_use_multi_agent"` + SubAgentCount int `json:"sub_agent_count"` +} + +// MultiAgentAPIUpdate 设置页/API 仅更新多代理标量字段;写入 YAML 时不覆盖 sub_agents 等块。 +type MultiAgentAPIUpdate struct { + Enabled bool `json:"enabled"` + DefaultMode string `json:"default_mode"` + RobotUseMultiAgent bool `json:"robot_use_multi_agent"` + BatchUseMultiAgent bool `json:"batch_use_multi_agent"` +} + +// RobotsConfig 机器人配置(企业微信、钉钉、飞书等) +type RobotsConfig struct { + Wecom RobotWecomConfig `yaml:"wecom,omitempty" json:"wecom,omitempty"` // 企业微信 + Dingtalk RobotDingtalkConfig `yaml:"dingtalk,omitempty" json:"dingtalk,omitempty"` // 钉钉 + Lark RobotLarkConfig `yaml:"lark,omitempty" json:"lark,omitempty"` // 飞书 +} + +// RobotWecomConfig 企业微信机器人配置 +type RobotWecomConfig struct { + Enabled bool `yaml:"enabled" json:"enabled"` + Token string `yaml:"token" json:"token"` // 回调 URL 校验 Token + EncodingAESKey string `yaml:"encoding_aes_key" json:"encoding_aes_key"` // EncodingAESKey + CorpID string `yaml:"corp_id" json:"corp_id"` // 企业 ID + Secret string `yaml:"secret" json:"secret"` // 应用 Secret + AgentID int64 `yaml:"agent_id" json:"agent_id"` // 应用 AgentId +} + +// RobotDingtalkConfig 钉钉机器人配置 +type RobotDingtalkConfig struct { + Enabled bool `yaml:"enabled" json:"enabled"` + ClientID string `yaml:"client_id" json:"client_id"` // 应用 Key (AppKey) + ClientSecret string `yaml:"client_secret" json:"client_secret"` // 应用 Secret +} + +// RobotLarkConfig 飞书机器人配置 +type RobotLarkConfig struct { + Enabled bool `yaml:"enabled" json:"enabled"` + AppID string `yaml:"app_id" json:"app_id"` // 应用 App ID + AppSecret string `yaml:"app_secret" json:"app_secret"` // 应用 App Secret + VerifyToken string `yaml:"verify_token" json:"verify_token"` // 事件订阅 Verification Token(可选) +} + +type ServerConfig struct { + Host string `yaml:"host"` + Port int `yaml:"port"` +} + +type LogConfig struct { + Level string `yaml:"level"` + Output string `yaml:"output"` +} + +type MCPConfig struct { + Enabled bool `yaml:"enabled"` + Host string `yaml:"host"` + Port int `yaml:"port"` + AuthHeader string `yaml:"auth_header,omitempty"` // 鉴权 header 名,留空表示不鉴权 + AuthHeaderValue string `yaml:"auth_header_value,omitempty"` // 鉴权 header 值,需与请求中该 header 一致 +} + +type OpenAIConfig struct { + Provider string `yaml:"provider,omitempty" json:"provider,omitempty"` // API 提供商: "openai"(默认) 或 "claude",claude 时自动桥接为 Anthropic Messages API + APIKey string `yaml:"api_key" json:"api_key"` + BaseURL string `yaml:"base_url" json:"base_url"` + Model string `yaml:"model" json:"model"` + MaxTotalTokens int `yaml:"max_total_tokens,omitempty" json:"max_total_tokens,omitempty"` +} + +type FofaConfig struct { + // Email 为 FOFA 账号邮箱;APIKey 为 FOFA API Key(建议使用只读权限的 Key) + Email string `yaml:"email,omitempty" json:"email,omitempty"` + APIKey string `yaml:"api_key,omitempty" json:"api_key,omitempty"` + BaseURL string `yaml:"base_url,omitempty" json:"base_url,omitempty"` // 默认 https://fofa.info/api/v1/search/all +} + +type SecurityConfig struct { + Tools []ToolConfig `yaml:"tools,omitempty"` // 向后兼容:支持在主配置文件中定义工具 + ToolsDir string `yaml:"tools_dir,omitempty"` // 工具配置文件目录(新方式) + ToolDescriptionMode string `yaml:"tool_description_mode,omitempty"` // 工具描述模式: "short" | "full",默认 short +} + +type DatabaseConfig struct { + Path string `yaml:"path"` // 会话数据库路径 + KnowledgeDBPath string `yaml:"knowledge_db_path,omitempty"` // 知识库数据库路径(可选,为空则使用会话数据库) +} + +type AgentConfig struct { + MaxIterations int `yaml:"max_iterations" json:"max_iterations"` + LargeResultThreshold int `yaml:"large_result_threshold" json:"large_result_threshold"` // 大结果阈值(字节),默认50KB + ResultStorageDir string `yaml:"result_storage_dir" json:"result_storage_dir"` // 结果存储目录,默认tmp + ToolTimeoutMinutes int `yaml:"tool_timeout_minutes" json:"tool_timeout_minutes"` // 单次工具执行最大时长(分钟),超时自动终止,防止长时间挂起;0 表示不限制(不推荐) +} + +type AuthConfig struct { + Password string `yaml:"password" json:"password"` + SessionDurationHours int `yaml:"session_duration_hours" json:"session_duration_hours"` + GeneratedPassword string `yaml:"-" json:"-"` + GeneratedPasswordPersisted bool `yaml:"-" json:"-"` + GeneratedPasswordPersistErr string `yaml:"-" json:"-"` +} + +// ExternalMCPConfig 外部MCP配置 +type ExternalMCPConfig struct { + Servers map[string]ExternalMCPServerConfig `yaml:"servers,omitempty" json:"servers,omitempty"` +} + +// ExternalMCPServerConfig 外部MCP服务器配置 +type ExternalMCPServerConfig struct { + // stdio模式配置 + Command string `yaml:"command,omitempty" json:"command,omitempty"` + Args []string `yaml:"args,omitempty" json:"args,omitempty"` + Env map[string]string `yaml:"env,omitempty" json:"env,omitempty"` // 环境变量(用于stdio模式) + + // HTTP模式配置 + Transport string `yaml:"transport,omitempty" json:"transport,omitempty"` // "stdio" | "sse" | "http"(Streamable) | "simple_http"(自建/简单POST端点,如本机 http://127.0.0.1:8081/mcp) + URL string `yaml:"url,omitempty" json:"url,omitempty"` + Headers map[string]string `yaml:"headers,omitempty" json:"headers,omitempty"` // HTTP/SSE 请求头(如 x-api-key) + + // 通用配置 + Description string `yaml:"description,omitempty" json:"description,omitempty"` + Timeout int `yaml:"timeout,omitempty" json:"timeout,omitempty"` // 超时时间(秒) + ExternalMCPEnable bool `yaml:"external_mcp_enable,omitempty" json:"external_mcp_enable,omitempty"` // 是否启用外部MCP + ToolEnabled map[string]bool `yaml:"tool_enabled,omitempty" json:"tool_enabled,omitempty"` // 每个工具的启用状态(工具名称 -> 是否启用) + + // 向后兼容字段(已废弃,保留用于读取旧配置) + Enabled bool `yaml:"enabled,omitempty" json:"enabled,omitempty"` // 已废弃,使用 external_mcp_enable + Disabled bool `yaml:"disabled,omitempty" json:"disabled,omitempty"` // 已废弃,使用 external_mcp_enable +} +type ToolConfig struct { + Name string `yaml:"name"` + Command string `yaml:"command"` + Args []string `yaml:"args,omitempty"` // 固定参数(可选) + ShortDescription string `yaml:"short_description,omitempty"` // 简短描述(用于工具列表,减少token消耗) + Description string `yaml:"description"` // 详细描述(用于工具文档) + Enabled bool `yaml:"enabled"` + Parameters []ParameterConfig `yaml:"parameters,omitempty"` // 参数定义(可选) + ArgMapping string `yaml:"arg_mapping,omitempty"` // 参数映射方式: "auto", "manual", "template"(可选) + AllowedExitCodes []int `yaml:"allowed_exit_codes,omitempty"` // 允许的退出码列表(某些工具在成功时也返回非零退出码) +} + +// ParameterConfig 参数配置 +type ParameterConfig struct { + Name string `yaml:"name"` // 参数名称 + Type string `yaml:"type"` // 参数类型: string, int, bool, array + Description string `yaml:"description"` // 参数描述 + Required bool `yaml:"required,omitempty"` // 是否必需 + Default interface{} `yaml:"default,omitempty"` // 默认值 + ItemType string `yaml:"item_type,omitempty"` // 当 type 为 array 时,数组元素类型,如 string, number, object + Flag string `yaml:"flag,omitempty"` // 命令行标志,如 "-u", "--url", "-p" + Position *int `yaml:"position,omitempty"` // 位置参数的位置(从0开始) + Format string `yaml:"format,omitempty"` // 参数格式: "flag", "positional", "combined" (flag=value), "template" + Template string `yaml:"template,omitempty"` // 模板字符串,如 "{flag} {value}" 或 "{value}" + Options []string `yaml:"options,omitempty"` // 可选值列表(用于枚举) +} + +func Load(path string) (*Config, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("读取配置文件失败: %w", err) + } + + var cfg Config + if err := yaml.Unmarshal(data, &cfg); err != nil { + return nil, fmt.Errorf("解析配置文件失败: %w", err) + } + + if cfg.Auth.SessionDurationHours <= 0 { + cfg.Auth.SessionDurationHours = 12 + } + + if strings.TrimSpace(cfg.Auth.Password) == "" { + password, err := generateStrongPassword(24) + if err != nil { + return nil, fmt.Errorf("生成默认密码失败: %w", err) + } + + cfg.Auth.Password = password + cfg.Auth.GeneratedPassword = password + + if err := PersistAuthPassword(path, password); err != nil { + cfg.Auth.GeneratedPasswordPersisted = false + cfg.Auth.GeneratedPasswordPersistErr = err.Error() + } else { + cfg.Auth.GeneratedPasswordPersisted = true + } + } + + // 如果配置了工具目录,从目录加载工具配置 + if cfg.Security.ToolsDir != "" { + configDir := filepath.Dir(path) + toolsDir := cfg.Security.ToolsDir + + // 如果是相对路径,相对于配置文件所在目录 + if !filepath.IsAbs(toolsDir) { + toolsDir = filepath.Join(configDir, toolsDir) + } + + tools, err := LoadToolsFromDir(toolsDir) + if err != nil { + return nil, fmt.Errorf("从工具目录加载工具配置失败: %w", err) + } + + // 合并工具配置:目录中的工具优先,主配置中的工具作为补充 + existingTools := make(map[string]bool) + for _, tool := range tools { + existingTools[tool.Name] = true + } + + // 添加主配置中不存在于目录中的工具(向后兼容) + for _, tool := range cfg.Security.Tools { + if !existingTools[tool.Name] { + tools = append(tools, tool) + } + } + + cfg.Security.Tools = tools + } + + // 迁移外部MCP配置:将旧的 enabled/disabled 字段迁移到 external_mcp_enable + if cfg.ExternalMCP.Servers != nil { + for name, serverCfg := range cfg.ExternalMCP.Servers { + // 如果已经设置了 external_mcp_enable,跳过迁移 + // 否则从 enabled/disabled 字段迁移 + // 注意:由于 ExternalMCPEnable 是 bool 类型,零值为 false,所以需要检查是否真的设置了 + // 这里我们通过检查旧的 enabled/disabled 字段来判断是否需要迁移 + if serverCfg.Disabled { + // 旧配置使用 disabled,迁移到 external_mcp_enable + serverCfg.ExternalMCPEnable = false + } else if serverCfg.Enabled { + // 旧配置使用 enabled,迁移到 external_mcp_enable + serverCfg.ExternalMCPEnable = true + } else { + // 都没有设置,默认为启用 + serverCfg.ExternalMCPEnable = true + } + cfg.ExternalMCP.Servers[name] = serverCfg + } + } + + // 从角色目录加载角色配置 + if cfg.RolesDir != "" { + configDir := filepath.Dir(path) + rolesDir := cfg.RolesDir + + // 如果是相对路径,相对于配置文件所在目录 + if !filepath.IsAbs(rolesDir) { + rolesDir = filepath.Join(configDir, rolesDir) + } + + roles, err := LoadRolesFromDir(rolesDir) + if err != nil { + return nil, fmt.Errorf("从角色目录加载角色配置失败: %w", err) + } + + cfg.Roles = roles + } else { + // 如果未配置 roles_dir,初始化为空 map + if cfg.Roles == nil { + cfg.Roles = make(map[string]RoleConfig) + } + } + + return &cfg, nil +} + +func generateStrongPassword(length int) (string, error) { + if length <= 0 { + length = 24 + } + + bytesLen := length + randomBytes := make([]byte, bytesLen) + if _, err := rand.Read(randomBytes); err != nil { + return "", err + } + + password := base64.RawURLEncoding.EncodeToString(randomBytes) + if len(password) > length { + password = password[:length] + } + return password, nil +} + +func PersistAuthPassword(path, password string) error { + data, err := os.ReadFile(path) + if err != nil { + return err + } + + lines := strings.Split(string(data), "\n") + inAuthBlock := false + authIndent := -1 + + for i, line := range lines { + trimmed := strings.TrimSpace(line) + if !inAuthBlock { + if strings.HasPrefix(trimmed, "auth:") { + inAuthBlock = true + authIndent = len(line) - len(strings.TrimLeft(line, " ")) + } + continue + } + + if trimmed == "" || strings.HasPrefix(trimmed, "#") { + continue + } + + leadingSpaces := len(line) - len(strings.TrimLeft(line, " ")) + if leadingSpaces <= authIndent { + // 离开 auth 块 + inAuthBlock = false + authIndent = -1 + // 继续寻找其它 auth 块(理论上没有) + if strings.HasPrefix(trimmed, "auth:") { + inAuthBlock = true + authIndent = leadingSpaces + } + continue + } + + if strings.HasPrefix(strings.TrimSpace(line), "password:") { + prefix := line[:len(line)-len(strings.TrimLeft(line, " "))] + comment := "" + if idx := strings.Index(line, "#"); idx >= 0 { + comment = strings.TrimRight(line[idx:], " ") + } + + newLine := fmt.Sprintf("%spassword: %s", prefix, password) + if comment != "" { + if !strings.HasPrefix(comment, " ") { + newLine += " " + } + newLine += comment + } + lines[i] = newLine + break + } + } + + return os.WriteFile(path, []byte(strings.Join(lines, "\n")), 0644) +} + +func PrintGeneratedPasswordWarning(password string, persisted bool, persistErr string) { + if strings.TrimSpace(password) == "" { + return + } + + if persisted { + fmt.Println("[CyberStrikeAI] ✅ 已为您自动生成并写入 Web 登录密码。") + } else { + if persistErr != "" { + fmt.Printf("[CyberStrikeAI] ⚠️ 无法自动写入配置文件中的密码: %s\n", persistErr) + } else { + fmt.Println("[CyberStrikeAI] ⚠️ 无法自动写入配置文件中的密码。") + } + fmt.Println("请手动将以下随机密码写入 config.yaml 的 auth.password:") + } + + fmt.Println("----------------------------------------------------------------") + fmt.Println("CyberStrikeAI Auto-Generated Web Password") + fmt.Printf("Password: %s\n", password) + fmt.Println("WARNING: Anyone with this password can fully control CyberStrikeAI.") + fmt.Println("Please store it securely and change it in config.yaml as soon as possible.") + fmt.Println("警告:持有此密码的人将拥有对 CyberStrikeAI 的完全控制权限。") + fmt.Println("请妥善保管,并尽快在 config.yaml 中修改 auth.password!") + fmt.Println("----------------------------------------------------------------") +} + +// generateRandomToken 生成用于 MCP 鉴权的随机字符串(64 位十六进制) +func generateRandomToken() (string, error) { + b := make([]byte, 32) + if _, err := rand.Read(b); err != nil { + return "", err + } + return hex.EncodeToString(b), nil +} + +// persistMCPAuth 将 MCP 的 auth_header / auth_header_value 写回配置文件 +func persistMCPAuth(path string, mcp *MCPConfig) error { + data, err := os.ReadFile(path) + if err != nil { + return err + } + lines := strings.Split(string(data), "\n") + inMcpBlock := false + mcpIndent := -1 + + for i, line := range lines { + trimmed := strings.TrimSpace(line) + if !inMcpBlock { + if strings.HasPrefix(trimmed, "mcp:") { + inMcpBlock = true + mcpIndent = len(line) - len(strings.TrimLeft(line, " ")) + } + continue + } + if trimmed == "" || strings.HasPrefix(trimmed, "#") { + continue + } + leadingSpaces := len(line) - len(strings.TrimLeft(line, " ")) + if leadingSpaces <= mcpIndent { + inMcpBlock = false + mcpIndent = -1 + if strings.HasPrefix(trimmed, "mcp:") { + inMcpBlock = true + mcpIndent = leadingSpaces + } + continue + } + + prefix := line[:leadingSpaces] + rest := strings.TrimSpace(line[leadingSpaces:]) + comment := "" + if idx := strings.Index(line, "#"); idx >= 0 { + comment = strings.TrimRight(line[idx:], " ") + } + withComment := "" + if comment != "" { + if !strings.HasPrefix(comment, " ") { + withComment = " " + } + withComment += comment + } + + if strings.HasPrefix(rest, "auth_header_value:") { + lines[i] = fmt.Sprintf("%sauth_header_value: %q%s", prefix, mcp.AuthHeaderValue, withComment) + } else if strings.HasPrefix(rest, "auth_header:") { + lines[i] = fmt.Sprintf("%sauth_header: %q%s", prefix, mcp.AuthHeader, withComment) + } + } + + return os.WriteFile(path, []byte(strings.Join(lines, "\n")), 0644) +} + +// EnsureMCPAuth 在 MCP 启用且 auth_header_value 为空时,自动生成随机密钥并写回配置 +func EnsureMCPAuth(path string, cfg *Config) error { + if !cfg.MCP.Enabled || strings.TrimSpace(cfg.MCP.AuthHeaderValue) != "" { + return nil + } + token, err := generateRandomToken() + if err != nil { + return fmt.Errorf("生成 MCP 鉴权密钥失败: %w", err) + } + cfg.MCP.AuthHeaderValue = token + if strings.TrimSpace(cfg.MCP.AuthHeader) == "" { + cfg.MCP.AuthHeader = "X-MCP-Token" + } + return persistMCPAuth(path, &cfg.MCP) +} + +// PrintMCPConfigJSON 向终端输出 MCP 配置的 JSON,可直接复制到 Cursor / Claude Code 的 mcp 配置中使用 +func PrintMCPConfigJSON(mcp MCPConfig) { + if !mcp.Enabled { + return + } + hostForURL := strings.TrimSpace(mcp.Host) + if hostForURL == "" || hostForURL == "0.0.0.0" { + hostForURL = "localhost" + } + url := fmt.Sprintf("http://%s:%d/mcp", hostForURL, mcp.Port) + headers := map[string]string{} + if mcp.AuthHeader != "" { + headers[mcp.AuthHeader] = mcp.AuthHeaderValue + } + serverEntry := map[string]interface{}{ + "url": url, + } + if len(headers) > 0 { + serverEntry["headers"] = headers + } + // Claude Code 需要 type: "http" + serverEntry["type"] = "http" + out := map[string]interface{}{ + "mcpServers": map[string]interface{}{ + "cyberstrike-ai": serverEntry, + }, + } + b, _ := json.MarshalIndent(out, "", " ") + fmt.Println("[CyberStrikeAI] MCP 配置(可复制到 Cursor / Claude Code 使用):") + fmt.Println(" Cursor: 放入 ~/.cursor/mcp.json 的 mcpServers,或项目 .cursor/mcp.json") + fmt.Println(" Claude Code: 放入 .mcp.json 或 ~/.claude.json 的 mcpServers") + fmt.Println("----------------------------------------------------------------") + fmt.Println(string(b)) + fmt.Println("----------------------------------------------------------------") +} + +// LoadToolsFromDir 从目录加载所有工具配置文件 +func LoadToolsFromDir(dir string) ([]ToolConfig, error) { + var tools []ToolConfig + + // 检查目录是否存在 + if _, err := os.Stat(dir); os.IsNotExist(err) { + return tools, nil // 目录不存在时返回空列表,不报错 + } + + // 读取目录中的所有 .yaml 和 .yml 文件 + entries, err := os.ReadDir(dir) + if err != nil { + return nil, fmt.Errorf("读取工具目录失败: %w", err) + } + + for _, entry := range entries { + if entry.IsDir() { + continue + } + + name := entry.Name() + if !strings.HasSuffix(name, ".yaml") && !strings.HasSuffix(name, ".yml") { + continue + } + + filePath := filepath.Join(dir, name) + tool, err := LoadToolFromFile(filePath) + if err != nil { + // 记录错误但继续加载其他文件 + fmt.Printf("警告: 加载工具配置文件 %s 失败: %v\n", filePath, err) + continue + } + + tools = append(tools, *tool) + } + + return tools, nil +} + +// LoadToolFromFile 从单个文件加载工具配置 +func LoadToolFromFile(path string) (*ToolConfig, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("读取文件失败: %w", err) + } + + var tool ToolConfig + if err := yaml.Unmarshal(data, &tool); err != nil { + return nil, fmt.Errorf("解析工具配置失败: %w", err) + } + + // 验证必需字段 + if tool.Name == "" { + return nil, fmt.Errorf("工具名称不能为空") + } + if tool.Command == "" { + return nil, fmt.Errorf("工具命令不能为空") + } + + return &tool, nil +} + +// LoadRolesFromDir 从目录加载所有角色配置文件 +func LoadRolesFromDir(dir string) (map[string]RoleConfig, error) { + roles := make(map[string]RoleConfig) + + // 检查目录是否存在 + if _, err := os.Stat(dir); os.IsNotExist(err) { + return roles, nil // 目录不存在时返回空map,不报错 + } + + // 读取目录中的所有 .yaml 和 .yml 文件 + entries, err := os.ReadDir(dir) + if err != nil { + return nil, fmt.Errorf("读取角色目录失败: %w", err) + } + + for _, entry := range entries { + if entry.IsDir() { + continue + } + + name := entry.Name() + if !strings.HasSuffix(name, ".yaml") && !strings.HasSuffix(name, ".yml") { + continue + } + + filePath := filepath.Join(dir, name) + role, err := LoadRoleFromFile(filePath) + if err != nil { + // 记录错误但继续加载其他文件 + fmt.Printf("警告: 加载角色配置文件 %s 失败: %v\n", filePath, err) + continue + } + + // 使用角色名称作为key + roleName := role.Name + if roleName == "" { + // 如果角色名称为空,使用文件名(去掉扩展名)作为名称 + roleName = strings.TrimSuffix(strings.TrimSuffix(name, ".yaml"), ".yml") + role.Name = roleName + } + + roles[roleName] = *role + } + + return roles, nil +} + +// LoadRoleFromFile 从单个文件加载角色配置 +func LoadRoleFromFile(path string) (*RoleConfig, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("读取文件失败: %w", err) + } + + var role RoleConfig + if err := yaml.Unmarshal(data, &role); err != nil { + return nil, fmt.Errorf("解析角色配置失败: %w", err) + } + + // 处理 icon 字段:如果包含 Unicode 转义格式(\U0001F3C6),转换为实际的 Unicode 字符 + // Go 的 yaml 库可能不会自动解析 \U 转义序列,需要手动转换 + if role.Icon != "" { + icon := role.Icon + // 去除可能的引号 + icon = strings.Trim(icon, `"`) + + // 检查是否是 Unicode 转义格式 \U0001F3C6(8位十六进制)或 \uXXXX(4位十六进制) + if len(icon) >= 3 && icon[0] == '\\' { + if icon[1] == 'U' && len(icon) >= 10 { + // \U0001F3C6 格式(8位十六进制) + if codePoint, err := strconv.ParseInt(icon[2:10], 16, 32); err == nil { + role.Icon = string(rune(codePoint)) + } + } else if icon[1] == 'u' && len(icon) >= 6 { + // \uXXXX 格式(4位十六进制) + if codePoint, err := strconv.ParseInt(icon[2:6], 16, 32); err == nil { + role.Icon = string(rune(codePoint)) + } + } + } + } + + // 验证必需字段 + if role.Name == "" { + // 如果名称为空,尝试从文件名获取 + baseName := filepath.Base(path) + role.Name = strings.TrimSuffix(strings.TrimSuffix(baseName, ".yaml"), ".yml") + } + + return &role, nil +} + +func Default() *Config { + return &Config{ + Server: ServerConfig{ + Host: "0.0.0.0", + Port: 8080, + }, + Log: LogConfig{ + Level: "info", + Output: "stdout", + }, + MCP: MCPConfig{ + Enabled: true, + Host: "0.0.0.0", + Port: 8081, + }, + OpenAI: OpenAIConfig{ + BaseURL: "https://api.openai.com/v1", + Model: "gpt-4", + MaxTotalTokens: 120000, + }, + Agent: AgentConfig{ + MaxIterations: 30, // 默认最大迭代次数 + ToolTimeoutMinutes: 10, // 单次工具执行默认最多 10 分钟,避免异常长时间占用 + }, + Security: SecurityConfig{ + Tools: []ToolConfig{}, // 工具配置应该从 config.yaml 或 tools/ 目录加载 + ToolsDir: "tools", // 默认工具目录 + }, + Database: DatabaseConfig{ + Path: "data/conversations.db", + KnowledgeDBPath: "data/knowledge.db", // 默认知识库数据库路径 + }, + Auth: AuthConfig{ + SessionDurationHours: 12, + }, + Knowledge: KnowledgeConfig{ + Enabled: true, + BasePath: "knowledge_base", + Embedding: EmbeddingConfig{ + Provider: "openai", + Model: "text-embedding-3-small", + BaseURL: "https://api.openai.com/v1", + }, + Retrieval: RetrievalConfig{ + TopK: 5, + SimilarityThreshold: 0.65, // 降低阈值到 0.65,减少漏检 + }, + Indexing: IndexingConfig{ + ChunkStrategy: "markdown_then_recursive", + RequestTimeoutSeconds: 120, + ChunkSize: 768, // 增加到 768,更好的上下文保持 + ChunkOverlap: 50, + MaxChunksPerItem: 20, // 限制单个知识项最多 20 个块,避免消耗过多配额 + BatchSize: 64, + PreferSourceFile: false, + MaxRPM: 100, // 默认 100 RPM,避免 429 错误 + RateLimitDelayMs: 600, // 600ms 间隔,对应 100 RPM + MaxRetries: 3, + RetryDelayMs: 1000, + SubIndexes: nil, + }, + }, + } +} + +// KnowledgeConfig 知识库配置 +type KnowledgeConfig struct { + Enabled bool `yaml:"enabled" json:"enabled"` // 是否启用知识检索 + BasePath string `yaml:"base_path" json:"base_path"` // 知识库路径 + Embedding EmbeddingConfig `yaml:"embedding" json:"embedding"` + Retrieval RetrievalConfig `yaml:"retrieval" json:"retrieval"` + Indexing IndexingConfig `yaml:"indexing,omitempty" json:"indexing,omitempty"` // 索引构建配置 +} + +// IndexingConfig 索引构建配置(用于控制知识库索引构建时的行为) +type IndexingConfig struct { + // ChunkStrategy: "markdown_then_recursive"(默认,Eino Markdown 标题切分后再递归切)或 "recursive"(仅递归切分) + ChunkStrategy string `yaml:"chunk_strategy,omitempty" json:"chunk_strategy,omitempty"` + // RequestTimeoutSeconds 嵌入 HTTP 客户端超时(秒),0 表示使用默认 120 + RequestTimeoutSeconds int `yaml:"request_timeout_seconds,omitempty" json:"request_timeout_seconds,omitempty"` + // 分块配置 + ChunkSize int `yaml:"chunk_size,omitempty" json:"chunk_size,omitempty"` // 每个块的最大 token 数(估算),默认 512 + ChunkOverlap int `yaml:"chunk_overlap,omitempty" json:"chunk_overlap,omitempty"` // 块之间的重叠 token 数,默认 50 + MaxChunksPerItem int `yaml:"max_chunks_per_item,omitempty" json:"max_chunks_per_item,omitempty"` // 单个知识项的最大块数量,0 表示不限制 + + // PreferSourceFile 为 true 时优先用 Eino FileLoader 从 file_path 读原文再索引(与库内 content 不一致时以磁盘为准) + PreferSourceFile bool `yaml:"prefer_source_file,omitempty" json:"prefer_source_file,omitempty"` + + // 速率限制配置(用于避免 API 速率限制) + RateLimitDelayMs int `yaml:"rate_limit_delay_ms,omitempty" json:"rate_limit_delay_ms,omitempty"` // 请求间隔时间(毫秒),0 表示不使用固定延迟 + MaxRPM int `yaml:"max_rpm,omitempty" json:"max_rpm,omitempty"` // 每分钟最大请求数,0 表示不限制 + + // 重试配置(用于处理临时错误) + MaxRetries int `yaml:"max_retries,omitempty" json:"max_retries,omitempty"` // 最大重试次数,默认 3 + RetryDelayMs int `yaml:"retry_delay_ms,omitempty" json:"retry_delay_ms,omitempty"` // 重试间隔(毫秒),默认 1000 + + // BatchSize 嵌入批大小(SQLite 索引写入),0 表示默认 64 + BatchSize int `yaml:"batch_size,omitempty" json:"batch_size,omitempty"` + // SubIndexes 传入 Eino indexer.WithSubIndexes(逻辑分区标记,随 Document 元数据传递) + SubIndexes []string `yaml:"sub_indexes,omitempty" json:"sub_indexes,omitempty"` +} + +// EmbeddingConfig 嵌入配置 +type EmbeddingConfig struct { + Provider string `yaml:"provider" json:"provider"` // 嵌入模型提供商 + Model string `yaml:"model" json:"model"` // 模型名称 + BaseURL string `yaml:"base_url" json:"base_url"` // API Base URL + APIKey string `yaml:"api_key" json:"api_key"` // API Key(从OpenAI配置继承) +} + +// PostRetrieveConfig 检索后处理:固定对正文做规范化去重(最佳实践)、上下文预算截断;PrefetchTopK 用于多取候选再收敛到 top_k。 +type PostRetrieveConfig struct { + // PrefetchTopK 向量检索阶段最多保留的候选数(余弦序),应 ≥ top_k,0 表示与 top_k 相同;上限见知识库包内常量。 + PrefetchTopK int `yaml:"prefetch_top_k,omitempty" json:"prefetch_top_k,omitempty"` + // MaxContextChars 返回文档内容总 Unicode 字符数上限(整段 chunk,不截断半段);0 表示不限制。 + MaxContextChars int `yaml:"max_context_chars,omitempty" json:"max_context_chars,omitempty"` + // MaxContextTokens 返回文档内容总 token 上限(tiktoken,按嵌入模型名映射,失败则 cl100k_base);0 表示不限制。 + MaxContextTokens int `yaml:"max_context_tokens,omitempty" json:"max_context_tokens,omitempty"` +} + +// RetrievalConfig 检索配置 +type RetrievalConfig struct { + TopK int `yaml:"top_k" json:"top_k"` // 检索Top-K + SimilarityThreshold float64 `yaml:"similarity_threshold" json:"similarity_threshold"` // 余弦相似度阈值 + // SubIndexFilter 非空时仅保留 sub_indexes 含该标签(逗号分隔之一)的行;sub_indexes 为空的旧行仍返回。 + SubIndexFilter string `yaml:"sub_index_filter,omitempty" json:"sub_index_filter,omitempty"` + // PostRetrieve 检索后处理(去重、预算截断);重排通过代码注入 [knowledge.DocumentReranker]。 + PostRetrieve PostRetrieveConfig `yaml:"post_retrieve,omitempty" json:"post_retrieve,omitempty"` +} + +// RolesConfig 角色配置(已废弃,使用 map[string]RoleConfig 替代) +// 保留此类型以兼容旧代码,但建议直接使用 map[string]RoleConfig +type RolesConfig struct { + Roles map[string]RoleConfig `yaml:"roles,omitempty" json:"roles,omitempty"` +} + +// RoleConfig 单个角色配置 +type RoleConfig struct { + Name string `yaml:"name" json:"name"` // 角色名称 + Description string `yaml:"description" json:"description"` // 角色描述 + UserPrompt string `yaml:"user_prompt" json:"user_prompt"` // 用户提示词(追加到用户消息前) + Icon string `yaml:"icon,omitempty" json:"icon,omitempty"` // 角色图标(可选) + Tools []string `yaml:"tools,omitempty" json:"tools,omitempty"` // 关联的工具列表(toolKey格式,如 "toolName" 或 "mcpName::toolName") + MCPs []string `yaml:"mcps,omitempty" json:"mcps,omitempty"` // 向后兼容:关联的MCP服务器列表(已废弃,使用tools替代) + Skills []string `yaml:"skills,omitempty" json:"skills,omitempty"` // 关联的skills列表(skill名称列表,在执行任务前会读取这些skills的内容) + Enabled bool `yaml:"enabled" json:"enabled"` // 是否启用 +} diff --git a/internal/database/attackchain.go b/internal/database/attackchain.go new file mode 100644 index 00000000..c8529e70 --- /dev/null +++ b/internal/database/attackchain.go @@ -0,0 +1,168 @@ +package database + +import ( + "database/sql" + "encoding/json" + "fmt" + + "go.uber.org/zap" +) + +// AttackChainNode 攻击链节点 +type AttackChainNode struct { + ID string `json:"id"` + Type string `json:"type"` // tool, vulnerability, target, exploit + Label string `json:"label"` + ToolExecutionID string `json:"tool_execution_id,omitempty"` + Metadata map[string]interface{} `json:"metadata"` + RiskScore int `json:"risk_score"` +} + +// AttackChainEdge 攻击链边 +type AttackChainEdge struct { + ID string `json:"id"` + Source string `json:"source"` + Target string `json:"target"` + Type string `json:"type"` // leads_to, exploits, enables, depends_on + Weight int `json:"weight"` +} + +// SaveAttackChainNode 保存攻击链节点 +func (db *DB) SaveAttackChainNode(conversationID, nodeID, nodeType, nodeName, toolExecutionID, metadata string, riskScore int) error { + var toolExecID sql.NullString + if toolExecutionID != "" { + toolExecID = sql.NullString{String: toolExecutionID, Valid: true} + } + + var metadataJSON sql.NullString + if metadata != "" { + metadataJSON = sql.NullString{String: metadata, Valid: true} + } + + query := ` + INSERT OR REPLACE INTO attack_chain_nodes + (id, conversation_id, node_type, node_name, tool_execution_id, metadata, risk_score, created_at) + VALUES (?, ?, ?, ?, ?, ?, ?, CURRENT_TIMESTAMP) + ` + + _, err := db.Exec(query, nodeID, conversationID, nodeType, nodeName, toolExecID, metadataJSON, riskScore) + if err != nil { + db.logger.Error("保存攻击链节点失败", zap.Error(err), zap.String("nodeId", nodeID)) + return err + } + + return nil +} + +// SaveAttackChainEdge 保存攻击链边 +func (db *DB) SaveAttackChainEdge(conversationID, edgeID, sourceNodeID, targetNodeID, edgeType string, weight int) error { + query := ` + INSERT OR REPLACE INTO attack_chain_edges + (id, conversation_id, source_node_id, target_node_id, edge_type, weight, created_at) + VALUES (?, ?, ?, ?, ?, ?, CURRENT_TIMESTAMP) + ` + + _, err := db.Exec(query, edgeID, conversationID, sourceNodeID, targetNodeID, edgeType, weight) + if err != nil { + db.logger.Error("保存攻击链边失败", zap.Error(err), zap.String("edgeId", edgeID)) + return err + } + + return nil +} + +// LoadAttackChainNodes 加载攻击链节点 +func (db *DB) LoadAttackChainNodes(conversationID string) ([]AttackChainNode, error) { + query := ` + SELECT id, node_type, node_name, tool_execution_id, metadata, risk_score + FROM attack_chain_nodes + WHERE conversation_id = ? + ORDER BY created_at ASC + ` + + rows, err := db.Query(query, conversationID) + if err != nil { + return nil, fmt.Errorf("查询攻击链节点失败: %w", err) + } + defer rows.Close() + + var nodes []AttackChainNode + for rows.Next() { + var node AttackChainNode + var toolExecID sql.NullString + var metadataJSON sql.NullString + + err := rows.Scan(&node.ID, &node.Type, &node.Label, &toolExecID, &metadataJSON, &node.RiskScore) + if err != nil { + db.logger.Warn("扫描攻击链节点失败", zap.Error(err)) + continue + } + + if toolExecID.Valid { + node.ToolExecutionID = toolExecID.String + } + + if metadataJSON.Valid && metadataJSON.String != "" { + if err := json.Unmarshal([]byte(metadataJSON.String), &node.Metadata); err != nil { + db.logger.Warn("解析节点元数据失败", zap.Error(err)) + node.Metadata = make(map[string]interface{}) + } + } else { + node.Metadata = make(map[string]interface{}) + } + + nodes = append(nodes, node) + } + + return nodes, nil +} + +// LoadAttackChainEdges 加载攻击链边 +func (db *DB) LoadAttackChainEdges(conversationID string) ([]AttackChainEdge, error) { + query := ` + SELECT id, source_node_id, target_node_id, edge_type, weight + FROM attack_chain_edges + WHERE conversation_id = ? + ORDER BY created_at ASC + ` + + rows, err := db.Query(query, conversationID) + if err != nil { + return nil, fmt.Errorf("查询攻击链边失败: %w", err) + } + defer rows.Close() + + var edges []AttackChainEdge + for rows.Next() { + var edge AttackChainEdge + + err := rows.Scan(&edge.ID, &edge.Source, &edge.Target, &edge.Type, &edge.Weight) + if err != nil { + db.logger.Warn("扫描攻击链边失败", zap.Error(err)) + continue + } + + edges = append(edges, edge) + } + + return edges, nil +} + +// DeleteAttackChain 删除对话的攻击链数据 +func (db *DB) DeleteAttackChain(conversationID string) error { + // 先删除边(因为有外键约束) + _, err := db.Exec("DELETE FROM attack_chain_edges WHERE conversation_id = ?", conversationID) + if err != nil { + db.logger.Warn("删除攻击链边失败", zap.Error(err)) + } + + // 再删除节点 + _, err = db.Exec("DELETE FROM attack_chain_nodes WHERE conversation_id = ?", conversationID) + if err != nil { + db.logger.Error("删除攻击链节点失败", zap.Error(err), zap.String("conversationId", conversationID)) + return err + } + + return nil +} + diff --git a/internal/database/batch_task.go b/internal/database/batch_task.go new file mode 100644 index 00000000..c774be65 --- /dev/null +++ b/internal/database/batch_task.go @@ -0,0 +1,537 @@ +package database + +import ( + "database/sql" + "fmt" + "strings" + "time" + + "go.uber.org/zap" +) + +// BatchTaskQueueRow 批量任务队列数据库行 +type BatchTaskQueueRow struct { + ID string + Title sql.NullString + Role sql.NullString + AgentMode sql.NullString + ScheduleMode sql.NullString + CronExpr sql.NullString + NextRunAt sql.NullTime + ScheduleEnabled sql.NullInt64 + LastScheduleTriggerAt sql.NullTime + LastScheduleError sql.NullString + LastRunError sql.NullString + Status string + CreatedAt time.Time + StartedAt sql.NullTime + CompletedAt sql.NullTime + CurrentIndex int +} + +// BatchTaskRow 批量任务数据库行 +type BatchTaskRow struct { + ID string + QueueID string + Message string + ConversationID sql.NullString + Status string + StartedAt sql.NullTime + CompletedAt sql.NullTime + Error sql.NullString + Result sql.NullString +} + +// CreateBatchQueue 创建批量任务队列 +func (db *DB) CreateBatchQueue( + queueID string, + title string, + role string, + agentMode string, + scheduleMode string, + cronExpr string, + nextRunAt *time.Time, + tasks []map[string]interface{}, +) error { + tx, err := db.Begin() + if err != nil { + return fmt.Errorf("开始事务失败: %w", err) + } + defer tx.Rollback() + + now := time.Now() + var nextRunAtValue interface{} + if nextRunAt != nil { + nextRunAtValue = *nextRunAt + } + + _, err = tx.Exec( + "INSERT INTO batch_task_queues (id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, status, created_at, current_index) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + queueID, title, role, agentMode, scheduleMode, cronExpr, nextRunAtValue, 1, "pending", now, 0, + ) + if err != nil { + return fmt.Errorf("创建批量任务队列失败: %w", err) + } + + // 插入任务 + for _, task := range tasks { + taskID, ok := task["id"].(string) + if !ok { + continue + } + message, ok := task["message"].(string) + if !ok { + continue + } + + _, err = tx.Exec( + "INSERT INTO batch_tasks (id, queue_id, message, status) VALUES (?, ?, ?, ?)", + taskID, queueID, message, "pending", + ) + if err != nil { + return fmt.Errorf("创建批量任务失败: %w", err) + } + } + + return tx.Commit() +} + +// GetBatchQueue 获取批量任务队列 +func (db *DB) GetBatchQueue(queueID string) (*BatchTaskQueueRow, error) { + var row BatchTaskQueueRow + var createdAt string + err := db.QueryRow( + "SELECT id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, last_schedule_trigger_at, last_schedule_error, last_run_error, status, created_at, started_at, completed_at, current_index FROM batch_task_queues WHERE id = ?", + queueID, + ).Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex) + if err == sql.ErrNoRows { + return nil, nil + } + if err != nil { + return nil, fmt.Errorf("查询批量任务队列失败: %w", err) + } + + parsedTime, parseErr := time.Parse("2006-01-02 15:04:05", createdAt) + if parseErr != nil { + // 尝试其他时间格式 + parsedTime, parseErr = time.Parse(time.RFC3339, createdAt) + if parseErr != nil { + db.logger.Warn("解析创建时间失败", zap.String("createdAt", createdAt), zap.Error(parseErr)) + parsedTime = time.Now() + } + } + row.CreatedAt = parsedTime + return &row, nil +} + +// GetAllBatchQueues 获取所有批量任务队列 +func (db *DB) GetAllBatchQueues() ([]*BatchTaskQueueRow, error) { + rows, err := db.Query( + "SELECT id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, last_schedule_trigger_at, last_schedule_error, last_run_error, status, created_at, started_at, completed_at, current_index FROM batch_task_queues ORDER BY created_at DESC", + ) + if err != nil { + return nil, fmt.Errorf("查询批量任务队列列表失败: %w", err) + } + defer rows.Close() + + var queues []*BatchTaskQueueRow + for rows.Next() { + var row BatchTaskQueueRow + var createdAt string + if err := rows.Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex); err != nil { + return nil, fmt.Errorf("扫描批量任务队列失败: %w", err) + } + parsedTime, parseErr := time.Parse("2006-01-02 15:04:05", createdAt) + if parseErr != nil { + parsedTime, parseErr = time.Parse(time.RFC3339, createdAt) + if parseErr != nil { + db.logger.Warn("解析创建时间失败", zap.String("createdAt", createdAt), zap.Error(parseErr)) + parsedTime = time.Now() + } + } + row.CreatedAt = parsedTime + queues = append(queues, &row) + } + + return queues, nil +} + +// ListBatchQueues 列出批量任务队列(支持筛选和分页) +func (db *DB) ListBatchQueues(limit, offset int, status, keyword string) ([]*BatchTaskQueueRow, error) { + query := "SELECT id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, last_schedule_trigger_at, last_schedule_error, last_run_error, status, created_at, started_at, completed_at, current_index FROM batch_task_queues WHERE 1=1" + args := []interface{}{} + + // 状态筛选 + if status != "" && status != "all" { + query += " AND status = ?" + args = append(args, status) + } + + // 关键字搜索(搜索队列ID和标题) + if keyword != "" { + query += " AND (id LIKE ? OR title LIKE ?)" + args = append(args, "%"+keyword+"%", "%"+keyword+"%") + } + + query += " ORDER BY created_at DESC LIMIT ? OFFSET ?" + args = append(args, limit, offset) + + rows, err := db.Query(query, args...) + if err != nil { + return nil, fmt.Errorf("查询批量任务队列列表失败: %w", err) + } + defer rows.Close() + + var queues []*BatchTaskQueueRow + for rows.Next() { + var row BatchTaskQueueRow + var createdAt string + if err := rows.Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex); err != nil { + return nil, fmt.Errorf("扫描批量任务队列失败: %w", err) + } + parsedTime, parseErr := time.Parse("2006-01-02 15:04:05", createdAt) + if parseErr != nil { + parsedTime, parseErr = time.Parse(time.RFC3339, createdAt) + if parseErr != nil { + db.logger.Warn("解析创建时间失败", zap.String("createdAt", createdAt), zap.Error(parseErr)) + parsedTime = time.Now() + } + } + row.CreatedAt = parsedTime + queues = append(queues, &row) + } + + return queues, nil +} + +// CountBatchQueues 统计批量任务队列总数(支持筛选条件) +func (db *DB) CountBatchQueues(status, keyword string) (int, error) { + query := "SELECT COUNT(*) FROM batch_task_queues WHERE 1=1" + args := []interface{}{} + + // 状态筛选 + if status != "" && status != "all" { + query += " AND status = ?" + args = append(args, status) + } + + // 关键字搜索(搜索队列ID和标题) + if keyword != "" { + query += " AND (id LIKE ? OR title LIKE ?)" + args = append(args, "%"+keyword+"%", "%"+keyword+"%") + } + + var count int + err := db.QueryRow(query, args...).Scan(&count) + if err != nil { + return 0, fmt.Errorf("统计批量任务队列总数失败: %w", err) + } + + return count, nil +} + +// GetBatchTasks 获取批量任务队列的所有任务 +func (db *DB) GetBatchTasks(queueID string) ([]*BatchTaskRow, error) { + rows, err := db.Query( + "SELECT id, queue_id, message, conversation_id, status, started_at, completed_at, error, result FROM batch_tasks WHERE queue_id = ? ORDER BY id", + queueID, + ) + if err != nil { + return nil, fmt.Errorf("查询批量任务失败: %w", err) + } + defer rows.Close() + + var tasks []*BatchTaskRow + for rows.Next() { + var task BatchTaskRow + if err := rows.Scan( + &task.ID, &task.QueueID, &task.Message, &task.ConversationID, + &task.Status, &task.StartedAt, &task.CompletedAt, &task.Error, &task.Result, + ); err != nil { + return nil, fmt.Errorf("扫描批量任务失败: %w", err) + } + tasks = append(tasks, &task) + } + + return tasks, nil +} + +// UpdateBatchQueueStatus 更新批量任务队列状态 +func (db *DB) UpdateBatchQueueStatus(queueID, status string) error { + var err error + now := time.Now() + + if status == "running" { + _, err = db.Exec( + "UPDATE batch_task_queues SET status = ?, started_at = COALESCE(started_at, ?) WHERE id = ?", + status, now, queueID, + ) + } else if status == "completed" || status == "cancelled" { + _, err = db.Exec( + "UPDATE batch_task_queues SET status = ?, completed_at = COALESCE(completed_at, ?) WHERE id = ?", + status, now, queueID, + ) + } else { + _, err = db.Exec( + "UPDATE batch_task_queues SET status = ? WHERE id = ?", + status, queueID, + ) + } + + if err != nil { + return fmt.Errorf("更新批量任务队列状态失败: %w", err) + } + return nil +} + +// UpdateBatchTaskStatus 更新批量任务状态 +func (db *DB) UpdateBatchTaskStatus(queueID, taskID, status string, conversationID, result, errorMsg string) error { + var err error + now := time.Now() + + // 构建更新语句 + var updates []string + var args []interface{} + + updates = append(updates, "status = ?") + args = append(args, status) + + if conversationID != "" { + updates = append(updates, "conversation_id = ?") + args = append(args, conversationID) + } + + if result != "" { + updates = append(updates, "result = ?") + args = append(args, result) + } + + if errorMsg != "" { + updates = append(updates, "error = ?") + args = append(args, errorMsg) + } + + if status == "running" { + updates = append(updates, "started_at = COALESCE(started_at, ?)") + args = append(args, now) + } + + if status == "completed" || status == "failed" || status == "cancelled" { + updates = append(updates, "completed_at = COALESCE(completed_at, ?)") + args = append(args, now) + } + + args = append(args, queueID, taskID) + + // 构建SQL语句 + sql := "UPDATE batch_tasks SET " + for i, update := range updates { + if i > 0 { + sql += ", " + } + sql += update + } + sql += " WHERE queue_id = ? AND id = ?" + + _, err = db.Exec(sql, args...) + if err != nil { + return fmt.Errorf("更新批量任务状态失败: %w", err) + } + return nil +} + +// UpdateBatchQueueCurrentIndex 更新批量任务队列的当前索引 +func (db *DB) UpdateBatchQueueCurrentIndex(queueID string, currentIndex int) error { + _, err := db.Exec( + "UPDATE batch_task_queues SET current_index = ? WHERE id = ?", + currentIndex, queueID, + ) + if err != nil { + return fmt.Errorf("更新批量任务队列当前索引失败: %w", err) + } + return nil +} + +// UpdateBatchQueueMetadata 更新批量任务队列标题、角色和代理模式 +func (db *DB) UpdateBatchQueueMetadata(queueID, title, role, agentMode string) error { + _, err := db.Exec( + "UPDATE batch_task_queues SET title = ?, role = ?, agent_mode = ? WHERE id = ?", + title, role, agentMode, queueID, + ) + if err != nil { + return fmt.Errorf("更新批量任务队列元数据失败: %w", err) + } + return nil +} + +// UpdateBatchQueueSchedule 更新批量任务队列调度相关信息 +func (db *DB) UpdateBatchQueueSchedule(queueID, scheduleMode, cronExpr string, nextRunAt *time.Time) error { + var nextRunAtValue interface{} + if nextRunAt != nil { + nextRunAtValue = *nextRunAt + } + _, err := db.Exec( + "UPDATE batch_task_queues SET schedule_mode = ?, cron_expr = ?, next_run_at = ? WHERE id = ?", + scheduleMode, cronExpr, nextRunAtValue, queueID, + ) + if err != nil { + return fmt.Errorf("更新批量任务调度配置失败: %w", err) + } + return nil +} + +// UpdateBatchQueueScheduleEnabled 是否允许 Cron 自动触发(手工「开始执行」不受影响) +func (db *DB) UpdateBatchQueueScheduleEnabled(queueID string, enabled bool) error { + v := 0 + if enabled { + v = 1 + } + _, err := db.Exec( + "UPDATE batch_task_queues SET schedule_enabled = ? WHERE id = ?", + v, queueID, + ) + if err != nil { + return fmt.Errorf("更新批量任务调度开关失败: %w", err) + } + return nil +} + +// RecordBatchQueueScheduledTriggerStart 记录一次由调度触发的开始时间并清空调度层错误 +func (db *DB) RecordBatchQueueScheduledTriggerStart(queueID string, at time.Time) error { + _, err := db.Exec( + "UPDATE batch_task_queues SET last_schedule_trigger_at = ?, last_schedule_error = NULL WHERE id = ?", + at, queueID, + ) + if err != nil { + return fmt.Errorf("记录调度触发时间失败: %w", err) + } + return nil +} + +// SetBatchQueueLastScheduleError 调度启动失败等原因(如状态不允许、重置失败) +func (db *DB) SetBatchQueueLastScheduleError(queueID, msg string) error { + _, err := db.Exec( + "UPDATE batch_task_queues SET last_schedule_error = ? WHERE id = ?", + msg, queueID, + ) + if err != nil { + return fmt.Errorf("写入调度错误信息失败: %w", err) + } + return nil +} + +// SetBatchQueueLastRunError 最近一轮执行中出现的子任务失败摘要(空串表示清空) +func (db *DB) SetBatchQueueLastRunError(queueID, msg string) error { + var v interface{} + if strings.TrimSpace(msg) == "" { + v = nil + } else { + v = msg + } + _, err := db.Exec( + "UPDATE batch_task_queues SET last_run_error = ? WHERE id = ?", + v, queueID, + ) + if err != nil { + return fmt.Errorf("写入最近运行错误失败: %w", err) + } + return nil +} + +// ResetBatchQueueForRerun 重置队列和任务状态用于下一轮调度执行 +func (db *DB) ResetBatchQueueForRerun(queueID string) error { + tx, err := db.Begin() + if err != nil { + return fmt.Errorf("开始事务失败: %w", err) + } + defer tx.Rollback() + + _, err = tx.Exec( + "UPDATE batch_task_queues SET status = ?, current_index = 0, started_at = NULL, completed_at = NULL, last_run_error = NULL, last_schedule_error = NULL WHERE id = ?", + "pending", queueID, + ) + if err != nil { + return fmt.Errorf("重置批量任务队列状态失败: %w", err) + } + + _, err = tx.Exec( + "UPDATE batch_tasks SET status = ?, conversation_id = NULL, started_at = NULL, completed_at = NULL, error = NULL, result = NULL WHERE queue_id = ?", + "pending", queueID, + ) + if err != nil { + return fmt.Errorf("重置批量任务状态失败: %w", err) + } + + return tx.Commit() +} + +// UpdateBatchTaskMessage 更新批量任务消息 +func (db *DB) UpdateBatchTaskMessage(queueID, taskID, message string) error { + _, err := db.Exec( + "UPDATE batch_tasks SET message = ? WHERE queue_id = ? AND id = ?", + message, queueID, taskID, + ) + if err != nil { + return fmt.Errorf("更新批量任务消息失败: %w", err) + } + return nil +} + +// AddBatchTask 添加任务到批量任务队列 +func (db *DB) AddBatchTask(queueID, taskID, message string) error { + _, err := db.Exec( + "INSERT INTO batch_tasks (id, queue_id, message, status) VALUES (?, ?, ?, ?)", + taskID, queueID, message, "pending", + ) + if err != nil { + return fmt.Errorf("添加批量任务失败: %w", err) + } + return nil +} + +// CancelPendingBatchTasks 批量取消队列中所有 pending 状态的任务(单条 SQL) +func (db *DB) CancelPendingBatchTasks(queueID string, completedAt time.Time) error { + _, err := db.Exec( + "UPDATE batch_tasks SET status = ?, completed_at = ? WHERE queue_id = ? AND status = ?", + "cancelled", completedAt, queueID, "pending", + ) + if err != nil { + return fmt.Errorf("批量取消 pending 任务失败: %w", err) + } + return nil +} + +// DeleteBatchTask 删除批量任务 +func (db *DB) DeleteBatchTask(queueID, taskID string) error { + _, err := db.Exec( + "DELETE FROM batch_tasks WHERE queue_id = ? AND id = ?", + queueID, taskID, + ) + if err != nil { + return fmt.Errorf("删除批量任务失败: %w", err) + } + return nil +} + +// DeleteBatchQueue 删除批量任务队列 +func (db *DB) DeleteBatchQueue(queueID string) error { + tx, err := db.Begin() + if err != nil { + return fmt.Errorf("开始事务失败: %w", err) + } + defer tx.Rollback() + + // 删除任务(外键会自动级联删除) + _, err = tx.Exec("DELETE FROM batch_tasks WHERE queue_id = ?", queueID) + if err != nil { + return fmt.Errorf("删除批量任务失败: %w", err) + } + + // 删除队列 + _, err = tx.Exec("DELETE FROM batch_task_queues WHERE id = ?", queueID) + if err != nil { + return fmt.Errorf("删除批量任务队列失败: %w", err) + } + + return tx.Commit() +} diff --git a/internal/database/conversation.go b/internal/database/conversation.go new file mode 100644 index 00000000..ca2b1f5a --- /dev/null +++ b/internal/database/conversation.go @@ -0,0 +1,758 @@ +package database + +import ( + "database/sql" + "encoding/json" + "fmt" + "strings" + "time" + + "github.com/google/uuid" + "go.uber.org/zap" +) + +// Conversation 对话 +type Conversation struct { + ID string `json:"id"` + Title string `json:"title"` + Pinned bool `json:"pinned"` + CreatedAt time.Time `json:"createdAt"` + UpdatedAt time.Time `json:"updatedAt"` + Messages []Message `json:"messages,omitempty"` +} + +// Message 消息 +type Message struct { + ID string `json:"id"` + ConversationID string `json:"conversationId"` + Role string `json:"role"` + Content string `json:"content"` + MCPExecutionIDs []string `json:"mcpExecutionIds,omitempty"` + ProcessDetails []map[string]interface{} `json:"processDetails,omitempty"` + CreatedAt time.Time `json:"createdAt"` +} + +// CreateConversation 创建新对话 +func (db *DB) CreateConversation(title string) (*Conversation, error) { + return db.CreateConversationWithWebshell("", title) +} + +// CreateConversationWithWebshell 创建新对话,可选绑定 WebShell 连接 ID(为空则普通对话) +func (db *DB) CreateConversationWithWebshell(webshellConnectionID, title string) (*Conversation, error) { + id := uuid.New().String() + now := time.Now() + + var err error + if webshellConnectionID != "" { + _, err = db.Exec( + "INSERT INTO conversations (id, title, created_at, updated_at, webshell_connection_id) VALUES (?, ?, ?, ?, ?)", + id, title, now, now, webshellConnectionID, + ) + } else { + _, err = db.Exec( + "INSERT INTO conversations (id, title, created_at, updated_at) VALUES (?, ?, ?, ?)", + id, title, now, now, + ) + } + if err != nil { + return nil, fmt.Errorf("创建对话失败: %w", err) + } + + return &Conversation{ + ID: id, + Title: title, + CreatedAt: now, + UpdatedAt: now, + }, nil +} + +// GetConversationByWebshellConnectionID 根据 WebShell 连接 ID 获取该连接下最近一条对话(用于 AI 助手持久化) +func (db *DB) GetConversationByWebshellConnectionID(connectionID string) (*Conversation, error) { + if connectionID == "" { + return nil, fmt.Errorf("connectionID is empty") + } + var conv Conversation + var createdAt, updatedAt string + var pinned int + err := db.QueryRow( + "SELECT id, title, pinned, created_at, updated_at FROM conversations WHERE webshell_connection_id = ? ORDER BY updated_at DESC LIMIT 1", + connectionID, + ).Scan(&conv.ID, &conv.Title, &pinned, &createdAt, &updatedAt) + if err != nil { + if err == sql.ErrNoRows { + return nil, nil + } + return nil, fmt.Errorf("查询对话失败: %w", err) + } + conv.Pinned = pinned != 0 + if t, e := time.Parse("2006-01-02 15:04:05.999999999-07:00", createdAt); e == nil { + conv.CreatedAt = t + } else if t, e := time.Parse("2006-01-02 15:04:05", createdAt); e == nil { + conv.CreatedAt = t + } else { + conv.CreatedAt, _ = time.Parse(time.RFC3339, createdAt) + } + if t, e := time.Parse("2006-01-02 15:04:05.999999999-07:00", updatedAt); e == nil { + conv.UpdatedAt = t + } else if t, e := time.Parse("2006-01-02 15:04:05", updatedAt); e == nil { + conv.UpdatedAt = t + } else { + conv.UpdatedAt, _ = time.Parse(time.RFC3339, updatedAt) + } + messages, err := db.GetMessages(conv.ID) + if err != nil { + return nil, fmt.Errorf("加载消息失败: %w", err) + } + conv.Messages = messages + + // 加载过程详情并附加到对应消息(与 GetConversation 一致,便于刷新后仍可查看执行过程) + processDetailsMap, err := db.GetProcessDetailsByConversation(conv.ID) + if err != nil { + db.logger.Warn("加载过程详情失败", zap.Error(err)) + processDetailsMap = make(map[string][]ProcessDetail) + } + for i := range conv.Messages { + if details, ok := processDetailsMap[conv.Messages[i].ID]; ok { + detailsJSON := make([]map[string]interface{}, len(details)) + for j, detail := range details { + var data interface{} + if detail.Data != "" { + if err := json.Unmarshal([]byte(detail.Data), &data); err != nil { + db.logger.Warn("解析过程详情数据失败", zap.Error(err)) + } + } + detailsJSON[j] = map[string]interface{}{ + "id": detail.ID, + "messageId": detail.MessageID, + "conversationId": detail.ConversationID, + "eventType": detail.EventType, + "message": detail.Message, + "data": data, + "createdAt": detail.CreatedAt, + } + } + conv.Messages[i].ProcessDetails = detailsJSON + } + } + + return &conv, nil +} + +// WebShellConversationItem 用于侧边栏列表,不含消息 +type WebShellConversationItem struct { + ID string `json:"id"` + Title string `json:"title"` + UpdatedAt time.Time `json:"updatedAt"` +} + +// ListConversationsByWebshellConnectionID 列出该 WebShell 连接下的所有对话(按更新时间倒序),供侧边栏展示 +func (db *DB) ListConversationsByWebshellConnectionID(connectionID string) ([]WebShellConversationItem, error) { + if connectionID == "" { + return nil, nil + } + rows, err := db.Query( + "SELECT id, title, updated_at FROM conversations WHERE webshell_connection_id = ? ORDER BY updated_at DESC", + connectionID, + ) + if err != nil { + return nil, fmt.Errorf("查询对话列表失败: %w", err) + } + defer rows.Close() + var list []WebShellConversationItem + for rows.Next() { + var item WebShellConversationItem + var updatedAt string + if err := rows.Scan(&item.ID, &item.Title, &updatedAt); err != nil { + continue + } + if t, e := time.Parse("2006-01-02 15:04:05.999999999-07:00", updatedAt); e == nil { + item.UpdatedAt = t + } else if t, e := time.Parse("2006-01-02 15:04:05", updatedAt); e == nil { + item.UpdatedAt = t + } else { + item.UpdatedAt, _ = time.Parse(time.RFC3339, updatedAt) + } + list = append(list, item) + } + return list, rows.Err() +} + +// GetConversation 获取对话 +func (db *DB) GetConversation(id string) (*Conversation, error) { + var conv Conversation + var createdAt, updatedAt string + var pinned int + + err := db.QueryRow( + "SELECT id, title, pinned, created_at, updated_at FROM conversations WHERE id = ?", + id, + ).Scan(&conv.ID, &conv.Title, &pinned, &createdAt, &updatedAt) + if err != nil { + if err == sql.ErrNoRows { + return nil, fmt.Errorf("对话不存在") + } + return nil, fmt.Errorf("查询对话失败: %w", err) + } + + // 尝试多种时间格式解析 + var err1, err2 error + conv.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05.999999999-07:00", createdAt) + if err1 != nil { + conv.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05", createdAt) + } + if err1 != nil { + conv.CreatedAt, _ = time.Parse(time.RFC3339, createdAt) + } + + conv.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05.999999999-07:00", updatedAt) + if err2 != nil { + conv.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05", updatedAt) + } + if err2 != nil { + conv.UpdatedAt, _ = time.Parse(time.RFC3339, updatedAt) + } + + conv.Pinned = pinned != 0 + + // 加载消息 + messages, err := db.GetMessages(id) + if err != nil { + return nil, fmt.Errorf("加载消息失败: %w", err) + } + conv.Messages = messages + + // 加载过程详情(按消息ID分组) + processDetailsMap, err := db.GetProcessDetailsByConversation(id) + if err != nil { + db.logger.Warn("加载过程详情失败", zap.Error(err)) + processDetailsMap = make(map[string][]ProcessDetail) + } + + // 将过程详情附加到对应的消息上 + for i := range conv.Messages { + if details, ok := processDetailsMap[conv.Messages[i].ID]; ok { + // 将ProcessDetail转换为JSON格式,以便前端使用 + detailsJSON := make([]map[string]interface{}, len(details)) + for j, detail := range details { + var data interface{} + if detail.Data != "" { + if err := json.Unmarshal([]byte(detail.Data), &data); err != nil { + db.logger.Warn("解析过程详情数据失败", zap.Error(err)) + } + } + detailsJSON[j] = map[string]interface{}{ + "id": detail.ID, + "messageId": detail.MessageID, + "conversationId": detail.ConversationID, + "eventType": detail.EventType, + "message": detail.Message, + "data": data, + "createdAt": detail.CreatedAt, + } + } + conv.Messages[i].ProcessDetails = detailsJSON + } + } + + return &conv, nil +} + +// GetConversationLite 获取对话(轻量版):包含 messages,但不加载 process_details。 +// 用于历史会话快速切换,避免一次性把大体量过程详情灌到前端导致卡顿。 +func (db *DB) GetConversationLite(id string) (*Conversation, error) { + var conv Conversation + var createdAt, updatedAt string + var pinned int + + err := db.QueryRow( + "SELECT id, title, pinned, created_at, updated_at FROM conversations WHERE id = ?", + id, + ).Scan(&conv.ID, &conv.Title, &pinned, &createdAt, &updatedAt) + if err != nil { + if err == sql.ErrNoRows { + return nil, fmt.Errorf("对话不存在") + } + return nil, fmt.Errorf("查询对话失败: %w", err) + } + + // 尝试多种时间格式解析 + var err1, err2 error + conv.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05.999999999-07:00", createdAt) + if err1 != nil { + conv.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05", createdAt) + } + if err1 != nil { + conv.CreatedAt, _ = time.Parse(time.RFC3339, createdAt) + } + + conv.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05.999999999-07:00", updatedAt) + if err2 != nil { + conv.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05", updatedAt) + } + if err2 != nil { + conv.UpdatedAt, _ = time.Parse(time.RFC3339, updatedAt) + } + + conv.Pinned = pinned != 0 + + // 加载消息(不加载 process_details) + messages, err := db.GetMessages(id) + if err != nil { + return nil, fmt.Errorf("加载消息失败: %w", err) + } + conv.Messages = messages + return &conv, nil +} + +// ListConversations 列出所有对话 +func (db *DB) ListConversations(limit, offset int, search string) ([]*Conversation, error) { + var rows *sql.Rows + var err error + + if search != "" { + // 使用 EXISTS 子查询代替 LEFT JOIN + DISTINCT,避免大表笛卡尔积 + searchPattern := "%" + search + "%" + rows, err = db.Query( + `SELECT c.id, c.title, COALESCE(c.pinned, 0), c.created_at, c.updated_at + FROM conversations c + WHERE c.title LIKE ? + OR EXISTS (SELECT 1 FROM messages m WHERE m.conversation_id = c.id AND m.content LIKE ?) + ORDER BY c.updated_at DESC + LIMIT ? OFFSET ?`, + searchPattern, searchPattern, limit, offset, + ) + } else { + rows, err = db.Query( + "SELECT id, title, COALESCE(pinned, 0), created_at, updated_at FROM conversations ORDER BY updated_at DESC LIMIT ? OFFSET ?", + limit, offset, + ) + } + + if err != nil { + return nil, fmt.Errorf("查询对话列表失败: %w", err) + } + defer rows.Close() + + var conversations []*Conversation + for rows.Next() { + var conv Conversation + var createdAt, updatedAt string + var pinned int + + if err := rows.Scan(&conv.ID, &conv.Title, &pinned, &createdAt, &updatedAt); err != nil { + return nil, fmt.Errorf("扫描对话失败: %w", err) + } + + // 尝试多种时间格式解析 + var err1, err2 error + conv.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05.999999999-07:00", createdAt) + if err1 != nil { + conv.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05", createdAt) + } + if err1 != nil { + conv.CreatedAt, _ = time.Parse(time.RFC3339, createdAt) + } + + conv.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05.999999999-07:00", updatedAt) + if err2 != nil { + conv.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05", updatedAt) + } + if err2 != nil { + conv.UpdatedAt, _ = time.Parse(time.RFC3339, updatedAt) + } + + conv.Pinned = pinned != 0 + + conversations = append(conversations, &conv) + } + + return conversations, nil +} + +// UpdateConversationTitle 更新对话标题 +func (db *DB) UpdateConversationTitle(id, title string) error { + // 注意:不更新 updated_at,因为重命名操作不应该改变对话的更新时间 + _, err := db.Exec( + "UPDATE conversations SET title = ? WHERE id = ?", + title, id, + ) + if err != nil { + return fmt.Errorf("更新对话标题失败: %w", err) + } + return nil +} + +// UpdateConversationTime 更新对话时间 +func (db *DB) UpdateConversationTime(id string) error { + _, err := db.Exec( + "UPDATE conversations SET updated_at = ? WHERE id = ?", + time.Now(), id, + ) + if err != nil { + return fmt.Errorf("更新对话时间失败: %w", err) + } + return nil +} + +// DeleteConversation 删除对话及其所有相关数据 +// 由于数据库外键约束设置了 ON DELETE CASCADE,删除对话时会自动删除: +// - messages(消息) +// - process_details(过程详情) +// - attack_chain_nodes(攻击链节点) +// - attack_chain_edges(攻击链边) +// - vulnerabilities(漏洞) +// - conversation_group_mappings(分组映射) +// 注意:knowledge_retrieval_logs 使用 ON DELETE SET NULL,记录会保留但 conversation_id 会被设为 NULL +func (db *DB) DeleteConversation(id string) error { + // 显式删除知识检索日志(虽然外键是SET NULL,但为了彻底清理,我们手动删除) + _, err := db.Exec("DELETE FROM knowledge_retrieval_logs WHERE conversation_id = ?", id) + if err != nil { + db.logger.Warn("删除知识检索日志失败", zap.String("conversationId", id), zap.Error(err)) + // 不返回错误,继续删除对话 + } + + // 删除对话(外键CASCADE会自动删除其他相关数据) + _, err = db.Exec("DELETE FROM conversations WHERE id = ?", id) + if err != nil { + return fmt.Errorf("删除对话失败: %w", err) + } + + db.logger.Info("对话及其所有相关数据已删除", zap.String("conversationId", id)) + return nil +} + +// SaveReActData 保存最后一轮ReAct的输入和输出 +func (db *DB) SaveReActData(conversationID, reactInput, reactOutput string) error { + _, err := db.Exec( + "UPDATE conversations SET last_react_input = ?, last_react_output = ?, updated_at = ? WHERE id = ?", + reactInput, reactOutput, time.Now(), conversationID, + ) + if err != nil { + return fmt.Errorf("保存ReAct数据失败: %w", err) + } + return nil +} + +// GetReActData 获取最后一轮ReAct的输入和输出 +func (db *DB) GetReActData(conversationID string) (reactInput, reactOutput string, err error) { + var input, output sql.NullString + err = db.QueryRow( + "SELECT last_react_input, last_react_output FROM conversations WHERE id = ?", + conversationID, + ).Scan(&input, &output) + if err != nil { + if err == sql.ErrNoRows { + return "", "", fmt.Errorf("对话不存在") + } + return "", "", fmt.Errorf("获取ReAct数据失败: %w", err) + } + + if input.Valid { + reactInput = input.String + } + if output.Valid { + reactOutput = output.String + } + + return reactInput, reactOutput, nil +} + +// ConversationHasToolProcessDetails 对话是否存在已落库的工具调用/结果(用于多代理等场景下 MCP execution id 未汇总时的攻击链判定)。 +func (db *DB) ConversationHasToolProcessDetails(conversationID string) (bool, error) { + var n int + err := db.QueryRow( + `SELECT COUNT(*) FROM process_details WHERE conversation_id = ? AND event_type IN ('tool_call', 'tool_result')`, + conversationID, + ).Scan(&n) + if err != nil { + return false, fmt.Errorf("查询过程详情失败: %w", err) + } + return n > 0, nil +} + +// AddMessage 添加消息 +func (db *DB) AddMessage(conversationID, role, content string, mcpExecutionIDs []string) (*Message, error) { + id := uuid.New().String() + + var mcpIDsJSON string + if len(mcpExecutionIDs) > 0 { + jsonData, err := json.Marshal(mcpExecutionIDs) + if err != nil { + db.logger.Warn("序列化MCP执行ID失败", zap.Error(err)) + } else { + mcpIDsJSON = string(jsonData) + } + } + + _, err := db.Exec( + "INSERT INTO messages (id, conversation_id, role, content, mcp_execution_ids, created_at) VALUES (?, ?, ?, ?, ?, ?)", + id, conversationID, role, content, mcpIDsJSON, time.Now(), + ) + if err != nil { + return nil, fmt.Errorf("添加消息失败: %w", err) + } + + // 更新对话时间 + if err := db.UpdateConversationTime(conversationID); err != nil { + db.logger.Warn("更新对话时间失败", zap.Error(err)) + } + + message := &Message{ + ID: id, + ConversationID: conversationID, + Role: role, + Content: content, + MCPExecutionIDs: mcpExecutionIDs, + CreatedAt: time.Now(), + } + + return message, nil +} + +// GetMessages 获取对话的所有消息 +func (db *DB) GetMessages(conversationID string) ([]Message, error) { + rows, err := db.Query( + "SELECT id, conversation_id, role, content, mcp_execution_ids, created_at FROM messages WHERE conversation_id = ? ORDER BY created_at ASC", + conversationID, + ) + if err != nil { + return nil, fmt.Errorf("查询消息失败: %w", err) + } + defer rows.Close() + + var messages []Message + for rows.Next() { + var msg Message + var mcpIDsJSON sql.NullString + var createdAt string + + if err := rows.Scan(&msg.ID, &msg.ConversationID, &msg.Role, &msg.Content, &mcpIDsJSON, &createdAt); err != nil { + return nil, fmt.Errorf("扫描消息失败: %w", err) + } + + // 尝试多种时间格式解析 + var err error + msg.CreatedAt, err = time.Parse("2006-01-02 15:04:05.999999999-07:00", createdAt) + if err != nil { + msg.CreatedAt, err = time.Parse("2006-01-02 15:04:05", createdAt) + } + if err != nil { + msg.CreatedAt, _ = time.Parse(time.RFC3339, createdAt) + } + + // 解析MCP执行ID + if mcpIDsJSON.Valid && mcpIDsJSON.String != "" { + if err := json.Unmarshal([]byte(mcpIDsJSON.String), &msg.MCPExecutionIDs); err != nil { + db.logger.Warn("解析MCP执行ID失败", zap.Error(err)) + } + } + + messages = append(messages, msg) + } + + return messages, nil +} + +// turnSliceRange 根据任意一条消息 ID 定位「一轮对话」在 msgs 中的 [start, end) 下标区间(msgs 须已按时间升序,与 GetMessages 一致)。 +// 一轮 = 从某条 user 消息起,至下一条 user 之前(含中间所有 assistant)。 +func turnSliceRange(msgs []Message, anchorID string) (start, end int, err error) { + idx := -1 + for i := range msgs { + if msgs[i].ID == anchorID { + idx = i + break + } + } + if idx < 0 { + return 0, 0, fmt.Errorf("message not found") + } + start = idx + for start > 0 && msgs[start].Role != "user" { + start-- + } + if start < len(msgs) && msgs[start].Role != "user" { + start = 0 + } + end = len(msgs) + for i := start + 1; i < len(msgs); i++ { + if msgs[i].Role == "user" { + end = i + break + } + } + return start, end, nil +} + +// DeleteConversationTurn 删除锚点所在轮次的全部消息(用户提问 + 该轮助手回复等),并清空 last_react_*,避免与消息表不一致。 +func (db *DB) DeleteConversationTurn(conversationID, anchorMessageID string) (deletedIDs []string, err error) { + msgs, err := db.GetMessages(conversationID) + if err != nil { + return nil, err + } + start, end, err := turnSliceRange(msgs, anchorMessageID) + if err != nil { + return nil, err + } + if start >= end { + return nil, fmt.Errorf("empty turn range") + } + deletedIDs = make([]string, 0, end-start) + for i := start; i < end; i++ { + deletedIDs = append(deletedIDs, msgs[i].ID) + } + + tx, err := db.Begin() + if err != nil { + return nil, fmt.Errorf("begin tx: %w", err) + } + defer func() { _ = tx.Rollback() }() + + ph := strings.Repeat("?,", len(deletedIDs)) + ph = ph[:len(ph)-1] + args := make([]interface{}, 0, 1+len(deletedIDs)) + args = append(args, conversationID) + for _, id := range deletedIDs { + args = append(args, id) + } + res, err := tx.Exec( + "DELETE FROM messages WHERE conversation_id = ? AND id IN ("+ph+")", + args..., + ) + if err != nil { + return nil, fmt.Errorf("delete messages: %w", err) + } + n, err := res.RowsAffected() + if err != nil { + return nil, err + } + if int(n) != len(deletedIDs) { + return nil, fmt.Errorf("deleted count mismatch") + } + + _, err = tx.Exec( + `UPDATE conversations SET last_react_input = NULL, last_react_output = NULL, updated_at = ? WHERE id = ?`, + time.Now(), conversationID, + ) + if err != nil { + return nil, fmt.Errorf("clear react data: %w", err) + } + + if err := tx.Commit(); err != nil { + return nil, fmt.Errorf("commit: %w", err) + } + + db.logger.Info("conversation turn deleted", + zap.String("conversationId", conversationID), + zap.Strings("deletedMessageIds", deletedIDs), + zap.Int("count", len(deletedIDs)), + ) + return deletedIDs, nil +} + +// ProcessDetail 过程详情事件 +type ProcessDetail struct { + ID string `json:"id"` + MessageID string `json:"messageId"` + ConversationID string `json:"conversationId"` + EventType string `json:"eventType"` // iteration, thinking, tool_calls_detected, tool_call, tool_result, progress, error + Message string `json:"message"` + Data string `json:"data"` // JSON格式的数据 + CreatedAt time.Time `json:"createdAt"` +} + +// AddProcessDetail 添加过程详情事件 +func (db *DB) AddProcessDetail(messageID, conversationID, eventType, message string, data interface{}) error { + id := uuid.New().String() + + var dataJSON string + if data != nil { + jsonData, err := json.Marshal(data) + if err != nil { + db.logger.Warn("序列化过程详情数据失败", zap.Error(err)) + } else { + dataJSON = string(jsonData) + } + } + + _, err := db.Exec( + "INSERT INTO process_details (id, message_id, conversation_id, event_type, message, data, created_at) VALUES (?, ?, ?, ?, ?, ?, ?)", + id, messageID, conversationID, eventType, message, dataJSON, time.Now(), + ) + if err != nil { + return fmt.Errorf("添加过程详情失败: %w", err) + } + + return nil +} + +// GetProcessDetails 获取消息的过程详情 +func (db *DB) GetProcessDetails(messageID string) ([]ProcessDetail, error) { + rows, err := db.Query( + "SELECT id, message_id, conversation_id, event_type, message, data, created_at FROM process_details WHERE message_id = ? ORDER BY created_at ASC", + messageID, + ) + if err != nil { + return nil, fmt.Errorf("查询过程详情失败: %w", err) + } + defer rows.Close() + + var details []ProcessDetail + for rows.Next() { + var detail ProcessDetail + var createdAt string + + if err := rows.Scan(&detail.ID, &detail.MessageID, &detail.ConversationID, &detail.EventType, &detail.Message, &detail.Data, &createdAt); err != nil { + return nil, fmt.Errorf("扫描过程详情失败: %w", err) + } + + // 尝试多种时间格式解析 + var err error + detail.CreatedAt, err = time.Parse("2006-01-02 15:04:05.999999999-07:00", createdAt) + if err != nil { + detail.CreatedAt, err = time.Parse("2006-01-02 15:04:05", createdAt) + } + if err != nil { + detail.CreatedAt, _ = time.Parse(time.RFC3339, createdAt) + } + + details = append(details, detail) + } + + return details, nil +} + +// GetProcessDetailsByConversation 获取对话的所有过程详情(按消息分组) +func (db *DB) GetProcessDetailsByConversation(conversationID string) (map[string][]ProcessDetail, error) { + rows, err := db.Query( + "SELECT id, message_id, conversation_id, event_type, message, data, created_at FROM process_details WHERE conversation_id = ? ORDER BY created_at ASC", + conversationID, + ) + if err != nil { + return nil, fmt.Errorf("查询过程详情失败: %w", err) + } + defer rows.Close() + + detailsMap := make(map[string][]ProcessDetail) + for rows.Next() { + var detail ProcessDetail + var createdAt string + + if err := rows.Scan(&detail.ID, &detail.MessageID, &detail.ConversationID, &detail.EventType, &detail.Message, &detail.Data, &createdAt); err != nil { + return nil, fmt.Errorf("扫描过程详情失败: %w", err) + } + + // 尝试多种时间格式解析 + var err error + detail.CreatedAt, err = time.Parse("2006-01-02 15:04:05.999999999-07:00", createdAt) + if err != nil { + detail.CreatedAt, err = time.Parse("2006-01-02 15:04:05", createdAt) + } + if err != nil { + detail.CreatedAt, _ = time.Parse(time.RFC3339, createdAt) + } + + detailsMap[detail.MessageID] = append(detailsMap[detail.MessageID], detail) + } + + return detailsMap, nil +} diff --git a/internal/database/conversation_turn_test.go b/internal/database/conversation_turn_test.go new file mode 100644 index 00000000..68743468 --- /dev/null +++ b/internal/database/conversation_turn_test.go @@ -0,0 +1,39 @@ +package database + +import ( + "testing" +) + +func TestTurnSliceRange(t *testing.T) { + mk := func(id, role string) Message { + return Message{ID: id, Role: role} + } + msgs := []Message{ + mk("u1", "user"), + mk("a1", "assistant"), + mk("u2", "user"), + mk("a2", "assistant"), + } + cases := []struct { + anchor string + start int + end int + }{ + {"u1", 0, 2}, + {"a1", 0, 2}, + {"u2", 2, 4}, + {"a2", 2, 4}, + } + for _, tc := range cases { + s, e, err := turnSliceRange(msgs, tc.anchor) + if err != nil { + t.Fatalf("anchor %s: %v", tc.anchor, err) + } + if s != tc.start || e != tc.end { + t.Fatalf("anchor %s: got [%d,%d) want [%d,%d)", tc.anchor, s, e, tc.start, tc.end) + } + } + if _, _, err := turnSliceRange(msgs, "nope"); err == nil { + t.Fatal("expected error for missing id") + } +} diff --git a/internal/database/database.go b/internal/database/database.go new file mode 100644 index 00000000..0e0ec524 --- /dev/null +++ b/internal/database/database.go @@ -0,0 +1,809 @@ +package database + +import ( + "database/sql" + "fmt" + "strings" + + _ "github.com/mattn/go-sqlite3" + "go.uber.org/zap" +) + +// DB 数据库连接 +type DB struct { + *sql.DB + logger *zap.Logger +} + +// NewDB 创建数据库连接 +func NewDB(dbPath string, logger *zap.Logger) (*DB, error) { + db, err := sql.Open("sqlite3", dbPath+"?_journal_mode=WAL&_foreign_keys=1") + if err != nil { + return nil, fmt.Errorf("打开数据库失败: %w", err) + } + + if err := db.Ping(); err != nil { + return nil, fmt.Errorf("连接数据库失败: %w", err) + } + + database := &DB{ + DB: db, + logger: logger, + } + + // 初始化表 + if err := database.initTables(); err != nil { + return nil, fmt.Errorf("初始化表失败: %w", err) + } + + return database, nil +} + +// initTables 初始化数据库表 +func (db *DB) initTables() error { + // 创建对话表 + createConversationsTable := ` + CREATE TABLE IF NOT EXISTS conversations ( + id TEXT PRIMARY KEY, + title TEXT NOT NULL, + created_at DATETIME NOT NULL, + updated_at DATETIME NOT NULL, + last_react_input TEXT, + last_react_output TEXT + );` + + // 创建消息表 + createMessagesTable := ` + CREATE TABLE IF NOT EXISTS messages ( + id TEXT PRIMARY KEY, + conversation_id TEXT NOT NULL, + role TEXT NOT NULL, + content TEXT NOT NULL, + mcp_execution_ids TEXT, + created_at DATETIME NOT NULL, + FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE + );` + + // 创建过程详情表 + createProcessDetailsTable := ` + CREATE TABLE IF NOT EXISTS process_details ( + id TEXT PRIMARY KEY, + message_id TEXT NOT NULL, + conversation_id TEXT NOT NULL, + event_type TEXT NOT NULL, + message TEXT, + data TEXT, + created_at DATETIME NOT NULL, + FOREIGN KEY (message_id) REFERENCES messages(id) ON DELETE CASCADE, + FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE + );` + + // 创建工具执行记录表 + createToolExecutionsTable := ` + CREATE TABLE IF NOT EXISTS tool_executions ( + id TEXT PRIMARY KEY, + tool_name TEXT NOT NULL, + arguments TEXT NOT NULL, + status TEXT NOT NULL, + result TEXT, + error TEXT, + start_time DATETIME NOT NULL, + end_time DATETIME, + duration_ms INTEGER, + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP + );` + + // 创建工具统计表 + createToolStatsTable := ` + CREATE TABLE IF NOT EXISTS tool_stats ( + tool_name TEXT PRIMARY KEY, + total_calls INTEGER NOT NULL DEFAULT 0, + success_calls INTEGER NOT NULL DEFAULT 0, + failed_calls INTEGER NOT NULL DEFAULT 0, + last_call_time DATETIME, + updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP + );` + + // 创建Skills统计表 + createSkillStatsTable := ` + CREATE TABLE IF NOT EXISTS skill_stats ( + skill_name TEXT PRIMARY KEY, + total_calls INTEGER NOT NULL DEFAULT 0, + success_calls INTEGER NOT NULL DEFAULT 0, + failed_calls INTEGER NOT NULL DEFAULT 0, + last_call_time DATETIME, + updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP + );` + + // 创建攻击链节点表 + createAttackChainNodesTable := ` + CREATE TABLE IF NOT EXISTS attack_chain_nodes ( + id TEXT PRIMARY KEY, + conversation_id TEXT NOT NULL, + node_type TEXT NOT NULL, + node_name TEXT NOT NULL, + tool_execution_id TEXT, + metadata TEXT, + risk_score INTEGER DEFAULT 0, + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE, + FOREIGN KEY (tool_execution_id) REFERENCES tool_executions(id) ON DELETE SET NULL + );` + + // 创建攻击链边表 + createAttackChainEdgesTable := ` + CREATE TABLE IF NOT EXISTS attack_chain_edges ( + id TEXT PRIMARY KEY, + conversation_id TEXT NOT NULL, + source_node_id TEXT NOT NULL, + target_node_id TEXT NOT NULL, + edge_type TEXT NOT NULL, + weight INTEGER DEFAULT 1, + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE, + FOREIGN KEY (source_node_id) REFERENCES attack_chain_nodes(id) ON DELETE CASCADE, + FOREIGN KEY (target_node_id) REFERENCES attack_chain_nodes(id) ON DELETE CASCADE + );` + + // 创建知识检索日志表(保留在会话数据库中,因为有外键关联) + createKnowledgeRetrievalLogsTable := ` + CREATE TABLE IF NOT EXISTS knowledge_retrieval_logs ( + id TEXT PRIMARY KEY, + conversation_id TEXT, + message_id TEXT, + query TEXT NOT NULL, + risk_type TEXT, + retrieved_items TEXT, + created_at DATETIME NOT NULL, + FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE SET NULL, + FOREIGN KEY (message_id) REFERENCES messages(id) ON DELETE SET NULL + );` + + // 创建对话分组表 + createConversationGroupsTable := ` + CREATE TABLE IF NOT EXISTS conversation_groups ( + id TEXT PRIMARY KEY, + name TEXT NOT NULL, + icon TEXT, + created_at DATETIME NOT NULL, + updated_at DATETIME NOT NULL + );` + + // 创建对话分组映射表 + createConversationGroupMappingsTable := ` + CREATE TABLE IF NOT EXISTS conversation_group_mappings ( + id TEXT PRIMARY KEY, + conversation_id TEXT NOT NULL, + group_id TEXT NOT NULL, + created_at DATETIME NOT NULL, + FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE, + FOREIGN KEY (group_id) REFERENCES conversation_groups(id) ON DELETE CASCADE, + UNIQUE(conversation_id, group_id) + );` + + // 创建漏洞表 + createVulnerabilitiesTable := ` + CREATE TABLE IF NOT EXISTS vulnerabilities ( + id TEXT PRIMARY KEY, + conversation_id TEXT NOT NULL, + title TEXT NOT NULL, + description TEXT, + severity TEXT NOT NULL, + status TEXT NOT NULL DEFAULT 'open', + vulnerability_type TEXT, + target TEXT, + proof TEXT, + impact TEXT, + recommendation TEXT, + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE + );` + + // 创建批量任务队列表 + createBatchTaskQueuesTable := ` + CREATE TABLE IF NOT EXISTS batch_task_queues ( + id TEXT PRIMARY KEY, + title TEXT, + role TEXT, + agent_mode TEXT NOT NULL DEFAULT 'single', + schedule_mode TEXT NOT NULL DEFAULT 'manual', + cron_expr TEXT, + next_run_at DATETIME, + schedule_enabled INTEGER NOT NULL DEFAULT 1, + last_schedule_trigger_at DATETIME, + last_schedule_error TEXT, + last_run_error TEXT, + status TEXT NOT NULL, + created_at DATETIME NOT NULL, + started_at DATETIME, + completed_at DATETIME, + current_index INTEGER NOT NULL DEFAULT 0 + );` + + // 创建批量任务表 + createBatchTasksTable := ` + CREATE TABLE IF NOT EXISTS batch_tasks ( + id TEXT PRIMARY KEY, + queue_id TEXT NOT NULL, + message TEXT NOT NULL, + conversation_id TEXT, + status TEXT NOT NULL, + started_at DATETIME, + completed_at DATETIME, + error TEXT, + result TEXT, + FOREIGN KEY (queue_id) REFERENCES batch_task_queues(id) ON DELETE CASCADE + );` + + // 创建 WebShell 连接表 + createWebshellConnectionsTable := ` + CREATE TABLE IF NOT EXISTS webshell_connections ( + id TEXT PRIMARY KEY, + url TEXT NOT NULL, + password TEXT NOT NULL DEFAULT '', + type TEXT NOT NULL DEFAULT 'php', + method TEXT NOT NULL DEFAULT 'post', + cmd_param TEXT NOT NULL DEFAULT '', + remark TEXT NOT NULL DEFAULT '', + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP + );` + + // 创建 WebShell 连接扩展状态表(前端工作区/终端状态持久化) + createWebshellConnectionStatesTable := ` + CREATE TABLE IF NOT EXISTS webshell_connection_states ( + connection_id TEXT PRIMARY KEY, + state_json TEXT NOT NULL DEFAULT '{}', + updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (connection_id) REFERENCES webshell_connections(id) ON DELETE CASCADE + );` + + // 创建索引 + createIndexes := ` + CREATE INDEX IF NOT EXISTS idx_messages_conversation_id ON messages(conversation_id); + CREATE INDEX IF NOT EXISTS idx_conversations_updated_at ON conversations(updated_at); + CREATE INDEX IF NOT EXISTS idx_process_details_message_id ON process_details(message_id); + CREATE INDEX IF NOT EXISTS idx_process_details_conversation_id ON process_details(conversation_id); + CREATE INDEX IF NOT EXISTS idx_tool_executions_tool_name ON tool_executions(tool_name); + CREATE INDEX IF NOT EXISTS idx_tool_executions_start_time ON tool_executions(start_time); + CREATE INDEX IF NOT EXISTS idx_tool_executions_status ON tool_executions(status); + CREATE INDEX IF NOT EXISTS idx_chain_nodes_conversation ON attack_chain_nodes(conversation_id); + CREATE INDEX IF NOT EXISTS idx_chain_edges_conversation ON attack_chain_edges(conversation_id); + CREATE INDEX IF NOT EXISTS idx_chain_edges_source ON attack_chain_edges(source_node_id); + CREATE INDEX IF NOT EXISTS idx_chain_edges_target ON attack_chain_edges(target_node_id); + CREATE INDEX IF NOT EXISTS idx_knowledge_retrieval_logs_conversation ON knowledge_retrieval_logs(conversation_id); + CREATE INDEX IF NOT EXISTS idx_knowledge_retrieval_logs_message ON knowledge_retrieval_logs(message_id); + CREATE INDEX IF NOT EXISTS idx_knowledge_retrieval_logs_created_at ON knowledge_retrieval_logs(created_at); + CREATE INDEX IF NOT EXISTS idx_conversation_group_mappings_conversation ON conversation_group_mappings(conversation_id); + CREATE INDEX IF NOT EXISTS idx_conversation_group_mappings_group ON conversation_group_mappings(group_id); + CREATE INDEX IF NOT EXISTS idx_conversations_pinned ON conversations(pinned); + CREATE INDEX IF NOT EXISTS idx_vulnerabilities_conversation_id ON vulnerabilities(conversation_id); + CREATE INDEX IF NOT EXISTS idx_vulnerabilities_severity ON vulnerabilities(severity); + CREATE INDEX IF NOT EXISTS idx_vulnerabilities_status ON vulnerabilities(status); + CREATE INDEX IF NOT EXISTS idx_vulnerabilities_created_at ON vulnerabilities(created_at); + CREATE INDEX IF NOT EXISTS idx_batch_tasks_queue_id ON batch_tasks(queue_id); + CREATE INDEX IF NOT EXISTS idx_batch_task_queues_created_at ON batch_task_queues(created_at); + CREATE INDEX IF NOT EXISTS idx_batch_task_queues_title ON batch_task_queues(title); + CREATE INDEX IF NOT EXISTS idx_webshell_connections_created_at ON webshell_connections(created_at); + CREATE INDEX IF NOT EXISTS idx_webshell_connection_states_updated_at ON webshell_connection_states(updated_at); + ` + + if _, err := db.Exec(createConversationsTable); err != nil { + return fmt.Errorf("创建conversations表失败: %w", err) + } + + if _, err := db.Exec(createMessagesTable); err != nil { + return fmt.Errorf("创建messages表失败: %w", err) + } + + if _, err := db.Exec(createProcessDetailsTable); err != nil { + return fmt.Errorf("创建process_details表失败: %w", err) + } + + if _, err := db.Exec(createToolExecutionsTable); err != nil { + return fmt.Errorf("创建tool_executions表失败: %w", err) + } + + if _, err := db.Exec(createToolStatsTable); err != nil { + return fmt.Errorf("创建tool_stats表失败: %w", err) + } + + if _, err := db.Exec(createSkillStatsTable); err != nil { + return fmt.Errorf("创建skill_stats表失败: %w", err) + } + + if _, err := db.Exec(createAttackChainNodesTable); err != nil { + return fmt.Errorf("创建attack_chain_nodes表失败: %w", err) + } + + if _, err := db.Exec(createAttackChainEdgesTable); err != nil { + return fmt.Errorf("创建attack_chain_edges表失败: %w", err) + } + + if _, err := db.Exec(createKnowledgeRetrievalLogsTable); err != nil { + return fmt.Errorf("创建knowledge_retrieval_logs表失败: %w", err) + } + + if _, err := db.Exec(createConversationGroupsTable); err != nil { + return fmt.Errorf("创建conversation_groups表失败: %w", err) + } + + if _, err := db.Exec(createConversationGroupMappingsTable); err != nil { + return fmt.Errorf("创建conversation_group_mappings表失败: %w", err) + } + + if _, err := db.Exec(createVulnerabilitiesTable); err != nil { + return fmt.Errorf("创建vulnerabilities表失败: %w", err) + } + + if _, err := db.Exec(createBatchTaskQueuesTable); err != nil { + return fmt.Errorf("创建batch_task_queues表失败: %w", err) + } + + if _, err := db.Exec(createBatchTasksTable); err != nil { + return fmt.Errorf("创建batch_tasks表失败: %w", err) + } + + if _, err := db.Exec(createWebshellConnectionsTable); err != nil { + return fmt.Errorf("创建webshell_connections表失败: %w", err) + } + + if _, err := db.Exec(createWebshellConnectionStatesTable); err != nil { + return fmt.Errorf("创建webshell_connection_states表失败: %w", err) + } + + // 为已有表添加新字段(如果不存在)- 必须在创建索引之前 + if err := db.migrateConversationsTable(); err != nil { + db.logger.Warn("迁移conversations表失败", zap.Error(err)) + // 不返回错误,允许继续运行 + } + + if err := db.migrateConversationGroupsTable(); err != nil { + db.logger.Warn("迁移conversation_groups表失败", zap.Error(err)) + // 不返回错误,允许继续运行 + } + + if err := db.migrateConversationGroupMappingsTable(); err != nil { + db.logger.Warn("迁移conversation_group_mappings表失败", zap.Error(err)) + // 不返回错误,允许继续运行 + } + + if err := db.migrateBatchTaskQueuesTable(); err != nil { + db.logger.Warn("迁移batch_task_queues表失败", zap.Error(err)) + // 不返回错误,允许继续运行 + } + + if _, err := db.Exec(createIndexes); err != nil { + return fmt.Errorf("创建索引失败: %w", err) + } + + db.logger.Info("数据库表初始化完成") + return nil +} + +// migrateConversationsTable 迁移conversations表,添加新字段 +func (db *DB) migrateConversationsTable() error { + // 检查last_react_input字段是否存在 + var count int + err := db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('conversations') WHERE name='last_react_input'").Scan(&count) + if err != nil { + // 如果查询失败,尝试添加字段 + if _, addErr := db.Exec("ALTER TABLE conversations ADD COLUMN last_react_input TEXT"); addErr != nil { + // 如果字段已存在,忽略错误(SQLite错误信息可能不同) + errMsg := strings.ToLower(addErr.Error()) + if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") { + db.logger.Warn("添加last_react_input字段失败", zap.Error(addErr)) + } + } + } else if count == 0 { + // 字段不存在,添加它 + if _, err := db.Exec("ALTER TABLE conversations ADD COLUMN last_react_input TEXT"); err != nil { + db.logger.Warn("添加last_react_input字段失败", zap.Error(err)) + } + } + + // 检查last_react_output字段是否存在 + err = db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('conversations') WHERE name='last_react_output'").Scan(&count) + if err != nil { + // 如果查询失败,尝试添加字段 + if _, addErr := db.Exec("ALTER TABLE conversations ADD COLUMN last_react_output TEXT"); addErr != nil { + // 如果字段已存在,忽略错误 + errMsg := strings.ToLower(addErr.Error()) + if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") { + db.logger.Warn("添加last_react_output字段失败", zap.Error(addErr)) + } + } + } else if count == 0 { + // 字段不存在,添加它 + if _, err := db.Exec("ALTER TABLE conversations ADD COLUMN last_react_output TEXT"); err != nil { + db.logger.Warn("添加last_react_output字段失败", zap.Error(err)) + } + } + + // 检查pinned字段是否存在 + err = db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('conversations') WHERE name='pinned'").Scan(&count) + if err != nil { + // 如果查询失败,尝试添加字段 + if _, addErr := db.Exec("ALTER TABLE conversations ADD COLUMN pinned INTEGER DEFAULT 0"); addErr != nil { + // 如果字段已存在,忽略错误 + errMsg := strings.ToLower(addErr.Error()) + if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") { + db.logger.Warn("添加pinned字段失败", zap.Error(addErr)) + } + } + } else if count == 0 { + // 字段不存在,添加它 + if _, err := db.Exec("ALTER TABLE conversations ADD COLUMN pinned INTEGER DEFAULT 0"); err != nil { + db.logger.Warn("添加pinned字段失败", zap.Error(err)) + } + } + + // 检查 webshell_connection_id 字段是否存在(WebShell AI 助手对话关联) + err = db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('conversations') WHERE name='webshell_connection_id'").Scan(&count) + if err != nil { + if _, addErr := db.Exec("ALTER TABLE conversations ADD COLUMN webshell_connection_id TEXT"); addErr != nil { + errMsg := strings.ToLower(addErr.Error()) + if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") { + db.logger.Warn("添加webshell_connection_id字段失败", zap.Error(addErr)) + } + } + } else if count == 0 { + if _, err := db.Exec("ALTER TABLE conversations ADD COLUMN webshell_connection_id TEXT"); err != nil { + db.logger.Warn("添加webshell_connection_id字段失败", zap.Error(err)) + } + } + + return nil +} + +// migrateConversationGroupsTable 迁移conversation_groups表,添加新字段 +func (db *DB) migrateConversationGroupsTable() error { + // 检查pinned字段是否存在 + var count int + err := db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('conversation_groups') WHERE name='pinned'").Scan(&count) + if err != nil { + // 如果查询失败,尝试添加字段 + if _, addErr := db.Exec("ALTER TABLE conversation_groups ADD COLUMN pinned INTEGER DEFAULT 0"); addErr != nil { + // 如果字段已存在,忽略错误 + errMsg := strings.ToLower(addErr.Error()) + if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") { + db.logger.Warn("添加pinned字段失败", zap.Error(addErr)) + } + } + } else if count == 0 { + // 字段不存在,添加它 + if _, err := db.Exec("ALTER TABLE conversation_groups ADD COLUMN pinned INTEGER DEFAULT 0"); err != nil { + db.logger.Warn("添加pinned字段失败", zap.Error(err)) + } + } + + return nil +} + +// migrateConversationGroupMappingsTable 迁移conversation_group_mappings表,添加新字段 +func (db *DB) migrateConversationGroupMappingsTable() error { + // 检查pinned字段是否存在 + var count int + err := db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('conversation_group_mappings') WHERE name='pinned'").Scan(&count) + if err != nil { + // 如果查询失败,尝试添加字段 + if _, addErr := db.Exec("ALTER TABLE conversation_group_mappings ADD COLUMN pinned INTEGER DEFAULT 0"); addErr != nil { + // 如果字段已存在,忽略错误 + errMsg := strings.ToLower(addErr.Error()) + if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") { + db.logger.Warn("添加pinned字段失败", zap.Error(addErr)) + } + } + } else if count == 0 { + // 字段不存在,添加它 + if _, err := db.Exec("ALTER TABLE conversation_group_mappings ADD COLUMN pinned INTEGER DEFAULT 0"); err != nil { + db.logger.Warn("添加pinned字段失败", zap.Error(err)) + } + } + + return nil +} + +// migrateBatchTaskQueuesTable 迁移batch_task_queues表,补充新字段 +func (db *DB) migrateBatchTaskQueuesTable() error { + // 检查title字段是否存在 + var count int + err := db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('batch_task_queues') WHERE name='title'").Scan(&count) + if err != nil { + // 如果查询失败,尝试添加字段 + if _, addErr := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN title TEXT"); addErr != nil { + // 如果字段已存在,忽略错误 + errMsg := strings.ToLower(addErr.Error()) + if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") { + db.logger.Warn("添加title字段失败", zap.Error(addErr)) + } + } + } else if count == 0 { + // 字段不存在,添加它 + if _, err := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN title TEXT"); err != nil { + db.logger.Warn("添加title字段失败", zap.Error(err)) + } + } + + // 检查role字段是否存在 + var roleCount int + err = db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('batch_task_queues') WHERE name='role'").Scan(&roleCount) + if err != nil { + // 如果查询失败,尝试添加字段 + if _, addErr := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN role TEXT"); addErr != nil { + // 如果字段已存在,忽略错误 + errMsg := strings.ToLower(addErr.Error()) + if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") { + db.logger.Warn("添加role字段失败", zap.Error(addErr)) + } + } + } else if roleCount == 0 { + // 字段不存在,添加它 + if _, err := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN role TEXT"); err != nil { + db.logger.Warn("添加role字段失败", zap.Error(err)) + } + } + + // 检查agent_mode字段是否存在 + var agentModeCount int + err = db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('batch_task_queues') WHERE name='agent_mode'").Scan(&agentModeCount) + if err != nil { + if _, addErr := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN agent_mode TEXT NOT NULL DEFAULT 'single'"); addErr != nil { + errMsg := strings.ToLower(addErr.Error()) + if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") { + db.logger.Warn("添加agent_mode字段失败", zap.Error(addErr)) + } + } + } else if agentModeCount == 0 { + if _, err := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN agent_mode TEXT NOT NULL DEFAULT 'single'"); err != nil { + db.logger.Warn("添加agent_mode字段失败", zap.Error(err)) + } + } + + // 检查schedule_mode字段是否存在 + var scheduleModeCount int + err = db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('batch_task_queues') WHERE name='schedule_mode'").Scan(&scheduleModeCount) + if err != nil { + if _, addErr := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN schedule_mode TEXT NOT NULL DEFAULT 'manual'"); addErr != nil { + errMsg := strings.ToLower(addErr.Error()) + if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") { + db.logger.Warn("添加schedule_mode字段失败", zap.Error(addErr)) + } + } + } else if scheduleModeCount == 0 { + if _, err := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN schedule_mode TEXT NOT NULL DEFAULT 'manual'"); err != nil { + db.logger.Warn("添加schedule_mode字段失败", zap.Error(err)) + } + } + + // 检查cron_expr字段是否存在 + var cronExprCount int + err = db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('batch_task_queues') WHERE name='cron_expr'").Scan(&cronExprCount) + if err != nil { + if _, addErr := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN cron_expr TEXT"); addErr != nil { + errMsg := strings.ToLower(addErr.Error()) + if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") { + db.logger.Warn("添加cron_expr字段失败", zap.Error(addErr)) + } + } + } else if cronExprCount == 0 { + if _, err := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN cron_expr TEXT"); err != nil { + db.logger.Warn("添加cron_expr字段失败", zap.Error(err)) + } + } + + // 检查next_run_at字段是否存在 + var nextRunAtCount int + err = db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('batch_task_queues') WHERE name='next_run_at'").Scan(&nextRunAtCount) + if err != nil { + if _, addErr := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN next_run_at DATETIME"); addErr != nil { + errMsg := strings.ToLower(addErr.Error()) + if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") { + db.logger.Warn("添加next_run_at字段失败", zap.Error(addErr)) + } + } + } else if nextRunAtCount == 0 { + if _, err := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN next_run_at DATETIME"); err != nil { + db.logger.Warn("添加next_run_at字段失败", zap.Error(err)) + } + } + + // schedule_enabled:0=暂停 Cron 自动调度,1=允许(手工执行不受影响) + var scheduleEnCount int + err = db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('batch_task_queues') WHERE name='schedule_enabled'").Scan(&scheduleEnCount) + if err != nil { + if _, addErr := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN schedule_enabled INTEGER NOT NULL DEFAULT 1"); addErr != nil { + errMsg := strings.ToLower(addErr.Error()) + if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") { + db.logger.Warn("添加schedule_enabled字段失败", zap.Error(addErr)) + } + } + } else if scheduleEnCount == 0 { + if _, err := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN schedule_enabled INTEGER NOT NULL DEFAULT 1"); err != nil { + db.logger.Warn("添加schedule_enabled字段失败", zap.Error(err)) + } + } + + var lastTrigCount int + err = db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('batch_task_queues') WHERE name='last_schedule_trigger_at'").Scan(&lastTrigCount) + if err != nil { + if _, addErr := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN last_schedule_trigger_at DATETIME"); addErr != nil { + errMsg := strings.ToLower(addErr.Error()) + if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") { + db.logger.Warn("添加last_schedule_trigger_at字段失败", zap.Error(addErr)) + } + } + } else if lastTrigCount == 0 { + if _, err := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN last_schedule_trigger_at DATETIME"); err != nil { + db.logger.Warn("添加last_schedule_trigger_at字段失败", zap.Error(err)) + } + } + + var lastSchedErrCount int + err = db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('batch_task_queues') WHERE name='last_schedule_error'").Scan(&lastSchedErrCount) + if err != nil { + if _, addErr := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN last_schedule_error TEXT"); addErr != nil { + errMsg := strings.ToLower(addErr.Error()) + if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") { + db.logger.Warn("添加last_schedule_error字段失败", zap.Error(addErr)) + } + } + } else if lastSchedErrCount == 0 { + if _, err := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN last_schedule_error TEXT"); err != nil { + db.logger.Warn("添加last_schedule_error字段失败", zap.Error(err)) + } + } + + var lastRunErrCount int + err = db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('batch_task_queues') WHERE name='last_run_error'").Scan(&lastRunErrCount) + if err != nil { + if _, addErr := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN last_run_error TEXT"); addErr != nil { + errMsg := strings.ToLower(addErr.Error()) + if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") { + db.logger.Warn("添加last_run_error字段失败", zap.Error(addErr)) + } + } + } else if lastRunErrCount == 0 { + if _, err := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN last_run_error TEXT"); err != nil { + db.logger.Warn("添加last_run_error字段失败", zap.Error(err)) + } + } + + return nil +} + +// NewKnowledgeDB 创建知识库数据库连接(只包含知识库相关的表) +func NewKnowledgeDB(dbPath string, logger *zap.Logger) (*DB, error) { + sqlDB, err := sql.Open("sqlite3", dbPath+"?_journal_mode=WAL&_foreign_keys=1") + if err != nil { + return nil, fmt.Errorf("打开知识库数据库失败: %w", err) + } + + if err := sqlDB.Ping(); err != nil { + return nil, fmt.Errorf("连接知识库数据库失败: %w", err) + } + + database := &DB{ + DB: sqlDB, + logger: logger, + } + + // 初始化知识库表 + if err := database.initKnowledgeTables(); err != nil { + return nil, fmt.Errorf("初始化知识库表失败: %w", err) + } + + return database, nil +} + +// initKnowledgeTables 初始化知识库数据库表(只包含知识库相关的表) +func (db *DB) initKnowledgeTables() error { + // 创建知识库项表 + createKnowledgeBaseItemsTable := ` + CREATE TABLE IF NOT EXISTS knowledge_base_items ( + id TEXT PRIMARY KEY, + category TEXT NOT NULL, + title TEXT NOT NULL, + file_path TEXT NOT NULL, + content TEXT, + created_at DATETIME NOT NULL, + updated_at DATETIME NOT NULL + );` + + // 创建知识库向量表 + createKnowledgeEmbeddingsTable := ` + CREATE TABLE IF NOT EXISTS knowledge_embeddings ( + id TEXT PRIMARY KEY, + item_id TEXT NOT NULL, + chunk_index INTEGER NOT NULL, + chunk_text TEXT NOT NULL, + embedding TEXT NOT NULL, + sub_indexes TEXT NOT NULL DEFAULT '', + embedding_model TEXT NOT NULL DEFAULT '', + embedding_dim INTEGER NOT NULL DEFAULT 0, + created_at DATETIME NOT NULL, + FOREIGN KEY (item_id) REFERENCES knowledge_base_items(id) ON DELETE CASCADE + );` + + // 创建知识检索日志表(在独立知识库数据库中,不使用外键约束,因为conversations和messages表可能不在这个数据库中) + createKnowledgeRetrievalLogsTable := ` + CREATE TABLE IF NOT EXISTS knowledge_retrieval_logs ( + id TEXT PRIMARY KEY, + conversation_id TEXT, + message_id TEXT, + query TEXT NOT NULL, + risk_type TEXT, + retrieved_items TEXT, + created_at DATETIME NOT NULL + );` + + // 创建索引 + createIndexes := ` + CREATE INDEX IF NOT EXISTS idx_knowledge_items_category ON knowledge_base_items(category); + CREATE INDEX IF NOT EXISTS idx_knowledge_embeddings_item_id ON knowledge_embeddings(item_id); + CREATE INDEX IF NOT EXISTS idx_knowledge_retrieval_logs_conversation ON knowledge_retrieval_logs(conversation_id); + CREATE INDEX IF NOT EXISTS idx_knowledge_retrieval_logs_message ON knowledge_retrieval_logs(message_id); + CREATE INDEX IF NOT EXISTS idx_knowledge_retrieval_logs_created_at ON knowledge_retrieval_logs(created_at); + ` + + if _, err := db.Exec(createKnowledgeBaseItemsTable); err != nil { + return fmt.Errorf("创建knowledge_base_items表失败: %w", err) + } + + if _, err := db.Exec(createKnowledgeEmbeddingsTable); err != nil { + return fmt.Errorf("创建knowledge_embeddings表失败: %w", err) + } + + if _, err := db.Exec(createKnowledgeRetrievalLogsTable); err != nil { + return fmt.Errorf("创建knowledge_retrieval_logs表失败: %w", err) + } + + if _, err := db.Exec(createIndexes); err != nil { + return fmt.Errorf("创建索引失败: %w", err) + } + + if err := db.migrateKnowledgeEmbeddingsColumns(); err != nil { + return fmt.Errorf("迁移 knowledge_embeddings 列失败: %w", err) + } + + db.logger.Info("知识库数据库表初始化完成") + return nil +} + +// migrateKnowledgeEmbeddingsColumns 为已有库补充 sub_indexes、embedding_model、embedding_dim。 +func (db *DB) migrateKnowledgeEmbeddingsColumns() error { + var n int + if err := db.QueryRow(`SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='knowledge_embeddings'`).Scan(&n); err != nil { + return err + } + if n == 0 { + return nil + } + migrations := []struct { + col string + stmt string + }{ + {"sub_indexes", `ALTER TABLE knowledge_embeddings ADD COLUMN sub_indexes TEXT NOT NULL DEFAULT ''`}, + {"embedding_model", `ALTER TABLE knowledge_embeddings ADD COLUMN embedding_model TEXT NOT NULL DEFAULT ''`}, + {"embedding_dim", `ALTER TABLE knowledge_embeddings ADD COLUMN embedding_dim INTEGER NOT NULL DEFAULT 0`}, + } + for _, m := range migrations { + var colCount int + q := `SELECT COUNT(*) FROM pragma_table_info('knowledge_embeddings') WHERE name = ?` + if err := db.QueryRow(q, m.col).Scan(&colCount); err != nil { + return err + } + if colCount > 0 { + continue + } + if _, err := db.Exec(m.stmt); err != nil { + return err + } + } + return nil +} + +// Close 关闭数据库连接 +func (db *DB) Close() error { + return db.DB.Close() +} diff --git a/internal/database/group.go b/internal/database/group.go new file mode 100644 index 00000000..a3d32106 --- /dev/null +++ b/internal/database/group.go @@ -0,0 +1,449 @@ +package database + +import ( + "database/sql" + "fmt" + "time" + + "github.com/google/uuid" +) + +// ConversationGroup 对话分组 +type ConversationGroup struct { + ID string `json:"id"` + Name string `json:"name"` + Icon string `json:"icon"` + Pinned bool `json:"pinned"` + CreatedAt time.Time `json:"createdAt"` + UpdatedAt time.Time `json:"updatedAt"` +} + +// GroupExistsByName 检查分组名称是否已存在 +func (db *DB) GroupExistsByName(name string, excludeID string) (bool, error) { + var count int + var err error + + if excludeID != "" { + err = db.QueryRow( + "SELECT COUNT(*) FROM conversation_groups WHERE name = ? AND id != ?", + name, excludeID, + ).Scan(&count) + } else { + err = db.QueryRow( + "SELECT COUNT(*) FROM conversation_groups WHERE name = ?", + name, + ).Scan(&count) + } + + if err != nil { + return false, fmt.Errorf("检查分组名称失败: %w", err) + } + + return count > 0, nil +} + +// CreateGroup 创建分组 +func (db *DB) CreateGroup(name, icon string) (*ConversationGroup, error) { + // 检查名称是否已存在 + exists, err := db.GroupExistsByName(name, "") + if err != nil { + return nil, err + } + if exists { + return nil, fmt.Errorf("分组名称已存在") + } + + id := uuid.New().String() + now := time.Now() + + if icon == "" { + icon = "📁" + } + + _, err = db.Exec( + "INSERT INTO conversation_groups (id, name, icon, pinned, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?)", + id, name, icon, 0, now, now, + ) + if err != nil { + return nil, fmt.Errorf("创建分组失败: %w", err) + } + + return &ConversationGroup{ + ID: id, + Name: name, + Icon: icon, + Pinned: false, + CreatedAt: now, + UpdatedAt: now, + }, nil +} + +// ListGroups 列出所有分组 +func (db *DB) ListGroups() ([]*ConversationGroup, error) { + rows, err := db.Query( + "SELECT id, name, icon, COALESCE(pinned, 0), created_at, updated_at FROM conversation_groups ORDER BY COALESCE(pinned, 0) DESC, created_at ASC", + ) + if err != nil { + return nil, fmt.Errorf("查询分组列表失败: %w", err) + } + defer rows.Close() + + var groups []*ConversationGroup + for rows.Next() { + var group ConversationGroup + var createdAt, updatedAt string + var pinned int + + if err := rows.Scan(&group.ID, &group.Name, &group.Icon, &pinned, &createdAt, &updatedAt); err != nil { + return nil, fmt.Errorf("扫描分组失败: %w", err) + } + + group.Pinned = pinned != 0 + + // 尝试多种时间格式解析 + var err1, err2 error + group.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05.999999999-07:00", createdAt) + if err1 != nil { + group.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05", createdAt) + } + if err1 != nil { + group.CreatedAt, _ = time.Parse(time.RFC3339, createdAt) + } + + group.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05.999999999-07:00", updatedAt) + if err2 != nil { + group.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05", updatedAt) + } + if err2 != nil { + group.UpdatedAt, _ = time.Parse(time.RFC3339, updatedAt) + } + + groups = append(groups, &group) + } + + return groups, nil +} + +// GetGroup 获取分组 +func (db *DB) GetGroup(id string) (*ConversationGroup, error) { + var group ConversationGroup + var createdAt, updatedAt string + var pinned int + + err := db.QueryRow( + "SELECT id, name, icon, COALESCE(pinned, 0), created_at, updated_at FROM conversation_groups WHERE id = ?", + id, + ).Scan(&group.ID, &group.Name, &group.Icon, &pinned, &createdAt, &updatedAt) + if err != nil { + if err == sql.ErrNoRows { + return nil, fmt.Errorf("分组不存在") + } + return nil, fmt.Errorf("查询分组失败: %w", err) + } + + // 尝试多种时间格式解析 + var err1, err2 error + group.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05.999999999-07:00", createdAt) + if err1 != nil { + group.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05", createdAt) + } + if err1 != nil { + group.CreatedAt, _ = time.Parse(time.RFC3339, createdAt) + } + + group.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05.999999999-07:00", updatedAt) + if err2 != nil { + group.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05", updatedAt) + } + if err2 != nil { + group.UpdatedAt, _ = time.Parse(time.RFC3339, updatedAt) + } + + group.Pinned = pinned != 0 + + return &group, nil +} + +// UpdateGroup 更新分组 +func (db *DB) UpdateGroup(id, name, icon string) error { + // 检查名称是否已存在(排除当前分组) + exists, err := db.GroupExistsByName(name, id) + if err != nil { + return err + } + if exists { + return fmt.Errorf("分组名称已存在") + } + + _, err = db.Exec( + "UPDATE conversation_groups SET name = ?, icon = ?, updated_at = ? WHERE id = ?", + name, icon, time.Now(), id, + ) + if err != nil { + return fmt.Errorf("更新分组失败: %w", err) + } + return nil +} + +// DeleteGroup 删除分组 +func (db *DB) DeleteGroup(id string) error { + _, err := db.Exec("DELETE FROM conversation_groups WHERE id = ?", id) + if err != nil { + return fmt.Errorf("删除分组失败: %w", err) + } + return nil +} + +// AddConversationToGroup 将对话添加到分组 +// 注意:一个对话只能属于一个分组,所以在添加新分组之前,会先删除该对话的所有旧分组关联 +func (db *DB) AddConversationToGroup(conversationID, groupID string) error { + // 先删除该对话的所有旧分组关联,确保一个对话只属于一个分组 + _, err := db.Exec( + "DELETE FROM conversation_group_mappings WHERE conversation_id = ?", + conversationID, + ) + if err != nil { + return fmt.Errorf("删除对话旧分组关联失败: %w", err) + } + + // 然后插入新的分组关联 + id := uuid.New().String() + _, err = db.Exec( + "INSERT INTO conversation_group_mappings (id, conversation_id, group_id, created_at) VALUES (?, ?, ?, ?)", + id, conversationID, groupID, time.Now(), + ) + if err != nil { + return fmt.Errorf("添加对话到分组失败: %w", err) + } + return nil +} + +// RemoveConversationFromGroup 从分组中移除对话 +func (db *DB) RemoveConversationFromGroup(conversationID, groupID string) error { + _, err := db.Exec( + "DELETE FROM conversation_group_mappings WHERE conversation_id = ? AND group_id = ?", + conversationID, groupID, + ) + if err != nil { + return fmt.Errorf("从分组中移除对话失败: %w", err) + } + return nil +} + +// GetConversationsByGroup 获取分组中的所有对话 +func (db *DB) GetConversationsByGroup(groupID string) ([]*Conversation, error) { + rows, err := db.Query( + `SELECT c.id, c.title, COALESCE(c.pinned, 0), c.created_at, c.updated_at, COALESCE(cgm.pinned, 0) as group_pinned + FROM conversations c + INNER JOIN conversation_group_mappings cgm ON c.id = cgm.conversation_id + WHERE cgm.group_id = ? + ORDER BY COALESCE(cgm.pinned, 0) DESC, c.updated_at DESC`, + groupID, + ) + if err != nil { + return nil, fmt.Errorf("查询分组对话失败: %w", err) + } + defer rows.Close() + + var conversations []*Conversation + for rows.Next() { + var conv Conversation + var createdAt, updatedAt string + var pinned int + var groupPinned int + + if err := rows.Scan(&conv.ID, &conv.Title, &pinned, &createdAt, &updatedAt, &groupPinned); err != nil { + return nil, fmt.Errorf("扫描对话失败: %w", err) + } + + // 尝试多种时间格式解析 + var err1, err2 error + conv.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05.999999999-07:00", createdAt) + if err1 != nil { + conv.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05", createdAt) + } + if err1 != nil { + conv.CreatedAt, _ = time.Parse(time.RFC3339, createdAt) + } + + conv.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05.999999999-07:00", updatedAt) + if err2 != nil { + conv.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05", updatedAt) + } + if err2 != nil { + conv.UpdatedAt, _ = time.Parse(time.RFC3339, updatedAt) + } + + conv.Pinned = pinned != 0 + + conversations = append(conversations, &conv) + } + + return conversations, nil +} + +// SearchConversationsByGroup 搜索分组中的对话(按标题和消息内容模糊匹配) +func (db *DB) SearchConversationsByGroup(groupID string, searchQuery string) ([]*Conversation, error) { + // 构建SQL查询,支持按标题和消息内容搜索 + // 使用 DISTINCT 避免因为一个对话有多条匹配消息而重复 + query := `SELECT DISTINCT c.id, c.title, COALESCE(c.pinned, 0), c.created_at, c.updated_at, COALESCE(cgm.pinned, 0) as group_pinned + FROM conversations c + INNER JOIN conversation_group_mappings cgm ON c.id = cgm.conversation_id + WHERE cgm.group_id = ?` + + args := []interface{}{groupID} + + // 如果有搜索关键词,添加标题和消息内容搜索条件 + if searchQuery != "" { + searchPattern := "%" + searchQuery + "%" + // 搜索标题或消息内容 + // 使用 LEFT JOIN 连接消息表,这样即使没有消息的对话也能被搜索到(通过标题) + query += ` AND ( + LOWER(c.title) LIKE LOWER(?) + OR EXISTS ( + SELECT 1 FROM messages m + WHERE m.conversation_id = c.id + AND LOWER(m.content) LIKE LOWER(?) + ) + )` + args = append(args, searchPattern, searchPattern) + } + + query += " ORDER BY COALESCE(cgm.pinned, 0) DESC, c.updated_at DESC" + + rows, err := db.Query(query, args...) + if err != nil { + return nil, fmt.Errorf("搜索分组对话失败: %w", err) + } + defer rows.Close() + + var conversations []*Conversation + for rows.Next() { + var conv Conversation + var createdAt, updatedAt string + var pinned int + var groupPinned int + + if err := rows.Scan(&conv.ID, &conv.Title, &pinned, &createdAt, &updatedAt, &groupPinned); err != nil { + return nil, fmt.Errorf("扫描对话失败: %w", err) + } + + // 尝试多种时间格式解析 + var err1, err2 error + conv.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05.999999999-07:00", createdAt) + if err1 != nil { + conv.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05", createdAt) + } + if err1 != nil { + conv.CreatedAt, _ = time.Parse(time.RFC3339, createdAt) + } + + conv.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05.999999999-07:00", updatedAt) + if err2 != nil { + conv.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05", updatedAt) + } + if err2 != nil { + conv.UpdatedAt, _ = time.Parse(time.RFC3339, updatedAt) + } + + conv.Pinned = pinned != 0 + + conversations = append(conversations, &conv) + } + + return conversations, nil +} + +// GetGroupByConversation 获取对话所属的分组 +func (db *DB) GetGroupByConversation(conversationID string) (string, error) { + var groupID string + err := db.QueryRow( + "SELECT group_id FROM conversation_group_mappings WHERE conversation_id = ? LIMIT 1", + conversationID, + ).Scan(&groupID) + if err != nil { + if err == sql.ErrNoRows { + return "", nil // 没有分组 + } + return "", fmt.Errorf("查询对话分组失败: %w", err) + } + return groupID, nil +} + +// UpdateConversationPinned 更新对话置顶状态 +func (db *DB) UpdateConversationPinned(id string, pinned bool) error { + pinnedValue := 0 + if pinned { + pinnedValue = 1 + } + // 注意:不更新 updated_at,因为置顶操作不应该改变对话的更新时间 + _, err := db.Exec( + "UPDATE conversations SET pinned = ? WHERE id = ?", + pinnedValue, id, + ) + if err != nil { + return fmt.Errorf("更新对话置顶状态失败: %w", err) + } + return nil +} + +// UpdateGroupPinned 更新分组置顶状态 +func (db *DB) UpdateGroupPinned(id string, pinned bool) error { + pinnedValue := 0 + if pinned { + pinnedValue = 1 + } + _, err := db.Exec( + "UPDATE conversation_groups SET pinned = ?, updated_at = ? WHERE id = ?", + pinnedValue, time.Now(), id, + ) + if err != nil { + return fmt.Errorf("更新分组置顶状态失败: %w", err) + } + return nil +} + +// GroupMapping 分组映射关系 +type GroupMapping struct { + ConversationID string `json:"conversationId"` + GroupID string `json:"groupId"` +} + +// GetAllGroupMappings 批量获取所有分组映射(消除 N+1 查询) +func (db *DB) GetAllGroupMappings() ([]GroupMapping, error) { + rows, err := db.Query("SELECT conversation_id, group_id FROM conversation_group_mappings") + if err != nil { + return nil, fmt.Errorf("查询分组映射失败: %w", err) + } + defer rows.Close() + + var mappings []GroupMapping + for rows.Next() { + var m GroupMapping + if err := rows.Scan(&m.ConversationID, &m.GroupID); err != nil { + return nil, fmt.Errorf("扫描分组映射失败: %w", err) + } + mappings = append(mappings, m) + } + + if mappings == nil { + mappings = []GroupMapping{} + } + return mappings, nil +} + +// UpdateConversationPinnedInGroup 更新对话在分组中的置顶状态 +func (db *DB) UpdateConversationPinnedInGroup(conversationID, groupID string, pinned bool) error { + pinnedValue := 0 + if pinned { + pinnedValue = 1 + } + _, err := db.Exec( + "UPDATE conversation_group_mappings SET pinned = ? WHERE conversation_id = ? AND group_id = ?", + pinnedValue, conversationID, groupID, + ) + if err != nil { + return fmt.Errorf("更新分组对话置顶状态失败: %w", err) + } + return nil +} diff --git a/internal/database/monitor.go b/internal/database/monitor.go new file mode 100644 index 00000000..bdfffb61 --- /dev/null +++ b/internal/database/monitor.go @@ -0,0 +1,537 @@ +package database + +import ( + "database/sql" + "encoding/json" + "strings" + "time" + + "cyberstrike-ai/internal/mcp" + + "go.uber.org/zap" +) + +// SaveToolExecution 保存工具执行记录 +func (db *DB) SaveToolExecution(exec *mcp.ToolExecution) error { + argsJSON, err := json.Marshal(exec.Arguments) + if err != nil { + db.logger.Warn("序列化执行参数失败", zap.Error(err)) + argsJSON = []byte("{}") + } + + var resultJSON sql.NullString + if exec.Result != nil { + resultBytes, err := json.Marshal(exec.Result) + if err != nil { + db.logger.Warn("序列化执行结果失败", zap.Error(err)) + } else { + resultJSON = sql.NullString{String: string(resultBytes), Valid: true} + } + } + + var errorText sql.NullString + if exec.Error != "" { + errorText = sql.NullString{String: exec.Error, Valid: true} + } + + var endTime sql.NullTime + if exec.EndTime != nil { + endTime = sql.NullTime{Time: *exec.EndTime, Valid: true} + } + + var durationMs sql.NullInt64 + if exec.Duration > 0 { + durationMs = sql.NullInt64{Int64: exec.Duration.Milliseconds(), Valid: true} + } + + query := ` + INSERT OR REPLACE INTO tool_executions + (id, tool_name, arguments, status, result, error, start_time, end_time, duration_ms, created_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + ` + + _, err = db.Exec(query, + exec.ID, + exec.ToolName, + string(argsJSON), + exec.Status, + resultJSON, + errorText, + exec.StartTime, + endTime, + durationMs, + time.Now(), + ) + + if err != nil { + db.logger.Error("保存工具执行记录失败", zap.Error(err), zap.String("executionId", exec.ID)) + return err + } + + return nil +} + +// CountToolExecutions 统计工具执行记录总数 +func (db *DB) CountToolExecutions(status, toolName string) (int, error) { + query := `SELECT COUNT(*) FROM tool_executions` + args := []interface{}{} + conditions := []string{} + if status != "" { + conditions = append(conditions, "status = ?") + args = append(args, status) + } + if toolName != "" { + // 支持部分匹配(模糊搜索),不区分大小写 + conditions = append(conditions, "LOWER(tool_name) LIKE ?") + args = append(args, "%"+strings.ToLower(toolName)+"%") + } + if len(conditions) > 0 { + query += ` WHERE ` + conditions[0] + for i := 1; i < len(conditions); i++ { + query += ` AND ` + conditions[i] + } + } + var count int + err := db.QueryRow(query, args...).Scan(&count) + if err != nil { + return 0, err + } + return count, nil +} + +// LoadToolExecutions 加载所有工具执行记录(支持分页) +func (db *DB) LoadToolExecutions() ([]*mcp.ToolExecution, error) { + return db.LoadToolExecutionsWithPagination(0, 1000, "", "") +} + +// LoadToolExecutionsWithPagination 分页加载工具执行记录 +// limit: 最大返回记录数,0 表示使用默认值 1000 +// offset: 跳过的记录数,用于分页 +// status: 状态筛选,空字符串表示不过滤 +// toolName: 工具名称筛选,空字符串表示不过滤 +func (db *DB) LoadToolExecutionsWithPagination(offset, limit int, status, toolName string) ([]*mcp.ToolExecution, error) { + if limit <= 0 { + limit = 1000 // 默认限制 + } + if limit > 10000 { + limit = 10000 // 最大限制,防止一次性加载过多数据 + } + + query := ` + SELECT id, tool_name, arguments, status, result, error, start_time, end_time, duration_ms + FROM tool_executions + ` + args := []interface{}{} + conditions := []string{} + if status != "" { + conditions = append(conditions, "status = ?") + args = append(args, status) + } + if toolName != "" { + // 支持部分匹配(模糊搜索),不区分大小写 + conditions = append(conditions, "LOWER(tool_name) LIKE ?") + args = append(args, "%"+strings.ToLower(toolName)+"%") + } + if len(conditions) > 0 { + query += ` WHERE ` + conditions[0] + for i := 1; i < len(conditions); i++ { + query += ` AND ` + conditions[i] + } + } + query += ` ORDER BY start_time DESC LIMIT ? OFFSET ?` + args = append(args, limit, offset) + + rows, err := db.Query(query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + var executions []*mcp.ToolExecution + for rows.Next() { + var exec mcp.ToolExecution + var argsJSON string + var resultJSON sql.NullString + var errorText sql.NullString + var endTime sql.NullTime + var durationMs sql.NullInt64 + + err := rows.Scan( + &exec.ID, + &exec.ToolName, + &argsJSON, + &exec.Status, + &resultJSON, + &errorText, + &exec.StartTime, + &endTime, + &durationMs, + ) + if err != nil { + db.logger.Warn("加载执行记录失败", zap.Error(err)) + continue + } + + // 解析参数 + if err := json.Unmarshal([]byte(argsJSON), &exec.Arguments); err != nil { + db.logger.Warn("解析执行参数失败", zap.Error(err)) + exec.Arguments = make(map[string]interface{}) + } + + // 解析结果 + if resultJSON.Valid && resultJSON.String != "" { + var result mcp.ToolResult + if err := json.Unmarshal([]byte(resultJSON.String), &result); err != nil { + db.logger.Warn("解析执行结果失败", zap.Error(err)) + } else { + exec.Result = &result + } + } + + // 设置错误 + if errorText.Valid { + exec.Error = errorText.String + } + + // 设置结束时间 + if endTime.Valid { + exec.EndTime = &endTime.Time + } + + // 设置持续时间 + if durationMs.Valid { + exec.Duration = time.Duration(durationMs.Int64) * time.Millisecond + } + + executions = append(executions, &exec) + } + + return executions, nil +} + +// GetToolExecution 根据ID获取单条工具执行记录 +func (db *DB) GetToolExecution(id string) (*mcp.ToolExecution, error) { + query := ` + SELECT id, tool_name, arguments, status, result, error, start_time, end_time, duration_ms + FROM tool_executions + WHERE id = ? + ` + + row := db.QueryRow(query, id) + + var exec mcp.ToolExecution + var argsJSON string + var resultJSON sql.NullString + var errorText sql.NullString + var endTime sql.NullTime + var durationMs sql.NullInt64 + + err := row.Scan( + &exec.ID, + &exec.ToolName, + &argsJSON, + &exec.Status, + &resultJSON, + &errorText, + &exec.StartTime, + &endTime, + &durationMs, + ) + if err != nil { + return nil, err + } + + if err := json.Unmarshal([]byte(argsJSON), &exec.Arguments); err != nil { + db.logger.Warn("解析执行参数失败", zap.Error(err)) + exec.Arguments = make(map[string]interface{}) + } + + if resultJSON.Valid && resultJSON.String != "" { + var result mcp.ToolResult + if err := json.Unmarshal([]byte(resultJSON.String), &result); err != nil { + db.logger.Warn("解析执行结果失败", zap.Error(err)) + } else { + exec.Result = &result + } + } + + if errorText.Valid { + exec.Error = errorText.String + } + + if endTime.Valid { + exec.EndTime = &endTime.Time + } + + if durationMs.Valid { + exec.Duration = time.Duration(durationMs.Int64) * time.Millisecond + } + + return &exec, nil +} + +// DeleteToolExecution 删除工具执行记录 +func (db *DB) DeleteToolExecution(id string) error { + query := `DELETE FROM tool_executions WHERE id = ?` + _, err := db.Exec(query, id) + if err != nil { + db.logger.Error("删除工具执行记录失败", zap.Error(err), zap.String("executionId", id)) + return err + } + return nil +} + +// DeleteToolExecutions 批量删除工具执行记录 +func (db *DB) DeleteToolExecutions(ids []string) error { + if len(ids) == 0 { + return nil + } + + // 构建 IN 查询的占位符 + placeholders := make([]string, len(ids)) + args := make([]interface{}, len(ids)) + for i, id := range ids { + placeholders[i] = "?" + args[i] = id + } + + query := `DELETE FROM tool_executions WHERE id IN (` + strings.Join(placeholders, ",") + `)` + _, err := db.Exec(query, args...) + if err != nil { + db.logger.Error("批量删除工具执行记录失败", zap.Error(err), zap.Int("count", len(ids))) + return err + } + return nil +} + +// GetToolExecutionsByIds 根据ID列表获取工具执行记录(用于批量删除前获取统计信息) +func (db *DB) GetToolExecutionsByIds(ids []string) ([]*mcp.ToolExecution, error) { + if len(ids) == 0 { + return []*mcp.ToolExecution{}, nil + } + + // 构建 IN 查询的占位符 + placeholders := make([]string, len(ids)) + args := make([]interface{}, len(ids)) + for i, id := range ids { + placeholders[i] = "?" + args[i] = id + } + + query := ` + SELECT id, tool_name, arguments, status, result, error, start_time, end_time, duration_ms + FROM tool_executions + WHERE id IN (` + strings.Join(placeholders, ",") + `) + ` + + rows, err := db.Query(query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + var executions []*mcp.ToolExecution + for rows.Next() { + var exec mcp.ToolExecution + var argsJSON string + var resultJSON sql.NullString + var errorText sql.NullString + var endTime sql.NullTime + var durationMs sql.NullInt64 + + err := rows.Scan( + &exec.ID, + &exec.ToolName, + &argsJSON, + &exec.Status, + &resultJSON, + &errorText, + &exec.StartTime, + &endTime, + &durationMs, + ) + if err != nil { + db.logger.Warn("加载执行记录失败", zap.Error(err)) + continue + } + + // 解析参数 + if err := json.Unmarshal([]byte(argsJSON), &exec.Arguments); err != nil { + db.logger.Warn("解析执行参数失败", zap.Error(err)) + exec.Arguments = make(map[string]interface{}) + } + + // 解析结果 + if resultJSON.Valid && resultJSON.String != "" { + var result mcp.ToolResult + if err := json.Unmarshal([]byte(resultJSON.String), &result); err != nil { + db.logger.Warn("解析执行结果失败", zap.Error(err)) + } else { + exec.Result = &result + } + } + + // 设置错误 + if errorText.Valid { + exec.Error = errorText.String + } + + // 设置结束时间 + if endTime.Valid { + exec.EndTime = &endTime.Time + } + + // 设置持续时间 + if durationMs.Valid { + exec.Duration = time.Duration(durationMs.Int64) * time.Millisecond + } + + executions = append(executions, &exec) + } + + return executions, nil +} + +// SaveToolStats 保存工具统计信息 +func (db *DB) SaveToolStats(toolName string, stats *mcp.ToolStats) error { + var lastCallTime sql.NullTime + if stats.LastCallTime != nil { + lastCallTime = sql.NullTime{Time: *stats.LastCallTime, Valid: true} + } + + query := ` + INSERT OR REPLACE INTO tool_stats + (tool_name, total_calls, success_calls, failed_calls, last_call_time, updated_at) + VALUES (?, ?, ?, ?, ?, ?) + ` + + _, err := db.Exec(query, + toolName, + stats.TotalCalls, + stats.SuccessCalls, + stats.FailedCalls, + lastCallTime, + time.Now(), + ) + + if err != nil { + db.logger.Error("保存工具统计信息失败", zap.Error(err), zap.String("toolName", toolName)) + return err + } + + return nil +} + +// LoadToolStats 加载所有工具统计信息 +func (db *DB) LoadToolStats() (map[string]*mcp.ToolStats, error) { + query := ` + SELECT tool_name, total_calls, success_calls, failed_calls, last_call_time + FROM tool_stats + ` + + rows, err := db.Query(query) + if err != nil { + return nil, err + } + defer rows.Close() + + stats := make(map[string]*mcp.ToolStats) + for rows.Next() { + var stat mcp.ToolStats + var lastCallTime sql.NullTime + + err := rows.Scan( + &stat.ToolName, + &stat.TotalCalls, + &stat.SuccessCalls, + &stat.FailedCalls, + &lastCallTime, + ) + if err != nil { + db.logger.Warn("加载统计信息失败", zap.Error(err)) + continue + } + + if lastCallTime.Valid { + stat.LastCallTime = &lastCallTime.Time + } + + stats[stat.ToolName] = &stat + } + + return stats, nil +} + +// UpdateToolStats 更新工具统计信息(累加模式) +func (db *DB) UpdateToolStats(toolName string, totalCalls, successCalls, failedCalls int, lastCallTime *time.Time) error { + var lastCallTimeSQL sql.NullTime + if lastCallTime != nil { + lastCallTimeSQL = sql.NullTime{Time: *lastCallTime, Valid: true} + } + + query := ` + INSERT INTO tool_stats (tool_name, total_calls, success_calls, failed_calls, last_call_time, updated_at) + VALUES (?, ?, ?, ?, ?, ?) + ON CONFLICT(tool_name) DO UPDATE SET + total_calls = total_calls + ?, + success_calls = success_calls + ?, + failed_calls = failed_calls + ?, + last_call_time = COALESCE(?, last_call_time), + updated_at = ? + ` + + _, err := db.Exec(query, + toolName, totalCalls, successCalls, failedCalls, lastCallTimeSQL, time.Now(), + totalCalls, successCalls, failedCalls, lastCallTimeSQL, time.Now(), + ) + + if err != nil { + db.logger.Error("更新工具统计信息失败", zap.Error(err), zap.String("toolName", toolName)) + return err + } + + return nil +} + +// DecreaseToolStats 减少工具统计信息(用于删除执行记录时) +// 如果统计信息变为0,则删除该统计记录 +func (db *DB) DecreaseToolStats(toolName string, totalCalls, successCalls, failedCalls int) error { + // 先更新统计信息 + query := ` + UPDATE tool_stats SET + total_calls = CASE WHEN total_calls - ? < 0 THEN 0 ELSE total_calls - ? END, + success_calls = CASE WHEN success_calls - ? < 0 THEN 0 ELSE success_calls - ? END, + failed_calls = CASE WHEN failed_calls - ? < 0 THEN 0 ELSE failed_calls - ? END, + updated_at = ? + WHERE tool_name = ? + ` + + _, err := db.Exec(query, totalCalls, totalCalls, successCalls, successCalls, failedCalls, failedCalls, time.Now(), toolName) + if err != nil { + db.logger.Error("减少工具统计信息失败", zap.Error(err), zap.String("toolName", toolName)) + return err + } + + // 检查更新后的 total_calls 是否为 0,如果是则删除该统计记录 + checkQuery := `SELECT total_calls FROM tool_stats WHERE tool_name = ?` + var newTotalCalls int + err = db.QueryRow(checkQuery, toolName).Scan(&newTotalCalls) + if err != nil { + // 如果查询失败(记录不存在),直接返回 + return nil + } + + // 如果 total_calls 为 0,删除该统计记录 + if newTotalCalls == 0 { + deleteQuery := `DELETE FROM tool_stats WHERE tool_name = ?` + _, err = db.Exec(deleteQuery, toolName) + if err != nil { + db.logger.Warn("删除零统计记录失败", zap.Error(err), zap.String("toolName", toolName)) + // 不返回错误,因为主要操作(更新统计)已成功 + } else { + db.logger.Info("已删除零统计记录", zap.String("toolName", toolName)) + } + } + + return nil +} diff --git a/internal/database/skill_stats.go b/internal/database/skill_stats.go new file mode 100644 index 00000000..24e15585 --- /dev/null +++ b/internal/database/skill_stats.go @@ -0,0 +1,142 @@ +package database + +import ( + "database/sql" + "time" + + "go.uber.org/zap" +) + +// SkillStats Skills统计信息 +type SkillStats struct { + SkillName string + TotalCalls int + SuccessCalls int + FailedCalls int + LastCallTime *time.Time +} + +// SaveSkillStats 保存Skills统计信息 +func (db *DB) SaveSkillStats(skillName string, stats *SkillStats) error { + var lastCallTime sql.NullTime + if stats.LastCallTime != nil { + lastCallTime = sql.NullTime{Time: *stats.LastCallTime, Valid: true} + } + + query := ` + INSERT OR REPLACE INTO skill_stats + (skill_name, total_calls, success_calls, failed_calls, last_call_time, updated_at) + VALUES (?, ?, ?, ?, ?, ?) + ` + + _, err := db.Exec(query, + skillName, + stats.TotalCalls, + stats.SuccessCalls, + stats.FailedCalls, + lastCallTime, + time.Now(), + ) + + if err != nil { + db.logger.Error("保存Skills统计信息失败", zap.Error(err), zap.String("skillName", skillName)) + return err + } + + return nil +} + +// LoadSkillStats 加载所有Skills统计信息 +func (db *DB) LoadSkillStats() (map[string]*SkillStats, error) { + query := ` + SELECT skill_name, total_calls, success_calls, failed_calls, last_call_time + FROM skill_stats + ` + + rows, err := db.Query(query) + if err != nil { + return nil, err + } + defer rows.Close() + + stats := make(map[string]*SkillStats) + for rows.Next() { + var stat SkillStats + var lastCallTime sql.NullTime + + err := rows.Scan( + &stat.SkillName, + &stat.TotalCalls, + &stat.SuccessCalls, + &stat.FailedCalls, + &lastCallTime, + ) + if err != nil { + db.logger.Warn("加载Skills统计信息失败", zap.Error(err)) + continue + } + + if lastCallTime.Valid { + stat.LastCallTime = &lastCallTime.Time + } + + stats[stat.SkillName] = &stat + } + + return stats, nil +} + +// UpdateSkillStats 更新Skills统计信息(累加模式) +func (db *DB) UpdateSkillStats(skillName string, totalCalls, successCalls, failedCalls int, lastCallTime *time.Time) error { + var lastCallTimeSQL sql.NullTime + if lastCallTime != nil { + lastCallTimeSQL = sql.NullTime{Time: *lastCallTime, Valid: true} + } + + query := ` + INSERT INTO skill_stats (skill_name, total_calls, success_calls, failed_calls, last_call_time, updated_at) + VALUES (?, ?, ?, ?, ?, ?) + ON CONFLICT(skill_name) DO UPDATE SET + total_calls = total_calls + ?, + success_calls = success_calls + ?, + failed_calls = failed_calls + ?, + last_call_time = COALESCE(?, last_call_time), + updated_at = ? + ` + + _, err := db.Exec(query, + skillName, totalCalls, successCalls, failedCalls, lastCallTimeSQL, time.Now(), + totalCalls, successCalls, failedCalls, lastCallTimeSQL, time.Now(), + ) + + if err != nil { + db.logger.Error("更新Skills统计信息失败", zap.Error(err), zap.String("skillName", skillName)) + return err + } + + return nil +} + +// ClearSkillStats 清空所有Skills统计信息 +func (db *DB) ClearSkillStats() error { + query := `DELETE FROM skill_stats` + _, err := db.Exec(query) + if err != nil { + db.logger.Error("清空Skills统计信息失败", zap.Error(err)) + return err + } + db.logger.Info("已清空所有Skills统计信息") + return nil +} + +// ClearSkillStatsByName 清空指定skill的统计信息 +func (db *DB) ClearSkillStatsByName(skillName string) error { + query := `DELETE FROM skill_stats WHERE skill_name = ?` + _, err := db.Exec(query, skillName) + if err != nil { + db.logger.Error("清空指定skill统计信息失败", zap.Error(err), zap.String("skillName", skillName)) + return err + } + db.logger.Info("已清空指定skill统计信息", zap.String("skillName", skillName)) + return nil +} diff --git a/internal/database/vulnerability.go b/internal/database/vulnerability.go new file mode 100644 index 00000000..c4ec69b2 --- /dev/null +++ b/internal/database/vulnerability.go @@ -0,0 +1,281 @@ +package database + +import ( + "database/sql" + "fmt" + "time" + + "github.com/google/uuid" + "go.uber.org/zap" +) + +// Vulnerability 漏洞 +type Vulnerability struct { + ID string `json:"id"` + ConversationID string `json:"conversation_id"` + Title string `json:"title"` + Description string `json:"description"` + Severity string `json:"severity"` // critical, high, medium, low, info + Status string `json:"status"` // open, confirmed, fixed, false_positive + Type string `json:"type"` + Target string `json:"target"` + Proof string `json:"proof"` + Impact string `json:"impact"` + Recommendation string `json:"recommendation"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +// CreateVulnerability 创建漏洞 +func (db *DB) CreateVulnerability(vuln *Vulnerability) (*Vulnerability, error) { + if vuln.ID == "" { + vuln.ID = uuid.New().String() + } + if vuln.Status == "" { + vuln.Status = "open" + } + now := time.Now() + if vuln.CreatedAt.IsZero() { + vuln.CreatedAt = now + } + vuln.UpdatedAt = now + + query := ` + INSERT INTO vulnerabilities ( + id, conversation_id, title, description, severity, status, + vulnerability_type, target, proof, impact, recommendation, + created_at, updated_at + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + ` + + _, err := db.Exec( + query, + vuln.ID, vuln.ConversationID, vuln.Title, vuln.Description, + vuln.Severity, vuln.Status, vuln.Type, vuln.Target, + vuln.Proof, vuln.Impact, vuln.Recommendation, + vuln.CreatedAt, vuln.UpdatedAt, + ) + if err != nil { + return nil, fmt.Errorf("创建漏洞失败: %w", err) + } + + return vuln, nil +} + +// GetVulnerability 获取漏洞 +func (db *DB) GetVulnerability(id string) (*Vulnerability, error) { + var vuln Vulnerability + query := ` + SELECT id, conversation_id, title, description, severity, status, + vulnerability_type, target, proof, impact, recommendation, + created_at, updated_at + FROM vulnerabilities + WHERE id = ? + ` + + err := db.QueryRow(query, id).Scan( + &vuln.ID, &vuln.ConversationID, &vuln.Title, &vuln.Description, + &vuln.Severity, &vuln.Status, &vuln.Type, &vuln.Target, + &vuln.Proof, &vuln.Impact, &vuln.Recommendation, + &vuln.CreatedAt, &vuln.UpdatedAt, + ) + if err != nil { + if err == sql.ErrNoRows { + return nil, fmt.Errorf("漏洞不存在") + } + return nil, fmt.Errorf("获取漏洞失败: %w", err) + } + + return &vuln, nil +} + +// ListVulnerabilities 列出漏洞 +func (db *DB) ListVulnerabilities(limit, offset int, id, conversationID, severity, status string) ([]*Vulnerability, error) { + query := ` + SELECT id, conversation_id, title, description, severity, status, + vulnerability_type, target, proof, impact, recommendation, + created_at, updated_at + FROM vulnerabilities + WHERE 1=1 + ` + args := []interface{}{} + + if id != "" { + query += " AND id = ?" + args = append(args, id) + } + if conversationID != "" { + query += " AND conversation_id = ?" + args = append(args, conversationID) + } + if severity != "" { + query += " AND severity = ?" + args = append(args, severity) + } + if status != "" { + query += " AND status = ?" + args = append(args, status) + } + + query += " ORDER BY created_at DESC LIMIT ? OFFSET ?" + args = append(args, limit, offset) + + rows, err := db.Query(query, args...) + if err != nil { + return nil, fmt.Errorf("查询漏洞列表失败: %w", err) + } + defer rows.Close() + + var vulnerabilities []*Vulnerability + for rows.Next() { + var vuln Vulnerability + err := rows.Scan( + &vuln.ID, &vuln.ConversationID, &vuln.Title, &vuln.Description, + &vuln.Severity, &vuln.Status, &vuln.Type, &vuln.Target, + &vuln.Proof, &vuln.Impact, &vuln.Recommendation, + &vuln.CreatedAt, &vuln.UpdatedAt, + ) + if err != nil { + db.logger.Warn("扫描漏洞记录失败", zap.Error(err)) + continue + } + vulnerabilities = append(vulnerabilities, &vuln) + } + + return vulnerabilities, nil +} + +// CountVulnerabilities 统计漏洞总数(支持筛选条件) +func (db *DB) CountVulnerabilities(id, conversationID, severity, status string) (int, error) { + query := "SELECT COUNT(*) FROM vulnerabilities WHERE 1=1" + args := []interface{}{} + + if id != "" { + query += " AND id = ?" + args = append(args, id) + } + if conversationID != "" { + query += " AND conversation_id = ?" + args = append(args, conversationID) + } + if severity != "" { + query += " AND severity = ?" + args = append(args, severity) + } + if status != "" { + query += " AND status = ?" + args = append(args, status) + } + + var count int + err := db.QueryRow(query, args...).Scan(&count) + if err != nil { + return 0, fmt.Errorf("统计漏洞总数失败: %w", err) + } + + return count, nil +} + +// UpdateVulnerability 更新漏洞 +func (db *DB) UpdateVulnerability(id string, vuln *Vulnerability) error { + vuln.UpdatedAt = time.Now() + + query := ` + UPDATE vulnerabilities + SET title = ?, description = ?, severity = ?, status = ?, + vulnerability_type = ?, target = ?, proof = ?, impact = ?, + recommendation = ?, updated_at = ? + WHERE id = ? + ` + + _, err := db.Exec( + query, + vuln.Title, vuln.Description, vuln.Severity, vuln.Status, + vuln.Type, vuln.Target, vuln.Proof, vuln.Impact, + vuln.Recommendation, vuln.UpdatedAt, id, + ) + if err != nil { + return fmt.Errorf("更新漏洞失败: %w", err) + } + + return nil +} + +// DeleteVulnerability 删除漏洞 +func (db *DB) DeleteVulnerability(id string) error { + _, err := db.Exec("DELETE FROM vulnerabilities WHERE id = ?", id) + if err != nil { + return fmt.Errorf("删除漏洞失败: %w", err) + } + return nil +} + +// GetVulnerabilityStats 获取漏洞统计 +func (db *DB) GetVulnerabilityStats(conversationID string) (map[string]interface{}, error) { + stats := make(map[string]interface{}) + + // 总漏洞数 + var totalCount int + query := "SELECT COUNT(*) FROM vulnerabilities" + args := []interface{}{} + if conversationID != "" { + query += " WHERE conversation_id = ?" + args = append(args, conversationID) + } + err := db.QueryRow(query, args...).Scan(&totalCount) + if err != nil { + return nil, fmt.Errorf("获取总漏洞数失败: %w", err) + } + stats["total"] = totalCount + + // 按严重程度统计 + severityQuery := "SELECT severity, COUNT(*) FROM vulnerabilities" + if conversationID != "" { + severityQuery += " WHERE conversation_id = ?" + } + severityQuery += " GROUP BY severity" + + rows, err := db.Query(severityQuery, args...) + if err != nil { + return nil, fmt.Errorf("获取严重程度统计失败: %w", err) + } + defer rows.Close() + + severityStats := make(map[string]int) + for rows.Next() { + var severity string + var count int + if err := rows.Scan(&severity, &count); err != nil { + continue + } + severityStats[severity] = count + } + stats["by_severity"] = severityStats + + // 按状态统计 + statusQuery := "SELECT status, COUNT(*) FROM vulnerabilities" + if conversationID != "" { + statusQuery += " WHERE conversation_id = ?" + } + statusQuery += " GROUP BY status" + + rows, err = db.Query(statusQuery, args...) + if err != nil { + return nil, fmt.Errorf("获取状态统计失败: %w", err) + } + defer rows.Close() + + statusStats := make(map[string]int) + for rows.Next() { + var status string + var count int + if err := rows.Scan(&status, &count); err != nil { + continue + } + statusStats[status] = count + } + stats["by_status"] = statusStats + + return stats, nil +} + diff --git a/internal/database/webshell.go b/internal/database/webshell.go new file mode 100644 index 00000000..2ea25da7 --- /dev/null +++ b/internal/database/webshell.go @@ -0,0 +1,148 @@ +package database + +import ( + "database/sql" + "time" + + "go.uber.org/zap" +) + +// WebShellConnection WebShell 连接配置 +type WebShellConnection struct { + ID string `json:"id"` + URL string `json:"url"` + Password string `json:"password"` + Type string `json:"type"` + Method string `json:"method"` + CmdParam string `json:"cmdParam"` + Remark string `json:"remark"` + CreatedAt time.Time `json:"createdAt"` +} + +// GetWebshellConnectionState 获取连接关联的持久化状态 JSON,不存在时返回 "{}" +func (db *DB) GetWebshellConnectionState(connectionID string) (string, error) { + var stateJSON string + err := db.QueryRow(`SELECT state_json FROM webshell_connection_states WHERE connection_id = ?`, connectionID).Scan(&stateJSON) + if err == sql.ErrNoRows { + return "{}", nil + } + if err != nil { + db.logger.Error("查询 WebShell 连接状态失败", zap.Error(err), zap.String("connectionID", connectionID)) + return "", err + } + if stateJSON == "" { + stateJSON = "{}" + } + return stateJSON, nil +} + +// UpsertWebshellConnectionState 保存连接关联的持久化状态 JSON +func (db *DB) UpsertWebshellConnectionState(connectionID, stateJSON string) error { + if stateJSON == "" { + stateJSON = "{}" + } + query := ` + INSERT INTO webshell_connection_states (connection_id, state_json, updated_at) + VALUES (?, ?, ?) + ON CONFLICT(connection_id) DO UPDATE SET + state_json = excluded.state_json, + updated_at = excluded.updated_at + ` + if _, err := db.Exec(query, connectionID, stateJSON, time.Now()); err != nil { + db.logger.Error("保存 WebShell 连接状态失败", zap.Error(err), zap.String("connectionID", connectionID)) + return err + } + return nil +} + +// ListWebshellConnections 列出所有 WebShell 连接,按创建时间倒序 +func (db *DB) ListWebshellConnections() ([]WebShellConnection, error) { + query := ` + SELECT id, url, password, type, method, cmd_param, remark, created_at + FROM webshell_connections + ORDER BY created_at DESC + ` + rows, err := db.Query(query) + if err != nil { + db.logger.Error("查询 WebShell 连接列表失败", zap.Error(err)) + return nil, err + } + defer rows.Close() + + var list []WebShellConnection + for rows.Next() { + var c WebShellConnection + err := rows.Scan(&c.ID, &c.URL, &c.Password, &c.Type, &c.Method, &c.CmdParam, &c.Remark, &c.CreatedAt) + if err != nil { + db.logger.Warn("扫描 WebShell 连接行失败", zap.Error(err)) + continue + } + list = append(list, c) + } + return list, rows.Err() +} + +// GetWebshellConnection 根据 ID 获取一条连接 +func (db *DB) GetWebshellConnection(id string) (*WebShellConnection, error) { + query := ` + SELECT id, url, password, type, method, cmd_param, remark, created_at + FROM webshell_connections WHERE id = ? + ` + var c WebShellConnection + err := db.QueryRow(query, id).Scan(&c.ID, &c.URL, &c.Password, &c.Type, &c.Method, &c.CmdParam, &c.Remark, &c.CreatedAt) + if err == sql.ErrNoRows { + return nil, nil + } + if err != nil { + db.logger.Error("查询 WebShell 连接失败", zap.Error(err), zap.String("id", id)) + return nil, err + } + return &c, nil +} + +// CreateWebshellConnection 创建 WebShell 连接 +func (db *DB) CreateWebshellConnection(c *WebShellConnection) error { + query := ` + INSERT INTO webshell_connections (id, url, password, type, method, cmd_param, remark, created_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?) + ` + _, err := db.Exec(query, c.ID, c.URL, c.Password, c.Type, c.Method, c.CmdParam, c.Remark, c.CreatedAt) + if err != nil { + db.logger.Error("创建 WebShell 连接失败", zap.Error(err), zap.String("id", c.ID)) + return err + } + return nil +} + +// UpdateWebshellConnection 更新 WebShell 连接 +func (db *DB) UpdateWebshellConnection(c *WebShellConnection) error { + query := ` + UPDATE webshell_connections + SET url = ?, password = ?, type = ?, method = ?, cmd_param = ?, remark = ? + WHERE id = ? + ` + result, err := db.Exec(query, c.URL, c.Password, c.Type, c.Method, c.CmdParam, c.Remark, c.ID) + if err != nil { + db.logger.Error("更新 WebShell 连接失败", zap.Error(err), zap.String("id", c.ID)) + return err + } + affected, _ := result.RowsAffected() + if affected == 0 { + return sql.ErrNoRows + } + return nil +} + +// DeleteWebshellConnection 删除 WebShell 连接 +func (db *DB) DeleteWebshellConnection(id string) error { + result, err := db.Exec(`DELETE FROM webshell_connections WHERE id = ?`, id) + if err != nil { + db.logger.Error("删除 WebShell 连接失败", zap.Error(err), zap.String("id", id)) + return err + } + affected, _ := result.RowsAffected() + if affected == 0 { + return sql.ErrNoRows + } + return nil +} diff --git a/internal/logger/logger.go b/internal/logger/logger.go new file mode 100644 index 00000000..97addc0c --- /dev/null +++ b/internal/logger/logger.go @@ -0,0 +1,68 @@ +package logger + +import ( + "os" + + "go.uber.org/zap" + "go.uber.org/zap/zapcore" +) + +type Logger struct { + *zap.Logger +} + +func New(level, output string) *Logger { + var zapLevel zapcore.Level + switch level { + case "debug": + zapLevel = zapcore.DebugLevel + case "info": + zapLevel = zapcore.InfoLevel + case "warn": + zapLevel = zapcore.WarnLevel + case "error": + zapLevel = zapcore.ErrorLevel + default: + zapLevel = zapcore.InfoLevel + } + + config := zap.NewProductionConfig() + config.Level = zap.NewAtomicLevelAt(zapLevel) + config.EncoderConfig.TimeKey = "timestamp" + config.EncoderConfig.EncodeTime = zapcore.ISO8601TimeEncoder + + var writeSyncer zapcore.WriteSyncer + if output == "stdout" { + writeSyncer = zapcore.AddSync(os.Stdout) + } else { + file, err := os.OpenFile(output, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666) + if err != nil { + writeSyncer = zapcore.AddSync(os.Stdout) + } else { + writeSyncer = zapcore.AddSync(file) + } + } + + core := zapcore.NewCore( + zapcore.NewJSONEncoder(config.EncoderConfig), + writeSyncer, + zapLevel, + ) + + logger := zap.New(core, zap.AddCaller(), zap.AddStacktrace(zapcore.ErrorLevel)) + + return &Logger{Logger: logger} +} + +func (l *Logger) Fatal(msg string, fields ...interface{}) { + zapFields := make([]zap.Field, 0, len(fields)) + for _, f := range fields { + switch v := f.(type) { + case error: + zapFields = append(zapFields, zap.Error(v)) + default: + zapFields = append(zapFields, zap.Any("field", v)) + } + } + l.Logger.Fatal(msg, zapFields...) +} diff --git a/internal/mcp/builtin/constants.go b/internal/mcp/builtin/constants.go new file mode 100644 index 00000000..7e669ea1 --- /dev/null +++ b/internal/mcp/builtin/constants.go @@ -0,0 +1,105 @@ +package builtin + +// 内置工具名称常量 +// 所有代码中使用内置工具名称的地方都应该使用这些常量,而不是硬编码字符串 +const ( + // 漏洞管理工具 + ToolRecordVulnerability = "record_vulnerability" + + // 知识库工具 + ToolListKnowledgeRiskTypes = "list_knowledge_risk_types" + ToolSearchKnowledgeBase = "search_knowledge_base" + + // WebShell 助手工具(AI 在 WebShell 管理 - AI 助手 中使用) + ToolWebshellExec = "webshell_exec" + ToolWebshellFileList = "webshell_file_list" + ToolWebshellFileRead = "webshell_file_read" + ToolWebshellFileWrite = "webshell_file_write" + + // WebShell 连接管理工具(用于通过 MCP 管理 webshell 连接) + ToolManageWebshellList = "manage_webshell_list" + ToolManageWebshellAdd = "manage_webshell_add" + ToolManageWebshellUpdate = "manage_webshell_update" + ToolManageWebshellDelete = "manage_webshell_delete" + ToolManageWebshellTest = "manage_webshell_test" + + // 批量任务队列(与 Web 端批量任务一致,供模型创建/启停/查询队列) + ToolBatchTaskList = "batch_task_list" + ToolBatchTaskGet = "batch_task_get" + ToolBatchTaskCreate = "batch_task_create" + ToolBatchTaskStart = "batch_task_start" + ToolBatchTaskRerun = "batch_task_rerun" + ToolBatchTaskPause = "batch_task_pause" + ToolBatchTaskDelete = "batch_task_delete" + ToolBatchTaskUpdateMetadata = "batch_task_update_metadata" + ToolBatchTaskUpdateSchedule = "batch_task_update_schedule" + ToolBatchTaskScheduleEnabled = "batch_task_schedule_enabled" + ToolBatchTaskAdd = "batch_task_add_task" + ToolBatchTaskUpdate = "batch_task_update_task" + ToolBatchTaskRemove = "batch_task_remove_task" +) + +// IsBuiltinTool 检查工具名称是否是内置工具 +func IsBuiltinTool(toolName string) bool { + switch toolName { + case ToolRecordVulnerability, + ToolListKnowledgeRiskTypes, + ToolSearchKnowledgeBase, + ToolWebshellExec, + ToolWebshellFileList, + ToolWebshellFileRead, + ToolWebshellFileWrite, + ToolManageWebshellList, + ToolManageWebshellAdd, + ToolManageWebshellUpdate, + ToolManageWebshellDelete, + ToolManageWebshellTest, + ToolBatchTaskList, + ToolBatchTaskGet, + ToolBatchTaskCreate, + ToolBatchTaskStart, + ToolBatchTaskRerun, + ToolBatchTaskPause, + ToolBatchTaskDelete, + ToolBatchTaskUpdateMetadata, + ToolBatchTaskUpdateSchedule, + ToolBatchTaskScheduleEnabled, + ToolBatchTaskAdd, + ToolBatchTaskUpdate, + ToolBatchTaskRemove: + return true + default: + return false + } +} + +// GetAllBuiltinTools 返回所有内置工具名称列表 +func GetAllBuiltinTools() []string { + return []string{ + ToolRecordVulnerability, + ToolListKnowledgeRiskTypes, + ToolSearchKnowledgeBase, + ToolWebshellExec, + ToolWebshellFileList, + ToolWebshellFileRead, + ToolWebshellFileWrite, + ToolManageWebshellList, + ToolManageWebshellAdd, + ToolManageWebshellUpdate, + ToolManageWebshellDelete, + ToolManageWebshellTest, + ToolBatchTaskList, + ToolBatchTaskGet, + ToolBatchTaskCreate, + ToolBatchTaskStart, + ToolBatchTaskRerun, + ToolBatchTaskPause, + ToolBatchTaskDelete, + ToolBatchTaskUpdateMetadata, + ToolBatchTaskUpdateSchedule, + ToolBatchTaskScheduleEnabled, + ToolBatchTaskAdd, + ToolBatchTaskUpdate, + ToolBatchTaskRemove, + } +} diff --git a/internal/mcp/client_sdk.go b/internal/mcp/client_sdk.go new file mode 100644 index 00000000..59b513b2 --- /dev/null +++ b/internal/mcp/client_sdk.go @@ -0,0 +1,551 @@ +// Package mcp 外部 MCP 客户端 - 基于官方 go-sdk 实现,保证协议兼容性 +package mcp + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "os/exec" + "strings" + "sync" + "time" + + "cyberstrike-ai/internal/config" + + "github.com/google/uuid" + "github.com/modelcontextprotocol/go-sdk/mcp" + "go.uber.org/zap" +) + +const ( + clientName = "CyberStrikeAI" + clientVersion = "1.0.0" +) + +// sdkClient 基于官方 MCP Go SDK 的外部 MCP 客户端,实现 ExternalMCPClient 接口 +type sdkClient struct { + session *mcp.ClientSession + client *mcp.Client + logger *zap.Logger + mu sync.RWMutex + status string // "disconnected", "connecting", "connected", "error" +} + +// newSDKClientFromSession 用已连接成功的 session 构造(供 createSDKClient 内部使用) +func newSDKClientFromSession(session *mcp.ClientSession, client *mcp.Client, logger *zap.Logger) *sdkClient { + return &sdkClient{ + session: session, + client: client, + logger: logger, + status: "connected", + } +} + +// lazySDKClient 延迟连接:Initialize() 时才调用官方 SDK 建立连接,对外实现 ExternalMCPClient +type lazySDKClient struct { + serverCfg config.ExternalMCPServerConfig + logger *zap.Logger + inner ExternalMCPClient // 连接成功后为 *sdkClient + mu sync.RWMutex + status string +} + +func newLazySDKClient(serverCfg config.ExternalMCPServerConfig, logger *zap.Logger) *lazySDKClient { + return &lazySDKClient{ + serverCfg: serverCfg, + logger: logger, + status: "connecting", + } +} + +func (c *lazySDKClient) setStatus(s string) { + c.mu.Lock() + defer c.mu.Unlock() + c.status = s +} + +func (c *lazySDKClient) GetStatus() string { + c.mu.RLock() + defer c.mu.RUnlock() + if c.inner != nil { + return c.inner.GetStatus() + } + return c.status +} + +func (c *lazySDKClient) IsConnected() bool { + c.mu.RLock() + inner := c.inner + c.mu.RUnlock() + if inner != nil { + return inner.IsConnected() + } + return false +} + +func (c *lazySDKClient) Initialize(ctx context.Context) error { + c.mu.Lock() + if c.inner != nil { + c.mu.Unlock() + return nil + } + c.mu.Unlock() + + inner, err := createSDKClient(ctx, c.serverCfg, c.logger) + if err != nil { + c.setStatus("error") + return err + } + + c.mu.Lock() + c.inner = inner + c.mu.Unlock() + c.setStatus("connected") + return nil +} + +func (c *lazySDKClient) ListTools(ctx context.Context) ([]Tool, error) { + c.mu.RLock() + inner := c.inner + c.mu.RUnlock() + if inner == nil { + return nil, fmt.Errorf("未连接") + } + return inner.ListTools(ctx) +} + +func (c *lazySDKClient) CallTool(ctx context.Context, name string, args map[string]interface{}) (*ToolResult, error) { + c.mu.RLock() + inner := c.inner + c.mu.RUnlock() + if inner == nil { + return nil, fmt.Errorf("未连接") + } + return inner.CallTool(ctx, name, args) +} + +func (c *lazySDKClient) Close() error { + c.mu.Lock() + inner := c.inner + c.inner = nil + c.mu.Unlock() + c.setStatus("disconnected") + if inner != nil { + return inner.Close() + } + return nil +} + +func (c *sdkClient) setStatus(s string) { + c.mu.Lock() + defer c.mu.Unlock() + c.status = s +} + +func (c *sdkClient) GetStatus() string { + c.mu.RLock() + defer c.mu.RUnlock() + return c.status +} + +func (c *sdkClient) IsConnected() bool { + return c.GetStatus() == "connected" +} + +func (c *sdkClient) Initialize(ctx context.Context) error { + // sdkClient 由 createSDKClient 在 Connect 成功后才创建,因此 Initialize 时已经连接 + // 此方法仅用于满足 ExternalMCPClient 接口,实际连接在 createSDKClient 中完成 + return nil +} + +func (c *sdkClient) ListTools(ctx context.Context) ([]Tool, error) { + if c.session == nil { + return nil, fmt.Errorf("未连接") + } + res, err := c.session.ListTools(ctx, nil) + if err != nil { + return nil, err + } + if res == nil { + return nil, nil + } + return sdkToolsToOur(res.Tools), nil +} + +func (c *sdkClient) CallTool(ctx context.Context, name string, args map[string]interface{}) (*ToolResult, error) { + if c.session == nil { + return nil, fmt.Errorf("未连接") + } + params := &mcp.CallToolParams{ + Name: name, + Arguments: args, + } + res, err := c.session.CallTool(ctx, params) + if err != nil { + return nil, err + } + return sdkCallToolResultToOurs(res), nil +} + +func (c *sdkClient) Close() error { + c.setStatus("disconnected") + if c.session != nil { + err := c.session.Close() + c.session = nil + return err + } + return nil +} + +// sdkToolsToOur 将 SDK 的 []*mcp.Tool 转为我们的 []Tool +func sdkToolsToOur(tools []*mcp.Tool) []Tool { + if len(tools) == 0 { + return nil + } + out := make([]Tool, 0, len(tools)) + for _, t := range tools { + if t == nil { + continue + } + schema := make(map[string]interface{}) + if t.InputSchema != nil { + // SDK InputSchema 可能为 *jsonschema.Schema 或 map,统一转为 map + if m, ok := t.InputSchema.(map[string]interface{}); ok { + schema = m + } else { + _ = json.Unmarshal(mustJSON(t.InputSchema), &schema) + } + } + desc := t.Description + shortDesc := desc + if t.Annotations != nil && t.Annotations.Title != "" { + shortDesc = t.Annotations.Title + } + out = append(out, Tool{ + Name: t.Name, + Description: desc, + ShortDescription: shortDesc, + InputSchema: schema, + }) + } + return out +} + +// sdkCallToolResultToOurs 将 SDK 的 *mcp.CallToolResult 转为我们的 *ToolResult +func sdkCallToolResultToOurs(res *mcp.CallToolResult) *ToolResult { + if res == nil { + return &ToolResult{Content: []Content{}} + } + content := sdkContentToOurs(res.Content) + return &ToolResult{ + Content: content, + IsError: res.IsError, + } +} + +func sdkContentToOurs(list []mcp.Content) []Content { + if len(list) == 0 { + return nil + } + out := make([]Content, 0, len(list)) + for _, c := range list { + switch v := c.(type) { + case *mcp.TextContent: + out = append(out, Content{Type: "text", Text: v.Text}) + default: + out = append(out, Content{Type: "text", Text: fmt.Sprintf("%v", c)}) + } + } + return out +} + +func mustJSON(v interface{}) []byte { + b, _ := json.Marshal(v) + return b +} + +// simpleHTTPClient 简单 JSON-RPC over HTTP:每次请求一次 POST、响应在 body。实现 ExternalMCPClient。 +// 用于自建 MCP(如 http://127.0.0.1:8081/mcp)或其它仅支持简单 POST 的端点。 +type simpleHTTPClient struct { + url string + client *http.Client + logger *zap.Logger + mu sync.RWMutex + status string +} + +func newSimpleHTTPClient(ctx context.Context, url string, timeout time.Duration, headers map[string]string, logger *zap.Logger) (ExternalMCPClient, error) { + c := &simpleHTTPClient{ + url: url, + client: httpClientWithTimeoutAndHeaders(timeout, headers), + logger: logger, + status: "connecting", + } + if err := c.initialize(ctx); err != nil { + return nil, err + } + c.mu.Lock() + c.status = "connected" + c.mu.Unlock() + return c, nil +} + +func (c *simpleHTTPClient) setStatus(s string) { + c.mu.Lock() + defer c.mu.Unlock() + c.status = s +} + +func (c *simpleHTTPClient) GetStatus() string { + c.mu.RLock() + defer c.mu.RUnlock() + return c.status +} + +func (c *simpleHTTPClient) IsConnected() bool { + return c.GetStatus() == "connected" +} + +func (c *simpleHTTPClient) Initialize(context.Context) error { + return nil // 已在 newSimpleHTTPClient 中完成 +} + +func (c *simpleHTTPClient) initialize(ctx context.Context) error { + params := InitializeRequest{ + ProtocolVersion: ProtocolVersion, + Capabilities: make(map[string]interface{}), + ClientInfo: ClientInfo{Name: clientName, Version: clientVersion}, + } + paramsJSON, _ := json.Marshal(params) + req := &Message{ + ID: MessageID{value: "1"}, + Method: "initialize", + Version: "2.0", + Params: paramsJSON, + } + resp, err := c.sendRequest(ctx, req) + if err != nil { + return fmt.Errorf("initialize: %w", err) + } + if resp.Error != nil { + return fmt.Errorf("initialize: %s (code %d)", resp.Error.Message, resp.Error.Code) + } + // 发送 notifications/initialized(协议要求) + notify := &Message{ + ID: MessageID{value: nil}, + Method: "notifications/initialized", + Version: "2.0", + Params: json.RawMessage("{}"), + } + _ = c.sendNotification(notify) + return nil +} + +func (c *simpleHTTPClient) sendRequest(ctx context.Context, msg *Message) (*Message, error) { + body, err := json.Marshal(msg) + if err != nil { + return nil, err + } + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.url, bytes.NewReader(body)) + if err != nil { + return nil, err + } + httpReq.Header.Set("Content-Type", "application/json") + resp, err := c.client.Do(httpReq) + if err != nil { + return nil, err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + b, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(b)) + } + var out Message + if err := json.NewDecoder(resp.Body).Decode(&out); err != nil { + return nil, err + } + return &out, nil +} + +func (c *simpleHTTPClient) sendNotification(msg *Message) error { + body, _ := json.Marshal(msg) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + httpReq, _ := http.NewRequestWithContext(ctx, http.MethodPost, c.url, bytes.NewReader(body)) + httpReq.Header.Set("Content-Type", "application/json") + resp, err := c.client.Do(httpReq) + if err != nil { + return err + } + resp.Body.Close() + return nil +} + +func (c *simpleHTTPClient) ListTools(ctx context.Context) ([]Tool, error) { + req := &Message{ + ID: MessageID{value: uuid.New().String()}, + Method: "tools/list", + Version: "2.0", + Params: json.RawMessage("{}"), + } + resp, err := c.sendRequest(ctx, req) + if err != nil { + return nil, err + } + if resp.Error != nil { + return nil, fmt.Errorf("tools/list: %s (code %d)", resp.Error.Message, resp.Error.Code) + } + var listResp ListToolsResponse + if err := json.Unmarshal(resp.Result, &listResp); err != nil { + return nil, err + } + return listResp.Tools, nil +} + +func (c *simpleHTTPClient) CallTool(ctx context.Context, name string, args map[string]interface{}) (*ToolResult, error) { + params := CallToolRequest{Name: name, Arguments: args} + paramsJSON, _ := json.Marshal(params) + req := &Message{ + ID: MessageID{value: uuid.New().String()}, + Method: "tools/call", + Version: "2.0", + Params: paramsJSON, + } + resp, err := c.sendRequest(ctx, req) + if err != nil { + return nil, err + } + if resp.Error != nil { + return nil, fmt.Errorf("tools/call: %s (code %d)", resp.Error.Message, resp.Error.Code) + } + var callResp CallToolResponse + if err := json.Unmarshal(resp.Result, &callResp); err != nil { + return nil, err + } + return &ToolResult{Content: callResp.Content, IsError: callResp.IsError}, nil +} + +func (c *simpleHTTPClient) Close() error { + c.setStatus("disconnected") + return nil +} + +// createSDKClient 根据配置创建并连接外部 MCP 客户端(使用官方 SDK),返回实现 ExternalMCPClient 的 *sdkClient +// 若连接失败返回 (nil, error)。ctx 用于连接超时与取消。 +func createSDKClient(ctx context.Context, serverCfg config.ExternalMCPServerConfig, logger *zap.Logger) (ExternalMCPClient, error) { + timeout := time.Duration(serverCfg.Timeout) * time.Second + if timeout <= 0 { + timeout = 30 * time.Second + } + + transport := serverCfg.Transport + if transport == "" { + if serverCfg.Command != "" { + transport = "stdio" + } else if serverCfg.URL != "" { + transport = "http" + } else { + return nil, fmt.Errorf("配置缺少 command 或 url") + } + } + + client := mcp.NewClient(&mcp.Implementation{ + Name: clientName, + Version: clientVersion, + }, nil) + + var t mcp.Transport + switch transport { + case "stdio": + if serverCfg.Command == "" { + return nil, fmt.Errorf("stdio 模式需要配置 command") + } + // 必须用 exec.Command 而非 CommandContext:doConnect 返回后 ctx 会被 cancel, + // 若用 CommandContext(ctx) 会立刻杀掉子进程,导致 ListTools 等后续请求失败、显示 0 工具 + cmd := exec.Command(serverCfg.Command, serverCfg.Args...) + if len(serverCfg.Env) > 0 { + cmd.Env = append(cmd.Env, envMapToSlice(serverCfg.Env)...) + } + t = &mcp.CommandTransport{Command: cmd} + case "sse": + if serverCfg.URL == "" { + return nil, fmt.Errorf("sse 模式需要配置 url") + } + httpClient := httpClientWithTimeoutAndHeaders(timeout, serverCfg.Headers) + t = &mcp.SSEClientTransport{ + Endpoint: serverCfg.URL, + HTTPClient: httpClient, + } + case "http": + if serverCfg.URL == "" { + return nil, fmt.Errorf("http 模式需要配置 url") + } + httpClient := httpClientWithTimeoutAndHeaders(timeout, serverCfg.Headers) + t = &mcp.StreamableClientTransport{ + Endpoint: serverCfg.URL, + HTTPClient: httpClient, + } + case "simple_http": + // 简单 JSON-RPC HTTP:每次请求一次 POST、响应在 body。用于自建 MCP 或兼容旧端点(如 http://127.0.0.1:8081/mcp) + if serverCfg.URL == "" { + return nil, fmt.Errorf("simple_http 模式需要配置 url") + } + return newSimpleHTTPClient(ctx, serverCfg.URL, timeout, serverCfg.Headers, logger) + default: + return nil, fmt.Errorf("不支持的传输模式: %s", transport) + } + + session, err := client.Connect(ctx, t, nil) + if err != nil { + return nil, fmt.Errorf("连接失败: %w", err) + } + + return newSDKClientFromSession(session, client, logger), nil +} + +func envMapToSlice(env map[string]string) []string { + m := make(map[string]string) + for _, s := range os.Environ() { + if i := strings.IndexByte(s, '='); i > 0 { + m[s[:i]] = s[i+1:] + } + } + for k, v := range env { + m[k] = v + } + out := make([]string, 0, len(m)) + for k, v := range m { + out = append(out, k+"="+v) + } + return out +} + +func httpClientWithTimeoutAndHeaders(timeout time.Duration, headers map[string]string) *http.Client { + transport := http.DefaultTransport + if len(headers) > 0 { + transport = &headerRoundTripper{ + headers: headers, + base: http.DefaultTransport, + } + } + return &http.Client{ + Timeout: timeout, + Transport: transport, + } +} + +type headerRoundTripper struct { + headers map[string]string + base http.RoundTripper +} + +func (h *headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + for k, v := range h.headers { + req.Header.Set(k, v) + } + return h.base.RoundTrip(req) +} diff --git a/internal/mcp/external_manager.go b/internal/mcp/external_manager.go new file mode 100644 index 00000000..1d9c3164 --- /dev/null +++ b/internal/mcp/external_manager.go @@ -0,0 +1,1105 @@ +package mcp + +import ( + "context" + "fmt" + "strings" + "sync" + "time" + + "cyberstrike-ai/internal/config" + + "github.com/google/uuid" + + "go.uber.org/zap" +) + +// ExternalMCPManager 外部MCP管理器 +type ExternalMCPManager struct { + clients map[string]ExternalMCPClient + configs map[string]config.ExternalMCPServerConfig + logger *zap.Logger + storage MonitorStorage // 可选的持久化存储 + executions map[string]*ToolExecution // 执行记录 + stats map[string]*ToolStats // 工具统计信息 + errors map[string]string // 错误信息 + toolCounts map[string]int // 工具数量缓存 + toolCountsMu sync.RWMutex // 工具数量缓存的锁 + toolCache map[string][]Tool // 工具列表缓存:MCP名称 -> 工具列表 + toolCacheMu sync.RWMutex // 工具列表缓存的锁 + stopRefresh chan struct{} // 停止后台刷新的信号 + refreshWg sync.WaitGroup // 等待后台刷新goroutine完成 + mu sync.RWMutex +} + +// NewExternalMCPManager 创建外部MCP管理器 +func NewExternalMCPManager(logger *zap.Logger) *ExternalMCPManager { + return NewExternalMCPManagerWithStorage(logger, nil) +} + +// NewExternalMCPManagerWithStorage 创建外部MCP管理器(带持久化存储) +func NewExternalMCPManagerWithStorage(logger *zap.Logger, storage MonitorStorage) *ExternalMCPManager { + manager := &ExternalMCPManager{ + clients: make(map[string]ExternalMCPClient), + configs: make(map[string]config.ExternalMCPServerConfig), + logger: logger, + storage: storage, + executions: make(map[string]*ToolExecution), + stats: make(map[string]*ToolStats), + errors: make(map[string]string), + toolCounts: make(map[string]int), + toolCache: make(map[string][]Tool), + stopRefresh: make(chan struct{}), + } + // 启动后台刷新工具数量的goroutine + manager.startToolCountRefresh() + return manager +} + +// LoadConfigs 加载配置 +func (m *ExternalMCPManager) LoadConfigs(cfg *config.ExternalMCPConfig) { + m.mu.Lock() + defer m.mu.Unlock() + + if cfg == nil || cfg.Servers == nil { + return + } + + m.configs = make(map[string]config.ExternalMCPServerConfig) + for name, serverCfg := range cfg.Servers { + m.configs[name] = serverCfg + } +} + +// GetConfigs 获取所有配置 +func (m *ExternalMCPManager) GetConfigs() map[string]config.ExternalMCPServerConfig { + m.mu.RLock() + defer m.mu.RUnlock() + + result := make(map[string]config.ExternalMCPServerConfig) + for k, v := range m.configs { + result[k] = v + } + return result +} + +// AddOrUpdateConfig 添加或更新配置 +func (m *ExternalMCPManager) AddOrUpdateConfig(name string, serverCfg config.ExternalMCPServerConfig) error { + m.mu.Lock() + defer m.mu.Unlock() + + // 如果已存在客户端,先关闭 + if client, exists := m.clients[name]; exists { + client.Close() + delete(m.clients, name) + } + + m.configs[name] = serverCfg + + // 如果启用,自动连接 + if m.isEnabled(serverCfg) { + go m.connectClient(name, serverCfg) + } + + return nil +} + +// RemoveConfig 移除配置 +func (m *ExternalMCPManager) RemoveConfig(name string) error { + m.mu.Lock() + defer m.mu.Unlock() + + // 关闭客户端 + if client, exists := m.clients[name]; exists { + client.Close() + delete(m.clients, name) + } + + delete(m.configs, name) + + // 清理工具数量缓存 + m.toolCountsMu.Lock() + delete(m.toolCounts, name) + m.toolCountsMu.Unlock() + + // 清理工具列表缓存 + m.toolCacheMu.Lock() + delete(m.toolCache, name) + m.toolCacheMu.Unlock() + + return nil +} + +// StartClient 启动客户端 +func (m *ExternalMCPManager) StartClient(name string) error { + m.mu.Lock() + serverCfg, exists := m.configs[name] + m.mu.Unlock() + + if !exists { + return fmt.Errorf("配置不存在: %s", name) + } + + // 检查是否已经有连接的客户端 + m.mu.RLock() + existingClient, hasClient := m.clients[name] + m.mu.RUnlock() + + if hasClient { + // 检查客户端是否已连接 + if existingClient.IsConnected() { + // 客户端已连接,直接返回成功(目标状态已达成) + // 更新配置为启用(确保配置一致) + m.mu.Lock() + serverCfg.ExternalMCPEnable = true + m.configs[name] = serverCfg + m.mu.Unlock() + return nil + } + // 如果有客户端但未连接,先关闭 + existingClient.Close() + m.mu.Lock() + delete(m.clients, name) + m.mu.Unlock() + } + + // 更新配置为启用 + m.mu.Lock() + serverCfg.ExternalMCPEnable = true + m.configs[name] = serverCfg + // 清除之前的错误信息(重新启动时) + delete(m.errors, name) + m.mu.Unlock() + + // 立即创建客户端并设置为"connecting"状态,这样前端可以立即看到状态 + client := m.createClient(serverCfg) + if client == nil { + return fmt.Errorf("无法创建客户端:不支持的传输模式") + } + + // 设置状态为connecting + m.setClientStatus(client, "connecting") + + // 立即保存客户端,这样前端查询时就能看到"connecting"状态 + m.mu.Lock() + m.clients[name] = client + m.mu.Unlock() + + // 在后台异步进行实际连接 + go func() { + if err := m.doConnect(name, serverCfg, client); err != nil { + m.logger.Error("连接外部MCP客户端失败", + zap.String("name", name), + zap.Error(err), + ) + // 连接失败,设置状态为error并保存错误信息 + m.setClientStatus(client, "error") + m.mu.Lock() + m.errors[name] = err.Error() + m.mu.Unlock() + // 触发工具数量刷新(连接失败,工具数量应为0) + m.triggerToolCountRefresh() + } else { + // 连接成功,清除错误信息 + m.mu.Lock() + delete(m.errors, name) + m.mu.Unlock() + // 立即刷新工具数量和工具列表缓存 + m.triggerToolCountRefresh() + m.refreshToolCache(name, client) + // 2 秒后再刷新一次,覆盖 SSE/Streamable 等需稍等就绪的远端 + go func() { + time.Sleep(2 * time.Second) + m.triggerToolCountRefresh() + m.refreshToolCache(name, client) + }() + } + }() + + return nil +} + +// StopClient 停止客户端 +func (m *ExternalMCPManager) StopClient(name string) error { + m.mu.Lock() + defer m.mu.Unlock() + + serverCfg, exists := m.configs[name] + if !exists { + return fmt.Errorf("配置不存在: %s", name) + } + + // 关闭客户端 + if client, exists := m.clients[name]; exists { + client.Close() + delete(m.clients, name) + } + + // 清除错误信息 + delete(m.errors, name) + + // 更新工具数量缓存(停止后工具数量为0) + m.toolCountsMu.Lock() + m.toolCounts[name] = 0 + m.toolCountsMu.Unlock() + + // 更新配置为禁用 + serverCfg.ExternalMCPEnable = false + m.configs[name] = serverCfg + + return nil +} + +// GetClient 获取客户端 +func (m *ExternalMCPManager) GetClient(name string) (ExternalMCPClient, bool) { + m.mu.RLock() + defer m.mu.RUnlock() + + client, exists := m.clients[name] + return client, exists +} + +// GetError 获取错误信息 +func (m *ExternalMCPManager) GetError(name string) string { + m.mu.RLock() + defer m.mu.RUnlock() + + return m.errors[name] +} + +// GetAllTools 获取所有外部MCP的工具 +// 优先从已连接的客户端获取,如果连接断开则返回缓存的工具列表 +// 策略: +// - error 状态:不使用缓存,直接跳过(配置错误或服务不可用) +// - disconnected/connecting 状态:使用缓存(临时断开) +// - connected 状态:正常获取,失败时降级使用缓存 +func (m *ExternalMCPManager) GetAllTools(ctx context.Context) ([]Tool, error) { + m.mu.RLock() + clients := make(map[string]ExternalMCPClient) + for k, v := range m.clients { + clients[k] = v + } + m.mu.RUnlock() + + var allTools []Tool + var hasError bool + var lastError error + + // 使用较短的超时时间进行快速检查(3秒),避免阻塞 + quickCtx, quickCancel := context.WithTimeout(ctx, 3*time.Second) + defer quickCancel() + + for name, client := range clients { + tools, err := m.getToolsForClient(name, client, quickCtx) + if err != nil { + // 记录错误,但继续处理其他客户端 + hasError = true + if lastError == nil { + lastError = err + } + continue + } + + // 为工具添加前缀,避免冲突 + for _, tool := range tools { + tool.Name = fmt.Sprintf("%s::%s", name, tool.Name) + allTools = append(allTools, tool) + } + } + + // 如果有错误但至少返回了一些工具,不返回错误(部分成功) + if hasError && len(allTools) == 0 { + return nil, fmt.Errorf("获取外部MCP工具失败: %w", lastError) + } + + return allTools, nil +} + +// getToolsForClient 获取指定客户端的工具列表 +// 返回工具列表和错误(如果完全无法获取) +func (m *ExternalMCPManager) getToolsForClient(name string, client ExternalMCPClient, ctx context.Context) ([]Tool, error) { + status := client.GetStatus() + + // error 状态:不使用缓存,直接返回错误 + if status == "error" { + m.logger.Debug("跳过连接失败的外部MCP(不使用缓存)", + zap.String("name", name), + zap.String("status", status), + ) + return nil, fmt.Errorf("外部MCP连接失败: %s", name) + } + + // 已连接:尝试获取最新工具列表 + if client.IsConnected() { + tools, err := client.ListTools(ctx) + if err != nil { + // 获取失败,尝试使用缓存 + return m.getCachedTools(name, "连接正常但获取失败", err) + } + + // 获取成功,更新缓存 + m.updateToolCache(name, tools) + return tools, nil + } + + // 未连接:根据状态决定是否使用缓存 + if status == "disconnected" || status == "connecting" { + return m.getCachedTools(name, fmt.Sprintf("客户端临时断开(状态: %s)", status), nil) + } + + // 其他未知状态,不使用缓存 + m.logger.Debug("跳过外部MCP(未知状态)", + zap.String("name", name), + zap.String("status", status), + ) + return nil, fmt.Errorf("外部MCP状态未知: %s (状态: %s)", name, status) +} + +// getCachedTools 获取缓存的工具列表 +func (m *ExternalMCPManager) getCachedTools(name, reason string, originalErr error) ([]Tool, error) { + m.toolCacheMu.RLock() + cachedTools, hasCache := m.toolCache[name] + m.toolCacheMu.RUnlock() + + if hasCache && len(cachedTools) > 0 { + m.logger.Debug("使用缓存的工具列表", + zap.String("name", name), + zap.String("reason", reason), + zap.Int("count", len(cachedTools)), + zap.Error(originalErr), + ) + return cachedTools, nil + } + + // 无缓存,返回错误 + if originalErr != nil { + return nil, fmt.Errorf("获取外部MCP工具失败且无缓存: %w", originalErr) + } + return nil, fmt.Errorf("外部MCP无缓存工具: %s", name) +} + +// updateToolCache 更新工具列表缓存 +func (m *ExternalMCPManager) updateToolCache(name string, tools []Tool) { + m.toolCacheMu.Lock() + m.toolCache[name] = tools + m.toolCacheMu.Unlock() + + // 如果返回空列表,记录警告 + if len(tools) == 0 { + m.logger.Warn("外部MCP返回空工具列表", + zap.String("name", name), + zap.String("hint", "服务可能暂时不可用,工具列表为空"), + ) + } else { + m.logger.Debug("工具列表缓存已更新", + zap.String("name", name), + zap.Int("count", len(tools)), + ) + } +} + +// CallTool 调用外部MCP工具(返回执行ID) +func (m *ExternalMCPManager) CallTool(ctx context.Context, toolName string, args map[string]interface{}) (*ToolResult, string, error) { + // 解析工具名称:name::toolName + var mcpName, actualToolName string + if idx := findSubstring(toolName, "::"); idx > 0 { + mcpName = toolName[:idx] + actualToolName = toolName[idx+2:] + } else { + return nil, "", fmt.Errorf("无效的工具名称格式: %s", toolName) + } + + client, exists := m.GetClient(mcpName) + if !exists { + return nil, "", fmt.Errorf("外部MCP客户端不存在: %s", mcpName) + } + + // 检查连接状态,如果未连接或状态为error,不允许调用 + if !client.IsConnected() { + status := client.GetStatus() + if status == "error" { + // 获取错误信息(如果有) + errorMsg := m.GetError(mcpName) + if errorMsg != "" { + return nil, "", fmt.Errorf("外部MCP连接失败: %s (错误: %s)", mcpName, errorMsg) + } + return nil, "", fmt.Errorf("外部MCP连接失败: %s", mcpName) + } + return nil, "", fmt.Errorf("外部MCP客户端未连接: %s (状态: %s)", mcpName, status) + } + + // 创建执行记录 + executionID := uuid.New().String() + execution := &ToolExecution{ + ID: executionID, + ToolName: toolName, // 使用完整工具名称(包含MCP名称) + Arguments: args, + Status: "running", + StartTime: time.Now(), + } + + m.mu.Lock() + m.executions[executionID] = execution + // 如果内存中的执行记录超过限制,清理最旧的记录 + m.cleanupOldExecutions() + m.mu.Unlock() + + if m.storage != nil { + if err := m.storage.SaveToolExecution(execution); err != nil { + m.logger.Warn("保存执行记录到数据库失败", zap.Error(err)) + } + } + + // 调用工具 + result, err := client.CallTool(ctx, actualToolName, args) + + // 更新执行记录 + m.mu.Lock() + now := time.Now() + execution.EndTime = &now + execution.Duration = now.Sub(execution.StartTime) + + if err != nil { + execution.Status = "failed" + execution.Error = err.Error() + } else if result != nil && result.IsError { + execution.Status = "failed" + if len(result.Content) > 0 { + execution.Error = result.Content[0].Text + } else { + execution.Error = "工具执行返回错误结果" + } + execution.Result = result + } else { + execution.Status = "completed" + if result == nil { + result = &ToolResult{ + Content: []Content{ + {Type: "text", Text: "工具执行完成,但未返回结果"}, + }, + } + } + execution.Result = result + } + m.mu.Unlock() + + if m.storage != nil { + if err := m.storage.SaveToolExecution(execution); err != nil { + m.logger.Warn("保存执行记录到数据库失败", zap.Error(err)) + } + } + + // 更新统计信息 + failed := err != nil || (result != nil && result.IsError) + m.updateStats(toolName, failed) + + // 如果使用存储,从内存中删除(已持久化) + if m.storage != nil { + m.mu.Lock() + delete(m.executions, executionID) + m.mu.Unlock() + } + + if err != nil { + return nil, executionID, err + } + + return result, executionID, nil +} + +// cleanupOldExecutions 清理旧的执行记录(保持内存中的记录数量在限制内) +func (m *ExternalMCPManager) cleanupOldExecutions() { + const maxExecutionsInMemory = 1000 + if len(m.executions) <= maxExecutionsInMemory { + return + } + + // 按开始时间排序,删除最旧的记录 + type execTime struct { + id string + startTime time.Time + } + var execs []execTime + for id, exec := range m.executions { + execs = append(execs, execTime{id: id, startTime: exec.StartTime}) + } + + // 按时间排序 + for i := 0; i < len(execs)-1; i++ { + for j := i + 1; j < len(execs); j++ { + if execs[i].startTime.After(execs[j].startTime) { + execs[i], execs[j] = execs[j], execs[i] + } + } + } + + // 删除最旧的记录 + toDelete := len(m.executions) - maxExecutionsInMemory + for i := 0; i < toDelete && i < len(execs); i++ { + delete(m.executions, execs[i].id) + } +} + +// GetExecution 获取执行记录(先从内存查找,再从数据库查找) +func (m *ExternalMCPManager) GetExecution(id string) (*ToolExecution, bool) { + m.mu.RLock() + exec, exists := m.executions[id] + m.mu.RUnlock() + + if exists { + return exec, true + } + + if m.storage != nil { + exec, err := m.storage.GetToolExecution(id) + if err == nil { + return exec, true + } + } + + return nil, false +} + +// updateStats 更新统计信息 +func (m *ExternalMCPManager) updateStats(toolName string, failed bool) { + now := time.Now() + if m.storage != nil { + totalCalls := 1 + successCalls := 0 + failedCalls := 0 + if failed { + failedCalls = 1 + } else { + successCalls = 1 + } + if err := m.storage.UpdateToolStats(toolName, totalCalls, successCalls, failedCalls, &now); err != nil { + m.logger.Warn("保存统计信息到数据库失败", zap.Error(err)) + } + return + } + + m.mu.Lock() + defer m.mu.Unlock() + + if m.stats[toolName] == nil { + m.stats[toolName] = &ToolStats{ + ToolName: toolName, + } + } + + stats := m.stats[toolName] + stats.TotalCalls++ + stats.LastCallTime = &now + + if failed { + stats.FailedCalls++ + } else { + stats.SuccessCalls++ + } +} + +// GetStats 获取MCP服务器统计信息 +func (m *ExternalMCPManager) GetStats() map[string]interface{} { + m.mu.RLock() + defer m.mu.RUnlock() + + total := len(m.configs) + enabled := 0 + disabled := 0 + connected := 0 + + for name, cfg := range m.configs { + if m.isEnabled(cfg) { + enabled++ + if client, exists := m.clients[name]; exists && client.IsConnected() { + connected++ + } + } else { + disabled++ + } + } + + return map[string]interface{}{ + "total": total, + "enabled": enabled, + "disabled": disabled, + "connected": connected, + } +} + +// GetToolStats 获取工具统计信息(合并内存和数据库) +// 只返回外部MCP工具的统计信息(工具名称包含 "::") +func (m *ExternalMCPManager) GetToolStats() map[string]*ToolStats { + result := make(map[string]*ToolStats) + + // 从数据库加载统计信息(如果使用数据库存储) + if m.storage != nil { + dbStats, err := m.storage.LoadToolStats() + if err == nil { + // 只保留外部MCP工具的统计信息(工具名称包含 "::") + for k, v := range dbStats { + if findSubstring(k, "::") > 0 { + result[k] = v + } + } + } else { + m.logger.Warn("从数据库加载统计信息失败", zap.Error(err)) + } + } + + // 合并内存中的统计信息 + m.mu.RLock() + for k, v := range m.stats { + // 如果数据库中已有该工具的统计信息,合并它们 + if existing, exists := result[k]; exists { + // 创建新的统计信息对象,避免修改共享对象 + merged := &ToolStats{ + ToolName: k, + TotalCalls: existing.TotalCalls + v.TotalCalls, + SuccessCalls: existing.SuccessCalls + v.SuccessCalls, + FailedCalls: existing.FailedCalls + v.FailedCalls, + } + // 使用最新的调用时间 + if v.LastCallTime != nil && (existing.LastCallTime == nil || v.LastCallTime.After(*existing.LastCallTime)) { + merged.LastCallTime = v.LastCallTime + } else if existing.LastCallTime != nil { + timeCopy := *existing.LastCallTime + merged.LastCallTime = &timeCopy + } + result[k] = merged + } else { + // 如果数据库中没有,直接使用内存中的统计信息 + statCopy := *v + result[k] = &statCopy + } + } + m.mu.RUnlock() + + return result +} + +// GetToolCount 获取指定外部MCP的工具数量(从缓存读取,不阻塞) +func (m *ExternalMCPManager) GetToolCount(name string) (int, error) { + // 先从缓存读取 + m.toolCountsMu.RLock() + if count, exists := m.toolCounts[name]; exists { + m.toolCountsMu.RUnlock() + return count, nil + } + m.toolCountsMu.RUnlock() + + // 如果缓存中没有,检查客户端状态 + client, exists := m.GetClient(name) + if !exists { + return 0, fmt.Errorf("客户端不存在: %s", name) + } + + if !client.IsConnected() { + // 未连接,缓存为0 + m.toolCountsMu.Lock() + m.toolCounts[name] = 0 + m.toolCountsMu.Unlock() + return 0, nil + } + + // 如果已连接但缓存中没有,触发异步刷新并返回0(避免阻塞) + m.triggerToolCountRefresh() + return 0, nil +} + +// GetToolCounts 获取所有外部MCP的工具数量(从缓存读取,不阻塞) +func (m *ExternalMCPManager) GetToolCounts() map[string]int { + m.toolCountsMu.RLock() + defer m.toolCountsMu.RUnlock() + + // 返回缓存的副本,避免外部修改 + result := make(map[string]int) + for k, v := range m.toolCounts { + result[k] = v + } + return result +} + +// refreshToolCounts 刷新工具数量缓存(后台异步执行) +func (m *ExternalMCPManager) refreshToolCounts() { + m.mu.RLock() + clients := make(map[string]ExternalMCPClient) + for k, v := range m.clients { + clients[k] = v + } + m.mu.RUnlock() + + newCounts := make(map[string]int) + + // 使用goroutine并发获取每个客户端的工具数量,避免串行阻塞 + type countResult struct { + name string + count int + } + resultChan := make(chan countResult, len(clients)) + + for name, client := range clients { + go func(n string, c ExternalMCPClient) { + if !c.IsConnected() { + resultChan <- countResult{name: n, count: 0} + return + } + + // 使用合理的超时时间(15秒),既能应对网络延迟,又不会过长阻塞 + // 由于这是后台异步刷新,超时不会影响前端响应 + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + tools, err := c.ListTools(ctx) + cancel() + + if err != nil { + errStr := err.Error() + // SSE 连接 EOF:远端可能关闭了流或未按规范在流上推送响应,仅首次用 Warn 提示 + if strings.Contains(errStr, "EOF") || strings.Contains(errStr, "client is closing") { + m.logger.Warn("获取外部MCP工具数量失败(SSE 流已关闭或服务端未在流上返回 tools/list 响应)", + zap.String("name", n), + zap.String("hint", "若为 SSE 连接,请确认服务端保持 GET 流打开并按 MCP 规范以 event: message 推送 JSON-RPC 响应"), + zap.Error(err), + ) + } else { + m.logger.Warn("获取外部MCP工具数量失败,请检查连接或服务端 tools/list", + zap.String("name", n), + zap.Error(err), + ) + } + resultChan <- countResult{name: n, count: -1} // -1 表示使用旧值 + return + } + + resultChan <- countResult{name: n, count: len(tools)} + }(name, client) + } + + // 收集结果 + m.toolCountsMu.RLock() + oldCounts := make(map[string]int) + for k, v := range m.toolCounts { + oldCounts[k] = v + } + m.toolCountsMu.RUnlock() + + for i := 0; i < len(clients); i++ { + result := <-resultChan + if result.count >= 0 { + newCounts[result.name] = result.count + } else { + // 获取失败,保留旧值 + if oldCount, exists := oldCounts[result.name]; exists { + newCounts[result.name] = oldCount + } else { + newCounts[result.name] = 0 + } + } + } + + // 更新缓存 + m.toolCountsMu.Lock() + // 更新所有获取到的值 + for name, count := range newCounts { + m.toolCounts[name] = count + } + // 对于未连接的客户端,设置为0 + for name, client := range clients { + if !client.IsConnected() { + m.toolCounts[name] = 0 + } + } + m.toolCountsMu.Unlock() +} + +// refreshToolCache 刷新指定MCP的工具列表缓存 +func (m *ExternalMCPManager) refreshToolCache(name string, client ExternalMCPClient) { + if !client.IsConnected() { + return + } + + // 检查状态,如果是error状态,不更新缓存 + status := client.GetStatus() + if status == "error" { + m.logger.Debug("跳过刷新工具列表缓存(连接失败)", + zap.String("name", name), + zap.String("status", status), + ) + return + } + + // 使用较短的超时时间(5秒) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + tools, err := client.ListTools(ctx) + if err != nil { + m.logger.Debug("刷新工具列表缓存失败", + zap.String("name", name), + zap.Error(err), + ) + // 刷新失败时不更新缓存,保留旧缓存(如果有) + return + } + + // 使用统一的缓存更新方法 + m.updateToolCache(name, tools) +} + +// startToolCountRefresh 启动后台刷新工具数量的goroutine +func (m *ExternalMCPManager) startToolCountRefresh() { + m.refreshWg.Add(1) + go func() { + defer m.refreshWg.Done() + ticker := time.NewTicker(10 * time.Second) // 每10秒刷新一次 + defer ticker.Stop() + + // 立即执行一次刷新 + m.refreshToolCounts() + + for { + select { + case <-ticker.C: + m.refreshToolCounts() + case <-m.stopRefresh: + return + } + } + }() +} + +// triggerToolCountRefresh 触发立即刷新工具数量(异步) +func (m *ExternalMCPManager) triggerToolCountRefresh() { + go m.refreshToolCounts() +} + +// createClient 创建客户端(不连接)。统一使用官方 MCP Go SDK 的 lazy 客户端,连接在 Initialize 时完成。 +func (m *ExternalMCPManager) createClient(serverCfg config.ExternalMCPServerConfig) ExternalMCPClient { + transport := serverCfg.Transport + if transport == "" { + if serverCfg.Command != "" { + transport = "stdio" + } else if serverCfg.URL != "" { + transport = "http" + } else { + return nil + } + } + + switch transport { + case "http": + if serverCfg.URL == "" { + return nil + } + return newLazySDKClient(serverCfg, m.logger) + case "simple_http": + // 简单 HTTP(一次 POST 一次响应),用于自建 MCP 等 + if serverCfg.URL == "" { + return nil + } + return newLazySDKClient(serverCfg, m.logger) + case "stdio": + if serverCfg.Command == "" { + return nil + } + return newLazySDKClient(serverCfg, m.logger) + case "sse": + if serverCfg.URL == "" { + return nil + } + return newLazySDKClient(serverCfg, m.logger) + default: + return nil + } +} + +// doConnect 执行实际连接 +func (m *ExternalMCPManager) doConnect(name string, serverCfg config.ExternalMCPServerConfig, client ExternalMCPClient) error { + timeout := time.Duration(serverCfg.Timeout) * time.Second + if timeout <= 0 { + timeout = 30 * time.Second + } + + // 初始化连接 + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + if err := client.Initialize(ctx); err != nil { + return err + } + + m.logger.Info("外部MCP客户端已连接", + zap.String("name", name), + ) + + return nil +} + +// setClientStatus 设置客户端状态(通过类型断言) +func (m *ExternalMCPManager) setClientStatus(client ExternalMCPClient, status string) { + if c, ok := client.(*lazySDKClient); ok { + c.setStatus(status) + } +} + +// connectClient 连接客户端(异步)- 保留用于向后兼容 +func (m *ExternalMCPManager) connectClient(name string, serverCfg config.ExternalMCPServerConfig) error { + client := m.createClient(serverCfg) + if client == nil { + return fmt.Errorf("无法创建客户端:不支持的传输模式") + } + + // 设置状态为connecting + m.setClientStatus(client, "connecting") + + // 初始化连接 + timeout := time.Duration(serverCfg.Timeout) * time.Second + if timeout <= 0 { + timeout = 30 * time.Second + } + + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + if err := client.Initialize(ctx); err != nil { + m.logger.Error("初始化外部MCP客户端失败", + zap.String("name", name), + zap.Error(err), + ) + return err + } + + // 保存客户端 + m.mu.Lock() + m.clients[name] = client + m.mu.Unlock() + + m.logger.Info("外部MCP客户端已连接", + zap.String("name", name), + ) + + // 连接成功,触发工具数量刷新和工具列表缓存刷新 + m.triggerToolCountRefresh() + m.mu.RLock() + if client, exists := m.clients[name]; exists { + m.refreshToolCache(name, client) + } + m.mu.RUnlock() + + return nil +} + +// isEnabled 检查是否启用 +func (m *ExternalMCPManager) isEnabled(cfg config.ExternalMCPServerConfig) bool { + // 优先使用 ExternalMCPEnable 字段 + // 如果没有设置,检查旧的 enabled/disabled 字段(向后兼容) + if cfg.ExternalMCPEnable { + return true + } + // 向后兼容:检查旧字段 + if cfg.Disabled { + return false + } + if cfg.Enabled { + return true + } + // 都没有设置,默认为启用 + return true +} + +// findSubstring 查找子字符串(简单实现) +func findSubstring(s, substr string) int { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return i + } + } + return -1 +} + +// StartAllEnabled 启动所有启用的客户端 +func (m *ExternalMCPManager) StartAllEnabled() { + m.mu.RLock() + configs := make(map[string]config.ExternalMCPServerConfig) + for k, v := range m.configs { + configs[k] = v + } + m.mu.RUnlock() + + for name, cfg := range configs { + if m.isEnabled(cfg) { + go func(n string, c config.ExternalMCPServerConfig) { + if err := m.connectClient(n, c); err != nil { + // 检查是否是连接被拒绝的错误(服务可能还没启动) + errStr := strings.ToLower(err.Error()) + isConnectionRefused := strings.Contains(errStr, "connection refused") || + strings.Contains(errStr, "dial tcp") || + strings.Contains(errStr, "connect: connection refused") + + if isConnectionRefused { + // 连接被拒绝,说明目标服务可能还没启动,这是正常的 + // 使用 Warn 级别,提示用户这是正常的,可以通过手动启动或等待服务启动后自动连接 + fields := []zap.Field{ + zap.String("name", n), + zap.String("message", "目标服务可能尚未启动,这是正常的。服务启动后可通过界面手动连接,或等待自动重试"), + zap.Error(err), + } + + // 根据传输模式添加相应的信息 + transport := c.Transport + if transport == "" { + if c.Command != "" { + transport = "stdio" + } else if c.URL != "" { + transport = "http" + } + } + + if transport == "http" && c.URL != "" { + fields = append(fields, zap.String("url", c.URL)) + } else if transport == "stdio" && c.Command != "" { + fields = append(fields, zap.String("command", c.Command)) + } + + m.logger.Warn("外部MCP服务暂未就绪", fields...) + } else { + // 其他错误,使用 Error 级别 + m.logger.Error("启动外部MCP客户端失败", + zap.String("name", n), + zap.Error(err), + ) + } + } + }(name, cfg) + } + } +} + +// StopAll 停止所有客户端 +func (m *ExternalMCPManager) StopAll() { + m.mu.Lock() + defer m.mu.Unlock() + + for name, client := range m.clients { + client.Close() + delete(m.clients, name) + } + + // 清理所有工具数量缓存 + m.toolCountsMu.Lock() + m.toolCounts = make(map[string]int) + m.toolCountsMu.Unlock() + + // 清理所有工具列表缓存 + m.toolCacheMu.Lock() + m.toolCache = make(map[string][]Tool) + m.toolCacheMu.Unlock() + + // 停止后台刷新(使用 select 避免重复关闭 channel) + select { + case <-m.stopRefresh: + // 已经关闭,不需要再次关闭 + default: + close(m.stopRefresh) + m.refreshWg.Wait() + } +} diff --git a/internal/mcp/external_manager_test.go b/internal/mcp/external_manager_test.go new file mode 100644 index 00000000..d4c49851 --- /dev/null +++ b/internal/mcp/external_manager_test.go @@ -0,0 +1,239 @@ +package mcp + +import ( + "context" + "testing" + "time" + + "cyberstrike-ai/internal/config" + + "go.uber.org/zap" +) + +func TestExternalMCPManager_AddOrUpdateConfig(t *testing.T) { + logger := zap.NewNop() + manager := NewExternalMCPManager(logger) + + // 测试添加stdio配置 + stdioCfg := config.ExternalMCPServerConfig{ + Command: "python3", + Args: []string{"/path/to/script.py"}, + Transport: "stdio", + Description: "Test stdio MCP", + Timeout: 30, + Enabled: true, + } + + err := manager.AddOrUpdateConfig("test-stdio", stdioCfg) + if err != nil { + t.Fatalf("添加stdio配置失败: %v", err) + } + + // 测试添加HTTP配置 + httpCfg := config.ExternalMCPServerConfig{ + Transport: "http", + URL: "http://127.0.0.1:8081/mcp", + Description: "Test HTTP MCP", + Timeout: 30, + Enabled: false, + } + + err = manager.AddOrUpdateConfig("test-http", httpCfg) + if err != nil { + t.Fatalf("添加HTTP配置失败: %v", err) + } + + // 验证配置已保存 + configs := manager.GetConfigs() + if len(configs) != 2 { + t.Fatalf("期望2个配置,实际%d个", len(configs)) + } + + if configs["test-stdio"].Command != stdioCfg.Command { + t.Errorf("stdio配置命令不匹配") + } + + if configs["test-http"].URL != httpCfg.URL { + t.Errorf("HTTP配置URL不匹配") + } +} + +func TestExternalMCPManager_RemoveConfig(t *testing.T) { + logger := zap.NewNop() + manager := NewExternalMCPManager(logger) + + cfg := config.ExternalMCPServerConfig{ + Command: "python3", + Transport: "stdio", + Enabled: false, + } + + manager.AddOrUpdateConfig("test-remove", cfg) + + // 移除配置 + err := manager.RemoveConfig("test-remove") + if err != nil { + t.Fatalf("移除配置失败: %v", err) + } + + configs := manager.GetConfigs() + if _, exists := configs["test-remove"]; exists { + t.Error("配置应该已被移除") + } +} + +func TestExternalMCPManager_GetStats(t *testing.T) { + logger := zap.NewNop() + manager := NewExternalMCPManager(logger) + + // 添加多个配置 + manager.AddOrUpdateConfig("enabled1", config.ExternalMCPServerConfig{ + Command: "python3", + Enabled: true, + }) + + manager.AddOrUpdateConfig("enabled2", config.ExternalMCPServerConfig{ + URL: "http://127.0.0.1:8081/mcp", + Enabled: true, + }) + + manager.AddOrUpdateConfig("disabled1", config.ExternalMCPServerConfig{ + Command: "python3", + Enabled: false, + Disabled: true, // 明确设置为禁用 + }) + + stats := manager.GetStats() + + if stats["total"].(int) != 3 { + t.Errorf("期望总数3,实际%d", stats["total"]) + } + + if stats["enabled"].(int) != 2 { + t.Errorf("期望启用数2,实际%d", stats["enabled"]) + } + + if stats["disabled"].(int) != 1 { + t.Errorf("期望停用数1,实际%d", stats["disabled"]) + } +} + +func TestExternalMCPManager_LoadConfigs(t *testing.T) { + logger := zap.NewNop() + manager := NewExternalMCPManager(logger) + + externalMCPConfig := config.ExternalMCPConfig{ + Servers: map[string]config.ExternalMCPServerConfig{ + "loaded1": { + Command: "python3", + Enabled: true, + }, + "loaded2": { + URL: "http://127.0.0.1:8081/mcp", + Enabled: false, + }, + }, + } + + manager.LoadConfigs(&externalMCPConfig) + + configs := manager.GetConfigs() + if len(configs) != 2 { + t.Fatalf("期望2个配置,实际%d个", len(configs)) + } + + if configs["loaded1"].Command != "python3" { + t.Error("配置1加载失败") + } + + if configs["loaded2"].URL != "http://127.0.0.1:8081/mcp" { + t.Error("配置2加载失败") + } +} + +// TestLazySDKClient_InitializeFails 验证无效配置时 SDK 客户端 Initialize 失败并设置 error 状态 +func TestLazySDKClient_InitializeFails(t *testing.T) { + logger := zap.NewNop() + // 使用不存在的 HTTP 地址,Initialize 应失败 + cfg := config.ExternalMCPServerConfig{ + Transport: "http", + URL: "http://127.0.0.1:19999/nonexistent", + Timeout: 2, + } + c := newLazySDKClient(cfg, logger) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + err := c.Initialize(ctx) + if err == nil { + t.Fatal("expected error when connecting to invalid server") + } + if c.GetStatus() != "error" { + t.Errorf("expected status error, got %s", c.GetStatus()) + } + c.Close() +} + +func TestExternalMCPManager_StartStopClient(t *testing.T) { + logger := zap.NewNop() + manager := NewExternalMCPManager(logger) + + // 添加一个禁用的配置 + cfg := config.ExternalMCPServerConfig{ + Command: "python3", + Transport: "stdio", + Enabled: false, + } + + manager.AddOrUpdateConfig("test-start-stop", cfg) + + // 尝试启动(可能会失败,因为没有真实的服务器) + err := manager.StartClient("test-start-stop") + if err != nil { + t.Logf("启动失败(可能是没有服务器): %v", err) + } + + // 停止 + err = manager.StopClient("test-start-stop") + if err != nil { + t.Fatalf("停止失败: %v", err) + } + + // 验证配置已更新为禁用 + configs := manager.GetConfigs() + if configs["test-start-stop"].Enabled { + t.Error("配置应该已被禁用") + } +} + +func TestExternalMCPManager_CallTool(t *testing.T) { + logger := zap.NewNop() + manager := NewExternalMCPManager(logger) + + // 测试调用不存在的工具 + _, _, err := manager.CallTool(context.Background(), "nonexistent::tool", map[string]interface{}{}) + if err == nil { + t.Error("应该返回错误") + } + + // 测试无效的工具名称格式 + _, _, err = manager.CallTool(context.Background(), "invalid-tool-name", map[string]interface{}{}) + if err == nil { + t.Error("应该返回错误(无效格式)") + } +} + +func TestExternalMCPManager_GetAllTools(t *testing.T) { + logger := zap.NewNop() + manager := NewExternalMCPManager(logger) + + ctx := context.Background() + tools, err := manager.GetAllTools(ctx) + if err != nil { + t.Fatalf("获取工具列表失败: %v", err) + } + + // 如果没有连接的客户端,应该返回空列表 + if len(tools) != 0 { + t.Logf("获取到%d个工具", len(tools)) + } +} diff --git a/internal/mcp/server.go b/internal/mcp/server.go new file mode 100644 index 00000000..37670ba6 --- /dev/null +++ b/internal/mcp/server.go @@ -0,0 +1,1237 @@ +package mcp + +import ( + "bufio" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "sort" + "strings" + "sync" + "time" + + "github.com/google/uuid" + "go.uber.org/zap" +) + +// MonitorStorage 监控数据存储接口 +type MonitorStorage interface { + SaveToolExecution(exec *ToolExecution) error + LoadToolExecutions() ([]*ToolExecution, error) + GetToolExecution(id string) (*ToolExecution, error) + SaveToolStats(toolName string, stats *ToolStats) error + LoadToolStats() (map[string]*ToolStats, error) + UpdateToolStats(toolName string, totalCalls, successCalls, failedCalls int, lastCallTime *time.Time) error +} + +// Server MCP服务器 +type Server struct { + tools map[string]ToolHandler + toolDefs map[string]Tool // 工具定义 + executions map[string]*ToolExecution + stats map[string]*ToolStats + prompts map[string]*Prompt // 提示词模板 + resources map[string]*Resource // 资源 + storage MonitorStorage // 可选的持久化存储 + mu sync.RWMutex + logger *zap.Logger + maxExecutionsInMemory int // 内存中最大执行记录数 + sseClients map[string]*sseClient +} + +type sseClient struct { + id string + send chan []byte +} + +// ToolHandler 工具处理函数 +type ToolHandler func(ctx context.Context, args map[string]interface{}) (*ToolResult, error) + +// NewServer 创建新的MCP服务器 +func NewServer(logger *zap.Logger) *Server { + return NewServerWithStorage(logger, nil) +} + +// NewServerWithStorage 创建新的MCP服务器(带持久化存储) +func NewServerWithStorage(logger *zap.Logger, storage MonitorStorage) *Server { + s := &Server{ + tools: make(map[string]ToolHandler), + toolDefs: make(map[string]Tool), + executions: make(map[string]*ToolExecution), + stats: make(map[string]*ToolStats), + prompts: make(map[string]*Prompt), + resources: make(map[string]*Resource), + storage: storage, + logger: logger, + maxExecutionsInMemory: 1000, // 默认最多在内存中保留1000条执行记录 + sseClients: make(map[string]*sseClient), + } + + // 初始化默认提示词和资源 + s.initDefaultPrompts() + s.initDefaultResources() + + return s +} + +// RegisterTool 注册工具 +func (s *Server) RegisterTool(tool Tool, handler ToolHandler) { + s.mu.Lock() + defer s.mu.Unlock() + s.tools[tool.Name] = handler + s.toolDefs[tool.Name] = tool + + // 自动为工具创建资源文档 + resourceURI := fmt.Sprintf("tool://%s", tool.Name) + s.resources[resourceURI] = &Resource{ + URI: resourceURI, + Name: fmt.Sprintf("%s工具文档", tool.Name), + Description: tool.Description, + MimeType: "text/plain", + } +} + +// ClearTools 清空所有工具(用于重新加载配置) +func (s *Server) ClearTools() { + s.mu.Lock() + defer s.mu.Unlock() + + // 清空工具和工具定义 + s.tools = make(map[string]ToolHandler) + s.toolDefs = make(map[string]Tool) + + // 清空工具相关的资源(保留其他资源) + newResources := make(map[string]*Resource) + for uri, resource := range s.resources { + // 保留非工具资源 + if !strings.HasPrefix(uri, "tool://") { + newResources[uri] = resource + } + } + s.resources = newResources +} + +// HandleHTTP 处理HTTP请求 +func (s *Server) HandleHTTP(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet && strings.Contains(r.Header.Get("Accept"), "text/event-stream") { + s.handleSSE(w, r) + return + } + + if r.Method != http.MethodPost { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + // 官方 MCP SSE 规范:带 sessionid 的 POST 表示消息发往该 SSE 会话,响应通过 SSE 流返回 + if sessionID := r.URL.Query().Get("sessionid"); sessionID != "" { + s.serveSSESessionMessage(w, r, sessionID) + return + } + + // 简单 POST:请求体为 JSON-RPC,响应在 body 中返回 + body, err := io.ReadAll(r.Body) + if err != nil { + s.sendError(w, nil, -32700, "Parse error", err.Error()) + return + } + + var msg Message + if err := json.Unmarshal(body, &msg); err != nil { + s.sendError(w, nil, -32700, "Parse error", err.Error()) + return + } + + response := s.handleMessage(&msg) + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(response) +} + +// serveSSESessionMessage 处理发往 SSE 会话的 POST:读取 JSON-RPC 请求,处理后将响应通过该会话的 SSE 流推送 +func (s *Server) serveSSESessionMessage(w http.ResponseWriter, r *http.Request, sessionID string) { + s.mu.RLock() + client, exists := s.sseClients[sessionID] + s.mu.RUnlock() + if !exists || client == nil { + http.Error(w, "session not found", http.StatusNotFound) + return + } + + body, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, "failed to read body", http.StatusBadRequest) + return + } + + var msg Message + if err := json.Unmarshal(body, &msg); err != nil { + http.Error(w, "failed to parse body", http.StatusBadRequest) + return + } + + response := s.handleMessage(&msg) + if response == nil { + w.WriteHeader(http.StatusAccepted) + return + } + + respBytes, err := json.Marshal(response) + if err != nil { + http.Error(w, "failed to encode response", http.StatusInternalServerError) + return + } + + select { + case client.send <- respBytes: + w.WriteHeader(http.StatusAccepted) + default: + http.Error(w, "session send buffer full", http.StatusServiceUnavailable) + } +} + +// handleSSE 处理 SSE 连接,兼容官方 MCP 2024-11-05 SSE 规范: +// 1. 首个事件必须为 event: endpoint,data 为客户端 POST 消息的 URL(含 sessionid) +// 2. 后续事件为 event: message,data 为 JSON-RPC 响应 +func (s *Server) handleSSE(w http.ResponseWriter, r *http.Request) { + flusher, ok := w.(http.Flusher) + if !ok { + http.Error(w, "Streaming unsupported", http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + w.Header().Set("X-Accel-Buffering", "no") + + sessionID := uuid.New().String() + client := &sseClient{ + id: sessionID, + send: make(chan []byte, 32), + } + + s.addSSEClient(client) + defer s.removeSSEClient(client.id) + + // 官方规范:首个事件为 endpoint,data 为消息端点 URL(客户端将向该 URL POST 请求) + scheme := "http" + if r.TLS != nil { + scheme = "https" + } + if r.URL.Scheme != "" { + scheme = r.URL.Scheme + } + endpointURL := fmt.Sprintf("%s://%s%s?sessionid=%s", scheme, r.Host, r.URL.Path, sessionID) + fmt.Fprintf(w, "event: endpoint\ndata: %s\n\n", endpointURL) + flusher.Flush() + + ticker := time.NewTicker(15 * time.Second) + defer ticker.Stop() + + for { + select { + case <-r.Context().Done(): + return + case msg, ok := <-client.send: + if !ok { + return + } + fmt.Fprintf(w, "event: message\ndata: %s\n\n", msg) + flusher.Flush() + case <-ticker.C: + fmt.Fprintf(w, ": ping\n\n") + flusher.Flush() + } + } +} + +// addSSEClient 注册SSE客户端 +func (s *Server) addSSEClient(client *sseClient) { + s.mu.Lock() + defer s.mu.Unlock() + s.sseClients[client.id] = client +} + +// removeSSEClient 移除SSE客户端 +func (s *Server) removeSSEClient(id string) { + s.mu.Lock() + defer s.mu.Unlock() + if client, exists := s.sseClients[id]; exists { + close(client.send) + delete(s.sseClients, id) + } +} + +// handleMessage 处理MCP消息 +func (s *Server) handleMessage(msg *Message) *Message { + // 检查是否是通知(notification)- 通知没有id字段,不需要响应 + isNotification := msg.ID.Value() == nil || msg.ID.String() == "" + + // 如果不是通知且ID为空,生成新的UUID + if !isNotification && msg.ID.String() == "" { + msg.ID = MessageID{value: uuid.New().String()} + } + + switch msg.Method { + case "initialize": + return s.handleInitialize(msg) + case "tools/list": + return s.handleListTools(msg) + case "tools/call": + return s.handleCallTool(msg) + case "prompts/list": + return s.handleListPrompts(msg) + case "prompts/get": + return s.handleGetPrompt(msg) + case "resources/list": + return s.handleListResources(msg) + case "resources/read": + return s.handleReadResource(msg) + case "sampling/request": + return s.handleSamplingRequest(msg) + case "notifications/initialized": + // 通知类型,不需要响应 + s.logger.Debug("收到 initialized 通知") + return nil + case "": + // 空方法名,可能是通知,不返回错误 + if isNotification { + s.logger.Debug("收到无方法名的通知消息") + return nil + } + fallthrough + default: + // 如果是通知,不返回错误响应 + if isNotification { + s.logger.Debug("收到未知通知", zap.String("method", msg.Method)) + return nil + } + // 对于请求,返回方法未找到错误 + return &Message{ + ID: msg.ID, + Type: MessageTypeError, + Version: "2.0", + Error: &Error{Code: -32601, Message: "Method not found"}, + } + } +} + +// handleInitialize 处理初始化请求 +func (s *Server) handleInitialize(msg *Message) *Message { + var req InitializeRequest + if err := json.Unmarshal(msg.Params, &req); err != nil { + return &Message{ + ID: msg.ID, + Type: MessageTypeError, + Version: "2.0", + Error: &Error{Code: -32602, Message: "Invalid params"}, + } + } + + response := InitializeResponse{ + ProtocolVersion: ProtocolVersion, + Capabilities: ServerCapabilities{ + Tools: map[string]interface{}{ + "listChanged": true, + }, + Prompts: map[string]interface{}{ + "listChanged": true, + }, + Resources: map[string]interface{}{ + "subscribe": true, + "listChanged": true, + }, + Sampling: map[string]interface{}{}, + }, + ServerInfo: ServerInfo{ + Name: "CyberStrikeAI", + Version: "1.0.0", + }, + } + + result, _ := json.Marshal(response) + return &Message{ + ID: msg.ID, + Type: MessageTypeResponse, + Version: "2.0", + Result: result, + } +} + +// handleListTools 处理列出工具请求 +func (s *Server) handleListTools(msg *Message) *Message { + s.mu.RLock() + tools := make([]Tool, 0, len(s.toolDefs)) + for _, tool := range s.toolDefs { + tools = append(tools, tool) + } + s.mu.RUnlock() + s.logger.Debug("tools/list 请求", zap.Int("返回工具数", len(tools))) + + response := ListToolsResponse{Tools: tools} + result, _ := json.Marshal(response) + return &Message{ + ID: msg.ID, + Type: MessageTypeResponse, + Version: "2.0", + Result: result, + } +} + +// handleCallTool 处理工具调用请求 +func (s *Server) handleCallTool(msg *Message) *Message { + var req CallToolRequest + if err := json.Unmarshal(msg.Params, &req); err != nil { + return &Message{ + ID: msg.ID, + Type: MessageTypeError, + Version: "2.0", + Error: &Error{Code: -32602, Message: "Invalid params"}, + } + } + + executionID := uuid.New().String() + execution := &ToolExecution{ + ID: executionID, + ToolName: req.Name, + Arguments: req.Arguments, + Status: "running", + StartTime: time.Now(), + } + + s.mu.Lock() + s.executions[executionID] = execution + // 如果内存中的执行记录超过限制,清理最旧的记录 + s.cleanupOldExecutions() + s.mu.Unlock() + + if s.storage != nil { + if err := s.storage.SaveToolExecution(execution); err != nil { + s.logger.Warn("保存执行记录到数据库失败", zap.Error(err)) + } + } + + s.mu.RLock() + handler, exists := s.tools[req.Name] + s.mu.RUnlock() + + if !exists { + execution.Status = "failed" + execution.Error = "Tool not found" + now := time.Now() + execution.EndTime = &now + execution.Duration = now.Sub(execution.StartTime) + + if s.storage != nil { + if err := s.storage.SaveToolExecution(execution); err != nil { + s.logger.Warn("保存执行记录到数据库失败", zap.Error(err)) + } + s.mu.Lock() + delete(s.executions, executionID) + s.mu.Unlock() + } + + s.updateStats(req.Name, true) + + return &Message{ + ID: msg.ID, + Type: MessageTypeError, + Version: "2.0", + Error: &Error{Code: -32601, Message: "Tool not found"}, + } + } + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Minute) + defer cancel() + + s.logger.Info("开始执行工具", + zap.String("toolName", req.Name), + zap.Any("arguments", req.Arguments), + ) + + result, err := handler(ctx, req.Arguments) + now := time.Now() + var failed bool + var finalResult *ToolResult + + s.mu.Lock() + execution.EndTime = &now + execution.Duration = now.Sub(execution.StartTime) + + if err != nil { + execution.Status = "failed" + execution.Error = err.Error() + failed = true + } else if result != nil && result.IsError { + execution.Status = "failed" + if len(result.Content) > 0 { + execution.Error = result.Content[0].Text + } else { + execution.Error = "工具执行返回错误结果" + } + execution.Result = result + failed = true + } else { + execution.Status = "completed" + if result == nil { + result = &ToolResult{ + Content: []Content{ + {Type: "text", Text: "工具执行完成,但未返回结果"}, + }, + } + } + execution.Result = result + failed = false + } + + finalResult = execution.Result + s.mu.Unlock() + + if s.storage != nil { + if err := s.storage.SaveToolExecution(execution); err != nil { + s.logger.Warn("保存执行记录到数据库失败", zap.Error(err)) + } + } + + s.updateStats(req.Name, failed) + + if s.storage != nil { + s.mu.Lock() + delete(s.executions, executionID) + s.mu.Unlock() + } + + if err != nil { + s.logger.Error("工具执行失败", + zap.String("toolName", req.Name), + zap.Error(err), + ) + + errorResult, _ := json.Marshal(CallToolResponse{ + Content: []Content{ + {Type: "text", Text: fmt.Sprintf("工具执行失败: %v", err)}, + }, + IsError: true, + }) + return &Message{ + ID: msg.ID, + Type: MessageTypeResponse, + Version: "2.0", + Result: errorResult, + } + } + + if finalResult != nil && finalResult.IsError { + s.logger.Warn("工具执行返回错误结果", + zap.String("toolName", req.Name), + ) + + errorResult, _ := json.Marshal(CallToolResponse{ + Content: finalResult.Content, + IsError: true, + }) + return &Message{ + ID: msg.ID, + Type: MessageTypeResponse, + Version: "2.0", + Result: errorResult, + } + } + + if finalResult == nil { + finalResult = &ToolResult{ + Content: []Content{ + {Type: "text", Text: "工具执行完成,但未返回结果"}, + }, + } + } + + resultJSON, _ := json.Marshal(CallToolResponse{ + Content: finalResult.Content, + IsError: false, + }) + + s.logger.Info("工具执行完成", + zap.String("toolName", req.Name), + zap.Bool("isError", finalResult.IsError), + ) + + return &Message{ + ID: msg.ID, + Type: MessageTypeResponse, + Version: "2.0", + Result: resultJSON, + } +} + +// updateStats 更新统计信息 +func (s *Server) updateStats(toolName string, failed bool) { + now := time.Now() + if s.storage != nil { + totalCalls := 1 + successCalls := 0 + failedCalls := 0 + if failed { + failedCalls = 1 + } else { + successCalls = 1 + } + if err := s.storage.UpdateToolStats(toolName, totalCalls, successCalls, failedCalls, &now); err != nil { + s.logger.Warn("保存统计信息到数据库失败", zap.Error(err)) + } + return + } + + s.mu.Lock() + defer s.mu.Unlock() + + if s.stats[toolName] == nil { + s.stats[toolName] = &ToolStats{ + ToolName: toolName, + } + } + + stats := s.stats[toolName] + stats.TotalCalls++ + stats.LastCallTime = &now + + if failed { + stats.FailedCalls++ + } else { + stats.SuccessCalls++ + } +} + +// GetExecution 获取执行记录(先从内存查找,再从数据库查找) +func (s *Server) GetExecution(id string) (*ToolExecution, bool) { + s.mu.RLock() + exec, exists := s.executions[id] + s.mu.RUnlock() + + if exists { + return exec, true + } + + if s.storage != nil { + exec, err := s.storage.GetToolExecution(id) + if err == nil { + return exec, true + } + } + + return nil, false +} + +// loadHistoricalData 从数据库加载历史数据 +func (s *Server) loadHistoricalData() { + if s.storage == nil { + return + } + + // 加载历史执行记录(最近1000条) + executions, err := s.storage.LoadToolExecutions() + if err != nil { + s.logger.Warn("加载历史执行记录失败", zap.Error(err)) + } else { + s.mu.Lock() + for _, exec := range executions { + // 只加载最近 maxExecutionsInMemory 条,避免内存占用过大 + if len(s.executions) < s.maxExecutionsInMemory { + s.executions[exec.ID] = exec + } else { + break + } + } + s.mu.Unlock() + s.logger.Info("加载历史执行记录", zap.Int("count", len(executions))) + } + + // 加载历史统计信息 + stats, err := s.storage.LoadToolStats() + if err != nil { + s.logger.Warn("加载历史统计信息失败", zap.Error(err)) + } else { + s.mu.Lock() + for k, v := range stats { + s.stats[k] = v + } + s.mu.Unlock() + s.logger.Info("加载历史统计信息", zap.Int("count", len(stats))) + } +} + +// GetAllExecutions 获取所有执行记录(合并内存和数据库) +func (s *Server) GetAllExecutions() []*ToolExecution { + if s.storage != nil { + dbExecutions, err := s.storage.LoadToolExecutions() + if err == nil { + execMap := make(map[string]*ToolExecution) + for _, exec := range dbExecutions { + if _, exists := execMap[exec.ID]; !exists { + execMap[exec.ID] = exec + } + } + + s.mu.RLock() + for id, exec := range s.executions { + if _, exists := execMap[id]; !exists { + execMap[id] = exec + } + } + s.mu.RUnlock() + + result := make([]*ToolExecution, 0, len(execMap)) + for _, exec := range execMap { + result = append(result, exec) + } + return result + } else { + s.logger.Warn("从数据库加载执行记录失败", zap.Error(err)) + } + } + + s.mu.RLock() + defer s.mu.RUnlock() + + memExecutions := make([]*ToolExecution, 0, len(s.executions)) + for _, exec := range s.executions { + memExecutions = append(memExecutions, exec) + } + return memExecutions +} + +// GetStats 获取统计信息(合并内存和数据库) +func (s *Server) GetStats() map[string]*ToolStats { + if s.storage != nil { + dbStats, err := s.storage.LoadToolStats() + if err == nil { + return dbStats + } + s.logger.Warn("从数据库加载统计信息失败", zap.Error(err)) + } + + s.mu.RLock() + defer s.mu.RUnlock() + + memStats := make(map[string]*ToolStats) + for k, v := range s.stats { + statCopy := *v + memStats[k] = &statCopy + } + + return memStats +} + +// GetAllTools 获取所有已注册的工具(用于Agent动态获取工具列表) +func (s *Server) GetAllTools() []Tool { + s.mu.RLock() + defer s.mu.RUnlock() + + tools := make([]Tool, 0, len(s.toolDefs)) + for _, tool := range s.toolDefs { + tools = append(tools, tool) + } + return tools +} + +// CallTool 直接调用工具(用于内部调用) +func (s *Server) CallTool(ctx context.Context, toolName string, args map[string]interface{}) (*ToolResult, string, error) { + s.mu.RLock() + handler, exists := s.tools[toolName] + s.mu.RUnlock() + + if !exists { + return nil, "", fmt.Errorf("工具 %s 未找到", toolName) + } + + // 创建执行记录 + executionID := uuid.New().String() + execution := &ToolExecution{ + ID: executionID, + ToolName: toolName, + Arguments: args, + Status: "running", + StartTime: time.Now(), + } + + s.mu.Lock() + s.executions[executionID] = execution + // 如果内存中的执行记录超过限制,清理最旧的记录 + s.cleanupOldExecutions() + s.mu.Unlock() + + if s.storage != nil { + if err := s.storage.SaveToolExecution(execution); err != nil { + s.logger.Warn("保存执行记录到数据库失败", zap.Error(err)) + } + } + + result, err := handler(ctx, args) + + s.mu.Lock() + now := time.Now() + execution.EndTime = &now + execution.Duration = now.Sub(execution.StartTime) + var failed bool + var finalResult *ToolResult + + if err != nil { + execution.Status = "failed" + execution.Error = err.Error() + failed = true + } else if result != nil && result.IsError { + execution.Status = "failed" + if len(result.Content) > 0 { + execution.Error = result.Content[0].Text + } else { + execution.Error = "工具执行返回错误结果" + } + execution.Result = result + failed = true + finalResult = result + } else { + execution.Status = "completed" + if result == nil { + result = &ToolResult{ + Content: []Content{ + {Type: "text", Text: "工具执行完成,但未返回结果"}, + }, + } + } + execution.Result = result + finalResult = result + failed = false + } + + if finalResult == nil { + finalResult = execution.Result + } + s.mu.Unlock() + + if s.storage != nil { + if err := s.storage.SaveToolExecution(execution); err != nil { + s.logger.Warn("保存执行记录到数据库失败", zap.Error(err)) + } + } + + s.updateStats(toolName, failed) + + if s.storage != nil { + s.mu.Lock() + delete(s.executions, executionID) + s.mu.Unlock() + } + + if err != nil { + return nil, executionID, err + } + + return finalResult, executionID, nil +} + +// cleanupOldExecutions 清理旧的执行记录,防止内存无限增长 +func (s *Server) cleanupOldExecutions() { + if len(s.executions) <= s.maxExecutionsInMemory { + return + } + + // 按开始时间排序,找出最旧的记录 + type execWithTime struct { + id string + startTime time.Time + } + execs := make([]execWithTime, 0, len(s.executions)) + for id, exec := range s.executions { + execs = append(execs, execWithTime{ + id: id, + startTime: exec.StartTime, + }) + } + + // 使用 sort 包进行高效排序(最旧的在前) + sort.Slice(execs, func(i, j int) bool { + return execs[i].startTime.Before(execs[j].startTime) + }) + + // 删除最旧的记录,保留 maxExecutionsInMemory 条 + toDelete := len(s.executions) - s.maxExecutionsInMemory + for i := 0; i < toDelete; i++ { + delete(s.executions, execs[i].id) + } + + s.logger.Debug("清理旧的执行记录", + zap.Int("before", len(execs)), + zap.Int("after", len(s.executions)), + zap.Int("deleted", toDelete), + ) +} + +// initDefaultPrompts 初始化默认提示词模板 +func (s *Server) initDefaultPrompts() { + s.mu.Lock() + defer s.mu.Unlock() + + // 网络安全测试提示词 + s.prompts["security_scan"] = &Prompt{ + Name: "security_scan", + Description: "生成网络安全扫描任务的提示词", + Arguments: []PromptArgument{ + {Name: "target", Description: "扫描目标(IP地址或域名)", Required: true}, + {Name: "scan_type", Description: "扫描类型(port, vuln, web等)", Required: false}, + }, + } + + // 渗透测试提示词 + s.prompts["penetration_test"] = &Prompt{ + Name: "penetration_test", + Description: "生成渗透测试任务的提示词", + Arguments: []PromptArgument{ + {Name: "target", Description: "测试目标", Required: true}, + {Name: "scope", Description: "测试范围", Required: false}, + }, + } +} + +// initDefaultResources 初始化默认资源 +// 注意:工具资源现在在 RegisterTool 时自动创建,此函数保留用于其他非工具资源 +func (s *Server) initDefaultResources() { + // 工具资源已改为在 RegisterTool 时自动创建,无需在此硬编码 +} + +// handleListPrompts 处理列出提示词请求 +func (s *Server) handleListPrompts(msg *Message) *Message { + s.mu.RLock() + prompts := make([]Prompt, 0, len(s.prompts)) + for _, prompt := range s.prompts { + prompts = append(prompts, *prompt) + } + s.mu.RUnlock() + + response := ListPromptsResponse{ + Prompts: prompts, + } + result, _ := json.Marshal(response) + return &Message{ + ID: msg.ID, + Type: MessageTypeResponse, + Version: "2.0", + Result: result, + } +} + +// handleGetPrompt 处理获取提示词请求 +func (s *Server) handleGetPrompt(msg *Message) *Message { + var req GetPromptRequest + if err := json.Unmarshal(msg.Params, &req); err != nil { + return &Message{ + ID: msg.ID, + Type: MessageTypeError, + Version: "2.0", + Error: &Error{Code: -32602, Message: "Invalid params"}, + } + } + + s.mu.RLock() + prompt, exists := s.prompts[req.Name] + s.mu.RUnlock() + + if !exists { + return &Message{ + ID: msg.ID, + Type: MessageTypeError, + Version: "2.0", + Error: &Error{Code: -32601, Message: "Prompt not found"}, + } + } + + // 根据提示词名称生成消息 + messages := s.generatePromptMessages(prompt, req.Arguments) + + response := GetPromptResponse{ + Messages: messages, + } + result, _ := json.Marshal(response) + return &Message{ + ID: msg.ID, + Type: MessageTypeResponse, + Version: "2.0", + Result: result, + } +} + +// generatePromptMessages 生成提示词消息 +func (s *Server) generatePromptMessages(prompt *Prompt, args map[string]interface{}) []PromptMessage { + messages := []PromptMessage{} + + switch prompt.Name { + case "security_scan": + target, _ := args["target"].(string) + scanType, _ := args["scan_type"].(string) + if scanType == "" { + scanType = "comprehensive" + } + + content := fmt.Sprintf(`请对目标 %s 执行%s安全扫描。包括: +1. 端口扫描和服务识别 +2. 漏洞检测 +3. Web应用安全测试 +4. 生成详细的安全报告`, target, scanType) + + messages = append(messages, PromptMessage{ + Role: "user", + Content: content, + }) + + case "penetration_test": + target, _ := args["target"].(string) + scope, _ := args["scope"].(string) + + content := fmt.Sprintf(`请对目标 %s 执行渗透测试。`, target) + if scope != "" { + content += fmt.Sprintf("测试范围:%s", scope) + } + content += "\n请按照OWASP Top 10进行全面的安全测试。" + + messages = append(messages, PromptMessage{ + Role: "user", + Content: content, + }) + + default: + messages = append(messages, PromptMessage{ + Role: "user", + Content: "请执行安全测试任务", + }) + } + + return messages +} + +// handleListResources 处理列出资源请求 +func (s *Server) handleListResources(msg *Message) *Message { + s.mu.RLock() + resources := make([]Resource, 0, len(s.resources)) + for _, resource := range s.resources { + resources = append(resources, *resource) + } + s.mu.RUnlock() + + response := ListResourcesResponse{ + Resources: resources, + } + result, _ := json.Marshal(response) + return &Message{ + ID: msg.ID, + Type: MessageTypeResponse, + Version: "2.0", + Result: result, + } +} + +// handleReadResource 处理读取资源请求 +func (s *Server) handleReadResource(msg *Message) *Message { + var req ReadResourceRequest + if err := json.Unmarshal(msg.Params, &req); err != nil { + return &Message{ + ID: msg.ID, + Type: MessageTypeError, + Version: "2.0", + Error: &Error{Code: -32602, Message: "Invalid params"}, + } + } + + s.mu.RLock() + resource, exists := s.resources[req.URI] + s.mu.RUnlock() + + if !exists { + return &Message{ + ID: msg.ID, + Type: MessageTypeError, + Version: "2.0", + Error: &Error{Code: -32601, Message: "Resource not found"}, + } + } + + // 生成资源内容 + content := s.generateResourceContent(resource) + + response := ReadResourceResponse{ + Contents: []ResourceContent{content}, + } + result, _ := json.Marshal(response) + return &Message{ + ID: msg.ID, + Type: MessageTypeResponse, + Version: "2.0", + Result: result, + } +} + +// generateResourceContent 生成资源内容 +func (s *Server) generateResourceContent(resource *Resource) ResourceContent { + content := ResourceContent{ + URI: resource.URI, + MimeType: resource.MimeType, + } + + // 如果是工具资源,生成详细文档 + if strings.HasPrefix(resource.URI, "tool://") { + toolName := strings.TrimPrefix(resource.URI, "tool://") + content.Text = s.generateToolDocumentation(toolName, resource) + } else { + // 其他资源使用描述或默认内容 + content.Text = resource.Description + } + + return content +} + +// generateToolDocumentation 生成工具文档 +// 注意:硬编码的工具文档已移除,现在只使用工具定义中的信息 +func (s *Server) generateToolDocumentation(toolName string, resource *Resource) string { + // 获取工具定义以获取更详细的信息 + s.mu.RLock() + tool, hasTool := s.toolDefs[toolName] + s.mu.RUnlock() + + // 使用工具定义中的描述信息 + if hasTool { + doc := fmt.Sprintf("%s\n\n", resource.Description) + if tool.InputSchema != nil { + if props, ok := tool.InputSchema["properties"].(map[string]interface{}); ok { + doc += "参数说明:\n" + for paramName, paramInfo := range props { + if paramMap, ok := paramInfo.(map[string]interface{}); ok { + if desc, ok := paramMap["description"].(string); ok { + doc += fmt.Sprintf("- %s: %s\n", paramName, desc) + } + } + } + } + } + return doc + } + return resource.Description +} + +// handleSamplingRequest 处理采样请求 +func (s *Server) handleSamplingRequest(msg *Message) *Message { + var req SamplingRequest + if err := json.Unmarshal(msg.Params, &req); err != nil { + return &Message{ + ID: msg.ID, + Type: MessageTypeError, + Version: "2.0", + Error: &Error{Code: -32602, Message: "Invalid params"}, + } + } + + // 注意:采样功能通常需要连接到实际的LLM服务 + // 这里返回一个占位符响应,实际实现需要集成LLM API + s.logger.Warn("Sampling request received but not fully implemented", + zap.Any("request", req), + ) + + response := SamplingResponse{ + Content: []SamplingContent{ + { + Type: "text", + Text: "采样功能需要配置LLM服务。请使用Agent Loop API进行AI对话。", + }, + }, + StopReason: "length", + } + result, _ := json.Marshal(response) + return &Message{ + ID: msg.ID, + Type: MessageTypeResponse, + Version: "2.0", + Result: result, + } +} + +// RegisterPrompt 注册提示词模板 +func (s *Server) RegisterPrompt(prompt *Prompt) { + s.mu.Lock() + defer s.mu.Unlock() + s.prompts[prompt.Name] = prompt +} + +// RegisterResource 注册资源 +func (s *Server) RegisterResource(resource *Resource) { + s.mu.Lock() + defer s.mu.Unlock() + s.resources[resource.URI] = resource +} + +// HandleStdio 处理标准输入输出(用于 stdio 传输模式) +// MCP 协议使用换行分隔的 JSON-RPC 消息;管道下需每次写入后 Flush,否则客户端会读不到响应 +func (s *Server) HandleStdio() error { + decoder := json.NewDecoder(os.Stdin) + stdout := bufio.NewWriter(os.Stdout) + encoder := json.NewEncoder(stdout) + // 注意:不设置缩进,MCP 协议期望紧凑的 JSON 格式 + + for { + var msg Message + if err := decoder.Decode(&msg); err != nil { + if err == io.EOF { + break + } + // 日志输出到 stderr,避免干扰 stdout 的 JSON-RPC 通信 + s.logger.Error("读取消息失败", zap.Error(err)) + // 发送错误响应 + errorMsg := Message{ + ID: msg.ID, + Type: MessageTypeError, + Version: "2.0", + Error: &Error{Code: -32700, Message: "Parse error", Data: err.Error()}, + } + if err := encoder.Encode(errorMsg); err != nil { + return fmt.Errorf("发送错误响应失败: %w", err) + } + if err := stdout.Flush(); err != nil { + return fmt.Errorf("刷新 stdout 失败: %w", err) + } + continue + } + + // 处理消息 + response := s.handleMessage(&msg) + + // 如果是通知(response 为 nil),不需要发送响应 + if response == nil { + continue + } + + // 发送响应 + if err := encoder.Encode(response); err != nil { + return fmt.Errorf("发送响应失败: %w", err) + } + if err := stdout.Flush(); err != nil { + return fmt.Errorf("刷新 stdout 失败: %w", err) + } + } + + return nil +} + +// sendError 发送错误响应 +func (s *Server) sendError(w http.ResponseWriter, id interface{}, code int, message, data string) { + var msgID MessageID + if id != nil { + msgID = MessageID{value: id} + } + response := Message{ + ID: msgID, + Type: MessageTypeError, + Version: "2.0", + Error: &Error{Code: code, Message: message, Data: data}, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(response) +} diff --git a/internal/mcp/types.go b/internal/mcp/types.go new file mode 100644 index 00000000..393717b9 --- /dev/null +++ b/internal/mcp/types.go @@ -0,0 +1,295 @@ +package mcp + +import ( + "context" + "encoding/json" + "fmt" + "time" +) + +// ExternalMCPClient 外部 MCP 客户端接口(由 client_sdk.go 基于官方 SDK 实现) +type ExternalMCPClient interface { + Initialize(ctx context.Context) error + ListTools(ctx context.Context) ([]Tool, error) + CallTool(ctx context.Context, name string, args map[string]interface{}) (*ToolResult, error) + Close() error + IsConnected() bool + GetStatus() string +} + +// MCP消息类型 +const ( + MessageTypeRequest = "request" + MessageTypeResponse = "response" + MessageTypeError = "error" + MessageTypeNotify = "notify" +) + +// MCP协议版本 +const ProtocolVersion = "2024-11-05" + +// MessageID 表示JSON-RPC 2.0的id字段,可以是字符串、数字或null +type MessageID struct { + value interface{} +} + +// UnmarshalJSON 自定义反序列化,支持字符串、数字和null +func (m *MessageID) UnmarshalJSON(data []byte) error { + // 尝试解析为null + if string(data) == "null" { + m.value = nil + return nil + } + + // 尝试解析为字符串 + var str string + if err := json.Unmarshal(data, &str); err == nil { + m.value = str + return nil + } + + // 尝试解析为数字 + var num json.Number + if err := json.Unmarshal(data, &num); err == nil { + m.value = num + return nil + } + + return fmt.Errorf("invalid id type") +} + +// MarshalJSON 自定义序列化 +func (m MessageID) MarshalJSON() ([]byte, error) { + if m.value == nil { + return []byte("null"), nil + } + return json.Marshal(m.value) +} + +// String 返回字符串表示 +func (m MessageID) String() string { + if m.value == nil { + return "" + } + return fmt.Sprintf("%v", m.value) +} + +// Value 返回原始值 +func (m MessageID) Value() interface{} { + return m.value +} + +// Message 表示MCP消息(符合JSON-RPC 2.0规范) +type Message struct { + ID MessageID `json:"id,omitempty"` + Type string `json:"-"` // 内部使用,不序列化到JSON + Method string `json:"method,omitempty"` + Params json.RawMessage `json:"params,omitempty"` + Result json.RawMessage `json:"result,omitempty"` + Error *Error `json:"error,omitempty"` + Version string `json:"jsonrpc,omitempty"` // JSON-RPC 2.0 版本标识 +} + +// Error 表示MCP错误 +type Error struct { + Code int `json:"code"` + Message string `json:"message"` + Data interface{} `json:"data,omitempty"` +} + +// Tool 表示MCP工具定义 +type Tool struct { + Name string `json:"name"` + Description string `json:"description"` // 详细描述 + ShortDescription string `json:"shortDescription,omitempty"` // 简短描述(用于工具列表,减少token消耗) + InputSchema map[string]interface{} `json:"inputSchema"` +} + +// ToolCall 表示工具调用 +type ToolCall struct { + Name string `json:"name"` + Arguments map[string]interface{} `json:"arguments"` +} + +// ToolResult 表示工具执行结果 +type ToolResult struct { + Content []Content `json:"content"` + IsError bool `json:"isError,omitempty"` +} + +// Content 表示内容 +type Content struct { + Type string `json:"type"` + Text string `json:"text"` +} + +// InitializeRequest 初始化请求 +type InitializeRequest struct { + ProtocolVersion string `json:"protocolVersion"` + Capabilities map[string]interface{} `json:"capabilities"` + ClientInfo ClientInfo `json:"clientInfo"` +} + +// ClientInfo 客户端信息 +type ClientInfo struct { + Name string `json:"name"` + Version string `json:"version"` +} + +// InitializeResponse 初始化响应 +type InitializeResponse struct { + ProtocolVersion string `json:"protocolVersion"` + Capabilities ServerCapabilities `json:"capabilities"` + ServerInfo ServerInfo `json:"serverInfo"` +} + +// ServerCapabilities 服务器能力 +type ServerCapabilities struct { + Tools map[string]interface{} `json:"tools,omitempty"` + Prompts map[string]interface{} `json:"prompts,omitempty"` + Resources map[string]interface{} `json:"resources,omitempty"` + Sampling map[string]interface{} `json:"sampling,omitempty"` +} + +// ServerInfo 服务器信息 +type ServerInfo struct { + Name string `json:"name"` + Version string `json:"version"` +} + +// ListToolsRequest 列出工具请求 +type ListToolsRequest struct{} + +// ListToolsResponse 列出工具响应 +type ListToolsResponse struct { + Tools []Tool `json:"tools"` +} + +// ListPromptsResponse 列出提示词响应 +type ListPromptsResponse struct { + Prompts []Prompt `json:"prompts"` +} + +// ListResourcesResponse 列出资源响应 +type ListResourcesResponse struct { + Resources []Resource `json:"resources"` +} + +// CallToolRequest 调用工具请求 +type CallToolRequest struct { + Name string `json:"name"` + Arguments map[string]interface{} `json:"arguments"` +} + +// CallToolResponse 调用工具响应 +type CallToolResponse struct { + Content []Content `json:"content"` + IsError bool `json:"isError,omitempty"` +} + +// ToolExecution 工具执行记录 +type ToolExecution struct { + ID string `json:"id"` + ToolName string `json:"toolName"` + Arguments map[string]interface{} `json:"arguments"` + Status string `json:"status"` // pending, running, completed, failed + Result *ToolResult `json:"result,omitempty"` + Error string `json:"error,omitempty"` + StartTime time.Time `json:"startTime"` + EndTime *time.Time `json:"endTime,omitempty"` + Duration time.Duration `json:"duration,omitempty"` +} + +// ToolStats 工具统计信息 +type ToolStats struct { + ToolName string `json:"toolName"` + TotalCalls int `json:"totalCalls"` + SuccessCalls int `json:"successCalls"` + FailedCalls int `json:"failedCalls"` + LastCallTime *time.Time `json:"lastCallTime,omitempty"` +} + +// Prompt 提示词模板 +type Prompt struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + Arguments []PromptArgument `json:"arguments,omitempty"` +} + +// PromptArgument 提示词参数 +type PromptArgument struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + Required bool `json:"required,omitempty"` +} + +// GetPromptRequest 获取提示词请求 +type GetPromptRequest struct { + Name string `json:"name"` + Arguments map[string]interface{} `json:"arguments,omitempty"` +} + +// GetPromptResponse 获取提示词响应 +type GetPromptResponse struct { + Messages []PromptMessage `json:"messages"` +} + +// PromptMessage 提示词消息 +type PromptMessage struct { + Role string `json:"role"` + Content string `json:"content"` +} + +// Resource 资源 +type Resource struct { + URI string `json:"uri"` + Name string `json:"name"` + Description string `json:"description,omitempty"` + MimeType string `json:"mimeType,omitempty"` +} + +// ReadResourceRequest 读取资源请求 +type ReadResourceRequest struct { + URI string `json:"uri"` +} + +// ReadResourceResponse 读取资源响应 +type ReadResourceResponse struct { + Contents []ResourceContent `json:"contents"` +} + +// ResourceContent 资源内容 +type ResourceContent struct { + URI string `json:"uri"` + MimeType string `json:"mimeType,omitempty"` + Text string `json:"text,omitempty"` + Blob string `json:"blob,omitempty"` +} + +// SamplingRequest 采样请求 +type SamplingRequest struct { + Messages []SamplingMessage `json:"messages"` + Model string `json:"model,omitempty"` + MaxTokens int `json:"maxTokens,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"topP,omitempty"` +} + +// SamplingMessage 采样消息 +type SamplingMessage struct { + Role string `json:"role"` + Content string `json:"content"` +} + +// SamplingResponse 采样响应 +type SamplingResponse struct { + Content []SamplingContent `json:"content"` + Model string `json:"model,omitempty"` + StopReason string `json:"stopReason,omitempty"` +} + +// SamplingContent 采样内容 +type SamplingContent struct { + Type string `json:"type"` + Text string `json:"text,omitempty"` +} diff --git a/internal/multiagent/eino_skills.go b/internal/multiagent/eino_skills.go new file mode 100644 index 00000000..73dafe5a --- /dev/null +++ b/internal/multiagent/eino_skills.go @@ -0,0 +1,85 @@ +package multiagent + +import ( + "context" + "fmt" + "os" + "path/filepath" + "strings" + + "cyberstrike-ai/internal/config" + + localbk "github.com/cloudwego/eino-ext/adk/backend/local" + "github.com/cloudwego/eino/adk" + "github.com/cloudwego/eino/adk/middlewares/filesystem" + "github.com/cloudwego/eino/adk/middlewares/skill" + "go.uber.org/zap" +) + +// prepareEinoSkills builds Eino official skill backend + middleware, and a shared local disk backend +// for skill discovery and (optionally) filesystem/execute tools. Returns nils when disabled or dir missing. +func prepareEinoSkills( + ctx context.Context, + skillsDir string, + ma *config.MultiAgentConfig, + logger *zap.Logger, +) (loc *localbk.Local, skillMW adk.ChatModelAgentMiddleware, fsTools bool, err error) { + if ma == nil || ma.EinoSkills.Disable { + return nil, nil, false, nil + } + root := strings.TrimSpace(skillsDir) + if root == "" { + if logger != nil { + logger.Warn("eino skills: skills_dir empty, skip") + } + return nil, nil, false, nil + } + abs, err := filepath.Abs(root) + if err != nil { + return nil, nil, false, fmt.Errorf("skills_dir abs: %w", err) + } + if st, err := os.Stat(abs); err != nil || !st.IsDir() { + if logger != nil { + logger.Warn("eino skills: directory missing, skip", zap.String("dir", abs), zap.Error(err)) + } + return nil, nil, false, nil + } + + loc, err = localbk.NewBackend(ctx, &localbk.Config{}) + if err != nil { + return nil, nil, false, fmt.Errorf("eino local backend: %w", err) + } + + skillBE, err := skill.NewBackendFromFilesystem(ctx, &skill.BackendFromFilesystemConfig{ + Backend: loc, + BaseDir: abs, + }) + if err != nil { + return nil, nil, false, fmt.Errorf("eino skill filesystem backend: %w", err) + } + + sc := &skill.Config{Backend: skillBE} + if name := strings.TrimSpace(ma.EinoSkills.SkillToolName); name != "" { + sc.SkillToolName = &name + } + skillMW, err = skill.NewMiddleware(ctx, sc) + if err != nil { + return nil, nil, false, fmt.Errorf("eino skill middleware: %w", err) + } + + fsTools = ma.EinoSkills.EinoSkillFilesystemToolsEffective() + return loc, skillMW, fsTools, nil +} + +// subAgentFilesystemMiddleware returns filesystem middleware for a sub-agent when Deep itself +// does not set Backend (fsTools false on orchestrator) but we still want tools on subs — not used; +// when orchestrator has Backend, builtin FS is only on outer agent; subs need explicit FS for parity. +func subAgentFilesystemMiddleware(ctx context.Context, loc *localbk.Local) (adk.ChatModelAgentMiddleware, error) { + if loc == nil { + return nil, nil + } + return filesystem.New(ctx, &filesystem.MiddlewareConfig{ + Backend: loc, + StreamingShell: loc, + }) +} diff --git a/internal/multiagent/eino_summarize.go b/internal/multiagent/eino_summarize.go new file mode 100644 index 00000000..81260109 --- /dev/null +++ b/internal/multiagent/eino_summarize.go @@ -0,0 +1,140 @@ +package multiagent + +import ( + "context" + "fmt" + "strings" + + "cyberstrike-ai/internal/agent" + "cyberstrike-ai/internal/config" + + "github.com/bytedance/sonic" + "github.com/cloudwego/eino/adk" + "github.com/cloudwego/eino/adk/middlewares/summarization" + "github.com/cloudwego/eino/components/model" + "github.com/cloudwego/eino/schema" + "go.uber.org/zap" +) + +// einoSummarizeUserInstruction 与单 Agent MemoryCompressor 目标一致:压缩时保留渗透关键信息。 +const einoSummarizeUserInstruction = `在保持所有关键安全测试信息完整的前提下压缩对话历史。 + +必须保留:已确认漏洞与攻击路径、工具输出中的核心发现、凭证与认证细节、架构与薄弱点、当前进度、失败尝试与死路、策略决策。 +保留精确技术细节(URL、路径、参数、Payload、版本号、报错原文可摘要但要点不丢)。 +将冗长扫描输出概括为结论;重复发现合并表述。 + +输出须使后续代理能无缝继续同一授权测试任务。` + +// newEinoSummarizationMiddleware 使用 Eino ADK Summarization 中间件(见 https://www.cloudwego.io/zh/docs/eino/core_modules/eino_adk/eino_adk_chatmodelagentmiddleware/middleware_summarization/)。 +// 触发阈值与单 Agent MemoryCompressor 一致:当估算 token 超过 openai.max_total_tokens 的 90% 时摘要。 +func newEinoSummarizationMiddleware( + ctx context.Context, + summaryModel model.BaseChatModel, + appCfg *config.Config, + logger *zap.Logger, +) (adk.ChatModelAgentMiddleware, error) { + if summaryModel == nil || appCfg == nil { + return nil, fmt.Errorf("multiagent: summarization 需要 model 与配置") + } + maxTotal := appCfg.OpenAI.MaxTotalTokens + if maxTotal <= 0 { + maxTotal = 120000 + } + trigger := int(float64(maxTotal) * 0.9) + if trigger < 4096 { + trigger = maxTotal + if trigger < 4096 { + trigger = 4096 + } + } + preserveMax := trigger / 3 + if preserveMax < 2048 { + preserveMax = 2048 + } + + modelName := strings.TrimSpace(appCfg.OpenAI.Model) + if modelName == "" { + modelName = "gpt-4o" + } + + mw, err := summarization.New(ctx, &summarization.Config{ + Model: summaryModel, + Trigger: &summarization.TriggerCondition{ + ContextTokens: trigger, + }, + TokenCounter: einoSummarizationTokenCounter(modelName), + UserInstruction: einoSummarizeUserInstruction, + EmitInternalEvents: false, + PreserveUserMessages: &summarization.PreserveUserMessages{ + Enabled: true, + MaxTokens: preserveMax, + }, + Callback: func(ctx context.Context, before, after adk.ChatModelAgentState) error { + if logger == nil { + return nil + } + logger.Info("eino summarization 已压缩上下文", + zap.Int("messages_before", len(before.Messages)), + zap.Int("messages_after", len(after.Messages)), + zap.Int("max_total_tokens", maxTotal), + zap.Int("trigger_context_tokens", trigger), + ) + return nil + }, + }) + if err != nil { + return nil, fmt.Errorf("summarization.New: %w", err) + } + return mw, nil +} + +func einoSummarizationTokenCounter(openAIModel string) summarization.TokenCounterFunc { + tc := agent.NewTikTokenCounter() + return func(ctx context.Context, input *summarization.TokenCounterInput) (int, error) { + var sb strings.Builder + for _, msg := range input.Messages { + if msg == nil { + continue + } + sb.WriteString(string(msg.Role)) + sb.WriteByte('\n') + if msg.Content != "" { + sb.WriteString(msg.Content) + sb.WriteByte('\n') + } + if msg.ReasoningContent != "" { + sb.WriteString(msg.ReasoningContent) + sb.WriteByte('\n') + } + if len(msg.ToolCalls) > 0 { + if b, err := sonic.Marshal(msg.ToolCalls); err == nil { + sb.Write(b) + sb.WriteByte('\n') + } + } + for _, part := range msg.UserInputMultiContent { + if part.Type == schema.ChatMessagePartTypeText && part.Text != "" { + sb.WriteString(part.Text) + sb.WriteByte('\n') + } + } + } + for _, tl := range input.Tools { + if tl == nil { + continue + } + cp := *tl + cp.Extra = nil + if text, err := sonic.MarshalString(cp); err == nil { + sb.WriteString(text) + sb.WriteByte('\n') + } + } + text := sb.String() + n, err := tc.Count(openAIModel, text) + if err != nil { + return (len(text) + 3) / 4, nil + } + return n, nil + } +} diff --git a/internal/multiagent/no_nested_task.go b/internal/multiagent/no_nested_task.go new file mode 100644 index 00000000..09ad28e9 --- /dev/null +++ b/internal/multiagent/no_nested_task.go @@ -0,0 +1,62 @@ +package multiagent + +import ( + "context" + "strings" + + "github.com/cloudwego/eino/adk" + "github.com/cloudwego/eino/components/tool" +) + +// noNestedTaskMiddleware 禁止在已经处于 task(sub-agent) 执行链中再次调用 task, +// 避免子代理再次委派子代理造成的无限委派/递归。 +// +// 通过在 ctx 中设置临时标记来实现嵌套检测:外层 task 调用会先标记 ctx, +// 子代理内再调用 task 时会命中该标记并拒绝。 +type noNestedTaskMiddleware struct { + adk.BaseChatModelAgentMiddleware +} + +type nestedTaskCtxKey struct{} + +func newNoNestedTaskMiddleware() adk.ChatModelAgentMiddleware { + return &noNestedTaskMiddleware{} +} + +func (m *noNestedTaskMiddleware) WrapInvokableToolCall( + ctx context.Context, + endpoint adk.InvokableToolCallEndpoint, + tCtx *adk.ToolContext, +) (adk.InvokableToolCallEndpoint, error) { + if tCtx == nil || strings.TrimSpace(tCtx.Name) == "" { + return endpoint, nil + } + // Deep 内置 task 工具名固定为 "task";为兼容可能的大小写/空白,仅做不区分大小写匹配。 + if !strings.EqualFold(strings.TrimSpace(tCtx.Name), "task") { + return endpoint, nil + } + + // 已在 task 执行链中:拒绝继续委派,直接报错让上层快速终止。 + if ctx != nil { + if v, ok := ctx.Value(nestedTaskCtxKey{}).(bool); ok && v { + return func(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (string, error) { + // Important: return a tool result text (not an error) to avoid hard-stopping the whole multi-agent run. + // The nested task is still prevented from spawning another sub-agent, so recursion is avoided. + _ = argumentsInJSON + _ = opts + return "Nested task delegation is forbidden (already inside a sub-agent delegation chain) to avoid infinite delegation. Please continue the work using the current agent's tools.", nil + }, nil + } + } + + // 标记当前 task 调用链,确保子代理内的再次 task 调用能检测到嵌套。 + return func(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (string, error) { + ctx2 := ctx + if ctx2 == nil { + ctx2 = context.Background() + } + ctx2 = context.WithValue(ctx2, nestedTaskCtxKey{}, true) + return endpoint(ctx2, argumentsInJSON, opts...) + }, nil +} + diff --git a/internal/multiagent/runner.go b/internal/multiagent/runner.go new file mode 100644 index 00000000..835b6eee --- /dev/null +++ b/internal/multiagent/runner.go @@ -0,0 +1,1068 @@ +// Package multiagent 使用 CloudWeGo Eino 的 DeepAgent(adk/prebuilt/deep)编排多代理,MCP 工具经 einomcp 桥接到现有 Agent。 +package multiagent + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net" + "net/http" + "sort" + "strings" + "sync" + "sync/atomic" + "time" + + "cyberstrike-ai/internal/agent" + "cyberstrike-ai/internal/agents" + "cyberstrike-ai/internal/config" + "cyberstrike-ai/internal/einomcp" + "cyberstrike-ai/internal/openai" + + einoopenai "github.com/cloudwego/eino-ext/components/model/openai" + "github.com/cloudwego/eino/adk" + "github.com/cloudwego/eino/adk/filesystem" + "github.com/cloudwego/eino/adk/prebuilt/deep" + "github.com/cloudwego/eino/compose" + "github.com/cloudwego/eino/schema" + "go.uber.org/zap" +) + +// RunResult 与单 Agent 循环结果字段对齐,便于复用存储与 SSE 收尾逻辑。 +type RunResult struct { + Response string + MCPExecutionIDs []string + LastReActInput string + LastReActOutput string +} + +// toolCallPendingInfo tracks a tool_call emitted to the UI so we can later +// correlate tool_result events (even when the framework omits ToolCallID) and +// avoid leaving the UI stuck in "running" state on recoverable errors. +type toolCallPendingInfo struct { + ToolCallID string + ToolName string + EinoAgent string + EinoRole string +} + +// RunDeepAgent 使用 Eino DeepAgent 执行一轮对话(流式事件通过 progress 回调输出)。 +func RunDeepAgent( + ctx context.Context, + appCfg *config.Config, + ma *config.MultiAgentConfig, + ag *agent.Agent, + logger *zap.Logger, + conversationID string, + userMessage string, + history []agent.ChatMessage, + roleTools []string, + progress func(eventType, message string, data interface{}), + agentsMarkdownDir string, +) (*RunResult, error) { + if appCfg == nil || ma == nil || ag == nil { + return nil, fmt.Errorf("multiagent: 配置或 Agent 为空") + } + + effectiveSubs := ma.SubAgents + var orch *agents.OrchestratorMarkdown + if strings.TrimSpace(agentsMarkdownDir) != "" { + load, merr := agents.LoadMarkdownAgentsDir(agentsMarkdownDir) + if merr != nil { + if logger != nil { + logger.Warn("加载 agents 目录 Markdown 失败,沿用 config 中的 sub_agents", zap.Error(merr)) + } + } else { + effectiveSubs = agents.MergeYAMLAndMarkdown(ma.SubAgents, load.SubAgents) + orch = load.Orchestrator + } + } + if ma.WithoutGeneralSubAgent && len(effectiveSubs) == 0 { + return nil, fmt.Errorf("multi_agent.without_general_sub_agent 为 true 时,必须在 multi_agent.sub_agents 或 agents 目录 Markdown 中配置至少一个子代理") + } + + einoLoc, einoSkillMW, einoFSTools, einoErr := prepareEinoSkills(ctx, appCfg.SkillsDir, ma, logger) + if einoErr != nil { + return nil, einoErr + } + + holder := &einomcp.ConversationHolder{} + holder.Set(conversationID) + + var mcpIDsMu sync.Mutex + var mcpIDs []string + recorder := func(id string) { + if id == "" { + return + } + mcpIDsMu.Lock() + mcpIDs = append(mcpIDs, id) + mcpIDsMu.Unlock() + } + + // 与单代理流式一致:在 response_start / response_delta 的 data 中带当前 mcpExecutionIds,供主聊天绑定复制与展示。 + snapshotMCPIDs := func() []string { + mcpIDsMu.Lock() + defer mcpIDsMu.Unlock() + out := make([]string, len(mcpIDs)) + copy(out, mcpIDs) + return out + } + + mainDefs := ag.ToolsForRole(roleTools) + toolOutputChunk := func(toolName, toolCallID, chunk string) { + // When toolCallId is missing, frontend ignores tool_result_delta. + if progress == nil || toolCallID == "" { + return + } + progress("tool_result_delta", chunk, map[string]interface{}{ + "toolName": toolName, + "toolCallId": toolCallID, + // index/total/iteration are optional for UI; we don't know them in this bridge. + "index": 0, + "total": 0, + "iteration": 0, + "source": "eino", + }) + } + + mainTools, err := einomcp.ToolsFromDefinitions(ag, holder, mainDefs, recorder, toolOutputChunk) + if err != nil { + return nil, err + } + + httpClient := &http.Client{ + Timeout: 30 * time.Minute, + Transport: &http.Transport{ + DialContext: (&net.Dialer{ + Timeout: 300 * time.Second, + KeepAlive: 300 * time.Second, + }).DialContext, + MaxIdleConns: 100, + MaxIdleConnsPerHost: 10, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 30 * time.Second, + ResponseHeaderTimeout: 60 * time.Minute, + }, + } + + // 若配置为 Claude provider,注入自动桥接 transport,对 Eino 透明走 Anthropic Messages API + httpClient = openai.NewEinoHTTPClient(&appCfg.OpenAI, httpClient) + + baseModelCfg := &einoopenai.ChatModelConfig{ + APIKey: appCfg.OpenAI.APIKey, + BaseURL: strings.TrimSuffix(appCfg.OpenAI.BaseURL, "/"), + Model: appCfg.OpenAI.Model, + HTTPClient: httpClient, + } + + deepMaxIter := ma.MaxIteration + if deepMaxIter <= 0 { + deepMaxIter = appCfg.Agent.MaxIterations + } + if deepMaxIter <= 0 { + deepMaxIter = 40 + } + + subDefaultIter := ma.SubAgentMaxIterations + if subDefaultIter <= 0 { + subDefaultIter = 20 + } + + subAgents := make([]adk.Agent, 0, len(effectiveSubs)) + for _, sub := range effectiveSubs { + id := strings.TrimSpace(sub.ID) + if id == "" { + return nil, fmt.Errorf("multi_agent.sub_agents 中存在空的 id") + } + name := strings.TrimSpace(sub.Name) + if name == "" { + name = id + } + desc := strings.TrimSpace(sub.Description) + if desc == "" { + desc = fmt.Sprintf("Specialist agent %s for penetration testing workflow.", id) + } + instr := strings.TrimSpace(sub.Instruction) + if instr == "" { + instr = "你是 CyberStrikeAI 中的专业子代理,在授权渗透测试场景下协助完成用户委托的子任务。优先使用可用工具获取证据,回答简洁专业。" + } + + roleTools := sub.RoleTools + bind := strings.TrimSpace(sub.BindRole) + if bind != "" && appCfg.Roles != nil { + if r, ok := appCfg.Roles[bind]; ok && r.Enabled { + if len(roleTools) == 0 && len(r.Tools) > 0 { + roleTools = r.Tools + } + if len(r.Skills) > 0 { + var b strings.Builder + b.WriteString(instr) + b.WriteString("\n\n本角色推荐优先通过 Eino `skill` 工具(渐进式披露)加载的技能包 name:") + for i, s := range r.Skills { + if i > 0 { + b.WriteString("、") + } + b.WriteString(s) + } + b.WriteString("。") + instr = b.String() + } + } + } + + subModel, err := einoopenai.NewChatModel(ctx, baseModelCfg) + if err != nil { + return nil, fmt.Errorf("子代理 %q ChatModel: %w", id, err) + } + + subDefs := ag.ToolsForRole(roleTools) + subTools, err := einomcp.ToolsFromDefinitions(ag, holder, subDefs, recorder, toolOutputChunk) + if err != nil { + return nil, fmt.Errorf("子代理 %q 工具: %w", id, err) + } + + subMax := sub.MaxIterations + if subMax <= 0 { + subMax = subDefaultIter + } + + subSumMw, err := newEinoSummarizationMiddleware(ctx, subModel, appCfg, logger) + if err != nil { + return nil, fmt.Errorf("子代理 %q summarization 中间件: %w", id, err) + } + + var subHandlers []adk.ChatModelAgentMiddleware + if einoSkillMW != nil { + if einoFSTools && einoLoc != nil { + subFs, fsErr := subAgentFilesystemMiddleware(ctx, einoLoc) + if fsErr != nil { + return nil, fmt.Errorf("子代理 %q filesystem 中间件: %w", id, fsErr) + } + subHandlers = append(subHandlers, subFs) + } + subHandlers = append(subHandlers, einoSkillMW) + } + subHandlers = append(subHandlers, subSumMw) + + sa, err := adk.NewChatModelAgent(ctx, &adk.ChatModelAgentConfig{ + Name: id, + Description: desc, + Instruction: instr, + Model: subModel, + ToolsConfig: adk.ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: subTools, + UnknownToolsHandler: einomcp.UnknownToolReminderHandler(), + ToolCallMiddlewares: []compose.ToolMiddleware{ + {Invokable: softRecoveryToolCallMiddleware()}, + }, + }, + EmitInternalEvents: true, + }, + MaxIterations: subMax, + Handlers: subHandlers, + }) + if err != nil { + return nil, fmt.Errorf("子代理 %q: %w", id, err) + } + subAgents = append(subAgents, sa) + } + + mainModel, err := einoopenai.NewChatModel(ctx, baseModelCfg) + if err != nil { + return nil, fmt.Errorf("Deep 主模型: %w", err) + } + + mainSumMw, err := newEinoSummarizationMiddleware(ctx, mainModel, appCfg, logger) + if err != nil { + return nil, fmt.Errorf("Deep 主代理 summarization 中间件: %w", err) + } + + // 与 deep.Config.Name 一致。子代理的 assistant 正文也会经 EmitInternalEvents 流出,若全部当主回复会重复(编排器总结 + 子代理原文)。 + orchestratorName := "cyberstrike-deep" + orchDescription := "Coordinates specialist agents and MCP tools for authorized security testing." + orchInstruction := strings.TrimSpace(ma.OrchestratorInstruction) + if orch != nil { + if strings.TrimSpace(orch.EinoName) != "" { + orchestratorName = strings.TrimSpace(orch.EinoName) + } + if d := strings.TrimSpace(orch.Description); d != "" { + orchDescription = d + } + if ins := strings.TrimSpace(orch.Instruction); ins != "" { + orchInstruction = ins + } + } + var deepBackend filesystem.Backend + var deepShell filesystem.StreamingShell + if einoLoc != nil && einoFSTools { + deepBackend = einoLoc + deepShell = einoLoc + } + + deepHandlers := []adk.ChatModelAgentMiddleware{} + if einoSkillMW != nil { + deepHandlers = append(deepHandlers, einoSkillMW) + } + deepHandlers = append(deepHandlers, newNoNestedTaskMiddleware(), mainSumMw) + + da, err := deep.New(ctx, &deep.Config{ + Name: orchestratorName, + Description: orchDescription, + ChatModel: mainModel, + Instruction: orchInstruction, + SubAgents: subAgents, + WithoutGeneralSubAgent: ma.WithoutGeneralSubAgent, + WithoutWriteTodos: ma.WithoutWriteTodos, + MaxIteration: deepMaxIter, + Backend: deepBackend, + StreamingShell: deepShell, + // 防止 sub-agent 再调用 task(再委派 sub-agent),形成无限委派链。 + Handlers: deepHandlers, + ToolsConfig: adk.ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: mainTools, + UnknownToolsHandler: einomcp.UnknownToolReminderHandler(), + ToolCallMiddlewares: []compose.ToolMiddleware{ + {Invokable: softRecoveryToolCallMiddleware()}, + }, + }, + EmitInternalEvents: true, + }, + }) + if err != nil { + return nil, fmt.Errorf("deep.New: %w", err) + } + + baseMsgs := historyToMessages(history) + baseMsgs = append(baseMsgs, schema.UserMessage(userMessage)) + + streamsMainAssistant := func(agent string) bool { + return agent == "" || agent == orchestratorName + } + einoRoleTag := func(agent string) string { + if streamsMainAssistant(agent) { + return "orchestrator" + } + return "sub" + } + + var lastRunMsgs []adk.Message + var lastAssistant string + + // retryHints tracks the corrective hint to append for each retry attempt. + // Index i corresponds to the hint that will be appended on attempt i+1. + var retryHints []adk.Message + +attemptLoop: + for attempt := 0; attempt < maxToolCallRecoveryAttempts; attempt++ { + msgs := make([]adk.Message, 0, len(baseMsgs)+len(retryHints)) + msgs = append(msgs, baseMsgs...) + msgs = append(msgs, retryHints...) + + if attempt > 0 { + mcpIDsMu.Lock() + mcpIDs = mcpIDs[:0] + mcpIDsMu.Unlock() + } + + // 仅保留主代理最后一次 assistant 输出;每轮重试重置,避免拼接失败轮次的片段。 + lastAssistant = "" + var reasoningStreamSeq int64 + var einoSubReplyStreamSeq int64 + toolEmitSeen := make(map[string]struct{}) + var einoMainRound int + var einoLastAgent string + subAgentToolStep := make(map[string]int) + // Track tool calls emitted in this attempt so we can: + // - attach toolCallId to tool_result when framework omits it + // - flush running tool calls as failed when a recoverable tool execution error happens + pendingByID := make(map[string]toolCallPendingInfo) + pendingQueueByAgent := make(map[string][]string) + markPending := func(tc toolCallPendingInfo) { + if tc.ToolCallID == "" { + return + } + pendingByID[tc.ToolCallID] = tc + pendingQueueByAgent[tc.EinoAgent] = append(pendingQueueByAgent[tc.EinoAgent], tc.ToolCallID) + } + popNextPendingForAgent := func(agentName string) (toolCallPendingInfo, bool) { + q := pendingQueueByAgent[agentName] + for len(q) > 0 { + id := q[0] + q = q[1:] + pendingQueueByAgent[agentName] = q + if tc, ok := pendingByID[id]; ok { + delete(pendingByID, id) + return tc, true + } + } + return toolCallPendingInfo{}, false + } + removePendingByID := func(toolCallID string) { + if toolCallID == "" { + return + } + delete(pendingByID, toolCallID) + // queue cleanup is lazy in popNextPendingForAgent + } + flushAllPendingAsFailed := func(err error) { + if progress == nil { + pendingByID = make(map[string]toolCallPendingInfo) + pendingQueueByAgent = make(map[string][]string) + return + } + msg := "" + if err != nil { + msg = err.Error() + } + for _, tc := range pendingByID { + toolName := tc.ToolName + if strings.TrimSpace(toolName) == "" { + toolName = "unknown" + } + progress("tool_result", fmt.Sprintf("工具结果 (%s)", toolName), map[string]interface{}{ + "toolName": toolName, + "success": false, + "isError": true, + "result": msg, + "resultPreview": msg, + "toolCallId": tc.ToolCallID, + "conversationId": conversationID, + "einoAgent": tc.EinoAgent, + "einoRole": tc.EinoRole, + "source": "eino", + }) + } + pendingByID = make(map[string]toolCallPendingInfo) + pendingQueueByAgent = make(map[string][]string) + } + + runner := adk.NewRunner(ctx, adk.RunnerConfig{ + Agent: da, + EnableStreaming: true, + }) + iter := runner.Run(ctx, msgs) + + for { + ev, ok := iter.Next() + if !ok { + lastRunMsgs = msgs + break attemptLoop + } + if ev == nil { + continue + } + if ev.Err != nil { + canRetry := attempt+1 < maxToolCallRecoveryAttempts + + // Recoverable: API-level JSON argument validation error. + if canRetry && isRecoverableToolCallArgumentsJSONError(ev.Err) { + if logger != nil { + logger.Warn("eino: recoverable tool-call JSON error from model/API", zap.Error(ev.Err), zap.Int("attempt", attempt)) + } + retryHints = append(retryHints, toolCallArgumentsJSONRetryHint()) + if progress != nil { + progress("eino_recovery", toolCallArgumentsJSONRecoveryTimelineMessage(attempt), map[string]interface{}{ + "conversationId": conversationID, + "source": "eino", + "einoRetry": attempt, + "runIndex": attempt + 1, + "maxRuns": maxToolCallRecoveryAttempts, + "reason": "invalid_tool_arguments_json", + }) + } + continue attemptLoop + } + + // Recoverable: tool execution error (unknown sub-agent, tool not found, bad JSON in args, etc.). + if canRetry && isRecoverableToolExecutionError(ev.Err) { + if logger != nil { + logger.Warn("eino: recoverable tool execution error, will retry with corrective hint", + zap.Error(ev.Err), zap.Int("attempt", attempt)) + } + // Ensure UI/tool timeline doesn't get stuck at "running" for tool calls that + // will never receive a proper tool_result due to the recoverable error. + flushAllPendingAsFailed(ev.Err) + retryHints = append(retryHints, toolExecutionRetryHint()) + if progress != nil { + progress("eino_recovery", toolExecutionRecoveryTimelineMessage(attempt), map[string]interface{}{ + "conversationId": conversationID, + "source": "eino", + "einoRetry": attempt, + "runIndex": attempt + 1, + "maxRuns": maxToolCallRecoveryAttempts, + "reason": "tool_execution_error", + }) + } + continue attemptLoop + } + + // Non-recoverable error. + flushAllPendingAsFailed(ev.Err) + if progress != nil { + progress("error", ev.Err.Error(), map[string]interface{}{ + "conversationId": conversationID, + "source": "eino", + }) + } + return nil, ev.Err + } + if ev.AgentName != "" && progress != nil { + if streamsMainAssistant(ev.AgentName) { + if einoMainRound == 0 { + einoMainRound = 1 + progress("iteration", "", map[string]interface{}{ + "iteration": 1, + "einoScope": "main", + "einoRole": "orchestrator", + "einoAgent": orchestratorName, + "conversationId": conversationID, + "source": "eino", + }) + } else if einoLastAgent != "" && !streamsMainAssistant(einoLastAgent) { + einoMainRound++ + progress("iteration", "", map[string]interface{}{ + "iteration": einoMainRound, + "einoScope": "main", + "einoRole": "orchestrator", + "einoAgent": orchestratorName, + "conversationId": conversationID, + "source": "eino", + }) + } + } + einoLastAgent = ev.AgentName + progress("progress", fmt.Sprintf("[Eino] %s", ev.AgentName), map[string]interface{}{ + "conversationId": conversationID, + "einoAgent": ev.AgentName, + "einoRole": einoRoleTag(ev.AgentName), + }) + } + if ev.Output == nil || ev.Output.MessageOutput == nil { + continue + } + mv := ev.Output.MessageOutput + + if mv.IsStreaming && mv.MessageStream != nil { + streamHeaderSent := false + var reasoningStreamID string + var toolStreamFragments []schema.ToolCall + var subAssistantBuf strings.Builder + var subReplyStreamID string + var mainAssistantBuf strings.Builder + for { + chunk, rerr := mv.MessageStream.Recv() + if rerr != nil { + if errors.Is(rerr, io.EOF) { + break + } + if logger != nil { + logger.Warn("eino stream recv", zap.Error(rerr)) + } + break + } + if chunk == nil { + continue + } + if progress != nil && strings.TrimSpace(chunk.ReasoningContent) != "" { + if reasoningStreamID == "" { + reasoningStreamID = fmt.Sprintf("eino-reasoning-%s-%d", conversationID, atomic.AddInt64(&reasoningStreamSeq, 1)) + progress("thinking_stream_start", " ", map[string]interface{}{ + "streamId": reasoningStreamID, + "source": "eino", + "einoAgent": ev.AgentName, + "einoRole": einoRoleTag(ev.AgentName), + }) + } + progress("thinking_stream_delta", chunk.ReasoningContent, map[string]interface{}{ + "streamId": reasoningStreamID, + }) + } + if chunk.Content != "" { + if progress != nil && streamsMainAssistant(ev.AgentName) { + if !streamHeaderSent { + progress("response_start", "", map[string]interface{}{ + "conversationId": conversationID, + "mcpExecutionIds": snapshotMCPIDs(), + "messageGeneratedBy": "eino:" + ev.AgentName, + "einoRole": "orchestrator", + }) + streamHeaderSent = true + } + progress("response_delta", chunk.Content, map[string]interface{}{ + "conversationId": conversationID, + "mcpExecutionIds": snapshotMCPIDs(), + "einoRole": "orchestrator", + }) + mainAssistantBuf.WriteString(chunk.Content) + } else if !streamsMainAssistant(ev.AgentName) { + if progress != nil { + if subReplyStreamID == "" { + subReplyStreamID = fmt.Sprintf("eino-sub-reply-%s-%d", conversationID, atomic.AddInt64(&einoSubReplyStreamSeq, 1)) + progress("eino_agent_reply_stream_start", "", map[string]interface{}{ + "streamId": subReplyStreamID, + "einoAgent": ev.AgentName, + "einoRole": "sub", + "conversationId": conversationID, + "source": "eino", + }) + } + progress("eino_agent_reply_stream_delta", chunk.Content, map[string]interface{}{ + "streamId": subReplyStreamID, + "conversationId": conversationID, + }) + } + subAssistantBuf.WriteString(chunk.Content) + } + } + // 收集流式 tool_calls 全部分片;arguments 在最后一帧常为 "",需按 index/id 合并后才能展示 subagent_type/description。 + if len(chunk.ToolCalls) > 0 { + toolStreamFragments = append(toolStreamFragments, chunk.ToolCalls...) + } + } + if streamsMainAssistant(ev.AgentName) { + if s := strings.TrimSpace(mainAssistantBuf.String()); s != "" { + lastAssistant = s + } + } + if subAssistantBuf.Len() > 0 && progress != nil { + if s := strings.TrimSpace(subAssistantBuf.String()); s != "" { + if subReplyStreamID != "" { + progress("eino_agent_reply_stream_end", s, map[string]interface{}{ + "streamId": subReplyStreamID, + "einoAgent": ev.AgentName, + "einoRole": "sub", + "conversationId": conversationID, + "source": "eino", + }) + } else { + progress("eino_agent_reply", s, map[string]interface{}{ + "conversationId": conversationID, + "einoAgent": ev.AgentName, + "einoRole": "sub", + "source": "eino", + }) + } + } + } + var lastToolChunk *schema.Message + if merged := mergeStreamingToolCallFragments(toolStreamFragments); len(merged) > 0 { + lastToolChunk = &schema.Message{ToolCalls: merged} + } + tryEmitToolCallsOnce(lastToolChunk, ev.AgentName, orchestratorName, conversationID, progress, toolEmitSeen, subAgentToolStep, markPending) + continue + } + + msg, gerr := mv.GetMessage() + if gerr != nil || msg == nil { + continue + } + tryEmitToolCallsOnce(mergeMessageToolCalls(msg), ev.AgentName, orchestratorName, conversationID, progress, toolEmitSeen, subAgentToolStep, markPending) + + if mv.Role == schema.Assistant { + if progress != nil && strings.TrimSpace(msg.ReasoningContent) != "" { + progress("thinking", strings.TrimSpace(msg.ReasoningContent), map[string]interface{}{ + "conversationId": conversationID, + "source": "eino", + "einoAgent": ev.AgentName, + "einoRole": einoRoleTag(ev.AgentName), + }) + } + body := strings.TrimSpace(msg.Content) + if body != "" { + if streamsMainAssistant(ev.AgentName) { + if progress != nil { + progress("response_start", "", map[string]interface{}{ + "conversationId": conversationID, + "mcpExecutionIds": snapshotMCPIDs(), + "messageGeneratedBy": "eino:" + ev.AgentName, + "einoRole": "orchestrator", + }) + progress("response_delta", body, map[string]interface{}{ + "conversationId": conversationID, + "mcpExecutionIds": snapshotMCPIDs(), + "einoRole": "orchestrator", + }) + } + lastAssistant = body + } else if progress != nil { + progress("eino_agent_reply", body, map[string]interface{}{ + "conversationId": conversationID, + "einoAgent": ev.AgentName, + "einoRole": "sub", + "source": "eino", + }) + } + } + } + + if mv.Role == schema.Tool && progress != nil { + toolName := msg.ToolName + if toolName == "" { + toolName = mv.ToolName + } + + // bridge 工具在 res.IsError=true 时会返回带前缀的内容;这里解析为 success/isError,避免前端误判为成功。 + content := msg.Content + isErr := false + if strings.HasPrefix(content, einomcp.ToolErrorPrefix) { + isErr = true + content = strings.TrimPrefix(content, einomcp.ToolErrorPrefix) + } + + preview := content + if len(preview) > 200 { + preview = preview[:200] + "..." + } + data := map[string]interface{}{ + "toolName": toolName, + "success": !isErr, + "isError": isErr, + "result": content, + "resultPreview": preview, + "conversationId": conversationID, + "einoAgent": ev.AgentName, + "einoRole": einoRoleTag(ev.AgentName), + "source": "eino", + } + toolCallID := strings.TrimSpace(msg.ToolCallID) + // Some framework paths (e.g. UnknownToolsHandler) may omit ToolCallID on tool messages. + // Infer from the tool_call emission order for this agent to keep UI state consistent. + if toolCallID == "" { + // In some internal tool execution paths, ev.AgentName may be empty for tool-role + // messages. Try several fallbacks to avoid leaving UI tool_call status stuck. + if inferred, ok := popNextPendingForAgent(ev.AgentName); ok { + toolCallID = inferred.ToolCallID + } else if inferred, ok := popNextPendingForAgent(orchestratorName); ok { + toolCallID = inferred.ToolCallID + } else if inferred, ok := popNextPendingForAgent(""); ok { + toolCallID = inferred.ToolCallID + } else { + // last resort: pick any pending toolCallID + for id := range pendingByID { + toolCallID = id + delete(pendingByID, id) + break + } + } + } else { + removePendingByID(toolCallID) + } + if toolCallID != "" { + data["toolCallId"] = toolCallID + } + progress("tool_result", fmt.Sprintf("工具结果 (%s)", toolName), data) + } + } + } + + mcpIDsMu.Lock() + ids := append([]string(nil), mcpIDs...) + mcpIDsMu.Unlock() + + histJSON, _ := json.Marshal(lastRunMsgs) + cleaned := strings.TrimSpace(lastAssistant) + cleaned = dedupeRepeatedParagraphs(cleaned, 80) + cleaned = dedupeParagraphsByLineFingerprint(cleaned, 100) + out := &RunResult{ + Response: cleaned, + MCPExecutionIDs: ids, + LastReActInput: string(histJSON), + LastReActOutput: cleaned, + } + if out.Response == "" { + out.Response = "(Eino DeepAgent 已完成,但未捕获到助手文本输出。请查看过程详情或日志。)" + out.LastReActOutput = out.Response + } + return out, nil +} + +func historyToMessages(history []agent.ChatMessage) []adk.Message { + if len(history) == 0 { + return nil + } + // 放宽条数上限:跨轮历史交给 Eino Summarization(阈值对齐 openai.max_total_tokens)在调用模型前压缩,避免在入队前硬截断为 40 条。 + const maxHistoryMessages = 300 + start := 0 + if len(history) > maxHistoryMessages { + start = len(history) - maxHistoryMessages + } + out := make([]adk.Message, 0, len(history[start:])) + for _, h := range history[start:] { + switch h.Role { + case "user": + if strings.TrimSpace(h.Content) != "" { + out = append(out, schema.UserMessage(h.Content)) + } + case "assistant": + if strings.TrimSpace(h.Content) == "" && len(h.ToolCalls) > 0 { + continue + } + if strings.TrimSpace(h.Content) != "" { + out = append(out, schema.AssistantMessage(h.Content, nil)) + } + default: + continue + } + } + return out +} + +// mergeStreamingToolCallFragments 将流式多帧的 ToolCall 按 index 合并 arguments(与 schema.concatToolCalls 行为一致)。 +func mergeStreamingToolCallFragments(fragments []schema.ToolCall) []schema.ToolCall { + if len(fragments) == 0 { + return nil + } + m, err := schema.ConcatMessages([]*schema.Message{{ToolCalls: fragments}}) + if err != nil || m == nil { + return fragments + } + return m.ToolCalls +} + +// mergeMessageToolCalls 非流式路径上若仍带分片式 tool_calls,合并后再上报 UI。 +func mergeMessageToolCalls(msg *schema.Message) *schema.Message { + if msg == nil || len(msg.ToolCalls) == 0 { + return msg + } + m, err := schema.ConcatMessages([]*schema.Message{msg}) + if err != nil || m == nil { + return msg + } + out := *msg + out.ToolCalls = m.ToolCalls + return &out +} + +// toolCallStableID 用于流式阶段去重;OpenAI 流式常先给 index 后补 id。 +func toolCallStableID(tc schema.ToolCall) string { + if tc.ID != "" { + return tc.ID + } + if tc.Index != nil { + return fmt.Sprintf("idx:%d", *tc.Index) + } + return "" +} + +// toolCallDisplayName 避免前端「未知工具」:DeepAgent 内置 task 等可能延迟写入 function.name。 +func toolCallDisplayName(tc schema.ToolCall) string { + if n := strings.TrimSpace(tc.Function.Name); n != "" { + return n + } + if n := strings.TrimSpace(tc.Type); n != "" && !strings.EqualFold(n, "function") { + return n + } + return "task" +} + +// toolCallsSignatureFlush 用于去重键;无 id/index 时用占位 pos,避免流末帧缺 id 时整条工具事件丢失。 +func toolCallsSignatureFlush(msg *schema.Message) string { + if msg == nil || len(msg.ToolCalls) == 0 { + return "" + } + parts := make([]string, 0, len(msg.ToolCalls)) + for i, tc := range msg.ToolCalls { + id := toolCallStableID(tc) + if id == "" { + id = fmt.Sprintf("pos:%d", i) + } + parts = append(parts, id+"|"+toolCallDisplayName(tc)) + } + sort.Strings(parts) + return strings.Join(parts, ";") +} + +// toolCallsRichSignature 用于去重:同一次流式已上报后,紧随其后的非流式消息常带相同 tool_calls。 +func toolCallsRichSignature(msg *schema.Message) string { + base := toolCallsSignatureFlush(msg) + if base == "" { + return "" + } + parts := make([]string, 0, len(msg.ToolCalls)) + for _, tc := range msg.ToolCalls { + id := toolCallStableID(tc) + arg := tc.Function.Arguments + if len(arg) > 240 { + arg = arg[:240] + } + parts = append(parts, id+":"+arg) + } + sort.Strings(parts) + return base + "|" + strings.Join(parts, ";") +} + +func tryEmitToolCallsOnce( + msg *schema.Message, + agentName, orchestratorName, conversationID string, + progress func(string, string, interface{}), + seen map[string]struct{}, + subAgentToolStep map[string]int, + markPending func(toolCallPendingInfo), +) { + if msg == nil || len(msg.ToolCalls) == 0 || progress == nil || seen == nil { + return + } + if toolCallsSignatureFlush(msg) == "" { + return + } + sig := agentName + "\x1e" + toolCallsRichSignature(msg) + if _, ok := seen[sig]; ok { + return + } + seen[sig] = struct{}{} + emitToolCallsFromMessage(msg, agentName, orchestratorName, conversationID, progress, subAgentToolStep, markPending) +} + +func emitToolCallsFromMessage( + msg *schema.Message, + agentName, orchestratorName, conversationID string, + progress func(string, string, interface{}), + subAgentToolStep map[string]int, + markPending func(toolCallPendingInfo), +) { + if msg == nil || len(msg.ToolCalls) == 0 || progress == nil { + return + } + if subAgentToolStep == nil { + subAgentToolStep = make(map[string]int) + } + isSubToolRound := agentName != "" && agentName != orchestratorName + if isSubToolRound { + subAgentToolStep[agentName]++ + n := subAgentToolStep[agentName] + progress("iteration", "", map[string]interface{}{ + "iteration": n, + "einoScope": "sub", + "einoRole": "sub", + "einoAgent": agentName, + "conversationId": conversationID, + "source": "eino", + }) + } + role := "orchestrator" + if isSubToolRound { + role = "sub" + } + progress("tool_calls_detected", fmt.Sprintf("检测到 %d 个工具调用", len(msg.ToolCalls)), map[string]interface{}{ + "count": len(msg.ToolCalls), + "conversationId": conversationID, + "source": "eino", + "einoAgent": agentName, + "einoRole": role, + }) + for idx, tc := range msg.ToolCalls { + argStr := strings.TrimSpace(tc.Function.Arguments) + if argStr == "" && len(tc.Extra) > 0 { + if b, mErr := json.Marshal(tc.Extra); mErr == nil { + argStr = string(b) + } + } + var argsObj map[string]interface{} + if argStr != "" { + if uErr := json.Unmarshal([]byte(argStr), &argsObj); uErr != nil || argsObj == nil { + argsObj = map[string]interface{}{"_raw": argStr} + } + } + display := toolCallDisplayName(tc) + toolCallID := tc.ID + if toolCallID == "" && tc.Index != nil { + toolCallID = fmt.Sprintf("eino-stream-%d", *tc.Index) + } + // Record pending tool calls for later tool_result correlation / recovery flushing. + // We intentionally record even for unknown tools to avoid "running" badge getting stuck. + if markPending != nil && toolCallID != "" { + markPending(toolCallPendingInfo{ + ToolCallID: toolCallID, + ToolName: display, + EinoAgent: agentName, + EinoRole: role, + }) + } + progress("tool_call", fmt.Sprintf("正在调用工具: %s", display), map[string]interface{}{ + "toolName": display, + "arguments": argStr, + "argumentsObj": argsObj, + "toolCallId": toolCallID, + "index": idx + 1, + "total": len(msg.ToolCalls), + "conversationId": conversationID, + "source": "eino", + "einoAgent": agentName, + "einoRole": role, + }) + } +} + +// dedupeRepeatedParagraphs 去掉完全相同的连续/重复段落,缓解多代理各自复述同一列表。 +func dedupeRepeatedParagraphs(s string, minLen int) string { + if s == "" || minLen <= 0 { + return s + } + paras := strings.Split(s, "\n\n") + var out []string + seen := make(map[string]bool) + for _, p := range paras { + t := strings.TrimSpace(p) + if len(t) < minLen { + out = append(out, p) + continue + } + if seen[t] { + continue + } + seen[t] = true + out = append(out, p) + } + return strings.TrimSpace(strings.Join(out, "\n\n")) +} + +// dedupeParagraphsByLineFingerprint 去掉「正文行集合相同」的重复段落(开场白略不同也会合并),缓解多代理各写一遍目录清单。 +func dedupeParagraphsByLineFingerprint(s string, minParaLen int) string { + if s == "" || minParaLen <= 0 { + return s + } + paras := strings.Split(s, "\n\n") + var out []string + seen := make(map[string]bool) + for _, p := range paras { + t := strings.TrimSpace(p) + if len(t) < minParaLen { + out = append(out, p) + continue + } + fp := paragraphLineFingerprint(t) + // 指纹仅在「≥4 条非空行」时有效;单行/短段落长回复(如自我介绍)fp 为空,必须保留,否则会误删全文并触发「未捕获到助手文本」占位。 + if fp == "" { + out = append(out, p) + continue + } + if seen[fp] { + continue + } + seen[fp] = true + out = append(out, p) + } + return strings.TrimSpace(strings.Join(out, "\n\n")) +} + +func paragraphLineFingerprint(t string) string { + lines := strings.Split(t, "\n") + norm := make([]string, 0, len(lines)) + for _, L := range lines { + s := strings.TrimSpace(L) + if s == "" { + continue + } + norm = append(norm, s) + } + if len(norm) < 4 { + return "" + } + sort.Strings(norm) + return strings.Join(norm, "\x1e") +} diff --git a/internal/multiagent/tool_args_json_retry.go b/internal/multiagent/tool_args_json_retry.go new file mode 100644 index 00000000..d6d79971 --- /dev/null +++ b/internal/multiagent/tool_args_json_retry.go @@ -0,0 +1,51 @@ +package multiagent + +import ( + "fmt" + "strings" + + "github.com/cloudwego/eino/schema" +) + +// maxToolCallRecoveryAttempts 含首次运行:首次 + 自动重试次数。 +// 例如为 3 表示最多共 3 次完整 DeepAgent 运行(2 次失败后各追加一条纠错提示)。 +// 该常量同时用于 JSON 参数错误和工具执行错误(如子代理名称不存在)的恢复重试。 +const maxToolCallRecoveryAttempts = 5 + +// toolCallArgumentsJSONRetryHint 追加在用户消息后,提示模型输出合法 JSON 工具参数(部分云厂商会在流式阶段校验 arguments)。 +func toolCallArgumentsJSONRetryHint() *schema.Message { + return schema.UserMessage(`[系统提示] 上一次输出中,工具调用的 function.arguments 不是合法 JSON,接口已拒绝。请重新生成:每个 tool call 的 arguments 必须是完整、可解析的 JSON 对象字符串(键名用双引号,无多余逗号,括号配对)。不要输出截断或不完整的 JSON。 + +[System] Your previous tool call used invalid JSON in function.arguments and was rejected by the API. Regenerate with strictly valid JSON objects only (double-quoted keys, matched braces, no trailing commas).`) +} + +// toolCallArgumentsJSONRecoveryTimelineMessage 供 eino_recovery 事件落库与前端时间线展示。 +func toolCallArgumentsJSONRecoveryTimelineMessage(attempt int) string { + return fmt.Sprintf( + "接口拒绝了无效的工具参数 JSON。已向对话追加系统提示并要求模型重新生成合法的 function.arguments。"+ + "当前为第 %d/%d 轮完整运行。\n\n"+ + "The API rejected invalid JSON in tool arguments. A system hint was appended. This is full run %d of %d.", + attempt+1, maxToolCallRecoveryAttempts, attempt+1, maxToolCallRecoveryAttempts, + ) +} + +// isRecoverableToolCallArgumentsJSONError 判断是否为「工具参数非合法 JSON」类流式错误,可通过追加提示后重跑一轮。 +func isRecoverableToolCallArgumentsJSONError(err error) bool { + if err == nil { + return false + } + s := strings.ToLower(err.Error()) + if !strings.Contains(s, "json") { + return false + } + if strings.Contains(s, "function.arguments") || strings.Contains(s, "function arguments") { + return true + } + if strings.Contains(s, "invalidparameter") && strings.Contains(s, "json") { + return true + } + if strings.Contains(s, "must be in json format") { + return true + } + return false +} diff --git a/internal/multiagent/tool_args_json_retry_test.go b/internal/multiagent/tool_args_json_retry_test.go new file mode 100644 index 00000000..41264eb0 --- /dev/null +++ b/internal/multiagent/tool_args_json_retry_test.go @@ -0,0 +1,17 @@ +package multiagent + +import ( + "errors" + "testing" +) + +func TestIsRecoverableToolCallArgumentsJSONError(t *testing.T) { + yes := errors.New(`failed to receive stream chunk: error, <400> InternalError.Algo.InvalidParameter: The "function.arguments" parameter of the code model must be in JSON format.`) + if !isRecoverableToolCallArgumentsJSONError(yes) { + t.Fatal("expected recoverable for function.arguments + JSON") + } + no := errors.New("unrelated network failure") + if isRecoverableToolCallArgumentsJSONError(no) { + t.Fatal("expected not recoverable") + } +} diff --git a/internal/multiagent/tool_error_middleware.go b/internal/multiagent/tool_error_middleware.go new file mode 100644 index 00000000..10158fc2 --- /dev/null +++ b/internal/multiagent/tool_error_middleware.go @@ -0,0 +1,131 @@ +package multiagent + +import ( + "context" + "encoding/json" + "fmt" + "strings" + + "github.com/cloudwego/eino/compose" +) + +// softRecoveryToolCallMiddleware returns an InvokableToolMiddleware that catches +// specific recoverable errors from tool execution (JSON parse errors, tool-not-found, +// etc.) and converts them into soft errors: nil error + descriptive error content +// returned to the LLM. This allows the model to self-correct within the same +// iteration rather than crashing the entire graph and requiring a full replay. +// +// Without this middleware, a JSON parse failure in any tool's InvokableRun propagates +// as a hard error through the Eino ToolsNode → [NodeRunError] → ev.Err, which +// either triggers the full-replay retry loop (expensive) or terminates the run +// entirely once retries are exhausted. With it, the LLM simply sees an error message +// in the tool result and can adjust its next tool call accordingly. +func softRecoveryToolCallMiddleware() compose.InvokableToolMiddleware { + return func(next compose.InvokableToolEndpoint) compose.InvokableToolEndpoint { + return func(ctx context.Context, input *compose.ToolInput) (*compose.ToolOutput, error) { + output, err := next(ctx, input) + if err == nil { + return output, nil + } + if !isSoftRecoverableToolError(err) { + return output, err + } + // Convert the hard error into a soft error: the LLM will see this + // message as the tool's output and can self-correct. + msg := buildSoftRecoveryMessage(input.Name, input.Arguments, err) + return &compose.ToolOutput{Result: msg}, nil + } + } +} + +// isSoftRecoverableToolError determines whether a tool execution error should be +// silently converted to a tool-result message rather than crashing the graph. +func isSoftRecoverableToolError(err error) bool { + if err == nil { + return false + } + s := strings.ToLower(err.Error()) + + // JSON unmarshal/parse failures — the model generated truncated or malformed arguments. + if isJSONRelatedError(s) { + return true + } + + // Sub-agent type not found (from deep/task_tool.go) + if strings.Contains(s, "subagent type") && strings.Contains(s, "not found") { + return true + } + + // Tool not found in ToolsNode indexes + if strings.Contains(s, "tool") && strings.Contains(s, "not found") { + return true + } + + return false +} + +// isJSONRelatedError checks whether an error string indicates a JSON parsing problem. +func isJSONRelatedError(lower string) bool { + if !strings.Contains(lower, "json") { + return false + } + jsonIndicators := []string{ + "unexpected end of json", + "unmarshal", + "invalid character", + "cannot unmarshal", + "invalid tool arguments", + "failed to unmarshal", + "must be in json format", + "unexpected eof", + } + for _, ind := range jsonIndicators { + if strings.Contains(lower, ind) { + return true + } + } + return false +} + +// buildSoftRecoveryMessage creates a bilingual error message that the LLM can act on. +func buildSoftRecoveryMessage(toolName, arguments string, err error) string { + // Truncate arguments preview to avoid flooding the context. + argPreview := arguments + if len(argPreview) > 300 { + argPreview = argPreview[:300] + "... (truncated)" + } + + // Try to determine if it's specifically a JSON parse error for a friendlier message. + errStr := err.Error() + var jsonErr *json.SyntaxError + isJSONErr := strings.Contains(strings.ToLower(errStr), "json") || + strings.Contains(strings.ToLower(errStr), "unmarshal") + _ = jsonErr // suppress unused + + if isJSONErr { + return fmt.Sprintf( + "[Tool Error] The arguments for tool '%s' are not valid JSON and could not be parsed.\n"+ + "Error: %s\n"+ + "Arguments received: %s\n\n"+ + "Please fix the JSON (ensure double-quoted keys, matched braces/brackets, no trailing commas, "+ + "no truncation) and call the tool again.\n\n"+ + "[工具错误] 工具 '%s' 的参数不是合法 JSON,无法解析。\n"+ + "错误:%s\n"+ + "收到的参数:%s\n\n"+ + "请修正 JSON(确保双引号键名、括号配对、无尾部逗号、无截断),然后重新调用工具。", + toolName, errStr, argPreview, + toolName, errStr, argPreview, + ) + } + + return fmt.Sprintf( + "[Tool Error] Tool '%s' execution failed: %s\n"+ + "Arguments: %s\n\n"+ + "Please review the available tools and their expected arguments, then retry.\n\n"+ + "[工具错误] 工具 '%s' 执行失败:%s\n"+ + "参数:%s\n\n"+ + "请检查可用工具及其参数要求,然后重试。", + toolName, errStr, argPreview, + toolName, errStr, argPreview, + ) +} diff --git a/internal/multiagent/tool_error_middleware_test.go b/internal/multiagent/tool_error_middleware_test.go new file mode 100644 index 00000000..d87e417b --- /dev/null +++ b/internal/multiagent/tool_error_middleware_test.go @@ -0,0 +1,166 @@ +package multiagent + +import ( + "context" + "encoding/json" + "errors" + "testing" + + "github.com/cloudwego/eino/compose" +) + +func TestIsSoftRecoverableToolError(t *testing.T) { + tests := []struct { + name string + err error + expected bool + }{ + { + name: "nil error", + err: nil, + expected: false, + }, + { + name: "unexpected end of JSON input", + err: errors.New("unexpected end of JSON input"), + expected: true, + }, + { + name: "failed to unmarshal task tool input json", + err: errors.New("failed to unmarshal task tool input json: unexpected end of JSON input"), + expected: true, + }, + { + name: "invalid tool arguments JSON", + err: errors.New("invalid tool arguments JSON: unexpected end of JSON input"), + expected: true, + }, + { + name: "json invalid character", + err: errors.New(`invalid character '}' looking for beginning of value in JSON`), + expected: true, + }, + { + name: "subagent type not found", + err: errors.New("subagent type recon_agent not found"), + expected: true, + }, + { + name: "tool not found", + err: errors.New("tool nmap_scan not found in toolsNode indexes"), + expected: true, + }, + { + name: "unrelated network error", + err: errors.New("connection refused"), + expected: false, + }, + { + name: "context cancelled", + err: context.Canceled, + expected: false, + }, + { + name: "real json unmarshal error", + err: func() error { + var v map[string]interface{} + return json.Unmarshal([]byte(`{"key": `), &v) + }(), + expected: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := isSoftRecoverableToolError(tt.err) + if got != tt.expected { + t.Errorf("isSoftRecoverableToolError(%v) = %v, want %v", tt.err, got, tt.expected) + } + }) + } +} + +func TestSoftRecoveryToolCallMiddleware_PassesThrough(t *testing.T) { + mw := softRecoveryToolCallMiddleware() + called := false + next := func(ctx context.Context, input *compose.ToolInput) (*compose.ToolOutput, error) { + called = true + return &compose.ToolOutput{Result: "success"}, nil + } + wrapped := mw(next) + out, err := wrapped(context.Background(), &compose.ToolInput{ + Name: "test_tool", + Arguments: `{"key": "value"}`, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !called { + t.Fatal("next endpoint was not called") + } + if out.Result != "success" { + t.Fatalf("expected 'success', got %q", out.Result) + } +} + +func TestSoftRecoveryToolCallMiddleware_ConvertsJSONError(t *testing.T) { + mw := softRecoveryToolCallMiddleware() + next := func(ctx context.Context, input *compose.ToolInput) (*compose.ToolOutput, error) { + return nil, errors.New("failed to unmarshal task tool input json: unexpected end of JSON input") + } + wrapped := mw(next) + out, err := wrapped(context.Background(), &compose.ToolInput{ + Name: "task", + Arguments: `{"subagent_type": "recon`, + }) + if err != nil { + t.Fatalf("expected nil error (soft recovery), got: %v", err) + } + if out == nil || out.Result == "" { + t.Fatal("expected non-empty recovery message") + } + if !containsAll(out.Result, "[Tool Error]", "task", "JSON") { + t.Fatalf("recovery message missing expected content: %s", out.Result) + } +} + +func TestSoftRecoveryToolCallMiddleware_PropagatesNonRecoverable(t *testing.T) { + mw := softRecoveryToolCallMiddleware() + origErr := errors.New("connection timeout to remote server") + next := func(ctx context.Context, input *compose.ToolInput) (*compose.ToolOutput, error) { + return nil, origErr + } + wrapped := mw(next) + _, err := wrapped(context.Background(), &compose.ToolInput{ + Name: "test_tool", + Arguments: `{}`, + }) + if err == nil { + t.Fatal("expected error to propagate for non-recoverable errors") + } + if err != origErr { + t.Fatalf("expected original error, got: %v", err) + } +} + +func containsAll(s string, subs ...string) bool { + for _, sub := range subs { + if !contains(s, sub) { + return false + } + } + return true +} + +func contains(s, sub string) bool { + return len(s) >= len(sub) && searchString(s, sub) +} + +func searchString(s, sub string) bool { + for i := 0; i <= len(s)-len(sub); i++ { + if s[i:i+len(sub)] == sub { + return true + } + } + return false +} diff --git a/internal/multiagent/tool_execution_retry.go b/internal/multiagent/tool_execution_retry.go new file mode 100644 index 00000000..c79f8a66 --- /dev/null +++ b/internal/multiagent/tool_execution_retry.go @@ -0,0 +1,76 @@ +package multiagent + +import ( + "fmt" + "strings" + + "github.com/cloudwego/eino/schema" +) + +// isRecoverableToolExecutionError detects tool-level execution errors that can be +// recovered by retrying with a corrective hint. These errors originate from eino +// framework internals (e.g. task_tool.go, tool_node.go) when the LLM produces +// invalid tool calls such as non-existent sub-agent types, malformed JSON arguments, +// or unregistered tool names. +func isRecoverableToolExecutionError(err error) bool { + if err == nil { + return false + } + s := strings.ToLower(err.Error()) + + // Sub-agent type not found (from deep/task_tool.go) + if strings.Contains(s, "subagent type") && strings.Contains(s, "not found") { + return true + } + + // Tool not found in toolsNode indexes (from compose/tool_node.go, when UnknownToolsHandler is nil) + if strings.Contains(s, "tool") && strings.Contains(s, "not found") { + return true + } + + // Invalid tool arguments JSON (from einomcp/mcp_tools.go or eino internals) + if strings.Contains(s, "invalid tool arguments json") { + return true + } + + // Failed to unmarshal task tool input json (from deep/task_tool.go) + if strings.Contains(s, "failed to unmarshal") && strings.Contains(s, "json") { + return true + } + + // Generic tool call stream/invoke failure wrapping the above + if (strings.Contains(s, "failed to stream tool call") || strings.Contains(s, "failed to invoke tool")) && + (strings.Contains(s, "not found") || strings.Contains(s, "json") || strings.Contains(s, "unmarshal")) { + return true + } + + return false +} + +// toolExecutionRetryHint returns a user message appended to the conversation to prompt +// the LLM to correct its tool call after a tool execution error. +func toolExecutionRetryHint() *schema.Message { + return schema.UserMessage(`[System] Your previous tool call failed because: +- The tool or sub-agent name you used does not exist, OR +- The tool call arguments were not valid JSON. + +Please carefully review the available tools and sub-agents listed in your context, use only exact registered names (case-sensitive), and ensure all arguments are well-formed JSON objects. Then retry your action. + +[系统提示] 上一次工具调用失败,可能原因: +- 你使用的工具名或子代理名称不存在; +- 工具调用参数不是合法 JSON。 + +请仔细检查上下文中列出的可用工具和子代理名称(须完全匹配、区分大小写),确保所有参数均为合法的 JSON 对象,然后重新执行。`) +} + +// toolExecutionRecoveryTimelineMessage returns a message for the eino_recovery event +// displayed in the UI timeline when a tool execution error triggers a retry. +func toolExecutionRecoveryTimelineMessage(attempt int) string { + return fmt.Sprintf( + "工具调用执行失败(工具/子代理名称不存在或参数 JSON 无效)。已向对话追加纠错提示并要求模型重新生成。"+ + "当前为第 %d/%d 轮完整运行。\n\n"+ + "Tool call execution failed (unknown tool/sub-agent name or invalid JSON arguments). "+ + "A corrective hint was appended. This is full run %d of %d.", + attempt+1, maxToolCallRecoveryAttempts, attempt+1, maxToolCallRecoveryAttempts, + ) +} diff --git a/internal/skillpackage/content.go b/internal/skillpackage/content.go new file mode 100644 index 00000000..851a5238 --- /dev/null +++ b/internal/skillpackage/content.go @@ -0,0 +1,165 @@ +package skillpackage + +import ( + "fmt" + "regexp" + "strings" +) + +var reH2 = regexp.MustCompile(`(?m)^##\s+(.+)$`) + +const summaryContentRunes = 6000 + +type markdownSection struct { + Heading string + Title string + Content string +} + +func splitMarkdownSections(body string) []markdownSection { + body = strings.TrimSpace(body) + if body == "" { + return nil + } + idxs := reH2.FindAllStringIndex(body, -1) + titles := reH2.FindAllStringSubmatch(body, -1) + if len(idxs) == 0 { + return []markdownSection{{ + Heading: "", + Title: "_body", + Content: body, + }} + } + var out []markdownSection + for i := range idxs { + title := strings.TrimSpace(titles[i][1]) + start := idxs[i][0] + end := len(body) + if i+1 < len(idxs) { + end = idxs[i+1][0] + } + chunk := strings.TrimSpace(body[start:end]) + out = append(out, markdownSection{ + Heading: "## " + title, + Title: title, + Content: chunk, + }) + } + return out +} + +func deriveSections(body string) []SkillSection { + md := splitMarkdownSections(body) + out := make([]SkillSection, 0, len(md)) + for _, ms := range md { + if ms.Title == "_body" { + continue + } + out = append(out, SkillSection{ + ID: slugifySectionID(ms.Title), + Title: ms.Title, + Heading: ms.Heading, + Level: 2, + }) + } + return out +} + +func slugifySectionID(title string) string { + title = strings.TrimSpace(strings.ToLower(title)) + if title == "" { + return "section" + } + var b strings.Builder + for _, r := range title { + switch { + case r >= 'a' && r <= 'z', r >= '0' && r <= '9': + b.WriteRune(r) + case r == ' ', r == '-', r == '_': + b.WriteRune('-') + } + } + s := strings.Trim(b.String(), "-") + if s == "" { + return "section" + } + return s +} + +func findSectionContent(sections []markdownSection, sec string) string { + sec = strings.TrimSpace(sec) + if sec == "" { + return "" + } + want := strings.ToLower(sec) + for _, s := range sections { + if strings.EqualFold(slugifySectionID(s.Title), want) || strings.EqualFold(s.Title, sec) { + return s.Content + } + if strings.EqualFold(strings.ReplaceAll(s.Title, " ", "-"), want) { + return s.Content + } + } + return "" +} + +func buildSummaryMarkdown(name, description string, tags []string, scripts []SkillScriptInfo, sections []SkillSection, body string) string { + var b strings.Builder + if description != "" { + b.WriteString(description) + b.WriteString("\n\n") + } + if len(tags) > 0 { + b.WriteString("**Tags**: ") + b.WriteString(strings.Join(tags, ", ")) + b.WriteString("\n\n") + } + if len(scripts) > 0 { + b.WriteString("### Bundled scripts\n\n") + for _, sc := range scripts { + line := "- `" + sc.RelPath + "`" + if sc.Description != "" { + line += " — " + sc.Description + } + b.WriteString(line) + b.WriteString("\n") + } + b.WriteString("\n") + } + if len(sections) > 0 { + b.WriteString("### Sections\n\n") + for _, sec := range sections { + line := "- **" + sec.ID + "**" + if sec.Title != "" && sec.Title != sec.ID { + line += ": " + sec.Title + } + b.WriteString(line) + b.WriteString("\n") + } + b.WriteString("\n") + } + mdSecs := splitMarkdownSections(body) + preview := body + if len(mdSecs) > 0 && mdSecs[0].Title != "_body" { + preview = mdSecs[0].Content + } + b.WriteString("### Preview (SKILL.md)\n\n") + b.WriteString(truncateRunes(strings.TrimSpace(preview), summaryContentRunes)) + b.WriteString("\n\n---\n\n_(Summary for admin UI. Agents use Eino `skill` tool for full SKILL.md progressive loading.)_") + if name != "" { + b.WriteString(fmt.Sprintf("\n\n_Skill name: %s_", name)) + } + return b.String() +} + +func truncateRunes(s string, max int) string { + if max <= 0 || s == "" { + return s + } + r := []rune(s) + if len(r) <= max { + return s + } + return string(r[:max]) + "…" +} + diff --git a/internal/skillpackage/frontmatter.go b/internal/skillpackage/frontmatter.go new file mode 100644 index 00000000..620f698d --- /dev/null +++ b/internal/skillpackage/frontmatter.go @@ -0,0 +1,114 @@ +package skillpackage + +import ( + "fmt" + "strings" + + "gopkg.in/yaml.v3" +) + +// ExtractSkillMDFrontMatterYAML returns the YAML source inside the first --- ... --- block and the markdown body. +func ExtractSkillMDFrontMatterYAML(raw []byte) (fmYAML string, body string, err error) { + text := strings.TrimPrefix(string(raw), "\ufeff") + if strings.TrimSpace(text) == "" { + return "", "", fmt.Errorf("SKILL.md is empty") + } + lines := strings.Split(text, "\n") + if len(lines) < 2 || strings.TrimSpace(lines[0]) != "---" { + return "", "", fmt.Errorf("SKILL.md must start with YAML front matter (---) per Agent Skills standard") + } + var fmLines []string + i := 1 + for i < len(lines) { + if strings.TrimSpace(lines[i]) == "---" { + break + } + fmLines = append(fmLines, lines[i]) + i++ + } + if i >= len(lines) { + return "", "", fmt.Errorf("SKILL.md: front matter must end with a line containing only ---") + } + body = strings.Join(lines[i+1:], "\n") + body = strings.TrimSpace(body) + fmYAML = strings.Join(fmLines, "\n") + return fmYAML, body, nil +} + +// ParseSkillMD parses SKILL.md YAML head + body. +func ParseSkillMD(raw []byte) (*SkillManifest, string, error) { + fmYAML, body, err := ExtractSkillMDFrontMatterYAML(raw) + if err != nil { + return nil, "", err + } + var m SkillManifest + if err := yaml.Unmarshal([]byte(fmYAML), &m); err != nil { + return nil, "", fmt.Errorf("SKILL.md front matter: %w", err) + } + return &m, body, nil +} + +type skillFrontMatterExport struct { + Name string `yaml:"name"` + Description string `yaml:"description"` + License string `yaml:"license,omitempty"` + Compatibility string `yaml:"compatibility,omitempty"` + Metadata map[string]any `yaml:"metadata,omitempty"` + AllowedTools string `yaml:"allowed-tools,omitempty"` +} + +// BuildSkillMD serializes SKILL.md per agentskills.io. +func BuildSkillMD(m *SkillManifest, body string) ([]byte, error) { + if m == nil { + return nil, fmt.Errorf("nil manifest") + } + fm := skillFrontMatterExport{ + Name: strings.TrimSpace(m.Name), + Description: strings.TrimSpace(m.Description), + License: strings.TrimSpace(m.License), + Compatibility: strings.TrimSpace(m.Compatibility), + AllowedTools: strings.TrimSpace(m.AllowedTools), + } + if len(m.Metadata) > 0 { + fm.Metadata = m.Metadata + } + head, err := yaml.Marshal(&fm) + if err != nil { + return nil, err + } + s := strings.TrimSpace(string(head)) + out := "---\n" + s + "\n---\n\n" + strings.TrimSpace(body) + "\n" + return []byte(out), nil +} + +func manifestTags(m *SkillManifest) []string { + if m == nil || m.Metadata == nil { + return nil + } + var out []string + if raw, ok := m.Metadata["tags"]; ok { + switch v := raw.(type) { + case []any: + for _, x := range v { + if s, ok := x.(string); ok && s != "" { + out = append(out, s) + } + } + case []string: + out = append(out, v...) + } + } + return out +} + +func versionFromMetadata(m *SkillManifest) string { + if m == nil || m.Metadata == nil { + return "" + } + if v, ok := m.Metadata["version"]; ok { + if s, ok := v.(string); ok { + return strings.TrimSpace(s) + } + } + return "" +} diff --git a/internal/skillpackage/io.go b/internal/skillpackage/io.go new file mode 100644 index 00000000..f89f4506 --- /dev/null +++ b/internal/skillpackage/io.go @@ -0,0 +1,200 @@ +package skillpackage + +import ( + "fmt" + "io/fs" + "os" + "path/filepath" + "strings" +) + +const ( + maxPackageFiles = 4000 + maxPackageDepth = 24 + maxScriptsDepth = 24 + defaultMaxRead = 10 << 20 +) + +// SafeRelPath resolves rel inside root (no ..). +func SafeRelPath(root, rel string) (string, error) { + rel = strings.TrimSpace(rel) + rel = filepath.ToSlash(rel) + rel = strings.TrimPrefix(rel, "/") + if rel == "" || rel == "." { + return "", fmt.Errorf("empty resource path") + } + if strings.Contains(rel, "..") { + return "", fmt.Errorf("invalid path %q", rel) + } + abs := filepath.Join(root, filepath.FromSlash(rel)) + cleanRoot := filepath.Clean(root) + cleanAbs := filepath.Clean(abs) + relOut, err := filepath.Rel(cleanRoot, cleanAbs) + if err != nil || relOut == ".." || strings.HasPrefix(relOut, ".."+string(filepath.Separator)) { + return "", fmt.Errorf("path escapes skill directory: %q", rel) + } + return cleanAbs, nil +} + +// ListPackageFiles lists files under a skill directory. +func ListPackageFiles(skillsRoot, skillID string) ([]PackageFileInfo, error) { + root := SkillDir(skillsRoot, skillID) + if _, err := ResolveSKILLPath(root); err != nil { + return nil, fmt.Errorf("skill %q: %w", skillID, err) + } + var out []PackageFileInfo + err := filepath.WalkDir(root, func(path string, d fs.DirEntry, err error) error { + if err != nil { + return err + } + rel, e := filepath.Rel(root, path) + if e != nil { + return e + } + if rel == "." { + return nil + } + depth := strings.Count(rel, string(os.PathSeparator)) + if depth > maxPackageDepth { + if d.IsDir() { + return filepath.SkipDir + } + return nil + } + if strings.HasPrefix(d.Name(), ".") { + if d.IsDir() { + return filepath.SkipDir + } + return nil + } + if len(out) >= maxPackageFiles { + return fmt.Errorf("skill package exceeds %d files", maxPackageFiles) + } + fi, err := d.Info() + if err != nil { + return err + } + out = append(out, PackageFileInfo{ + Path: filepath.ToSlash(rel), + Size: fi.Size(), + IsDir: d.IsDir(), + }) + return nil + }) + return out, err +} + +// ReadPackageFile reads a file relative to the skill package. +func ReadPackageFile(skillsRoot, skillID, relPath string, maxBytes int64) ([]byte, error) { + if maxBytes <= 0 { + maxBytes = defaultMaxRead + } + root := SkillDir(skillsRoot, skillID) + abs, err := SafeRelPath(root, relPath) + if err != nil { + return nil, err + } + fi, err := os.Stat(abs) + if err != nil { + return nil, err + } + if fi.IsDir() { + return nil, fmt.Errorf("path is a directory") + } + if fi.Size() > maxBytes { + return readFileHead(abs, maxBytes) + } + return os.ReadFile(abs) +} + +// WritePackageFile writes a file inside the skill package. +func WritePackageFile(skillsRoot, skillID, relPath string, content []byte) error { + root := SkillDir(skillsRoot, skillID) + if _, err := ResolveSKILLPath(root); err != nil { + return fmt.Errorf("skill %q: %w", skillID, err) + } + abs, err := SafeRelPath(root, relPath) + if err != nil { + return err + } + if err := os.MkdirAll(filepath.Dir(abs), 0755); err != nil { + return err + } + return os.WriteFile(abs, content, 0644) +} + +func readFileHead(path string, max int64) ([]byte, error) { + f, err := os.Open(path) + if err != nil { + return nil, err + } + defer f.Close() + buf := make([]byte, max) + n, err := f.Read(buf) + if err != nil && n == 0 { + return nil, err + } + return buf[:n], nil +} + +func listScripts(skillsRoot, skillID string) ([]SkillScriptInfo, error) { + root := filepath.Join(SkillDir(skillsRoot, skillID), "scripts") + st, err := os.Stat(root) + if err != nil { + if os.IsNotExist(err) { + return nil, nil + } + return nil, err + } + if !st.IsDir() { + return nil, nil + } + var out []SkillScriptInfo + err = filepath.WalkDir(root, func(path string, d os.DirEntry, err error) error { + if err != nil { + return err + } + rel, e := filepath.Rel(root, path) + if e != nil { + return e + } + if rel == "." { + return nil + } + if d.IsDir() { + if strings.HasPrefix(d.Name(), ".") { + return filepath.SkipDir + } + if strings.Count(rel, string(os.PathSeparator)) >= maxScriptsDepth { + return filepath.SkipDir + } + return nil + } + if strings.HasPrefix(d.Name(), ".") { + return nil + } + relSkill := filepath.Join("scripts", rel) + full := filepath.Join(root, rel) + fi, err := os.Stat(full) + if err != nil || fi.IsDir() { + return nil + } + out = append(out, SkillScriptInfo{ + Name: filepath.Base(rel), + RelPath: filepath.ToSlash(relSkill), + Size: fi.Size(), + }) + return nil + }) + return out, err +} + +func countNonDirFiles(files []PackageFileInfo) int { + n := 0 + for _, f := range files { + if !f.IsDir && f.Path != "SKILL.md" { + n++ + } + } + return n +} diff --git a/internal/skillpackage/layout.go b/internal/skillpackage/layout.go new file mode 100644 index 00000000..0da7395a --- /dev/null +++ b/internal/skillpackage/layout.go @@ -0,0 +1,66 @@ +package skillpackage + +import ( + "fmt" + "os" + "path/filepath" + "strings" +) + +// SkillDir returns the absolute path to a skill package directory. +func SkillDir(skillsRoot, skillID string) string { + return filepath.Join(skillsRoot, skillID) +} + +// ResolveSKILLPath returns SKILL.md path or error if missing. +func ResolveSKILLPath(skillPath string) (string, error) { + md := filepath.Join(skillPath, "SKILL.md") + if st, err := os.Stat(md); err != nil || st.IsDir() { + return "", fmt.Errorf("missing SKILL.md in %q (Agent Skills standard)", filepath.Base(skillPath)) + } + return md, nil +} + +// SkillsRootFromConfig resolves cfg.SkillsDir relative to the config file directory. +func SkillsRootFromConfig(skillsDir string, configPath string) string { + if skillsDir == "" { + skillsDir = "skills" + } + configDir := filepath.Dir(configPath) + if !filepath.IsAbs(skillsDir) { + skillsDir = filepath.Join(configDir, skillsDir) + } + return skillsDir +} + +// DirLister satisfies handler.SkillsManager for role UI (lists package directory names). +type DirLister struct { + SkillsRoot string +} + +// ListSkills implements the role handler dependency. +func (d DirLister) ListSkills() ([]string, error) { + return ListSkillDirNames(d.SkillsRoot) +} + +// ListSkillDirNames returns subdirectory names under skillsRoot that contain SKILL.md. +func ListSkillDirNames(skillsRoot string) ([]string, error) { + if _, err := os.Stat(skillsRoot); os.IsNotExist(err) { + return nil, nil + } + entries, err := os.ReadDir(skillsRoot) + if err != nil { + return nil, fmt.Errorf("read skills directory: %w", err) + } + var names []string + for _, entry := range entries { + if !entry.IsDir() || strings.HasPrefix(entry.Name(), ".") { + continue + } + skillPath := filepath.Join(skillsRoot, entry.Name()) + if _, err := ResolveSKILLPath(skillPath); err == nil { + names = append(names, entry.Name()) + } + } + return names, nil +} diff --git a/internal/skillpackage/service.go b/internal/skillpackage/service.go new file mode 100644 index 00000000..52dbe90a --- /dev/null +++ b/internal/skillpackage/service.go @@ -0,0 +1,155 @@ +package skillpackage + +import ( + "fmt" + "os" + "sort" + "strings" +) + +// ListSkillSummaries scans skillsRoot and returns index rows for the admin API. +func ListSkillSummaries(skillsRoot string) ([]SkillSummary, error) { + names, err := ListSkillDirNames(skillsRoot) + if err != nil { + return nil, err + } + sort.Strings(names) + out := make([]SkillSummary, 0, len(names)) + for _, dirName := range names { + su, err := loadSummary(skillsRoot, dirName) + if err != nil { + continue + } + out = append(out, su) + } + return out, nil +} + +func loadSummary(skillsRoot, dirName string) (SkillSummary, error) { + skillPath := SkillDir(skillsRoot, dirName) + mdPath, err := ResolveSKILLPath(skillPath) + if err != nil { + return SkillSummary{}, err + } + raw, err := os.ReadFile(mdPath) + if err != nil { + return SkillSummary{}, err + } + man, _, err := ParseSkillMD(raw) + if err != nil { + return SkillSummary{}, err + } + if err := ValidateAgentSkillManifestInPackage(man, dirName); err != nil { + return SkillSummary{}, err + } + fi, err := os.Stat(mdPath) + if err != nil { + return SkillSummary{}, err + } + pfiles, err := ListPackageFiles(skillsRoot, dirName) + if err != nil { + return SkillSummary{}, err + } + nFiles := 0 + for _, p := range pfiles { + if !p.IsDir { + nFiles++ + } + } + scripts, err := listScripts(skillsRoot, dirName) + if err != nil { + return SkillSummary{}, err + } + ver := versionFromMetadata(man) + return SkillSummary{ + ID: dirName, + DirName: dirName, + Name: man.Name, + Description: man.Description, + Version: ver, + Path: skillPath, + Tags: manifestTags(man), + ScriptCount: len(scripts), + FileCount: nFiles, + FileSize: fi.Size(), + ModTime: fi.ModTime().Format("2006-01-02 15:04:05"), + Progressive: true, + }, nil +} + +// LoadOptions mirrors legacy API query params for the web admin. +type LoadOptions struct { + Depth string // summary | full + Section string +} + +// LoadSkill returns manifest + body + package listing for admin. +func LoadSkill(skillsRoot, skillID string, opt LoadOptions) (*SkillView, error) { + skillPath := SkillDir(skillsRoot, skillID) + mdPath, err := ResolveSKILLPath(skillPath) + if err != nil { + return nil, err + } + raw, err := os.ReadFile(mdPath) + if err != nil { + return nil, err + } + man, body, err := ParseSkillMD(raw) + if err != nil { + return nil, err + } + if err := ValidateAgentSkillManifestInPackage(man, skillID); err != nil { + return nil, err + } + pfiles, err := ListPackageFiles(skillsRoot, skillID) + if err != nil { + return nil, err + } + scripts, err := listScripts(skillsRoot, skillID) + if err != nil { + return nil, err + } + sort.Slice(scripts, func(i, j int) bool { return scripts[i].RelPath < scripts[j].RelPath }) + sections := deriveSections(body) + ver := versionFromMetadata(man) + v := &SkillView{ + DirName: skillID, + Name: man.Name, + Description: man.Description, + Content: body, + Path: skillPath, + Version: ver, + Tags: manifestTags(man), + Scripts: scripts, + Sections: sections, + PackageFiles: pfiles, + } + depth := strings.ToLower(strings.TrimSpace(opt.Depth)) + if depth == "" { + depth = "full" + } + sec := strings.TrimSpace(opt.Section) + if sec != "" { + mds := splitMarkdownSections(body) + chunk := findSectionContent(mds, sec) + if chunk == "" { + v.Content = fmt.Sprintf("_(section %q not found in SKILL.md for skill %s)_", sec, skillID) + } else { + v.Content = chunk + } + return v, nil + } + if depth == "summary" { + v.Content = buildSummaryMarkdown(man.Name, man.Description, v.Tags, scripts, sections, body) + } + return v, nil +} + +// ReadScriptText returns file content as string (for HTTP resource_path). +func ReadScriptText(skillsRoot, skillID, relPath string, maxBytes int64) (string, error) { + b, err := ReadPackageFile(skillsRoot, skillID, relPath, maxBytes) + if err != nil { + return "", err + } + return string(b), nil +} diff --git a/internal/skillpackage/types.go b/internal/skillpackage/types.go new file mode 100644 index 00000000..bf313425 --- /dev/null +++ b/internal/skillpackage/types.go @@ -0,0 +1,67 @@ +// Package skillpackage provides filesystem-backed Agent Skills layout (SKILL.md + package files) +// for HTTP admin APIs. Runtime discovery and progressive loading for agents use Eino ADK skill middleware. +package skillpackage + +// SkillManifest is parsed from SKILL.md front matter (https://agentskills.io/specification.md). +type SkillManifest struct { + Name string `yaml:"name"` + Description string `yaml:"description"` + License string `yaml:"license,omitempty"` + Compatibility string `yaml:"compatibility,omitempty"` + Metadata map[string]any `yaml:"metadata,omitempty"` + AllowedTools string `yaml:"allowed-tools,omitempty"` +} + +// SkillSummary is API metadata for one skill directory. +type SkillSummary struct { + ID string `json:"id"` + DirName string `json:"dir_name"` + Name string `json:"name"` + Description string `json:"description"` + Version string `json:"version"` + Path string `json:"path"` + Tags []string `json:"tags"` + Triggers []string `json:"triggers,omitempty"` + ScriptCount int `json:"script_count"` + FileCount int `json:"file_count"` + FileSize int64 `json:"file_size"` + ModTime string `json:"mod_time"` + Progressive bool `json:"progressive"` +} + +// SkillScriptInfo describes a file under scripts/. +type SkillScriptInfo struct { + Name string `json:"name"` + RelPath string `json:"rel_path"` + Description string `json:"description,omitempty"` + Size int64 `json:"size"` +} + +// SkillSection is derived from ## headings in SKILL.md. +type SkillSection struct { + ID string `json:"id"` + Title string `json:"title"` + Heading string `json:"heading"` + Level int `json:"level"` +} + +// PackageFileInfo describes one file inside a package. +type PackageFileInfo struct { + Path string `json:"path"` + Size int64 `json:"size"` + IsDir bool `json:"is_dir,omitempty"` +} + +// SkillView is a loaded package for admin / API. +type SkillView struct { + DirName string `json:"dir_name"` + Name string `json:"name"` + Description string `json:"description"` + Content string `json:"content"` + Path string `json:"path"` + Version string `json:"version"` + Tags []string `json:"tags"` + Scripts []SkillScriptInfo `json:"scripts,omitempty"` + Sections []SkillSection `json:"sections,omitempty"` + PackageFiles []PackageFileInfo `json:"package_files,omitempty"` +} diff --git a/internal/skillpackage/validate.go b/internal/skillpackage/validate.go new file mode 100644 index 00000000..79d8255c --- /dev/null +++ b/internal/skillpackage/validate.go @@ -0,0 +1,102 @@ +package skillpackage + +import ( + "fmt" + "strings" + "unicode/utf8" + + "gopkg.in/yaml.v3" +) + +var agentSkillsSpecFrontMatterKeys = map[string]struct{}{ + "name": {}, "description": {}, "license": {}, "compatibility": {}, + "metadata": {}, "allowed-tools": {}, +} + +// ValidateAgentSkillManifest enforces Agent Skills rules for name and description. +func ValidateAgentSkillManifest(m *SkillManifest) error { + if m == nil { + return fmt.Errorf("skill manifest is nil") + } + if strings.TrimSpace(m.Name) == "" { + return fmt.Errorf("SKILL.md front matter: name is required") + } + if strings.TrimSpace(m.Description) == "" { + return fmt.Errorf("SKILL.md front matter: description is required") + } + if utf8.RuneCountInString(m.Name) > 64 { + return fmt.Errorf("name exceeds 64 characters (Agent Skills limit)") + } + if utf8.RuneCountInString(m.Description) > 1024 { + return fmt.Errorf("description exceeds 1024 characters (Agent Skills limit)") + } + if m.Name != strings.ToLower(m.Name) { + return fmt.Errorf("name must be lowercase (Agent Skills)") + } + for _, r := range m.Name { + if !((r >= 'a' && r <= 'z') || (r >= '0' && r <= '9') || r == '-') { + return fmt.Errorf("name must contain only lowercase letters, numbers, hyphens (Agent Skills)") + } + } + if strings.HasPrefix(m.Name, "-") || strings.HasSuffix(m.Name, "-") { + return fmt.Errorf("name must not start or end with a hyphen (Agent Skills spec)") + } + if strings.Contains(m.Name, "--") { + return fmt.Errorf("name must not contain consecutive hyphens (Agent Skills spec)") + } + lname := strings.ToLower(m.Name) + if strings.Contains(lname, "anthropic") || strings.Contains(lname, "claude") { + return fmt.Errorf("name must not contain reserved words anthropic or claude") + } + return nil +} + +// ValidateAgentSkillManifestInPackage checks manifest and that name matches package directory. +func ValidateAgentSkillManifestInPackage(m *SkillManifest, packageDirName string) error { + if err := ValidateAgentSkillManifest(m); err != nil { + return err + } + if strings.TrimSpace(packageDirName) == "" { + return nil + } + if m.Name != packageDirName { + return fmt.Errorf("SKILL.md name %q must match directory name %q (Agent Skills spec)", m.Name, packageDirName) + } + return nil +} + +// ValidateOfficialFrontMatterTopLevelKeys rejects keys not in the open spec. +func ValidateOfficialFrontMatterTopLevelKeys(fmYAML string) error { + var top map[string]interface{} + if err := yaml.Unmarshal([]byte(fmYAML), &top); err != nil { + return fmt.Errorf("SKILL.md front matter: %w", err) + } + for k := range top { + if _, ok := agentSkillsSpecFrontMatterKeys[k]; !ok { + return fmt.Errorf("SKILL.md front matter: unsupported key %q (allowed: name, description, license, compatibility, metadata, allowed-tools — see https://agentskills.io/specification.md)", k) + } + } + return nil +} + +// ValidateSkillMDPackage validates SKILL.md bytes for writes. +func ValidateSkillMDPackage(raw []byte, packageDirName string) error { + fmYAML, body, err := ExtractSkillMDFrontMatterYAML(raw) + if err != nil { + return err + } + if err := ValidateOfficialFrontMatterTopLevelKeys(fmYAML); err != nil { + return err + } + if strings.TrimSpace(body) == "" { + return fmt.Errorf("SKILL.md: markdown body after front matter must not be empty") + } + var fm SkillManifest + if err := yaml.Unmarshal([]byte(fmYAML), &fm); err != nil { + return fmt.Errorf("SKILL.md front matter: %w", err) + } + if c := strings.TrimSpace(fm.Compatibility); c != "" && utf8.RuneCountInString(c) > 500 { + return fmt.Errorf("compatibility exceeds 500 characters (Agent Skills spec)") + } + return ValidateAgentSkillManifestInPackage(&fm, packageDirName) +} diff --git a/internal/storage/result_storage.go b/internal/storage/result_storage.go new file mode 100644 index 00000000..85a8b7b3 --- /dev/null +++ b/internal/storage/result_storage.go @@ -0,0 +1,297 @@ +package storage + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + "regexp" + "strings" + "sync" + "time" + + "go.uber.org/zap" +) + +// ResultStorage 结果存储接口 +type ResultStorage interface { + // SaveResult 保存工具执行结果 + SaveResult(executionID string, toolName string, result string) error + + // GetResult 获取完整结果 + GetResult(executionID string) (string, error) + + // GetResultPage 分页获取结果 + GetResultPage(executionID string, page int, limit int) (*ResultPage, error) + + // SearchResult 搜索结果 + // useRegex: 如果为 true,将 keyword 作为正则表达式使用;如果为 false,使用简单的字符串包含匹配 + SearchResult(executionID string, keyword string, useRegex bool) ([]string, error) + + // FilterResult 过滤结果 + // useRegex: 如果为 true,将 filter 作为正则表达式使用;如果为 false,使用简单的字符串包含匹配 + FilterResult(executionID string, filter string, useRegex bool) ([]string, error) + + // GetResultMetadata 获取结果元信息 + GetResultMetadata(executionID string) (*ResultMetadata, error) + + // GetResultPath 获取结果文件路径 + GetResultPath(executionID string) string + + // DeleteResult 删除结果 + DeleteResult(executionID string) error +} + +// ResultPage 分页结果 +type ResultPage struct { + Lines []string `json:"lines"` + Page int `json:"page"` + Limit int `json:"limit"` + TotalLines int `json:"total_lines"` + TotalPages int `json:"total_pages"` +} + +// ResultMetadata 结果元信息 +type ResultMetadata struct { + ExecutionID string `json:"execution_id"` + ToolName string `json:"tool_name"` + TotalSize int `json:"total_size"` + TotalLines int `json:"total_lines"` + CreatedAt time.Time `json:"created_at"` +} + +// FileResultStorage 基于文件的结果存储实现 +type FileResultStorage struct { + baseDir string + logger *zap.Logger + mu sync.RWMutex +} + +// NewFileResultStorage 创建新的文件结果存储 +func NewFileResultStorage(baseDir string, logger *zap.Logger) (*FileResultStorage, error) { + // 确保目录存在 + if err := os.MkdirAll(baseDir, 0755); err != nil { + return nil, fmt.Errorf("创建存储目录失败: %w", err) + } + + return &FileResultStorage{ + baseDir: baseDir, + logger: logger, + }, nil +} + +// getResultPath 获取结果文件路径 +func (s *FileResultStorage) getResultPath(executionID string) string { + return filepath.Join(s.baseDir, executionID+".txt") +} + +// getMetadataPath 获取元数据文件路径 +func (s *FileResultStorage) getMetadataPath(executionID string) string { + return filepath.Join(s.baseDir, executionID+".meta.json") +} + +// SaveResult 保存工具执行结果 +func (s *FileResultStorage) SaveResult(executionID string, toolName string, result string) error { + s.mu.Lock() + defer s.mu.Unlock() + + // 保存结果文件 + resultPath := s.getResultPath(executionID) + if err := os.WriteFile(resultPath, []byte(result), 0644); err != nil { + return fmt.Errorf("保存结果文件失败: %w", err) + } + + // 计算统计信息 + lines := strings.Split(result, "\n") + metadata := &ResultMetadata{ + ExecutionID: executionID, + ToolName: toolName, + TotalSize: len(result), + TotalLines: len(lines), + CreatedAt: time.Now(), + } + + // 保存元数据 + metadataPath := s.getMetadataPath(executionID) + metadataJSON, err := json.Marshal(metadata) + if err != nil { + return fmt.Errorf("序列化元数据失败: %w", err) + } + + if err := os.WriteFile(metadataPath, metadataJSON, 0644); err != nil { + return fmt.Errorf("保存元数据文件失败: %w", err) + } + + s.logger.Info("保存工具执行结果", + zap.String("executionID", executionID), + zap.String("toolName", toolName), + zap.Int("size", len(result)), + zap.Int("lines", len(lines)), + ) + + return nil +} + +// GetResult 获取完整结果 +func (s *FileResultStorage) GetResult(executionID string) (string, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + resultPath := s.getResultPath(executionID) + data, err := os.ReadFile(resultPath) + if err != nil { + if os.IsNotExist(err) { + return "", fmt.Errorf("结果不存在: %s", executionID) + } + return "", fmt.Errorf("读取结果文件失败: %w", err) + } + + return string(data), nil +} + +// GetResultMetadata 获取结果元信息 +func (s *FileResultStorage) GetResultMetadata(executionID string) (*ResultMetadata, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + metadataPath := s.getMetadataPath(executionID) + data, err := os.ReadFile(metadataPath) + if err != nil { + if os.IsNotExist(err) { + return nil, fmt.Errorf("结果不存在: %s", executionID) + } + return nil, fmt.Errorf("读取元数据文件失败: %w", err) + } + + var metadata ResultMetadata + if err := json.Unmarshal(data, &metadata); err != nil { + return nil, fmt.Errorf("解析元数据失败: %w", err) + } + + return &metadata, nil +} + +// GetResultPage 分页获取结果 +func (s *FileResultStorage) GetResultPage(executionID string, page int, limit int) (*ResultPage, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + // 获取完整结果 + result, err := s.GetResult(executionID) + if err != nil { + return nil, err + } + + // 分割为行 + lines := strings.Split(result, "\n") + totalLines := len(lines) + + // 计算分页 + totalPages := (totalLines + limit - 1) / limit + if page < 1 { + page = 1 + } + if page > totalPages && totalPages > 0 { + page = totalPages + } + + // 计算起始和结束索引 + start := (page - 1) * limit + end := start + limit + if end > totalLines { + end = totalLines + } + + // 提取指定页的行 + var pageLines []string + if start < totalLines { + pageLines = lines[start:end] + } else { + pageLines = []string{} + } + + return &ResultPage{ + Lines: pageLines, + Page: page, + Limit: limit, + TotalLines: totalLines, + TotalPages: totalPages, + }, nil +} + +// SearchResult 搜索结果 +func (s *FileResultStorage) SearchResult(executionID string, keyword string, useRegex bool) ([]string, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + // 获取完整结果 + result, err := s.GetResult(executionID) + if err != nil { + return nil, err + } + + // 如果使用正则表达式,先编译正则 + var regex *regexp.Regexp + if useRegex { + compiledRegex, err := regexp.Compile(keyword) + if err != nil { + return nil, fmt.Errorf("无效的正则表达式: %w", err) + } + regex = compiledRegex + } + + // 分割为行并搜索 + lines := strings.Split(result, "\n") + var matchedLines []string + + for _, line := range lines { + var matched bool + if useRegex { + matched = regex.MatchString(line) + } else { + matched = strings.Contains(line, keyword) + } + + if matched { + matchedLines = append(matchedLines, line) + } + } + + return matchedLines, nil +} + +// FilterResult 过滤结果 +func (s *FileResultStorage) FilterResult(executionID string, filter string, useRegex bool) ([]string, error) { + // 过滤和搜索逻辑相同,都是查找包含关键词的行 + return s.SearchResult(executionID, filter, useRegex) +} + +// GetResultPath 获取结果文件路径 +func (s *FileResultStorage) GetResultPath(executionID string) string { + return s.getResultPath(executionID) +} + +// DeleteResult 删除结果 +func (s *FileResultStorage) DeleteResult(executionID string) error { + s.mu.Lock() + defer s.mu.Unlock() + + resultPath := s.getResultPath(executionID) + metadataPath := s.getMetadataPath(executionID) + + // 删除结果文件 + if err := os.Remove(resultPath); err != nil && !os.IsNotExist(err) { + return fmt.Errorf("删除结果文件失败: %w", err) + } + + // 删除元数据文件 + if err := os.Remove(metadataPath); err != nil && !os.IsNotExist(err) { + return fmt.Errorf("删除元数据文件失败: %w", err) + } + + s.logger.Info("删除工具执行结果", + zap.String("executionID", executionID), + ) + + return nil +} diff --git a/internal/storage/result_storage_test.go b/internal/storage/result_storage_test.go new file mode 100644 index 00000000..51305c92 --- /dev/null +++ b/internal/storage/result_storage_test.go @@ -0,0 +1,453 @@ +package storage + +import ( + "fmt" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "go.uber.org/zap" +) + +// setupTestStorage 创建测试用的存储实例 +func setupTestStorage(t *testing.T) (*FileResultStorage, string) { + tmpDir := filepath.Join(os.TempDir(), "test_result_storage_"+time.Now().Format("20060102_150405")) + logger := zap.NewNop() + + storage, err := NewFileResultStorage(tmpDir, logger) + if err != nil { + t.Fatalf("创建测试存储失败: %v", err) + } + + return storage, tmpDir +} + +// cleanupTestStorage 清理测试数据 +func cleanupTestStorage(t *testing.T, tmpDir string) { + if err := os.RemoveAll(tmpDir); err != nil { + t.Logf("清理测试目录失败: %v", err) + } +} + +func TestNewFileResultStorage(t *testing.T) { + tmpDir := filepath.Join(os.TempDir(), "test_new_storage_"+time.Now().Format("20060102_150405")) + defer cleanupTestStorage(t, tmpDir) + + logger := zap.NewNop() + storage, err := NewFileResultStorage(tmpDir, logger) + if err != nil { + t.Fatalf("创建存储失败: %v", err) + } + + if storage == nil { + t.Fatal("存储实例为nil") + } + + // 验证目录已创建 + if _, err := os.Stat(tmpDir); os.IsNotExist(err) { + t.Fatal("存储目录未创建") + } +} + +func TestFileResultStorage_SaveResult(t *testing.T) { + storage, tmpDir := setupTestStorage(t) + defer cleanupTestStorage(t, tmpDir) + + executionID := "test_exec_001" + toolName := "nmap_scan" + result := "Line 1\nLine 2\nLine 3\nLine 4\nLine 5" + + err := storage.SaveResult(executionID, toolName, result) + if err != nil { + t.Fatalf("保存结果失败: %v", err) + } + + // 验证结果文件存在 + resultPath := filepath.Join(tmpDir, executionID+".txt") + if _, err := os.Stat(resultPath); os.IsNotExist(err) { + t.Fatal("结果文件未创建") + } + + // 验证元数据文件存在 + metadataPath := filepath.Join(tmpDir, executionID+".meta.json") + if _, err := os.Stat(metadataPath); os.IsNotExist(err) { + t.Fatal("元数据文件未创建") + } +} + +func TestFileResultStorage_GetResult(t *testing.T) { + storage, tmpDir := setupTestStorage(t) + defer cleanupTestStorage(t, tmpDir) + + executionID := "test_exec_002" + toolName := "test_tool" + expectedResult := "Test result content\nLine 2\nLine 3" + + // 先保存结果 + err := storage.SaveResult(executionID, toolName, expectedResult) + if err != nil { + t.Fatalf("保存结果失败: %v", err) + } + + // 获取结果 + result, err := storage.GetResult(executionID) + if err != nil { + t.Fatalf("获取结果失败: %v", err) + } + + if result != expectedResult { + t.Errorf("结果不匹配。期望: %q, 实际: %q", expectedResult, result) + } + + // 测试不存在的执行ID + _, err = storage.GetResult("nonexistent_id") + if err == nil { + t.Fatal("应该返回错误") + } +} + +func TestFileResultStorage_GetResultMetadata(t *testing.T) { + storage, tmpDir := setupTestStorage(t) + defer cleanupTestStorage(t, tmpDir) + + executionID := "test_exec_003" + toolName := "test_tool" + result := "Line 1\nLine 2\nLine 3" + + // 保存结果 + err := storage.SaveResult(executionID, toolName, result) + if err != nil { + t.Fatalf("保存结果失败: %v", err) + } + + // 获取元数据 + metadata, err := storage.GetResultMetadata(executionID) + if err != nil { + t.Fatalf("获取元数据失败: %v", err) + } + + if metadata.ExecutionID != executionID { + t.Errorf("执行ID不匹配。期望: %s, 实际: %s", executionID, metadata.ExecutionID) + } + + if metadata.ToolName != toolName { + t.Errorf("工具名称不匹配。期望: %s, 实际: %s", toolName, metadata.ToolName) + } + + if metadata.TotalSize != len(result) { + t.Errorf("总大小不匹配。期望: %d, 实际: %d", len(result), metadata.TotalSize) + } + + expectedLines := len(strings.Split(result, "\n")) + if metadata.TotalLines != expectedLines { + t.Errorf("总行数不匹配。期望: %d, 实际: %d", expectedLines, metadata.TotalLines) + } + + // 验证创建时间在合理范围内 + now := time.Now() + if metadata.CreatedAt.After(now) || metadata.CreatedAt.Before(now.Add(-time.Second)) { + t.Errorf("创建时间不在合理范围内: %v", metadata.CreatedAt) + } +} + +func TestFileResultStorage_GetResultPage(t *testing.T) { + storage, tmpDir := setupTestStorage(t) + defer cleanupTestStorage(t, tmpDir) + + executionID := "test_exec_004" + toolName := "test_tool" + // 创建包含10行的结果 + lines := make([]string, 10) + for i := 0; i < 10; i++ { + lines[i] = fmt.Sprintf("Line %d", i+1) + } + result := strings.Join(lines, "\n") + + // 保存结果 + err := storage.SaveResult(executionID, toolName, result) + if err != nil { + t.Fatalf("保存结果失败: %v", err) + } + + // 测试第一页(每页3行) + page, err := storage.GetResultPage(executionID, 1, 3) + if err != nil { + t.Fatalf("获取第一页失败: %v", err) + } + + if page.Page != 1 { + t.Errorf("页码不匹配。期望: 1, 实际: %d", page.Page) + } + + if page.Limit != 3 { + t.Errorf("每页行数不匹配。期望: 3, 实际: %d", page.Limit) + } + + if page.TotalLines != 10 { + t.Errorf("总行数不匹配。期望: 10, 实际: %d", page.TotalLines) + } + + if page.TotalPages != 4 { + t.Errorf("总页数不匹配。期望: 4, 实际: %d", page.TotalPages) + } + + if len(page.Lines) != 3 { + t.Errorf("第一页行数不匹配。期望: 3, 实际: %d", len(page.Lines)) + } + + if page.Lines[0] != "Line 1" { + t.Errorf("第一行内容不匹配。期望: Line 1, 实际: %s", page.Lines[0]) + } + + // 测试第二页 + page2, err := storage.GetResultPage(executionID, 2, 3) + if err != nil { + t.Fatalf("获取第二页失败: %v", err) + } + + if len(page2.Lines) != 3 { + t.Errorf("第二页行数不匹配。期望: 3, 实际: %d", len(page2.Lines)) + } + + if page2.Lines[0] != "Line 4" { + t.Errorf("第二页第一行内容不匹配。期望: Line 4, 实际: %s", page2.Lines[0]) + } + + // 测试最后一页(可能不满一页) + page4, err := storage.GetResultPage(executionID, 4, 3) + if err != nil { + t.Fatalf("获取第四页失败: %v", err) + } + + if len(page4.Lines) != 1 { + t.Errorf("第四页行数不匹配。期望: 1, 实际: %d", len(page4.Lines)) + } + + // 测试超出范围的页码(应该返回最后一页) + page5, err := storage.GetResultPage(executionID, 5, 3) + if err != nil { + t.Fatalf("获取第五页失败: %v", err) + } + + // 超出范围的页码会被修正为最后一页,所以应该返回最后一页的内容 + if page5.Page != 4 { + t.Errorf("超出范围的页码应该被修正为最后一页。期望: 4, 实际: %d", page5.Page) + } + + // 最后一页应该只有1行 + if len(page5.Lines) != 1 { + t.Errorf("最后一页应该只有1行。实际: %d行", len(page5.Lines)) + } +} + +func TestFileResultStorage_SearchResult(t *testing.T) { + storage, tmpDir := setupTestStorage(t) + defer cleanupTestStorage(t, tmpDir) + + executionID := "test_exec_005" + toolName := "test_tool" + result := "Line 1: error occurred\nLine 2: success\nLine 3: error again\nLine 4: ok" + + // 保存结果 + err := storage.SaveResult(executionID, toolName, result) + if err != nil { + t.Fatalf("保存结果失败: %v", err) + } + + // 搜索包含"error"的行(简单字符串匹配) + matchedLines, err := storage.SearchResult(executionID, "error", false) + if err != nil { + t.Fatalf("搜索失败: %v", err) + } + + if len(matchedLines) != 2 { + t.Errorf("搜索结果数量不匹配。期望: 2, 实际: %d", len(matchedLines)) + } + + // 验证搜索结果内容 + for i, line := range matchedLines { + if !strings.Contains(line, "error") { + t.Errorf("搜索结果第%d行不包含关键词: %s", i+1, line) + } + } + + // 测试搜索不存在的关键词 + noMatch, err := storage.SearchResult(executionID, "nonexistent", false) + if err != nil { + t.Fatalf("搜索失败: %v", err) + } + + if len(noMatch) != 0 { + t.Errorf("搜索不存在的关键词应该返回空结果。实际: %d行", len(noMatch)) + } + + // 测试正则表达式搜索 + regexMatched, err := storage.SearchResult(executionID, "error.*again", true) + if err != nil { + t.Fatalf("正则搜索失败: %v", err) + } + + if len(regexMatched) != 1 { + t.Errorf("正则搜索结果数量不匹配。期望: 1, 实际: %d", len(regexMatched)) + } +} + +func TestFileResultStorage_FilterResult(t *testing.T) { + storage, tmpDir := setupTestStorage(t) + defer cleanupTestStorage(t, tmpDir) + + executionID := "test_exec_006" + toolName := "test_tool" + result := "Line 1: warning message\nLine 2: info message\nLine 3: warning again\nLine 4: debug message" + + // 保存结果 + err := storage.SaveResult(executionID, toolName, result) + if err != nil { + t.Fatalf("保存结果失败: %v", err) + } + + // 过滤包含"warning"的行(简单字符串匹配) + filteredLines, err := storage.FilterResult(executionID, "warning", false) + if err != nil { + t.Fatalf("过滤失败: %v", err) + } + + if len(filteredLines) != 2 { + t.Errorf("过滤结果数量不匹配。期望: 2, 实际: %d", len(filteredLines)) + } + + // 验证过滤结果内容 + for i, line := range filteredLines { + if !strings.Contains(line, "warning") { + t.Errorf("过滤结果第%d行不包含关键词: %s", i+1, line) + } + } +} + +func TestFileResultStorage_DeleteResult(t *testing.T) { + storage, tmpDir := setupTestStorage(t) + defer cleanupTestStorage(t, tmpDir) + + executionID := "test_exec_007" + toolName := "test_tool" + result := "Test result" + + // 保存结果 + err := storage.SaveResult(executionID, toolName, result) + if err != nil { + t.Fatalf("保存结果失败: %v", err) + } + + // 验证文件存在 + resultPath := filepath.Join(tmpDir, executionID+".txt") + metadataPath := filepath.Join(tmpDir, executionID+".meta.json") + + if _, err := os.Stat(resultPath); os.IsNotExist(err) { + t.Fatal("结果文件不存在") + } + + if _, err := os.Stat(metadataPath); os.IsNotExist(err) { + t.Fatal("元数据文件不存在") + } + + // 删除结果 + err = storage.DeleteResult(executionID) + if err != nil { + t.Fatalf("删除结果失败: %v", err) + } + + // 验证文件已删除 + if _, err := os.Stat(resultPath); !os.IsNotExist(err) { + t.Fatal("结果文件未被删除") + } + + if _, err := os.Stat(metadataPath); !os.IsNotExist(err) { + t.Fatal("元数据文件未被删除") + } + + // 测试删除不存在的执行ID(应该不报错) + err = storage.DeleteResult("nonexistent_id") + if err != nil { + t.Errorf("删除不存在的执行ID不应该报错: %v", err) + } +} + +func TestFileResultStorage_ConcurrentAccess(t *testing.T) { + storage, tmpDir := setupTestStorage(t) + defer cleanupTestStorage(t, tmpDir) + + // 并发保存多个结果 + done := make(chan bool, 10) + for i := 0; i < 10; i++ { + go func(id int) { + executionID := fmt.Sprintf("test_exec_%d", id) + toolName := "test_tool" + result := fmt.Sprintf("Result %d\nLine 2\nLine 3", id) + + err := storage.SaveResult(executionID, toolName, result) + if err != nil { + t.Errorf("并发保存失败 (ID: %s): %v", executionID, err) + } + + // 并发读取 + _, err = storage.GetResult(executionID) + if err != nil { + t.Errorf("并发读取失败 (ID: %s): %v", executionID, err) + } + + done <- true + }(i) + } + + // 等待所有goroutine完成 + for i := 0; i < 10; i++ { + <-done + } +} + +func TestFileResultStorage_LargeResult(t *testing.T) { + storage, tmpDir := setupTestStorage(t) + defer cleanupTestStorage(t, tmpDir) + + executionID := "test_exec_large" + toolName := "test_tool" + + // 创建大结果(1000行) + lines := make([]string, 1000) + for i := 0; i < 1000; i++ { + lines[i] = fmt.Sprintf("Line %d: This is a test line with some content", i+1) + } + result := strings.Join(lines, "\n") + + // 保存大结果 + err := storage.SaveResult(executionID, toolName, result) + if err != nil { + t.Fatalf("保存大结果失败: %v", err) + } + + // 验证元数据 + metadata, err := storage.GetResultMetadata(executionID) + if err != nil { + t.Fatalf("获取元数据失败: %v", err) + } + + if metadata.TotalLines != 1000 { + t.Errorf("总行数不匹配。期望: 1000, 实际: %d", metadata.TotalLines) + } + + // 测试分页查询大结果 + page, err := storage.GetResultPage(executionID, 1, 100) + if err != nil { + t.Fatalf("获取第一页失败: %v", err) + } + + if page.TotalPages != 10 { + t.Errorf("总页数不匹配。期望: 10, 实际: %d", page.TotalPages) + } + + if len(page.Lines) != 100 { + t.Errorf("第一页行数不匹配。期望: 100, 实际: %d", len(page.Lines)) + } +}