diff --git a/internal/handler/agent.go b/internal/handler/agent.go index b3599c3f..29739e51 100644 --- a/internal/handler/agent.go +++ b/internal/handler/agent.go @@ -19,6 +19,7 @@ import ( "cyberstrike-ai/internal/agent" "cyberstrike-ai/internal/config" "cyberstrike-ai/internal/database" + "cyberstrike-ai/internal/mcp" "cyberstrike-ai/internal/mcp/builtin" "cyberstrike-ai/internal/multiagent" @@ -1495,6 +1496,8 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) { taskCtx, timeoutCancel := context.WithTimeout(baseCtx, 600*time.Minute) defer timeoutCancel() defer cancelWithCause(nil) + taskCtx = mcp.WithMCPConversationID(taskCtx, conversationID) + taskCtx = mcp.WithToolRunRegistry(taskCtx, h.tasks) progressCallback := h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, sendEvent) taskCtx = h.injectReactHITLInterceptor(taskCtx, cancelWithCause, conversationID, assistantMessageID, sendEvent) @@ -1728,22 +1731,39 @@ func (h *AgentHandler) CancelAgentLoop(c *gin.Context) { return } - if req.ContinueAfter && strings.TrimSpace(req.Reason) == "" { - c.JSON(http.StatusBadRequest, gin.H{"error": "continueAfter 为 true 时必须提供非空的 reason(中断说明)"}) + if req.ContinueAfter { + if h.tasks.GetTask(req.ConversationID) == nil { + c.JSON(http.StatusNotFound, gin.H{"error": "未找到正在执行的任务"}) + return + } + execID := h.tasks.ActiveMCPExecutionID(req.ConversationID) + if execID == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "当前没有正在执行的 MCP 工具(例如模型尚在推理、尚未发起工具调用)。请等待工具开始执行后再试,或使用「彻底停止」结束整轮任务。"}) + return + } + note := strings.TrimSpace(req.Reason) + if !h.agent.CancelMCPToolExecutionWithNote(execID, note) { + c.JSON(http.StatusNotFound, gin.H{"error": "未找到进行中的工具执行或该调用已结束"}) + return + } + h.logger.Info("对话页仅终止当前 MCP 工具", + zap.String("conversationId", req.ConversationID), + zap.String("executionId", execID), + zap.Bool("hasNote", note != ""), + ) + c.JSON(http.StatusOK, gin.H{ + "status": "tool_abort_requested", + "conversationId": req.ConversationID, + "executionId": execID, + "message": "已请求终止当前工具调用;工具返回后本轮推理将继续(与 MCP 监控页终止一致)。", + "continueAfter": true, + "interruptWithNote": note != "", + }) return } var cause error = ErrTaskCancelled msg := "已提交取消请求,任务将在当前步骤完成后停止。" - if req.ContinueAfter { - if !h.tasks.SetInterruptContinueReason(req.ConversationID, req.Reason) { - c.JSON(http.StatusNotFound, gin.H{"error": "未找到正在执行的任务,无法提交中断说明"}) - return - } - cause = ErrUserInterruptContinue - msg = "已提交中断说明,当前步骤结束后将写入对话并继续迭代。" - } - ok, err := h.tasks.CancelTask(req.ConversationID, cause) if err != nil { h.logger.Error("取消任务失败", zap.Error(err)) @@ -1758,10 +1778,10 @@ func (h *AgentHandler) CancelAgentLoop(c *gin.Context) { c.JSON(http.StatusOK, gin.H{ "status": "cancelling", - "conversationId": req.ConversationID, + "conversationId": req.ConversationID, "message": msg, - "continueAfter": req.ContinueAfter, - "interruptWithNote": req.ContinueAfter, + "continueAfter": false, + "interruptWithNote": false, }) } @@ -2539,6 +2559,8 @@ func (h *AgentHandler) executeBatchQueue(queueID string) { // 创建进度回调函数:写 DB + 镜像到 task-events,支持刷新后继续流式展示。 progressCallback = h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, sendEvent) + taskCtx = mcp.WithMCPConversationID(taskCtx, conversationID) + taskCtx = mcp.WithToolRunRegistry(taskCtx, h.tasks) // 使用队列配置的角色工具列表(如果为空,表示使用所有工具) useBatchMulti := false diff --git a/internal/handler/eino_single_agent.go b/internal/handler/eino_single_agent.go index 6801d93e..93fc603e 100644 --- a/internal/handler/eino_single_agent.go +++ b/internal/handler/eino_single_agent.go @@ -10,6 +10,7 @@ import ( "sync" "time" + "cyberstrike-ai/internal/mcp" "cyberstrike-ai/internal/multiagent" "github.com/gin-gonic/gin" @@ -45,7 +46,7 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) { sendEvent := func(eventType, message string, data interface{}) { if eventType == "error" && baseCtx != nil { cause := context.Cause(baseCtx) - if errors.Is(cause, ErrTaskCancelled) || errors.Is(cause, ErrUserInterruptContinue) { + if errors.Is(cause, ErrTaskCancelled) { return } } @@ -117,7 +118,6 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) { } var cancelWithCause context.CancelCauseFunc - firstRun := true curFinalMessage := prep.FinalMessage curHistory := prep.History roleTools := prep.RoleTools @@ -151,108 +151,56 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) { var result *multiagent.RunResult var runErr error - for { - baseCtx, cancelWithCause = context.WithCancelCause(context.Background()) - taskCtx, timeoutCancel := context.WithTimeout(baseCtx, 600*time.Minute) + baseCtx, cancelWithCause = context.WithCancelCause(context.Background()) + taskCtx, timeoutCancel := context.WithTimeout(baseCtx, 600*time.Minute) - if firstRun { - if _, err := h.tasks.StartTask(conversationID, req.Message, cancelWithCause); err != nil { - var errorMsg string - if errors.Is(err, ErrTaskAlreadyRunning) { - errorMsg = "⚠️ 当前会话已有任务正在执行中,请等待当前任务完成或点击「停止任务」后再尝试。" - sendEvent("error", errorMsg, map[string]interface{}{ - "conversationId": conversationID, - "errorType": "task_already_running", - }) - } else { - errorMsg = "❌ 无法启动任务: " + err.Error() - sendEvent("error", errorMsg, nil) - } - if assistantMessageID != "" { - _, _ = h.db.Exec("UPDATE messages SET content = ?, updated_at = ? WHERE id = ?", errorMsg, time.Now(), assistantMessageID) - } - sendEvent("done", "", map[string]interface{}{"conversationId": conversationID}) - timeoutCancel() - return - } - taskOwned = true - firstRun = false + if _, err := h.tasks.StartTask(conversationID, req.Message, cancelWithCause); err != nil { + var errorMsg string + if errors.Is(err, ErrTaskAlreadyRunning) { + errorMsg = "⚠️ 当前会话已有任务正在执行中,请等待当前任务完成或点击「停止任务」后再尝试。" + sendEvent("error", errorMsg, map[string]interface{}{ + "conversationId": conversationID, + "errorType": "task_already_running", + }) } else { - if err := h.tasks.ResetTaskCancelForContinue(conversationID, cancelWithCause); err != nil { - h.logger.Error("续跑任务时重置 cancel 失败", zap.Error(err)) - taskStatus = "failed" - sendEvent("error", err.Error(), nil) - sendEvent("done", "", map[string]interface{}{"conversationId": conversationID}) - timeoutCancel() - return - } + errorMsg = "❌ 无法启动任务: " + err.Error() + sendEvent("error", errorMsg, nil) } - - progressCallback := h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, sendEvent) - taskCtx = multiagent.WithHITLToolInterceptor(taskCtx, func(ctx context.Context, toolName, arguments string) (string, error) { - return h.interceptHITLForEinoTool(ctx, cancelWithCause, conversationID, assistantMessageID, sendEvent, toolName, arguments) - }) - - result, runErr = multiagent.RunEinoSingleChatModelAgent( - taskCtx, - h.config, - &h.config.MultiAgent, - h.agent, - h.logger, - conversationID, - curFinalMessage, - curHistory, - roleTools, - progressCallback, - ) + if assistantMessageID != "" { + _, _ = h.db.Exec("UPDATE messages SET content = ?, updated_at = ? WHERE id = ?", errorMsg, time.Now(), assistantMessageID) + } + sendEvent("done", "", map[string]interface{}{"conversationId": conversationID}) timeoutCancel() + return + } + taskOwned = true - if runErr == nil { - break - } + progressCallback := h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, sendEvent) + taskCtx = mcp.WithMCPConversationID(taskCtx, conversationID) + taskCtx = mcp.WithToolRunRegistry(taskCtx, h.tasks) + taskCtx = multiagent.WithHITLToolInterceptor(taskCtx, func(ctx context.Context, toolName, arguments string) (string, error) { + return h.interceptHITLForEinoTool(ctx, cancelWithCause, conversationID, assistantMessageID, sendEvent, toolName, arguments) + }) + result, runErr = multiagent.RunEinoSingleChatModelAgent( + taskCtx, + h.config, + &h.config.MultiAgent, + h.agent, + h.logger, + conversationID, + curFinalMessage, + curHistory, + roleTools, + progressCallback, + ) + timeoutCancel() + + if runErr != nil { cause := context.Cause(baseCtx) if shouldPersistEinoAgentTraceAfterRunError(baseCtx) { h.persistEinoAgentTraceForResume(conversationID, result) } - if errors.Is(cause, ErrUserInterruptContinue) { - reason := h.tasks.TakeInterruptContinueReason(conversationID) - prepNext, perr := h.prepareSessionAfterUserInterrupt(conversationID, assistantMessageID, reason, roleTools) - if perr != nil { - h.logger.Error("准备中断后续跑失败", zap.Error(perr)) - taskStatus = "failed" - h.tasks.UpdateTaskStatus(conversationID, taskStatus) - errMsg := "中断后续跑失败: " + perr.Error() - if assistantMessageID != "" { - _, _ = h.db.Exec("UPDATE messages SET content = ?, updated_at = ? WHERE id = ?", errMsg, time.Now(), assistantMessageID) - } - sendEvent("error", errMsg, map[string]interface{}{ - "conversationId": conversationID, - "messageId": assistantMessageID, - }) - sendEvent("done", "", map[string]interface{}{"conversationId": conversationID}) - return - } - assistantMessageID = prepNext.AssistantMessageID - curFinalMessage = prepNext.FinalMessage - curHistory = prepNext.History - if prepNext.UserMessageID != "" { - sendEvent("message_saved", "", map[string]interface{}{ - "conversationId": conversationID, - "userMessageId": prepNext.UserMessageID, - }) - } - sendEvent("user_interrupt_continue", reason, map[string]interface{}{ - "conversationId": conversationID, - "reason": reason, - "messageId": assistantMessageID, - }) - sendEvent("progress", "已接收中断说明,继续迭代...", map[string]interface{}{ - "conversationId": conversationID, - }) - continue - } - if errors.Is(cause, ErrTaskCancelled) { taskStatus = "cancelled" h.tasks.UpdateTaskStatus(conversationID, taskStatus) diff --git a/internal/handler/multi_agent.go b/internal/handler/multi_agent.go index 4793a28d..68b77a26 100644 --- a/internal/handler/multi_agent.go +++ b/internal/handler/multi_agent.go @@ -11,6 +11,7 @@ import ( "time" "cyberstrike-ai/internal/config" + "cyberstrike-ai/internal/mcp" "cyberstrike-ai/internal/multiagent" "github.com/gin-gonic/gin" @@ -62,7 +63,7 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) { // 为避免 UI 看到“取消错误 + cancelled 文案”两条回复,这里直接丢弃取消对应的 error。 if eventType == "error" && baseCtx != nil { cause := context.Cause(baseCtx) - if errors.Is(cause, ErrTaskCancelled) || errors.Is(cause, ErrUserInterruptContinue) { + if errors.Is(cause, ErrTaskCancelled) { return } } @@ -134,7 +135,6 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) { } var cancelWithCause context.CancelCauseFunc - firstRun := true curFinalMessage := prep.FinalMessage curHistory := prep.History roleTools := prep.RoleTools @@ -160,110 +160,58 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) { var result *multiagent.RunResult var runErr error - for { - baseCtx, cancelWithCause = context.WithCancelCause(context.Background()) - taskCtx, timeoutCancel := context.WithTimeout(baseCtx, 600*time.Minute) + baseCtx, cancelWithCause = context.WithCancelCause(context.Background()) + taskCtx, timeoutCancel := context.WithTimeout(baseCtx, 600*time.Minute) - if firstRun { - if _, err := h.tasks.StartTask(conversationID, req.Message, cancelWithCause); err != nil { - var errorMsg string - if errors.Is(err, ErrTaskAlreadyRunning) { - errorMsg = "⚠️ 当前会话已有任务正在执行中,请等待当前任务完成或点击「停止任务」后再尝试。" - sendEvent("error", errorMsg, map[string]interface{}{ - "conversationId": conversationID, - "errorType": "task_already_running", - }) - } else { - errorMsg = "❌ 无法启动任务: " + err.Error() - sendEvent("error", errorMsg, nil) - } - if assistantMessageID != "" { - _, _ = h.db.Exec("UPDATE messages SET content = ?, updated_at = ? WHERE id = ?", errorMsg, time.Now(), assistantMessageID) - } - sendEvent("done", "", map[string]interface{}{"conversationId": conversationID}) - timeoutCancel() - return - } - taskOwned = true - firstRun = false + if _, err := h.tasks.StartTask(conversationID, req.Message, cancelWithCause); err != nil { + var errorMsg string + if errors.Is(err, ErrTaskAlreadyRunning) { + errorMsg = "⚠️ 当前会话已有任务正在执行中,请等待当前任务完成或点击「停止任务」后再尝试。" + sendEvent("error", errorMsg, map[string]interface{}{ + "conversationId": conversationID, + "errorType": "task_already_running", + }) } else { - if err := h.tasks.ResetTaskCancelForContinue(conversationID, cancelWithCause); err != nil { - h.logger.Error("续跑任务时重置 cancel 失败", zap.Error(err)) - taskStatus = "failed" - sendEvent("error", err.Error(), nil) - sendEvent("done", "", map[string]interface{}{"conversationId": conversationID}) - timeoutCancel() - return - } + errorMsg = "❌ 无法启动任务: " + err.Error() + sendEvent("error", errorMsg, nil) } - - progressCallback := h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, sendEvent) - taskCtx = multiagent.WithHITLToolInterceptor(taskCtx, func(ctx context.Context, toolName, arguments string) (string, error) { - return h.interceptHITLForEinoTool(ctx, cancelWithCause, conversationID, assistantMessageID, sendEvent, toolName, arguments) - }) - - result, runErr = multiagent.RunDeepAgent( - taskCtx, - h.config, - &h.config.MultiAgent, - h.agent, - h.logger, - conversationID, - curFinalMessage, - curHistory, - roleTools, - progressCallback, - h.agentsMarkdownDir, - orch, - ) + if assistantMessageID != "" { + _, _ = h.db.Exec("UPDATE messages SET content = ?, updated_at = ? WHERE id = ?", errorMsg, time.Now(), assistantMessageID) + } + sendEvent("done", "", map[string]interface{}{"conversationId": conversationID}) timeoutCancel() + return + } + taskOwned = true - if runErr == nil { - break - } + progressCallback := h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, sendEvent) + taskCtx = mcp.WithMCPConversationID(taskCtx, conversationID) + taskCtx = mcp.WithToolRunRegistry(taskCtx, h.tasks) + taskCtx = multiagent.WithHITLToolInterceptor(taskCtx, func(ctx context.Context, toolName, arguments string) (string, error) { + return h.interceptHITLForEinoTool(ctx, cancelWithCause, conversationID, assistantMessageID, sendEvent, toolName, arguments) + }) + result, runErr = multiagent.RunDeepAgent( + taskCtx, + h.config, + &h.config.MultiAgent, + h.agent, + h.logger, + conversationID, + curFinalMessage, + curHistory, + roleTools, + progressCallback, + h.agentsMarkdownDir, + orch, + ) + timeoutCancel() + + if runErr != nil { cause := context.Cause(baseCtx) if shouldPersistEinoAgentTraceAfterRunError(baseCtx) { h.persistEinoAgentTraceForResume(conversationID, result) } - if errors.Is(cause, ErrUserInterruptContinue) { - reason := h.tasks.TakeInterruptContinueReason(conversationID) - prepNext, perr := h.prepareSessionAfterUserInterrupt(conversationID, assistantMessageID, reason, roleTools) - if perr != nil { - h.logger.Error("准备中断后续跑失败", zap.Error(perr)) - taskStatus = "failed" - h.tasks.UpdateTaskStatus(conversationID, taskStatus) - errMsg := "中断后续跑失败: " + perr.Error() - if assistantMessageID != "" { - _, _ = h.db.Exec("UPDATE messages SET content = ?, updated_at = ? WHERE id = ?", errMsg, time.Now(), assistantMessageID) - } - sendEvent("error", errMsg, map[string]interface{}{ - "conversationId": conversationID, - "messageId": assistantMessageID, - }) - sendEvent("done", "", map[string]interface{}{"conversationId": conversationID}) - return - } - assistantMessageID = prepNext.AssistantMessageID - curFinalMessage = prepNext.FinalMessage - curHistory = prepNext.History - if prepNext.UserMessageID != "" { - sendEvent("message_saved", "", map[string]interface{}{ - "conversationId": conversationID, - "userMessageId": prepNext.UserMessageID, - }) - } - sendEvent("user_interrupt_continue", reason, map[string]interface{}{ - "conversationId": conversationID, - "reason": reason, - "messageId": assistantMessageID, - }) - sendEvent("progress", "已接收中断说明,继续迭代...", map[string]interface{}{ - "conversationId": conversationID, - }) - continue - } - if errors.Is(cause, ErrTaskCancelled) { taskStatus = "cancelled" h.tasks.UpdateTaskStatus(conversationID, taskStatus) diff --git a/internal/handler/multi_agent_prepare.go b/internal/handler/multi_agent_prepare.go index 47f6ae09..51703e86 100644 --- a/internal/handler/multi_agent_prepare.go +++ b/internal/handler/multi_agent_prepare.go @@ -3,7 +3,6 @@ package handler import ( "fmt" "strings" - "time" "cyberstrike-ai/internal/agent" "cyberstrike-ai/internal/database" @@ -143,64 +142,3 @@ func (h *AgentHandler) prepareMultiAgentSession(req *ChatRequest) (*multiAgentPr UserMessageID: userMessageID, }, nil } - -// prepareSessionAfterUserInterrupt 用户「中断并说明」后:结束当前助手占位、写入用户说明、新建助手占位,并生成下一轮 Run 所需的 History + FinalMessage。 -func (h *AgentHandler) prepareSessionAfterUserInterrupt(conversationID, prevAssistantMessageID, reason string, roleTools []string) (*multiAgentPrepared, error) { - if strings.TrimSpace(conversationID) == "" { - return nil, fmt.Errorf("conversationId 为空") - } - if _, err := h.db.GetConversation(conversationID); err != nil { - return nil, fmt.Errorf("对话不存在") - } - note := "(已根据用户说明中断当前步骤,正在继续迭代。)" - if prevAssistantMessageID != "" { - if _, err := h.db.Exec("UPDATE messages SET content = ?, updated_at = ? WHERE id = ?", note, time.Now(), prevAssistantMessageID); err != nil { - return nil, fmt.Errorf("更新助手消息失败: %w", err) - } - r := strings.TrimSpace(reason) - detail := "用户中断并说明" - if r != "" { - detail += ":" + r - } - _ = h.db.AddProcessDetail(prevAssistantMessageID, conversationID, "user_interrupt", detail, map[string]interface{}{ - "reason": r, - }) - } - userContent := fmt.Sprintf("【用户中断说明】%s\n\n请根据以上说明调整并继续任务。", strings.TrimSpace(reason)) - if strings.TrimSpace(reason) == "" { - userContent = "【用户中断说明】(未填写具体原因)\n\n请根据情况调整并继续任务。" - } - userMsgRow, err := h.db.AddMessage(conversationID, "user", userContent, nil) - if err != nil { - return nil, fmt.Errorf("保存用户消息失败: %w", err) - } - assistantMsg, err := h.db.AddMessage(conversationID, "assistant", "处理中...", nil) - if err != nil || assistantMsg == nil { - return nil, fmt.Errorf("创建助手占位失败: %w", err) - } - msgs, err := h.db.GetMessages(conversationID) - if err != nil || len(msgs) < 2 { - return nil, fmt.Errorf("读取消息历史失败或消息不足") - } - histMsgs := msgs[:len(msgs)-2] - agentHistory := make([]agent.ChatMessage, 0, len(histMsgs)) - for _, msg := range histMsgs { - agentHistory = append(agentHistory, agent.ChatMessage{ - Role: msg.Role, - Content: msg.Content, - }) - } - userMessageID := "" - if userMsgRow != nil { - userMessageID = userMsgRow.ID - } - return &multiAgentPrepared{ - ConversationID: conversationID, - CreatedNew: false, - History: agentHistory, - FinalMessage: userContent, - RoleTools: roleTools, - AssistantMessageID: assistantMsg.ID, - UserMessageID: userMessageID, - }, nil -} diff --git a/internal/handler/openapi.go b/internal/handler/openapi.go index b91d9d2b..da785f0c 100644 --- a/internal/handler/openapi.go +++ b/internal/handler/openapi.go @@ -463,11 +463,11 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) { }, "reason": map[string]interface{}{ "type": "string", - "description": "中断说明;与 continueAfter 同时为真时必填,将写入对话并由同一会话流式迭代继续", + "description": "可选。与 MCP 监控页「终止并说明」一致:非空时合并进当前工具返回给模型的文本(含 USER INTERRUPT NOTE 块)", }, "continueAfter": map[string]interface{}{ "type": "boolean", - "description": "为 true 时取消当前运行步骤并注入 reason 后继续迭代(非彻底停止)", + "description": "为 true 时仅终止当前进行中的 MCP 工具调用(不取消整轮任务);须已有工具在执行,否则 400", }, }, }, diff --git a/internal/handler/task_manager.go b/internal/handler/task_manager.go index 55a0910d..4609ad52 100644 --- a/internal/handler/task_manager.go +++ b/internal/handler/task_manager.go @@ -11,9 +11,6 @@ import ( // ErrTaskCancelled 用户取消任务的错误 var ErrTaskCancelled = errors.New("agent task cancelled by user") -// ErrUserInterruptContinue 用户在进度条上「中断并说明」:取消当前运行步骤,将说明写入对话并继续迭代(与 ErrTaskCancelled 区分) -var ErrUserInterruptContinue = errors.New("user interrupt with continue") - // ErrTaskAlreadyRunning 会话已有任务正在执行 var ErrTaskAlreadyRunning = errors.New("agent task already running for conversation") @@ -34,12 +31,56 @@ type AgentTask struct { Status string `json:"status"` CancellingAt time.Time `json:"-"` // 进入 cancelling 状态的时间,用于清理长时间卡住的任务 - // InterruptContinueReason 由 /api/agent-loop/cancel 在 continueAfter 时写入,Run 返回后由 handler 取出并清空 - InterruptContinueReason string `json:"-"` + // ActiveMCPExecutionID 当前正在执行的 MCP 工具 executionId(仅内存,供「中断并继续」= 仅掐当前工具) + ActiveMCPExecutionID string `json:"-"` cancel func(error) } +// RegisterRunningTool 实现 mcp.ToolRunRegistry:工具开始时登记本会话当前 executionId。 +func (m *AgentTaskManager) RegisterRunningTool(conversationID, executionID string) { + conversationID = strings.TrimSpace(conversationID) + executionID = strings.TrimSpace(executionID) + if conversationID == "" || executionID == "" { + return + } + m.mu.Lock() + defer m.mu.Unlock() + if t, ok := m.tasks[conversationID]; ok && t != nil { + t.ActiveMCPExecutionID = executionID + } +} + +// UnregisterRunningTool 工具结束时清除登记(仅当 id 仍匹配时清除,避免并发串单)。 +func (m *AgentTaskManager) UnregisterRunningTool(conversationID, executionID string) { + conversationID = strings.TrimSpace(conversationID) + executionID = strings.TrimSpace(executionID) + if conversationID == "" || executionID == "" { + return + } + m.mu.Lock() + defer m.mu.Unlock() + if t, ok := m.tasks[conversationID]; ok && t != nil { + if t.ActiveMCPExecutionID == executionID { + t.ActiveMCPExecutionID = "" + } + } +} + +// ActiveMCPExecutionID 返回当前会话进行中的工具 executionId,无则空串。 +func (m *AgentTaskManager) ActiveMCPExecutionID(conversationID string) string { + conversationID = strings.TrimSpace(conversationID) + if conversationID == "" { + return "" + } + m.mu.RLock() + defer m.mu.RUnlock() + if t, ok := m.tasks[conversationID]; ok && t != nil { + return strings.TrimSpace(t.ActiveMCPExecutionID) + } + return "" +} + // CompletedTask 已完成的任务(用于历史记录) type CompletedTask struct { ConversationID string `json:"conversationId"` @@ -156,49 +197,6 @@ func (m *AgentTaskManager) StartTask(conversationID, message string, cancel cont return task, nil } -// SetInterruptContinueReason 在发起 ErrUserInterruptContinue 取消前写入用户说明(须任务仍存在)。 -func (m *AgentTaskManager) SetInterruptContinueReason(conversationID, reason string) bool { - m.mu.Lock() - defer m.mu.Unlock() - task, ok := m.tasks[conversationID] - if !ok { - return false - } - task.InterruptContinueReason = strings.TrimSpace(reason) - return true -} - -// TakeInterruptContinueReason 取出并清空用户中断说明。 -func (m *AgentTaskManager) TakeInterruptContinueReason(conversationID string) string { - m.mu.Lock() - defer m.mu.Unlock() - task, ok := m.tasks[conversationID] - if !ok { - return "" - } - r := task.InterruptContinueReason - task.InterruptContinueReason = "" - return r -} - -// ResetTaskCancelForContinue 在一次「中断并继续」后恢复任务为 running 并绑定新的 cancel(同一会话同一条 HTTP 流内续跑)。 -func (m *AgentTaskManager) ResetTaskCancelForContinue(conversationID string, cancel context.CancelCauseFunc) error { - m.mu.Lock() - defer m.mu.Unlock() - task, ok := m.tasks[conversationID] - if !ok { - return errors.New("no active task") - } - task.cancel = func(err error) { - if cancel != nil { - cancel(err) - } - } - task.Status = "running" - task.CancellingAt = time.Time{} - return nil -} - // CancelTask 取消指定会话的任务。若任务已在取消中,仍返回 (true, nil) 以便接口幂等、前端不报错。 func (m *AgentTaskManager) CancelTask(conversationID string, cause error) (bool, error) { m.mu.Lock()