diff --git a/internal/workflow/agent_subgraph.go b/internal/workflow/agent_subgraph.go new file mode 100644 index 00000000..408f1bd4 --- /dev/null +++ b/internal/workflow/agent_subgraph.go @@ -0,0 +1,24 @@ +package workflow + +import ( + "context" + + "github.com/cloudwego/eino/compose" +) + +// compileAgentSubgraph wraps an Agent canvas node as an Eino subgraph (AddGraphNode best practice). +func compileAgentSubgraph(_ context.Context, node graphNode) (compose.AnyGraph, error) { + n := node + innerID := n.ID + "__agent" + g := compose.NewGraph[WorkflowNodeOutput, WorkflowNodeOutput]() + _ = g.AddLambdaNode(innerID, compose.InvokableLambda(func(runCtx context.Context, _ WorkflowNodeOutput) (WorkflowNodeOutput, error) { + return runWorkflowNodeLambda(runCtx, n) + })) + if err := g.AddEdge(compose.START, innerID); err != nil { + return nil, err + } + if err := g.AddEdge(innerID, compose.END); err != nil { + return nil, err + } + return g, nil +} diff --git a/internal/workflow/bindings.go b/internal/workflow/bindings.go new file mode 100644 index 00000000..9f1e7e14 --- /dev/null +++ b/internal/workflow/bindings.go @@ -0,0 +1,141 @@ +package workflow + +import ( + "encoding/json" + "fmt" + "strings" +) + +// FieldBinding selects a value from workflow state (replaces {{...}} templates). +type FieldBinding struct { + From string `json:"from"` // inputs | previous | + Field string `json:"field"` // e.g. output, message +} + +func parseFieldBinding(cfg map[string]any, keys ...string) (FieldBinding, bool) { + for _, key := range keys { + if cfg == nil { + continue + } + raw, ok := cfg[key] + if !ok || raw == nil { + continue + } + switch v := raw.(type) { + case map[string]any: + return FieldBinding{ + From: strings.TrimSpace(fmt.Sprint(v["from"])), + Field: strings.TrimSpace(fmt.Sprint(v["field"])), + }, true + case string: + s := strings.TrimSpace(v) + if s == "" { + continue + } + var b FieldBinding + if err := json.Unmarshal([]byte(s), &b); err == nil && (b.From != "" || b.Field != "") { + return b, true + } + } + } + return FieldBinding{}, false +} + +func defaultBinding(from, field string) FieldBinding { + return FieldBinding{From: from, Field: field} +} + +func resolveBinding(b FieldBinding, state *WorkflowLocalState) any { + from := strings.TrimSpace(b.From) + field := strings.TrimSpace(b.Field) + if field == "" { + field = "output" + } + if from == "" || from == "previous" || from == "prev" { + if field == "output" && state.LastOutput != nil { + return state.LastOutput["output"] + } + return valueFromPath("previous."+field, state) + } + if from == "inputs" || from == "input" { + if field == "" { + return state.Inputs + } + return valueFromPath("inputs."+field, state) + } + if from == "outputs" { + return valueFromPath("outputs."+field, state) + } + return valueFromPath(from+"."+field, state) +} + +func resolveBindingString(b FieldBinding, state *WorkflowLocalState) string { + return strings.TrimSpace(fmt.Sprint(resolveBinding(b, state))) +} + +func resolveNodeInputBinding(cfg map[string]any, state *WorkflowLocalState) string { + if b, ok := parseFieldBinding(cfg, "input_binding"); ok { + return resolveBindingString(b, state) + } + // legacy template field removed — default previous.output + return resolveBindingString(defaultBinding("previous", "output"), state) +} + +func resolveOutputSourceBinding(cfg map[string]any, state *WorkflowLocalState) any { + if b, ok := parseFieldBinding(cfg, "source_binding"); ok { + return resolveBinding(b, state) + } + return resolveBinding(defaultBinding("previous", "output"), state) +} + +func resolveHITLPromptBinding(cfg map[string]any, state *WorkflowLocalState) string { + if b, ok := parseFieldBinding(cfg, "prompt_binding"); ok { + return resolveBindingString(b, state) + } + if s := cfgString(cfg, "prompt"); s != "" { + return s + } + return resolveBindingString(defaultBinding("previous", "output"), state) +} + +func toolArgumentBindings(cfg map[string]any) map[string]FieldBinding { + raw, ok := cfg["argument_bindings"].(map[string]any) + if !ok || len(raw) == 0 { + return nil + } + out := make(map[string]FieldBinding, len(raw)) + for argName, v := range raw { + m, ok := v.(map[string]any) + if !ok { + continue + } + out[argName] = FieldBinding{ + From: strings.TrimSpace(fmt.Sprint(m["from"])), + Field: strings.TrimSpace(fmt.Sprint(m["field"])), + } + } + return out +} + +func resolveToolArguments(cfg map[string]any, state *WorkflowLocalState) (map[string]interface{}, error) { + bindings := toolArgumentBindings(cfg) + if len(bindings) > 0 { + args := make(map[string]interface{}, len(bindings)) + for k, b := range bindings { + args[k] = resolveBinding(b, state) + } + return args, nil + } + raw := cfgString(cfg, "arguments") + if raw == "" { + return map[string]interface{}{}, nil + } + var args map[string]interface{} + if err := json.Unmarshal([]byte(raw), &args); err != nil { + return nil, err + } + if args == nil { + args = map[string]interface{}{} + } + return args, nil +} diff --git a/internal/workflow/checkpoint_store.go b/internal/workflow/checkpoint_store.go new file mode 100644 index 00000000..5d254ac0 --- /dev/null +++ b/internal/workflow/checkpoint_store.go @@ -0,0 +1,69 @@ +package workflow + +import ( + "context" + "fmt" + "os" + "path/filepath" + "strings" + "sync" +) + +// fileCheckPointStore persists Eino workflow checkpoints on disk (per run id). +type fileCheckPointStore struct { + dir string + mu sync.RWMutex +} + +func newFileCheckPointStore(dir string) (*fileCheckPointStore, error) { + dir = strings.TrimSpace(dir) + if dir == "" { + dir = filepath.Join("data", "workflow-checkpoints") + } + if err := os.MkdirAll(dir, 0o755); err != nil { + return nil, fmt.Errorf("create workflow checkpoint dir: %w", err) + } + return &fileCheckPointStore{dir: dir}, nil +} + +func (s *fileCheckPointStore) path(id string) (string, error) { + id = strings.TrimSpace(id) + if id == "" { + return "", fmt.Errorf("checkpoint id is empty") + } + if strings.Contains(id, "..") || strings.ContainsAny(id, `/\`) { + return "", fmt.Errorf("invalid checkpoint id") + } + return filepath.Join(s.dir, id+".ckpt"), nil +} + +func (s *fileCheckPointStore) Get(_ context.Context, checkPointID string) ([]byte, bool, error) { + s.mu.RLock() + defer s.mu.RUnlock() + p, err := s.path(checkPointID) + if err != nil { + return nil, false, err + } + data, err := os.ReadFile(p) + if err != nil { + if os.IsNotExist(err) { + return nil, false, nil + } + return nil, false, err + } + return data, true, nil +} + +func (s *fileCheckPointStore) Set(_ context.Context, checkPointID string, checkPoint []byte) error { + s.mu.Lock() + defer s.mu.Unlock() + p, err := s.path(checkPointID) + if err != nil { + return err + } + tmp := p + ".tmp" + if err := os.WriteFile(tmp, checkPoint, 0o600); err != nil { + return err + } + return os.Rename(tmp, p) +} diff --git a/internal/workflow/eino_branch.go b/internal/workflow/eino_branch.go new file mode 100644 index 00000000..1b3b9b13 --- /dev/null +++ b/internal/workflow/eino_branch.go @@ -0,0 +1,107 @@ +package workflow + +import ( + "context" + "fmt" + + "github.com/cloudwego/eino/compose" +) + +func hasConditionalOutgoingEdges(idx *graphIndex, nodeID string) bool { + for _, edge := range idx.outgoing[nodeID] { + cond := firstNonEmpty(cfgString(edge.Config, "condition"), cfgString(edge.Config, "expression")) + if cond != "" { + return true + } + } + return false +} + +func wireConditionBranch( + wf *compose.Workflow[WorkflowInput, WorkflowOutput], + nodeRefs map[string]*compose.WorkflowNode, + idx *graphIndex, + condID string, + condNode graphNode, +) error { + edges := idx.outgoing[condID] + if len(edges) == 0 { + return nil + } + branchID := branchNodeID(condID) + wf.AddPassthroughNode(branchID).AddInput(condID) + + endNodes := map[string]bool{compose.END: true} + for _, edge := range edges { + endNodes[edge.Target] = true + } + + sortedEdges := append([]graphEdge(nil), edges...) + sortEdgesByCanvas(sortedEdges, idx.nodes) + + branch := compose.NewGraphBranch(func(runCtx context.Context, _ map[string]any) (string, error) { + rt := workflowRuntimeFrom(runCtx) + if rt == nil { + return compose.END, fmt.Errorf("workflow runtime missing in context") + } + emitConditionBranchProgress(rt.args, rt.runID, condNode, sortedEdges, idx.nodes, rt.state) + for edgeIdx, edge := range sortedEdges { + if conditionBranchAllowed(edge, edgeIdx, rt.state) { + return edge.Target, nil + } + } + return compose.END, nil + }, endNodes) + wf.AddBranch(branchID, branch) + + for _, edge := range edges { + if target, ok := nodeRefs[edge.Target]; ok { + target.AddInput(branchID) + } + } + return nil +} + +func wireEdgeConditionBranch( + wf *compose.Workflow[WorkflowInput, WorkflowOutput], + nodeRefs map[string]*compose.WorkflowNode, + idx *graphIndex, + sourceID string, + sourceNode graphNode, +) error { + edges := idx.outgoing[sourceID] + if len(edges) == 0 { + return nil + } + branchID := edgeBranchNodeID(sourceID) + wf.AddPassthroughNode(branchID).AddInput(sourceID) + + endNodes := map[string]bool{compose.END: true} + for _, edge := range edges { + endNodes[edge.Target] = true + } + + sortedEdges := append([]graphEdge(nil), edges...) + sortEdgesByCanvas(sortedEdges, idx.nodes) + + branch := compose.NewGraphBranch(func(runCtx context.Context, _ map[string]any) (string, error) { + rt := workflowRuntimeFrom(runCtx) + if rt == nil { + return compose.END, fmt.Errorf("workflow runtime missing in context") + } + for edgeIdx, edge := range sortedEdges { + if edgeAllowed(edge, sourceNode, edgeIdx, rt.state) { + return edge.Target, nil + } + } + return compose.END, nil + }, endNodes) + wf.AddBranch(branchID, branch) + + for _, edge := range edges { + if target, ok := nodeRefs[edge.Target]; ok { + target.AddInput(branchID) + } + } + return nil +} diff --git a/internal/workflow/eino_callbacks.go b/internal/workflow/eino_callbacks.go new file mode 100644 index 00000000..e4761c89 --- /dev/null +++ b/internal/workflow/eino_callbacks.go @@ -0,0 +1,22 @@ +package workflow + +import ( + "context" + + "cyberstrike-ai/internal/config" + "cyberstrike-ai/internal/einoobserve" +) + +func attachWorkflowCallbacks(ctx context.Context, cfg *config.Config, args RunArgs, workflowName string) context.Context { + if cfg == nil { + return ctx + } + cbCfg := &cfg.MultiAgent.EinoCallbacks + return einoobserve.AttachAgentRunCallbacks(ctx, cbCfg, einoobserve.Params{ + Logger: args.Logger, + Progress: args.Progress, + ConversationID: args.ConversationID, + OrchMode: "workflow", + OrchestratorName: workflowName, + }) +} diff --git a/internal/workflow/eino_compile.go b/internal/workflow/eino_compile.go new file mode 100644 index 00000000..2628ab0a --- /dev/null +++ b/internal/workflow/eino_compile.go @@ -0,0 +1,190 @@ +package workflow + +import ( + "context" + "encoding/json" + "fmt" + "strings" + + "github.com/cloudwego/eino/compose" +) + +func executeEinoGraph(ctx context.Context, args RunArgs, runID string, workflowID string, version int, g *graphDef, state *WorkflowLocalState) error { + _, err := invokeEinoGraph(ctx, args, runID, workflowID, version, g, state, false) + return err +} + +func invokeEinoGraph(ctx context.Context, args RunArgs, runID string, workflowID string, version int, g *graphDef, state *WorkflowLocalState, resume bool) (bool, error) { + wfInput := workflowInputFromMap(state.Inputs) + if resume { + wfInput = WorkflowInput{} + } + rt := &workflowRuntime{ + args: args, + runID: runID, + idx: indexGraph(g), + state: state, + } + + art, err := defaultEngine.getOrCompile(ctx, workflowID, version, g) + if err != nil { + return false, fmt.Errorf("编译 Eino Workflow 失败: %w", err) + } + rt.idx = art.idx + + runCtx := withWorkflowRuntime(ctx, rt) + runCtx = attachWorkflowCallbacks(runCtx, args.AppCfg, args, workflowID) + + invokeOpts := []compose.Option{compose.WithCheckPointID(runID)} + for { + _, err = art.runnable.Invoke(runCtx, wfInput, invokeOpts...) + if err == nil { + return false, nil + } + if hitlErr := extractAwaitingHITL(err, art, runID, args, state); hitlErr != nil { + return true, hitlErr + } + return false, err + } +} + +func extractAwaitingHITL(err error, art *compiledArtifact, runID string, args RunArgs, state *WorkflowLocalState) error { + info, ok := compose.ExtractInterruptInfo(err) + if !ok || len(art.hitlIDs) == 0 { + return nil + } + nodeID := nextHITLNodeID(info, art.hitlIDs) + node := art.idx.nodes[nodeID] + if nodeID == "" { + return nil + } + prompt := resolveHITLPromptBinding(node.Config, state) + label := firstNonEmpty(node.Label, nodeID) + if args.DB != nil { + pending := map[string]any{ + "nodeId": nodeID, + "label": label, + "prompt": prompt, + "reviewer": cfgString(node.Config, "reviewer"), + } + pendingJSON, _ := json.Marshal(pending) + _ = args.DB.SetWorkflowRunAwaitingHITL(runID, nodeID, string(pendingJSON)) + } + if args.Progress != nil { + args.Progress("workflow_hitl_waiting", fmt.Sprintf("等待人工确认:%s", label), map[string]any{ + "workflowRunId": runID, + "nodeId": nodeID, + "label": label, + "prompt": prompt, + "reviewer": cfgString(node.Config, "reviewer"), + "mode": "interactive", + "resumeApi": fmt.Sprintf("/api/workflows/runs/%s/resume", runID), + }) + } + return &AwaitingHITLError{ + RunID: runID, + NodeID: nodeID, + NodeLabel: label, + Prompt: prompt, + Reviewer: cfgString(node.Config, "reviewer"), + } +} + +func nextHITLNodeID(info *compose.InterruptInfo, hitlIDs []string) string { + if info != nil && len(info.BeforeNodes) > 0 { + for _, id := range info.BeforeNodes { + for _, hitl := range hitlIDs { + if id == hitl { + return id + } + } + } + return info.BeforeNodes[0] + } + if len(hitlIDs) == 0 { + return "" + } + return hitlIDs[0] +} + +// ResumeWorkflowRun continues a run paused at HITL after human decision. +func ResumeWorkflowRun(ctx context.Context, args RunArgs, runID string, approved bool, comment string) (*RunResult, error) { + run, err := args.DB.GetWorkflowRun(runID) + if err != nil { + return nil, err + } + if run == nil { + return nil, fmt.Errorf("工作流运行不存在") + } + if run.Status != "awaiting_hitl" { + return nil, fmt.Errorf("工作流运行不在等待审批状态: %s", run.Status) + } + wf, err := args.DB.GetWorkflowDefinition(run.WorkflowID) + if err != nil || wf == nil { + return nil, fmt.Errorf("工作流定义不存在") + } + graph, err := parseGraph(wf.GraphJSON) + if err != nil { + return nil, err + } + + var input map[string]interface{} + _ = json.Unmarshal([]byte(run.InputJSON), &input) + state := newWorkflowLocalState(input, runID) + if state.Inputs == nil { + state.Inputs = map[string]any{} + } + state.Inputs["_hitl_approved"] = approved + state.Inputs["_hitl_comment"] = strings.TrimSpace(comment) + state.Inputs["_hitl_node_id"] = run.PendingHITLNodeID + + if !approved { + errText := strings.TrimSpace(comment) + if errText == "" { + errText = "人工审批拒绝" + } + _ = args.DB.FinishWorkflowRun(runID, "rejected", "", errText) + return &RunResult{ + RunID: runID, + Response: fmt.Sprintf("工作流已在审批节点「%s」被拒绝。", run.PendingHITLNodeID), + Status: "rejected", + }, nil + } + + _ = args.DB.SetWorkflowRunStatus(runID, "running") + resumeArgs := args + if strings.TrimSpace(resumeArgs.ConversationID) == "" { + resumeArgs.ConversationID = run.ConversationID + } + + awaiting, err := invokeEinoGraph(ctx, resumeArgs, runID, wf.ID, run.WorkflowVersion, graph, state, true) + if err != nil { + if IsAwaitingHITL(err) { + return &RunResult{ + RunID: runID, + Status: "awaiting_hitl", + Response: fmt.Sprintf("工作流在节点「%s」等待下一次人工确认。", err.(*AwaitingHITLError).NodeID), + AwaitingHITL: true, + }, nil + } + _ = args.DB.FinishWorkflowRun(runID, "failed", "", err.Error()) + return nil, err + } + _ = awaiting + + output := map[string]interface{}{ + "workflowId": wf.ID, + "workflowName": wf.Name, + "workflowVersion": wf.Version, + "workflowRunId": runID, + "status": "completed", + "outputs": state.Outputs, + "executedNodes": state.Executed, + "skippedNodes": state.Skipped, + "engine": "eino_workflow", + } + outputJSON, _ := json.Marshal(output) + response := renderWorkflowResponse(args.Role.Name, wf.Name, wf.Version, runID, state) + _ = args.DB.FinishWorkflowRun(runID, "completed", string(outputJSON), "") + return &RunResult{Response: response, RunID: runID, Status: "completed"}, nil +} diff --git a/internal/workflow/eino_compile_test.go b/internal/workflow/eino_compile_test.go new file mode 100644 index 00000000..227eeea7 --- /dev/null +++ b/internal/workflow/eino_compile_test.go @@ -0,0 +1,195 @@ +package workflow + +import ( + "context" + "path/filepath" + "testing" + + "cyberstrike-ai/internal/config" + "cyberstrike-ai/internal/database" + + "go.uber.org/zap" +) + +func testWorkflowDB(t *testing.T) *database.DB { + t.Helper() + dir := t.TempDir() + db, err := database.NewDB(filepath.Join(dir, "workflow.db"), zap.NewNop()) + if err != nil { + t.Fatalf("NewDB: %v", err) + } + t.Cleanup(func() { _ = db.Close() }) + return db +} + +func linearStartOutputGraph() string { + return `{ + "nodes": [ + {"id": "start-1", "type": "start", "label": "开始", "position": {"x": 0, "y": 0}, "config": {}}, + {"id": "out-1", "type": "output", "label": "输出", "position": {"x": 0, "y": 120}, "config": {"output_key": "result", "source_binding": {"from": "inputs", "field": "message"}}} + ], + "edges": [ + {"id": "e1", "source": "start-1", "target": "out-1"} + ], + "config": {"schema_version": 1} +}` +} + +func conditionBranchGraph() string { + return `{ + "nodes": [ + {"id": "start-1", "type": "start", "label": "开始", "position": {"x": 0, "y": 0}, "config": {}}, + {"id": "cond-1", "type": "condition", "label": "判断", "position": {"x": 0, "y": 80}, "config": {"expression": "{{inputs.message}} == yes"}}, + {"id": "out-yes", "type": "output", "label": "是", "position": {"x": -80, "y": 160}, "config": {"output_key": "branch", "static_value": "yes"}}, + {"id": "out-no", "type": "output", "label": "否", "position": {"x": 80, "y": 160}, "config": {"output_key": "branch", "static_value": "no"}} + ], + "edges": [ + {"id": "e1", "source": "start-1", "target": "cond-1"}, + {"id": "e2", "source": "cond-1", "target": "out-yes", "label": "是"}, + {"id": "e3", "source": "cond-1", "target": "out-no", "label": "否"} + ], + "config": {"schema_version": 1} +}` +} + +func TestValidateGraphJSON_linear(t *testing.T) { + if err := ValidateGraphJSON(context.Background(), linearStartOutputGraph()); err != nil { + t.Fatalf("validate: %v", err) + } +} + +func TestCompileEngine_linear(t *testing.T) { + ctx := context.Background() + SetCheckpointDir(t.TempDir()) + g, err := parseGraph(linearStartOutputGraph()) + if err != nil { + t.Fatal(err) + } + if _, err := defaultEngine.compile(ctx, g); err != nil { + t.Fatalf("compile: %v", err) + } +} + +func createTestWorkflowRun(t *testing.T, db *database.DB, runID string) { + t.Helper() + if err := db.CreateWorkflowRun(&database.WorkflowRun{ + ID: runID, + WorkflowID: "test-wf", + Status: "running", + }); err != nil { + t.Fatalf("CreateWorkflowRun: %v", err) + } +} + +func TestExecuteEinoGraph_linearStartOutput(t *testing.T) { + ctx := context.Background() + SetCheckpointDir(t.TempDir()) + db := testWorkflowDB(t) + createTestWorkflowRun(t, db, "run-linear") + g, err := parseGraph(linearStartOutputGraph()) + if err != nil { + t.Fatal(err) + } + state := newWorkflowLocalState(map[string]interface{}{"message": "ping"}, "run-linear") + args := RunArgs{DB: db} + if err := executeEinoGraph(ctx, args, "run-linear", "test-wf", 1, g, state); err != nil { + t.Fatalf("execute: %v", err) + } + if got := state.Outputs["result"]; got != "ping" { + t.Fatalf("outputs[result] = %v, want ping", got) + } + if len(state.Executed) != 2 { + t.Fatalf("executed nodes = %d, want 2", len(state.Executed)) + } +} + +func TestExecuteEinoGraph_conditionBranch(t *testing.T) { + ctx := context.Background() + SetCheckpointDir(t.TempDir()) + db := testWorkflowDB(t) + createTestWorkflowRun(t, db, "run-yes") + createTestWorkflowRun(t, db, "run-no") + g, err := parseGraph(conditionBranchGraph()) + if err != nil { + t.Fatal(err) + } + + stateYes := newWorkflowLocalState(map[string]interface{}{"message": "yes"}, "run-yes") + if err := executeEinoGraph(ctx, RunArgs{DB: db}, "run-yes", "test-wf-branch", 1, g, stateYes); err != nil { + t.Fatalf("execute yes: %v", err) + } + if got := stateYes.Outputs["branch"]; got != "yes" { + t.Fatalf("yes branch output = %v", got) + } + + stateNo := newWorkflowLocalState(map[string]interface{}{"message": "no"}, "run-no") + if err := executeEinoGraph(ctx, RunArgs{DB: db}, "run-no", "test-wf-branch", 1, g, stateNo); err != nil { + t.Fatalf("execute no: %v", err) + } + if got := stateNo.Outputs["branch"]; got != "no" { + t.Fatalf("no branch output = %v", got) + } +} + +func TestRunRoleBoundWorkflow_integration(t *testing.T) { + ctx := context.Background() + SetCheckpointDir(t.TempDir()) + db := testWorkflowDB(t) + graph := linearStartOutputGraph() + if err := db.UpsertWorkflowDefinition(&database.WorkflowDefinition{ + ID: "wf-linear", + Name: "线性流程", + Version: 1, + GraphJSON: graph, + Enabled: true, + }); err != nil { + t.Fatal(err) + } + role := config.RoleConfig{ + Name: "tester", + Enabled: true, + WorkflowID: "wf-linear", + WorkflowPolicy: "auto", + } + result, err := RunRoleBoundWorkflow(ctx, RunArgs{ + DB: db, + Logger: zap.NewNop(), + Role: role, + UserMessage: "from-role", + }) + if err != nil { + t.Fatalf("RunRoleBoundWorkflow: %v", err) + } + if result == nil || result.RunID == "" { + t.Fatal("expected run result") + } +} + +func TestCompiledCache_reuse(t *testing.T) { + ctx := context.Background() + SetCheckpointDir(t.TempDir()) + InvalidateCompiledCache("cache-wf") + g, err := parseGraph(linearStartOutputGraph()) + if err != nil { + t.Fatal(err) + } + a1, err := defaultEngine.getOrCompile(ctx, "cache-wf", 1, g) + if err != nil { + t.Fatal(err) + } + a2, err := defaultEngine.getOrCompile(ctx, "cache-wf", 1, g) + if err != nil { + t.Fatal(err) + } + if a1 != a2 { + t.Fatal("expected cached artifact pointer reuse") + } + InvalidateCompiledCache("cache-wf") + a3, err := defaultEngine.getOrCompile(ctx, "cache-wf", 1, g) + if err != nil { + t.Fatal(err) + } + if a1 == a3 { + t.Fatal("expected new artifact after invalidation") + } +} diff --git a/internal/workflow/eino_runtime.go b/internal/workflow/eino_runtime.go new file mode 100644 index 00000000..904b6c5a --- /dev/null +++ b/internal/workflow/eino_runtime.go @@ -0,0 +1,64 @@ +package workflow + +import ( + "context" + + "cyberstrike-ai/internal/agent" + "cyberstrike-ai/internal/config" + "cyberstrike-ai/internal/database" + + "go.uber.org/zap" +) + +type workflowRuntimeCtxKey struct{} + +// workflowRuntime carries per-run execution context into Eino Workflow local state. +type workflowRuntime struct { + args RunArgs + runID string + idx *graphIndex + state *WorkflowLocalState +} + +func withWorkflowRuntime(ctx context.Context, rt *workflowRuntime) context.Context { + return context.WithValue(ctx, workflowRuntimeCtxKey{}, rt) +} + +func workflowRuntimeFrom(ctx context.Context) *workflowRuntime { + rt, _ := ctx.Value(workflowRuntimeCtxKey{}).(*workflowRuntime) + return rt +} + +func newWorkflowRuntime(args RunArgs, runID string, idx *graphIndex, inputs map[string]interface{}) *workflowRuntime { + return &workflowRuntime{ + args: args, + runID: runID, + idx: idx, + state: newWorkflowLocalState(inputs, runID), + } +} + +// RunArgs is the execution context for a role-bound workflow run. +type RunArgs struct { + DB *database.DB + Logger *zap.Logger + Role config.RoleConfig + AppCfg *config.Config + Agent *agent.Agent + ConversationID string + ProjectID string + UserMessage string + History []agent.ChatMessage + RoleTools []string + AgentsMarkdownDir string + SystemPromptExtra string + AssistantMessageID string + Progress agent.ProgressCallback +} + +type RunResult struct { + Response string + RunID string + Status string + AwaitingHITL bool +} diff --git a/internal/workflow/engine.go b/internal/workflow/engine.go new file mode 100644 index 00000000..769f364c --- /dev/null +++ b/internal/workflow/engine.go @@ -0,0 +1,236 @@ +package workflow + +import ( + "context" + "fmt" + "strings" + "sync" + + "github.com/cloudwego/eino/compose" +) + +type compiledArtifact struct { + runnable compose.Runnable[WorkflowInput, WorkflowOutput] + idx *graphIndex + hitlIDs []string +} + +// Engine compiles and caches Eino Workflow artifacts. +type Engine struct { + mu sync.RWMutex + cache map[string]*compiledArtifact + cpStore compose.CheckPointStore + cpStoreMu sync.Once + cpStoreErr error + checkpointDir string +} + +var defaultEngine = &Engine{ + cache: make(map[string]*compiledArtifact), + checkpointDir: "data/workflow-checkpoints", +} + +// SetCheckpointDir overrides the workflow checkpoint root (mainly for tests). +func SetCheckpointDir(dir string) { + defaultEngine.mu.Lock() + defer defaultEngine.mu.Unlock() + defaultEngine.checkpointDir = strings.TrimSpace(dir) + defaultEngine.cpStore = nil + defaultEngine.cpStoreErr = nil + defaultEngine.cpStoreMu = sync.Once{} +} + +func (e *Engine) checkpointStore() (compose.CheckPointStore, error) { + e.cpStoreMu.Do(func() { + e.cpStore, e.cpStoreErr = newFileCheckPointStore(e.checkpointDir) + }) + return e.cpStore, e.cpStoreErr +} + +// InvalidateCompiledCache drops cached compilations for a workflow id. +func InvalidateCompiledCache(workflowID string) { + workflowID = strings.TrimSpace(workflowID) + if workflowID == "" { + return + } + defaultEngine.mu.Lock() + defer defaultEngine.mu.Unlock() + for key := range defaultEngine.cache { + if strings.HasPrefix(key, workflowID+":") { + delete(defaultEngine.cache, key) + } + } +} + +// ValidateGraphJSON parses and trial-compiles a canvas graph (save-time gate). +func ValidateGraphJSON(ctx context.Context, graphJSON string) error { + g, err := parseGraph(graphJSON) + if err != nil { + return err + } + idx := indexGraph(g) + if len(findStartNodeIDs(idx)) == 0 { + return fmt.Errorf("工作流缺少可执行的起点节点") + } + if !hasTerminalNode(idx) { + return fmt.Errorf("工作流至少需要一个无出边的终点或 output/end 节点") + } + _, err = defaultEngine.compile(ctx, g) + return err +} + +func hasTerminalNode(idx *graphIndex) bool { + for id, node := range idx.nodes { + if len(idx.outgoing[id]) == 0 { + return true + } + if strings.EqualFold(node.Type, "end") || strings.EqualFold(node.Type, "output") { + return true + } + } + return false +} + +func (e *Engine) getOrCompile(ctx context.Context, workflowID string, version int, g *graphDef) (*compiledArtifact, error) { + key := cacheKey(workflowID, version) + e.mu.RLock() + if art, ok := e.cache[key]; ok { + e.mu.RUnlock() + return art, nil + } + e.mu.RUnlock() + + art, err := e.compile(ctx, g) + if err != nil { + return nil, err + } + e.mu.Lock() + if existing, ok := e.cache[key]; ok { + e.mu.Unlock() + return existing, nil + } + e.cache[key] = art + e.mu.Unlock() + return art, nil +} + +func (e *Engine) compile(ctx context.Context, g *graphDef) (*compiledArtifact, error) { + cpStore, err := e.checkpointStore() + if err != nil { + return nil, err + } + idx := indexGraph(g) + hitlIDs := collectHITLNodeIDs(idx) + compileOpts := []compose.GraphCompileOption{ + compose.WithGraphName("CyberStrikeWorkflow"), + compose.WithCheckPointStore(cpStore), + } + if len(hitlIDs) > 0 { + compileOpts = append(compileOpts, compose.WithInterruptBeforeNodes(hitlIDs)) + } + + wf := compose.NewWorkflow[WorkflowInput, WorkflowOutput]( + compose.WithGenLocalState(func(runCtx context.Context) *WorkflowLocalState { + if rt := workflowRuntimeFrom(runCtx); rt != nil && rt.state != nil { + return rt.state + } + return &WorkflowLocalState{ + Outputs: make(map[string]any), + NodeOutputs: make(map[string]map[string]any), + NodeProceed: make(map[string]bool), + } + }), + ) + + nodeRefs := make(map[string]*compose.WorkflowNode, len(idx.nodes)) + for id, node := range idx.nodes { + n := node + if strings.EqualFold(n.Type, "agent") { + sub, err := compileAgentSubgraph(ctx, n) + if err != nil { + return nil, fmt.Errorf("编译 Agent 子图 %s 失败: %w", id, err) + } + nodeRefs[id] = wf.AddGraphNode(id, sub) + continue + } + if strings.EqualFold(n.Type, "start") { + nodeRefs[id] = wf.AddLambdaNode(id, compose.InvokableLambda(func(runCtx context.Context, _ WorkflowInput) (WorkflowNodeOutput, error) { + return runWorkflowNodeLambda(runCtx, n) + })) + continue + } + nodeRefs[id] = wf.AddLambdaNode(id, compose.InvokableLambda(func(runCtx context.Context, _ WorkflowNodeOutput) (WorkflowNodeOutput, error) { + return runWorkflowNodeLambda(runCtx, n) + })) + } + + for id, node := range idx.nodes { + if strings.EqualFold(node.Type, "condition") { + if err := wireConditionBranch(wf, nodeRefs, idx, id, node); err != nil { + return nil, err + } + continue + } + if hasConditionalOutgoingEdges(idx, id) { + if err := wireEdgeConditionBranch(wf, nodeRefs, idx, id, node); err != nil { + return nil, err + } + continue + } + for _, edge := range idx.outgoing[id] { + if target, ok := nodeRefs[edge.Target]; ok { + target.AddInput(id) + } + } + } + + for _, startID := range findStartNodeIDs(idx) { + if ref, ok := nodeRefs[startID]; ok { + ref.AddInput(compose.START) + } + } + + endNode := wf.End() + for id, node := range idx.nodes { + if len(idx.outgoing[id]) == 0 || strings.EqualFold(node.Type, "end") { + endNode.AddInput(id, compose.ToField(id)) + } + } + + runnable, err := wf.Compile(ctx, compileOpts...) + if err != nil { + return nil, err + } + return &compiledArtifact{runnable: runnable, idx: idx, hitlIDs: hitlIDs}, nil +} + +func collectHITLNodeIDs(idx *graphIndex) []string { + var ids []string + for id, node := range idx.nodes { + if strings.EqualFold(node.Type, "hitl") { + ids = append(ids, id) + } + } + return ids +} + +func runWorkflowNodeLambda(runCtx context.Context, n graphNode) (WorkflowNodeOutput, error) { + localRT := workflowRuntimeFrom(runCtx) + if localRT == nil { + return nil, fmt.Errorf("workflow runtime missing in context") + } + result, proceed, err := executeNode(runCtx, localRT.args, localRT.runID, n, localRT.state) + if err != nil { + return nil, err + } + localRT.state.NodeOutputs[n.ID] = result + localRT.state.LastOutput = result + if !proceed && !strings.EqualFold(n.Type, "end") { + label := firstNonEmpty(n.Label, n.ID) + if errText := cfgString(result, "error"); errText != "" { + return result, fmt.Errorf("节点「%s」失败: %s", label, errText) + } + return result, fmt.Errorf("节点「%s」未继续执行", label) + } + return result, nil +} diff --git a/internal/workflow/errors.go b/internal/workflow/errors.go new file mode 100644 index 00000000..ffec51da --- /dev/null +++ b/internal/workflow/errors.go @@ -0,0 +1,24 @@ +package workflow + +import "errors" + +// AwaitingHITLError indicates the workflow paused before a HITL node for human approval. +type AwaitingHITLError struct { + RunID string + NodeID string + NodeLabel string + Prompt string + Reviewer string +} + +func (e *AwaitingHITLError) Error() string { + if e == nil { + return "workflow awaiting human approval" + } + return "workflow awaiting human approval at node " + e.NodeID +} + +func IsAwaitingHITL(err error) bool { + var target *AwaitingHITLError + return errors.As(err, &target) +} diff --git a/internal/workflow/graph_types.go b/internal/workflow/graph_types.go new file mode 100644 index 00000000..8e49be11 --- /dev/null +++ b/internal/workflow/graph_types.go @@ -0,0 +1,153 @@ +package workflow + +import ( + "encoding/json" + "fmt" + "sort" + "strings" +) + +type graphDef struct { + Nodes []graphNode `json:"nodes"` + Edges []graphEdge `json:"edges"` + Config map[string]any `json:"config"` +} + +type graphNode struct { + ID string `json:"id"` + Type string `json:"type"` + Label string `json:"label"` + Position graphPosition `json:"position"` + Config map[string]any `json:"config"` +} + +type graphEdge struct { + ID string `json:"id"` + Source string `json:"source"` + Target string `json:"target"` + Label string `json:"label"` + Config map[string]any `json:"config"` +} + +type graphPosition struct { + X float64 `json:"x"` + Y float64 `json:"y"` +} + +type graphIndex struct { + nodes map[string]graphNode + outgoing map[string][]graphEdge + incoming map[string][]graphEdge +} + +func parseGraph(raw string) (*graphDef, error) { + var g graphDef + if err := json.Unmarshal([]byte(strings.TrimSpace(raw)), &g); err != nil { + return nil, fmt.Errorf("解析工作流图失败: %w", err) + } + if len(g.Nodes) == 0 { + return nil, fmt.Errorf("工作流没有节点") + } + if g.Config == nil { + g.Config = make(map[string]any) + } + return &g, nil +} + +func indexGraph(g *graphDef) *graphIndex { + idx := &graphIndex{ + nodes: make(map[string]graphNode, len(g.Nodes)), + outgoing: make(map[string][]graphEdge), + incoming: make(map[string][]graphEdge), + } + for _, node := range g.Nodes { + node.ID = strings.TrimSpace(node.ID) + if node.ID == "" { + continue + } + if strings.TrimSpace(node.Type) == "" { + node.Type = "tool" + } + if node.Config == nil { + node.Config = make(map[string]any) + } + idx.nodes[node.ID] = node + } + for _, edge := range g.Edges { + if _, ok := idx.nodes[edge.Source]; !ok { + continue + } + if _, ok := idx.nodes[edge.Target]; !ok { + continue + } + idx.outgoing[edge.Source] = append(idx.outgoing[edge.Source], edge) + idx.incoming[edge.Target] = append(idx.incoming[edge.Target], edge) + } + for source := range idx.outgoing { + sortEdgesByCanvas(idx.outgoing[source], idx.nodes) + } + return idx +} + +func sortEdgesByCanvas(edges []graphEdge, nodes map[string]graphNode) { + sort.SliceStable(edges, func(i, j int) bool { + a := nodes[edges[i].Target] + b := nodes[edges[j].Target] + if a.Position.Y != b.Position.Y { + return a.Position.Y < b.Position.Y + } + if a.Position.X != b.Position.X { + return a.Position.X < b.Position.X + } + return edges[i].Target < edges[j].Target + }) +} + +func sortNodeIDsByCanvas(ids []string, nodes map[string]graphNode) { + sort.SliceStable(ids, func(i, j int) bool { + a := nodes[ids[i]] + b := nodes[ids[j]] + if a.Position.Y != b.Position.Y { + return a.Position.Y < b.Position.Y + } + if a.Position.X != b.Position.X { + return a.Position.X < b.Position.X + } + return ids[i] < ids[j] + }) +} + +func findStartNodeIDs(idx *graphIndex) []string { + var queue []string + for id, node := range idx.nodes { + if strings.EqualFold(node.Type, "start") { + queue = append(queue, id) + } + } + if len(queue) == 0 { + inDegree := make(map[string]int, len(idx.nodes)) + for id := range idx.nodes { + inDegree[id] = 0 + } + for _, edges := range idx.outgoing { + for _, edge := range edges { + inDegree[edge.Target]++ + } + } + for id, deg := range inDegree { + if deg == 0 { + queue = append(queue, id) + } + } + } + sortNodeIDsByCanvas(queue, idx.nodes) + return queue +} + +func branchNodeID(nodeID string) string { + return nodeID + "__eino_branch" +} + +func edgeBranchNodeID(nodeID string) string { + return nodeID + "__eino_edge_branch" +} diff --git a/internal/workflow/node_exec.go b/internal/workflow/node_exec.go new file mode 100644 index 00000000..68d57ca3 --- /dev/null +++ b/internal/workflow/node_exec.go @@ -0,0 +1,131 @@ +package workflow + +import ( + "context" + "encoding/json" + "fmt" + "strings" + "time" + + "cyberstrike-ai/internal/database" + + "github.com/google/uuid" +) + +func executeNode(ctx context.Context, args RunArgs, runID string, node graphNode, state *WorkflowLocalState) (map[string]any, bool, error) { + label := node.Label + if strings.TrimSpace(label) == "" { + label = node.ID + } + nodeRunID := uuid.NewString() + input := map[string]any{ + "nodeId": node.ID, + "nodeType": node.Type, + "label": label, + "inputs": state.Inputs, + "previous": state.LastOutput, + } + inputJSON, _ := json.Marshal(input) + if err := args.DB.CreateWorkflowNodeRun(&database.WorkflowNodeRun{ + ID: nodeRunID, + RunID: runID, + NodeID: node.ID, + Status: "running", + InputJSON: string(inputJSON), + StartedAt: time.Now(), + }); err != nil { + return nil, false, err + } + if args.Progress != nil { + args.Progress("workflow_node_start", fmt.Sprintf("开始节点:%s", label), map[string]any{ + "workflowRunId": runID, + "nodeRunId": nodeRunID, + "nodeId": node.ID, + "nodeType": node.Type, + "label": label, + }) + } + + result, proceed, status, errText := runBuiltinNode(ctx, args, node, state) + outputJSON, _ := json.Marshal(result) + if err := args.DB.FinishWorkflowNodeRun(nodeRunID, status, string(outputJSON), errText); err != nil { + return nil, false, err + } + if status == "skipped" { + state.Skipped = append(state.Skipped, label) + } else { + state.Executed = append(state.Executed, label) + } + if args.Progress != nil { + progressData := map[string]any{ + "workflowRunId": runID, + "nodeRunId": nodeRunID, + "nodeId": node.ID, + "nodeType": node.Type, + "label": label, + "status": status, + "output": result, + } + progressMsg := fmt.Sprintf("节点完成:%s(%s)", label, status) + if strings.EqualFold(node.Type, "condition") { + matched := false + if v, ok := result["matched"].(bool); ok { + matched = v + } + expr := cfgString(node.Config, "expression") + if matched { + progressMsg = fmt.Sprintf("条件判断:%s → 是", label) + } else { + progressMsg = fmt.Sprintf("条件判断:%s → 否", label) + } + progressData["expression"] = expr + progressData["matched"] = matched + } + args.Progress("workflow_node_result", progressMsg, progressData) + } + state.NodeProceed[node.ID] = proceed + return result, proceed, nil +} + +func emitConditionBranchProgress(args RunArgs, runID string, node graphNode, edges []graphEdge, nodes map[string]graphNode, state *WorkflowLocalState) { + if args.Progress == nil || len(edges) == 0 { + return + } + for edgeIdx, edge := range edges { + allowed := edgeAllowed(edge, node, edgeIdx, state) + target := nodes[edge.Target] + targetLabel := strings.TrimSpace(target.Label) + if targetLabel == "" { + targetLabel = edge.Target + } + branchLabel := strings.TrimSpace(edge.Label) + if branchLabel == "" { + switch edgeIdx { + case 0: + branchLabel = "是" + case 1: + branchLabel = "否" + default: + branchLabel = fmt.Sprintf("分支 %d", edgeIdx+1) + } + } + cond := firstNonEmpty(cfgString(edge.Config, "condition"), cfgString(edge.Config, "expression")) + eventType := "workflow_branch_skipped" + msg := fmt.Sprintf("跳过分支「%s」→ %s", branchLabel, targetLabel) + if allowed { + eventType = "workflow_branch_taken" + msg = fmt.Sprintf("执行分支「%s」→ %s", branchLabel, targetLabel) + } + args.Progress(eventType, msg, map[string]any{ + "workflowRunId": runID, + "nodeId": node.ID, + "nodeType": node.Type, + "label": node.Label, + "branchLabel": branchLabel, + "targetId": edge.Target, + "targetLabel": targetLabel, + "edgeCondition": cond, + "matched": conditionMatched(state), + }) + } +} diff --git a/internal/workflow/nodes.go b/internal/workflow/nodes.go new file mode 100644 index 00000000..f63bfa73 --- /dev/null +++ b/internal/workflow/nodes.go @@ -0,0 +1,323 @@ +package workflow + +import ( + "context" + "fmt" + "strings" + + "cyberstrike-ai/internal/agent" + "cyberstrike-ai/internal/multiagent" +) + +func runBuiltinNode(ctx context.Context, args RunArgs, node graphNode, state *WorkflowLocalState) (map[string]any, bool, string, string) { + cfg := node.Config + switch strings.ToLower(strings.TrimSpace(node.Type)) { + case "start": + out := map[string]any{ + "output": state.Inputs["message"], + "message": state.Inputs["message"], + "conversationId": state.Inputs["conversationId"], + "projectId": state.Inputs["projectId"], + } + return out, true, "completed", "" + case "condition": + expr := cfgString(cfg, "expression") + ok := evalCondition(expr, state) + out := map[string]any{"output": ok, "condition": expr, "matched": ok} + return out, true, "completed", "" + case "output": + key := cfgString(cfg, "output_key") + if key == "" { + key = "result" + } + var value any + if v := cfgString(cfg, "static_value"); v != "" { + value = v + } else { + value = resolveOutputSourceBinding(cfg, state) + } + state.Outputs[key] = value + return map[string]any{"output": value, "outputs": map[string]any{key: value}}, true, "completed", "" + case "end": + value := resolveOutputSourceBinding(cfg, state) + if b, ok := parseFieldBinding(cfg, "result_binding"); ok { + value = resolveBinding(b, state) + } + return map[string]any{"output": value}, false, "completed", "" + case "tool": + return runToolNode(ctx, args, node, state) + case "agent": + return runAgentNode(ctx, args, node, state) + case "hitl": + return runHITLNode(args, node, state) + default: + reason := "未知节点类型" + return map[string]any{"output": "", "skipped": true, "reason": reason, "node_type": node.Type}, true, "skipped", reason + } +} + +func runToolNode(ctx context.Context, args RunArgs, node graphNode, state *WorkflowLocalState) (map[string]any, bool, string, string) { + toolName := cfgString(node.Config, "tool_name") + if toolName == "" { + errText := "工具节点未选择 MCP 工具" + return map[string]any{"output": "", "error": errText}, false, "failed", errText + } + if args.Agent == nil { + errText := "工具节点执行失败:Agent 为空" + return map[string]any{"output": "", "tool_name": toolName, "error": errText}, false, "failed", errText + } + toolArgs, err := resolveToolArguments(node.Config, state) + if err != nil { + errText := fmt.Sprintf("工具参数不是合法 JSON:%v", err) + return map[string]any{"output": "", "tool_name": toolName, "error": errText}, false, "failed", errText + } + if args.Progress != nil { + args.Progress("workflow_tool_start", fmt.Sprintf("调用工具:%s", toolName), map[string]any{ + "nodeId": node.ID, + "tool": toolName, + "args": toolArgs, + }) + } + result, err := args.Agent.ExecuteMCPToolForConversation(ctx, args.ConversationID, toolName, toolArgs) + if err != nil { + errText := err.Error() + return map[string]any{"output": "", "tool_name": toolName, "arguments": toolArgs, "error": errText}, false, "failed", errText + } + output := "" + executionID := "" + isError := false + if result != nil { + output = result.Result + executionID = result.ExecutionID + isError = result.IsError + } + out := map[string]any{ + "output": output, + "tool_name": toolName, + "arguments": toolArgs, + "execution_id": executionID, + "is_error": isError, + } + if key := cfgString(node.Config, "output_key"); key != "" { + state.Outputs[key] = output + } + if isError { + errText := strings.TrimSpace(output) + if errText == "" { + errText = "工具返回错误" + } + return out, false, "failed", errText + } + return out, true, "completed", "" +} + +func runAgentNode(ctx context.Context, args RunArgs, node graphNode, state *WorkflowLocalState) (map[string]any, bool, string, string) { + if args.AppCfg == nil || args.Agent == nil { + errText := "Agent 节点执行失败:应用配置或 Agent 为空" + return map[string]any{"output": "", "error": errText}, false, "failed", errText + } + mode := strings.ToLower(cfgString(node.Config, "agent_mode")) + if mode == "" { + mode = "eino_single" + } + inputSource := resolveNodeInputBinding(node.Config, state) + message := buildAgentNodeMessage(node, state, inputSource) + var result *multiagent.RunResult + var err error + state.SegmentMaxIteration = 0 + agentProgress := workflowAgentProgress(args.Progress, state, node) + switch mode { + case "eino_single", "single", "chat": + result, err = multiagent.RunEinoSingleChatModelAgent( + ctx, + args.AppCfg, + &args.AppCfg.MultiAgent, + args.Agent, + args.DB, + args.Logger, + args.ConversationID, + args.ProjectID, + message, + args.History, + args.RoleTools, + agentProgress, + nil, + args.SystemPromptExtra, + ) + default: + result, err = multiagent.RunDeepAgent( + ctx, + args.AppCfg, + &args.AppCfg.MultiAgent, + args.Agent, + args.DB, + args.Logger, + args.ConversationID, + args.ProjectID, + message, + args.History, + args.RoleTools, + agentProgress, + args.AgentsMarkdownDir, + mode, + nil, + args.SystemPromptExtra, + ) + } + if err != nil { + errText := err.Error() + state.MainIterationOffset += state.SegmentMaxIteration + return map[string]any{"output": "", "mode": mode, "error": errText}, false, "failed", errText + } + state.MainIterationOffset += state.SegmentMaxIteration + response := "" + mcpIDs := []string{} + if result != nil { + response = result.Response + mcpIDs = result.MCPExecutionIDs + } + if args.Progress != nil { + args.Progress("workflow_agent_output", response, map[string]any{ + "nodeId": node.ID, + "label": firstNonEmpty(node.Label, node.ID), + "mode": mode, + "inputSource": inputSource, + "inputPreview": truncateWorkflowPreview(inputSource, 500), + "mcpExecutionIds": mcpIDs, + }) + } + if key := cfgString(node.Config, "output_key"); key != "" { + state.Outputs[key] = response + } + return map[string]any{ + "output": response, + "mode": mode, + "mcp_execution_ids": mcpIDs, + }, true, "completed", "" +} + +func buildAgentNodeMessage(node graphNode, state *WorkflowLocalState, upstreamInput string) string { + instruction := strings.TrimSpace(cfgString(node.Config, "instruction")) + upstreamInput = strings.TrimSpace(upstreamInput) + if instruction == "" { + if upstreamInput != "" { + return fmt.Sprintf("请基于上游节点输出继续处理:\n%s", upstreamInput) + } + return fmt.Sprintf("请基于上游节点输出继续处理:\n%v", state.LastOutput["output"]) + } + if upstreamInput == "" { + return instruction + } + return strings.TrimSpace(fmt.Sprintf("上游输入:\n%s\n\n节点指令:\n%s", upstreamInput, instruction)) +} + +func workflowAgentProgress(progress agent.ProgressCallback, state *WorkflowLocalState, node graphNode) agent.ProgressCallback { + if progress == nil { + return nil + } + return func(eventType, message string, data interface{}) { + switch eventType { + case "response_start", "response_delta", "response", "done": + return + default: + enrichWorkflowAgentEventData(data, state, node) + if eventType == "iteration" { + applyWorkflowMainIterationOffset(data, state) + } + progress(eventType, message, data) + } + } +} + +func enrichWorkflowAgentEventData(data interface{}, state *WorkflowLocalState, node graphNode) { + m, ok := data.(map[string]interface{}) + if !ok || m == nil { + return + } + if node.ID != "" { + m["workflowNodeId"] = node.ID + } + if state != nil && strings.TrimSpace(state.WorkflowRunID) != "" { + m["workflowRunId"] = state.WorkflowRunID + } +} + +func applyWorkflowMainIterationOffset(data interface{}, state *WorkflowLocalState) { + if state == nil { + return + } + m, ok := data.(map[string]interface{}) + if !ok || m == nil { + return + } + scope, _ := m["einoScope"].(string) + if strings.TrimSpace(scope) != "main" { + return + } + raw := iterationNumberFromProgressData(m) + if raw <= 0 { + return + } + if raw > state.SegmentMaxIteration { + state.SegmentMaxIteration = raw + } + m["iteration"] = raw + state.MainIterationOffset +} + +func iterationNumberFromProgressData(m map[string]interface{}) int { + switch v := m["iteration"].(type) { + case int: + return v + case int32: + return int(v) + case int64: + return int(v) + case float64: + return int(v) + case float32: + return int(v) + default: + return 0 + } +} + +func runHITLNode(args RunArgs, node graphNode, state *WorkflowLocalState) (map[string]any, bool, string, string) { + prompt := resolveHITLPromptBinding(node.Config, state) + reviewer := cfgString(node.Config, "reviewer") + if reviewer == "" { + reviewer = "human" + } + approved := true + if state != nil && state.Inputs != nil { + if v, ok := state.Inputs["_hitl_approved"]; ok { + approved = fmt.Sprint(v) == "true" + } + } + if !approved { + reason := "人工审批已拒绝" + if state != nil && state.Inputs != nil { + if v, ok := state.Inputs["_hitl_comment"]; ok { + if s := strings.TrimSpace(fmt.Sprint(v)); s != "" { + reason = s + } + } + } + return map[string]any{"output": "", "prompt": prompt, "approved": false, "mode": "interactive"}, false, "failed", reason + } + if args.Progress != nil { + args.Progress("workflow_hitl_checkpoint", "人工确认节点已通过", map[string]any{ + "nodeId": node.ID, + "prompt": prompt, + "reviewer": reviewer, + "mode": "interactive", + "approved": true, + }) + } + return map[string]any{ + "output": prompt, + "prompt": prompt, + "reviewer": reviewer, + "approved": true, + "mode": "interactive", + }, true, "completed", "" +} diff --git a/internal/workflow/runner.go b/internal/workflow/runner.go index 8e5e6f60..3c8152b2 100644 --- a/internal/workflow/runner.go +++ b/internal/workflow/runner.go @@ -4,82 +4,16 @@ import ( "context" "encoding/json" "fmt" - "regexp" - "sort" "strings" "time" - "cyberstrike-ai/internal/agent" "cyberstrike-ai/internal/config" "cyberstrike-ai/internal/database" - "cyberstrike-ai/internal/multiagent" "github.com/google/uuid" "go.uber.org/zap" ) -type RunArgs struct { - DB *database.DB - Logger *zap.Logger - Role config.RoleConfig - AppCfg *config.Config - Agent *agent.Agent - ConversationID string - ProjectID string - UserMessage string - History []agent.ChatMessage - RoleTools []string - AgentsMarkdownDir string - SystemPromptExtra string - AssistantMessageID string - Progress agent.ProgressCallback -} - -type RunResult struct { - Response string - RunID string -} - -type graphDef struct { - Nodes []graphNode `json:"nodes"` - Edges []graphEdge `json:"edges"` - Config map[string]any `json:"config"` -} - -type graphNode struct { - ID string `json:"id"` - Type string `json:"type"` - Label string `json:"label"` - Position graphPosition `json:"position"` - Config map[string]any `json:"config"` -} - -type graphEdge struct { - ID string `json:"id"` - Source string `json:"source"` - Target string `json:"target"` - Label string `json:"label"` - Config map[string]any `json:"config"` -} - -type graphPosition struct { - X float64 `json:"x"` - Y float64 `json:"y"` -} - -type workflowExecState struct { - inputs map[string]any - outputs map[string]any - nodeOutputs map[string]map[string]any - lastOutput map[string]any - executed []string - skipped []string - workflowRunID string - // 图编排内多个 Agent 节点各自从第 1 轮上报 iteration;累计偏移避免对话页迭代序号回跳与流式条目复用错乱。 - mainIterationOffset int - segmentMaxIteration int -} - // ShouldAutoRunRoleWorkflow returns true when a role explicitly binds a workflow // and does not turn it off. Empty policy defaults to auto to keep role UX simple. func ShouldAutoRunRoleWorkflow(role config.RoleConfig) bool { @@ -90,10 +24,7 @@ func ShouldAutoRunRoleWorkflow(role config.RoleConfig) bool { return policy == "" || policy == "auto" } -// RunRoleBoundWorkflow executes the persisted role-bound workflow graph. -// Control nodes are interpreted locally, tool nodes call the existing MCP bridge, -// and agent nodes reuse the existing Eino ADK runners so role-bound flows share -// the same model/tool/session behavior as the chat page. +// RunRoleBoundWorkflow executes the persisted role-bound workflow via cached Eino Workflow. func RunRoleBoundWorkflow(ctx context.Context, args RunArgs) (*RunResult, error) { if args.DB == nil { return nil, fmt.Errorf("workflow db is nil") @@ -150,6 +81,7 @@ func RunRoleBoundWorkflow(ctx context.Context, args RunArgs) (*RunResult, error) "workflowVersion": wf.Version, "workflowRunId": runID, "conversationId": args.ConversationID, + "engine": "eino_workflow", }) } @@ -158,13 +90,44 @@ func RunRoleBoundWorkflow(ctx context.Context, args RunArgs) (*RunResult, error) _ = args.DB.FinishWorkflowRun(runID, "failed", "", err.Error()) return nil, err } - state := &workflowExecState{ - inputs: input, - outputs: make(map[string]any), - nodeOutputs: make(map[string]map[string]any), - workflowRunID: runID, - } - if err := executeGraph(ctx, args, runID, graph, state); err != nil { + state := newWorkflowLocalState(input, runID) + if err := executeEinoGraph(ctx, args, runID, wf.ID, wf.Version, graph, state); err != nil { + if IsAwaitingHITL(err) { + hitl := err.(*AwaitingHITLError) + partial := map[string]interface{}{ + "workflowId": wf.ID, + "workflowName": wf.Name, + "workflowVersion": wf.Version, + "workflowRunId": runID, + "status": "awaiting_hitl", + "outputs": state.Outputs, + "executedNodes": state.Executed, + "skippedNodes": state.Skipped, + "pendingHitl": map[string]interface{}{ + "nodeId": hitl.NodeID, + "label": hitl.NodeLabel, + "prompt": hitl.Prompt, + }, + "engine": "eino_workflow", + } + partialJSON, _ := json.Marshal(partial) + _ = args.DB.SetWorkflowRunAwaitingHITL(runID, hitl.NodeID, string(partialJSON)) + response := fmt.Sprintf("工作流「%s」已在节点「%s」暂停,等待人工审批。\n运行 ID:%s", wf.Name, firstNonEmpty(hitl.NodeLabel, hitl.NodeID), runID) + if args.Progress != nil { + args.Progress("workflow_paused", response, map[string]interface{}{ + "workflowRunId": runID, + "status": "awaiting_hitl", + "nodeId": hitl.NodeID, + "resumeApi": fmt.Sprintf("/api/workflows/runs/%s/resume", runID), + }) + } + return &RunResult{ + Response: response, + RunID: runID, + Status: "awaiting_hitl", + AwaitingHITL: true, + }, nil + } _ = args.DB.FinishWorkflowRun(runID, "failed", "", err.Error()) return nil, err } @@ -175,9 +138,10 @@ func RunRoleBoundWorkflow(ctx context.Context, args RunArgs) (*RunResult, error) "workflowVersion": wf.Version, "workflowRunId": runID, "status": "completed", - "outputs": state.outputs, - "executedNodes": state.executed, - "skippedNodes": state.skipped, + "outputs": state.Outputs, + "executedNodes": state.Executed, + "skippedNodes": state.Skipped, + "engine": "eino_workflow", } outputJSON, _ := json.Marshal(output) @@ -189,8 +153,9 @@ func RunRoleBoundWorkflow(ctx context.Context, args RunArgs) (*RunResult, error) args.Progress("workflow_done", fmt.Sprintf("流程「%s」运行完成", wf.Name), map[string]interface{}{ "workflowRunId": runID, "workflowId": wf.ID, - "outputs": state.outputs, + "outputs": state.Outputs, "response": response, + "engine": "eino_workflow", }) } if args.Logger != nil { @@ -199,746 +164,8 @@ func RunRoleBoundWorkflow(ctx context.Context, args RunArgs) (*RunResult, error) zap.String("workflow_run_id", runID), zap.String("conversation_id", args.ConversationID), zap.String("role", args.Role.Name), + zap.String("engine", "eino_workflow"), ) } - return &RunResult{Response: response, RunID: runID}, nil -} - -func parseGraph(raw string) (*graphDef, error) { - var g graphDef - if err := json.Unmarshal([]byte(strings.TrimSpace(raw)), &g); err != nil { - return nil, fmt.Errorf("解析工作流图失败: %w", err) - } - if len(g.Nodes) == 0 { - return nil, fmt.Errorf("工作流没有节点") - } - if g.Config == nil { - g.Config = make(map[string]any) - } - return &g, nil -} - -func executeGraph(ctx context.Context, args RunArgs, runID string, g *graphDef, state *workflowExecState) error { - nodes := make(map[string]graphNode, len(g.Nodes)) - inDegree := make(map[string]int, len(g.Nodes)) - outgoing := make(map[string][]graphEdge) - for _, node := range g.Nodes { - node.ID = strings.TrimSpace(node.ID) - if node.ID == "" { - continue - } - if strings.TrimSpace(node.Type) == "" { - node.Type = "tool" - } - if node.Config == nil { - node.Config = make(map[string]any) - } - nodes[node.ID] = node - inDegree[node.ID] = 0 - } - for _, edge := range g.Edges { - if _, ok := nodes[edge.Source]; !ok { - continue - } - if _, ok := nodes[edge.Target]; !ok { - continue - } - outgoing[edge.Source] = append(outgoing[edge.Source], edge) - inDegree[edge.Target]++ - } - for source := range outgoing { - sort.SliceStable(outgoing[source], func(i, j int) bool { - a := nodes[outgoing[source][i].Target] - b := nodes[outgoing[source][j].Target] - if a.Position.Y != b.Position.Y { - return a.Position.Y < b.Position.Y - } - if a.Position.X != b.Position.X { - return a.Position.X < b.Position.X - } - return outgoing[source][i].Target < outgoing[source][j].Target - }) - } - - var queue []string - for id, node := range nodes { - if strings.EqualFold(node.Type, "start") { - queue = append(queue, id) - } - } - if len(queue) == 0 { - for id, deg := range inDegree { - if deg == 0 { - queue = append(queue, id) - } - } - } - sortNodeIDsByCanvas(queue, nodes) - seen := make(map[string]bool) - remainingIncoming := make(map[string]int, len(inDegree)) - for id, deg := range inDegree { - remainingIncoming[id] = deg - } - for len(queue) > 0 { - select { - case <-ctx.Done(): - return ctx.Err() - default: - } - id := queue[0] - queue = queue[1:] - if seen[id] { - continue - } - seen[id] = true - node := nodes[id] - result, proceed, err := executeNode(ctx, args, runID, node, state) - if err != nil { - return err - } - state.nodeOutputs[id] = result - state.lastOutput = result - if proceed { - edges := outgoing[id] - if strings.EqualFold(node.Type, "condition") { - emitConditionBranchProgress(args, runID, node, edges, nodes, state) - } - for edgeIdx, edge := range edges { - if !edgeAllowed(edge, node, edgeIdx, state) { - continue - } - remainingIncoming[edge.Target]-- - if remainingIncoming[edge.Target] > 0 { - continue - } - queue = append(queue, edge.Target) - } - sortNodeIDsByCanvas(queue, nodes) - } - } - return nil -} - -func sortNodeIDsByCanvas(ids []string, nodes map[string]graphNode) { - sort.SliceStable(ids, func(i, j int) bool { - a := nodes[ids[i]] - b := nodes[ids[j]] - if a.Position.Y != b.Position.Y { - return a.Position.Y < b.Position.Y - } - if a.Position.X != b.Position.X { - return a.Position.X < b.Position.X - } - return ids[i] < ids[j] - }) -} - -func executeNode(ctx context.Context, args RunArgs, runID string, node graphNode, state *workflowExecState) (map[string]any, bool, error) { - label := node.Label - if strings.TrimSpace(label) == "" { - label = node.ID - } - nodeRunID := uuid.NewString() - input := map[string]any{ - "nodeId": node.ID, - "nodeType": node.Type, - "label": label, - "inputs": state.inputs, - "previous": state.lastOutput, - } - inputJSON, _ := json.Marshal(input) - if err := args.DB.CreateWorkflowNodeRun(&database.WorkflowNodeRun{ - ID: nodeRunID, - RunID: runID, - NodeID: node.ID, - Status: "running", - InputJSON: string(inputJSON), - StartedAt: time.Now(), - }); err != nil { - return nil, false, err - } - if args.Progress != nil { - args.Progress("workflow_node_start", fmt.Sprintf("开始节点:%s", label), map[string]any{ - "workflowRunId": runID, - "nodeRunId": nodeRunID, - "nodeId": node.ID, - "nodeType": node.Type, - "label": label, - }) - } - - result, proceed, status, errText := runBuiltinNode(ctx, args, node, state) - outputJSON, _ := json.Marshal(result) - if err := args.DB.FinishWorkflowNodeRun(nodeRunID, status, string(outputJSON), errText); err != nil { - return nil, false, err - } - if status == "skipped" { - state.skipped = append(state.skipped, label) - } else { - state.executed = append(state.executed, label) - } - if args.Progress != nil { - progressData := map[string]any{ - "workflowRunId": runID, - "nodeRunId": nodeRunID, - "nodeId": node.ID, - "nodeType": node.Type, - "label": label, - "status": status, - "output": result, - } - progressMsg := fmt.Sprintf("节点完成:%s(%s)", label, status) - if strings.EqualFold(node.Type, "condition") { - matched := false - if v, ok := result["matched"].(bool); ok { - matched = v - } - expr := cfgString(node.Config, "expression") - if matched { - progressMsg = fmt.Sprintf("条件判断:%s → 是", label) - } else { - progressMsg = fmt.Sprintf("条件判断:%s → 否", label) - } - progressData["expression"] = expr - progressData["matched"] = matched - } - args.Progress("workflow_node_result", progressMsg, progressData) - } - return result, proceed, nil -} - -func runBuiltinNode(ctx context.Context, args RunArgs, node graphNode, state *workflowExecState) (map[string]any, bool, string, string) { - cfg := node.Config - switch strings.ToLower(strings.TrimSpace(node.Type)) { - case "start": - out := map[string]any{ - "output": state.inputs["message"], - "message": state.inputs["message"], - "conversationId": state.inputs["conversationId"], - "projectId": state.inputs["projectId"], - } - return out, true, "completed", "" - case "condition": - expr := cfgString(cfg, "expression") - ok := evalCondition(expr, state) - out := map[string]any{"output": ok, "condition": expr, "matched": ok} - // 条件节点始终继续,由出边条件(或连线标签/顺序)决定走「是/否」分支。 - return out, true, "completed", "" - case "output": - key := cfgString(cfg, "output_key") - if key == "" { - key = "result" - } - value := resolveTemplate(cfgString(cfg, "source"), state) - state.outputs[key] = value - return map[string]any{"output": value, "outputs": map[string]any{key: value}}, true, "completed", "" - case "end": - value := resolveTemplate(cfgString(cfg, "result_template"), state) - return map[string]any{"output": value}, false, "completed", "" - case "tool": - return runToolNode(ctx, args, node, state) - case "agent": - return runAgentNode(ctx, args, node, state) - case "hitl": - return runHITLNode(args, node, state) - default: - reason := "未知节点类型" - return map[string]any{"output": "", "skipped": true, "reason": reason, "node_type": node.Type}, true, "skipped", reason - } -} - -func runToolNode(ctx context.Context, args RunArgs, node graphNode, state *workflowExecState) (map[string]any, bool, string, string) { - toolName := cfgString(node.Config, "tool_name") - if toolName == "" { - errText := "工具节点未选择 MCP 工具" - return map[string]any{"output": "", "error": errText}, false, "failed", errText - } - if args.Agent == nil { - errText := "工具节点执行失败:Agent 为空" - return map[string]any{"output": "", "tool_name": toolName, "error": errText}, false, "failed", errText - } - toolArgs, err := parseToolArguments(cfgString(node.Config, "arguments"), state) - if err != nil { - errText := fmt.Sprintf("工具参数不是合法 JSON:%v", err) - return map[string]any{"output": "", "tool_name": toolName, "error": errText}, false, "failed", errText - } - if args.Progress != nil { - args.Progress("workflow_tool_start", fmt.Sprintf("调用工具:%s", toolName), map[string]any{ - "nodeId": node.ID, - "tool": toolName, - "args": toolArgs, - }) - } - result, err := args.Agent.ExecuteMCPToolForConversation(ctx, args.ConversationID, toolName, toolArgs) - if err != nil { - errText := err.Error() - return map[string]any{"output": "", "tool_name": toolName, "arguments": toolArgs, "error": errText}, false, "failed", errText - } - output := "" - executionID := "" - isError := false - if result != nil { - output = result.Result - executionID = result.ExecutionID - isError = result.IsError - } - out := map[string]any{ - "output": output, - "tool_name": toolName, - "arguments": toolArgs, - "execution_id": executionID, - "is_error": isError, - } - if key := cfgString(node.Config, "output_key"); key != "" { - state.outputs[key] = output - } - if isError { - errText := strings.TrimSpace(output) - if errText == "" { - errText = "工具返回错误" - } - return out, false, "failed", errText - } - return out, true, "completed", "" -} - -func runAgentNode(ctx context.Context, args RunArgs, node graphNode, state *workflowExecState) (map[string]any, bool, string, string) { - if args.AppCfg == nil || args.Agent == nil { - errText := "Agent 节点执行失败:应用配置或 Agent 为空" - return map[string]any{"output": "", "error": errText}, false, "failed", errText - } - mode := strings.ToLower(cfgString(node.Config, "agent_mode")) - if mode == "" { - mode = "eino_single" - } - inputSource := cfgString(node.Config, "input_source") - if inputSource == "" { - inputSource = "{{previous.output}}" - } - upstreamInput := strings.TrimSpace(resolveTemplate(inputSource, state)) - message := buildAgentNodeMessage(node, state) - var result *multiagent.RunResult - var err error - state.segmentMaxIteration = 0 - agentProgress := workflowAgentProgress(args.Progress, state, node) - switch mode { - case "eino_single", "single", "chat": - result, err = multiagent.RunEinoSingleChatModelAgent( - ctx, - args.AppCfg, - &args.AppCfg.MultiAgent, - args.Agent, - args.DB, - args.Logger, - args.ConversationID, - args.ProjectID, - message, - args.History, - args.RoleTools, - agentProgress, - nil, - args.SystemPromptExtra, - ) - default: - result, err = multiagent.RunDeepAgent( - ctx, - args.AppCfg, - &args.AppCfg.MultiAgent, - args.Agent, - args.DB, - args.Logger, - args.ConversationID, - args.ProjectID, - message, - args.History, - args.RoleTools, - agentProgress, - args.AgentsMarkdownDir, - mode, - nil, - args.SystemPromptExtra, - ) - } - if err != nil { - errText := err.Error() - state.mainIterationOffset += state.segmentMaxIteration - return map[string]any{"output": "", "mode": mode, "error": errText}, false, "failed", errText - } - state.mainIterationOffset += state.segmentMaxIteration - response := "" - mcpIDs := []string{} - if result != nil { - response = result.Response - mcpIDs = result.MCPExecutionIDs - } - if args.Progress != nil { - args.Progress("workflow_agent_output", response, map[string]any{ - "nodeId": node.ID, - "label": firstNonEmpty(node.Label, node.ID), - "mode": mode, - "inputSource": inputSource, - "inputPreview": truncateWorkflowPreview(upstreamInput, 500), - "mcpExecutionIds": mcpIDs, - }) - } - if key := cfgString(node.Config, "output_key"); key != "" { - state.outputs[key] = response - } - return map[string]any{ - "output": response, - "mode": mode, - "mcp_execution_ids": mcpIDs, - }, true, "completed", "" -} - -func buildAgentNodeMessage(node graphNode, state *workflowExecState) string { - instruction := strings.TrimSpace(resolveTemplate(cfgString(node.Config, "instruction"), state)) - inputSource := cfgString(node.Config, "input_source") - if inputSource == "" { - inputSource = "{{previous.output}}" - } - upstreamInput := strings.TrimSpace(resolveTemplate(inputSource, state)) - if instruction == "" { - if upstreamInput != "" { - return fmt.Sprintf("请基于上游节点输出继续处理:\n%s", upstreamInput) - } - return fmt.Sprintf("请基于上游节点输出继续处理:\n%v", state.lastOutput["output"]) - } - if upstreamInput == "" { - return instruction - } - return strings.TrimSpace(fmt.Sprintf("上游输入:\n%s\n\n节点指令:\n%s", upstreamInput, instruction)) -} - -func workflowAgentProgress(progress agent.ProgressCallback, state *workflowExecState, node graphNode) agent.ProgressCallback { - if progress == nil { - return nil - } - return func(eventType, message string, data interface{}) { - switch eventType { - case "response_start", "response_delta", "response", "done": - return - default: - enrichWorkflowAgentEventData(data, state, node) - if eventType == "iteration" { - applyWorkflowMainIterationOffset(data, state) - } - progress(eventType, message, data) - } - } -} - -func enrichWorkflowAgentEventData(data interface{}, state *workflowExecState, node graphNode) { - m, ok := data.(map[string]interface{}) - if !ok || m == nil { - return - } - if node.ID != "" { - m["workflowNodeId"] = node.ID - } - if state != nil && strings.TrimSpace(state.workflowRunID) != "" { - m["workflowRunId"] = state.workflowRunID - } -} - -func applyWorkflowMainIterationOffset(data interface{}, state *workflowExecState) { - if state == nil { - return - } - m, ok := data.(map[string]interface{}) - if !ok || m == nil { - return - } - scope, _ := m["einoScope"].(string) - if strings.TrimSpace(scope) != "main" { - return - } - raw := iterationNumberFromProgressData(m) - if raw <= 0 { - return - } - if raw > state.segmentMaxIteration { - state.segmentMaxIteration = raw - } - m["iteration"] = raw + state.mainIterationOffset -} - -func iterationNumberFromProgressData(m map[string]interface{}) int { - switch v := m["iteration"].(type) { - case int: - return v - case int32: - return int(v) - case int64: - return int(v) - case float64: - return int(v) - case float32: - return int(v) - default: - return 0 - } -} - -func runHITLNode(args RunArgs, node graphNode, state *workflowExecState) (map[string]any, bool, string, string) { - prompt := resolveTemplate(cfgString(node.Config, "prompt"), state) - reviewer := cfgString(node.Config, "reviewer") - if args.Progress != nil { - args.Progress("workflow_hitl_checkpoint", "人工确认节点已记录", map[string]any{ - "nodeId": node.ID, - "prompt": prompt, - "reviewer": reviewer, - "mode": "record_only", - }) - } - return map[string]any{ - "output": prompt, - "prompt": prompt, - "reviewer": reviewer, - "approved": true, - "mode": "record_only", - }, true, "completed", "" -} - -func parseToolArguments(raw string, state *workflowExecState) (map[string]interface{}, error) { - if raw == "" { - return map[string]interface{}{}, nil - } - raw = strings.TrimSpace(resolveTemplate(raw, state)) - if raw == "" { - return map[string]interface{}{}, nil - } - var args map[string]interface{} - if err := json.Unmarshal([]byte(raw), &args); err != nil { - return nil, err - } - if args == nil { - args = map[string]interface{}{} - } - return args, nil -} - -func edgeAllowed(edge graphEdge, sourceNode graphNode, edgeIndex int, state *workflowExecState) bool { - cond := firstNonEmpty(cfgString(edge.Config, "condition"), cfgString(edge.Config, "expression")) - if cond != "" { - return evalCondition(cond, state) - } - if strings.EqualFold(strings.TrimSpace(sourceNode.Type), "condition") { - return conditionBranchAllowed(edge, edgeIndex, state) - } - return true -} - -func conditionBranchAllowed(edge graphEdge, edgeIndex int, state *workflowExecState) bool { - matched := conditionMatched(state) - if branch := conditionBranchHint(edge); branch != "" { - return (branch == "true" && matched) || (branch == "false" && !matched) - } - switch edgeIndex { - case 0: - return matched - case 1: - return !matched - default: - return false - } -} - -func conditionMatched(state *workflowExecState) bool { - v := strings.ToLower(cleanComparable(fmt.Sprint(valueFromPath("previous.matched", state)))) - return v == "true" || v == "1" -} - -func conditionBranchHint(edge graphEdge) string { - if edge.Config != nil { - switch strings.ToLower(strings.TrimSpace(cfgString(edge.Config, "branch"))) { - case "true", "yes", "y", "是": - return "true" - case "false", "no", "n", "否": - return "false" - } - } - switch strings.ToLower(strings.TrimSpace(edge.Label)) { - case "true", "yes", "y", "是": - return "true" - case "false", "no", "n", "否": - return "false" - } - return "" -} - -func emitConditionBranchProgress(args RunArgs, runID string, node graphNode, edges []graphEdge, nodes map[string]graphNode, state *workflowExecState) { - if args.Progress == nil || len(edges) == 0 { - return - } - for edgeIdx, edge := range edges { - allowed := edgeAllowed(edge, node, edgeIdx, state) - target := nodes[edge.Target] - targetLabel := strings.TrimSpace(target.Label) - if targetLabel == "" { - targetLabel = edge.Target - } - branchLabel := strings.TrimSpace(edge.Label) - if branchLabel == "" { - switch edgeIdx { - case 0: - branchLabel = "是" - case 1: - branchLabel = "否" - default: - branchLabel = fmt.Sprintf("分支 %d", edgeIdx+1) - } - } - cond := firstNonEmpty(cfgString(edge.Config, "condition"), cfgString(edge.Config, "expression")) - eventType := "workflow_branch_skipped" - msg := fmt.Sprintf("跳过分支「%s」→ %s", branchLabel, targetLabel) - if allowed { - eventType = "workflow_branch_taken" - msg = fmt.Sprintf("执行分支「%s」→ %s", branchLabel, targetLabel) - } - args.Progress(eventType, msg, map[string]any{ - "workflowRunId": runID, - "nodeId": node.ID, - "nodeType": node.Type, - "label": node.Label, - "branchLabel": branchLabel, - "targetId": edge.Target, - "targetLabel": targetLabel, - "edgeCondition": cond, - "matched": conditionMatched(state), - }) - } -} - -func cfgString(cfg map[string]any, key string) string { - if cfg == nil { - return "" - } - if v, ok := cfg[key]; ok { - return strings.TrimSpace(fmt.Sprint(v)) - } - return "" -} - -func firstNonEmpty(values ...string) string { - for _, value := range values { - if s := strings.TrimSpace(value); s != "" { - return s - } - } - return "" -} - -func truncateWorkflowPreview(s string, limit int) string { - s = strings.TrimSpace(s) - if limit <= 0 || len([]rune(s)) <= limit { - return s - } - runes := []rune(s) - return string(runes[:limit]) + "..." -} - -var templateVarRe = regexp.MustCompile(`\{\{\s*([a-zA-Z0-9_.-]+)\s*\}\}`) - -func resolveTemplate(s string, state *workflowExecState) string { - if strings.TrimSpace(s) == "" { - return fmt.Sprint(valueFromPath("previous.output", state)) - } - return templateVarRe.ReplaceAllStringFunc(s, func(match string) string { - m := templateVarRe.FindStringSubmatch(match) - if len(m) != 2 { - return match - } - return fmt.Sprint(valueFromPath(m[1], state)) - }) -} - -func valueFromPath(path string, state *workflowExecState) any { - parts := strings.Split(path, ".") - if len(parts) == 0 { - return "" - } - var cur any - switch parts[0] { - case "inputs", "input": - cur = state.inputs - case "previous", "prev": - cur = state.lastOutput - case "outputs": - cur = state.outputs - default: - if v, ok := state.inputs[parts[0]]; ok { - cur = v - } else if v, ok := state.nodeOutputs[parts[0]]; ok { - cur = v - } else { - return "" - } - } - for _, p := range parts[1:] { - m, ok := cur.(map[string]any) - if !ok { - return "" - } - cur = m[p] - } - if cur == nil { - return "" - } - return cur -} - -func evalCondition(expr string, state *workflowExecState) bool { - expr = strings.TrimSpace(expr) - if expr == "" { - return true - } - resolved := strings.TrimSpace(resolveTemplate(expr, state)) - switch { - case strings.Contains(resolved, "!="): - parts := strings.SplitN(resolved, "!=", 2) - return cleanComparable(parts[0]) != cleanComparable(parts[1]) - case strings.Contains(resolved, "=="): - parts := strings.SplitN(resolved, "==", 2) - return cleanComparable(parts[0]) == cleanComparable(parts[1]) - default: - v := strings.ToLower(cleanComparable(resolved)) - return v != "" && v != "false" && v != "0" && v != "null" - } -} - -func cleanComparable(s string) string { - s = strings.TrimSpace(s) - s = strings.Trim(s, `"'`) - return s -} - -func renderWorkflowResponse(roleName, workflowName string, version int, runID string, state *workflowExecState) string { - var sb strings.Builder - sb.WriteString(fmt.Sprintf("角色「%s」已完成工作流「%s」(版本 %d)。\n\n", roleName, workflowName, version)) - sb.WriteString(fmt.Sprintf("运行 ID:%s\n", runID)) - sb.WriteString(fmt.Sprintf("已执行节点:%d", len(state.executed))) - if len(state.skipped) > 0 { - sb.WriteString(fmt.Sprintf(",跳过节点:%d", len(state.skipped))) - } - sb.WriteString("\n\n") - if len(state.outputs) > 0 { - sb.WriteString("输出:\n") - keys := make([]string, 0, len(state.outputs)) - for k := range state.outputs { - keys = append(keys, k) - } - sort.Strings(keys) - for _, k := range keys { - sb.WriteString(fmt.Sprintf("- %s:%v\n", k, state.outputs[k])) - } - } else { - sb.WriteString("暂无输出。请检查是否配置了输出节点,或条件分支是否命中。\n") - } - if len(state.skipped) > 0 { - sb.WriteString("\n未执行的节点类型仍会保留运行记录:") - sb.WriteString(strings.Join(state.skipped, "、")) - sb.WriteString("。") - } - return strings.TrimSpace(sb.String()) + return &RunResult{Response: response, RunID: runID, Status: "completed"}, nil } diff --git a/internal/workflow/state.go b/internal/workflow/state.go new file mode 100644 index 00000000..9edbe880 --- /dev/null +++ b/internal/workflow/state.go @@ -0,0 +1,224 @@ +package workflow + +import ( + "fmt" + "regexp" + "sort" + "strings" + + "github.com/cloudwego/eino/schema" +) + +func init() { + schema.RegisterName[*WorkflowLocalState]("_cyberstrike_workflow_local_state") +} + +// WorkflowLocalState is the Eino WithGenLocalState payload (checkpoint-serializable). +type WorkflowLocalState struct { + Inputs map[string]any `json:"inputs,omitempty"` + Outputs map[string]any `json:"outputs,omitempty"` + NodeOutputs map[string]map[string]any `json:"nodeOutputs,omitempty"` + NodeProceed map[string]bool `json:"nodeProceed,omitempty"` + LastOutput map[string]any `json:"lastOutput,omitempty"` + Executed []string `json:"executed,omitempty"` + Skipped []string `json:"skipped,omitempty"` + WorkflowRunID string `json:"workflowRunId,omitempty"` + MainIterationOffset int `json:"mainIterationOffset,omitempty"` + SegmentMaxIteration int `json:"segmentMaxIteration,omitempty"` +} + +func newWorkflowLocalState(inputs map[string]interface{}, runID string) *WorkflowLocalState { + in := make(map[string]any, len(inputs)) + for k, v := range inputs { + in[k] = v + } + return &WorkflowLocalState{ + Inputs: in, + Outputs: make(map[string]any), + NodeOutputs: make(map[string]map[string]any), + NodeProceed: make(map[string]bool), + WorkflowRunID: runID, + } +} + +var templateVarRe = regexp.MustCompile(`\{\{\s*([a-zA-Z0-9_.-]+)\s*\}\}`) + +func resolveTemplate(s string, state *WorkflowLocalState) string { + if strings.TrimSpace(s) == "" { + return fmt.Sprint(valueFromPath("previous.output", state)) + } + return templateVarRe.ReplaceAllStringFunc(s, func(match string) string { + m := templateVarRe.FindStringSubmatch(match) + if len(m) != 2 { + return match + } + return fmt.Sprint(valueFromPath(m[1], state)) + }) +} + +func valueFromPath(path string, state *WorkflowLocalState) any { + parts := strings.Split(path, ".") + if len(parts) == 0 { + return "" + } + var cur any + switch parts[0] { + case "inputs", "input": + cur = state.Inputs + case "previous", "prev": + cur = state.LastOutput + case "outputs": + cur = state.Outputs + default: + if v, ok := state.Inputs[parts[0]]; ok { + cur = v + } else if v, ok := state.NodeOutputs[parts[0]]; ok { + cur = v + } else { + return "" + } + } + for _, p := range parts[1:] { + m, ok := cur.(map[string]any) + if !ok { + return "" + } + cur = m[p] + } + if cur == nil { + return "" + } + return cur +} + +func evalCondition(expr string, state *WorkflowLocalState) bool { + expr = strings.TrimSpace(expr) + if expr == "" { + return true + } + resolved := strings.TrimSpace(resolveTemplate(expr, state)) + switch { + case strings.Contains(resolved, "!="): + parts := strings.SplitN(resolved, "!=", 2) + return cleanComparable(parts[0]) != cleanComparable(parts[1]) + case strings.Contains(resolved, "=="): + parts := strings.SplitN(resolved, "==", 2) + return cleanComparable(parts[0]) == cleanComparable(parts[1]) + default: + v := strings.ToLower(cleanComparable(resolved)) + return v != "" && v != "false" && v != "0" && v != "null" + } +} + +func cleanComparable(s string) string { + s = strings.TrimSpace(s) + s = strings.Trim(s, `"'`) + return s +} + +func edgeAllowed(edge graphEdge, sourceNode graphNode, edgeIndex int, state *WorkflowLocalState) bool { + cond := firstNonEmpty(cfgString(edge.Config, "condition"), cfgString(edge.Config, "expression")) + if cond != "" { + return evalCondition(cond, state) + } + if strings.EqualFold(strings.TrimSpace(sourceNode.Type), "condition") { + return conditionBranchAllowed(edge, edgeIndex, state) + } + return true +} + +func conditionBranchAllowed(edge graphEdge, edgeIndex int, state *WorkflowLocalState) bool { + matched := conditionMatched(state) + if branch := conditionBranchHint(edge); branch != "" { + return (branch == "true" && matched) || (branch == "false" && !matched) + } + switch edgeIndex { + case 0: + return matched + case 1: + return !matched + default: + return false + } +} + +func conditionMatched(state *WorkflowLocalState) bool { + v := strings.ToLower(cleanComparable(fmt.Sprint(valueFromPath("previous.matched", state)))) + return v == "true" || v == "1" +} + +func conditionBranchHint(edge graphEdge) string { + if edge.Config != nil { + switch strings.ToLower(strings.TrimSpace(cfgString(edge.Config, "branch"))) { + case "true", "yes", "y", "是": + return "true" + case "false", "no", "n", "否": + return "false" + } + } + switch strings.ToLower(strings.TrimSpace(edge.Label)) { + case "true", "yes", "y", "是": + return "true" + case "false", "no", "n", "否": + return "false" + } + return "" +} + +func cfgString(cfg map[string]any, key string) string { + if cfg == nil { + return "" + } + if v, ok := cfg[key]; ok { + return strings.TrimSpace(fmt.Sprint(v)) + } + return "" +} + +func firstNonEmpty(values ...string) string { + for _, value := range values { + if s := strings.TrimSpace(value); s != "" { + return s + } + } + return "" +} + +func truncateWorkflowPreview(s string, limit int) string { + s = strings.TrimSpace(s) + if limit <= 0 || len([]rune(s)) <= limit { + return s + } + runes := []rune(s) + return string(runes[:limit]) + "..." +} + +func renderWorkflowResponse(roleName, workflowName string, version int, runID string, state *WorkflowLocalState) string { + var sb strings.Builder + sb.WriteString(fmt.Sprintf("角色「%s」已完成工作流「%s」(版本 %d)。\n\n", roleName, workflowName, version)) + sb.WriteString(fmt.Sprintf("运行 ID:%s\n", runID)) + sb.WriteString(fmt.Sprintf("已执行节点:%d", len(state.Executed))) + if len(state.Skipped) > 0 { + sb.WriteString(fmt.Sprintf(",跳过节点:%d", len(state.Skipped))) + } + sb.WriteString("\n\n") + if len(state.Outputs) > 0 { + sb.WriteString("输出:\n") + keys := make([]string, 0, len(state.Outputs)) + for k := range state.Outputs { + keys = append(keys, k) + } + sort.Strings(keys) + for _, k := range keys { + sb.WriteString(fmt.Sprintf("- %s:%v\n", k, state.Outputs[k])) + } + } else { + sb.WriteString("暂无输出。请检查是否配置了输出节点,或条件分支是否命中。\n") + } + if len(state.Skipped) > 0 { + sb.WriteString("\n未执行的节点类型仍会保留运行记录:") + sb.WriteString(strings.Join(state.Skipped, "、")) + sb.WriteString("。") + } + return strings.TrimSpace(sb.String()) +} diff --git a/internal/workflow/types.go b/internal/workflow/types.go new file mode 100644 index 00000000..218c024b --- /dev/null +++ b/internal/workflow/types.go @@ -0,0 +1,74 @@ +package workflow + +import ( + "fmt" + "strconv" +) + +// WorkflowInput is the typed entry for Eino compose.Workflow[I,O]. +type WorkflowInput struct { + Message string `json:"message"` + ConversationID string `json:"conversationId"` + ProjectID string `json:"projectId"` + Role string `json:"role"` + WorkflowID string `json:"workflowId"` + WorkflowVersion int `json:"workflowVersion"` +} + +// WorkflowOutput aggregates terminal node payloads keyed by canvas node id. +type WorkflowOutput map[string]any + +// WorkflowNodeOutput is the per-node lambda payload (alias for Eino edge type alignment). +type WorkflowNodeOutput = map[string]interface{} + +func workflowInputFromMap(m map[string]interface{}) WorkflowInput { + in := WorkflowInput{} + if m == nil { + return in + } + if v, ok := m["message"].(string); ok { + in.Message = v + } else if m["message"] != nil { + in.Message = fmt.Sprint(m["message"]) + } + if v, ok := m["conversationId"].(string); ok { + in.ConversationID = v + } + if v, ok := m["projectId"].(string); ok { + in.ProjectID = v + } + if v, ok := m["role"].(string); ok { + in.Role = v + } + if v, ok := m["workflowId"].(string); ok { + in.WorkflowID = v + } + switch v := m["workflowVersion"].(type) { + case int: + in.WorkflowVersion = v + case int64: + in.WorkflowVersion = int(v) + case float64: + in.WorkflowVersion = int(v) + case string: + if n, err := strconv.Atoi(v); err == nil { + in.WorkflowVersion = n + } + } + return in +} + +func (in WorkflowInput) toStateInputs() map[string]any { + return map[string]any{ + "message": in.Message, + "conversationId": in.ConversationID, + "projectId": in.ProjectID, + "role": in.Role, + "workflowId": in.WorkflowID, + "workflowVersion": in.WorkflowVersion, + } +} + +func cacheKey(workflowID string, version int) string { + return workflowID + ":" + strconv.Itoa(version) +}