mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-07-04 11:37:57 +02:00
191 lines
5.5 KiB
Go
191 lines
5.5 KiB
Go
package workflow
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"strings"
|
|
|
|
"github.com/cloudwego/eino/compose"
|
|
)
|
|
|
|
func executeEinoGraph(ctx context.Context, args RunArgs, runID string, workflowID string, version int, g *graphDef, state *WorkflowLocalState) error {
|
|
_, err := invokeEinoGraph(ctx, args, runID, workflowID, version, g, state, false)
|
|
return err
|
|
}
|
|
|
|
func invokeEinoGraph(ctx context.Context, args RunArgs, runID string, workflowID string, version int, g *graphDef, state *WorkflowLocalState, resume bool) (bool, error) {
|
|
wfInput := workflowInputFromMap(state.Inputs)
|
|
if resume {
|
|
wfInput = WorkflowInput{}
|
|
}
|
|
rt := &workflowRuntime{
|
|
args: args,
|
|
runID: runID,
|
|
idx: indexGraph(g),
|
|
state: state,
|
|
}
|
|
|
|
art, err := defaultEngine.getOrCompile(ctx, workflowID, version, g)
|
|
if err != nil {
|
|
return false, fmt.Errorf("编译 Eino Workflow 失败: %w", err)
|
|
}
|
|
rt.idx = art.idx
|
|
|
|
runCtx := withWorkflowRuntime(ctx, rt)
|
|
runCtx = attachWorkflowCallbacks(runCtx, args.AppCfg, args, workflowID)
|
|
|
|
invokeOpts := []compose.Option{compose.WithCheckPointID(runID)}
|
|
for {
|
|
_, err = art.runnable.Invoke(runCtx, wfInput, invokeOpts...)
|
|
if err == nil {
|
|
return false, nil
|
|
}
|
|
if hitlErr := extractAwaitingHITL(err, art, runID, args, state); hitlErr != nil {
|
|
return true, hitlErr
|
|
}
|
|
return false, err
|
|
}
|
|
}
|
|
|
|
func extractAwaitingHITL(err error, art *compiledArtifact, runID string, args RunArgs, state *WorkflowLocalState) error {
|
|
info, ok := compose.ExtractInterruptInfo(err)
|
|
if !ok || len(art.hitlIDs) == 0 {
|
|
return nil
|
|
}
|
|
nodeID := nextHITLNodeID(info, art.hitlIDs)
|
|
node := art.idx.nodes[nodeID]
|
|
if nodeID == "" {
|
|
return nil
|
|
}
|
|
prompt := resolveHITLPromptBinding(node.Config, state)
|
|
label := firstNonEmpty(node.Label, nodeID)
|
|
if args.DB != nil {
|
|
pending := map[string]any{
|
|
"nodeId": nodeID,
|
|
"label": label,
|
|
"prompt": prompt,
|
|
"reviewer": cfgString(node.Config, "reviewer"),
|
|
}
|
|
pendingJSON, _ := json.Marshal(pending)
|
|
_ = args.DB.SetWorkflowRunAwaitingHITL(runID, nodeID, string(pendingJSON))
|
|
}
|
|
if args.Progress != nil {
|
|
args.Progress("workflow_hitl_waiting", fmt.Sprintf("等待人工确认:%s", label), map[string]any{
|
|
"workflowRunId": runID,
|
|
"nodeId": nodeID,
|
|
"label": label,
|
|
"prompt": prompt,
|
|
"reviewer": cfgString(node.Config, "reviewer"),
|
|
"mode": "interactive",
|
|
"resumeApi": fmt.Sprintf("/api/workflows/runs/%s/resume", runID),
|
|
})
|
|
}
|
|
return &AwaitingHITLError{
|
|
RunID: runID,
|
|
NodeID: nodeID,
|
|
NodeLabel: label,
|
|
Prompt: prompt,
|
|
Reviewer: cfgString(node.Config, "reviewer"),
|
|
}
|
|
}
|
|
|
|
func nextHITLNodeID(info *compose.InterruptInfo, hitlIDs []string) string {
|
|
if info != nil && len(info.BeforeNodes) > 0 {
|
|
for _, id := range info.BeforeNodes {
|
|
for _, hitl := range hitlIDs {
|
|
if id == hitl {
|
|
return id
|
|
}
|
|
}
|
|
}
|
|
return info.BeforeNodes[0]
|
|
}
|
|
if len(hitlIDs) == 0 {
|
|
return ""
|
|
}
|
|
return hitlIDs[0]
|
|
}
|
|
|
|
// ResumeWorkflowRun continues a run paused at HITL after human decision.
|
|
func ResumeWorkflowRun(ctx context.Context, args RunArgs, runID string, approved bool, comment string) (*RunResult, error) {
|
|
run, err := args.DB.GetWorkflowRun(runID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if run == nil {
|
|
return nil, fmt.Errorf("工作流运行不存在")
|
|
}
|
|
if run.Status != "awaiting_hitl" {
|
|
return nil, fmt.Errorf("工作流运行不在等待审批状态: %s", run.Status)
|
|
}
|
|
wf, err := args.DB.GetWorkflowDefinition(run.WorkflowID)
|
|
if err != nil || wf == nil {
|
|
return nil, fmt.Errorf("工作流定义不存在")
|
|
}
|
|
graph, err := parseGraph(wf.GraphJSON)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var input map[string]interface{}
|
|
_ = json.Unmarshal([]byte(run.InputJSON), &input)
|
|
state := newWorkflowLocalState(input, runID)
|
|
if state.Inputs == nil {
|
|
state.Inputs = map[string]any{}
|
|
}
|
|
state.Inputs["_hitl_approved"] = approved
|
|
state.Inputs["_hitl_comment"] = strings.TrimSpace(comment)
|
|
state.Inputs["_hitl_node_id"] = run.PendingHITLNodeID
|
|
|
|
if !approved {
|
|
errText := strings.TrimSpace(comment)
|
|
if errText == "" {
|
|
errText = "人工审批拒绝"
|
|
}
|
|
_ = args.DB.FinishWorkflowRun(runID, "rejected", "", errText)
|
|
return &RunResult{
|
|
RunID: runID,
|
|
Response: fmt.Sprintf("工作流已在审批节点「%s」被拒绝。", run.PendingHITLNodeID),
|
|
Status: "rejected",
|
|
}, nil
|
|
}
|
|
|
|
_ = args.DB.SetWorkflowRunStatus(runID, "running")
|
|
resumeArgs := args
|
|
if strings.TrimSpace(resumeArgs.ConversationID) == "" {
|
|
resumeArgs.ConversationID = run.ConversationID
|
|
}
|
|
|
|
awaiting, err := invokeEinoGraph(ctx, resumeArgs, runID, wf.ID, run.WorkflowVersion, graph, state, true)
|
|
if err != nil {
|
|
if IsAwaitingHITL(err) {
|
|
return &RunResult{
|
|
RunID: runID,
|
|
Status: "awaiting_hitl",
|
|
Response: fmt.Sprintf("工作流在节点「%s」等待下一次人工确认。", err.(*AwaitingHITLError).NodeID),
|
|
AwaitingHITL: true,
|
|
}, nil
|
|
}
|
|
_ = args.DB.FinishWorkflowRun(runID, "failed", "", err.Error())
|
|
return nil, err
|
|
}
|
|
_ = awaiting
|
|
|
|
output := map[string]interface{}{
|
|
"workflowId": wf.ID,
|
|
"workflowName": wf.Name,
|
|
"workflowVersion": wf.Version,
|
|
"workflowRunId": runID,
|
|
"status": "completed",
|
|
"outputs": state.Outputs,
|
|
"executedNodes": state.Executed,
|
|
"skippedNodes": state.Skipped,
|
|
"engine": "eino_workflow",
|
|
}
|
|
outputJSON, _ := json.Marshal(output)
|
|
response := renderWorkflowResponse(args.Role.Name, wf.Name, wf.Version, runID, state)
|
|
_ = args.DB.FinishWorkflowRun(runID, "completed", string(outputJSON), "")
|
|
return &RunResult{Response: response, RunID: runID, Status: "completed"}, nil
|
|
}
|