diff --git a/internal/workflow/eino_compile.go b/internal/workflow/eino_compile.go index 2628ab0a..7f151b4a 100644 --- a/internal/workflow/eino_compile.go +++ b/internal/workflow/eino_compile.go @@ -144,6 +144,13 @@ func ResumeWorkflowRun(ctx context.Context, args RunArgs, runID string, approved errText = "人工审批拒绝" } _ = args.DB.FinishWorkflowRun(runID, "rejected", "", errText) + if args.Progress != nil { + args.Progress("workflow_hitl_rejected", fmt.Sprintf("工作流已在审批节点「%s」被拒绝。", run.PendingHITLNodeID), map[string]interface{}{ + "workflowRunId": runID, + "nodeId": run.PendingHITLNodeID, + "comment": errText, + }) + } return &RunResult{ RunID: runID, Response: fmt.Sprintf("工作流已在审批节点「%s」被拒绝。", run.PendingHITLNodeID), @@ -151,6 +158,14 @@ func ResumeWorkflowRun(ctx context.Context, args RunArgs, runID string, approved }, nil } + if args.Progress != nil { + args.Progress("workflow_hitl_resumed", "人工审批已通过,继续执行", map[string]interface{}{ + "workflowRunId": runID, + "nodeId": run.PendingHITLNodeID, + "comment": strings.TrimSpace(comment), + }) + } + _ = args.DB.SetWorkflowRunStatus(runID, "running") resumeArgs := args if strings.TrimSpace(resumeArgs.ConversationID) == "" { @@ -186,5 +201,14 @@ func ResumeWorkflowRun(ctx context.Context, args RunArgs, runID string, approved outputJSON, _ := json.Marshal(output) response := renderWorkflowResponse(args.Role.Name, wf.Name, wf.Version, runID, state) _ = args.DB.FinishWorkflowRun(runID, "completed", string(outputJSON), "") + if args.Progress != nil { + args.Progress("workflow_done", fmt.Sprintf("流程「%s」运行完成", wf.Name), map[string]interface{}{ + "workflowRunId": runID, + "workflowId": wf.ID, + "outputs": state.Outputs, + "response": response, + "engine": "eino_workflow", + }) + } return &RunResult{Response: response, RunID: runID, Status: "completed"}, nil } diff --git a/internal/workflow/hitl_wait.go b/internal/workflow/hitl_wait.go new file mode 100644 index 00000000..624cdf5b --- /dev/null +++ b/internal/workflow/hitl_wait.go @@ -0,0 +1,119 @@ +package workflow + +import ( + "context" + "encoding/json" + "fmt" + "strings" + "sync" + "time" + + "cyberstrike-ai/internal/database" +) + +// HITLDecision is a human decision on a workflow approval node. +type HITLDecision struct { + Approved bool + Comment string +} + +var hitlWaiters sync.Map // runID -> chan HITLDecision + +func registerHITLWaiter(runID string) chan HITLDecision { + ch := make(chan HITLDecision, 1) + hitlWaiters.Store(runID, ch) + return ch +} + +func unregisterHITLWaiter(runID string, ch chan HITLDecision) { + hitlWaiters.CompareAndDelete(runID, ch) +} + +// NotifyHITLDecision wakes a streaming workflow run waiting at a HITL node. +// Returns true when an active waiter was signaled. +func NotifyHITLDecision(runID string, decision HITLDecision) bool { + v, ok := hitlWaiters.Load(runID) + if !ok { + return false + } + ch, ok := v.(chan HITLDecision) + if !ok { + return false + } + select { + case ch <- decision: + return true + default: + return true + } +} + +func readHITLDecisionFromDB(db *database.DB, runID string) (HITLDecision, bool, error) { + if db == nil { + return HITLDecision{}, false, nil + } + run, err := db.GetWorkflowRun(runID) + if err != nil { + return HITLDecision{}, false, err + } + if run == nil || strings.TrimSpace(run.PendingHITLJSON) == "" { + return HITLDecision{}, false, nil + } + var pending map[string]interface{} + if err := json.Unmarshal([]byte(run.PendingHITLJSON), &pending); err != nil { + return HITLDecision{}, false, nil + } + raw, ok := pending["decision"] + if !ok { + return HITLDecision{}, false, nil + } + decision := strings.ToLower(strings.TrimSpace(fmt.Sprint(raw))) + switch decision { + case "approved", "approve": + comment := "" + if v, ok := pending["comment"]; ok { + comment = strings.TrimSpace(fmt.Sprint(v)) + } + return HITLDecision{Approved: true, Comment: comment}, true, nil + case "rejected", "reject": + comment := "" + if v, ok := pending["comment"]; ok { + comment = strings.TrimSpace(fmt.Sprint(v)) + } + return HITLDecision{Approved: false, Comment: comment}, true, nil + default: + return HITLDecision{}, false, nil + } +} + +func waitWorkflowHITLDecision(ctx context.Context, db *database.DB, runID string) (HITLDecision, error) { + ch := registerHITLWaiter(runID) + defer unregisterHITLWaiter(runID, ch) + return waitWorkflowHITLDecisionWithChannel(ctx, db, runID, ch) +} + +func waitWorkflowHITLDecisionWithChannel(ctx context.Context, db *database.DB, runID string, ch chan HITLDecision) (HITLDecision, error) { + if d, ok, err := readHITLDecisionFromDB(db, runID); err != nil { + return HITLDecision{}, err + } else if ok { + return d, nil + } + + ticker := time.NewTicker(500 * time.Millisecond) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return HITLDecision{}, ctx.Err() + case d := <-ch: + return d, nil + case <-ticker.C: + if d, ok, err := readHITLDecisionFromDB(db, runID); err != nil { + return HITLDecision{}, err + } else if ok { + return d, nil + } + } + } +} diff --git a/internal/workflow/runner.go b/internal/workflow/runner.go index 3c8152b2..d9865573 100644 --- a/internal/workflow/runner.go +++ b/internal/workflow/runner.go @@ -91,36 +91,46 @@ func RunRoleBoundWorkflow(ctx context.Context, args RunArgs) (*RunResult, error) return nil, err } state := newWorkflowLocalState(input, runID) - if err := executeEinoGraph(ctx, args, runID, wf.ID, wf.Version, graph, state); err != nil { - if IsAwaitingHITL(err) { - hitl := err.(*AwaitingHITLError) - partial := map[string]interface{}{ - "workflowId": wf.ID, - "workflowName": wf.Name, - "workflowVersion": wf.Version, - "workflowRunId": runID, - "status": "awaiting_hitl", - "outputs": state.Outputs, - "executedNodes": state.Executed, - "skippedNodes": state.Skipped, - "pendingHitl": map[string]interface{}{ - "nodeId": hitl.NodeID, - "label": hitl.NodeLabel, - "prompt": hitl.Prompt, - }, - "engine": "eino_workflow", - } - partialJSON, _ := json.Marshal(partial) - _ = args.DB.SetWorkflowRunAwaitingHITL(runID, hitl.NodeID, string(partialJSON)) - response := fmt.Sprintf("工作流「%s」已在节点「%s」暂停,等待人工审批。\n运行 ID:%s", wf.Name, firstNonEmpty(hitl.NodeLabel, hitl.NodeID), runID) - if args.Progress != nil { - args.Progress("workflow_paused", response, map[string]interface{}{ - "workflowRunId": runID, - "status": "awaiting_hitl", - "nodeId": hitl.NodeID, - "resumeApi": fmt.Sprintf("/api/workflows/runs/%s/resume", runID), - }) - } + streaming := args.Progress != nil + resuming := false + for { + _, err := invokeEinoGraph(ctx, args, runID, wf.ID, wf.Version, graph, state, resuming) + if err == nil { + break + } + if !IsAwaitingHITL(err) { + _ = args.DB.FinishWorkflowRun(runID, "failed", "", err.Error()) + return nil, err + } + hitl := err.(*AwaitingHITLError) + partial := map[string]interface{}{ + "workflowId": wf.ID, + "workflowName": wf.Name, + "workflowVersion": wf.Version, + "workflowRunId": runID, + "status": "awaiting_hitl", + "outputs": state.Outputs, + "executedNodes": state.Executed, + "skippedNodes": state.Skipped, + "pendingHitl": map[string]interface{}{ + "nodeId": hitl.NodeID, + "label": hitl.NodeLabel, + "prompt": hitl.Prompt, + }, + "engine": "eino_workflow", + } + partialJSON, _ := json.Marshal(partial) + _ = args.DB.SetWorkflowRunAwaitingHITL(runID, hitl.NodeID, string(partialJSON)) + response := fmt.Sprintf("工作流「%s」已在节点「%s」暂停,等待人工审批。\n运行 ID:%s", wf.Name, firstNonEmpty(hitl.NodeLabel, hitl.NodeID), runID) + if args.Progress != nil { + args.Progress("workflow_paused", response, map[string]interface{}{ + "workflowRunId": runID, + "status": "awaiting_hitl", + "nodeId": hitl.NodeID, + "resumeApi": fmt.Sprintf("/api/workflows/runs/%s/resume", runID), + }) + } + if !streaming { return &RunResult{ Response: response, RunID: runID, @@ -128,8 +138,48 @@ func RunRoleBoundWorkflow(ctx context.Context, args RunArgs) (*RunResult, error) AwaitingHITL: true, }, nil } - _ = args.DB.FinishWorkflowRun(runID, "failed", "", err.Error()) - return nil, err + ch := registerHITLWaiter(runID) + decision, waitErr := waitWorkflowHITLDecisionWithChannel(ctx, args.DB, runID, ch) + unregisterHITLWaiter(runID, ch) + if waitErr != nil { + _ = args.DB.FinishWorkflowRun(runID, "cancelled", "", waitErr.Error()) + return nil, waitErr + } + if !decision.Approved { + errText := strings.TrimSpace(decision.Comment) + if errText == "" { + errText = "人工审批拒绝" + } + _ = args.DB.FinishWorkflowRun(runID, "rejected", "", errText) + rejectResponse := fmt.Sprintf("工作流已在审批节点「%s」被拒绝。", firstNonEmpty(hitl.NodeLabel, hitl.NodeID)) + if args.Progress != nil { + args.Progress("workflow_hitl_rejected", rejectResponse, map[string]interface{}{ + "workflowRunId": runID, + "nodeId": hitl.NodeID, + "comment": errText, + }) + } + return &RunResult{ + Response: rejectResponse, + RunID: runID, + Status: "rejected", + }, nil + } + if args.Progress != nil { + args.Progress("workflow_hitl_resumed", "人工审批已通过,继续执行", map[string]interface{}{ + "workflowRunId": runID, + "nodeId": hitl.NodeID, + "comment": decision.Comment, + }) + } + if state.Inputs == nil { + state.Inputs = map[string]any{} + } + state.Inputs["_hitl_approved"] = true + state.Inputs["_hitl_comment"] = decision.Comment + state.Inputs["_hitl_node_id"] = hitl.NodeID + _ = args.DB.SetWorkflowRunStatus(runID, "running") + resuming = true } output := map[string]interface{}{