diff --git a/internal/agent/agent.go b/internal/agent/agent.go index 36261379..fa82a3a9 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -336,10 +336,10 @@ func (fc *FunctionCall) UnmarshalJSON(data []byte) error { // AgentLoopResult Agent Loop执行结果 type AgentLoopResult struct { - Response string - MCPExecutionIDs []string - LastReActInput string // 最后一轮ReAct的输入(压缩后的messages,JSON格式) - LastReActOutput string // 最终大模型的输出 + Response string + MCPExecutionIDs []string + LastAgentTraceInput string // 最后一轮代理消息轨迹(压缩后的 messages,JSON;与 multiagent.RunResult 字段对齐) + LastAgentTraceOutput string // 最终助手输出文本 } // ProgressCallback 进度回调函数类型 @@ -471,7 +471,7 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his } // 用于保存当前的messages,以便在异常情况下也能保存ReAct输入 - var currentReActInput string + var currentAgentTraceInput string maxIterations := a.maxIterations thinkingStreamSeq := 0 @@ -490,9 +490,9 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his if err != nil { a.logger.Warn("序列化ReAct输入失败", zap.Error(err)) } else { - currentReActInput = string(messagesJSON) + currentAgentTraceInput = string(messagesJSON) // 更新result中的值,确保始终保存最新的ReAct输入(压缩后的) - result.LastReActInput = currentReActInput + result.LastAgentTraceInput = currentAgentTraceInput } // 检查上下文是否已取消 @@ -500,13 +500,13 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his case <-ctx.Done(): // 上下文被取消(可能是用户主动暂停或其他原因) a.logger.Info("检测到上下文取消,保存当前ReAct数据", zap.Error(ctx.Err())) - result.LastReActInput = currentReActInput + result.LastAgentTraceInput = currentAgentTraceInput if ctx.Err() == context.Canceled { result.Response = "任务已被取消。" } else { result.Response = fmt.Sprintf("任务执行中断: %v", ctx.Err()) } - result.LastReActOutput = result.Response + result.LastAgentTraceOutput = result.Response return result, ctx.Err() default: } @@ -600,10 +600,10 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his }) if err != nil { // API调用失败,保存当前的ReAct输入和错误信息作为输出 - result.LastReActInput = currentReActInput + result.LastAgentTraceInput = currentAgentTraceInput errorMsg := fmt.Sprintf("调用OpenAI失败: %v", err) result.Response = errorMsg - result.LastReActOutput = errorMsg + result.LastAgentTraceOutput = errorMsg a.logger.Warn("OpenAI调用失败,已保存ReAct数据", zap.Error(err)) return result, fmt.Errorf("调用OpenAI失败: %w", err) } @@ -629,19 +629,19 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his continue } // OpenAI返回错误,保存当前的ReAct输入和错误信息作为输出 - result.LastReActInput = currentReActInput + result.LastAgentTraceInput = currentAgentTraceInput errorMsg := fmt.Sprintf("OpenAI错误: %s", response.Error.Message) result.Response = errorMsg - result.LastReActOutput = errorMsg + result.LastAgentTraceOutput = errorMsg return result, fmt.Errorf("OpenAI错误: %s", response.Error.Message) } if len(response.Choices) == 0 { // 没有收到响应,保存当前的ReAct输入和错误信息作为输出 - result.LastReActInput = currentReActInput + result.LastAgentTraceInput = currentAgentTraceInput errorMsg := "没有收到响应" result.Response = errorMsg - result.LastReActOutput = errorMsg + result.LastAgentTraceOutput = errorMsg return result, fmt.Errorf("没有收到响应") } @@ -816,7 +816,7 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his }) if strings.TrimSpace(streamText) != "" { result.Response = streamText - result.LastReActOutput = result.Response + result.LastAgentTraceOutput = result.Response sendProgress("progress", "总结生成完成", nil) return result, nil } @@ -863,14 +863,14 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his }) if strings.TrimSpace(streamText) != "" { result.Response = streamText - result.LastReActOutput = result.Response + result.LastAgentTraceOutput = result.Response sendProgress("progress", "总结生成完成", nil) return result, nil } // 如果获取总结失败,使用当前回复作为结果 if choice.Message.Content != "" { result.Response = choice.Message.Content - result.LastReActOutput = result.Response + result.LastAgentTraceOutput = result.Response return result, nil } // 如果都没有内容,跳出循环,让后续逻辑处理 @@ -881,7 +881,7 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his if choice.FinishReason == "stop" { sendProgress("progress", "正在生成最终回复...", nil) result.Response = choice.Message.Content - result.LastReActOutput = result.Response + result.LastAgentTraceOutput = result.Response return result, nil } } @@ -910,14 +910,14 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his }) if strings.TrimSpace(streamText) != "" { result.Response = streamText - result.LastReActOutput = result.Response + result.LastAgentTraceOutput = result.Response sendProgress("progress", "总结生成完成", nil) return result, nil } // 如果无法生成总结,返回友好的提示 result.Response = fmt.Sprintf("已达到最大迭代次数(%d轮)。系统已执行了多轮测试,但由于达到迭代上限,无法继续自动执行。建议您查看已执行的工具结果,或提出新的测试请求以继续测试。", a.maxIterations) - result.LastReActOutput = result.Response + result.LastAgentTraceOutput = result.Response return result, nil } diff --git a/internal/agent/agent_test.go b/internal/agent/agent_test.go index fcbcfa64..26df9ce3 100644 --- a/internal/agent/agent_test.go +++ b/internal/agent/agent_test.go @@ -18,62 +18,62 @@ import ( func setupTestAgent(t *testing.T) (*Agent, *storage.FileResultStorage) { logger := zap.NewNop() mcpServer := mcp.NewServer(logger) - + openAICfg := &config.OpenAIConfig{ APIKey: "test-key", BaseURL: "https://api.test.com/v1", Model: "test-model", } - + agentCfg := &config.AgentConfig{ MaxIterations: 10, LargeResultThreshold: 100, // 设置较小的阈值便于测试 ResultStorageDir: "", } - + agent := NewAgent(openAICfg, agentCfg, mcpServer, nil, logger, 10) - + // 创建测试存储 tmpDir := filepath.Join(os.TempDir(), "test_agent_storage_"+time.Now().Format("20060102_150405")) testStorage, err := storage.NewFileResultStorage(tmpDir, logger) if err != nil { t.Fatalf("创建测试存储失败: %v", err) } - + agent.SetResultStorage(testStorage) - + return agent, testStorage } func TestAgent_FormatMinimalNotification(t *testing.T) { agent, testStorage := setupTestAgent(t) _ = testStorage // 避免未使用变量警告 - + executionID := "test_exec_001" toolName := "nmap_scan" size := 50000 lineCount := 1000 filePath := "tmp/test_exec_001.txt" - + notification := agent.formatMinimalNotification(executionID, toolName, size, lineCount, filePath) - + // 验证通知包含必要信息 if !strings.Contains(notification, executionID) { t.Errorf("通知中应该包含执行ID: %s", executionID) } - + if !strings.Contains(notification, toolName) { t.Errorf("通知中应该包含工具名称: %s", toolName) } - + if !strings.Contains(notification, "50000") { t.Errorf("通知中应该包含大小信息") } - + if !strings.Contains(notification, "1000") { t.Errorf("通知中应该包含行数信息") } - + if !strings.Contains(notification, "query_execution_result") { t.Errorf("通知中应该包含查询工具的使用说明") } @@ -81,7 +81,7 @@ func TestAgent_FormatMinimalNotification(t *testing.T) { func TestAgent_ExecuteToolViaMCP_LargeResult(t *testing.T) { agent, _ := setupTestAgent(t) - + // 创建模拟的MCP工具结果(大结果) largeResult := &mcp.ToolResult{ Content: []mcp.Content{ @@ -92,59 +92,59 @@ func TestAgent_ExecuteToolViaMCP_LargeResult(t *testing.T) { }, IsError: false, } - + // 模拟MCP服务器返回大结果 // 由于我们需要模拟CallTool的行为,这里需要创建一个mock或者使用实际的MCP服务器 // 为了简化测试,我们直接测试结果处理逻辑 - + // 设置阈值 agent.mu.Lock() agent.largeResultThreshold = 1000 // 设置较小的阈值 agent.mu.Unlock() - + // 创建执行ID executionID := "test_exec_large_001" toolName := "test_tool" - + // 格式化结果 var resultText strings.Builder for _, content := range largeResult.Content { resultText.WriteString(content.Text) resultText.WriteString("\n") } - + resultStr := resultText.String() resultSize := len(resultStr) - + // 检测大结果并保存 agent.mu.RLock() threshold := agent.largeResultThreshold storage := agent.resultStorage agent.mu.RUnlock() - + if resultSize > threshold && storage != nil { // 保存大结果 err := storage.SaveResult(executionID, toolName, resultStr) if err != nil { t.Fatalf("保存大结果失败: %v", err) } - + // 生成通知 lines := strings.Split(resultStr, "\n") filePath := storage.GetResultPath(executionID) notification := agent.formatMinimalNotification(executionID, toolName, resultSize, len(lines), filePath) - + // 验证通知格式 if !strings.Contains(notification, executionID) { t.Errorf("通知中应该包含执行ID") } - + // 验证结果已保存 savedResult, err := storage.GetResult(executionID) if err != nil { t.Fatalf("获取保存的结果失败: %v", err) } - + if savedResult != resultStr { t.Errorf("保存的结果与原始结果不匹配") } @@ -155,7 +155,7 @@ func TestAgent_ExecuteToolViaMCP_LargeResult(t *testing.T) { func TestAgent_ExecuteToolViaMCP_SmallResult(t *testing.T) { agent, _ := setupTestAgent(t) - + // 创建小结果 smallResult := &mcp.ToolResult{ Content: []mcp.Content{ @@ -166,32 +166,32 @@ func TestAgent_ExecuteToolViaMCP_SmallResult(t *testing.T) { }, IsError: false, } - + // 设置较大的阈值 agent.mu.Lock() agent.largeResultThreshold = 100000 // 100KB agent.mu.Unlock() - + // 格式化结果 var resultText strings.Builder for _, content := range smallResult.Content { resultText.WriteString(content.Text) resultText.WriteString("\n") } - + resultStr := resultText.String() resultSize := len(resultStr) - + // 检测大结果 agent.mu.RLock() threshold := agent.largeResultThreshold storage := agent.resultStorage agent.mu.RUnlock() - + if resultSize > threshold && storage != nil { t.Fatal("小结果不应该被保存") } - + // 小结果应该直接返回 if resultSize <= threshold { // 这是预期的行为 @@ -203,26 +203,26 @@ func TestAgent_ExecuteToolViaMCP_SmallResult(t *testing.T) { func TestAgent_SetResultStorage(t *testing.T) { agent, _ := setupTestAgent(t) - + // 创建新的存储 tmpDir := filepath.Join(os.TempDir(), "test_new_storage_"+time.Now().Format("20060102_150405")) newStorage, err := storage.NewFileResultStorage(tmpDir, zap.NewNop()) if err != nil { t.Fatalf("创建新存储失败: %v", err) } - + // 设置新存储 agent.SetResultStorage(newStorage) - + // 验证存储已更新 agent.mu.RLock() currentStorage := agent.resultStorage agent.mu.RUnlock() - + if currentStorage != newStorage { t.Fatal("存储未正确更新") } - + // 清理 os.RemoveAll(tmpDir) } @@ -230,24 +230,24 @@ func TestAgent_SetResultStorage(t *testing.T) { func TestAgent_NewAgent_DefaultValues(t *testing.T) { logger := zap.NewNop() mcpServer := mcp.NewServer(logger) - + openAICfg := &config.OpenAIConfig{ APIKey: "test-key", BaseURL: "https://api.test.com/v1", Model: "test-model", } - + // 测试默认配置 agent := NewAgent(openAICfg, nil, mcpServer, nil, logger, 0) - + if agent.maxIterations != 30 { t.Errorf("默认迭代次数不匹配。期望: 30, 实际: %d", agent.maxIterations) } - + agent.mu.RLock() threshold := agent.largeResultThreshold agent.mu.RUnlock() - + if threshold != 50*1024 { t.Errorf("默认阈值不匹配。期望: %d, 实际: %d", 50*1024, threshold) } @@ -256,31 +256,30 @@ func TestAgent_NewAgent_DefaultValues(t *testing.T) { func TestAgent_NewAgent_CustomConfig(t *testing.T) { logger := zap.NewNop() mcpServer := mcp.NewServer(logger) - + openAICfg := &config.OpenAIConfig{ APIKey: "test-key", BaseURL: "https://api.test.com/v1", Model: "test-model", } - + agentCfg := &config.AgentConfig{ MaxIterations: 20, LargeResultThreshold: 100 * 1024, // 100KB ResultStorageDir: "custom_tmp", } - + agent := NewAgent(openAICfg, agentCfg, mcpServer, nil, logger, 15) - + if agent.maxIterations != 15 { t.Errorf("迭代次数不匹配。期望: 15, 实际: %d", agent.maxIterations) } - + agent.mu.RLock() threshold := agent.largeResultThreshold agent.mu.RUnlock() - + if threshold != 100*1024 { t.Errorf("阈值不匹配。期望: %d, 实际: %d", 100*1024, threshold) } } - diff --git a/internal/agents/markdown.go b/internal/agents/markdown.go index ab44ab04..b3aa8a0f 100644 --- a/internal/agents/markdown.go +++ b/internal/agents/markdown.go @@ -256,11 +256,11 @@ func orchestratorConfigFromOrchestrator(o *OrchestratorMarkdown) config.MultiAge return config.MultiAgentSubConfig{} } return config.MultiAgentSubConfig{ - ID: o.EinoName, - Name: o.DisplayName, - Description: o.Description, - Instruction: o.Instruction, - Kind: "orchestrator", + ID: o.EinoName, + Name: o.DisplayName, + Description: o.Description, + Instruction: o.Instruction, + Kind: "orchestrator", } } diff --git a/internal/handler/agent.go b/internal/handler/agent.go index a5342c87..9af3b67d 100644 --- a/internal/handler/agent.go +++ b/internal/handler/agent.go @@ -497,10 +497,10 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) { defer h.hitlManager.DeactivateConversation(conversationID) } - // 优先尝试从保存的ReAct数据恢复历史上下文 - agentHistoryMessages, err := h.loadHistoryFromReActData(conversationID) + // 优先尝试从保存的代理轨迹恢复历史上下文 + agentHistoryMessages, err := h.loadHistoryFromAgentTrace(conversationID) if err != nil { - h.logger.Warn("从ReAct数据加载历史消息失败,使用消息表", zap.Error(err)) + h.logger.Warn("从代理轨迹加载历史消息失败,使用消息表", zap.Error(err)) // 回退到使用数据库消息表 historyMessages, err := h.db.GetMessages(conversationID) if err != nil { @@ -518,7 +518,7 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) { h.logger.Info("从消息表加载历史消息", zap.Int("count", len(agentHistoryMessages))) } } else { - h.logger.Info("从ReAct数据恢复历史上下文", zap.Int("count", len(agentHistoryMessages))) + h.logger.Info("从代理轨迹恢复历史上下文", zap.Int("count", len(agentHistoryMessages))) } // 校验附件数量(非流式) @@ -613,12 +613,12 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) { if err != nil { h.logger.Error("Agent Loop执行失败", zap.Error(err)) - // 即使执行失败,也尝试保存ReAct数据(如果result中有) - if result != nil && (result.LastReActInput != "" || result.LastReActOutput != "") { - if saveErr := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); saveErr != nil { - h.logger.Warn("保存失败任务的ReAct数据失败", zap.Error(saveErr)) + // 即使执行失败,也尝试保存代理轨迹(如果 result 中有) + if result != nil && (result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "") { + if saveErr := h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); saveErr != nil { + h.logger.Warn("保存失败任务的代理轨迹失败", zap.Error(saveErr)) } else { - h.logger.Info("已保存失败任务的ReAct数据", zap.String("conversationId", conversationID)) + h.logger.Info("已保存失败任务的代理轨迹", zap.String("conversationId", conversationID)) } } @@ -634,12 +634,12 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) { // 因为AI已经生成了回复,用户应该能看到 } - // 保存最后一轮ReAct的输入和输出 - if result.LastReActInput != "" || result.LastReActOutput != "" { - if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil { - h.logger.Warn("保存ReAct数据失败", zap.Error(err)) + // 保存最后一轮代理轨迹与助手输出 + if result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "" { + if err := h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); err != nil { + h.logger.Warn("保存代理轨迹失败", zap.Error(err)) } else { - h.logger.Info("已保存ReAct数据", zap.String("conversationId", conversationID)) + h.logger.Info("已保存代理轨迹", zap.String("conversationId", conversationID)) } } @@ -666,7 +666,7 @@ func (h *AgentHandler) ProcessMessageForRobot(ctx context.Context, conversationI } } - agentHistoryMessages, err := h.loadHistoryFromReActData(conversationID) + agentHistoryMessages, err := h.loadHistoryFromAgentTrace(conversationID) if err != nil { historyMessages, getErr := h.db.GetMessages(conversationID) if getErr != nil { @@ -722,6 +722,7 @@ func (h *AgentHandler) ProcessMessageForRobot(ctx context.Context, conversationI "deep", ) if errMA != nil { + h.persistEinoAgentTraceForResume(conversationID, resultMA) errMsg := "执行失败: " + errMA.Error() if assistantMessageID != "" { _, _ = h.db.Exec("UPDATE messages SET content = ? WHERE id = ?", errMsg, assistantMessageID) @@ -747,8 +748,8 @@ func (h *AgentHandler) ProcessMessageForRobot(ctx context.Context, conversationI h.logger.Warn("机器人:保存助手消息失败", zap.Error(err)) } } - if resultMA.LastReActInput != "" || resultMA.LastReActOutput != "" { - _ = h.db.SaveReActData(conversationID, resultMA.LastReActInput, resultMA.LastReActOutput) + if resultMA.LastAgentTraceInput != "" || resultMA.LastAgentTraceOutput != "" { + _ = h.db.SaveAgentTrace(conversationID, resultMA.LastAgentTraceInput, resultMA.LastAgentTraceOutput) } return resultMA.Response, conversationID, nil } @@ -782,8 +783,8 @@ func (h *AgentHandler) ProcessMessageForRobot(ctx context.Context, conversationI h.logger.Warn("机器人:保存助手消息失败", zap.Error(err)) } } - if result.LastReActInput != "" || result.LastReActOutput != "" { - _ = h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput) + if result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "" { + _ = h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput) } return result.Response, conversationID, nil } @@ -1359,10 +1360,10 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) { } ssePublishConversationID = conversationID - // 优先尝试从保存的ReAct数据恢复历史上下文 - agentHistoryMessages, err := h.loadHistoryFromReActData(conversationID) + // 优先尝试从保存的代理轨迹恢复历史上下文 + agentHistoryMessages, err := h.loadHistoryFromAgentTrace(conversationID) if err != nil { - h.logger.Warn("从ReAct数据加载历史消息失败,使用消息表", zap.Error(err)) + h.logger.Warn("从代理轨迹加载历史消息失败,使用消息表", zap.Error(err)) // 回退到使用数据库消息表 historyMessages, err := h.db.GetMessages(conversationID) if err != nil { @@ -1380,7 +1381,7 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) { h.logger.Info("从消息表加载历史消息", zap.Int("count", len(agentHistoryMessages))) } } else { - h.logger.Info("从ReAct数据恢复历史上下文", zap.Int("count", len(agentHistoryMessages))) + h.logger.Info("从代理轨迹恢复历史上下文", zap.Int("count", len(agentHistoryMessages))) } // 校验附件数量 @@ -1579,12 +1580,12 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) { h.db.AddProcessDetail(assistantMessageID, conversationID, "cancelled", cancelMsg, nil) } - // 即使任务被取消,也尝试保存ReAct数据(如果result中有) - if result != nil && (result.LastReActInput != "" || result.LastReActOutput != "") { - if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil { - h.logger.Warn("保存取消任务的ReAct数据失败", zap.Error(err)) + // 即使任务被取消,也尝试保存代理轨迹(如果 result 中有) + if result != nil && (result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "") { + if err := h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); err != nil { + h.logger.Warn("保存取消任务的代理轨迹失败", zap.Error(err)) } else { - h.logger.Info("已保存取消任务的ReAct数据", zap.String("conversationId", conversationID)) + h.logger.Info("已保存取消任务的代理轨迹", zap.String("conversationId", conversationID)) } } @@ -1614,12 +1615,12 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) { h.db.AddProcessDetail(assistantMessageID, conversationID, "timeout", timeoutMsg, nil) } - // 即使任务超时,也尝试保存ReAct数据(如果result中有) - if result != nil && (result.LastReActInput != "" || result.LastReActOutput != "") { - if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil { - h.logger.Warn("保存超时任务的ReAct数据失败", zap.Error(err)) + // 即使任务超时,也尝试保存代理轨迹(如果 result 中有) + if result != nil && (result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "") { + if err := h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); err != nil { + h.logger.Warn("保存超时任务的代理轨迹失败", zap.Error(err)) } else { - h.logger.Info("已保存超时任务的ReAct数据", zap.String("conversationId", conversationID)) + h.logger.Info("已保存超时任务的代理轨迹", zap.String("conversationId", conversationID)) } } @@ -1649,12 +1650,12 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) { h.db.AddProcessDetail(assistantMessageID, conversationID, "error", errorMsg, nil) } - // 即使任务失败,也尝试保存ReAct数据(如果result中有) - if result != nil && (result.LastReActInput != "" || result.LastReActOutput != "") { - if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil { - h.logger.Warn("保存失败任务的ReAct数据失败", zap.Error(err)) + // 即使任务失败,也尝试保存代理轨迹(如果 result 中有) + if result != nil && (result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "") { + if err := h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); err != nil { + h.logger.Warn("保存失败任务的代理轨迹失败", zap.Error(err)) } else { - h.logger.Info("已保存失败任务的ReAct数据", zap.String("conversationId", conversationID)) + h.logger.Info("已保存失败任务的代理轨迹", zap.String("conversationId", conversationID)) } } @@ -1694,12 +1695,12 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) { } } - // 保存最后一轮ReAct的输入和输出 - if result.LastReActInput != "" || result.LastReActOutput != "" { - if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil { - h.logger.Warn("保存ReAct数据失败", zap.Error(err)) + // 保存最后一轮代理轨迹与助手输出 + if result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "" { + if err := h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); err != nil { + h.logger.Warn("保存代理轨迹失败", zap.Error(err)) } else { - h.logger.Info("已保存ReAct数据", zap.String("conversationId", conversationID)) + h.logger.Info("已保存代理轨迹", zap.String("conversationId", conversationID)) } } @@ -2499,6 +2500,9 @@ func (h *AgentHandler) executeBatchQueue(queueID string) { cancel() if runErr != nil { + if useRunResult { + h.persistEinoAgentTraceForResume(conversationID, resultMA) + } // 检查是否是取消错误 // 1. 直接检查是否是 context.Canceled(包括包装后的错误) // 2. 检查错误消息中是否包含"context canceled"或"cancelled"关键字 @@ -2542,14 +2546,14 @@ func (h *AgentHandler) executeBatchQueue(queueID string) { h.logger.Warn("保存取消消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(errMsg)) } } - // 保存ReAct数据(如果存在) - if result != nil && (result.LastReActInput != "" || result.LastReActOutput != "") { - if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil { - h.logger.Warn("保存取消任务的ReAct数据失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err)) + // 保存代理轨迹(如果存在) + if result != nil && (result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "") { + if err := h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); err != nil { + h.logger.Warn("保存取消任务的代理轨迹失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err)) } - } else if useRunResult && resultMA != nil && (resultMA.LastReActInput != "" || resultMA.LastReActOutput != "") { - if err := h.db.SaveReActData(conversationID, resultMA.LastReActInput, resultMA.LastReActOutput); err != nil { - h.logger.Warn("保存取消任务的ReAct数据失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err)) + } else if useRunResult && resultMA != nil && (resultMA.LastAgentTraceInput != "" || resultMA.LastAgentTraceOutput != "") { + if err := h.db.SaveAgentTrace(conversationID, resultMA.LastAgentTraceInput, resultMA.LastAgentTraceOutput); err != nil { + h.logger.Warn("保存取消任务的代理轨迹失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err)) } } h.batchTaskManager.UpdateTaskStatusWithConversationID(queueID, task.ID, "cancelled", cancelMsg, "", conversationID) @@ -2581,13 +2585,13 @@ func (h *AgentHandler) executeBatchQueue(queueID string) { if useRunResult { resText = resultMA.Response mcpIDs = resultMA.MCPExecutionIDs - lastIn = resultMA.LastReActInput - lastOut = resultMA.LastReActOutput + lastIn = resultMA.LastAgentTraceInput + lastOut = resultMA.LastAgentTraceOutput } else { resText = result.Response mcpIDs = result.MCPExecutionIDs - lastIn = result.LastReActInput - lastOut = result.LastReActOutput + lastIn = result.LastAgentTraceInput + lastOut = result.LastAgentTraceOutput } // 更新助手消息内容 @@ -2618,12 +2622,12 @@ func (h *AgentHandler) executeBatchQueue(queueID string) { } } - // 保存ReAct数据 + // 保存代理轨迹 if lastIn != "" || lastOut != "" { - if err := h.db.SaveReActData(conversationID, lastIn, lastOut); err != nil { - h.logger.Warn("保存ReAct数据失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err)) + if err := h.db.SaveAgentTrace(conversationID, lastIn, lastOut); err != nil { + h.logger.Warn("保存代理轨迹失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err)) } else { - h.logger.Info("已保存ReAct数据", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID)) + h.logger.Info("已保存代理轨迹", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID)) } } @@ -2642,36 +2646,33 @@ func (h *AgentHandler) executeBatchQueue(queueID string) { } } -// loadHistoryFromReActData 从保存的ReAct数据恢复历史消息上下文 -// 采用与攻击链生成类似的拼接逻辑:优先使用保存的last_react_input和last_react_output,若不存在则回退到消息表 -func (h *AgentHandler) loadHistoryFromReActData(conversationID string) ([]agent.ChatMessage, error) { - // 获取保存的ReAct输入和输出 - reactInputJSON, reactOutput, err := h.db.GetReActData(conversationID) +// loadHistoryFromAgentTrace 从库中保存的代理消息轨迹恢复历史(列 last_react_*;含单代理与 Eino)。 +// 逻辑与攻击链一致:优先用已保存的 JSON 消息带 + 最后一轮助手摘要,否则回退消息表。 +func (h *AgentHandler) loadHistoryFromAgentTrace(conversationID string) ([]agent.ChatMessage, error) { + traceInputJSON, assistantOut, err := h.db.GetAgentTrace(conversationID) if err != nil { - return nil, fmt.Errorf("获取ReAct数据失败: %w", err) + return nil, fmt.Errorf("获取代理轨迹失败: %w", err) } - // 如果last_react_input为空,回退到使用消息表(与攻击链生成逻辑一致) - if reactInputJSON == "" { - return nil, fmt.Errorf("ReAct数据为空,将使用消息表") + if traceInputJSON == "" { + return nil, fmt.Errorf("代理轨迹为空,将使用消息表") } - dataSource := "database_last_react_input" + dataSource := "database_last_agent_trace" - // 解析JSON格式的messages数组 var messagesArray []map[string]interface{} - if err := json.Unmarshal([]byte(reactInputJSON), &messagesArray); err != nil { - return nil, fmt.Errorf("解析ReAct输入JSON失败: %w", err) + if err := json.Unmarshal([]byte(traceInputJSON), &messagesArray); err != nil { + return nil, fmt.Errorf("解析代理轨迹 JSON 失败: %w", err) } messageCount := len(messagesArray) - h.logger.Info("使用保存的ReAct数据恢复历史上下文", + h.logger.Info("使用保存的代理轨迹恢复历史上下文", zap.String("conversationId", conversationID), zap.String("dataSource", dataSource), - zap.Int("reactInputSize", len(reactInputJSON)), + zap.Int("traceInputSize", len(traceInputJSON)), zap.Int("messageCount", messageCount), - zap.Int("reactOutputSize", len(reactOutput)), + zap.Int("assistantOutSize", len(assistantOut)), ) // fmt.Println("messagesArray:", messagesArray)//debug @@ -2755,53 +2756,44 @@ func (h *AgentHandler) loadHistoryFromReActData(conversationID string) ([]agent. agentMessages = append(agentMessages, msg) } - // 如果存在last_react_output,需要将其作为最后一条assistant消息 - // 因为last_react_input是在迭代开始前保存的,不包含最后一轮的最终输出 - if reactOutput != "" { - // 检查最后一条消息是否是assistant消息且没有tool_calls - // 如果有tool_calls,说明后面应该还有tool消息和最终的assistant回复 + // 若存在 last_react_output(助手摘要),合并为最后一条 assistant(与保存格式一致) + if assistantOut != "" { if len(agentMessages) > 0 { lastMsg := &agentMessages[len(agentMessages)-1] if strings.EqualFold(lastMsg.Role, "assistant") && len(lastMsg.ToolCalls) == 0 { - // 最后一条是assistant消息且没有tool_calls,用最终输出更新其content - lastMsg.Content = reactOutput + lastMsg.Content = assistantOut } else { - // 最后一条不是assistant消息,或者有tool_calls,添加最终输出作为新的assistant消息 agentMessages = append(agentMessages, agent.ChatMessage{ Role: "assistant", - Content: reactOutput, + Content: assistantOut, }) } } else { - // 如果没有消息,直接添加最终输出 agentMessages = append(agentMessages, agent.ChatMessage{ Role: "assistant", - Content: reactOutput, + Content: assistantOut, }) } } if len(agentMessages) == 0 { - return nil, fmt.Errorf("从ReAct数据解析的消息为空") + return nil, fmt.Errorf("从代理轨迹解析的消息为空") } - // 修复可能存在的失配tool消息,避免OpenAI报错 - // 这可以防止出现"messages with role 'tool' must be a response to a preceeding message with 'tool_calls'"错误 if h.agent != nil { if fixed := h.agent.RepairOrphanToolMessages(&agentMessages); fixed { - h.logger.Info("修复了从ReAct数据恢复的历史消息中的失配tool消息", + h.logger.Info("修复了从代理轨迹恢复的历史消息中的失配 tool 消息", zap.String("conversationId", conversationID), ) } } - h.logger.Info("从ReAct数据恢复历史消息完成", + h.logger.Info("从代理轨迹恢复历史消息完成", zap.String("conversationId", conversationID), zap.String("dataSource", dataSource), zap.Int("originalMessageCount", messageCount), zap.Int("finalMessageCount", len(agentMessages)), - zap.Bool("hasReactOutput", reactOutput != ""), + zap.Bool("hasAssistantOut", assistantOut != ""), ) - fmt.Println("agentMessages:", agentMessages) //debug return agentMessages, nil } diff --git a/internal/handler/attackchain.go b/internal/handler/attackchain.go index 2b78b9bf..837516e8 100644 --- a/internal/handler/attackchain.go +++ b/internal/handler/attackchain.go @@ -83,7 +83,7 @@ func (h *AttackChainHandler) GetAttackChain(c *gin.Context) { // 使用锁机制防止同一对话的并发生成 lockInterface, _ := h.generatingLocks.LoadOrStore(conversationID, &sync.Mutex{}) lock := lockInterface.(*sync.Mutex) - + // 尝试获取锁,如果正在生成则返回错误 acquired := lock.TryLock() if !acquired { @@ -144,7 +144,7 @@ func (h *AttackChainHandler) RegenerateAttackChain(c *gin.Context) { // 使用锁机制防止并发生成 lockInterface, _ := h.generatingLocks.LoadOrStore(conversationID, &sync.Mutex{}) lock := lockInterface.(*sync.Mutex) - + acquired := lock.TryLock() if !acquired { h.logger.Info("攻击链正在生成中,请稍后再试", zap.String("conversationId", conversationID)) @@ -170,4 +170,3 @@ func (h *AttackChainHandler) RegenerateAttackChain(c *gin.Context) { c.JSON(http.StatusOK, chain) } - diff --git a/internal/handler/conversation.go b/internal/handler/conversation.go index 4bb72bbe..0bd538ec 100644 --- a/internal/handler/conversation.go +++ b/internal/handler/conversation.go @@ -230,4 +230,3 @@ func (h *ConversationHandler) DeleteConversationTurn(c *gin.Context) { "message": "ok", }) } - diff --git a/internal/handler/eino_single_agent.go b/internal/handler/eino_single_agent.go index 36fc3c1b..a3ed3e6c 100644 --- a/internal/handler/eino_single_agent.go +++ b/internal/handler/eino_single_agent.go @@ -175,6 +175,7 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) { ) if runErr != nil { + h.persistEinoAgentTraceForResume(conversationID, result) cause := context.Cause(baseCtx) if errors.Is(cause, ErrTaskCancelled) { taskStatus = "cancelled" @@ -239,9 +240,9 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) { ) } - if result.LastReActInput != "" || result.LastReActOutput != "" { - if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil { - h.logger.Warn("保存 ReAct 数据失败", zap.Error(err)) + if result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "" { + if err := h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); err != nil { + h.logger.Warn("保存代理轨迹失败", zap.Error(err)) } } @@ -306,6 +307,7 @@ func (h *AgentHandler) EinoSingleAgentLoop(c *gin.Context) { progressCallback, ) if runErr != nil { + h.persistEinoAgentTraceForResume(prep.ConversationID, result) c.JSON(http.StatusInternalServerError, gin.H{"error": runErr.Error()}) return } @@ -323,8 +325,8 @@ func (h *AgentHandler) EinoSingleAgentLoop(c *gin.Context) { prep.AssistantMessageID, ) } - if result.LastReActInput != "" || result.LastReActOutput != "" { - _ = h.db.SaveReActData(prep.ConversationID, result.LastReActInput, result.LastReActOutput) + if result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "" { + _ = h.db.SaveAgentTrace(prep.ConversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput) } c.JSON(http.StatusOK, gin.H{ diff --git a/internal/handler/external_mcp_test.go b/internal/handler/external_mcp_test.go index e52eeced..16c64ff7 100644 --- a/internal/handler/external_mcp_test.go +++ b/internal/handler/external_mcp_test.go @@ -247,7 +247,7 @@ func TestExternalMCPHandler_DeleteExternalMCP(t *testing.T) { // 先添加一个配置 configObj := config.ExternalMCPServerConfig{ - Command: "python3", + Command: "python3", ExternalMCPEnable: true, } handler.manager.AddOrUpdateConfig("test-delete", configObj) @@ -276,11 +276,11 @@ func TestExternalMCPHandler_GetExternalMCPs(t *testing.T) { // 添加多个配置 handler.manager.AddOrUpdateConfig("test1", config.ExternalMCPServerConfig{ - Command: "python3", + Command: "python3", ExternalMCPEnable: true, }) handler.manager.AddOrUpdateConfig("test2", config.ExternalMCPServerConfig{ - URL: "http://127.0.0.1:8081/mcp", + URL: "http://127.0.0.1:8081/mcp", ExternalMCPEnable: false, }) @@ -319,15 +319,15 @@ func TestExternalMCPHandler_GetExternalMCPStats(t *testing.T) { // 添加配置 handler.manager.AddOrUpdateConfig("enabled1", config.ExternalMCPServerConfig{ - Command: "python3", + Command: "python3", ExternalMCPEnable: true, }) handler.manager.AddOrUpdateConfig("enabled2", config.ExternalMCPServerConfig{ - URL: "http://127.0.0.1:8081/mcp", + URL: "http://127.0.0.1:8081/mcp", ExternalMCPEnable: true, }) handler.manager.AddOrUpdateConfig("disabled1", config.ExternalMCPServerConfig{ - Command: "python3", + Command: "python3", }) req := httptest.NewRequest("GET", "/api/external-mcp/stats", nil) @@ -360,7 +360,7 @@ func TestExternalMCPHandler_StartStopExternalMCP(t *testing.T) { // 添加一个禁用的配置 handler.manager.AddOrUpdateConfig("test-start-stop", config.ExternalMCPServerConfig{ - Command: "python3", + Command: "python3", }) // 测试启动(可能会失败,因为没有真实的服务器) @@ -416,7 +416,7 @@ func TestExternalMCPHandler_AddOrUpdateExternalMCP_EmptyName(t *testing.T) { router, _, _ := setupTestRouter() configObj := config.ExternalMCPServerConfig{ - Command: "python3", + Command: "python3", ExternalMCPEnable: true, } @@ -459,14 +459,14 @@ func TestExternalMCPHandler_UpdateExistingConfig(t *testing.T) { // 先添加配置 config1 := config.ExternalMCPServerConfig{ - Command: "python3", + Command: "python3", ExternalMCPEnable: true, } handler.manager.AddOrUpdateConfig("test-update", config1) // 更新配置 config2 := config.ExternalMCPServerConfig{ - URL: "http://127.0.0.1:8081/mcp", + URL: "http://127.0.0.1:8081/mcp", ExternalMCPEnable: true, } diff --git a/internal/handler/markdown_agents.go b/internal/handler/markdown_agents.go index 2341aaaf..bc7abb47 100644 --- a/internal/handler/markdown_agents.go +++ b/internal/handler/markdown_agents.go @@ -131,16 +131,16 @@ func (h *MarkdownAgentsHandler) GetMarkdownAgent(c *gin.Context) { } type markdownAgentBody struct { - Filename string `json:"filename"` - ID string `json:"id"` - Name string `json:"name"` - Description string `json:"description"` - Tools []string `json:"tools"` - Instruction string `json:"instruction"` - BindRole string `json:"bind_role"` - MaxIterations int `json:"max_iterations"` - Kind string `json:"kind"` - Raw string `json:"raw"` + Filename string `json:"filename"` + ID string `json:"id"` + Name string `json:"name"` + Description string `json:"description"` + Tools []string `json:"tools"` + Instruction string `json:"instruction"` + BindRole string `json:"bind_role"` + MaxIterations int `json:"max_iterations"` + Kind string `json:"kind"` + Raw string `json:"raw"` } // CreateMarkdownAgent POST /api/multi-agent/markdown-agents diff --git a/internal/handler/monitor.go b/internal/handler/monitor.go index c337c374..17e6c79c 100644 --- a/internal/handler/monitor.go +++ b/internal/handler/monitor.go @@ -42,11 +42,11 @@ func (h *MonitorHandler) SetExternalMCPManager(mgr *mcp.ExternalMCPManager) { type MonitorResponse struct { Executions []*mcp.ToolExecution `json:"executions"` Stats map[string]*mcp.ToolStats `json:"stats"` - Timestamp time.Time `json:"timestamp"` - Total int `json:"total,omitempty"` - Page int `json:"page,omitempty"` - PageSize int `json:"page_size,omitempty"` - TotalPages int `json:"total_pages,omitempty"` + Timestamp time.Time `json:"timestamp"` + Total int `json:"total,omitempty"` + Page int `json:"page,omitempty"` + PageSize int `json:"page_size,omitempty"` + TotalPages int `json:"total_pages,omitempty"` } // Monitor 获取监控信息 @@ -213,7 +213,6 @@ func (h *MonitorHandler) loadStats() map[string]*mcp.ToolStats { return stats } - // GetExecution 获取特定执行记录 func (h *MonitorHandler) GetExecution(c *gin.Context) { id := c.Param("id") @@ -416,5 +415,3 @@ func (h *MonitorHandler) DeleteExecutions(c *gin.Context) { h.logger.Info("尝试批量删除内存中的执行记录", zap.Int("count", len(request.IDs))) c.JSON(http.StatusOK, gin.H{"message": "执行记录已删除(如果存在)"}) } - - diff --git a/internal/handler/multi_agent.go b/internal/handler/multi_agent.go index 0684ad87..8d8d896f 100644 --- a/internal/handler/multi_agent.go +++ b/internal/handler/multi_agent.go @@ -185,6 +185,7 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) { ) if runErr != nil { + h.persistEinoAgentTraceForResume(conversationID, result) cause := context.Cause(baseCtx) if errors.Is(cause, ErrTaskCancelled) { taskStatus = "cancelled" @@ -249,9 +250,9 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) { ) } - if result.LastReActInput != "" || result.LastReActOutput != "" { - if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil { - h.logger.Warn("保存 ReAct 数据失败", zap.Error(err)) + if result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "" { + if err := h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); err != nil { + h.logger.Warn("保存代理轨迹失败", zap.Error(err)) } } @@ -318,6 +319,7 @@ func (h *AgentHandler) MultiAgentLoop(c *gin.Context) { strings.TrimSpace(req.Orchestration), ) if runErr != nil { + h.persistEinoAgentTraceForResume(prep.ConversationID, result) h.logger.Error("Eino DeepAgent 执行失败", zap.Error(runErr)) errMsg := "执行失败: " + runErr.Error() if prep.AssistantMessageID != "" { @@ -341,9 +343,9 @@ func (h *AgentHandler) MultiAgentLoop(c *gin.Context) { ) } - if result.LastReActInput != "" || result.LastReActOutput != "" { - if err := h.db.SaveReActData(prep.ConversationID, result.LastReActInput, result.LastReActOutput); err != nil { - h.logger.Warn("保存 ReAct 数据失败", zap.Error(err)) + if result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "" { + if err := h.db.SaveAgentTrace(prep.ConversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); err != nil { + h.logger.Warn("保存代理轨迹失败", zap.Error(err)) } } @@ -355,6 +357,19 @@ func (h *AgentHandler) MultiAgentLoop(c *gin.Context) { }) } +// persistEinoAgentTraceForResume 在 Eino 运行异常结束时写入代理轨迹(库列 last_react_*),供下一请求 loadHistoryFromAgentTrace 软续跑。 +func (h *AgentHandler) persistEinoAgentTraceForResume(conversationID string, result *multiagent.RunResult) { + if h == nil || result == nil { + return + } + if result.LastAgentTraceInput == "" && result.LastAgentTraceOutput == "" { + return + } + if err := h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); err != nil { + h.logger.Warn("保存 Eino 续跑上下文失败", zap.String("conversationId", conversationID), zap.Error(err)) + } +} + func multiAgentHTTPErrorStatus(err error) (int, string) { msg := err.Error() switch { diff --git a/internal/handler/multi_agent_prepare.go b/internal/handler/multi_agent_prepare.go index ca03f4a7..0e9cd43b 100644 --- a/internal/handler/multi_agent_prepare.go +++ b/internal/handler/multi_agent_prepare.go @@ -49,7 +49,7 @@ func (h *AgentHandler) prepareMultiAgentSession(req *ChatRequest) (*multiAgentPr } } - agentHistoryMessages, err := h.loadHistoryFromReActData(conversationID) + agentHistoryMessages, err := h.loadHistoryFromAgentTrace(conversationID) if err != nil { historyMessages, getErr := h.db.GetMessages(conversationID) if getErr != nil { diff --git a/internal/handler/openapi.go b/internal/handler/openapi.go index 45216daa..1b6dc4d4 100644 --- a/internal/handler/openapi.go +++ b/internal/handler/openapi.go @@ -4445,7 +4445,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) { "content": map[string]interface{}{ "application/json": map[string]interface{}{ "schema": map[string]interface{}{ - "type": "object", + "type": "object", "required": []string{"messageId"}, "properties": map[string]interface{}{ "messageId": map[string]interface{}{ @@ -4689,7 +4689,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) { "content": map[string]interface{}{ "application/json": map[string]interface{}{ "schema": map[string]interface{}{ - "type": "object", + "type": "object", "required": []string{"scheduleEnabled"}, "properties": map[string]interface{}{ "scheduleEnabled": map[string]interface{}{"type": "boolean", "description": "是否启用自动调度"}, @@ -4761,7 +4761,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) { "content": map[string]interface{}{ "application/json": map[string]interface{}{ "schema": map[string]interface{}{ - "type": "object", + "type": "object", "required": []string{"query"}, "properties": map[string]interface{}{ "query": map[string]interface{}{"type": "string", "description": "FOFA查询语法", "example": "domain=\"example.com\""}, @@ -4810,7 +4810,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) { "content": map[string]interface{}{ "application/json": map[string]interface{}{ "schema": map[string]interface{}{ - "type": "object", + "type": "object", "required": []string{"text"}, "properties": map[string]interface{}{ "text": map[string]interface{}{"type": "string", "description": "自然语言描述", "example": "查找使用WordPress的网站"}, @@ -4853,7 +4853,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) { "content": map[string]interface{}{ "application/json": map[string]interface{}{ "schema": map[string]interface{}{ - "type": "object", + "type": "object", "required": []string{"api_key", "model"}, "properties": map[string]interface{}{ "provider": map[string]interface{}{"type": "string", "description": "LLM提供商(openai/claude)", "example": "openai"}, @@ -4900,7 +4900,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) { "content": map[string]interface{}{ "application/json": map[string]interface{}{ "schema": map[string]interface{}{ - "type": "object", + "type": "object", "required": []string{"command"}, "properties": map[string]interface{}{ "command": map[string]interface{}{"type": "string", "description": "要执行的命令"}, @@ -4943,7 +4943,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) { "content": map[string]interface{}{ "application/json": map[string]interface{}{ "schema": map[string]interface{}{ - "type": "object", + "type": "object", "required": []string{"command"}, "properties": map[string]interface{}{ "command": map[string]interface{}{"type": "string", "description": "要执行的命令"}, @@ -5027,7 +5027,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) { "content": map[string]interface{}{ "application/json": map[string]interface{}{ "schema": map[string]interface{}{ - "type": "object", + "type": "object", "required": []string{"url"}, "properties": map[string]interface{}{ "url": map[string]interface{}{"type": "string", "description": "WebShell URL"}, @@ -5231,7 +5231,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) { "content": map[string]interface{}{ "application/json": map[string]interface{}{ "schema": map[string]interface{}{ - "type": "object", + "type": "object", "required": []string{"url", "command"}, "properties": map[string]interface{}{ "url": map[string]interface{}{"type": "string", "description": "WebShell URL"}, @@ -5277,7 +5277,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) { "content": map[string]interface{}{ "application/json": map[string]interface{}{ "schema": map[string]interface{}{ - "type": "object", + "type": "object", "required": []string{"url", "action", "path"}, "properties": map[string]interface{}{ "url": map[string]interface{}{"type": "string", "description": "WebShell URL"}, @@ -5339,14 +5339,14 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) { "items": map[string]interface{}{ "type": "object", "properties": map[string]interface{}{ - "relativePath": map[string]interface{}{"type": "string"}, - "absolutePath": map[string]interface{}{"type": "string"}, - "name": map[string]interface{}{"type": "string"}, - "size": map[string]interface{}{"type": "integer"}, - "modifiedUnix": map[string]interface{}{"type": "integer"}, - "date": map[string]interface{}{"type": "string"}, + "relativePath": map[string]interface{}{"type": "string"}, + "absolutePath": map[string]interface{}{"type": "string"}, + "name": map[string]interface{}{"type": "string"}, + "size": map[string]interface{}{"type": "integer"}, + "modifiedUnix": map[string]interface{}{"type": "integer"}, + "date": map[string]interface{}{"type": "string"}, "conversationId": map[string]interface{}{"type": "string"}, - "subPath": map[string]interface{}{"type": "string"}, + "subPath": map[string]interface{}{"type": "string"}, }, }, }, @@ -5369,7 +5369,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) { "content": map[string]interface{}{ "multipart/form-data": map[string]interface{}{ "schema": map[string]interface{}{ - "type": "object", + "type": "object", "required": []string{"file"}, "properties": map[string]interface{}{ "file": map[string]interface{}{"type": "string", "format": "binary", "description": "上传的文件"}, @@ -5410,7 +5410,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) { "content": map[string]interface{}{ "application/json": map[string]interface{}{ "schema": map[string]interface{}{ - "type": "object", + "type": "object", "required": []string{"path"}, "properties": map[string]interface{}{ "path": map[string]interface{}{"type": "string", "description": "文件相对路径"}, @@ -5485,7 +5485,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) { "content": map[string]interface{}{ "application/json": map[string]interface{}{ "schema": map[string]interface{}{ - "type": "object", + "type": "object", "required": []string{"path", "content"}, "properties": map[string]interface{}{ "path": map[string]interface{}{"type": "string", "description": "文件相对路径"}, @@ -5512,7 +5512,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) { "content": map[string]interface{}{ "application/json": map[string]interface{}{ "schema": map[string]interface{}{ - "type": "object", + "type": "object", "required": []string{"name"}, "properties": map[string]interface{}{ "parent": map[string]interface{}{"type": "string", "description": "父目录相对路径"}, @@ -5552,7 +5552,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) { "content": map[string]interface{}{ "application/json": map[string]interface{}{ "schema": map[string]interface{}{ - "type": "object", + "type": "object", "required": []string{"path", "newName"}, "properties": map[string]interface{}{ "path": map[string]interface{}{"type": "string", "description": "当前文件相对路径"}, @@ -5646,7 +5646,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) { "content": map[string]interface{}{ "application/json": map[string]interface{}{ "schema": map[string]interface{}{ - "type": "object", + "type": "object", "required": []string{"platform", "text"}, "properties": map[string]interface{}{ "platform": map[string]interface{}{"type": "string", "description": "平台类型", "enum": []string{"dingtalk", "lark", "wecom"}}, @@ -5712,7 +5712,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) { "content": map[string]interface{}{ "application/json": map[string]interface{}{ "schema": map[string]interface{}{ - "type": "object", + "type": "object", "required": []string{"name"}, "properties": map[string]interface{}{ "filename": map[string]interface{}{"type": "string", "description": "文件名(可选,自动生成)"}, @@ -5932,7 +5932,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) { "content": map[string]interface{}{ "application/json": map[string]interface{}{ "schema": map[string]interface{}{ - "type": "object", + "type": "object", "required": []string{"path"}, "properties": map[string]interface{}{ "path": map[string]interface{}{"type": "string", "description": "文件相对路径"}, @@ -5974,7 +5974,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) { "content": map[string]interface{}{ "application/json": map[string]interface{}{ "schema": map[string]interface{}{ - "type": "object", + "type": "object", "required": []string{"ids"}, "properties": map[string]interface{}{ "ids": map[string]interface{}{ diff --git a/internal/handler/openapi_i18n.go b/internal/handler/openapi_i18n.go index 250842cd..953c9d2a 100644 --- a/internal/handler/openapi_i18n.go +++ b/internal/handler/openapi_i18n.go @@ -26,7 +26,7 @@ var apiDocI18nSummaryToKey = map[string]string{ "创建分组": "createGroup", "列出分组": "listGroups", "获取分组": "getGroup", "更新分组": "updateGroup", "删除分组": "deleteGroup", "获取分组中的对话": "getGroupConversations", "添加对话到分组": "addConversationToGroup", "从分组移除对话": "removeConversationFromGroup", - "列出漏洞": "listVulnerabilities", "创建漏洞": "createVulnerability", "获取漏洞统计": "getVulnerabilityStats", + "列出漏洞": "listVulnerabilities", "创建漏洞": "createVulnerability", "获取漏洞统计": "getVulnerabilityStats", "获取漏洞": "getVulnerability", "更新漏洞": "updateVulnerability", "删除漏洞": "deleteVulnerability", "列出角色": "listRoles", "创建角色": "createRole", "获取角色": "getRole", "更新角色": "updateRole", "删除角色": "deleteRole", "获取可用Skills列表": "getAvailableSkills", "列出Skills": "listSkills", "创建Skill": "createSkill", @@ -52,9 +52,9 @@ var apiDocI18nSummaryToKey = map[string]string{ "重跑批量任务队列": "rerunBatchQueue", "修改队列元数据": "updateBatchQueueMetadata", "修改队列调度配置": "updateBatchQueueSchedule", "开关Cron自动调度": "setBatchQueueScheduleEnabled", "获取所有分组映射": "getAllGroupMappings", - "FOFA搜索": "fofaSearch", "自然语言解析为FOFA语法": "fofaParse", + "FOFA搜索": "fofaSearch", "自然语言解析为FOFA语法": "fofaParse", "测试OpenAI API连接": "testOpenAI", - "执行终端命令": "terminalRun", "流式执行终端命令": "terminalRunStream", "WebSocket终端": "terminalWS", + "执行终端命令": "terminalRun", "流式执行终端命令": "terminalRunStream", "WebSocket终端": "terminalWS", "列出WebShell连接": "listWebshellConnections", "创建WebShell连接": "createWebshellConnection", "更新WebShell连接": "updateWebshellConnection", "删除WebShell连接": "deleteWebshellConnection", "获取连接状态": "getWebshellConnectionState", "保存连接状态": "saveWebshellConnectionState", @@ -69,7 +69,7 @@ var apiDocI18nSummaryToKey = map[string]string{ "获取Markdown代理详情": "getMarkdownAgent", "更新Markdown代理": "updateMarkdownAgent", "删除Markdown代理": "deleteMarkdownAgent", "列出技能包文件": "listSkillPackageFiles", "获取技能包文件内容": "getSkillPackageFile", "写入技能包文件": "putSkillPackageFile", "批量获取工具名称": "batchGetToolNames", - "获取知识库统计": "getKnowledgeStats", + "获取知识库统计": "getKnowledgeStats", } var apiDocI18nResponseDescToKey = map[string]string{ @@ -78,7 +78,7 @@ var apiDocI18nResponseDescToKey = map[string]string{ "对话不存在或结果不存在": "conversationOrResultNotFound", "请求参数错误(如task为空)": "badRequestTaskEmpty", "请求参数错误或分组名称已存在": "badRequestGroupNameExists", "分组不存在": "groupNotFound", "请求参数错误(如配置格式不正确、缺少必需字段等)": "badRequestConfig", - "请求参数错误(如query为空)": "badRequestQueryEmpty", "方法不允许(仅支持POST请求)": "methodNotAllowed", + "请求参数错误(如query为空)": "badRequestQueryEmpty", "方法不允许(仅支持POST请求)": "methodNotAllowed", "登录成功": "loginSuccess", "密码错误": "invalidPassword", "登出成功": "logoutSuccess", "密码修改成功": "passwordChanged", "Token有效": "tokenValid", "Token无效或已过期": "tokenInvalid", "对话创建成功": "conversationCreated", "服务器内部错误": "internalError", "更新成功": "updateSuccess", @@ -89,7 +89,7 @@ var apiDocI18nResponseDescToKey = map[string]string{ "消息发送成功,返回AI回复": "messageSent", "流式响应(Server-Sent Events)": "streamResponse", // 新增缺失端点响应 "参数错误或删除失败": "badRequestOrDeleteFailed", - "参数错误": "paramError", "仅已完成或已取消的队列可以重跑": "onlyCompletedOrCancelledCanRerun", + "参数错误": "paramError", "仅已完成或已取消的队列可以重跑": "onlyCompletedOrCancelledCanRerun", "参数错误或队列正在运行中": "badRequestOrQueueRunning", "设置成功": "setSuccess", "搜索成功": "searchSuccess", "解析成功": "parseSuccess", "测试结果": "testResult", "执行完成": "executionDone", "SSE事件流": "sseEventStream", "WebSocket连接已建立": "wsEstablished", diff --git a/internal/handler/robot.go b/internal/handler/robot.go index a7b8f3a7..7d701fc6 100644 --- a/internal/handler/robot.go +++ b/internal/handler/robot.go @@ -28,20 +28,20 @@ import ( ) const ( - robotCmdHelp = "帮助" - robotCmdList = "列表" - robotCmdListAlt = "对话列表" - robotCmdSwitch = "切换" - robotCmdContinue = "继续" - robotCmdNew = "新对话" - robotCmdClear = "清空" - robotCmdCurrent = "当前" - robotCmdStop = "停止" - robotCmdRoles = "角色" - robotCmdRolesList = "角色列表" - robotCmdSwitchRole = "切换角色" - robotCmdDelete = "删除" - robotCmdVersion = "版本" + robotCmdHelp = "帮助" + robotCmdList = "列表" + robotCmdListAlt = "对话列表" + robotCmdSwitch = "切换" + robotCmdContinue = "继续" + robotCmdNew = "新对话" + robotCmdClear = "清空" + robotCmdCurrent = "当前" + robotCmdStop = "停止" + robotCmdRoles = "角色" + robotCmdRolesList = "角色列表" + robotCmdSwitchRole = "切换角色" + robotCmdDelete = "删除" + robotCmdVersion = "版本" ) // RobotHandler 企业微信/钉钉/飞书等机器人回调处理 diff --git a/internal/handler/skills.go b/internal/handler/skills.go index 09e5e0ff..52f2dc99 100644 --- a/internal/handler/skills.go +++ b/internal/handler/skills.go @@ -65,19 +65,19 @@ func (h *SkillsHandler) GetSkills(c *gin.Context) { allSkillsInfo := make([]map[string]interface{}, 0, len(allSummaries)) for _, s := range allSummaries { skillInfo := map[string]interface{}{ - "id": s.ID, - "name": s.Name, - "dir_name": s.DirName, - "description": s.Description, - "version": s.Version, - "path": s.Path, - "tags": s.Tags, - "triggers": s.Triggers, - "script_count": s.ScriptCount, - "file_count": s.FileCount, - "progressive": s.Progressive, - "file_size": s.FileSize, - "mod_time": s.ModTime, + "id": s.ID, + "name": s.Name, + "dir_name": s.DirName, + "description": s.Description, + "version": s.Version, + "path": s.Path, + "tags": s.Tags, + "triggers": s.Triggers, + "script_count": s.ScriptCount, + "file_count": s.FileCount, + "progressive": s.Progressive, + "file_size": s.FileSize, + "mod_time": s.ModTime, } allSkillsInfo = append(allSkillsInfo, skillInfo) } diff --git a/internal/handler/terminal_ws_unix.go b/internal/handler/terminal_ws_unix.go index eaa5df67..0f446d83 100644 --- a/internal/handler/terminal_ws_unix.go +++ b/internal/handler/terminal_ws_unix.go @@ -109,4 +109,3 @@ func (h *TerminalHandler) RunCommandWS(c *gin.Context) { <-doneChan } - diff --git a/internal/handler/vulnerability.go b/internal/handler/vulnerability.go index d9531976..f9a578bd 100644 --- a/internal/handler/vulnerability.go +++ b/internal/handler/vulnerability.go @@ -28,7 +28,7 @@ func NewVulnerabilityHandler(db *database.DB, logger *zap.Logger) *Vulnerability // CreateVulnerabilityRequest 创建漏洞请求 type CreateVulnerabilityRequest struct { - ConversationID string `json:"conversation_id" binding:"required"` + ConversationID string `json:"conversation_id" binding:"required"` ConversationTag string `json:"conversation_tag"` TaskTag string `json:"task_tag"` Title string `json:"title" binding:"required"` @@ -51,18 +51,18 @@ func (h *VulnerabilityHandler) CreateVulnerability(c *gin.Context) { } vuln := &database.Vulnerability{ - ConversationID: req.ConversationID, + ConversationID: req.ConversationID, ConversationTag: req.ConversationTag, TaskTag: req.TaskTag, - Title: req.Title, - Description: req.Description, - Severity: req.Severity, - Status: req.Status, - Type: req.Type, - Target: req.Target, - Proof: req.Proof, - Impact: req.Impact, - Recommendation: req.Recommendation, + Title: req.Title, + Description: req.Description, + Severity: req.Severity, + Status: req.Status, + Type: req.Type, + Target: req.Target, + Proof: req.Proof, + Impact: req.Impact, + Recommendation: req.Recommendation, } created, err := h.db.CreateVulnerability(vuln) @@ -172,15 +172,15 @@ func (h *VulnerabilityHandler) ListVulnerabilities(c *gin.Context) { type UpdateVulnerabilityRequest struct { ConversationTag string `json:"conversation_tag"` TaskTag string `json:"task_tag"` - Title string `json:"title"` - Description string `json:"description"` - Severity string `json:"severity"` - Status string `json:"status"` - Type string `json:"type"` - Target string `json:"target"` - Proof string `json:"proof"` - Impact string `json:"impact"` - Recommendation string `json:"recommendation"` + Title string `json:"title"` + Description string `json:"description"` + Severity string `json:"severity"` + Status string `json:"status"` + Type string `json:"type"` + Target string `json:"target"` + Proof string `json:"proof"` + Impact string `json:"impact"` + Recommendation string `json:"recommendation"` } // UpdateVulnerability 更新漏洞 @@ -460,4 +460,3 @@ func sanitizeExportName(raw string) string { replacer := strings.NewReplacer("/", "-", "\\", "-", ":", "-", "*", "-", "?", "-", "\"", "-", "<", "-", ">", "-", "|", "-") return replacer.Replace(name) } - diff --git a/internal/knowledge/eino_meta.go b/internal/knowledge/eino_meta.go index 2ae419c4..0ee7c41b 100644 --- a/internal/knowledge/eino_meta.go +++ b/internal/knowledge/eino_meta.go @@ -16,9 +16,9 @@ const ( // DSL keys for [VectorEinoRetriever.Retrieve] via [retriever.WithDSLInfo]. const ( - DSLRiskType = "risk_type" - DSLSimilarityThreshold = "similarity_threshold" - DSLSubIndexFilter = "sub_index_filter" + DSLRiskType = "risk_type" + DSLSimilarityThreshold = "similarity_threshold" + DSLSubIndexFilter = "sub_index_filter" ) // FormatEmbeddingInput matches the historical indexing format so existing embeddings diff --git a/internal/knowledge/index_pipeline.go b/internal/knowledge/index_pipeline.go index de5d466e..a9b9a4c4 100644 --- a/internal/knowledge/index_pipeline.go +++ b/internal/knowledge/index_pipeline.go @@ -8,8 +8,8 @@ import ( "cyberstrike-ai/internal/config" - "github.com/cloudwego/eino/compose" "github.com/cloudwego/eino/components/document" + "github.com/cloudwego/eino/compose" "github.com/cloudwego/eino/schema" ) diff --git a/internal/knowledge/indexer.go b/internal/knowledge/indexer.go index 390835c6..aeb6b9ff 100644 --- a/internal/knowledge/indexer.go +++ b/internal/knowledge/indexer.go @@ -11,9 +11,9 @@ import ( "cyberstrike-ai/internal/config" fileloader "github.com/cloudwego/eino-ext/components/document/loader/file" - "github.com/cloudwego/eino/compose" "github.com/cloudwego/eino/components/document" "github.com/cloudwego/eino/components/indexer" + "github.com/cloudwego/eino/compose" "github.com/cloudwego/eino/schema" "go.uber.org/zap" ) @@ -35,14 +35,14 @@ type Indexer struct { lastErrorTime time.Time errorCount int - rebuildMu sync.RWMutex - isRebuilding bool - rebuildTotalItems int - rebuildCurrent int - rebuildFailed int - rebuildStartTime time.Time - rebuildLastItemID string - rebuildLastChunks int + rebuildMu sync.RWMutex + isRebuilding bool + rebuildTotalItems int + rebuildCurrent int + rebuildFailed int + rebuildStartTime time.Time + rebuildLastItemID string + rebuildLastChunks int } // NewIndexer 创建索引器并编译 Eino 索引链;kcfg 为完整知识库配置(含 indexing 与路径相关行为)。 diff --git a/internal/knowledge/types.go b/internal/knowledge/types.go index 80d0eb5f..42e35e76 100644 --- a/internal/knowledge/types.go +++ b/internal/knowledge/types.go @@ -108,9 +108,9 @@ func (r *RetrievalLog) MarshalJSON() ([]byte, error) { // CategoryWithItems 分类及其下的知识项(用于按分类分页) type CategoryWithItems struct { - Category string `json:"category"` // 分类名称 - ItemCount int `json:"itemCount"` // 该分类下的知识项总数 - Items []*KnowledgeItemSummary `json:"items"` // 该分类下的知识项列表 + Category string `json:"category"` // 分类名称 + ItemCount int `json:"itemCount"` // 该分类下的知识项总数 + Items []*KnowledgeItemSummary `json:"items"` // 该分类下的知识项列表 } // SearchRequest 搜索请求 diff --git a/internal/openai/claude_bridge.go b/internal/openai/claude_bridge.go index ca3a608a..f61e642d 100644 --- a/internal/openai/claude_bridge.go +++ b/internal/openai/claude_bridge.go @@ -192,13 +192,13 @@ func convertOpenAIToClaude(payload interface{}) (*claudeRequest, error) { fnName, _ := fn["name"].(string) fnArgs, _ := fn["arguments"] - // 防御:缺少 name 或 id 的 tool_call 会被 Claude 拒绝 - if strings.TrimSpace(fnName) == "" { - fnName = "unknown_function" - } - if strings.TrimSpace(tcID) == "" { - tcID = fmt.Sprintf("call_%d", time.Now().UnixNano()) - } + // 防御:缺少 name 或 id 的 tool_call 会被 Claude 拒绝 + if strings.TrimSpace(fnName) == "" { + fnName = "unknown_function" + } + if strings.TrimSpace(tcID) == "" { + tcID = fmt.Sprintf("call_%d", time.Now().UnixNano()) + } var inputRaw json.RawMessage switch v := fnArgs.(type) { diff --git a/internal/openai/openai.go b/internal/openai/openai.go index 7d813d1c..10faf565 100644 --- a/internal/openai/openai.go +++ b/internal/openai/openai.go @@ -281,9 +281,9 @@ func (c *Client) ChatCompletionStream(ctx context.Context, payload interface{}, // StreamToolCall 流式工具调用的累积结果(arguments 以字符串形式拼接,留给上层再解析为 JSON)。 type StreamToolCall struct { - Index int - ID string - Type string + Index int + ID string + Type string FunctionName string FunctionArgsStr string } @@ -348,10 +348,10 @@ func (c *Client) ChatCompletionStreamWithToolCalls( Arguments string `json:"arguments,omitempty"` } type toolCallDelta struct { - Index int `json:"index,omitempty"` - ID string `json:"id,omitempty"` - Type string `json:"type,omitempty"` - Function toolCallFunctionDelta `json:"function,omitempty"` + Index int `json:"index,omitempty"` + ID string `json:"id,omitempty"` + Type string `json:"type,omitempty"` + Function toolCallFunctionDelta `json:"function,omitempty"` } type streamDelta2 struct { Content string `json:"content,omitempty"` @@ -371,10 +371,10 @@ func (c *Client) ChatCompletionStreamWithToolCalls( } type toolCallAccum struct { - id string - typ string - name string - args strings.Builder + id string + typ string + name string + args strings.Builder } toolCallAccums := make(map[int]*toolCallAccum) @@ -475,9 +475,9 @@ func (c *Client) ChatCompletionStreamWithToolCalls( for _, idx := range indices { acc := toolCallAccums[idx] tc := StreamToolCall{ - Index: idx, - ID: acc.id, - Type: acc.typ, + Index: idx, + ID: acc.id, + Type: acc.typ, FunctionName: acc.name, FunctionArgsStr: acc.args.String(), } diff --git a/internal/security/executor_test.go b/internal/security/executor_test.go index 2885fcb4..6286c5e7 100644 --- a/internal/security/executor_test.go +++ b/internal/security/executor_test.go @@ -19,11 +19,11 @@ import ( func setupTestExecutor(t *testing.T) (*Executor, *mcp.Server) { logger := zap.NewNop() mcpServer := mcp.NewServer(logger) - + cfg := &config.SecurityConfig{ Tools: []config.ToolConfig{}, } - + executor := NewExecutor(cfg, mcpServer, logger) return executor, mcpServer } @@ -32,12 +32,12 @@ func setupTestExecutor(t *testing.T) (*Executor, *mcp.Server) { func setupTestStorage(t *testing.T) *storage.FileResultStorage { tmpDir := filepath.Join(os.TempDir(), "test_executor_storage_"+time.Now().Format("20060102_150405")) logger := zap.NewNop() - + storage, err := storage.NewFileResultStorage(tmpDir, logger) if err != nil { t.Fatalf("创建测试存储失败: %v", err) } - + return storage } @@ -45,46 +45,46 @@ func TestExecutor_ExecuteInternalTool_QueryExecutionResult(t *testing.T) { executor, _ := setupTestExecutor(t) testStorage := setupTestStorage(t) executor.SetResultStorage(testStorage) - + // 准备测试数据 executionID := "test_exec_001" toolName := "nmap_scan" result := "Line 1: Port 22 open\nLine 2: Port 80 open\nLine 3: Port 443 open\nLine 4: error occurred" - + // 保存测试结果 err := testStorage.SaveResult(executionID, toolName, result) if err != nil { t.Fatalf("保存测试结果失败: %v", err) } - + ctx := context.Background() - + // 测试1: 基本查询(第一页) args := map[string]interface{}{ "execution_id": executionID, "page": float64(1), "limit": float64(2), } - + toolResult, err := executor.executeQueryExecutionResult(ctx, args) if err != nil { t.Fatalf("执行查询失败: %v", err) } - + if toolResult.IsError { t.Fatalf("查询应该成功,但返回了错误: %s", toolResult.Content[0].Text) } - + // 验证结果包含预期内容 resultText := toolResult.Content[0].Text if !strings.Contains(resultText, executionID) { t.Errorf("结果中应该包含执行ID: %s", executionID) } - + if !strings.Contains(resultText, "第 1/") { t.Errorf("结果中应该包含分页信息") } - + // 测试2: 搜索功能 args2 := map[string]interface{}{ "execution_id": executionID, @@ -92,21 +92,21 @@ func TestExecutor_ExecuteInternalTool_QueryExecutionResult(t *testing.T) { "page": float64(1), "limit": float64(10), } - + toolResult2, err := executor.executeQueryExecutionResult(ctx, args2) if err != nil { t.Fatalf("执行搜索失败: %v", err) } - + if toolResult2.IsError { t.Fatalf("搜索应该成功,但返回了错误: %s", toolResult2.Content[0].Text) } - + resultText2 := toolResult2.Content[0].Text if !strings.Contains(resultText2, "error") { t.Errorf("搜索结果中应该包含关键词: error") } - + // 测试3: 过滤功能 args3 := map[string]interface{}{ "execution_id": executionID, @@ -114,46 +114,46 @@ func TestExecutor_ExecuteInternalTool_QueryExecutionResult(t *testing.T) { "page": float64(1), "limit": float64(10), } - + toolResult3, err := executor.executeQueryExecutionResult(ctx, args3) if err != nil { t.Fatalf("执行过滤失败: %v", err) } - + if toolResult3.IsError { t.Fatalf("过滤应该成功,但返回了错误: %s", toolResult3.Content[0].Text) } - + resultText3 := toolResult3.Content[0].Text if !strings.Contains(resultText3, "Port") { t.Errorf("过滤结果中应该包含关键词: Port") } - + // 测试4: 缺少必需参数 args4 := map[string]interface{}{ "page": float64(1), } - + toolResult4, err := executor.executeQueryExecutionResult(ctx, args4) if err != nil { t.Fatalf("执行查询失败: %v", err) } - + if !toolResult4.IsError { t.Fatal("缺少execution_id应该返回错误") } - + // 测试5: 不存在的执行ID args5 := map[string]interface{}{ "execution_id": "nonexistent_id", "page": float64(1), } - + toolResult5, err := executor.executeQueryExecutionResult(ctx, args5) if err != nil { t.Fatalf("执行查询失败: %v", err) } - + if !toolResult5.IsError { t.Fatal("不存在的执行ID应该返回错误") } @@ -161,22 +161,22 @@ func TestExecutor_ExecuteInternalTool_QueryExecutionResult(t *testing.T) { func TestExecutor_ExecuteInternalTool_UnknownTool(t *testing.T) { executor, _ := setupTestExecutor(t) - + ctx := context.Background() args := map[string]interface{}{ "test": "value", } - + // 测试未知的内部工具类型 toolResult, err := executor.executeInternalTool(ctx, "unknown_tool", "internal:unknown_tool", args) if err != nil { t.Fatalf("执行内部工具失败: %v", err) } - + if !toolResult.IsError { t.Fatal("未知的工具类型应该返回错误") } - + if !strings.Contains(toolResult.Content[0].Text, "未知的内部工具类型") { t.Errorf("错误消息应该包含'未知的内部工具类型'") } @@ -185,21 +185,21 @@ func TestExecutor_ExecuteInternalTool_UnknownTool(t *testing.T) { func TestExecutor_ExecuteInternalTool_NoStorage(t *testing.T) { executor, _ := setupTestExecutor(t) // 不设置存储,测试未初始化的情况 - + ctx := context.Background() args := map[string]interface{}{ "execution_id": "test_id", } - + toolResult, err := executor.executeQueryExecutionResult(ctx, args) if err != nil { t.Fatalf("执行查询失败: %v", err) } - + if !toolResult.IsError { t.Fatal("未初始化的存储应该返回错误") } - + if !strings.Contains(toolResult.Content[0].Text, "结果存储未初始化") { t.Errorf("错误消息应该包含'结果存储未初始化'") } @@ -207,7 +207,7 @@ func TestExecutor_ExecuteInternalTool_NoStorage(t *testing.T) { func TestPaginateLines(t *testing.T) { lines := []string{"Line 1", "Line 2", "Line 3", "Line 4", "Line 5"} - + // 测试第一页 page := paginateLines(lines, 1, 2) if page.Page != 1 { @@ -225,7 +225,7 @@ func TestPaginateLines(t *testing.T) { if len(page.Lines) != 2 { t.Errorf("第一页行数不匹配。期望: 2, 实际: %d", len(page.Lines)) } - + // 测试第二页 page2 := paginateLines(lines, 2, 2) if len(page2.Lines) != 2 { @@ -234,13 +234,13 @@ func TestPaginateLines(t *testing.T) { if page2.Lines[0] != "Line 3" { t.Errorf("第二页第一行不匹配。期望: Line 3, 实际: %s", page2.Lines[0]) } - + // 测试最后一页 page3 := paginateLines(lines, 3, 2) if len(page3.Lines) != 1 { t.Errorf("第三页行数不匹配。期望: 1, 实际: %d", len(page3.Lines)) } - + // 测试超出范围的页码(应该返回最后一页) page4 := paginateLines(lines, 4, 2) if page4.Page != 3 { @@ -249,13 +249,13 @@ func TestPaginateLines(t *testing.T) { if len(page4.Lines) != 1 { t.Errorf("最后一页应该只有1行。实际: %d行", len(page4.Lines)) } - + // 测试无效页码(小于1) page0 := paginateLines(lines, 0, 2) if page0.Page != 1 { t.Errorf("无效页码应该被修正为1。实际: %d", page0.Page) } - + // 测试空列表 emptyPage := paginateLines([]string{}, 1, 10) if emptyPage.TotalLines != 0 { @@ -265,4 +265,3 @@ func TestPaginateLines(t *testing.T) { t.Errorf("空列表应该返回空结果。实际: %d行", len(emptyPage.Lines)) } } - diff --git a/internal/security/ratelimit.go b/internal/security/ratelimit.go index 1c959237..71795710 100644 --- a/internal/security/ratelimit.go +++ b/internal/security/ratelimit.go @@ -16,10 +16,10 @@ type rateLimitEntry struct { // RateLimiter 基于 IP 的滑动窗口速率限制器 type RateLimiter struct { - mu sync.Mutex - entries map[string]*rateLimitEntry - limit int // 窗口内允许的最大请求数 - window time.Duration // 窗口时长 + mu sync.Mutex + entries map[string]*rateLimitEntry + limit int // 窗口内允许的最大请求数 + window time.Duration // 窗口时长 } // NewRateLimiter 创建速率限制器