Files
CyberStrikeAI/internal/workflow/eino_compile.go
T
2026-07-03 19:36:40 +08:00

191 lines
5.5 KiB
Go

package workflow
import (
"context"
"encoding/json"
"fmt"
"strings"
"github.com/cloudwego/eino/compose"
)
func executeEinoGraph(ctx context.Context, args RunArgs, runID string, workflowID string, version int, g *graphDef, state *WorkflowLocalState) error {
_, err := invokeEinoGraph(ctx, args, runID, workflowID, version, g, state, false)
return err
}
func invokeEinoGraph(ctx context.Context, args RunArgs, runID string, workflowID string, version int, g *graphDef, state *WorkflowLocalState, resume bool) (bool, error) {
wfInput := workflowInputFromMap(state.Inputs)
if resume {
wfInput = WorkflowInput{}
}
rt := &workflowRuntime{
args: args,
runID: runID,
idx: indexGraph(g),
state: state,
}
art, err := defaultEngine.getOrCompile(ctx, workflowID, version, g)
if err != nil {
return false, fmt.Errorf("编译 Eino Workflow 失败: %w", err)
}
rt.idx = art.idx
runCtx := withWorkflowRuntime(ctx, rt)
runCtx = attachWorkflowCallbacks(runCtx, args.AppCfg, args, workflowID)
invokeOpts := []compose.Option{compose.WithCheckPointID(runID)}
for {
_, err = art.runnable.Invoke(runCtx, wfInput, invokeOpts...)
if err == nil {
return false, nil
}
if hitlErr := extractAwaitingHITL(err, art, runID, args, state); hitlErr != nil {
return true, hitlErr
}
return false, err
}
}
func extractAwaitingHITL(err error, art *compiledArtifact, runID string, args RunArgs, state *WorkflowLocalState) error {
info, ok := compose.ExtractInterruptInfo(err)
if !ok || len(art.hitlIDs) == 0 {
return nil
}
nodeID := nextHITLNodeID(info, art.hitlIDs)
node := art.idx.nodes[nodeID]
if nodeID == "" {
return nil
}
prompt := resolveHITLPromptBinding(node.Config, state)
label := firstNonEmpty(node.Label, nodeID)
if args.DB != nil {
pending := map[string]any{
"nodeId": nodeID,
"label": label,
"prompt": prompt,
"reviewer": cfgString(node.Config, "reviewer"),
}
pendingJSON, _ := json.Marshal(pending)
_ = args.DB.SetWorkflowRunAwaitingHITL(runID, nodeID, string(pendingJSON))
}
if args.Progress != nil {
args.Progress("workflow_hitl_waiting", fmt.Sprintf("等待人工确认:%s", label), map[string]any{
"workflowRunId": runID,
"nodeId": nodeID,
"label": label,
"prompt": prompt,
"reviewer": cfgString(node.Config, "reviewer"),
"mode": "interactive",
"resumeApi": fmt.Sprintf("/api/workflows/runs/%s/resume", runID),
})
}
return &AwaitingHITLError{
RunID: runID,
NodeID: nodeID,
NodeLabel: label,
Prompt: prompt,
Reviewer: cfgString(node.Config, "reviewer"),
}
}
func nextHITLNodeID(info *compose.InterruptInfo, hitlIDs []string) string {
if info != nil && len(info.BeforeNodes) > 0 {
for _, id := range info.BeforeNodes {
for _, hitl := range hitlIDs {
if id == hitl {
return id
}
}
}
return info.BeforeNodes[0]
}
if len(hitlIDs) == 0 {
return ""
}
return hitlIDs[0]
}
// ResumeWorkflowRun continues a run paused at HITL after human decision.
func ResumeWorkflowRun(ctx context.Context, args RunArgs, runID string, approved bool, comment string) (*RunResult, error) {
run, err := args.DB.GetWorkflowRun(runID)
if err != nil {
return nil, err
}
if run == nil {
return nil, fmt.Errorf("工作流运行不存在")
}
if run.Status != "awaiting_hitl" {
return nil, fmt.Errorf("工作流运行不在等待审批状态: %s", run.Status)
}
wf, err := args.DB.GetWorkflowDefinition(run.WorkflowID)
if err != nil || wf == nil {
return nil, fmt.Errorf("工作流定义不存在")
}
graph, err := parseGraph(wf.GraphJSON)
if err != nil {
return nil, err
}
var input map[string]interface{}
_ = json.Unmarshal([]byte(run.InputJSON), &input)
state := newWorkflowLocalState(input, runID)
if state.Inputs == nil {
state.Inputs = map[string]any{}
}
state.Inputs["_hitl_approved"] = approved
state.Inputs["_hitl_comment"] = strings.TrimSpace(comment)
state.Inputs["_hitl_node_id"] = run.PendingHITLNodeID
if !approved {
errText := strings.TrimSpace(comment)
if errText == "" {
errText = "人工审批拒绝"
}
_ = args.DB.FinishWorkflowRun(runID, "rejected", "", errText)
return &RunResult{
RunID: runID,
Response: fmt.Sprintf("工作流已在审批节点「%s」被拒绝。", run.PendingHITLNodeID),
Status: "rejected",
}, nil
}
_ = args.DB.SetWorkflowRunStatus(runID, "running")
resumeArgs := args
if strings.TrimSpace(resumeArgs.ConversationID) == "" {
resumeArgs.ConversationID = run.ConversationID
}
awaiting, err := invokeEinoGraph(ctx, resumeArgs, runID, wf.ID, run.WorkflowVersion, graph, state, true)
if err != nil {
if IsAwaitingHITL(err) {
return &RunResult{
RunID: runID,
Status: "awaiting_hitl",
Response: fmt.Sprintf("工作流在节点「%s」等待下一次人工确认。", err.(*AwaitingHITLError).NodeID),
AwaitingHITL: true,
}, nil
}
_ = args.DB.FinishWorkflowRun(runID, "failed", "", err.Error())
return nil, err
}
_ = awaiting
output := map[string]interface{}{
"workflowId": wf.ID,
"workflowName": wf.Name,
"workflowVersion": wf.Version,
"workflowRunId": runID,
"status": "completed",
"outputs": state.Outputs,
"executedNodes": state.Executed,
"skippedNodes": state.Skipped,
"engine": "eino_workflow",
}
outputJSON, _ := json.Marshal(output)
response := renderWorkflowResponse(args.Role.Name, wf.Name, wf.Version, runID, state)
_ = args.DB.FinishWorkflowRun(runID, "completed", string(outputJSON), "")
return &RunResult{Response: response, RunID: runID, Status: "completed"}, nil
}