diff --git a/internal/handler/agent.go b/internal/handler/agent.go index ae05ba1f..7baf58cf 100644 --- a/internal/handler/agent.go +++ b/internal/handler/agent.go @@ -200,9 +200,7 @@ func (h *AgentHandler) CancelRunningTaskForConversation(conversationID string) { if h == nil || conversationID == "" || h.tasks == nil { return } - if execID := h.tasks.ActiveMCPExecutionID(conversationID); execID != "" { - h.agent.CancelMCPToolExecutionWithNote(execID, "") - } + h.cancelActiveMCPToolForConversation(conversationID) if ok, err := h.tasks.CancelTask(conversationID, ErrTaskCancelled); ok { h.logger.Info("已取消会话运行中任务", zap.String("conversationId", conversationID)) } else if err != nil { @@ -210,6 +208,15 @@ func (h *AgentHandler) CancelRunningTaskForConversation(conversationID string) { } } +func (h *AgentHandler) cancelActiveMCPToolForConversation(conversationID string) { + if h == nil || h.tasks == nil || h.agent == nil { + return + } + if execID := h.tasks.ActiveMCPExecutionID(conversationID); execID != "" { + h.agent.CancelMCPToolExecutionWithNote(execID, "") + } +} + // HitlToolWhitelistSaver 合并 HITL 免审批工具到全局配置并落盘 type HitlToolWhitelistSaver interface { MergeHitlToolWhitelistIntoConfig(add []string) error @@ -239,6 +246,7 @@ func NewAgentHandler(agent *agent.Agent, db *database.DB, cfg *config.Config, lo hitlManager: NewHITLManager(db, logger), batchCronParser: cron.NewParser(cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow | cron.Descriptor), } + tm.SetToolCanceler(handler.cancelActiveMCPToolForConversation) if err := handler.hitlManager.EnsureSchema(); err != nil { logger.Warn("初始化 HITL 表失败", zap.Error(err)) } @@ -1411,6 +1419,7 @@ func (h *AgentHandler) CancelAgentLoop(c *gin.Context) { var cause error = ErrTaskCancelled msg := "已提交取消请求,任务将在当前步骤完成后停止。" + h.cancelActiveMCPToolForConversation(req.ConversationID) ok, err := h.tasks.CancelTask(req.ConversationID, cause) if err != nil { h.logger.Error("取消任务失败", zap.Error(err)) diff --git a/internal/handler/project_context.go b/internal/handler/project_context.go index ef8eb7e5..ca2664bc 100644 --- a/internal/handler/project_context.go +++ b/internal/handler/project_context.go @@ -7,7 +7,7 @@ import ( "go.uber.org/zap" ) -// agentSessionContextBlock 注入会话工作目录与项目黑板(用于 system prompt 追加块)。 +// agentSessionContextBlock 注入会话工作目录、项目黑板与用户原文锚点(用于 system prompt 追加块)。 func (h *AgentHandler) agentSessionContextBlock(conversationID string) string { var parts []string if ws := h.buildWorkspaceBlock(conversationID); ws != "" { @@ -16,6 +16,9 @@ func (h *AgentHandler) agentSessionContextBlock(conversationID string) string { if bb := h.projectBlackboardBlock(conversationID); bb != "" { parts = append(parts, bb) } + if uv := h.userVerbatimAnchorBlock(conversationID); uv != "" { + parts = append(parts, uv) + } return strings.Join(parts, "\n\n") } @@ -67,6 +70,29 @@ func (h *AgentHandler) projectBlackboardBlock(conversationID string) string { return strings.TrimSpace(block) } +// userVerbatimAnchorBlock 从 messages 表构建用户各轮原文锚点(压缩后仍由 summarization Finalize 刷新)。 +func (h *AgentHandler) userVerbatimAnchorBlock(conversationID string) string { + if h == nil || h.db == nil || h.config == nil { + return "" + } + conversationID = strings.TrimSpace(conversationID) + if conversationID == "" { + return "" + } + maxRunes := h.config.MultiAgent.UserVerbatimAnchorMaxRunesEffective() + if maxRunes < 0 { + return "" + } + msgs, err := h.db.GetMessages(conversationID) + if err != nil { + if h.logger != nil { + h.logger.Warn("构建用户原文锚点失败", zap.String("conversationId", conversationID), zap.Error(err)) + } + return "" + } + return project.BuildUserVerbatimAnchorBlockFromMessages(msgs, maxRunes) +} + // conversationProjectID 返回对话绑定的项目 ID;未绑定或查询失败时返回空字符串。 func (h *AgentHandler) conversationProjectID(conversationID string) string { if h == nil || h.db == nil { diff --git a/internal/handler/task_manager.go b/internal/handler/task_manager.go index a40c4123..bfaea31b 100644 --- a/internal/handler/task_manager.go +++ b/internal/handler/task_manager.go @@ -247,6 +247,8 @@ type AgentTaskManager struct { maxHistorySize int // 最大历史记录数 historyRetention time.Duration // 历史记录保留时间 eventBus *TaskEventBus // 可选:任务结束时关闭镜像 SSE 订阅 + // toolCanceler 在用户整轮停止任务时终止当前 MCP 工具(非「中断并继续」)。 + toolCanceler func(conversationID string) } const ( @@ -277,6 +279,13 @@ func (m *AgentTaskManager) SetTaskEventBus(b *TaskEventBus) { m.eventBus = b } +// SetToolCanceler 设置整轮停止任务时终止当前 MCP 工具的回调(由 AgentHandler 注入)。 +func (m *AgentTaskManager) SetToolCanceler(fn func(conversationID string)) { + m.mu.Lock() + defer m.mu.Unlock() + m.toolCanceler = fn +} + // GetTask 返回运行中任务(无则 nil)。 func (m *AgentTaskManager) GetTask(conversationID string) *AgentTask { m.mu.RLock() @@ -372,14 +381,21 @@ func (m *AgentTaskManager) CancelTask(conversationID string, cause error) (bool, task.InterruptContinueNote = "" } cancel := task.cancel - m.mu.Unlock() - if cause == nil { cause = ErrTaskCancelled } + var toolCanceler func(string) + if errors.Is(cause, ErrTaskCancelled) { + toolCanceler = m.toolCanceler + } + m.mu.Unlock() + if cancel != nil { cancel(cause) } + if toolCanceler != nil { + toolCanceler(conversationID) + } return true, nil } diff --git a/internal/handler/task_manager_tool_cancel_test.go b/internal/handler/task_manager_tool_cancel_test.go new file mode 100644 index 00000000..20c3a076 --- /dev/null +++ b/internal/handler/task_manager_tool_cancel_test.go @@ -0,0 +1,80 @@ +package handler + +import ( + "context" + "errors" + "testing" + + "cyberstrike-ai/internal/multiagent" +) + +func TestCancelTaskInvokesToolCancelerOnFullStop(t *testing.T) { + tm := NewAgentTaskManager() + called := false + tm.SetToolCanceler(func(conversationID string) { + if conversationID == "conv-1" { + called = true + } + }) + + _, cancel := context.WithCancelCause(context.Background()) + _, err := tm.StartTask("conv-1", "hello", cancel) + if err != nil { + t.Fatalf("StartTask: %v", err) + } + + ok, err := tm.CancelTask("conv-1", ErrTaskCancelled) + if err != nil || !ok { + t.Fatalf("CancelTask: ok=%v err=%v", ok, err) + } + if !called { + t.Fatal("expected tool canceler to be invoked on full task cancel") + } +} + +func TestCancelTaskSkipsToolCancelerOnInterruptContinue(t *testing.T) { + tm := NewAgentTaskManager() + called := false + tm.SetToolCanceler(func(conversationID string) { + called = true + }) + + _, cancel := context.WithCancelCause(context.Background()) + _, err := tm.StartTask("conv-1", "hello", cancel) + if err != nil { + t.Fatalf("StartTask: %v", err) + } + + ok, err := tm.CancelTask("conv-1", multiagent.ErrInterruptContinue) + if err != nil || !ok { + t.Fatalf("CancelTask: ok=%v err=%v", ok, err) + } + if called { + t.Fatal("tool canceler must not run for interrupt-continue") + } +} + +func TestCancelTaskDefaultCauseIsTaskCancelled(t *testing.T) { + tm := NewAgentTaskManager() + var gotCause error + tm.SetToolCanceler(func(conversationID string) { + if conversationID == "conv-2" { + gotCause = ErrTaskCancelled + } + }) + + ctx, cancel := context.WithCancelCause(context.Background()) + if _, err := tm.StartTask("conv-2", "hello", cancel); err != nil { + t.Fatalf("StartTask: %v", err) + } + + if _, err := tm.CancelTask("conv-2", nil); err != nil { + t.Fatalf("CancelTask: %v", err) + } + if !errors.Is(context.Cause(ctx), ErrTaskCancelled) { + t.Fatalf("expected ErrTaskCancelled cause, got %v", context.Cause(ctx)) + } + if gotCause != ErrTaskCancelled { + t.Fatalf("expected tool canceler path for default cancel cause") + } +}