Add files via upload

This commit is contained in:
公明
2026-04-28 11:37:52 +08:00
committed by GitHub
parent 3b3d094dc4
commit b53cae3a02
26 changed files with 374 additions and 374 deletions
+21 -21
View File
@@ -336,10 +336,10 @@ func (fc *FunctionCall) UnmarshalJSON(data []byte) error {
// AgentLoopResult Agent Loop执行结果
type AgentLoopResult struct {
Response string
MCPExecutionIDs []string
LastReActInput string // 最后一轮ReAct的输入(压缩后的messagesJSON格式
LastReActOutput string // 最终大模型的输出
Response string
MCPExecutionIDs []string
LastAgentTraceInput string // 最后一轮代理消息轨迹(压缩后的 messagesJSON;与 multiagent.RunResult 字段对齐
LastAgentTraceOutput string // 最终助手输出文本
}
// ProgressCallback 进度回调函数类型
@@ -471,7 +471,7 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
}
// 用于保存当前的messages,以便在异常情况下也能保存ReAct输入
var currentReActInput string
var currentAgentTraceInput string
maxIterations := a.maxIterations
thinkingStreamSeq := 0
@@ -490,9 +490,9 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
if err != nil {
a.logger.Warn("序列化ReAct输入失败", zap.Error(err))
} else {
currentReActInput = string(messagesJSON)
currentAgentTraceInput = string(messagesJSON)
// 更新result中的值,确保始终保存最新的ReAct输入(压缩后的)
result.LastReActInput = currentReActInput
result.LastAgentTraceInput = currentAgentTraceInput
}
// 检查上下文是否已取消
@@ -500,13 +500,13 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
case <-ctx.Done():
// 上下文被取消(可能是用户主动暂停或其他原因)
a.logger.Info("检测到上下文取消,保存当前ReAct数据", zap.Error(ctx.Err()))
result.LastReActInput = currentReActInput
result.LastAgentTraceInput = currentAgentTraceInput
if ctx.Err() == context.Canceled {
result.Response = "任务已被取消。"
} else {
result.Response = fmt.Sprintf("任务执行中断: %v", ctx.Err())
}
result.LastReActOutput = result.Response
result.LastAgentTraceOutput = result.Response
return result, ctx.Err()
default:
}
@@ -600,10 +600,10 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
})
if err != nil {
// API调用失败,保存当前的ReAct输入和错误信息作为输出
result.LastReActInput = currentReActInput
result.LastAgentTraceInput = currentAgentTraceInput
errorMsg := fmt.Sprintf("调用OpenAI失败: %v", err)
result.Response = errorMsg
result.LastReActOutput = errorMsg
result.LastAgentTraceOutput = errorMsg
a.logger.Warn("OpenAI调用失败,已保存ReAct数据", zap.Error(err))
return result, fmt.Errorf("调用OpenAI失败: %w", err)
}
@@ -629,19 +629,19 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
continue
}
// OpenAI返回错误,保存当前的ReAct输入和错误信息作为输出
result.LastReActInput = currentReActInput
result.LastAgentTraceInput = currentAgentTraceInput
errorMsg := fmt.Sprintf("OpenAI错误: %s", response.Error.Message)
result.Response = errorMsg
result.LastReActOutput = errorMsg
result.LastAgentTraceOutput = errorMsg
return result, fmt.Errorf("OpenAI错误: %s", response.Error.Message)
}
if len(response.Choices) == 0 {
// 没有收到响应,保存当前的ReAct输入和错误信息作为输出
result.LastReActInput = currentReActInput
result.LastAgentTraceInput = currentAgentTraceInput
errorMsg := "没有收到响应"
result.Response = errorMsg
result.LastReActOutput = errorMsg
result.LastAgentTraceOutput = errorMsg
return result, fmt.Errorf("没有收到响应")
}
@@ -816,7 +816,7 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
})
if strings.TrimSpace(streamText) != "" {
result.Response = streamText
result.LastReActOutput = result.Response
result.LastAgentTraceOutput = result.Response
sendProgress("progress", "总结生成完成", nil)
return result, nil
}
@@ -863,14 +863,14 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
})
if strings.TrimSpace(streamText) != "" {
result.Response = streamText
result.LastReActOutput = result.Response
result.LastAgentTraceOutput = result.Response
sendProgress("progress", "总结生成完成", nil)
return result, nil
}
// 如果获取总结失败,使用当前回复作为结果
if choice.Message.Content != "" {
result.Response = choice.Message.Content
result.LastReActOutput = result.Response
result.LastAgentTraceOutput = result.Response
return result, nil
}
// 如果都没有内容,跳出循环,让后续逻辑处理
@@ -881,7 +881,7 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
if choice.FinishReason == "stop" {
sendProgress("progress", "正在生成最终回复...", nil)
result.Response = choice.Message.Content
result.LastReActOutput = result.Response
result.LastAgentTraceOutput = result.Response
return result, nil
}
}
@@ -910,14 +910,14 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
})
if strings.TrimSpace(streamText) != "" {
result.Response = streamText
result.LastReActOutput = result.Response
result.LastAgentTraceOutput = result.Response
sendProgress("progress", "总结生成完成", nil)
return result, nil
}
// 如果无法生成总结,返回友好的提示
result.Response = fmt.Sprintf("已达到最大迭代次数(%d轮)。系统已执行了多轮测试,但由于达到迭代上限,无法继续自动执行。建议您查看已执行的工具结果,或提出新的测试请求以继续测试。", a.maxIterations)
result.LastReActOutput = result.Response
result.LastAgentTraceOutput = result.Response
return result, nil
}
+48 -49
View File
@@ -18,62 +18,62 @@ import (
func setupTestAgent(t *testing.T) (*Agent, *storage.FileResultStorage) {
logger := zap.NewNop()
mcpServer := mcp.NewServer(logger)
openAICfg := &config.OpenAIConfig{
APIKey: "test-key",
BaseURL: "https://api.test.com/v1",
Model: "test-model",
}
agentCfg := &config.AgentConfig{
MaxIterations: 10,
LargeResultThreshold: 100, // 设置较小的阈值便于测试
ResultStorageDir: "",
}
agent := NewAgent(openAICfg, agentCfg, mcpServer, nil, logger, 10)
// 创建测试存储
tmpDir := filepath.Join(os.TempDir(), "test_agent_storage_"+time.Now().Format("20060102_150405"))
testStorage, err := storage.NewFileResultStorage(tmpDir, logger)
if err != nil {
t.Fatalf("创建测试存储失败: %v", err)
}
agent.SetResultStorage(testStorage)
return agent, testStorage
}
func TestAgent_FormatMinimalNotification(t *testing.T) {
agent, testStorage := setupTestAgent(t)
_ = testStorage // 避免未使用变量警告
executionID := "test_exec_001"
toolName := "nmap_scan"
size := 50000
lineCount := 1000
filePath := "tmp/test_exec_001.txt"
notification := agent.formatMinimalNotification(executionID, toolName, size, lineCount, filePath)
// 验证通知包含必要信息
if !strings.Contains(notification, executionID) {
t.Errorf("通知中应该包含执行ID: %s", executionID)
}
if !strings.Contains(notification, toolName) {
t.Errorf("通知中应该包含工具名称: %s", toolName)
}
if !strings.Contains(notification, "50000") {
t.Errorf("通知中应该包含大小信息")
}
if !strings.Contains(notification, "1000") {
t.Errorf("通知中应该包含行数信息")
}
if !strings.Contains(notification, "query_execution_result") {
t.Errorf("通知中应该包含查询工具的使用说明")
}
@@ -81,7 +81,7 @@ func TestAgent_FormatMinimalNotification(t *testing.T) {
func TestAgent_ExecuteToolViaMCP_LargeResult(t *testing.T) {
agent, _ := setupTestAgent(t)
// 创建模拟的MCP工具结果(大结果)
largeResult := &mcp.ToolResult{
Content: []mcp.Content{
@@ -92,59 +92,59 @@ func TestAgent_ExecuteToolViaMCP_LargeResult(t *testing.T) {
},
IsError: false,
}
// 模拟MCP服务器返回大结果
// 由于我们需要模拟CallTool的行为,这里需要创建一个mock或者使用实际的MCP服务器
// 为了简化测试,我们直接测试结果处理逻辑
// 设置阈值
agent.mu.Lock()
agent.largeResultThreshold = 1000 // 设置较小的阈值
agent.mu.Unlock()
// 创建执行ID
executionID := "test_exec_large_001"
toolName := "test_tool"
// 格式化结果
var resultText strings.Builder
for _, content := range largeResult.Content {
resultText.WriteString(content.Text)
resultText.WriteString("\n")
}
resultStr := resultText.String()
resultSize := len(resultStr)
// 检测大结果并保存
agent.mu.RLock()
threshold := agent.largeResultThreshold
storage := agent.resultStorage
agent.mu.RUnlock()
if resultSize > threshold && storage != nil {
// 保存大结果
err := storage.SaveResult(executionID, toolName, resultStr)
if err != nil {
t.Fatalf("保存大结果失败: %v", err)
}
// 生成通知
lines := strings.Split(resultStr, "\n")
filePath := storage.GetResultPath(executionID)
notification := agent.formatMinimalNotification(executionID, toolName, resultSize, len(lines), filePath)
// 验证通知格式
if !strings.Contains(notification, executionID) {
t.Errorf("通知中应该包含执行ID")
}
// 验证结果已保存
savedResult, err := storage.GetResult(executionID)
if err != nil {
t.Fatalf("获取保存的结果失败: %v", err)
}
if savedResult != resultStr {
t.Errorf("保存的结果与原始结果不匹配")
}
@@ -155,7 +155,7 @@ func TestAgent_ExecuteToolViaMCP_LargeResult(t *testing.T) {
func TestAgent_ExecuteToolViaMCP_SmallResult(t *testing.T) {
agent, _ := setupTestAgent(t)
// 创建小结果
smallResult := &mcp.ToolResult{
Content: []mcp.Content{
@@ -166,32 +166,32 @@ func TestAgent_ExecuteToolViaMCP_SmallResult(t *testing.T) {
},
IsError: false,
}
// 设置较大的阈值
agent.mu.Lock()
agent.largeResultThreshold = 100000 // 100KB
agent.mu.Unlock()
// 格式化结果
var resultText strings.Builder
for _, content := range smallResult.Content {
resultText.WriteString(content.Text)
resultText.WriteString("\n")
}
resultStr := resultText.String()
resultSize := len(resultStr)
// 检测大结果
agent.mu.RLock()
threshold := agent.largeResultThreshold
storage := agent.resultStorage
agent.mu.RUnlock()
if resultSize > threshold && storage != nil {
t.Fatal("小结果不应该被保存")
}
// 小结果应该直接返回
if resultSize <= threshold {
// 这是预期的行为
@@ -203,26 +203,26 @@ func TestAgent_ExecuteToolViaMCP_SmallResult(t *testing.T) {
func TestAgent_SetResultStorage(t *testing.T) {
agent, _ := setupTestAgent(t)
// 创建新的存储
tmpDir := filepath.Join(os.TempDir(), "test_new_storage_"+time.Now().Format("20060102_150405"))
newStorage, err := storage.NewFileResultStorage(tmpDir, zap.NewNop())
if err != nil {
t.Fatalf("创建新存储失败: %v", err)
}
// 设置新存储
agent.SetResultStorage(newStorage)
// 验证存储已更新
agent.mu.RLock()
currentStorage := agent.resultStorage
agent.mu.RUnlock()
if currentStorage != newStorage {
t.Fatal("存储未正确更新")
}
// 清理
os.RemoveAll(tmpDir)
}
@@ -230,24 +230,24 @@ func TestAgent_SetResultStorage(t *testing.T) {
func TestAgent_NewAgent_DefaultValues(t *testing.T) {
logger := zap.NewNop()
mcpServer := mcp.NewServer(logger)
openAICfg := &config.OpenAIConfig{
APIKey: "test-key",
BaseURL: "https://api.test.com/v1",
Model: "test-model",
}
// 测试默认配置
agent := NewAgent(openAICfg, nil, mcpServer, nil, logger, 0)
if agent.maxIterations != 30 {
t.Errorf("默认迭代次数不匹配。期望: 30, 实际: %d", agent.maxIterations)
}
agent.mu.RLock()
threshold := agent.largeResultThreshold
agent.mu.RUnlock()
if threshold != 50*1024 {
t.Errorf("默认阈值不匹配。期望: %d, 实际: %d", 50*1024, threshold)
}
@@ -256,31 +256,30 @@ func TestAgent_NewAgent_DefaultValues(t *testing.T) {
func TestAgent_NewAgent_CustomConfig(t *testing.T) {
logger := zap.NewNop()
mcpServer := mcp.NewServer(logger)
openAICfg := &config.OpenAIConfig{
APIKey: "test-key",
BaseURL: "https://api.test.com/v1",
Model: "test-model",
}
agentCfg := &config.AgentConfig{
MaxIterations: 20,
LargeResultThreshold: 100 * 1024, // 100KB
ResultStorageDir: "custom_tmp",
}
agent := NewAgent(openAICfg, agentCfg, mcpServer, nil, logger, 15)
if agent.maxIterations != 15 {
t.Errorf("迭代次数不匹配。期望: 15, 实际: %d", agent.maxIterations)
}
agent.mu.RLock()
threshold := agent.largeResultThreshold
agent.mu.RUnlock()
if threshold != 100*1024 {
t.Errorf("阈值不匹配。期望: %d, 实际: %d", 100*1024, threshold)
}
}
+5 -5
View File
@@ -256,11 +256,11 @@ func orchestratorConfigFromOrchestrator(o *OrchestratorMarkdown) config.MultiAge
return config.MultiAgentSubConfig{}
}
return config.MultiAgentSubConfig{
ID: o.EinoName,
Name: o.DisplayName,
Description: o.Description,
Instruction: o.Instruction,
Kind: "orchestrator",
ID: o.EinoName,
Name: o.DisplayName,
Description: o.Description,
Instruction: o.Instruction,
Kind: "orchestrator",
}
}
+84 -92
View File
@@ -497,10 +497,10 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) {
defer h.hitlManager.DeactivateConversation(conversationID)
}
// 优先尝试从保存的ReAct数据恢复历史上下文
agentHistoryMessages, err := h.loadHistoryFromReActData(conversationID)
// 优先尝试从保存的代理轨迹恢复历史上下文
agentHistoryMessages, err := h.loadHistoryFromAgentTrace(conversationID)
if err != nil {
h.logger.Warn("从ReAct数据加载历史消息失败,使用消息表", zap.Error(err))
h.logger.Warn("从代理轨迹加载历史消息失败,使用消息表", zap.Error(err))
// 回退到使用数据库消息表
historyMessages, err := h.db.GetMessages(conversationID)
if err != nil {
@@ -518,7 +518,7 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) {
h.logger.Info("从消息表加载历史消息", zap.Int("count", len(agentHistoryMessages)))
}
} else {
h.logger.Info("从ReAct数据恢复历史上下文", zap.Int("count", len(agentHistoryMessages)))
h.logger.Info("从代理轨迹恢复历史上下文", zap.Int("count", len(agentHistoryMessages)))
}
// 校验附件数量(非流式)
@@ -613,12 +613,12 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) {
if err != nil {
h.logger.Error("Agent Loop执行失败", zap.Error(err))
// 即使执行失败,也尝试保存ReAct数据(如果result中有)
if result != nil && (result.LastReActInput != "" || result.LastReActOutput != "") {
if saveErr := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); saveErr != nil {
h.logger.Warn("保存失败任务的ReAct数据失败", zap.Error(saveErr))
// 即使执行失败,也尝试保存代理轨迹(如果 result 中有)
if result != nil && (result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "") {
if saveErr := h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); saveErr != nil {
h.logger.Warn("保存失败任务的代理轨迹失败", zap.Error(saveErr))
} else {
h.logger.Info("已保存失败任务的ReAct数据", zap.String("conversationId", conversationID))
h.logger.Info("已保存失败任务的代理轨迹", zap.String("conversationId", conversationID))
}
}
@@ -634,12 +634,12 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) {
// 因为AI已经生成了回复,用户应该能看到
}
// 保存最后一轮ReAct的输入和输出
if result.LastReActInput != "" || result.LastReActOutput != "" {
if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil {
h.logger.Warn("保存ReAct数据失败", zap.Error(err))
// 保存最后一轮代理轨迹与助手输出
if result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "" {
if err := h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); err != nil {
h.logger.Warn("保存代理轨迹失败", zap.Error(err))
} else {
h.logger.Info("已保存ReAct数据", zap.String("conversationId", conversationID))
h.logger.Info("已保存代理轨迹", zap.String("conversationId", conversationID))
}
}
@@ -666,7 +666,7 @@ func (h *AgentHandler) ProcessMessageForRobot(ctx context.Context, conversationI
}
}
agentHistoryMessages, err := h.loadHistoryFromReActData(conversationID)
agentHistoryMessages, err := h.loadHistoryFromAgentTrace(conversationID)
if err != nil {
historyMessages, getErr := h.db.GetMessages(conversationID)
if getErr != nil {
@@ -722,6 +722,7 @@ func (h *AgentHandler) ProcessMessageForRobot(ctx context.Context, conversationI
"deep",
)
if errMA != nil {
h.persistEinoAgentTraceForResume(conversationID, resultMA)
errMsg := "执行失败: " + errMA.Error()
if assistantMessageID != "" {
_, _ = h.db.Exec("UPDATE messages SET content = ? WHERE id = ?", errMsg, assistantMessageID)
@@ -747,8 +748,8 @@ func (h *AgentHandler) ProcessMessageForRobot(ctx context.Context, conversationI
h.logger.Warn("机器人:保存助手消息失败", zap.Error(err))
}
}
if resultMA.LastReActInput != "" || resultMA.LastReActOutput != "" {
_ = h.db.SaveReActData(conversationID, resultMA.LastReActInput, resultMA.LastReActOutput)
if resultMA.LastAgentTraceInput != "" || resultMA.LastAgentTraceOutput != "" {
_ = h.db.SaveAgentTrace(conversationID, resultMA.LastAgentTraceInput, resultMA.LastAgentTraceOutput)
}
return resultMA.Response, conversationID, nil
}
@@ -782,8 +783,8 @@ func (h *AgentHandler) ProcessMessageForRobot(ctx context.Context, conversationI
h.logger.Warn("机器人:保存助手消息失败", zap.Error(err))
}
}
if result.LastReActInput != "" || result.LastReActOutput != "" {
_ = h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput)
if result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "" {
_ = h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput)
}
return result.Response, conversationID, nil
}
@@ -1359,10 +1360,10 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
}
ssePublishConversationID = conversationID
// 优先尝试从保存的ReAct数据恢复历史上下文
agentHistoryMessages, err := h.loadHistoryFromReActData(conversationID)
// 优先尝试从保存的代理轨迹恢复历史上下文
agentHistoryMessages, err := h.loadHistoryFromAgentTrace(conversationID)
if err != nil {
h.logger.Warn("从ReAct数据加载历史消息失败,使用消息表", zap.Error(err))
h.logger.Warn("从代理轨迹加载历史消息失败,使用消息表", zap.Error(err))
// 回退到使用数据库消息表
historyMessages, err := h.db.GetMessages(conversationID)
if err != nil {
@@ -1380,7 +1381,7 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
h.logger.Info("从消息表加载历史消息", zap.Int("count", len(agentHistoryMessages)))
}
} else {
h.logger.Info("从ReAct数据恢复历史上下文", zap.Int("count", len(agentHistoryMessages)))
h.logger.Info("从代理轨迹恢复历史上下文", zap.Int("count", len(agentHistoryMessages)))
}
// 校验附件数量
@@ -1579,12 +1580,12 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
h.db.AddProcessDetail(assistantMessageID, conversationID, "cancelled", cancelMsg, nil)
}
// 即使任务被取消,也尝试保存ReAct数据(如果result中有)
if result != nil && (result.LastReActInput != "" || result.LastReActOutput != "") {
if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil {
h.logger.Warn("保存取消任务的ReAct数据失败", zap.Error(err))
// 即使任务被取消,也尝试保存代理轨迹(如果 result 中有)
if result != nil && (result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "") {
if err := h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); err != nil {
h.logger.Warn("保存取消任务的代理轨迹失败", zap.Error(err))
} else {
h.logger.Info("已保存取消任务的ReAct数据", zap.String("conversationId", conversationID))
h.logger.Info("已保存取消任务的代理轨迹", zap.String("conversationId", conversationID))
}
}
@@ -1614,12 +1615,12 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
h.db.AddProcessDetail(assistantMessageID, conversationID, "timeout", timeoutMsg, nil)
}
// 即使任务超时,也尝试保存ReAct数据(如果result中有)
if result != nil && (result.LastReActInput != "" || result.LastReActOutput != "") {
if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil {
h.logger.Warn("保存超时任务的ReAct数据失败", zap.Error(err))
// 即使任务超时,也尝试保存代理轨迹(如果 result 中有)
if result != nil && (result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "") {
if err := h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); err != nil {
h.logger.Warn("保存超时任务的代理轨迹失败", zap.Error(err))
} else {
h.logger.Info("已保存超时任务的ReAct数据", zap.String("conversationId", conversationID))
h.logger.Info("已保存超时任务的代理轨迹", zap.String("conversationId", conversationID))
}
}
@@ -1649,12 +1650,12 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
h.db.AddProcessDetail(assistantMessageID, conversationID, "error", errorMsg, nil)
}
// 即使任务失败,也尝试保存ReAct数据(如果result中有)
if result != nil && (result.LastReActInput != "" || result.LastReActOutput != "") {
if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil {
h.logger.Warn("保存失败任务的ReAct数据失败", zap.Error(err))
// 即使任务失败,也尝试保存代理轨迹(如果 result 中有)
if result != nil && (result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "") {
if err := h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); err != nil {
h.logger.Warn("保存失败任务的代理轨迹失败", zap.Error(err))
} else {
h.logger.Info("已保存失败任务的ReAct数据", zap.String("conversationId", conversationID))
h.logger.Info("已保存失败任务的代理轨迹", zap.String("conversationId", conversationID))
}
}
@@ -1694,12 +1695,12 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
}
}
// 保存最后一轮ReAct的输入和输出
if result.LastReActInput != "" || result.LastReActOutput != "" {
if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil {
h.logger.Warn("保存ReAct数据失败", zap.Error(err))
// 保存最后一轮代理轨迹与助手输出
if result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "" {
if err := h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); err != nil {
h.logger.Warn("保存代理轨迹失败", zap.Error(err))
} else {
h.logger.Info("已保存ReAct数据", zap.String("conversationId", conversationID))
h.logger.Info("已保存代理轨迹", zap.String("conversationId", conversationID))
}
}
@@ -2499,6 +2500,9 @@ func (h *AgentHandler) executeBatchQueue(queueID string) {
cancel()
if runErr != nil {
if useRunResult {
h.persistEinoAgentTraceForResume(conversationID, resultMA)
}
// 检查是否是取消错误
// 1. 直接检查是否是 context.Canceled(包括包装后的错误)
// 2. 检查错误消息中是否包含"context canceled"或"cancelled"关键字
@@ -2542,14 +2546,14 @@ func (h *AgentHandler) executeBatchQueue(queueID string) {
h.logger.Warn("保存取消消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(errMsg))
}
}
// 保存ReAct数据(如果存在)
if result != nil && (result.LastReActInput != "" || result.LastReActOutput != "") {
if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil {
h.logger.Warn("保存取消任务的ReAct数据失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
// 保存代理轨迹(如果存在)
if result != nil && (result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "") {
if err := h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); err != nil {
h.logger.Warn("保存取消任务的代理轨迹失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
}
} else if useRunResult && resultMA != nil && (resultMA.LastReActInput != "" || resultMA.LastReActOutput != "") {
if err := h.db.SaveReActData(conversationID, resultMA.LastReActInput, resultMA.LastReActOutput); err != nil {
h.logger.Warn("保存取消任务的ReAct数据失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
} else if useRunResult && resultMA != nil && (resultMA.LastAgentTraceInput != "" || resultMA.LastAgentTraceOutput != "") {
if err := h.db.SaveAgentTrace(conversationID, resultMA.LastAgentTraceInput, resultMA.LastAgentTraceOutput); err != nil {
h.logger.Warn("保存取消任务的代理轨迹失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
}
}
h.batchTaskManager.UpdateTaskStatusWithConversationID(queueID, task.ID, "cancelled", cancelMsg, "", conversationID)
@@ -2581,13 +2585,13 @@ func (h *AgentHandler) executeBatchQueue(queueID string) {
if useRunResult {
resText = resultMA.Response
mcpIDs = resultMA.MCPExecutionIDs
lastIn = resultMA.LastReActInput
lastOut = resultMA.LastReActOutput
lastIn = resultMA.LastAgentTraceInput
lastOut = resultMA.LastAgentTraceOutput
} else {
resText = result.Response
mcpIDs = result.MCPExecutionIDs
lastIn = result.LastReActInput
lastOut = result.LastReActOutput
lastIn = result.LastAgentTraceInput
lastOut = result.LastAgentTraceOutput
}
// 更新助手消息内容
@@ -2618,12 +2622,12 @@ func (h *AgentHandler) executeBatchQueue(queueID string) {
}
}
// 保存ReAct数据
// 保存代理轨迹
if lastIn != "" || lastOut != "" {
if err := h.db.SaveReActData(conversationID, lastIn, lastOut); err != nil {
h.logger.Warn("保存ReAct数据失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
if err := h.db.SaveAgentTrace(conversationID, lastIn, lastOut); err != nil {
h.logger.Warn("保存代理轨迹失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
} else {
h.logger.Info("已保存ReAct数据", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID))
h.logger.Info("已保存代理轨迹", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID))
}
}
@@ -2642,36 +2646,33 @@ func (h *AgentHandler) executeBatchQueue(queueID string) {
}
}
// loadHistoryFromReActData 从保存的ReAct数据恢复历史消息上下文
// 采用与攻击链生成类似的拼接逻辑:优先使用保存的last_react_input和last_react_output,若不存在则回退消息表
func (h *AgentHandler) loadHistoryFromReActData(conversationID string) ([]agent.ChatMessage, error) {
// 获取保存的ReAct输入和输出
reactInputJSON, reactOutput, err := h.db.GetReActData(conversationID)
// loadHistoryFromAgentTrace 从库中保存的代理消息轨迹恢复历史(列 last_react_*;含单代理与 Eino)。
// 逻辑与攻击链一致:优先用保存的 JSON 消息带 + 最后一轮助手摘要,否则回退消息表
func (h *AgentHandler) loadHistoryFromAgentTrace(conversationID string) ([]agent.ChatMessage, error) {
traceInputJSON, assistantOut, err := h.db.GetAgentTrace(conversationID)
if err != nil {
return nil, fmt.Errorf("获取ReAct数据失败: %w", err)
return nil, fmt.Errorf("获取代理轨迹失败: %w", err)
}
// 如果last_react_input为空,回退到使用消息表(与攻击链生成逻辑一致)
if reactInputJSON == "" {
return nil, fmt.Errorf("ReAct数据为空,将使用消息表")
if traceInputJSON == "" {
return nil, fmt.Errorf("代理轨迹为空,将使用消息表")
}
dataSource := "database_last_react_input"
dataSource := "database_last_agent_trace"
// 解析JSON格式的messages数组
var messagesArray []map[string]interface{}
if err := json.Unmarshal([]byte(reactInputJSON), &messagesArray); err != nil {
return nil, fmt.Errorf("解析ReAct输入JSON失败: %w", err)
if err := json.Unmarshal([]byte(traceInputJSON), &messagesArray); err != nil {
return nil, fmt.Errorf("解析代理轨迹 JSON 失败: %w", err)
}
messageCount := len(messagesArray)
h.logger.Info("使用保存的ReAct数据恢复历史上下文",
h.logger.Info("使用保存的代理轨迹恢复历史上下文",
zap.String("conversationId", conversationID),
zap.String("dataSource", dataSource),
zap.Int("reactInputSize", len(reactInputJSON)),
zap.Int("traceInputSize", len(traceInputJSON)),
zap.Int("messageCount", messageCount),
zap.Int("reactOutputSize", len(reactOutput)),
zap.Int("assistantOutSize", len(assistantOut)),
)
// fmt.Println("messagesArray:", messagesArray)//debug
@@ -2755,53 +2756,44 @@ func (h *AgentHandler) loadHistoryFromReActData(conversationID string) ([]agent.
agentMessages = append(agentMessages, msg)
}
// 如果存在last_react_output,需要将其作为最后一条assistant消息
// 因为last_react_input是在迭代开始前保存的,不包含最后一轮的最终输出
if reactOutput != "" {
// 检查最后一条消息是否是assistant消息且没有tool_calls
// 如果有tool_calls,说明后面应该还有tool消息和最终的assistant回复
// 存在 last_react_output(助手摘要),合并为最后一条 assistant(与保存格式一致)
if assistantOut != "" {
if len(agentMessages) > 0 {
lastMsg := &agentMessages[len(agentMessages)-1]
if strings.EqualFold(lastMsg.Role, "assistant") && len(lastMsg.ToolCalls) == 0 {
// 最后一条是assistant消息且没有tool_calls,用最终输出更新其content
lastMsg.Content = reactOutput
lastMsg.Content = assistantOut
} else {
// 最后一条不是assistant消息,或者有tool_calls,添加最终输出作为新的assistant消息
agentMessages = append(agentMessages, agent.ChatMessage{
Role: "assistant",
Content: reactOutput,
Content: assistantOut,
})
}
} else {
// 如果没有消息,直接添加最终输出
agentMessages = append(agentMessages, agent.ChatMessage{
Role: "assistant",
Content: reactOutput,
Content: assistantOut,
})
}
}
if len(agentMessages) == 0 {
return nil, fmt.Errorf("从ReAct数据解析的消息为空")
return nil, fmt.Errorf("从代理轨迹解析的消息为空")
}
// 修复可能存在的失配tool消息,避免OpenAI报错
// 这可以防止出现"messages with role 'tool' must be a response to a preceeding message with 'tool_calls'"错误
if h.agent != nil {
if fixed := h.agent.RepairOrphanToolMessages(&agentMessages); fixed {
h.logger.Info("修复了从ReAct数据恢复的历史消息中的失配tool消息",
h.logger.Info("修复了从代理轨迹恢复的历史消息中的失配 tool 消息",
zap.String("conversationId", conversationID),
)
}
}
h.logger.Info("从ReAct数据恢复历史消息完成",
h.logger.Info("从代理轨迹恢复历史消息完成",
zap.String("conversationId", conversationID),
zap.String("dataSource", dataSource),
zap.Int("originalMessageCount", messageCount),
zap.Int("finalMessageCount", len(agentMessages)),
zap.Bool("hasReactOutput", reactOutput != ""),
zap.Bool("hasAssistantOut", assistantOut != ""),
)
fmt.Println("agentMessages:", agentMessages) //debug
return agentMessages, nil
}
+2 -3
View File
@@ -83,7 +83,7 @@ func (h *AttackChainHandler) GetAttackChain(c *gin.Context) {
// 使用锁机制防止同一对话的并发生成
lockInterface, _ := h.generatingLocks.LoadOrStore(conversationID, &sync.Mutex{})
lock := lockInterface.(*sync.Mutex)
// 尝试获取锁,如果正在生成则返回错误
acquired := lock.TryLock()
if !acquired {
@@ -144,7 +144,7 @@ func (h *AttackChainHandler) RegenerateAttackChain(c *gin.Context) {
// 使用锁机制防止并发生成
lockInterface, _ := h.generatingLocks.LoadOrStore(conversationID, &sync.Mutex{})
lock := lockInterface.(*sync.Mutex)
acquired := lock.TryLock()
if !acquired {
h.logger.Info("攻击链正在生成中,请稍后再试", zap.String("conversationId", conversationID))
@@ -170,4 +170,3 @@ func (h *AttackChainHandler) RegenerateAttackChain(c *gin.Context) {
c.JSON(http.StatusOK, chain)
}
-1
View File
@@ -230,4 +230,3 @@ func (h *ConversationHandler) DeleteConversationTurn(c *gin.Context) {
"message": "ok",
})
}
+7 -5
View File
@@ -175,6 +175,7 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
)
if runErr != nil {
h.persistEinoAgentTraceForResume(conversationID, result)
cause := context.Cause(baseCtx)
if errors.Is(cause, ErrTaskCancelled) {
taskStatus = "cancelled"
@@ -239,9 +240,9 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
)
}
if result.LastReActInput != "" || result.LastReActOutput != "" {
if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil {
h.logger.Warn("保存 ReAct 数据失败", zap.Error(err))
if result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "" {
if err := h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); err != nil {
h.logger.Warn("保存代理轨迹失败", zap.Error(err))
}
}
@@ -306,6 +307,7 @@ func (h *AgentHandler) EinoSingleAgentLoop(c *gin.Context) {
progressCallback,
)
if runErr != nil {
h.persistEinoAgentTraceForResume(prep.ConversationID, result)
c.JSON(http.StatusInternalServerError, gin.H{"error": runErr.Error()})
return
}
@@ -323,8 +325,8 @@ func (h *AgentHandler) EinoSingleAgentLoop(c *gin.Context) {
prep.AssistantMessageID,
)
}
if result.LastReActInput != "" || result.LastReActOutput != "" {
_ = h.db.SaveReActData(prep.ConversationID, result.LastReActInput, result.LastReActOutput)
if result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "" {
_ = h.db.SaveAgentTrace(prep.ConversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput)
}
c.JSON(http.StatusOK, gin.H{
+10 -10
View File
@@ -247,7 +247,7 @@ func TestExternalMCPHandler_DeleteExternalMCP(t *testing.T) {
// 先添加一个配置
configObj := config.ExternalMCPServerConfig{
Command: "python3",
Command: "python3",
ExternalMCPEnable: true,
}
handler.manager.AddOrUpdateConfig("test-delete", configObj)
@@ -276,11 +276,11 @@ func TestExternalMCPHandler_GetExternalMCPs(t *testing.T) {
// 添加多个配置
handler.manager.AddOrUpdateConfig("test1", config.ExternalMCPServerConfig{
Command: "python3",
Command: "python3",
ExternalMCPEnable: true,
})
handler.manager.AddOrUpdateConfig("test2", config.ExternalMCPServerConfig{
URL: "http://127.0.0.1:8081/mcp",
URL: "http://127.0.0.1:8081/mcp",
ExternalMCPEnable: false,
})
@@ -319,15 +319,15 @@ func TestExternalMCPHandler_GetExternalMCPStats(t *testing.T) {
// 添加配置
handler.manager.AddOrUpdateConfig("enabled1", config.ExternalMCPServerConfig{
Command: "python3",
Command: "python3",
ExternalMCPEnable: true,
})
handler.manager.AddOrUpdateConfig("enabled2", config.ExternalMCPServerConfig{
URL: "http://127.0.0.1:8081/mcp",
URL: "http://127.0.0.1:8081/mcp",
ExternalMCPEnable: true,
})
handler.manager.AddOrUpdateConfig("disabled1", config.ExternalMCPServerConfig{
Command: "python3",
Command: "python3",
})
req := httptest.NewRequest("GET", "/api/external-mcp/stats", nil)
@@ -360,7 +360,7 @@ func TestExternalMCPHandler_StartStopExternalMCP(t *testing.T) {
// 添加一个禁用的配置
handler.manager.AddOrUpdateConfig("test-start-stop", config.ExternalMCPServerConfig{
Command: "python3",
Command: "python3",
})
// 测试启动(可能会失败,因为没有真实的服务器)
@@ -416,7 +416,7 @@ func TestExternalMCPHandler_AddOrUpdateExternalMCP_EmptyName(t *testing.T) {
router, _, _ := setupTestRouter()
configObj := config.ExternalMCPServerConfig{
Command: "python3",
Command: "python3",
ExternalMCPEnable: true,
}
@@ -459,14 +459,14 @@ func TestExternalMCPHandler_UpdateExistingConfig(t *testing.T) {
// 先添加配置
config1 := config.ExternalMCPServerConfig{
Command: "python3",
Command: "python3",
ExternalMCPEnable: true,
}
handler.manager.AddOrUpdateConfig("test-update", config1)
// 更新配置
config2 := config.ExternalMCPServerConfig{
URL: "http://127.0.0.1:8081/mcp",
URL: "http://127.0.0.1:8081/mcp",
ExternalMCPEnable: true,
}
+10 -10
View File
@@ -131,16 +131,16 @@ func (h *MarkdownAgentsHandler) GetMarkdownAgent(c *gin.Context) {
}
type markdownAgentBody struct {
Filename string `json:"filename"`
ID string `json:"id"`
Name string `json:"name"`
Description string `json:"description"`
Tools []string `json:"tools"`
Instruction string `json:"instruction"`
BindRole string `json:"bind_role"`
MaxIterations int `json:"max_iterations"`
Kind string `json:"kind"`
Raw string `json:"raw"`
Filename string `json:"filename"`
ID string `json:"id"`
Name string `json:"name"`
Description string `json:"description"`
Tools []string `json:"tools"`
Instruction string `json:"instruction"`
BindRole string `json:"bind_role"`
MaxIterations int `json:"max_iterations"`
Kind string `json:"kind"`
Raw string `json:"raw"`
}
// CreateMarkdownAgent POST /api/multi-agent/markdown-agents
+5 -8
View File
@@ -42,11 +42,11 @@ func (h *MonitorHandler) SetExternalMCPManager(mgr *mcp.ExternalMCPManager) {
type MonitorResponse struct {
Executions []*mcp.ToolExecution `json:"executions"`
Stats map[string]*mcp.ToolStats `json:"stats"`
Timestamp time.Time `json:"timestamp"`
Total int `json:"total,omitempty"`
Page int `json:"page,omitempty"`
PageSize int `json:"page_size,omitempty"`
TotalPages int `json:"total_pages,omitempty"`
Timestamp time.Time `json:"timestamp"`
Total int `json:"total,omitempty"`
Page int `json:"page,omitempty"`
PageSize int `json:"page_size,omitempty"`
TotalPages int `json:"total_pages,omitempty"`
}
// Monitor 获取监控信息
@@ -213,7 +213,6 @@ func (h *MonitorHandler) loadStats() map[string]*mcp.ToolStats {
return stats
}
// GetExecution 获取特定执行记录
func (h *MonitorHandler) GetExecution(c *gin.Context) {
id := c.Param("id")
@@ -416,5 +415,3 @@ func (h *MonitorHandler) DeleteExecutions(c *gin.Context) {
h.logger.Info("尝试批量删除内存中的执行记录", zap.Int("count", len(request.IDs)))
c.JSON(http.StatusOK, gin.H{"message": "执行记录已删除(如果存在)"})
}
+21 -6
View File
@@ -185,6 +185,7 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
)
if runErr != nil {
h.persistEinoAgentTraceForResume(conversationID, result)
cause := context.Cause(baseCtx)
if errors.Is(cause, ErrTaskCancelled) {
taskStatus = "cancelled"
@@ -249,9 +250,9 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
)
}
if result.LastReActInput != "" || result.LastReActOutput != "" {
if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil {
h.logger.Warn("保存 ReAct 数据失败", zap.Error(err))
if result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "" {
if err := h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); err != nil {
h.logger.Warn("保存代理轨迹失败", zap.Error(err))
}
}
@@ -318,6 +319,7 @@ func (h *AgentHandler) MultiAgentLoop(c *gin.Context) {
strings.TrimSpace(req.Orchestration),
)
if runErr != nil {
h.persistEinoAgentTraceForResume(prep.ConversationID, result)
h.logger.Error("Eino DeepAgent 执行失败", zap.Error(runErr))
errMsg := "执行失败: " + runErr.Error()
if prep.AssistantMessageID != "" {
@@ -341,9 +343,9 @@ func (h *AgentHandler) MultiAgentLoop(c *gin.Context) {
)
}
if result.LastReActInput != "" || result.LastReActOutput != "" {
if err := h.db.SaveReActData(prep.ConversationID, result.LastReActInput, result.LastReActOutput); err != nil {
h.logger.Warn("保存 ReAct 数据失败", zap.Error(err))
if result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "" {
if err := h.db.SaveAgentTrace(prep.ConversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); err != nil {
h.logger.Warn("保存代理轨迹失败", zap.Error(err))
}
}
@@ -355,6 +357,19 @@ func (h *AgentHandler) MultiAgentLoop(c *gin.Context) {
})
}
// persistEinoAgentTraceForResume 在 Eino 运行异常结束时写入代理轨迹(库列 last_react_*),供下一请求 loadHistoryFromAgentTrace 软续跑。
func (h *AgentHandler) persistEinoAgentTraceForResume(conversationID string, result *multiagent.RunResult) {
if h == nil || result == nil {
return
}
if result.LastAgentTraceInput == "" && result.LastAgentTraceOutput == "" {
return
}
if err := h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); err != nil {
h.logger.Warn("保存 Eino 续跑上下文失败", zap.String("conversationId", conversationID), zap.Error(err))
}
}
func multiAgentHTTPErrorStatus(err error) (int, string) {
msg := err.Error()
switch {
+1 -1
View File
@@ -49,7 +49,7 @@ func (h *AgentHandler) prepareMultiAgentSession(req *ChatRequest) (*multiAgentPr
}
}
agentHistoryMessages, err := h.loadHistoryFromReActData(conversationID)
agentHistoryMessages, err := h.loadHistoryFromAgentTrace(conversationID)
if err != nil {
historyMessages, getErr := h.db.GetMessages(conversationID)
if getErr != nil {
+26 -26
View File
@@ -4445,7 +4445,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
"content": map[string]interface{}{
"application/json": map[string]interface{}{
"schema": map[string]interface{}{
"type": "object",
"type": "object",
"required": []string{"messageId"},
"properties": map[string]interface{}{
"messageId": map[string]interface{}{
@@ -4689,7 +4689,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
"content": map[string]interface{}{
"application/json": map[string]interface{}{
"schema": map[string]interface{}{
"type": "object",
"type": "object",
"required": []string{"scheduleEnabled"},
"properties": map[string]interface{}{
"scheduleEnabled": map[string]interface{}{"type": "boolean", "description": "是否启用自动调度"},
@@ -4761,7 +4761,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
"content": map[string]interface{}{
"application/json": map[string]interface{}{
"schema": map[string]interface{}{
"type": "object",
"type": "object",
"required": []string{"query"},
"properties": map[string]interface{}{
"query": map[string]interface{}{"type": "string", "description": "FOFA查询语法", "example": "domain=\"example.com\""},
@@ -4810,7 +4810,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
"content": map[string]interface{}{
"application/json": map[string]interface{}{
"schema": map[string]interface{}{
"type": "object",
"type": "object",
"required": []string{"text"},
"properties": map[string]interface{}{
"text": map[string]interface{}{"type": "string", "description": "自然语言描述", "example": "查找使用WordPress的网站"},
@@ -4853,7 +4853,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
"content": map[string]interface{}{
"application/json": map[string]interface{}{
"schema": map[string]interface{}{
"type": "object",
"type": "object",
"required": []string{"api_key", "model"},
"properties": map[string]interface{}{
"provider": map[string]interface{}{"type": "string", "description": "LLM提供商(openai/claude", "example": "openai"},
@@ -4900,7 +4900,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
"content": map[string]interface{}{
"application/json": map[string]interface{}{
"schema": map[string]interface{}{
"type": "object",
"type": "object",
"required": []string{"command"},
"properties": map[string]interface{}{
"command": map[string]interface{}{"type": "string", "description": "要执行的命令"},
@@ -4943,7 +4943,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
"content": map[string]interface{}{
"application/json": map[string]interface{}{
"schema": map[string]interface{}{
"type": "object",
"type": "object",
"required": []string{"command"},
"properties": map[string]interface{}{
"command": map[string]interface{}{"type": "string", "description": "要执行的命令"},
@@ -5027,7 +5027,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
"content": map[string]interface{}{
"application/json": map[string]interface{}{
"schema": map[string]interface{}{
"type": "object",
"type": "object",
"required": []string{"url"},
"properties": map[string]interface{}{
"url": map[string]interface{}{"type": "string", "description": "WebShell URL"},
@@ -5231,7 +5231,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
"content": map[string]interface{}{
"application/json": map[string]interface{}{
"schema": map[string]interface{}{
"type": "object",
"type": "object",
"required": []string{"url", "command"},
"properties": map[string]interface{}{
"url": map[string]interface{}{"type": "string", "description": "WebShell URL"},
@@ -5277,7 +5277,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
"content": map[string]interface{}{
"application/json": map[string]interface{}{
"schema": map[string]interface{}{
"type": "object",
"type": "object",
"required": []string{"url", "action", "path"},
"properties": map[string]interface{}{
"url": map[string]interface{}{"type": "string", "description": "WebShell URL"},
@@ -5339,14 +5339,14 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
"items": map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"relativePath": map[string]interface{}{"type": "string"},
"absolutePath": map[string]interface{}{"type": "string"},
"name": map[string]interface{}{"type": "string"},
"size": map[string]interface{}{"type": "integer"},
"modifiedUnix": map[string]interface{}{"type": "integer"},
"date": map[string]interface{}{"type": "string"},
"relativePath": map[string]interface{}{"type": "string"},
"absolutePath": map[string]interface{}{"type": "string"},
"name": map[string]interface{}{"type": "string"},
"size": map[string]interface{}{"type": "integer"},
"modifiedUnix": map[string]interface{}{"type": "integer"},
"date": map[string]interface{}{"type": "string"},
"conversationId": map[string]interface{}{"type": "string"},
"subPath": map[string]interface{}{"type": "string"},
"subPath": map[string]interface{}{"type": "string"},
},
},
},
@@ -5369,7 +5369,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
"content": map[string]interface{}{
"multipart/form-data": map[string]interface{}{
"schema": map[string]interface{}{
"type": "object",
"type": "object",
"required": []string{"file"},
"properties": map[string]interface{}{
"file": map[string]interface{}{"type": "string", "format": "binary", "description": "上传的文件"},
@@ -5410,7 +5410,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
"content": map[string]interface{}{
"application/json": map[string]interface{}{
"schema": map[string]interface{}{
"type": "object",
"type": "object",
"required": []string{"path"},
"properties": map[string]interface{}{
"path": map[string]interface{}{"type": "string", "description": "文件相对路径"},
@@ -5485,7 +5485,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
"content": map[string]interface{}{
"application/json": map[string]interface{}{
"schema": map[string]interface{}{
"type": "object",
"type": "object",
"required": []string{"path", "content"},
"properties": map[string]interface{}{
"path": map[string]interface{}{"type": "string", "description": "文件相对路径"},
@@ -5512,7 +5512,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
"content": map[string]interface{}{
"application/json": map[string]interface{}{
"schema": map[string]interface{}{
"type": "object",
"type": "object",
"required": []string{"name"},
"properties": map[string]interface{}{
"parent": map[string]interface{}{"type": "string", "description": "父目录相对路径"},
@@ -5552,7 +5552,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
"content": map[string]interface{}{
"application/json": map[string]interface{}{
"schema": map[string]interface{}{
"type": "object",
"type": "object",
"required": []string{"path", "newName"},
"properties": map[string]interface{}{
"path": map[string]interface{}{"type": "string", "description": "当前文件相对路径"},
@@ -5646,7 +5646,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
"content": map[string]interface{}{
"application/json": map[string]interface{}{
"schema": map[string]interface{}{
"type": "object",
"type": "object",
"required": []string{"platform", "text"},
"properties": map[string]interface{}{
"platform": map[string]interface{}{"type": "string", "description": "平台类型", "enum": []string{"dingtalk", "lark", "wecom"}},
@@ -5712,7 +5712,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
"content": map[string]interface{}{
"application/json": map[string]interface{}{
"schema": map[string]interface{}{
"type": "object",
"type": "object",
"required": []string{"name"},
"properties": map[string]interface{}{
"filename": map[string]interface{}{"type": "string", "description": "文件名(可选,自动生成)"},
@@ -5932,7 +5932,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
"content": map[string]interface{}{
"application/json": map[string]interface{}{
"schema": map[string]interface{}{
"type": "object",
"type": "object",
"required": []string{"path"},
"properties": map[string]interface{}{
"path": map[string]interface{}{"type": "string", "description": "文件相对路径"},
@@ -5974,7 +5974,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
"content": map[string]interface{}{
"application/json": map[string]interface{}{
"schema": map[string]interface{}{
"type": "object",
"type": "object",
"required": []string{"ids"},
"properties": map[string]interface{}{
"ids": map[string]interface{}{
+6 -6
View File
@@ -26,7 +26,7 @@ var apiDocI18nSummaryToKey = map[string]string{
"创建分组": "createGroup", "列出分组": "listGroups", "获取分组": "getGroup", "更新分组": "updateGroup",
"删除分组": "deleteGroup", "获取分组中的对话": "getGroupConversations", "添加对话到分组": "addConversationToGroup",
"从分组移除对话": "removeConversationFromGroup",
"列出漏洞": "listVulnerabilities", "创建漏洞": "createVulnerability", "获取漏洞统计": "getVulnerabilityStats",
"列出漏洞": "listVulnerabilities", "创建漏洞": "createVulnerability", "获取漏洞统计": "getVulnerabilityStats",
"获取漏洞": "getVulnerability", "更新漏洞": "updateVulnerability", "删除漏洞": "deleteVulnerability",
"列出角色": "listRoles", "创建角色": "createRole", "获取角色": "getRole", "更新角色": "updateRole", "删除角色": "deleteRole",
"获取可用Skills列表": "getAvailableSkills", "列出Skills": "listSkills", "创建Skill": "createSkill",
@@ -52,9 +52,9 @@ var apiDocI18nSummaryToKey = map[string]string{
"重跑批量任务队列": "rerunBatchQueue", "修改队列元数据": "updateBatchQueueMetadata",
"修改队列调度配置": "updateBatchQueueSchedule", "开关Cron自动调度": "setBatchQueueScheduleEnabled",
"获取所有分组映射": "getAllGroupMappings",
"FOFA搜索": "fofaSearch", "自然语言解析为FOFA语法": "fofaParse",
"FOFA搜索": "fofaSearch", "自然语言解析为FOFA语法": "fofaParse",
"测试OpenAI API连接": "testOpenAI",
"执行终端命令": "terminalRun", "流式执行终端命令": "terminalRunStream", "WebSocket终端": "terminalWS",
"执行终端命令": "terminalRun", "流式执行终端命令": "terminalRunStream", "WebSocket终端": "terminalWS",
"列出WebShell连接": "listWebshellConnections", "创建WebShell连接": "createWebshellConnection",
"更新WebShell连接": "updateWebshellConnection", "删除WebShell连接": "deleteWebshellConnection",
"获取连接状态": "getWebshellConnectionState", "保存连接状态": "saveWebshellConnectionState",
@@ -69,7 +69,7 @@ var apiDocI18nSummaryToKey = map[string]string{
"获取Markdown代理详情": "getMarkdownAgent", "更新Markdown代理": "updateMarkdownAgent", "删除Markdown代理": "deleteMarkdownAgent",
"列出技能包文件": "listSkillPackageFiles", "获取技能包文件内容": "getSkillPackageFile", "写入技能包文件": "putSkillPackageFile",
"批量获取工具名称": "batchGetToolNames",
"获取知识库统计": "getKnowledgeStats",
"获取知识库统计": "getKnowledgeStats",
}
var apiDocI18nResponseDescToKey = map[string]string{
@@ -78,7 +78,7 @@ var apiDocI18nResponseDescToKey = map[string]string{
"对话不存在或结果不存在": "conversationOrResultNotFound", "请求参数错误(如task为空)": "badRequestTaskEmpty",
"请求参数错误或分组名称已存在": "badRequestGroupNameExists", "分组不存在": "groupNotFound",
"请求参数错误(如配置格式不正确、缺少必需字段等)": "badRequestConfig",
"请求参数错误(如query为空)": "badRequestQueryEmpty", "方法不允许(仅支持POST请求)": "methodNotAllowed",
"请求参数错误(如query为空)": "badRequestQueryEmpty", "方法不允许(仅支持POST请求)": "methodNotAllowed",
"登录成功": "loginSuccess", "密码错误": "invalidPassword", "登出成功": "logoutSuccess",
"密码修改成功": "passwordChanged", "Token有效": "tokenValid", "Token无效或已过期": "tokenInvalid",
"对话创建成功": "conversationCreated", "服务器内部错误": "internalError", "更新成功": "updateSuccess",
@@ -89,7 +89,7 @@ var apiDocI18nResponseDescToKey = map[string]string{
"消息发送成功,返回AI回复": "messageSent", "流式响应(Server-Sent Events": "streamResponse",
// 新增缺失端点响应
"参数错误或删除失败": "badRequestOrDeleteFailed",
"参数错误": "paramError", "仅已完成或已取消的队列可以重跑": "onlyCompletedOrCancelledCanRerun",
"参数错误": "paramError", "仅已完成或已取消的队列可以重跑": "onlyCompletedOrCancelledCanRerun",
"参数错误或队列正在运行中": "badRequestOrQueueRunning", "设置成功": "setSuccess",
"搜索成功": "searchSuccess", "解析成功": "parseSuccess", "测试结果": "testResult",
"执行完成": "executionDone", "SSE事件流": "sseEventStream", "WebSocket连接已建立": "wsEstablished",
+14 -14
View File
@@ -28,20 +28,20 @@ import (
)
const (
robotCmdHelp = "帮助"
robotCmdList = "列表"
robotCmdListAlt = "对话列表"
robotCmdSwitch = "切换"
robotCmdContinue = "继续"
robotCmdNew = "新对话"
robotCmdClear = "清空"
robotCmdCurrent = "当前"
robotCmdStop = "停止"
robotCmdRoles = "角色"
robotCmdRolesList = "角色列表"
robotCmdSwitchRole = "切换角色"
robotCmdDelete = "删除"
robotCmdVersion = "版本"
robotCmdHelp = "帮助"
robotCmdList = "列表"
robotCmdListAlt = "对话列表"
robotCmdSwitch = "切换"
robotCmdContinue = "继续"
robotCmdNew = "新对话"
robotCmdClear = "清空"
robotCmdCurrent = "当前"
robotCmdStop = "停止"
robotCmdRoles = "角色"
robotCmdRolesList = "角色列表"
robotCmdSwitchRole = "切换角色"
robotCmdDelete = "删除"
robotCmdVersion = "版本"
)
// RobotHandler 企业微信/钉钉/飞书等机器人回调处理
+13 -13
View File
@@ -65,19 +65,19 @@ func (h *SkillsHandler) GetSkills(c *gin.Context) {
allSkillsInfo := make([]map[string]interface{}, 0, len(allSummaries))
for _, s := range allSummaries {
skillInfo := map[string]interface{}{
"id": s.ID,
"name": s.Name,
"dir_name": s.DirName,
"description": s.Description,
"version": s.Version,
"path": s.Path,
"tags": s.Tags,
"triggers": s.Triggers,
"script_count": s.ScriptCount,
"file_count": s.FileCount,
"progressive": s.Progressive,
"file_size": s.FileSize,
"mod_time": s.ModTime,
"id": s.ID,
"name": s.Name,
"dir_name": s.DirName,
"description": s.Description,
"version": s.Version,
"path": s.Path,
"tags": s.Tags,
"triggers": s.Triggers,
"script_count": s.ScriptCount,
"file_count": s.FileCount,
"progressive": s.Progressive,
"file_size": s.FileSize,
"mod_time": s.ModTime,
}
allSkillsInfo = append(allSkillsInfo, skillInfo)
}
-1
View File
@@ -109,4 +109,3 @@ func (h *TerminalHandler) RunCommandWS(c *gin.Context) {
<-doneChan
}
+20 -21
View File
@@ -28,7 +28,7 @@ func NewVulnerabilityHandler(db *database.DB, logger *zap.Logger) *Vulnerability
// CreateVulnerabilityRequest 创建漏洞请求
type CreateVulnerabilityRequest struct {
ConversationID string `json:"conversation_id" binding:"required"`
ConversationID string `json:"conversation_id" binding:"required"`
ConversationTag string `json:"conversation_tag"`
TaskTag string `json:"task_tag"`
Title string `json:"title" binding:"required"`
@@ -51,18 +51,18 @@ func (h *VulnerabilityHandler) CreateVulnerability(c *gin.Context) {
}
vuln := &database.Vulnerability{
ConversationID: req.ConversationID,
ConversationID: req.ConversationID,
ConversationTag: req.ConversationTag,
TaskTag: req.TaskTag,
Title: req.Title,
Description: req.Description,
Severity: req.Severity,
Status: req.Status,
Type: req.Type,
Target: req.Target,
Proof: req.Proof,
Impact: req.Impact,
Recommendation: req.Recommendation,
Title: req.Title,
Description: req.Description,
Severity: req.Severity,
Status: req.Status,
Type: req.Type,
Target: req.Target,
Proof: req.Proof,
Impact: req.Impact,
Recommendation: req.Recommendation,
}
created, err := h.db.CreateVulnerability(vuln)
@@ -172,15 +172,15 @@ func (h *VulnerabilityHandler) ListVulnerabilities(c *gin.Context) {
type UpdateVulnerabilityRequest struct {
ConversationTag string `json:"conversation_tag"`
TaskTag string `json:"task_tag"`
Title string `json:"title"`
Description string `json:"description"`
Severity string `json:"severity"`
Status string `json:"status"`
Type string `json:"type"`
Target string `json:"target"`
Proof string `json:"proof"`
Impact string `json:"impact"`
Recommendation string `json:"recommendation"`
Title string `json:"title"`
Description string `json:"description"`
Severity string `json:"severity"`
Status string `json:"status"`
Type string `json:"type"`
Target string `json:"target"`
Proof string `json:"proof"`
Impact string `json:"impact"`
Recommendation string `json:"recommendation"`
}
// UpdateVulnerability 更新漏洞
@@ -460,4 +460,3 @@ func sanitizeExportName(raw string) string {
replacer := strings.NewReplacer("/", "-", "\\", "-", ":", "-", "*", "-", "?", "-", "\"", "-", "<", "-", ">", "-", "|", "-")
return replacer.Replace(name)
}
+3 -3
View File
@@ -16,9 +16,9 @@ const (
// DSL keys for [VectorEinoRetriever.Retrieve] via [retriever.WithDSLInfo].
const (
DSLRiskType = "risk_type"
DSLSimilarityThreshold = "similarity_threshold"
DSLSubIndexFilter = "sub_index_filter"
DSLRiskType = "risk_type"
DSLSimilarityThreshold = "similarity_threshold"
DSLSubIndexFilter = "sub_index_filter"
)
// FormatEmbeddingInput matches the historical indexing format so existing embeddings
+1 -1
View File
@@ -8,8 +8,8 @@ import (
"cyberstrike-ai/internal/config"
"github.com/cloudwego/eino/compose"
"github.com/cloudwego/eino/components/document"
"github.com/cloudwego/eino/compose"
"github.com/cloudwego/eino/schema"
)
+9 -9
View File
@@ -11,9 +11,9 @@ import (
"cyberstrike-ai/internal/config"
fileloader "github.com/cloudwego/eino-ext/components/document/loader/file"
"github.com/cloudwego/eino/compose"
"github.com/cloudwego/eino/components/document"
"github.com/cloudwego/eino/components/indexer"
"github.com/cloudwego/eino/compose"
"github.com/cloudwego/eino/schema"
"go.uber.org/zap"
)
@@ -35,14 +35,14 @@ type Indexer struct {
lastErrorTime time.Time
errorCount int
rebuildMu sync.RWMutex
isRebuilding bool
rebuildTotalItems int
rebuildCurrent int
rebuildFailed int
rebuildStartTime time.Time
rebuildLastItemID string
rebuildLastChunks int
rebuildMu sync.RWMutex
isRebuilding bool
rebuildTotalItems int
rebuildCurrent int
rebuildFailed int
rebuildStartTime time.Time
rebuildLastItemID string
rebuildLastChunks int
}
// NewIndexer 创建索引器并编译 Eino 索引链;kcfg 为完整知识库配置(含 indexing 与路径相关行为)。
+3 -3
View File
@@ -108,9 +108,9 @@ func (r *RetrievalLog) MarshalJSON() ([]byte, error) {
// CategoryWithItems 分类及其下的知识项(用于按分类分页)
type CategoryWithItems struct {
Category string `json:"category"` // 分类名称
ItemCount int `json:"itemCount"` // 该分类下的知识项总数
Items []*KnowledgeItemSummary `json:"items"` // 该分类下的知识项列表
Category string `json:"category"` // 分类名称
ItemCount int `json:"itemCount"` // 该分类下的知识项总数
Items []*KnowledgeItemSummary `json:"items"` // 该分类下的知识项列表
}
// SearchRequest 搜索请求
+7 -7
View File
@@ -192,13 +192,13 @@ func convertOpenAIToClaude(payload interface{}) (*claudeRequest, error) {
fnName, _ := fn["name"].(string)
fnArgs, _ := fn["arguments"]
// 防御:缺少 name 或 id 的 tool_call 会被 Claude 拒绝
if strings.TrimSpace(fnName) == "" {
fnName = "unknown_function"
}
if strings.TrimSpace(tcID) == "" {
tcID = fmt.Sprintf("call_%d", time.Now().UnixNano())
}
// 防御:缺少 name 或 id 的 tool_call 会被 Claude 拒绝
if strings.TrimSpace(fnName) == "" {
fnName = "unknown_function"
}
if strings.TrimSpace(tcID) == "" {
tcID = fmt.Sprintf("call_%d", time.Now().UnixNano())
}
var inputRaw json.RawMessage
switch v := fnArgs.(type) {
+14 -14
View File
@@ -281,9 +281,9 @@ func (c *Client) ChatCompletionStream(ctx context.Context, payload interface{},
// StreamToolCall 流式工具调用的累积结果(arguments 以字符串形式拼接,留给上层再解析为 JSON)。
type StreamToolCall struct {
Index int
ID string
Type string
Index int
ID string
Type string
FunctionName string
FunctionArgsStr string
}
@@ -348,10 +348,10 @@ func (c *Client) ChatCompletionStreamWithToolCalls(
Arguments string `json:"arguments,omitempty"`
}
type toolCallDelta struct {
Index int `json:"index,omitempty"`
ID string `json:"id,omitempty"`
Type string `json:"type,omitempty"`
Function toolCallFunctionDelta `json:"function,omitempty"`
Index int `json:"index,omitempty"`
ID string `json:"id,omitempty"`
Type string `json:"type,omitempty"`
Function toolCallFunctionDelta `json:"function,omitempty"`
}
type streamDelta2 struct {
Content string `json:"content,omitempty"`
@@ -371,10 +371,10 @@ func (c *Client) ChatCompletionStreamWithToolCalls(
}
type toolCallAccum struct {
id string
typ string
name string
args strings.Builder
id string
typ string
name string
args strings.Builder
}
toolCallAccums := make(map[int]*toolCallAccum)
@@ -475,9 +475,9 @@ func (c *Client) ChatCompletionStreamWithToolCalls(
for _, idx := range indices {
acc := toolCallAccums[idx]
tc := StreamToolCall{
Index: idx,
ID: acc.id,
Type: acc.typ,
Index: idx,
ID: acc.id,
Type: acc.typ,
FunctionName: acc.name,
FunctionArgsStr: acc.args.String(),
}
+40 -41
View File
@@ -19,11 +19,11 @@ import (
func setupTestExecutor(t *testing.T) (*Executor, *mcp.Server) {
logger := zap.NewNop()
mcpServer := mcp.NewServer(logger)
cfg := &config.SecurityConfig{
Tools: []config.ToolConfig{},
}
executor := NewExecutor(cfg, mcpServer, logger)
return executor, mcpServer
}
@@ -32,12 +32,12 @@ func setupTestExecutor(t *testing.T) (*Executor, *mcp.Server) {
func setupTestStorage(t *testing.T) *storage.FileResultStorage {
tmpDir := filepath.Join(os.TempDir(), "test_executor_storage_"+time.Now().Format("20060102_150405"))
logger := zap.NewNop()
storage, err := storage.NewFileResultStorage(tmpDir, logger)
if err != nil {
t.Fatalf("创建测试存储失败: %v", err)
}
return storage
}
@@ -45,46 +45,46 @@ func TestExecutor_ExecuteInternalTool_QueryExecutionResult(t *testing.T) {
executor, _ := setupTestExecutor(t)
testStorage := setupTestStorage(t)
executor.SetResultStorage(testStorage)
// 准备测试数据
executionID := "test_exec_001"
toolName := "nmap_scan"
result := "Line 1: Port 22 open\nLine 2: Port 80 open\nLine 3: Port 443 open\nLine 4: error occurred"
// 保存测试结果
err := testStorage.SaveResult(executionID, toolName, result)
if err != nil {
t.Fatalf("保存测试结果失败: %v", err)
}
ctx := context.Background()
// 测试1: 基本查询(第一页)
args := map[string]interface{}{
"execution_id": executionID,
"page": float64(1),
"limit": float64(2),
}
toolResult, err := executor.executeQueryExecutionResult(ctx, args)
if err != nil {
t.Fatalf("执行查询失败: %v", err)
}
if toolResult.IsError {
t.Fatalf("查询应该成功,但返回了错误: %s", toolResult.Content[0].Text)
}
// 验证结果包含预期内容
resultText := toolResult.Content[0].Text
if !strings.Contains(resultText, executionID) {
t.Errorf("结果中应该包含执行ID: %s", executionID)
}
if !strings.Contains(resultText, "第 1/") {
t.Errorf("结果中应该包含分页信息")
}
// 测试2: 搜索功能
args2 := map[string]interface{}{
"execution_id": executionID,
@@ -92,21 +92,21 @@ func TestExecutor_ExecuteInternalTool_QueryExecutionResult(t *testing.T) {
"page": float64(1),
"limit": float64(10),
}
toolResult2, err := executor.executeQueryExecutionResult(ctx, args2)
if err != nil {
t.Fatalf("执行搜索失败: %v", err)
}
if toolResult2.IsError {
t.Fatalf("搜索应该成功,但返回了错误: %s", toolResult2.Content[0].Text)
}
resultText2 := toolResult2.Content[0].Text
if !strings.Contains(resultText2, "error") {
t.Errorf("搜索结果中应该包含关键词: error")
}
// 测试3: 过滤功能
args3 := map[string]interface{}{
"execution_id": executionID,
@@ -114,46 +114,46 @@ func TestExecutor_ExecuteInternalTool_QueryExecutionResult(t *testing.T) {
"page": float64(1),
"limit": float64(10),
}
toolResult3, err := executor.executeQueryExecutionResult(ctx, args3)
if err != nil {
t.Fatalf("执行过滤失败: %v", err)
}
if toolResult3.IsError {
t.Fatalf("过滤应该成功,但返回了错误: %s", toolResult3.Content[0].Text)
}
resultText3 := toolResult3.Content[0].Text
if !strings.Contains(resultText3, "Port") {
t.Errorf("过滤结果中应该包含关键词: Port")
}
// 测试4: 缺少必需参数
args4 := map[string]interface{}{
"page": float64(1),
}
toolResult4, err := executor.executeQueryExecutionResult(ctx, args4)
if err != nil {
t.Fatalf("执行查询失败: %v", err)
}
if !toolResult4.IsError {
t.Fatal("缺少execution_id应该返回错误")
}
// 测试5: 不存在的执行ID
args5 := map[string]interface{}{
"execution_id": "nonexistent_id",
"page": float64(1),
}
toolResult5, err := executor.executeQueryExecutionResult(ctx, args5)
if err != nil {
t.Fatalf("执行查询失败: %v", err)
}
if !toolResult5.IsError {
t.Fatal("不存在的执行ID应该返回错误")
}
@@ -161,22 +161,22 @@ func TestExecutor_ExecuteInternalTool_QueryExecutionResult(t *testing.T) {
func TestExecutor_ExecuteInternalTool_UnknownTool(t *testing.T) {
executor, _ := setupTestExecutor(t)
ctx := context.Background()
args := map[string]interface{}{
"test": "value",
}
// 测试未知的内部工具类型
toolResult, err := executor.executeInternalTool(ctx, "unknown_tool", "internal:unknown_tool", args)
if err != nil {
t.Fatalf("执行内部工具失败: %v", err)
}
if !toolResult.IsError {
t.Fatal("未知的工具类型应该返回错误")
}
if !strings.Contains(toolResult.Content[0].Text, "未知的内部工具类型") {
t.Errorf("错误消息应该包含'未知的内部工具类型'")
}
@@ -185,21 +185,21 @@ func TestExecutor_ExecuteInternalTool_UnknownTool(t *testing.T) {
func TestExecutor_ExecuteInternalTool_NoStorage(t *testing.T) {
executor, _ := setupTestExecutor(t)
// 不设置存储,测试未初始化的情况
ctx := context.Background()
args := map[string]interface{}{
"execution_id": "test_id",
}
toolResult, err := executor.executeQueryExecutionResult(ctx, args)
if err != nil {
t.Fatalf("执行查询失败: %v", err)
}
if !toolResult.IsError {
t.Fatal("未初始化的存储应该返回错误")
}
if !strings.Contains(toolResult.Content[0].Text, "结果存储未初始化") {
t.Errorf("错误消息应该包含'结果存储未初始化'")
}
@@ -207,7 +207,7 @@ func TestExecutor_ExecuteInternalTool_NoStorage(t *testing.T) {
func TestPaginateLines(t *testing.T) {
lines := []string{"Line 1", "Line 2", "Line 3", "Line 4", "Line 5"}
// 测试第一页
page := paginateLines(lines, 1, 2)
if page.Page != 1 {
@@ -225,7 +225,7 @@ func TestPaginateLines(t *testing.T) {
if len(page.Lines) != 2 {
t.Errorf("第一页行数不匹配。期望: 2, 实际: %d", len(page.Lines))
}
// 测试第二页
page2 := paginateLines(lines, 2, 2)
if len(page2.Lines) != 2 {
@@ -234,13 +234,13 @@ func TestPaginateLines(t *testing.T) {
if page2.Lines[0] != "Line 3" {
t.Errorf("第二页第一行不匹配。期望: Line 3, 实际: %s", page2.Lines[0])
}
// 测试最后一页
page3 := paginateLines(lines, 3, 2)
if len(page3.Lines) != 1 {
t.Errorf("第三页行数不匹配。期望: 1, 实际: %d", len(page3.Lines))
}
// 测试超出范围的页码(应该返回最后一页)
page4 := paginateLines(lines, 4, 2)
if page4.Page != 3 {
@@ -249,13 +249,13 @@ func TestPaginateLines(t *testing.T) {
if len(page4.Lines) != 1 {
t.Errorf("最后一页应该只有1行。实际: %d行", len(page4.Lines))
}
// 测试无效页码(小于1
page0 := paginateLines(lines, 0, 2)
if page0.Page != 1 {
t.Errorf("无效页码应该被修正为1。实际: %d", page0.Page)
}
// 测试空列表
emptyPage := paginateLines([]string{}, 1, 10)
if emptyPage.TotalLines != 0 {
@@ -265,4 +265,3 @@ func TestPaginateLines(t *testing.T) {
t.Errorf("空列表应该返回空结果。实际: %d行", len(emptyPage.Lines))
}
}
+4 -4
View File
@@ -16,10 +16,10 @@ type rateLimitEntry struct {
// RateLimiter 基于 IP 的滑动窗口速率限制器
type RateLimiter struct {
mu sync.Mutex
entries map[string]*rateLimitEntry
limit int // 窗口内允许的最大请求数
window time.Duration // 窗口时长
mu sync.Mutex
entries map[string]*rateLimitEntry
limit int // 窗口内允许的最大请求数
window time.Duration // 窗口时长
}
// NewRateLimiter 创建速率限制器