From aae71a0c3e3f7d41e908b4089de2da776046d70f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=85=AC=E6=98=8E?= <83812544+Ed1s0nZ@users.noreply.github.com> Date: Fri, 3 Jul 2026 20:27:51 +0800 Subject: [PATCH] Add files via upload --- internal/handler/workflow_integration.go | 177 +++++++++++++++++++---- internal/handler/workflow_run.go | 41 +++++- 2 files changed, 185 insertions(+), 33 deletions(-) diff --git a/internal/handler/workflow_integration.go b/internal/handler/workflow_integration.go index 3eae0a0a..dee8bd35 100644 --- a/internal/handler/workflow_integration.go +++ b/internal/handler/workflow_integration.go @@ -2,6 +2,7 @@ package handler import ( "context" + "errors" "net/http" "strings" "time" @@ -10,6 +11,7 @@ import ( workflowrunner "cyberstrike-ai/internal/workflow" "github.com/gin-gonic/gin" + "go.uber.org/zap" ) func (h *AgentHandler) roleForWorkflow(req *ChatRequest) (config.RoleConfig, bool) { @@ -42,33 +44,108 @@ func (h *AgentHandler) runRoleWorkflowStreamIfBound( if !ok || prep == nil { return false } + + conversationID := prep.ConversationID + assistantMessageID := prep.AssistantMessageID + userMessage := "" + if req != nil { + userMessage = req.Message + } + + taskStatus := "completed" + taskOwned := false + defer func() { + if taskOwned { + h.tasks.FinishTask(conversationID, taskStatus) + } + }() + baseCtx, cancelWithCause := context.WithCancelCause(context.Background()) defer cancelWithCause(nil) - progress := h.createProgressCallback(baseCtx, cancelWithCause, prep.ConversationID, prep.AssistantMessageID, sendEvent) - result, err := workflowrunner.RunRoleBoundWorkflow(baseCtx, workflowrunner.RunArgs{ + taskCtx, timeoutCancel := context.WithTimeout(baseCtx, 600*time.Minute) + defer timeoutCancel() + + if _, err := h.tasks.StartTask(conversationID, userMessage, cancelWithCause); err != nil { + var errorMsg string + if errors.Is(err, ErrTaskAlreadyRunning) { + errorMsg = "⚠️ 当前会话已有任务正在执行中,请等待当前任务完成或点击「停止任务」后再尝试。" + sendEvent("error", errorMsg, map[string]interface{}{ + "conversationId": conversationID, + "errorType": "task_already_running", + }) + } else { + errorMsg = "❌ 无法启动任务: " + err.Error() + sendEvent("error", errorMsg, nil) + } + if assistantMessageID != "" { + _, _ = h.db.Exec("UPDATE messages SET content = ?, updated_at = ? WHERE id = ?", errorMsg, time.Now(), assistantMessageID) + } + sendEvent("done", "", map[string]interface{}{"conversationId": conversationID}) + return true + } + taskOwned = true + + progress := h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, sendEvent) + result, err := workflowrunner.RunRoleBoundWorkflow(taskCtx, workflowrunner.RunArgs{ DB: h.db, Logger: h.logger, Role: role, AppCfg: h.config, Agent: h.agent, - ConversationID: prep.ConversationID, - ProjectID: h.conversationProjectID(prep.ConversationID), + ConversationID: conversationID, + ProjectID: h.conversationProjectID(conversationID), UserMessage: prep.FinalMessage, History: prep.History, RoleTools: prep.RoleTools, AgentsMarkdownDir: h.agentsMarkdownDir, - SystemPromptExtra: h.agentSessionContextBlock(prep.ConversationID), - AssistantMessageID: prep.AssistantMessageID, + SystemPromptExtra: h.agentSessionContextBlock(conversationID), + AssistantMessageID: assistantMessageID, Progress: progress, }) if err != nil { - errMsg := "执行角色绑定流程失败: " + err.Error() - if prep.AssistantMessageID != "" { - _, _ = h.db.Exec("UPDATE messages SET content = ?, updated_at = ? WHERE id = ?", errMsg, time.Now(), prep.AssistantMessageID) - _ = h.db.AddProcessDetail(prep.AssistantMessageID, prep.ConversationID, "error", errMsg, nil) + cause := context.Cause(baseCtx) + if errors.Is(cause, ErrTaskCancelled) { + taskStatus = "cancelled" + h.tasks.UpdateTaskStatus(conversationID, taskStatus) + cancelMsg := "任务已被用户取消,后续操作已停止。" + if assistantMessageID != "" { + if err := h.appendAssistantMessageNotice(assistantMessageID, cancelMsg); err != nil { + h.logger.Warn("更新取消后的助手消息失败", zap.Error(err)) + } + _ = h.db.AddProcessDetail(assistantMessageID, conversationID, "cancelled", cancelMsg, nil) + } + sendEvent("cancelled", cancelMsg, map[string]interface{}{ + "conversationId": conversationID, + "messageId": assistantMessageID, + }) + sendEvent("done", "", map[string]interface{}{"conversationId": conversationID}) + return true } - sendEvent("error", errMsg, map[string]interface{}{"conversationId": prep.ConversationID}) - sendEvent("done", "", map[string]interface{}{"conversationId": prep.ConversationID}) + if errors.Is(err, context.DeadlineExceeded) || errors.Is(context.Cause(taskCtx), context.DeadlineExceeded) { + taskStatus = "timeout" + h.tasks.UpdateTaskStatus(conversationID, taskStatus) + timeoutMsg := "任务执行超时,已自动终止。" + if assistantMessageID != "" { + _, _ = h.db.Exec("UPDATE messages SET content = ?, updated_at = ? WHERE id = ?", timeoutMsg, time.Now(), assistantMessageID) + _ = h.db.AddProcessDetail(assistantMessageID, conversationID, "timeout", timeoutMsg, nil) + } + sendEvent("error", timeoutMsg, map[string]interface{}{ + "conversationId": conversationID, + "messageId": assistantMessageID, + "errorType": "timeout", + }) + sendEvent("done", "", map[string]interface{}{"conversationId": conversationID}) + return true + } + errMsg := "执行角色绑定流程失败: " + err.Error() + taskStatus = "failed" + h.tasks.UpdateTaskStatus(conversationID, taskStatus) + if assistantMessageID != "" { + _, _ = h.db.Exec("UPDATE messages SET content = ?, updated_at = ? WHERE id = ?", errMsg, time.Now(), assistantMessageID) + _ = h.db.AddProcessDetail(assistantMessageID, conversationID, "error", errMsg, nil) + } + sendEvent("error", errMsg, map[string]interface{}{"conversationId": conversationID}) + sendEvent("done", "", map[string]interface{}{"conversationId": conversationID}) return true } if prep.AssistantMessageID != "" { @@ -85,13 +162,6 @@ func (h *AgentHandler) runRoleWorkflowStreamIfBound( payload["awaitingHitl"] = true } sendEvent("response", result.Response, payload) - if result.AwaitingHITL { - sendEvent("done", "", map[string]interface{}{ - "conversationId": prep.ConversationID, - "workflowStatus": "awaiting_hitl", - }) - return true - } sendEvent("done", "", map[string]interface{}{"conversationId": prep.ConversationID}) return true } @@ -101,31 +171,80 @@ func (h *AgentHandler) runRoleWorkflowJSONIfBound(c *gin.Context, req *ChatReque if !ok || prep == nil { return false } + + conversationID := prep.ConversationID + assistantMessageID := prep.AssistantMessageID + userMessage := "" + if req != nil { + userMessage = req.Message + } + + taskStatus := "completed" + taskOwned := false + defer func() { + if taskOwned { + h.tasks.FinishTask(conversationID, taskStatus) + } + }() + baseCtx, cancelWithCause := context.WithCancelCause(c.Request.Context()) defer cancelWithCause(nil) - progress := h.createProgressCallback(baseCtx, cancelWithCause, prep.ConversationID, prep.AssistantMessageID, nil) - result, err := workflowrunner.RunRoleBoundWorkflow(baseCtx, workflowrunner.RunArgs{ + taskCtx, timeoutCancel := context.WithTimeout(baseCtx, 600*time.Minute) + defer timeoutCancel() + + if _, err := h.tasks.StartTask(conversationID, userMessage, cancelWithCause); err != nil { + if errors.Is(err, ErrTaskAlreadyRunning) { + c.JSON(http.StatusConflict, gin.H{ + "error": "⚠️ 当前会话已有任务正在执行中,请等待当前任务完成或点击「停止任务」后再尝试。", + "conversationId": conversationID, + "errorType": "task_already_running", + }) + } else { + c.JSON(http.StatusInternalServerError, gin.H{"error": "❌ 无法启动任务: " + err.Error()}) + } + return true + } + taskOwned = true + + progress := h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, nil) + result, err := workflowrunner.RunRoleBoundWorkflow(taskCtx, workflowrunner.RunArgs{ DB: h.db, Logger: h.logger, Role: role, AppCfg: h.config, Agent: h.agent, - ConversationID: prep.ConversationID, - ProjectID: h.conversationProjectID(prep.ConversationID), + ConversationID: conversationID, + ProjectID: h.conversationProjectID(conversationID), UserMessage: prep.FinalMessage, History: prep.History, RoleTools: prep.RoleTools, AgentsMarkdownDir: h.agentsMarkdownDir, - SystemPromptExtra: h.agentSessionContextBlock(prep.ConversationID), - AssistantMessageID: prep.AssistantMessageID, + SystemPromptExtra: h.agentSessionContextBlock(conversationID), + AssistantMessageID: assistantMessageID, Progress: progress, }) if err != nil { - errMsg := "执行角色绑定流程失败: " + err.Error() - if prep.AssistantMessageID != "" { - _, _ = h.db.Exec("UPDATE messages SET content = ?, updated_at = ? WHERE id = ?", errMsg, time.Now(), prep.AssistantMessageID) + cause := context.Cause(baseCtx) + if errors.Is(cause, ErrTaskCancelled) { + taskStatus = "cancelled" + cancelMsg := "任务已被用户取消,后续操作已停止。" + if assistantMessageID != "" { + _ = h.appendAssistantMessageNotice(assistantMessageID, cancelMsg) + _ = h.db.AddProcessDetail(assistantMessageID, conversationID, "cancelled", cancelMsg, nil) + } + c.JSON(http.StatusOK, gin.H{ + "status": "cancelled", + "message": cancelMsg, + "conversationId": conversationID, + }) + return true } - c.JSON(http.StatusInternalServerError, gin.H{"error": errMsg, "conversationId": prep.ConversationID}) + errMsg := "执行角色绑定流程失败: " + err.Error() + taskStatus = "failed" + if assistantMessageID != "" { + _, _ = h.db.Exec("UPDATE messages SET content = ?, updated_at = ? WHERE id = ?", errMsg, time.Now(), assistantMessageID) + } + c.JSON(http.StatusInternalServerError, gin.H{"error": errMsg, "conversationId": conversationID}) return true } if prep.AssistantMessageID != "" { diff --git a/internal/handler/workflow_run.go b/internal/handler/workflow_run.go index 0a19c0b8..0754912f 100644 --- a/internal/handler/workflow_run.go +++ b/internal/handler/workflow_run.go @@ -3,6 +3,7 @@ package handler import ( "net/http" "strings" + "time" "cyberstrike-ai/internal/agent" "cyberstrike-ai/internal/config" @@ -31,7 +32,8 @@ func (h *WorkflowHandler) GetRun(c *gin.Context) { } func (h *WorkflowHandler) ListPendingRuns(c *gin.Context) { - runs, err := h.db.ListWorkflowRunsAwaitingHITL(50) + conversationID := strings.TrimSpace(c.Query("conversationId")) + runs, err := h.db.ListWorkflowRunsAwaitingHITLFiltered(conversationID, 50) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return @@ -73,6 +75,37 @@ func (h *WorkflowHandler) ResumeRun(c *gin.Context) { } } } + if run.Status != "awaiting_hitl" { + c.JSON(http.StatusBadRequest, gin.H{"error": "工作流运行不在等待审批状态: " + run.Status}) + return + } + if err := h.db.RecordWorkflowRunHITLDecision(runID, req.Approved, req.Comment); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + decision := workflowrunner.HITLDecision{ + Approved: req.Approved, + Comment: strings.TrimSpace(req.Comment), + } + delegated := workflowrunner.NotifyHITLDecision(runID, decision) + if !delegated { + for i := 0; i < 10; i++ { + time.Sleep(50 * time.Millisecond) + if workflowrunner.NotifyHITLDecision(runID, decision) { + delegated = true + break + } + } + } + if delegated { + c.JSON(http.StatusOK, gin.H{ + "workflowRunId": runID, + "status": "delegated", + "streamResuming": true, + "approved": req.Approved, + }) + return + } result, err := workflowrunner.ResumeWorkflowRun(c.Request.Context(), workflowrunner.RunArgs{ DB: h.db, Logger: h.logger, @@ -87,9 +120,9 @@ func (h *WorkflowHandler) ResumeRun(c *gin.Context) { return } c.JSON(http.StatusOK, gin.H{ - "response": result.Response, + "response": result.Response, "workflowRunId": result.RunID, - "status": result.Status, - "awaitingHitl": result.AwaitingHITL, + "status": result.Status, + "awaitingHitl": result.AwaitingHITL, }) }