From c4e0b9735c7a96b0ef88412133cba2687a59e993 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=85=AC=E6=98=8E?= <83812544+Ed1s0nZ@users.noreply.github.com> Date: Sun, 10 May 2026 21:38:28 +0800 Subject: [PATCH] Add files via upload --- internal/handler/agent.go | 57 ++++++++--- internal/handler/eino_single_agent.go | 87 +++++++++++----- internal/handler/multi_agent.go | 138 +++++++++++++++++++++----- internal/handler/task_manager.go | 61 +++++++++++- 4 files changed, 277 insertions(+), 66 deletions(-) diff --git a/internal/handler/agent.go b/internal/handler/agent.go index 9e9a47ff..d3c3fe58 100644 --- a/internal/handler/agent.go +++ b/internal/handler/agent.go @@ -1789,27 +1789,51 @@ func (h *AgentHandler) CancelAgentLoop(c *gin.Context) { return } execID := h.tasks.ActiveMCPExecutionID(req.ConversationID) - if execID == "" { - c.JSON(http.StatusBadRequest, gin.H{"error": "当前没有正在执行的 MCP 工具(例如模型尚在推理、尚未发起工具调用)。请等待工具开始执行后再试,或使用「彻底停止」结束整轮任务。"}) - return - } note := strings.TrimSpace(req.Reason) - if !h.agent.CancelMCPToolExecutionWithNote(execID, note) { - c.JSON(http.StatusNotFound, gin.H{"error": "未找到进行中的工具执行或该调用已结束"}) + if execID != "" { + if !h.agent.CancelMCPToolExecutionWithNote(execID, note) { + c.JSON(http.StatusNotFound, gin.H{"error": "未找到进行中的工具执行或该调用已结束"}) + return + } + h.logger.Info("对话页仅终止当前 MCP 工具", + zap.String("conversationId", req.ConversationID), + zap.String("executionId", execID), + zap.Bool("hasNote", note != ""), + ) + c.JSON(http.StatusOK, gin.H{ + "status": "tool_abort_requested", + "conversationId": req.ConversationID, + "executionId": execID, + "message": "已请求终止当前工具调用;工具返回后本轮推理将继续(与 MCP 监控页终止一致)。", + "continueAfter": true, + "interruptWithNote": note != "", + "continueWithoutTool": false, + }) return } - h.logger.Info("对话页仅终止当前 MCP 工具", + // 无进行中的 MCP 工具(模型纯推理/流式输出阶段):取消当前上下文并由 Eino 流式处理器合并用户补充后自动续跑。 + h.tasks.SetInterruptContinueNote(req.ConversationID, note) + ok, err := h.tasks.CancelTask(req.ConversationID, multiagent.ErrInterruptContinue) + if err != nil { + h.logger.Error("中断并继续(无工具)失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + if !ok { + c.JSON(http.StatusNotFound, gin.H{"error": "未找到正在执行的任务"}) + return + } + h.logger.Info("对话页中断并继续(无 MCP 工具,将自动续跑)", zap.String("conversationId", req.ConversationID), - zap.String("executionId", execID), zap.Bool("hasNote", note != ""), ) c.JSON(http.StatusOK, gin.H{ - "status": "tool_abort_requested", - "conversationId": req.ConversationID, - "executionId": execID, - "message": "已请求终止当前工具调用;工具返回后本轮推理将继续(与 MCP 监控页终止一致)。", - "continueAfter": true, - "interruptWithNote": note != "", + "status": "interrupt_continue_scheduled", + "conversationId": req.ConversationID, + "message": "已请求暂停当前推理;用户补充将合并到上下文并自动继续执行(无需整轮停止)。", + "continueAfter": true, + "interruptWithNote": note != "", + "continueWithoutTool": true, }) return } @@ -2901,6 +2925,11 @@ func (h *AgentHandler) loadHistoryFromAgentTrace(conversationID string) ([]agent if toolCallID, ok := msgMap["tool_call_id"].(string); ok { msg.ToolCallID = toolCallID } + if tn, ok := msgMap["tool_name"].(string); ok && strings.TrimSpace(tn) != "" { + msg.ToolName = strings.TrimSpace(tn) + } else if tn, ok := msgMap["name"].(string); ok && strings.TrimSpace(tn) != "" && strings.EqualFold(msg.Role, "tool") { + msg.ToolName = strings.TrimSpace(tn) + } agentMessages = append(agentMessages, msg) } diff --git a/internal/handler/eino_single_agent.go b/internal/handler/eino_single_agent.go index 21f73d28..1bce56af 100644 --- a/internal/handler/eino_single_agent.go +++ b/internal/handler/eino_single_agent.go @@ -46,7 +46,7 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) { sendEvent := func(eventType, message string, data interface{}) { if eventType == "error" && baseCtx != nil { cause := context.Cause(baseCtx) - if errors.Is(cause, ErrTaskCancelled) { + if errors.Is(cause, ErrTaskCancelled) || errors.Is(cause, multiagent.ErrInterruptContinue) { return } } @@ -175,29 +175,68 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) { } taskOwned = true - progressCallback := h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, sendEvent) - taskCtx = mcp.WithMCPConversationID(taskCtx, conversationID) - taskCtx = mcp.WithToolRunRegistry(taskCtx, h.tasks) - taskCtx = multiagent.WithHITLToolInterceptor(taskCtx, func(ctx context.Context, toolName, arguments string) (string, error) { - return h.interceptHITLForEinoTool(ctx, cancelWithCause, conversationID, assistantMessageID, sendEvent, toolName, arguments) - }) + var cumulativeMCPExecutionIDs []string - result, runErr = multiagent.RunEinoSingleChatModelAgent( - taskCtx, - h.config, - &h.config.MultiAgent, - h.agent, - h.logger, - conversationID, - curFinalMessage, - curHistory, - roleTools, - progressCallback, - ) - timeoutCancel() + for { + progressCallback := h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, sendEvent) + taskCtxLoop := mcp.WithMCPConversationID(taskCtx, conversationID) + taskCtxLoop = mcp.WithToolRunRegistry(taskCtxLoop, h.tasks) + taskCtxLoop = multiagent.WithHITLToolInterceptor(taskCtxLoop, func(ctx context.Context, toolName, arguments string) (string, error) { + return h.interceptHITLForEinoTool(ctx, cancelWithCause, conversationID, assistantMessageID, sendEvent, toolName, arguments) + }) + + result, runErr = multiagent.RunEinoSingleChatModelAgent( + taskCtxLoop, + h.config, + &h.config.MultiAgent, + h.agent, + h.logger, + conversationID, + curFinalMessage, + curHistory, + roleTools, + progressCallback, + ) + timeoutCancel() + + if result != nil && len(result.MCPExecutionIDs) > 0 { + cumulativeMCPExecutionIDs = mergeMCPExecutionIDLists(cumulativeMCPExecutionIDs, result.MCPExecutionIDs) + } + + if runErr == nil { + break + } - if runErr != nil { cause := context.Cause(baseCtx) + if errors.Is(cause, multiagent.ErrInterruptContinue) { + if shouldPersistEinoAgentTraceAfterRunError(baseCtx) { + h.persistEinoAgentTraceForResume(conversationID, result) + } + note := h.tasks.TakeInterruptContinueNote(conversationID) + icSummary := interruptContinueTimelineSummary(note) + progressCallback("user_interrupt_continue", icSummary, map[string]interface{}{ + "conversationId": conversationID, + "rawReason": strings.TrimSpace(note), + "emptyReason": strings.TrimSpace(note) == "", + "kind": "no_active_mcp_tool", + }) + inject := formatInterruptContinueUserMessage(note) + // 不写入 messages 表为 user 气泡:避免主对话流出现大段模板;说明已由 user_interrupt_continue 记入助手 process_details(迭代详情)。 + if hist, err := h.loadHistoryFromAgentTrace(conversationID); err == nil && len(hist) > 0 { + curHistory = hist + } + curFinalMessage = inject + sendEvent("progress", "已合并用户补充与最新轨迹,正在继续推理…", map[string]interface{}{ + "conversationId": conversationID, + "source": "interrupt_continue", + }) + h.tasks.UpdateTaskStatus(conversationID, "running") + baseCtx, cancelWithCause = context.WithCancelCause(context.Background()) + h.tasks.BindTaskCancel(conversationID, cancelWithCause) + taskCtx, timeoutCancel = context.WithTimeout(baseCtx, 600*time.Minute) + continue + } + if shouldPersistEinoAgentTraceAfterRunError(baseCtx) { h.persistEinoAgentTraceForResume(conversationID, result) } @@ -259,8 +298,8 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) { if assistantMessageID != "" { mcpIDsJSON := "" - if len(result.MCPExecutionIDs) > 0 { - jsonData, _ := json.Marshal(result.MCPExecutionIDs) + if len(cumulativeMCPExecutionIDs) > 0 { + jsonData, _ := json.Marshal(cumulativeMCPExecutionIDs) mcpIDsJSON = string(jsonData) } _, _ = h.db.Exec( @@ -279,7 +318,7 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) { } sendEvent("response", result.Response, map[string]interface{}{ - "mcpExecutionIds": result.MCPExecutionIDs, + "mcpExecutionIds": cumulativeMCPExecutionIDs, "conversationId": conversationID, "messageId": assistantMessageID, "agentMode": "eino_single", diff --git a/internal/handler/multi_agent.go b/internal/handler/multi_agent.go index f68f42b1..4278119d 100644 --- a/internal/handler/multi_agent.go +++ b/internal/handler/multi_agent.go @@ -63,7 +63,7 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) { // 为避免 UI 看到“取消错误 + cancelled 文案”两条回复,这里直接丢弃取消对应的 error。 if eventType == "error" && baseCtx != nil { cause := context.Cause(baseCtx) - if errors.Is(cause, ErrTaskCancelled) { + if errors.Is(cause, ErrTaskCancelled) || errors.Is(cause, multiagent.ErrInterruptContinue) { return } } @@ -184,31 +184,71 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) { } taskOwned = true - progressCallback := h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, sendEvent) - taskCtx = mcp.WithMCPConversationID(taskCtx, conversationID) - taskCtx = mcp.WithToolRunRegistry(taskCtx, h.tasks) - taskCtx = multiagent.WithHITLToolInterceptor(taskCtx, func(ctx context.Context, toolName, arguments string) (string, error) { - return h.interceptHITLForEinoTool(ctx, cancelWithCause, conversationID, assistantMessageID, sendEvent, toolName, arguments) - }) + // 同一 HTTP 流内多段 Run(如中断并继续)合并 MCP execution id,供最终 response / 库表与工具芯片展示完整列表 + var cumulativeMCPExecutionIDs []string - result, runErr = multiagent.RunDeepAgent( - taskCtx, - h.config, - &h.config.MultiAgent, - h.agent, - h.logger, - conversationID, - curFinalMessage, - curHistory, - roleTools, - progressCallback, - h.agentsMarkdownDir, - orch, - ) - timeoutCancel() + for { + progressCallback := h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, sendEvent) + taskCtxLoop := mcp.WithMCPConversationID(taskCtx, conversationID) + taskCtxLoop = mcp.WithToolRunRegistry(taskCtxLoop, h.tasks) + taskCtxLoop = multiagent.WithHITLToolInterceptor(taskCtxLoop, func(ctx context.Context, toolName, arguments string) (string, error) { + return h.interceptHITLForEinoTool(ctx, cancelWithCause, conversationID, assistantMessageID, sendEvent, toolName, arguments) + }) + + result, runErr = multiagent.RunDeepAgent( + taskCtxLoop, + h.config, + &h.config.MultiAgent, + h.agent, + h.logger, + conversationID, + curFinalMessage, + curHistory, + roleTools, + progressCallback, + h.agentsMarkdownDir, + orch, + ) + timeoutCancel() + + if result != nil && len(result.MCPExecutionIDs) > 0 { + cumulativeMCPExecutionIDs = mergeMCPExecutionIDLists(cumulativeMCPExecutionIDs, result.MCPExecutionIDs) + } + + if runErr == nil { + break + } - if runErr != nil { cause := context.Cause(baseCtx) + if errors.Is(cause, multiagent.ErrInterruptContinue) { + if shouldPersistEinoAgentTraceAfterRunError(baseCtx) { + h.persistEinoAgentTraceForResume(conversationID, result) + } + note := h.tasks.TakeInterruptContinueNote(conversationID) + icSummary := interruptContinueTimelineSummary(note) + progressCallback("user_interrupt_continue", icSummary, map[string]interface{}{ + "conversationId": conversationID, + "rawReason": strings.TrimSpace(note), + "emptyReason": strings.TrimSpace(note) == "", + "kind": "no_active_mcp_tool", + }) + inject := formatInterruptContinueUserMessage(note) + // 不写入 messages 表为 user 气泡:避免主对话流出现大段模板;说明已由 user_interrupt_continue 记入助手 process_details(迭代详情)。 + if hist, err := h.loadHistoryFromAgentTrace(conversationID); err == nil && len(hist) > 0 { + curHistory = hist + } + curFinalMessage = inject + sendEvent("progress", "已合并用户补充与最新轨迹,正在继续推理…", map[string]interface{}{ + "conversationId": conversationID, + "source": "interrupt_continue", + }) + h.tasks.UpdateTaskStatus(conversationID, "running") + baseCtx, cancelWithCause = context.WithCancelCause(context.Background()) + h.tasks.BindTaskCancel(conversationID, cancelWithCause) + taskCtx, timeoutCancel = context.WithTimeout(baseCtx, 600*time.Minute) + continue + } + if shouldPersistEinoAgentTraceAfterRunError(baseCtx) { h.persistEinoAgentTraceForResume(conversationID, result) } @@ -270,8 +310,8 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) { if assistantMessageID != "" { mcpIDsJSON := "" - if len(result.MCPExecutionIDs) > 0 { - jsonData, _ := json.Marshal(result.MCPExecutionIDs) + if len(cumulativeMCPExecutionIDs) > 0 { + jsonData, _ := json.Marshal(cumulativeMCPExecutionIDs) mcpIDsJSON = string(jsonData) } _, _ = h.db.Exec( @@ -294,7 +334,7 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) { effectiveOrch = config.NormalizeMultiAgentOrchestration(o) } sendEvent("response", result.Response, map[string]interface{}{ - "mcpExecutionIds": result.MCPExecutionIDs, + "mcpExecutionIds": cumulativeMCPExecutionIDs, "conversationId": conversationID, "messageId": assistantMessageID, "agentMode": "eino_" + effectiveOrch, @@ -406,6 +446,52 @@ func (h *AgentHandler) persistEinoAgentTraceForResume(conversationID string, res } } +// mergeMCPExecutionIDLists 去重合并多段 Run 的 MCP execution id(顺序:先 dst 后 more)。 +func mergeMCPExecutionIDLists(dst []string, more []string) []string { + seen := make(map[string]struct{}, len(dst)+len(more)) + out := make([]string, 0, len(dst)+len(more)) + add := func(ids []string) { + for _, id := range ids { + id = strings.TrimSpace(id) + if id == "" { + continue + } + if _, ok := seen[id]; ok { + continue + } + seen[id] = struct{}{} + out = append(out, id) + } + } + add(dst) + add(more) + return out +} + +// interruptContinueTimelineSummary 时间线 / process_details 中展示的简短正文(完整模板已写入另一条用户消息)。 +func interruptContinueTimelineSummary(note string) string { + note = strings.TrimSpace(note) + if note == "" { + return "用户选择「中断并继续」,未填写说明;已按默认渗透补充模板合并上下文并续跑。" + } + return "用户中断说明(原文):\n\n" + note +} + +// formatInterruptContinueUserMessage 将「中断并继续」弹窗中的说明格式化为新一轮 user 消息(渗透场景下强调路径补充与端口复扫)。 +func formatInterruptContinueUserMessage(note string) string { + var b strings.Builder + b.WriteString("【用户补充 / 中断后继续】\n") + if s := strings.TrimSpace(note); s != "" { + b.WriteString(s) + b.WriteString("\n\n") + } + b.WriteString("【请在本轮落实】\n") + b.WriteString("- 将用户提供的接口路径、参数、业务变化纳入后续测试与推理。\n") + b.WriteString("- 若资产或目标信息有更新,请对目标重新执行端口/服务探测,再基于新结果规划下一步。\n") + b.WriteString("- 在已有轨迹基础上推进,避免无意义重复已完成的步骤。\n") + return strings.TrimSpace(b.String()) +} + func multiAgentHTTPErrorStatus(err error) (int, string) { msg := err.Error() switch { diff --git a/internal/handler/task_manager.go b/internal/handler/task_manager.go index 26fa4125..82e9f304 100644 --- a/internal/handler/task_manager.go +++ b/internal/handler/task_manager.go @@ -6,6 +6,8 @@ import ( "strings" "sync" "time" + + "cyberstrike-ai/internal/multiagent" ) // ErrTaskCancelled 用户取消任务的错误 @@ -32,6 +34,9 @@ type AgentTask struct { // ActiveMCPExecutionID 当前正在执行的 MCP 工具 executionId(仅内存,供「中断并继续」= 仅掐当前工具) ActiveMCPExecutionID string `json:"-"` + // InterruptContinueNote 无 MCP 时「中断并继续」由用户在弹窗中填写的补充说明(Cancel 前写入,续跑轮次读取后清空) + InterruptContinueNote string `json:"-"` + cancel func(error) } @@ -65,6 +70,50 @@ func (m *AgentTaskManager) UnregisterRunningTool(conversationID, executionID str } } +// SetInterruptContinueNote 在发起 ErrInterruptContinue 取消前写入用户补充说明(仅内存)。 +func (m *AgentTaskManager) SetInterruptContinueNote(conversationID, note string) { + conversationID = strings.TrimSpace(conversationID) + if conversationID == "" { + return + } + m.mu.Lock() + defer m.mu.Unlock() + if t, ok := m.tasks[conversationID]; ok && t != nil { + t.InterruptContinueNote = note + } +} + +// TakeInterruptContinueNote 读取并清空补充说明(续跑开始时调用一次)。 +func (m *AgentTaskManager) TakeInterruptContinueNote(conversationID string) string { + conversationID = strings.TrimSpace(conversationID) + if conversationID == "" { + return "" + } + m.mu.Lock() + defer m.mu.Unlock() + if t, ok := m.tasks[conversationID]; ok && t != nil { + n := t.InterruptContinueNote + t.InterruptContinueNote = "" + return n + } + return "" +} + +// BindTaskCancel 在同一运行任务内替换与 context 绑定的 cancel 函数(用于中断后继续时换新 baseCtx)。 +func (m *AgentTaskManager) BindTaskCancel(conversationID string, cancel context.CancelCauseFunc) { + conversationID = strings.TrimSpace(conversationID) + if conversationID == "" || cancel == nil { + return + } + m.mu.Lock() + defer m.mu.Unlock() + if t, ok := m.tasks[conversationID]; ok && t != nil { + t.cancel = func(err error) { + cancel(err) + } + } +} + // ActiveMCPExecutionID 返回当前会话进行中的工具 executionId,无则空串。 func (m *AgentTaskManager) ActiveMCPExecutionID(conversationID string) string { conversationID = strings.TrimSpace(conversationID) @@ -210,8 +259,16 @@ func (m *AgentTaskManager) CancelTask(conversationID string, cause error) (bool, return true, nil } - task.Status = "cancelling" - task.CancellingAt = time.Now() + // ErrInterruptContinue:仅掐断当前推理步骤,随后由处理器续跑,不进入长时间「取消中」态。 + if cause != nil && errors.Is(cause, multiagent.ErrInterruptContinue) { + task.Status = "running" + } else { + task.Status = "cancelling" + task.CancellingAt = time.Now() + } + if cause != nil && errors.Is(cause, ErrTaskCancelled) { + task.InterruptContinueNote = "" + } cancel := task.cancel m.mu.Unlock()