From d17da2a47ddb2c5111ac9a3f6017639e3f76be01 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=85=AC=E6=98=8E?= <83812544+Ed1s0nZ@users.noreply.github.com> Date: Fri, 24 Apr 2026 01:54:38 +0800 Subject: [PATCH] Add files via upload --- internal/agent/agent.go | 101 +++- internal/app/app.go | 7 + internal/handler/agent.go | 210 +++++++- internal/handler/config.go | 59 ++ internal/handler/eino_single_agent.go | 75 ++- internal/handler/hitl.go | 748 ++++++++++++++++++++++++++ internal/handler/multi_agent.go | 66 ++- internal/handler/task_event_bus.go | 116 ++++ internal/handler/task_manager.go | 63 ++- 9 files changed, 1362 insertions(+), 83 deletions(-) create mode 100644 internal/handler/hitl.go create mode 100644 internal/handler/task_event_bus.go diff --git a/internal/agent/agent.go b/internal/agent/agent.go index a96c0983..36261379 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -53,6 +53,37 @@ type ResultStorage interface { DeleteResult(executionID string) error } +type toolCallInterceptorCtxKey struct{} + +type agentConversationIDKey struct{} + +func withAgentConversationID(ctx context.Context, id string) context.Context { + id = strings.TrimSpace(id) + if id == "" || ctx == nil { + return ctx + } + return context.WithValue(ctx, agentConversationIDKey{}, id) +} + +func agentConversationIDFromContext(ctx context.Context) string { + if ctx == nil { + return "" + } + v, _ := ctx.Value(agentConversationIDKey{}).(string) + return v +} + +// ToolCallInterceptor allows caller to gate or rewrite tool arguments just before execution. +// Returning a non-nil error means the tool call is rejected and execution is skipped. +type ToolCallInterceptor func(ctx context.Context, toolName string, args map[string]interface{}, toolCallID string) (map[string]interface{}, error) + +func WithToolCallInterceptor(ctx context.Context, fn ToolCallInterceptor) context.Context { + if fn == nil { + return ctx + } + return context.WithValue(ctx, toolCallInterceptorCtxKey{}, fn) +} + // NewAgent 创建新的Agent func NewAgent(cfg *config.OpenAIConfig, agentCfg *config.AgentConfig, mcpServer *mcp.Server, externalMCPMgr *mcp.ExternalMCPManager, logger *zap.Logger, maxIterations int) *Agent { // 如果 maxIterations 为 0 或负数,使用默认值 30 @@ -348,7 +379,8 @@ func (a *Agent) EinoSingleAgentSystemInstruction() string { // AgentLoopWithProgress 执行Agent循环(带进度回调和对话ID) func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, historyMessages []ChatMessage, conversationID string, callback ProgressCallback, roleTools []string) (*AgentLoopResult, error) { - // 设置当前对话ID + ctx = withAgentConversationID(ctx, conversationID) + // 设置当前对话ID(兼容未走 context 的旧路径;并发会话应以 context 为准) a.mu.Lock() a.currentConversationID = conversationID a.mu.Unlock() @@ -653,22 +685,49 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his "iteration": i + 1, }) + execArgs := toolCall.Function.Arguments + if interceptor, ok := ctx.Value(toolCallInterceptorCtxKey{}).(ToolCallInterceptor); ok && interceptor != nil { + newArgs, interceptErr := interceptor(ctx, toolCall.Function.Name, execArgs, toolCall.ID) + if interceptErr != nil { + errorMsg := fmt.Sprintf("工具调用被人工拒绝: %v", interceptErr) + messages = append(messages, ChatMessage{ + Role: "tool", + ToolCallID: toolCall.ID, + Content: errorMsg, + }) + sendProgress("tool_result", fmt.Sprintf("工具 %s 执行失败", toolCall.Function.Name), map[string]interface{}{ + "toolName": toolCall.Function.Name, + "success": false, + "isError": true, + "error": errorMsg, + "toolCallId": toolCall.ID, + "index": idx + 1, + "total": len(choice.Message.ToolCalls), + "iteration": i + 1, + }) + continue + } + if newArgs != nil { + execArgs = newArgs + } + } + // 执行工具 toolCtx := context.WithValue(ctx, security.ToolOutputCallbackCtxKey, security.ToolOutputCallback(func(chunk string) { if strings.TrimSpace(chunk) == "" { return } sendProgress("tool_result_delta", chunk, map[string]interface{}{ - "toolName": toolCall.Function.Name, - "toolCallId": toolCall.ID, - "index": idx + 1, - "total": len(choice.Message.ToolCalls), - "iteration": i + 1, + "toolName": toolCall.Function.Name, + "toolCallId": toolCall.ID, + "index": idx + 1, + "total": len(choice.Message.ToolCalls), + "iteration": i + 1, // success 在最终 tool_result 事件里会以 success/isError 标记为准 }) })) - execResult, err := a.executeToolViaMCP(toolCtx, toolCall.Function.Name, toolCall.Function.Arguments) + execResult, err := a.executeToolViaMCP(toolCtx, toolCall.Function.Name, execArgs) if err != nil { // 构建详细的错误信息,帮助AI理解问题并做出决策 errorMsg := a.formatToolError(toolCall.Function.Name, toolCall.Function.Arguments, err) @@ -746,7 +805,7 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his // 流式调用OpenAI获取总结(不提供工具,强制AI直接回复) sendProgress("response_start", "", map[string]interface{}{ "conversationId": conversationID, - "mcpExecutionIds": result.MCPExecutionIDs, + "mcpExecutionIds": result.MCPExecutionIDs, "messageGeneratedBy": "summary", }) streamText, _ := a.callOpenAIStreamText(ctx, messages, []Tool{}, func(delta string) error { @@ -793,7 +852,7 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his // 流式调用OpenAI获取总结(不提供工具,强制AI直接回复) sendProgress("response_start", "", map[string]interface{}{ "conversationId": conversationID, - "mcpExecutionIds": result.MCPExecutionIDs, + "mcpExecutionIds": result.MCPExecutionIDs, "messageGeneratedBy": "summary", }) streamText, _ := a.callOpenAIStreamText(ctx, messages, []Tool{}, func(delta string) error { @@ -840,7 +899,7 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his // 流式调用OpenAI获取总结(不提供工具,强制AI直接回复) sendProgress("response_start", "", map[string]interface{}{ "conversationId": conversationID, - "mcpExecutionIds": result.MCPExecutionIDs, + "mcpExecutionIds": result.MCPExecutionIDs, "messageGeneratedBy": "max_iter_summary", }) streamText, _ := a.callOpenAIStreamText(ctx, messages, []Tool{}, func(delta string) error { @@ -913,17 +972,13 @@ func (a *Agent) getAvailableTools(roleTools []string) []Tool { defer cancel() externalTools, err := a.externalMCPMgr.GetAllTools(ctx) + extMap := make(map[string]string) if err != nil { a.logger.Warn("获取外部MCP工具失败", zap.Error(err)) } else { // 获取外部MCP配置,用于检查工具启用状态 externalMCPConfigs := a.externalMCPMgr.GetConfigs() - // 清空并重建工具名称映射 - a.mu.Lock() - a.toolNameMapping = make(map[string]string) - a.mu.Unlock() - // 将外部MCP工具添加到工具列表(只添加启用的工具) for _, externalTool := range externalTools { // 外部工具使用 "mcpName::toolName" 作为toolKey @@ -983,9 +1038,7 @@ func (a *Agent) getAvailableTools(roleTools []string) []Tool { openAIName := strings.ReplaceAll(externalTool.Name, "::", "__") // 保存名称映射关系(OpenAI格式 -> 原始格式) - a.mu.Lock() - a.toolNameMapping[openAIName] = externalTool.Name - a.mu.Unlock() + extMap[openAIName] = externalTool.Name tools = append(tools, Tool{ Type: "function", @@ -997,6 +1050,9 @@ func (a *Agent) getAvailableTools(roleTools []string) []Tool { }) } } + a.mu.Lock() + a.toolNameMapping = extMap + a.mu.Unlock() } a.logger.Debug("获取可用工具列表", @@ -1390,9 +1446,12 @@ func (a *Agent) executeToolViaMCP(ctx context.Context, toolName string, args map // 如果是record_vulnerability工具,自动添加conversation_id if toolName == builtin.ToolRecordVulnerability { - a.mu.RLock() - conversationID := a.currentConversationID - a.mu.RUnlock() + conversationID := agentConversationIDFromContext(ctx) + if conversationID == "" { + a.mu.RLock() + conversationID = a.currentConversationID + a.mu.RUnlock() + } if conversationID != "" { args["conversation_id"] = conversationID diff --git a/internal/app/app.go b/internal/app/app.go index d4e3dfe7..7b0bd78f 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -326,6 +326,7 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) { registerWebshellTools(mcpServer, db, webshellHandler, log.Logger) registerWebshellManagementTools(mcpServer, db, webshellHandler, log.Logger) configHandler := handler.NewConfigHandler(configPath, cfg, mcpServer, executor, agent, attackChainHandler, externalMCPMgr, log.Logger) + agentHandler.SetHitlToolWhitelistSaver(configHandler) externalMCPHandler := handler.NewExternalMCPHandler(externalMCPMgr, cfg, configPath, log.Logger) roleHandler := handler.NewRoleHandler(cfg, configPath, log.Logger) skillsHandler := handler.NewSkillsHandler(cfg, configPath, log.Logger) @@ -654,9 +655,15 @@ func setupRoutes( // Eino ADK 单代理(ChatModelAgent + Runner;不依赖 multi_agent.enabled) protected.POST("/eino-agent", agentHandler.EinoSingleAgentLoop) protected.POST("/eino-agent/stream", agentHandler.EinoSingleAgentLoopStream) + protected.GET("/hitl/pending", agentHandler.ListHITLPending) + protected.POST("/hitl/decision", agentHandler.DecideHITLInterrupt) + protected.GET("/hitl/config/:conversationId", agentHandler.GetHITLConversationConfig) + protected.PUT("/hitl/config", agentHandler.UpsertHITLConversationConfig) + protected.POST("/hitl/tool-whitelist", agentHandler.MergeHITLGlobalToolWhitelist) // Agent Loop 取消与任务列表 protected.POST("/agent-loop/cancel", agentHandler.CancelAgentLoop) protected.GET("/agent-loop/tasks", agentHandler.ListAgentTasks) + protected.GET("/agent-loop/task-events", agentHandler.SubscribeAgentTaskEvents) protected.GET("/agent-loop/tasks/completed", agentHandler.ListCompletedTasks) // Eino DeepAgent 多代理(与单 Agent 并存,需 config.multi_agent.enabled) diff --git a/internal/handler/agent.go b/internal/handler/agent.go index e52a4f7b..5de90488 100644 --- a/internal/handler/agent.go +++ b/internal/handler/agent.go @@ -115,7 +115,9 @@ type AgentHandler struct { db *database.DB logger *zap.Logger tasks *AgentTaskManager + taskEventBus *TaskEventBus // 镜像 SSE 事件,供刷新后订阅同一运行中任务 batchTaskManager *BatchTaskManager + hitlManager *HITLManager config *config.Config // 配置引用,用于获取角色信息 knowledgeManager interface { // 知识库管理器接口 LogRetrieval(conversationID, messageID, query, riskType string, retrievedItems []string) error @@ -124,6 +126,13 @@ type AgentHandler struct { batchCronParser cron.Parser batchRunnerMu sync.Mutex batchRunning map[string]struct{} + // hitlWhitelistSaver 侧栏「应用」HITL 时将会话增量白名单合并写入 config.yaml(可选) + hitlWhitelistSaver HitlToolWhitelistSaver +} + +// HitlToolWhitelistSaver 合并 HITL 免审批工具到全局配置并落盘 +type HitlToolWhitelistSaver interface { + MergeHitlToolWhitelistIntoConfig(add []string) error } // NewAgentHandler 创建新的Agent处理器 @@ -136,16 +145,24 @@ func NewAgentHandler(agent *agent.Agent, db *database.DB, cfg *config.Config, lo logger.Warn("从数据库加载批量任务队列失败", zap.Error(err)) } + bus := NewTaskEventBus() + tm := NewAgentTaskManager() + tm.SetTaskEventBus(bus) handler := &AgentHandler{ agent: agent, db: db, logger: logger, - tasks: NewAgentTaskManager(), + tasks: tm, + taskEventBus: bus, batchTaskManager: batchTaskManager, config: cfg, + hitlManager: NewHITLManager(db, logger), batchCronParser: cron.NewParser(cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow | cron.Descriptor), batchRunning: make(map[string]struct{}), } + if err := handler.hitlManager.EnsureSchema(); err != nil { + logger.Warn("初始化 HITL 表失败", zap.Error(err)) + } go handler.batchQueueSchedulerLoop() return handler } @@ -162,6 +179,11 @@ func (h *AgentHandler) SetAgentsMarkdownDir(absDir string) { h.agentsMarkdownDir = strings.TrimSpace(absDir) } +// SetHitlToolWhitelistSaver 设置 HITL 白名单落盘(与 ConfigHandler 配合,避免循环引用用接口) +func (h *AgentHandler) SetHitlToolWhitelistSaver(s HitlToolWhitelistSaver) { + h.hitlWhitelistSaver = s +} + // ChatAttachment 聊天附件(用户上传的文件) type ChatAttachment struct { FileName string `json:"fileName"` // 展示用文件名 @@ -177,10 +199,18 @@ type ChatRequest struct { Role string `json:"role,omitempty"` // 角色名称 Attachments []ChatAttachment `json:"attachments,omitempty"` WebShellConnectionID string `json:"webshellConnectionId,omitempty"` // WebShell 管理 - AI 助手:当前选中的连接 ID,仅使用 webshell_* 工具 + Hitl *HITLRequest `json:"hitl,omitempty"` // Orchestration 仅对 /api/multi-agent、/api/multi-agent/stream:deep | plan_execute | supervisor;空则等同 deep。机器人/批量等无请求体时由服务端默认 deep。/api/eino-agent* 不使用此字段。 Orchestration string `json:"orchestration,omitempty"` } +type HITLRequest struct { + Enabled bool `json:"enabled"` + Mode string `json:"mode,omitempty"` + SensitiveTools []string `json:"sensitiveTools,omitempty"` + TimeoutSeconds int `json:"timeoutSeconds,omitempty"` +} + const ( maxAttachments = 10 chatUploadsDirName = "chat_uploads" // 对话附件保存的根目录(相对当前工作目录) @@ -462,6 +492,11 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) { } } + h.activateHITLForConversation(conversationID, req.Hitl) + if h.hitlManager != nil { + defer h.hitlManager.DeactivateConversation(conversationID) + } + // 优先尝试从保存的ReAct数据恢复历史上下文 agentHistoryMessages, err := h.loadHistoryFromReActData(conversationID) if err != nil { @@ -566,8 +601,13 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) { return } + baseCtx, cancelWithCause := context.WithCancelCause(c.Request.Context()) + defer cancelWithCause(nil) + progressCallback := h.createProgressCallback(baseCtx, cancelWithCause, conversationID, "", nil) + baseCtx = h.injectReactHITLInterceptor(baseCtx, cancelWithCause, conversationID, "", nil) + // 执行Agent Loop,传入历史消息和对话ID(使用包含角色提示词的finalMessage和角色工具列表) - result, err := h.agent.AgentLoopWithProgress(c.Request.Context(), finalMessage, agentHistoryMessages, conversationID, nil, roleTools) + result, err := h.agent.AgentLoopWithProgress(baseCtx, finalMessage, agentHistoryMessages, conversationID, progressCallback, roleTools) if err != nil { h.logger.Error("Agent Loop执行失败", zap.Error(err)) @@ -661,7 +701,7 @@ func (h *AgentHandler) ProcessMessageForRobot(ctx context.Context, conversationI if assistantMsg != nil { assistantMessageID = assistantMsg.ID } - progressCallback := h.createProgressCallback(conversationID, assistantMessageID, nil) + progressCallback := h.createProgressCallback(ctx, nil, conversationID, assistantMessageID, nil) useRobotMulti := h.config != nil && h.config.MultiAgent.Enabled && h.config.MultiAgent.RobotUseMultiAgent if useRobotMulti { @@ -755,9 +795,41 @@ type StreamEvent struct { // createProgressCallback 创建进度回调函数,用于保存processDetails // sendEventFunc: 可选的流式事件发送函数,如果为nil则不发送流式事件 -func (h *AgentHandler) createProgressCallback(conversationID, assistantMessageID string, sendEventFunc func(eventType, message string, data interface{})) agent.ProgressCallback { +func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun context.CancelCauseFunc, conversationID, assistantMessageID string, sendEventFunc func(eventType, message string, data interface{})) agent.ProgressCallback { // 用于保存tool_call事件中的参数,以便在tool_result时使用 toolCallCache := make(map[string]map[string]interface{}) // toolCallId -> arguments + skillCallCache := make(map[string]string) // toolCallId -> skillName + skillToolName := "skill" + if h.config != nil { + if customName := strings.TrimSpace(h.config.MultiAgent.EinoSkills.SkillToolName); customName != "" { + skillToolName = customName + } + } + + extractSkillName := func(args map[string]interface{}) string { + if len(args) == 0 { + return "" + } + for _, key := range []string{"skill_name", "skillName", "name", "skill", "id", "skill_id", "skillId"} { + if v, ok := args[key]; ok { + switch vv := v.(type) { + case string: + if s := strings.TrimSpace(vv); s != "" { + return s + } + case map[string]interface{}: + for _, nestedKey := range []string{"name", "id", "skill_name", "skillId"} { + if nestedV, nestedOK := vv[nestedKey].(string); nestedOK { + if s := strings.TrimSpace(nestedV); s != "" { + return s + } + } + } + } + } + } + return "" + } // thinking_stream_*:不逐条落库,按 streamId 聚合,在后续关键事件前补一条可持久化的 thinking type thinkingBuf struct { @@ -840,6 +912,16 @@ func (h *AgentHandler) createProgressCallback(conversationID, assistantMessageID } } } + if strings.EqualFold(strings.TrimSpace(toolName), skillToolName) { + toolCallID, _ := dataMap["toolCallId"].(string) + if toolCallID != "" { + if argumentsObj, ok := dataMap["argumentsObj"].(map[string]interface{}); ok { + if skillName := extractSkillName(argumentsObj); skillName != "" { + skillCallCache[toolCallID] = skillName + } + } + } + } } } @@ -953,6 +1035,45 @@ func (h *AgentHandler) createProgressCallback(conversationID, assistantMessageID } } + // 记录 skills 调用统计(tool_call + tool_result 关联) + if eventType == "tool_result" && h.db != nil { + if dataMap, ok := data.(map[string]interface{}); ok { + toolName, _ := dataMap["toolName"].(string) + if strings.EqualFold(strings.TrimSpace(toolName), skillToolName) { + toolCallID, _ := dataMap["toolCallId"].(string) + skillName := "" + if toolCallID != "" { + skillName = strings.TrimSpace(skillCallCache[toolCallID]) + delete(skillCallCache, toolCallID) + } + if skillName == "" { + if argumentsObj, ok := dataMap["argumentsObj"].(map[string]interface{}); ok { + skillName = strings.TrimSpace(extractSkillName(argumentsObj)) + } + } + if skillName != "" { + success, ok := dataMap["success"].(bool) + if !ok { + if isError, okErr := dataMap["isError"].(bool); okErr { + success = !isError + } + } + successCalls := 0 + failedCalls := 0 + if success { + successCalls = 1 + } else { + failedCalls = 1 + } + now := time.Now() + if err := h.db.UpdateSkillStats(skillName, 1, successCalls, failedCalls, &now); err != nil { + h.logger.Warn("更新Skills调用统计失败", zap.Error(err), zap.String("skill", skillName)) + } + } + } + } + } + // 子代理回复流式增量不落库;结束时合并为一条 eino_agent_reply if assistantMessageID != "" && eventType == "eino_agent_reply_stream_end" { flushResponsePlan() @@ -1108,6 +1229,7 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) { clientDisconnected := false // 与 sseKeepalive 共用:禁止并发写 ResponseWriter,否则会破坏 chunked 编码(ERR_INVALID_CHUNKED_ENCODING)。 var sseWriteMu sync.Mutex + var ssePublishConversationID string // 用于快速确认模型是否真的产生了流式 delta var responseDeltaCount int var responseStartLogged bool @@ -1155,7 +1277,24 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) { } } - // 如果客户端已断开,不再发送事件 + event := StreamEvent{ + Type: eventType, + Message: message, + Data: data, + } + eventJSON, errJSON := json.Marshal(event) + if errJSON != nil { + eventJSON = []byte(`{"type":"error","message":"marshal failed"}`) + } + sseLine := make([]byte, 0, len(eventJSON)+8) + sseLine = append(sseLine, []byte("data: ")...) + sseLine = append(sseLine, eventJSON...) + sseLine = append(sseLine, '\n', '\n') + if ssePublishConversationID != "" && h.taskEventBus != nil { + h.taskEventBus.Publish(ssePublishConversationID, sseLine) + } + + // 如果客户端已断开,不再写入 HTTP(镜像订阅仍可收到事件) if clientDisconnected { return } @@ -1168,15 +1307,8 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) { default: } - event := StreamEvent{ - Type: eventType, - Message: message, - Data: data, - } - eventJSON, _ := json.Marshal(event) - sseWriteMu.Lock() - _, err := fmt.Fprintf(c.Writer, "data: %s\n\n", eventJSON) + _, err := c.Writer.Write(sseLine) if err != nil { sseWriteMu.Unlock() clientDisconnected = true @@ -1220,6 +1352,7 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) { return } } + ssePublishConversationID = conversationID // 优先尝试从保存的ReAct数据恢复历史上下文 agentHistoryMessages, err := h.loadHistoryFromReActData(conversationID) @@ -1350,14 +1483,14 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) { } // 创建进度回调函数,复用统一逻辑 - progressCallback := h.createProgressCallback(conversationID, assistantMessageID, sendEvent) - // 创建一个独立的上下文用于任务执行,不随HTTP请求取消 // 这样即使客户端断开连接(如刷新页面),任务也能继续执行 baseCtx, cancelWithCause := context.WithCancelCause(context.Background()) taskCtx, timeoutCancel := context.WithTimeout(baseCtx, 600*time.Minute) defer timeoutCancel() defer cancelWithCause(nil) + progressCallback := h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, sendEvent) + taskCtx = h.injectReactHITLInterceptor(taskCtx, cancelWithCause, conversationID, assistantMessageID, sendEvent) if _, err := h.tasks.StartTask(conversationID, req.Message, cancelWithCause); err != nil { var errorMsg string @@ -1606,6 +1739,51 @@ func (h *AgentHandler) CancelAgentLoop(c *gin.Context) { }) } +// SubscribeAgentTaskEvents GET SSE:订阅指定会话当前运行中任务的事件镜像(帧格式与 POST .../stream 一致),用于刷新页面或断线后接续 UI。 +func (h *AgentHandler) SubscribeAgentTaskEvents(c *gin.Context) { + conversationID := strings.TrimSpace(c.Query("conversationId")) + if conversationID == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "conversationId is required"}) + return + } + if h.tasks.GetTask(conversationID) == nil { + c.JSON(http.StatusNotFound, gin.H{"error": "no active task for this conversation"}) + return + } + if h.taskEventBus == nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "task event bus unavailable"}) + return + } + + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("X-Accel-Buffering", "no") + + sub, ch := h.taskEventBus.Subscribe(conversationID) + defer h.taskEventBus.Unsubscribe(conversationID, sub) + + flusher, _ := c.Writer.(http.Flusher) + ctx := c.Request.Context() + + for { + select { + case <-ctx.Done(): + return + case chunk, ok := <-ch: + if !ok { + return + } + if _, err := c.Writer.Write(chunk); err != nil { + return + } + if flusher != nil { + flusher.Flush() + } + } + } +} + // ListAgentTasks 列出所有运行中的任务 func (h *AgentHandler) ListAgentTasks(c *gin.Context) { c.JSON(http.StatusOK, gin.H{ @@ -2266,7 +2444,7 @@ func (h *AgentHandler) executeBatchQueue(queueID string) { if assistantMsg != nil { assistantMessageID = assistantMsg.ID } - progressCallback := h.createProgressCallback(conversationID, assistantMessageID, nil) + progressCallback := h.createProgressCallback(context.Background(), nil, conversationID, assistantMessageID, nil) // 执行任务(使用包含角色提示词的finalMessage和角色工具列表) h.logger.Info("执行批量任务", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("message", task.Message), zap.String("role", queue.Role), zap.String("conversationId", conversationID)) diff --git a/internal/handler/config.go b/internal/handler/config.go index 0b604330..4b4a7d07 100644 --- a/internal/handler/config.go +++ b/internal/handler/config.go @@ -187,6 +187,7 @@ type GetConfigResponse struct { MCP config.MCPConfig `json:"mcp"` Tools []ToolConfigInfo `json:"tools"` Agent config.AgentConfig `json:"agent"` + Hitl config.HitlConfig `json:"hitl,omitempty"` Knowledge config.KnowledgeConfig `json:"knowledge"` Robots config.RobotsConfig `json:"robots,omitempty"` MultiAgent config.MultiAgentPublic `json:"multi_agent,omitempty"` @@ -282,6 +283,7 @@ func (h *ConfigHandler) GetConfig(c *gin.Context) { MCP: h.config.MCP, Tools: tools, Agent: h.config.Agent, + Hitl: h.config.Hitl, Knowledge: h.config.Knowledge, Robots: h.config.Robots, MultiAgent: multiPub, @@ -1132,6 +1134,7 @@ func (h *ConfigHandler) saveConfig() error { updateFOFAConfig(root, h.config.FOFA) updateKnowledgeConfig(root, h.config.Knowledge) updateRobotsConfig(root, h.config.Robots) + updateHitlConfig(root, h.config.Hitl) updateMultiAgentConfig(root, h.config.MultiAgent) // 更新外部MCP配置(使用external_mcp.go中的函数,同一包中可直接调用) updateExternalMCPConfig(root, h.config.ExternalMCP) @@ -1308,6 +1311,47 @@ func updateKnowledgeConfig(doc *yaml.Node, cfg config.KnowledgeConfig) { setIntInMap(indexingNode, "retry_delay_ms", cfg.Indexing.RetryDelayMs) } +func mergeHitlToolWhitelistSlice(existing, add []string) []string { + seen := make(map[string]struct{}) + out := make([]string, 0, len(existing)+len(add)) + for _, list := range [][]string{existing, add} { + for _, t := range list { + n := strings.ToLower(strings.TrimSpace(t)) + if n == "" { + continue + } + if _, ok := seen[n]; ok { + continue + } + seen[n] = struct{}{} + out = append(out, strings.TrimSpace(t)) + } + } + return out +} + +// MergeHitlToolWhitelistIntoConfig 将会话侧栏提交的免审批工具名合并进内存配置并写入 config.yaml(与全局白名单去重规则一致:小写键、保留首次出现的原始大小写)。 +func (h *ConfigHandler) MergeHitlToolWhitelistIntoConfig(add []string) error { + h.mu.Lock() + defer h.mu.Unlock() + merged := mergeHitlToolWhitelistSlice(h.config.Hitl.ToolWhitelist, add) + h.config.Hitl.ToolWhitelist = merged + if err := h.saveConfig(); err != nil { + return err + } + h.logger.Info("HITL 全局工具白名单已合并写入配置文件", + zap.Int("count", len(merged)), + ) + return nil +} + +func updateHitlConfig(doc *yaml.Node, cfg config.HitlConfig) { + root := doc.Content[0] + hitlNode := ensureMap(root, "hitl") + // flow 样式 [a, b, c] 单行展示,工具多时比块序列省行数 + setFlowStringSliceInMap(hitlNode, "tool_whitelist", cfg.ToolWhitelist) +} + func updateRobotsConfig(doc *yaml.Node, cfg config.RobotsConfig) { root := doc.Content[0] robotsNode := ensureMap(root, "robots") @@ -1418,6 +1462,21 @@ func setStringSliceInMap(mapNode *yaml.Node, key string, values []string) { } } +func setFlowStringSliceInMap(mapNode *yaml.Node, key string, values []string) { + _, valueNode := ensureKeyValue(mapNode, key) + valueNode.Kind = yaml.SequenceNode + valueNode.Tag = "!!seq" + valueNode.Style = yaml.FlowStyle + valueNode.Content = nil + for _, v := range values { + valueNode.Content = append(valueNode.Content, &yaml.Node{ + Kind: yaml.ScalarNode, + Tag: "!!str", + Value: v, + }) + } +} + func setIntInMap(mapNode *yaml.Node, key string, value int) { _, valueNode := ensureKeyValue(mapNode, key) valueNode.Kind = yaml.ScalarNode diff --git a/internal/handler/eino_single_agent.go b/internal/handler/eino_single_agent.go index 9b7feddc..7f18bf8c 100644 --- a/internal/handler/eino_single_agent.go +++ b/internal/handler/eino_single_agent.go @@ -41,11 +41,24 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) { var baseCtx context.Context clientDisconnected := false var sseWriteMu sync.Mutex + var ssePublishConversationID string sendEvent := func(eventType, message string, data interface{}) { - if clientDisconnected { + if eventType == "error" && baseCtx != nil && errors.Is(context.Cause(baseCtx), ErrTaskCancelled) { return } - if eventType == "error" && baseCtx != nil && errors.Is(context.Cause(baseCtx), ErrTaskCancelled) { + ev := StreamEvent{Type: eventType, Message: message, Data: data} + b, errMarshal := json.Marshal(ev) + if errMarshal != nil { + b = []byte(`{"type":"error","message":"marshal failed"}`) + } + sseLine := make([]byte, 0, len(b)+8) + sseLine = append(sseLine, []byte("data: ")...) + sseLine = append(sseLine, b...) + sseLine = append(sseLine, '\n', '\n') + if ssePublishConversationID != "" && h.taskEventBus != nil { + h.taskEventBus.Publish(ssePublishConversationID, sseLine) + } + if clientDisconnected { return } select { @@ -54,10 +67,8 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) { return default: } - ev := StreamEvent{Type: eventType, Message: message, Data: data} - b, _ := json.Marshal(ev) sseWriteMu.Lock() - _, err := fmt.Fprintf(c.Writer, "data: %s\n\n", b) + _, err := c.Writer.Write(sseLine) if err != nil { sseWriteMu.Unlock() clientDisconnected = true @@ -81,6 +92,7 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) { sendEvent("done", "", nil) return } + ssePublishConversationID = prep.ConversationID if prep.CreatedNew { sendEvent("conversation", "会话已创建", map[string]interface{}{ "conversationId": prep.ConversationID, @@ -89,6 +101,10 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) { conversationID := prep.ConversationID assistantMessageID := prep.AssistantMessageID + h.activateHITLForConversation(conversationID, req.Hitl) + if h.hitlManager != nil { + defer h.hitlManager.DeactivateConversation(conversationID) + } if prep.UserMessageID != "" { sendEvent("message_saved", "", map[string]interface{}{ @@ -97,13 +113,15 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) { }) } - progressCallback := h.createProgressCallback(conversationID, assistantMessageID, sendEvent) - var cancelWithCause context.CancelCauseFunc baseCtx, cancelWithCause = context.WithCancelCause(context.Background()) taskCtx, timeoutCancel := context.WithTimeout(baseCtx, 600*time.Minute) defer timeoutCancel() defer cancelWithCause(nil) + progressCallback := h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, sendEvent) + taskCtx = multiagent.WithHITLToolInterceptor(taskCtx, func(ctx context.Context, toolName, arguments string) (string, error) { + return h.interceptHITLForEinoTool(ctx, cancelWithCause, conversationID, assistantMessageID, sendEvent, toolName, arguments) + }) if _, err := h.tasks.StartTask(conversationID, req.Message, cancelWithCause); err != nil { var errorMsg string @@ -136,6 +154,8 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) { defer close(stopKeepalive) if h.config == nil { + taskStatus = "failed" + h.tasks.UpdateTaskStatus(conversationID, taskStatus) sendEvent("error", "服务器配置未加载", nil) sendEvent("done", "", map[string]interface{}{"conversationId": conversationID}) return @@ -166,7 +186,24 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) { } sendEvent("cancelled", cancelMsg, map[string]interface{}{ "conversationId": conversationID, - "messageId": assistantMessageID, + "messageId": assistantMessageID, + }) + sendEvent("done", "", map[string]interface{}{"conversationId": conversationID}) + return + } + + if errors.Is(runErr, context.DeadlineExceeded) || errors.Is(context.Cause(taskCtx), context.DeadlineExceeded) { + taskStatus = "timeout" + h.tasks.UpdateTaskStatus(conversationID, taskStatus) + timeoutMsg := "任务执行超时,已自动终止。" + if assistantMessageID != "" { + _, _ = h.db.Exec("UPDATE messages SET content = ? WHERE id = ?", timeoutMsg, assistantMessageID) + _ = h.db.AddProcessDetail(assistantMessageID, conversationID, "timeout", timeoutMsg, nil) + } + sendEvent("error", timeoutMsg, map[string]interface{}{ + "conversationId": conversationID, + "messageId": assistantMessageID, + "errorType": "timeout", }) sendEvent("done", "", map[string]interface{}{"conversationId": conversationID}) return @@ -232,12 +269,22 @@ func (h *AgentHandler) EinoSingleAgentLoop(c *gin.Context) { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } + h.activateHITLForConversation(prep.ConversationID, req.Hitl) + if h.hitlManager != nil { + defer h.hitlManager.DeactivateConversation(prep.ConversationID) + } var progressBuf strings.Builder - progressCallback := func(eventType, message string, data interface{}) { + progressCallbackRaw := func(eventType, message string, data interface{}) { progressBuf.WriteString(eventType) progressBuf.WriteByte('\n') } + baseCtx, cancelWithCause := context.WithCancelCause(c.Request.Context()) + defer cancelWithCause(nil) + progressCallback := h.createProgressCallback(baseCtx, cancelWithCause, prep.ConversationID, prep.AssistantMessageID, progressCallbackRaw) + baseCtx = multiagent.WithHITLToolInterceptor(baseCtx, func(ctx context.Context, toolName, arguments string) (string, error) { + return h.interceptHITLForEinoTool(ctx, cancelWithCause, prep.ConversationID, prep.AssistantMessageID, nil, toolName, arguments) + }) if h.config == nil { c.JSON(http.StatusInternalServerError, gin.H{"error": "服务器配置未加载"}) @@ -245,7 +292,7 @@ func (h *AgentHandler) EinoSingleAgentLoop(c *gin.Context) { } result, runErr := multiagent.RunEinoSingleChatModelAgent( - c.Request.Context(), + baseCtx, h.config, &h.config.MultiAgent, h.agent, @@ -279,10 +326,10 @@ func (h *AgentHandler) EinoSingleAgentLoop(c *gin.Context) { } c.JSON(http.StatusOK, gin.H{ - "response": result.Response, - "conversationId": prep.ConversationID, - "mcpExecutionIds": result.MCPExecutionIDs, + "response": result.Response, + "conversationId": prep.ConversationID, + "mcpExecutionIds": result.MCPExecutionIDs, "assistantMessageId": prep.AssistantMessageID, - "agentMode": "eino_single", + "agentMode": "eino_single", }) } diff --git a/internal/handler/hitl.go b/internal/handler/hitl.go new file mode 100644 index 00000000..4231f319 --- /dev/null +++ b/internal/handler/hitl.go @@ -0,0 +1,748 @@ +package handler + +import ( + "context" + "database/sql" + "encoding/json" + "errors" + "math" + "net/http" + "strconv" + "strings" + "sync" + "time" + + "cyberstrike-ai/internal/agent" + "cyberstrike-ai/internal/database" + "cyberstrike-ai/internal/multiagent" + + "github.com/gin-gonic/gin" + "github.com/google/uuid" + "go.uber.org/zap" +) + +type hitlRuntimeConfig struct { + Enabled bool + Mode string + SensitiveTools map[string]struct{} + Timeout time.Duration +} + +type hitlDecision struct { + Decision string + Comment string + EditedArguments map[string]interface{} +} + +type pendingInterrupt struct { + ConversationID string + InterruptID string + Mode string + ToolName string + ToolCallID string + decideCh chan hitlDecision +} + +type HITLManager struct { + db *database.DB + logger *zap.Logger + + mu sync.RWMutex + runtime map[string]hitlRuntimeConfig + pending map[string]*pendingInterrupt +} + +func NewHITLManager(db *database.DB, logger *zap.Logger) *HITLManager { + return &HITLManager{ + db: db, + logger: logger, + runtime: make(map[string]hitlRuntimeConfig), + pending: make(map[string]*pendingInterrupt), + } +} + +func (m *HITLManager) EnsureSchema() error { + if _, err := m.db.Exec(` +CREATE TABLE IF NOT EXISTS hitl_interrupts ( + id TEXT PRIMARY KEY, + conversation_id TEXT NOT NULL, + message_id TEXT, + mode TEXT NOT NULL, + tool_name TEXT NOT NULL, + tool_call_id TEXT, + payload TEXT, + status TEXT NOT NULL, + decision TEXT, + decision_comment TEXT, + created_at DATETIME NOT NULL, + decided_at DATETIME +);`); err != nil { + return err + } + _, err := m.db.Exec(` +CREATE TABLE IF NOT EXISTS hitl_conversation_configs ( + conversation_id TEXT PRIMARY KEY, + enabled INTEGER NOT NULL DEFAULT 0, + mode TEXT NOT NULL DEFAULT 'off', + sensitive_tools TEXT NOT NULL DEFAULT '[]', + timeout_seconds INTEGER NOT NULL DEFAULT 300, + updated_at DATETIME NOT NULL +);`) + return err +} + +func normalizeHitlMode(mode string) string { + v := strings.ToLower(strings.TrimSpace(mode)) + if v == "" { + return "approval" + } + switch v { + case "off": + return "off" + case "feedback", "followup": + return "approval" + case "approval", "review_edit": + return v + default: + return "approval" + } +} + +func (m *HITLManager) ActivateConversation(conversationID string, req *HITLRequest) { + if req == nil || !req.Enabled { + m.DeactivateConversation(conversationID) + return + } + tools := make(map[string]struct{}) + for _, t := range req.SensitiveTools { + n := strings.ToLower(strings.TrimSpace(t)) + if n != "" { + tools[n] = struct{}{} + } + } + timeout := 5 * time.Minute + if req.TimeoutSeconds > 0 { + timeout = time.Duration(req.TimeoutSeconds) * time.Second + } + m.mu.Lock() + m.runtime[conversationID] = hitlRuntimeConfig{ + Enabled: true, + Mode: normalizeHitlMode(req.Mode), + SensitiveTools: tools, + Timeout: timeout, + } + m.mu.Unlock() +} + +func (m *HITLManager) DeactivateConversation(conversationID string) { + m.mu.Lock() + delete(m.runtime, conversationID) + m.mu.Unlock() +} + +// hitlConfigGlobalToolWhitelist 来自 config.yaml hitl.tool_whitelist(去重、去空)。 +func (h *AgentHandler) hitlConfigGlobalToolWhitelist() []string { + if h == nil || h.config == nil { + return nil + } + raw := h.config.Hitl.ToolWhitelist + if len(raw) == 0 { + return nil + } + seen := make(map[string]struct{}) + out := make([]string, 0, len(raw)) + for _, t := range raw { + n := strings.ToLower(strings.TrimSpace(t)) + if n == "" { + continue + } + if _, ok := seen[n]; ok { + continue + } + seen[n] = struct{}{} + out = append(out, strings.TrimSpace(t)) + } + return out +} + +// hitlRequestWithMergedConfigWhitelist 将会话/API 中的白名单与 config.yaml 全局白名单合并(并集),仅用于运行时 Activate;不写入数据库。 +func (h *AgentHandler) hitlRequestWithMergedConfigWhitelist(req *HITLRequest) *HITLRequest { + gw := h.hitlConfigGlobalToolWhitelist() + if len(gw) == 0 { + return req + } + if req == nil { + return nil + } + seen := make(map[string]struct{}) + union := make([]string, 0, len(gw)+len(req.SensitiveTools)) + for _, t := range gw { + n := strings.ToLower(strings.TrimSpace(t)) + if n == "" { + continue + } + if _, ok := seen[n]; ok { + continue + } + seen[n] = struct{}{} + union = append(union, strings.TrimSpace(t)) + } + for _, t := range req.SensitiveTools { + n := strings.ToLower(strings.TrimSpace(t)) + if n == "" { + continue + } + if _, ok := seen[n]; ok { + continue + } + seen[n] = struct{}{} + union = append(union, strings.TrimSpace(t)) + } + out := *req + out.SensitiveTools = union + return &out +} + +func (m *HITLManager) shouldInterrupt(conversationID, toolName string) (hitlRuntimeConfig, bool) { + m.mu.RLock() + cfg, ok := m.runtime[conversationID] + m.mu.RUnlock() + if !ok || !cfg.Enabled { + return hitlRuntimeConfig{}, false + } + // 语义:SensitiveTools 现在作为“白名单(免审批工具)” + // 空白名单 => 全部工具都需要审批 + if len(cfg.SensitiveTools) == 0 { + return cfg, true + } + _, inWhitelist := cfg.SensitiveTools[strings.ToLower(strings.TrimSpace(toolName))] + return cfg, !inWhitelist +} + +func (m *HITLManager) CreatePendingInterrupt(conversationID, assistantMessageID, mode, toolName, toolCallID, payload string) (*pendingInterrupt, error) { + now := time.Now() + id := "hitl_" + strings.ReplaceAll(uuid.New().String(), "-", "") + if _, err := m.db.Exec(`INSERT INTO hitl_interrupts + (id, conversation_id, message_id, mode, tool_name, tool_call_id, payload, status, created_at) + VALUES (?, ?, ?, ?, ?, ?, ?, 'pending', ?)`, + id, conversationID, assistantMessageID, mode, toolName, toolCallID, payload, now); err != nil { + return nil, err + } + // 刷新页面后侧栏依赖 DB 配置;若仅内存 Activate 未落库,会导致「有待审批却显示关闭」 + _ = m.ensureConversationHITLModePersisted(conversationID, mode) + p := &pendingInterrupt{ + ConversationID: conversationID, + InterruptID: id, + Mode: normalizeHitlMode(mode), + ToolName: toolName, + ToolCallID: toolCallID, + decideCh: make(chan hitlDecision, 1), + } + m.mu.Lock() + m.pending[id] = p + m.mu.Unlock() + return p, nil +} + +// ensureConversationHITLModePersisted 在产生待审批时把 mode 写入 hitl_conversation_configs,避免刷新后 GET 配置仍为关闭。 +func (m *HITLManager) ensureConversationHITLModePersisted(conversationID, interruptMode string) error { + if strings.TrimSpace(conversationID) == "" { + return nil + } + nm := normalizeHitlMode(interruptMode) + if nm == "off" { + return nil + } + cfg, err := m.LoadConversationConfig(conversationID) + if err != nil { + return err + } + if cfg.Enabled && normalizeHitlMode(cfg.Mode) == nm { + return nil + } + cfg.Enabled = true + cfg.Mode = nm + if cfg.TimeoutSeconds <= 0 { + cfg.TimeoutSeconds = 300 + } + return m.SaveConversationConfig(conversationID, cfg) +} + +// PendingHITLInterruptMode 返回该会话最新一条 pending 中断的协同模式(用于 GET 配置时与库内「关闭」状态对齐)。 +func (m *HITLManager) PendingHITLInterruptMode(conversationID string) (string, bool) { + if strings.TrimSpace(conversationID) == "" { + return "", false + } + var mode string + err := m.db.QueryRow(`SELECT mode FROM hitl_interrupts WHERE conversation_id = ? AND status = 'pending' ORDER BY created_at DESC LIMIT 1`, conversationID). + Scan(&mode) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return "", false + } + return "", false + } + mode = strings.TrimSpace(mode) + if mode == "" { + return "", false + } + return mode, true +} + +func hitlStoredConfigEffective(cfg *HITLRequest) bool { + if cfg == nil { + return false + } + if cfg.Enabled { + return true + } + return normalizeHitlMode(cfg.Mode) != "off" +} + +func (m *HITLManager) ResolveInterrupt(interruptID, decision, comment string, editedArguments map[string]interface{}) error { + decision = strings.ToLower(strings.TrimSpace(decision)) + if decision != "approve" && decision != "reject" { + return errors.New("decision must be approve/reject") + } + m.mu.RLock() + p, ok := m.pending[interruptID] + m.mu.RUnlock() + if !ok { + return errors.New("interrupt not found or already resolved") + } + d := hitlDecision{ + Decision: decision, + Comment: strings.TrimSpace(comment), + EditedArguments: editedArguments, + } + select { + case p.decideCh <- d: + return nil + default: + return errors.New("interrupt already resolved or decision channel busy") + } +} + +func (m *HITLManager) SaveConversationConfig(conversationID string, req *HITLRequest) error { + if strings.TrimSpace(conversationID) == "" { + return errors.New("conversationId is required") + } + if req == nil { + req = &HITLRequest{Enabled: false, Mode: "off", TimeoutSeconds: 300} + } + mode := normalizeHitlMode(req.Mode) + if !req.Enabled { + mode = "off" + } + tools, _ := json.Marshal(req.SensitiveTools) + timeout := req.TimeoutSeconds + if timeout <= 0 { + timeout = 300 + } + _, err := m.db.Exec(`INSERT INTO hitl_conversation_configs + (conversation_id, enabled, mode, sensitive_tools, timeout_seconds, updated_at) + VALUES (?, ?, ?, ?, ?, ?) + ON CONFLICT(conversation_id) DO UPDATE SET + enabled=excluded.enabled, mode=excluded.mode, sensitive_tools=excluded.sensitive_tools, timeout_seconds=excluded.timeout_seconds, updated_at=excluded.updated_at`, + conversationID, boolToInt(req.Enabled), mode, string(tools), timeout, time.Now()) + return err +} + +func (m *HITLManager) LoadConversationConfig(conversationID string) (*HITLRequest, error) { + var enabledInt int + var mode, toolsJSON string + var timeout int + err := m.db.QueryRow(`SELECT enabled, mode, sensitive_tools, timeout_seconds FROM hitl_conversation_configs WHERE conversation_id = ?`, conversationID). + Scan(&enabledInt, &mode, &toolsJSON, &timeout) + if errors.Is(err, sql.ErrNoRows) { + return &HITLRequest{Enabled: false, Mode: "off", SensitiveTools: []string{}, TimeoutSeconds: 300}, nil + } + if err != nil { + return nil, err + } + tools := make([]string, 0) + _ = json.Unmarshal([]byte(toolsJSON), &tools) + return &HITLRequest{ + Enabled: enabledInt == 1, + Mode: mode, + SensitiveTools: tools, + TimeoutSeconds: timeout, + }, nil +} + +func (m *HITLManager) waitDecision(ctx context.Context, p *pendingInterrupt, timeout time.Duration) (hitlDecision, error) { + defer func() { + m.mu.Lock() + delete(m.pending, p.InterruptID) + m.mu.Unlock() + }() + select { + case d := <-p.decideCh: + // 只有 review_edit 模式允许改参;其他模式一律忽略 edited arguments + if p.Mode != "review_edit" && len(d.EditedArguments) > 0 { + d.EditedArguments = nil + } + _, _ = m.db.Exec(`UPDATE hitl_interrupts SET status='decided', decision=?, decision_comment=?, decided_at=? WHERE id=?`, + d.Decision, d.Comment, time.Now(), p.InterruptID) + return d, nil + case <-time.After(timeout): + _, _ = m.db.Exec(`UPDATE hitl_interrupts SET status='timeout', decision='approve', decision_comment='timeout auto approve', decided_at=? WHERE id=?`, + time.Now(), p.InterruptID) + return hitlDecision{Decision: "approve", Comment: "timeout auto approve"}, nil + case <-ctx.Done(): + _, _ = m.db.Exec(`UPDATE hitl_interrupts SET status='cancelled', decision='reject', decision_comment='task cancelled', decided_at=? WHERE id=?`, + time.Now(), p.InterruptID) + return hitlDecision{Decision: "reject", Comment: "task cancelled"}, ctx.Err() + } +} + +func (h *AgentHandler) activateHITLForConversation(conversationID string, req *HITLRequest) { + if h.hitlManager == nil { + return + } + if req == nil { + cfg, err := h.hitlManager.LoadConversationConfig(conversationID) + if err == nil { + req = cfg + } + } + h.hitlManager.ActivateConversation(conversationID, h.hitlRequestWithMergedConfigWhitelist(req)) +} + +func (h *AgentHandler) waitHITLApproval(runCtx context.Context, cancelRun context.CancelCauseFunc, conversationID, assistantMessageID, toolName, toolCallID string, payload map[string]interface{}, sendEventFunc func(eventType, message string, data interface{})) (*hitlDecision, error) { + cfg, need := h.hitlManager.shouldInterrupt(conversationID, toolName) + if !need { + return nil, nil + } + payloadRaw, _ := json.Marshal(payload) + p, err := h.hitlManager.CreatePendingInterrupt(conversationID, assistantMessageID, cfg.Mode, toolName, toolCallID, string(payloadRaw)) + if err != nil { + h.logger.Warn("创建 HITL 中断失败", zap.Error(err)) + return nil, err + } + if sendEventFunc != nil { + sendEventFunc("hitl_interrupt", "命中人机协同审批", map[string]interface{}{ + "conversationId": conversationID, + "interruptId": p.InterruptID, + "mode": cfg.Mode, + "toolName": toolName, + "toolCallId": toolCallID, + "payload": payload, + }) + } + d, waitErr := h.hitlManager.waitDecision(runCtx, p, cfg.Timeout) + if waitErr != nil { + if cancelRun != nil && (errors.Is(waitErr, context.Canceled) || errors.Is(waitErr, context.DeadlineExceeded)) { + cause := context.Cause(runCtx) + switch { + case errors.Is(cause, ErrTaskCancelled): + cancelRun(ErrTaskCancelled) + case cause != nil: + cancelRun(cause) + case errors.Is(waitErr, context.DeadlineExceeded): + cancelRun(context.DeadlineExceeded) + default: + cancelRun(ErrTaskCancelled) + } + } + return nil, waitErr + } + if d.Decision == "reject" { + if sendEventFunc != nil { + sendEventFunc("hitl_rejected", "人工拒绝本次工具调用,模型将基于反馈继续迭代", map[string]interface{}{ + "conversationId": conversationID, + "interruptId": p.InterruptID, + "toolName": toolName, + "comment": d.Comment, + }) + } + return &d, nil + } + if sendEventFunc != nil { + sendEventFunc("hitl_resumed", "人工确认通过,继续执行", map[string]interface{}{ + "conversationId": conversationID, + "interruptId": p.InterruptID, + "toolName": toolName, + "comment": d.Comment, + "editedArgs": d.EditedArguments, + }) + } + return &d, nil +} + +func (h *AgentHandler) handleHITLToolCall(runCtx context.Context, cancelRun context.CancelCauseFunc, conversationID, assistantMessageID string, data map[string]interface{}, sendEventFunc func(eventType, message string, data interface{})) { + if h.hitlManager == nil { + return + } + toolName, _ := data["toolName"].(string) + toolCallID, _ := data["toolCallId"].(string) + d, err := h.waitHITLApproval(runCtx, cancelRun, conversationID, assistantMessageID, toolName, toolCallID, data, sendEventFunc) + if err != nil || d == nil { + return + } + if len(d.EditedArguments) > 0 { + if argsObj, ok := data["argumentsObj"].(map[string]interface{}); ok { + for k := range argsObj { + delete(argsObj, k) + } + for k, v := range d.EditedArguments { + argsObj[k] = v + } + if b, mErr := json.Marshal(argsObj); mErr == nil { + data["arguments"] = string(b) + } + } + } +} + +func (h *AgentHandler) ListHITLPending(c *gin.Context) { + conversationID := strings.TrimSpace(c.Query("conversationId")) + status := strings.TrimSpace(c.Query("status")) + if status == "" { + status = "pending" + } + page, _ := strconv.Atoi(c.DefaultQuery("page", "1")) + if page < 1 { + page = 1 + } + pageSize, _ := strconv.Atoi(c.DefaultQuery("pageSize", "20")) + pageSize = int(math.Max(1, math.Min(float64(pageSize), 200))) + offset := (page - 1) * pageSize + q := `SELECT id, conversation_id, message_id, mode, tool_name, tool_call_id, payload, status, decision, decision_comment, created_at, decided_at FROM hitl_interrupts WHERE 1=1` + args := []interface{}{} + if conversationID != "" { + q += " AND conversation_id = ?" + args = append(args, conversationID) + } + if status != "all" { + q += " AND status = ?" + args = append(args, status) + } + q += " ORDER BY created_at DESC LIMIT ? OFFSET ?" + args = append(args, pageSize, offset) + rows, err := h.db.Query(q, args...) + if err != nil { + c.JSON(500, gin.H{"error": err.Error()}) + return + } + defer rows.Close() + items := make([]map[string]interface{}, 0) + for rows.Next() { + var id, cid, mode, toolName, toolCallID, payload, rowStatus string + var messageID sql.NullString + var decision, comment sql.NullString + var createdAt time.Time + var decidedAt sql.NullTime + if err := rows.Scan(&id, &cid, &messageID, &mode, &toolName, &toolCallID, &payload, &rowStatus, &decision, &comment, &createdAt, &decidedAt); err != nil { + continue + } + msgID := "" + if messageID.Valid { + msgID = messageID.String + } + items = append(items, map[string]interface{}{ + "id": id, + "conversationId": cid, + "messageId": msgID, + "mode": mode, + "toolName": toolName, + "toolCallId": toolCallID, + "payload": payload, + "status": rowStatus, + "decision": decision.String, + "comment": comment.String, + "createdAt": createdAt, + "decidedAt": func() interface{} { + if decidedAt.Valid { + return decidedAt.Time + } + return nil + }(), + }) + } + c.JSON(http.StatusOK, gin.H{"items": items, "page": page, "pageSize": pageSize}) +} + +type hitlDecisionReq struct { + InterruptID string `json:"interruptId" binding:"required"` + Decision string `json:"decision" binding:"required"` + Comment string `json:"comment,omitempty"` + EditedArguments map[string]interface{} `json:"editedArguments,omitempty"` +} + +func (h *AgentHandler) DecideHITLInterrupt(c *gin.Context) { + var req hitlDecisionReq + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(400, gin.H{"error": err.Error()}) + return + } + if h.hitlManager == nil { + c.JSON(500, gin.H{"error": "hitl manager unavailable"}) + return + } + if err := h.hitlManager.ResolveInterrupt(req.InterruptID, req.Decision, req.Comment, req.EditedArguments); err != nil { + c.JSON(http.StatusConflict, gin.H{"error": err.Error()}) + return + } + c.JSON(http.StatusOK, gin.H{"ok": true}) +} + +func (h *AgentHandler) interceptHITLForEinoTool(runCtx context.Context, cancelRun context.CancelCauseFunc, conversationID, assistantMessageID string, sendEventFunc func(eventType, message string, data interface{}), toolName, arguments string) (string, error) { + payload := map[string]interface{}{ + "toolName": toolName, + "arguments": arguments, + "source": "eino_middleware", + "toolCallId": "", + } + var argsObj map[string]interface{} + if strings.TrimSpace(arguments) != "" { + _ = json.Unmarshal([]byte(arguments), &argsObj) + if argsObj != nil { + payload["argumentsObj"] = argsObj + } + } + d, err := h.waitHITLApproval(runCtx, cancelRun, conversationID, assistantMessageID, toolName, "", payload, sendEventFunc) + if err != nil || d == nil { + return arguments, err + } + if d.Decision == "reject" { + return arguments, multiagent.NewHumanRejectError(d.Comment) + } + if len(d.EditedArguments) > 0 { + edited, mErr := json.Marshal(d.EditedArguments) + if mErr == nil { + return string(edited), nil + } + } + return arguments, nil +} + +func (h *AgentHandler) interceptHITLForReactTool(runCtx context.Context, cancelRun context.CancelCauseFunc, conversationID, assistantMessageID string, sendEventFunc func(eventType, message string, data interface{}), toolName string, arguments map[string]interface{}, toolCallID string) (map[string]interface{}, error) { + payload := map[string]interface{}{ + "toolName": toolName, + "argumentsObj": arguments, + "toolCallId": toolCallID, + "source": "react_pre_exec", + } + d, err := h.waitHITLApproval(runCtx, cancelRun, conversationID, assistantMessageID, toolName, toolCallID, payload, sendEventFunc) + if err != nil || d == nil { + return arguments, err + } + if d.Decision == "reject" { + comment := strings.TrimSpace(d.Comment) + if comment == "" { + comment = "no extra feedback" + } + return arguments, errors.New("human rejected this tool call; feedback: " + comment) + } + if len(d.EditedArguments) > 0 { + return d.EditedArguments, nil + } + return arguments, nil +} + +func (h *AgentHandler) injectReactHITLInterceptor(ctx context.Context, cancelRun context.CancelCauseFunc, conversationID, assistantMessageID string, sendEventFunc func(eventType, message string, data interface{})) context.Context { + return agent.WithToolCallInterceptor(ctx, func(c context.Context, toolName string, args map[string]interface{}, toolCallID string) (map[string]interface{}, error) { + return h.interceptHITLForReactTool(c, cancelRun, conversationID, assistantMessageID, sendEventFunc, toolName, args, toolCallID) + }) +} + +type hitlConfigReq struct { + ConversationID string `json:"conversationId" binding:"required"` + HITLRequest +} + +func (h *AgentHandler) GetHITLConversationConfig(c *gin.Context) { + conversationID := strings.TrimSpace(c.Param("conversationId")) + if conversationID == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "conversationId is required"}) + return + } + cfg, err := h.hitlManager.LoadConversationConfig(conversationID) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + if !hitlStoredConfigEffective(cfg) { + if pendMode, ok := h.hitlManager.PendingHITLInterruptMode(conversationID); ok { + cfg2 := *cfg + cfg2.Enabled = true + cfg2.Mode = normalizeHitlMode(pendMode) + if cfg2.TimeoutSeconds <= 0 { + cfg2.TimeoutSeconds = 300 + } + cfg = &cfg2 + } + } + c.JSON(http.StatusOK, gin.H{ + "conversationId": conversationID, + "hitl": cfg, + "hitlGlobalToolWhitelist": h.hitlConfigGlobalToolWhitelist(), + }) +} + +func (h *AgentHandler) UpsertHITLConversationConfig(c *gin.Context) { + var req hitlConfigReq + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + req.Mode = normalizeHitlMode(req.Mode) + if err := h.hitlManager.SaveConversationConfig(req.ConversationID, &req.HITLRequest); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + if h.hitlWhitelistSaver != nil && len(req.SensitiveTools) > 0 { + if err := h.hitlWhitelistSaver.MergeHitlToolWhitelistIntoConfig(req.SensitiveTools); err != nil { + h.logger.Warn("HITL 会话配置已保存,但合并工具白名单到 config.yaml 失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "会话配置已保存,但写入 config.yaml 失败: " + err.Error(), + }) + return + } + } + h.hitlManager.ActivateConversation(req.ConversationID, h.hitlRequestWithMergedConfigWhitelist(&req.HITLRequest)) + c.JSON(http.StatusOK, gin.H{"ok": true}) +} + +type mergeHitlGlobalWhitelistReq struct { + SensitiveTools []string `json:"sensitiveTools"` +} + +// MergeHITLGlobalToolWhitelist 无会话 ID 时将侧栏提交的免审批工具合并进 config.yaml(与 PUT /hitl/config 中白名单落盘规则一致)。 +func (h *AgentHandler) MergeHITLGlobalToolWhitelist(c *gin.Context) { + if h.hitlWhitelistSaver == nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "HITL 配置持久化不可用"}) + return + } + var req mergeHitlGlobalWhitelistReq + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + if len(req.SensitiveTools) == 0 { + c.JSON(http.StatusOK, gin.H{ + "ok": true, + "hitlGlobalToolWhitelist": h.hitlConfigGlobalToolWhitelist(), + "hitlGlobalWhitelistMerged": false, + }) + return + } + if err := h.hitlWhitelistSaver.MergeHitlToolWhitelistIntoConfig(req.SensitiveTools); err != nil { + h.logger.Warn("合并 HITL 工具白名单到 config.yaml 失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + c.JSON(http.StatusOK, gin.H{ + "ok": true, + "hitlGlobalToolWhitelist": h.hitlConfigGlobalToolWhitelist(), + "hitlGlobalWhitelistMerged": true, + }) +} + +func boolToInt(v bool) int { + if v { + return 1 + } + return 0 +} diff --git a/internal/handler/multi_agent.go b/internal/handler/multi_agent.go index b9f9e0af..5114a222 100644 --- a/internal/handler/multi_agent.go +++ b/internal/handler/multi_agent.go @@ -53,25 +53,36 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) { clientDisconnected := false // 与 sseKeepalive 共用:禁止并发写 ResponseWriter,否则会破坏 chunked 编码(ERR_INVALID_CHUNKED_ENCODING)。 var sseWriteMu sync.Mutex + var ssePublishConversationID string sendEvent := func(eventType, message string, data interface{}) { - if clientDisconnected { - return - } // 用户主动停止时,Eino 可能仍会并发上报 eventType=="error"。 // 为避免 UI 看到“取消错误 + cancelled 文案”两条回复,这里直接丢弃取消对应的 error。 if eventType == "error" && baseCtx != nil && errors.Is(context.Cause(baseCtx), ErrTaskCancelled) { return } + ev := StreamEvent{Type: eventType, Message: message, Data: data} + b, errMarshal := json.Marshal(ev) + if errMarshal != nil { + b = []byte(`{"type":"error","message":"marshal failed"}`) + } + sseLine := make([]byte, 0, len(b)+8) + sseLine = append(sseLine, []byte("data: ")...) + sseLine = append(sseLine, b...) + sseLine = append(sseLine, '\n', '\n') + if ssePublishConversationID != "" && h.taskEventBus != nil { + h.taskEventBus.Publish(ssePublishConversationID, sseLine) + } + if clientDisconnected { + return + } select { case <-c.Request.Context().Done(): clientDisconnected = true return default: } - ev := StreamEvent{Type: eventType, Message: message, Data: data} - b, _ := json.Marshal(ev) sseWriteMu.Lock() - _, err := fmt.Fprintf(c.Writer, "data: %s\n\n", b) + _, err := c.Writer.Write(sseLine) if err != nil { sseWriteMu.Unlock() clientDisconnected = true @@ -95,6 +106,7 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) { sendEvent("done", "", nil) return } + ssePublishConversationID = prep.ConversationID if prep.CreatedNew { sendEvent("conversation", "会话已创建", map[string]interface{}{ "conversationId": prep.ConversationID, @@ -103,6 +115,10 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) { conversationID := prep.ConversationID assistantMessageID := prep.AssistantMessageID + h.activateHITLForConversation(conversationID, req.Hitl) + if h.hitlManager != nil { + defer h.hitlManager.DeactivateConversation(conversationID) + } if prep.UserMessageID != "" { sendEvent("message_saved", "", map[string]interface{}{ @@ -111,12 +127,14 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) { }) } - progressCallback := h.createProgressCallback(conversationID, assistantMessageID, sendEvent) - baseCtx, cancelWithCause := context.WithCancelCause(context.Background()) taskCtx, timeoutCancel := context.WithTimeout(baseCtx, 600*time.Minute) defer timeoutCancel() defer cancelWithCause(nil) + progressCallback := h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, sendEvent) + taskCtx = multiagent.WithHITLToolInterceptor(taskCtx, func(ctx context.Context, toolName, arguments string) (string, error) { + return h.interceptHITLForEinoTool(ctx, cancelWithCause, conversationID, assistantMessageID, sendEvent, toolName, arguments) + }) if _, err := h.tasks.StartTask(conversationID, req.Message, cancelWithCause); err != nil { var errorMsg string @@ -181,6 +199,23 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) { return } + if errors.Is(runErr, context.DeadlineExceeded) || errors.Is(context.Cause(taskCtx), context.DeadlineExceeded) { + taskStatus = "timeout" + h.tasks.UpdateTaskStatus(conversationID, taskStatus) + timeoutMsg := "任务执行超时,已自动终止。" + if assistantMessageID != "" { + _, _ = h.db.Exec("UPDATE messages SET content = ? WHERE id = ?", timeoutMsg, assistantMessageID) + _ = h.db.AddProcessDetail(assistantMessageID, conversationID, "timeout", timeoutMsg, nil) + } + sendEvent("error", timeoutMsg, map[string]interface{}{ + "conversationId": conversationID, + "messageId": assistantMessageID, + "errorType": "timeout", + }) + sendEvent("done", "", map[string]interface{}{"conversationId": conversationID}) + return + } + h.logger.Error("Eino DeepAgent 执行失败", zap.Error(runErr)) taskStatus = "failed" h.tasks.UpdateTaskStatus(conversationID, taskStatus) @@ -251,9 +286,20 @@ func (h *AgentHandler) MultiAgentLoop(c *gin.Context) { c.JSON(status, gin.H{"error": msg}) return } + h.activateHITLForConversation(prep.ConversationID, req.Hitl) + if h.hitlManager != nil { + defer h.hitlManager.DeactivateConversation(prep.ConversationID) + } + + baseCtx, cancelWithCause := context.WithCancelCause(c.Request.Context()) + defer cancelWithCause(nil) + progressCallback := h.createProgressCallback(baseCtx, cancelWithCause, prep.ConversationID, prep.AssistantMessageID, nil) + baseCtx = multiagent.WithHITLToolInterceptor(baseCtx, func(ctx context.Context, toolName, arguments string) (string, error) { + return h.interceptHITLForEinoTool(ctx, cancelWithCause, prep.ConversationID, prep.AssistantMessageID, nil, toolName, arguments) + }) result, runErr := multiagent.RunDeepAgent( - c.Request.Context(), + baseCtx, h.config, &h.config.MultiAgent, h.agent, @@ -262,7 +308,7 @@ func (h *AgentHandler) MultiAgentLoop(c *gin.Context) { prep.FinalMessage, prep.History, prep.RoleTools, - nil, + progressCallback, h.agentsMarkdownDir, strings.TrimSpace(req.Orchestration), ) diff --git a/internal/handler/task_event_bus.go b/internal/handler/task_event_bus.go new file mode 100644 index 00000000..bf2ad880 --- /dev/null +++ b/internal/handler/task_event_bus.go @@ -0,0 +1,116 @@ +package handler + +import "sync" + +// TaskEventBus 将主 SSE 连接上的事件镜像给后订阅的客户端(例如刷新页面后、HITL 审批通过需继续收事件)。 +// 每个 payload 为完整 SSE 行: "data: {...}\n\n" +type TaskEventBus struct { + mu sync.RWMutex + subs map[string]map[*taskEventSub]struct{} +} + +type taskEventSub struct { + mu sync.Mutex + ch chan []byte + closed bool +} + +func (s *taskEventSub) sendNonBlocking(line []byte) bool { + if s == nil { + return false + } + s.mu.Lock() + defer s.mu.Unlock() + if s.closed { + return false + } + select { + case s.ch <- line: + return true + default: + return false + } +} + +func (s *taskEventSub) closeOnce() { + if s == nil { + return + } + s.mu.Lock() + defer s.mu.Unlock() + if s.closed { + return + } + s.closed = true + close(s.ch) +} + +func NewTaskEventBus() *TaskEventBus { + return &TaskEventBus{ + subs: make(map[string]map[*taskEventSub]struct{}), + } +} + +// Subscribe 注册订阅;cancel 时需调用 Unsubscribe。 +func (b *TaskEventBus) Subscribe(conversationID string) (sub *taskEventSub, ch <-chan []byte) { + chBuf := make(chan []byte, 256) + sub = &taskEventSub{ch: chBuf} + b.mu.Lock() + if b.subs[conversationID] == nil { + b.subs[conversationID] = make(map[*taskEventSub]struct{}) + } + b.subs[conversationID][sub] = struct{}{} + b.mu.Unlock() + return sub, chBuf +} + +func (b *TaskEventBus) Unsubscribe(conversationID string, sub *taskEventSub) { + if sub == nil { + return + } + b.mu.Lock() + m, ok := b.subs[conversationID] + if !ok { + b.mu.Unlock() + return + } + delete(m, sub) + if len(m) == 0 { + delete(b.subs, conversationID) + } + b.mu.Unlock() + sub.closeOnce() +} + +// Publish 非阻塞投递;慢消费者丢帧(HITL 场景以最新状态为准,丢帧可接受)。 +func (b *TaskEventBus) Publish(conversationID string, line []byte) { + if b == nil || conversationID == "" || len(line) == 0 { + return + } + b.mu.RLock() + m := b.subs[conversationID] + subs := make([]*taskEventSub, 0, len(m)) + for s := range m { + subs = append(subs, s) + } + b.mu.RUnlock() + + cp := append([]byte(nil), line...) + for _, s := range subs { + s.sendNonBlocking(cp) + } +} + +// CloseConversation 任务结束时关闭该会话所有订阅 channel。 +func (b *TaskEventBus) CloseConversation(conversationID string) { + if b == nil || conversationID == "" { + return + } + b.mu.Lock() + m := b.subs[conversationID] + delete(b.subs, conversationID) + b.mu.Unlock() + for sub := range m { + sub.closeOnce() + } +} diff --git a/internal/handler/task_manager.go b/internal/handler/task_manager.go index 9964ad5c..acbc4733 100644 --- a/internal/handler/task_manager.go +++ b/internal/handler/task_manager.go @@ -35,11 +35,12 @@ type CompletedTask struct { // AgentTaskManager 管理正在运行的Agent任务 type AgentTaskManager struct { - mu sync.RWMutex - tasks map[string]*AgentTask - completedTasks []*CompletedTask // 最近完成的任务历史 - maxHistorySize int // 最大历史记录数 - historyRetention time.Duration // 历史记录保留时间 + mu sync.RWMutex + tasks map[string]*AgentTask + completedTasks []*CompletedTask // 最近完成的任务历史 + maxHistorySize int // 最大历史记录数 + historyRetention time.Duration // 历史记录保留时间 + eventBus *TaskEventBus // 可选:任务结束时关闭镜像 SSE 订阅 } const ( @@ -56,13 +57,27 @@ func NewAgentTaskManager() *AgentTaskManager { m := &AgentTaskManager{ tasks: make(map[string]*AgentTask), completedTasks: make([]*CompletedTask, 0), - maxHistorySize: 50, // 最多保留50条历史记录 - historyRetention: 24 * time.Hour, // 保留24小时 + maxHistorySize: 50, // 最多保留50条历史记录 + historyRetention: 24 * time.Hour, // 保留24小时 } go m.runStuckCancellingCleanup() return m } +// SetTaskEventBus 设置任务事件总线(与 AgentHandler 共用同一实例)。 +func (m *AgentTaskManager) SetTaskEventBus(b *TaskEventBus) { + m.mu.Lock() + defer m.mu.Unlock() + m.eventBus = b +} + +// GetTask 返回运行中任务(无则 nil)。 +func (m *AgentTaskManager) GetTask(conversationID string) *AgentTask { + m.mu.RLock() + defer m.mu.RUnlock() + return m.tasks[conversationID] +} + // runStuckCancellingCleanup 定期将长时间处于「取消中」的任务强制结束,避免卡住无法发新消息 func (m *AgentTaskManager) runStuckCancellingCleanup() { ticker := time.NewTicker(cleanupInterval) @@ -172,10 +187,9 @@ func (m *AgentTaskManager) UpdateTaskStatus(conversationID string, status string // FinishTask 完成任务并从管理器中移除 func (m *AgentTaskManager) FinishTask(conversationID string, finalStatus string) { m.mu.Lock() - defer m.mu.Unlock() - task, exists := m.tasks[conversationID] if !exists { + m.mu.Unlock() return } @@ -187,26 +201,31 @@ func (m *AgentTaskManager) FinishTask(conversationID string, finalStatus string) completedTask := &CompletedTask{ ConversationID: task.ConversationID, Message: task.Message, - StartedAt: task.StartedAt, - CompletedAt: time.Now(), - Status: finalStatus, + StartedAt: task.StartedAt, + CompletedAt: time.Now(), + Status: finalStatus, } - + // 添加到历史记录 m.completedTasks = append(m.completedTasks, completedTask) - + // 清理过期和过多的历史记录 m.cleanupHistory() // 从运行任务中移除 delete(m.tasks, conversationID) + bus := m.eventBus + m.mu.Unlock() + if bus != nil { + bus.CloseConversation(conversationID) + } } // cleanupHistory 清理过期的历史记录 func (m *AgentTaskManager) cleanupHistory() { now := time.Now() cutoffTime := now.Add(-m.historyRetention) - + // 过滤掉过期的记录 validTasks := make([]*CompletedTask, 0, len(m.completedTasks)) for _, task := range m.completedTasks { @@ -214,7 +233,7 @@ func (m *AgentTaskManager) cleanupHistory() { validTasks = append(validTasks, task) } } - + // 如果仍然超过最大数量,只保留最新的 if len(validTasks) > m.maxHistorySize { // 按完成时间排序,保留最新的 @@ -222,7 +241,7 @@ func (m *AgentTaskManager) cleanupHistory() { start := len(validTasks) - m.maxHistorySize validTasks = validTasks[start:] } - + m.completedTasks = validTasks } @@ -247,30 +266,30 @@ func (m *AgentTaskManager) GetActiveTasks() []*AgentTask { func (m *AgentTaskManager) GetCompletedTasks() []*CompletedTask { m.mu.RLock() defer m.mu.RUnlock() - + // 清理过期记录(只读锁,不影响其他操作) // 注意:这里不能直接调用cleanupHistory,因为需要写锁 // 所以返回时过滤过期记录 now := time.Now() cutoffTime := now.Add(-m.historyRetention) - + result := make([]*CompletedTask, 0, len(m.completedTasks)) for _, task := range m.completedTasks { if task.CompletedAt.After(cutoffTime) { result = append(result, task) } } - + // 按完成时间倒序排序(最新的在前) // 由于是追加的,最新的在最后,需要反转 for i, j := 0, len(result)-1; i < j; i, j = i+1, j-1 { result[i], result[j] = result[j], result[i] } - + // 限制返回数量 if len(result) > m.maxHistorySize { result = result[:m.maxHistorySize] } - + return result }