Add files via upload

This commit is contained in:
公明
2026-07-03 19:36:40 +08:00
committed by GitHub
parent c86825d365
commit 93a600d60e
16 changed files with 2025 additions and 821 deletions
+24
View File
@@ -0,0 +1,24 @@
package workflow
import (
"context"
"github.com/cloudwego/eino/compose"
)
// compileAgentSubgraph wraps an Agent canvas node as an Eino subgraph (AddGraphNode best practice).
func compileAgentSubgraph(_ context.Context, node graphNode) (compose.AnyGraph, error) {
n := node
innerID := n.ID + "__agent"
g := compose.NewGraph[WorkflowNodeOutput, WorkflowNodeOutput]()
_ = g.AddLambdaNode(innerID, compose.InvokableLambda(func(runCtx context.Context, _ WorkflowNodeOutput) (WorkflowNodeOutput, error) {
return runWorkflowNodeLambda(runCtx, n)
}))
if err := g.AddEdge(compose.START, innerID); err != nil {
return nil, err
}
if err := g.AddEdge(innerID, compose.END); err != nil {
return nil, err
}
return g, nil
}
+141
View File
@@ -0,0 +1,141 @@
package workflow
import (
"encoding/json"
"fmt"
"strings"
)
// FieldBinding selects a value from workflow state (replaces {{...}} templates).
type FieldBinding struct {
From string `json:"from"` // inputs | previous | <nodeId>
Field string `json:"field"` // e.g. output, message
}
func parseFieldBinding(cfg map[string]any, keys ...string) (FieldBinding, bool) {
for _, key := range keys {
if cfg == nil {
continue
}
raw, ok := cfg[key]
if !ok || raw == nil {
continue
}
switch v := raw.(type) {
case map[string]any:
return FieldBinding{
From: strings.TrimSpace(fmt.Sprint(v["from"])),
Field: strings.TrimSpace(fmt.Sprint(v["field"])),
}, true
case string:
s := strings.TrimSpace(v)
if s == "" {
continue
}
var b FieldBinding
if err := json.Unmarshal([]byte(s), &b); err == nil && (b.From != "" || b.Field != "") {
return b, true
}
}
}
return FieldBinding{}, false
}
func defaultBinding(from, field string) FieldBinding {
return FieldBinding{From: from, Field: field}
}
func resolveBinding(b FieldBinding, state *WorkflowLocalState) any {
from := strings.TrimSpace(b.From)
field := strings.TrimSpace(b.Field)
if field == "" {
field = "output"
}
if from == "" || from == "previous" || from == "prev" {
if field == "output" && state.LastOutput != nil {
return state.LastOutput["output"]
}
return valueFromPath("previous."+field, state)
}
if from == "inputs" || from == "input" {
if field == "" {
return state.Inputs
}
return valueFromPath("inputs."+field, state)
}
if from == "outputs" {
return valueFromPath("outputs."+field, state)
}
return valueFromPath(from+"."+field, state)
}
func resolveBindingString(b FieldBinding, state *WorkflowLocalState) string {
return strings.TrimSpace(fmt.Sprint(resolveBinding(b, state)))
}
func resolveNodeInputBinding(cfg map[string]any, state *WorkflowLocalState) string {
if b, ok := parseFieldBinding(cfg, "input_binding"); ok {
return resolveBindingString(b, state)
}
// legacy template field removed — default previous.output
return resolveBindingString(defaultBinding("previous", "output"), state)
}
func resolveOutputSourceBinding(cfg map[string]any, state *WorkflowLocalState) any {
if b, ok := parseFieldBinding(cfg, "source_binding"); ok {
return resolveBinding(b, state)
}
return resolveBinding(defaultBinding("previous", "output"), state)
}
func resolveHITLPromptBinding(cfg map[string]any, state *WorkflowLocalState) string {
if b, ok := parseFieldBinding(cfg, "prompt_binding"); ok {
return resolveBindingString(b, state)
}
if s := cfgString(cfg, "prompt"); s != "" {
return s
}
return resolveBindingString(defaultBinding("previous", "output"), state)
}
func toolArgumentBindings(cfg map[string]any) map[string]FieldBinding {
raw, ok := cfg["argument_bindings"].(map[string]any)
if !ok || len(raw) == 0 {
return nil
}
out := make(map[string]FieldBinding, len(raw))
for argName, v := range raw {
m, ok := v.(map[string]any)
if !ok {
continue
}
out[argName] = FieldBinding{
From: strings.TrimSpace(fmt.Sprint(m["from"])),
Field: strings.TrimSpace(fmt.Sprint(m["field"])),
}
}
return out
}
func resolveToolArguments(cfg map[string]any, state *WorkflowLocalState) (map[string]interface{}, error) {
bindings := toolArgumentBindings(cfg)
if len(bindings) > 0 {
args := make(map[string]interface{}, len(bindings))
for k, b := range bindings {
args[k] = resolveBinding(b, state)
}
return args, nil
}
raw := cfgString(cfg, "arguments")
if raw == "" {
return map[string]interface{}{}, nil
}
var args map[string]interface{}
if err := json.Unmarshal([]byte(raw), &args); err != nil {
return nil, err
}
if args == nil {
args = map[string]interface{}{}
}
return args, nil
}
+69
View File
@@ -0,0 +1,69 @@
package workflow
import (
"context"
"fmt"
"os"
"path/filepath"
"strings"
"sync"
)
// fileCheckPointStore persists Eino workflow checkpoints on disk (per run id).
type fileCheckPointStore struct {
dir string
mu sync.RWMutex
}
func newFileCheckPointStore(dir string) (*fileCheckPointStore, error) {
dir = strings.TrimSpace(dir)
if dir == "" {
dir = filepath.Join("data", "workflow-checkpoints")
}
if err := os.MkdirAll(dir, 0o755); err != nil {
return nil, fmt.Errorf("create workflow checkpoint dir: %w", err)
}
return &fileCheckPointStore{dir: dir}, nil
}
func (s *fileCheckPointStore) path(id string) (string, error) {
id = strings.TrimSpace(id)
if id == "" {
return "", fmt.Errorf("checkpoint id is empty")
}
if strings.Contains(id, "..") || strings.ContainsAny(id, `/\`) {
return "", fmt.Errorf("invalid checkpoint id")
}
return filepath.Join(s.dir, id+".ckpt"), nil
}
func (s *fileCheckPointStore) Get(_ context.Context, checkPointID string) ([]byte, bool, error) {
s.mu.RLock()
defer s.mu.RUnlock()
p, err := s.path(checkPointID)
if err != nil {
return nil, false, err
}
data, err := os.ReadFile(p)
if err != nil {
if os.IsNotExist(err) {
return nil, false, nil
}
return nil, false, err
}
return data, true, nil
}
func (s *fileCheckPointStore) Set(_ context.Context, checkPointID string, checkPoint []byte) error {
s.mu.Lock()
defer s.mu.Unlock()
p, err := s.path(checkPointID)
if err != nil {
return err
}
tmp := p + ".tmp"
if err := os.WriteFile(tmp, checkPoint, 0o600); err != nil {
return err
}
return os.Rename(tmp, p)
}
+107
View File
@@ -0,0 +1,107 @@
package workflow
import (
"context"
"fmt"
"github.com/cloudwego/eino/compose"
)
func hasConditionalOutgoingEdges(idx *graphIndex, nodeID string) bool {
for _, edge := range idx.outgoing[nodeID] {
cond := firstNonEmpty(cfgString(edge.Config, "condition"), cfgString(edge.Config, "expression"))
if cond != "" {
return true
}
}
return false
}
func wireConditionBranch(
wf *compose.Workflow[WorkflowInput, WorkflowOutput],
nodeRefs map[string]*compose.WorkflowNode,
idx *graphIndex,
condID string,
condNode graphNode,
) error {
edges := idx.outgoing[condID]
if len(edges) == 0 {
return nil
}
branchID := branchNodeID(condID)
wf.AddPassthroughNode(branchID).AddInput(condID)
endNodes := map[string]bool{compose.END: true}
for _, edge := range edges {
endNodes[edge.Target] = true
}
sortedEdges := append([]graphEdge(nil), edges...)
sortEdgesByCanvas(sortedEdges, idx.nodes)
branch := compose.NewGraphBranch(func(runCtx context.Context, _ map[string]any) (string, error) {
rt := workflowRuntimeFrom(runCtx)
if rt == nil {
return compose.END, fmt.Errorf("workflow runtime missing in context")
}
emitConditionBranchProgress(rt.args, rt.runID, condNode, sortedEdges, idx.nodes, rt.state)
for edgeIdx, edge := range sortedEdges {
if conditionBranchAllowed(edge, edgeIdx, rt.state) {
return edge.Target, nil
}
}
return compose.END, nil
}, endNodes)
wf.AddBranch(branchID, branch)
for _, edge := range edges {
if target, ok := nodeRefs[edge.Target]; ok {
target.AddInput(branchID)
}
}
return nil
}
func wireEdgeConditionBranch(
wf *compose.Workflow[WorkflowInput, WorkflowOutput],
nodeRefs map[string]*compose.WorkflowNode,
idx *graphIndex,
sourceID string,
sourceNode graphNode,
) error {
edges := idx.outgoing[sourceID]
if len(edges) == 0 {
return nil
}
branchID := edgeBranchNodeID(sourceID)
wf.AddPassthroughNode(branchID).AddInput(sourceID)
endNodes := map[string]bool{compose.END: true}
for _, edge := range edges {
endNodes[edge.Target] = true
}
sortedEdges := append([]graphEdge(nil), edges...)
sortEdgesByCanvas(sortedEdges, idx.nodes)
branch := compose.NewGraphBranch(func(runCtx context.Context, _ map[string]any) (string, error) {
rt := workflowRuntimeFrom(runCtx)
if rt == nil {
return compose.END, fmt.Errorf("workflow runtime missing in context")
}
for edgeIdx, edge := range sortedEdges {
if edgeAllowed(edge, sourceNode, edgeIdx, rt.state) {
return edge.Target, nil
}
}
return compose.END, nil
}, endNodes)
wf.AddBranch(branchID, branch)
for _, edge := range edges {
if target, ok := nodeRefs[edge.Target]; ok {
target.AddInput(branchID)
}
}
return nil
}
+22
View File
@@ -0,0 +1,22 @@
package workflow
import (
"context"
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/einoobserve"
)
func attachWorkflowCallbacks(ctx context.Context, cfg *config.Config, args RunArgs, workflowName string) context.Context {
if cfg == nil {
return ctx
}
cbCfg := &cfg.MultiAgent.EinoCallbacks
return einoobserve.AttachAgentRunCallbacks(ctx, cbCfg, einoobserve.Params{
Logger: args.Logger,
Progress: args.Progress,
ConversationID: args.ConversationID,
OrchMode: "workflow",
OrchestratorName: workflowName,
})
}
+190
View File
@@ -0,0 +1,190 @@
package workflow
import (
"context"
"encoding/json"
"fmt"
"strings"
"github.com/cloudwego/eino/compose"
)
func executeEinoGraph(ctx context.Context, args RunArgs, runID string, workflowID string, version int, g *graphDef, state *WorkflowLocalState) error {
_, err := invokeEinoGraph(ctx, args, runID, workflowID, version, g, state, false)
return err
}
func invokeEinoGraph(ctx context.Context, args RunArgs, runID string, workflowID string, version int, g *graphDef, state *WorkflowLocalState, resume bool) (bool, error) {
wfInput := workflowInputFromMap(state.Inputs)
if resume {
wfInput = WorkflowInput{}
}
rt := &workflowRuntime{
args: args,
runID: runID,
idx: indexGraph(g),
state: state,
}
art, err := defaultEngine.getOrCompile(ctx, workflowID, version, g)
if err != nil {
return false, fmt.Errorf("编译 Eino Workflow 失败: %w", err)
}
rt.idx = art.idx
runCtx := withWorkflowRuntime(ctx, rt)
runCtx = attachWorkflowCallbacks(runCtx, args.AppCfg, args, workflowID)
invokeOpts := []compose.Option{compose.WithCheckPointID(runID)}
for {
_, err = art.runnable.Invoke(runCtx, wfInput, invokeOpts...)
if err == nil {
return false, nil
}
if hitlErr := extractAwaitingHITL(err, art, runID, args, state); hitlErr != nil {
return true, hitlErr
}
return false, err
}
}
func extractAwaitingHITL(err error, art *compiledArtifact, runID string, args RunArgs, state *WorkflowLocalState) error {
info, ok := compose.ExtractInterruptInfo(err)
if !ok || len(art.hitlIDs) == 0 {
return nil
}
nodeID := nextHITLNodeID(info, art.hitlIDs)
node := art.idx.nodes[nodeID]
if nodeID == "" {
return nil
}
prompt := resolveHITLPromptBinding(node.Config, state)
label := firstNonEmpty(node.Label, nodeID)
if args.DB != nil {
pending := map[string]any{
"nodeId": nodeID,
"label": label,
"prompt": prompt,
"reviewer": cfgString(node.Config, "reviewer"),
}
pendingJSON, _ := json.Marshal(pending)
_ = args.DB.SetWorkflowRunAwaitingHITL(runID, nodeID, string(pendingJSON))
}
if args.Progress != nil {
args.Progress("workflow_hitl_waiting", fmt.Sprintf("等待人工确认:%s", label), map[string]any{
"workflowRunId": runID,
"nodeId": nodeID,
"label": label,
"prompt": prompt,
"reviewer": cfgString(node.Config, "reviewer"),
"mode": "interactive",
"resumeApi": fmt.Sprintf("/api/workflows/runs/%s/resume", runID),
})
}
return &AwaitingHITLError{
RunID: runID,
NodeID: nodeID,
NodeLabel: label,
Prompt: prompt,
Reviewer: cfgString(node.Config, "reviewer"),
}
}
func nextHITLNodeID(info *compose.InterruptInfo, hitlIDs []string) string {
if info != nil && len(info.BeforeNodes) > 0 {
for _, id := range info.BeforeNodes {
for _, hitl := range hitlIDs {
if id == hitl {
return id
}
}
}
return info.BeforeNodes[0]
}
if len(hitlIDs) == 0 {
return ""
}
return hitlIDs[0]
}
// ResumeWorkflowRun continues a run paused at HITL after human decision.
func ResumeWorkflowRun(ctx context.Context, args RunArgs, runID string, approved bool, comment string) (*RunResult, error) {
run, err := args.DB.GetWorkflowRun(runID)
if err != nil {
return nil, err
}
if run == nil {
return nil, fmt.Errorf("工作流运行不存在")
}
if run.Status != "awaiting_hitl" {
return nil, fmt.Errorf("工作流运行不在等待审批状态: %s", run.Status)
}
wf, err := args.DB.GetWorkflowDefinition(run.WorkflowID)
if err != nil || wf == nil {
return nil, fmt.Errorf("工作流定义不存在")
}
graph, err := parseGraph(wf.GraphJSON)
if err != nil {
return nil, err
}
var input map[string]interface{}
_ = json.Unmarshal([]byte(run.InputJSON), &input)
state := newWorkflowLocalState(input, runID)
if state.Inputs == nil {
state.Inputs = map[string]any{}
}
state.Inputs["_hitl_approved"] = approved
state.Inputs["_hitl_comment"] = strings.TrimSpace(comment)
state.Inputs["_hitl_node_id"] = run.PendingHITLNodeID
if !approved {
errText := strings.TrimSpace(comment)
if errText == "" {
errText = "人工审批拒绝"
}
_ = args.DB.FinishWorkflowRun(runID, "rejected", "", errText)
return &RunResult{
RunID: runID,
Response: fmt.Sprintf("工作流已在审批节点「%s」被拒绝。", run.PendingHITLNodeID),
Status: "rejected",
}, nil
}
_ = args.DB.SetWorkflowRunStatus(runID, "running")
resumeArgs := args
if strings.TrimSpace(resumeArgs.ConversationID) == "" {
resumeArgs.ConversationID = run.ConversationID
}
awaiting, err := invokeEinoGraph(ctx, resumeArgs, runID, wf.ID, run.WorkflowVersion, graph, state, true)
if err != nil {
if IsAwaitingHITL(err) {
return &RunResult{
RunID: runID,
Status: "awaiting_hitl",
Response: fmt.Sprintf("工作流在节点「%s」等待下一次人工确认。", err.(*AwaitingHITLError).NodeID),
AwaitingHITL: true,
}, nil
}
_ = args.DB.FinishWorkflowRun(runID, "failed", "", err.Error())
return nil, err
}
_ = awaiting
output := map[string]interface{}{
"workflowId": wf.ID,
"workflowName": wf.Name,
"workflowVersion": wf.Version,
"workflowRunId": runID,
"status": "completed",
"outputs": state.Outputs,
"executedNodes": state.Executed,
"skippedNodes": state.Skipped,
"engine": "eino_workflow",
}
outputJSON, _ := json.Marshal(output)
response := renderWorkflowResponse(args.Role.Name, wf.Name, wf.Version, runID, state)
_ = args.DB.FinishWorkflowRun(runID, "completed", string(outputJSON), "")
return &RunResult{Response: response, RunID: runID, Status: "completed"}, nil
}
+195
View File
@@ -0,0 +1,195 @@
package workflow
import (
"context"
"path/filepath"
"testing"
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/database"
"go.uber.org/zap"
)
func testWorkflowDB(t *testing.T) *database.DB {
t.Helper()
dir := t.TempDir()
db, err := database.NewDB(filepath.Join(dir, "workflow.db"), zap.NewNop())
if err != nil {
t.Fatalf("NewDB: %v", err)
}
t.Cleanup(func() { _ = db.Close() })
return db
}
func linearStartOutputGraph() string {
return `{
"nodes": [
{"id": "start-1", "type": "start", "label": "开始", "position": {"x": 0, "y": 0}, "config": {}},
{"id": "out-1", "type": "output", "label": "输出", "position": {"x": 0, "y": 120}, "config": {"output_key": "result", "source_binding": {"from": "inputs", "field": "message"}}}
],
"edges": [
{"id": "e1", "source": "start-1", "target": "out-1"}
],
"config": {"schema_version": 1}
}`
}
func conditionBranchGraph() string {
return `{
"nodes": [
{"id": "start-1", "type": "start", "label": "开始", "position": {"x": 0, "y": 0}, "config": {}},
{"id": "cond-1", "type": "condition", "label": "判断", "position": {"x": 0, "y": 80}, "config": {"expression": "{{inputs.message}} == yes"}},
{"id": "out-yes", "type": "output", "label": "是", "position": {"x": -80, "y": 160}, "config": {"output_key": "branch", "static_value": "yes"}},
{"id": "out-no", "type": "output", "label": "否", "position": {"x": 80, "y": 160}, "config": {"output_key": "branch", "static_value": "no"}}
],
"edges": [
{"id": "e1", "source": "start-1", "target": "cond-1"},
{"id": "e2", "source": "cond-1", "target": "out-yes", "label": "是"},
{"id": "e3", "source": "cond-1", "target": "out-no", "label": "否"}
],
"config": {"schema_version": 1}
}`
}
func TestValidateGraphJSON_linear(t *testing.T) {
if err := ValidateGraphJSON(context.Background(), linearStartOutputGraph()); err != nil {
t.Fatalf("validate: %v", err)
}
}
func TestCompileEngine_linear(t *testing.T) {
ctx := context.Background()
SetCheckpointDir(t.TempDir())
g, err := parseGraph(linearStartOutputGraph())
if err != nil {
t.Fatal(err)
}
if _, err := defaultEngine.compile(ctx, g); err != nil {
t.Fatalf("compile: %v", err)
}
}
func createTestWorkflowRun(t *testing.T, db *database.DB, runID string) {
t.Helper()
if err := db.CreateWorkflowRun(&database.WorkflowRun{
ID: runID,
WorkflowID: "test-wf",
Status: "running",
}); err != nil {
t.Fatalf("CreateWorkflowRun: %v", err)
}
}
func TestExecuteEinoGraph_linearStartOutput(t *testing.T) {
ctx := context.Background()
SetCheckpointDir(t.TempDir())
db := testWorkflowDB(t)
createTestWorkflowRun(t, db, "run-linear")
g, err := parseGraph(linearStartOutputGraph())
if err != nil {
t.Fatal(err)
}
state := newWorkflowLocalState(map[string]interface{}{"message": "ping"}, "run-linear")
args := RunArgs{DB: db}
if err := executeEinoGraph(ctx, args, "run-linear", "test-wf", 1, g, state); err != nil {
t.Fatalf("execute: %v", err)
}
if got := state.Outputs["result"]; got != "ping" {
t.Fatalf("outputs[result] = %v, want ping", got)
}
if len(state.Executed) != 2 {
t.Fatalf("executed nodes = %d, want 2", len(state.Executed))
}
}
func TestExecuteEinoGraph_conditionBranch(t *testing.T) {
ctx := context.Background()
SetCheckpointDir(t.TempDir())
db := testWorkflowDB(t)
createTestWorkflowRun(t, db, "run-yes")
createTestWorkflowRun(t, db, "run-no")
g, err := parseGraph(conditionBranchGraph())
if err != nil {
t.Fatal(err)
}
stateYes := newWorkflowLocalState(map[string]interface{}{"message": "yes"}, "run-yes")
if err := executeEinoGraph(ctx, RunArgs{DB: db}, "run-yes", "test-wf-branch", 1, g, stateYes); err != nil {
t.Fatalf("execute yes: %v", err)
}
if got := stateYes.Outputs["branch"]; got != "yes" {
t.Fatalf("yes branch output = %v", got)
}
stateNo := newWorkflowLocalState(map[string]interface{}{"message": "no"}, "run-no")
if err := executeEinoGraph(ctx, RunArgs{DB: db}, "run-no", "test-wf-branch", 1, g, stateNo); err != nil {
t.Fatalf("execute no: %v", err)
}
if got := stateNo.Outputs["branch"]; got != "no" {
t.Fatalf("no branch output = %v", got)
}
}
func TestRunRoleBoundWorkflow_integration(t *testing.T) {
ctx := context.Background()
SetCheckpointDir(t.TempDir())
db := testWorkflowDB(t)
graph := linearStartOutputGraph()
if err := db.UpsertWorkflowDefinition(&database.WorkflowDefinition{
ID: "wf-linear",
Name: "线性流程",
Version: 1,
GraphJSON: graph,
Enabled: true,
}); err != nil {
t.Fatal(err)
}
role := config.RoleConfig{
Name: "tester",
Enabled: true,
WorkflowID: "wf-linear",
WorkflowPolicy: "auto",
}
result, err := RunRoleBoundWorkflow(ctx, RunArgs{
DB: db,
Logger: zap.NewNop(),
Role: role,
UserMessage: "from-role",
})
if err != nil {
t.Fatalf("RunRoleBoundWorkflow: %v", err)
}
if result == nil || result.RunID == "" {
t.Fatal("expected run result")
}
}
func TestCompiledCache_reuse(t *testing.T) {
ctx := context.Background()
SetCheckpointDir(t.TempDir())
InvalidateCompiledCache("cache-wf")
g, err := parseGraph(linearStartOutputGraph())
if err != nil {
t.Fatal(err)
}
a1, err := defaultEngine.getOrCompile(ctx, "cache-wf", 1, g)
if err != nil {
t.Fatal(err)
}
a2, err := defaultEngine.getOrCompile(ctx, "cache-wf", 1, g)
if err != nil {
t.Fatal(err)
}
if a1 != a2 {
t.Fatal("expected cached artifact pointer reuse")
}
InvalidateCompiledCache("cache-wf")
a3, err := defaultEngine.getOrCompile(ctx, "cache-wf", 1, g)
if err != nil {
t.Fatal(err)
}
if a1 == a3 {
t.Fatal("expected new artifact after invalidation")
}
}
+64
View File
@@ -0,0 +1,64 @@
package workflow
import (
"context"
"cyberstrike-ai/internal/agent"
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/database"
"go.uber.org/zap"
)
type workflowRuntimeCtxKey struct{}
// workflowRuntime carries per-run execution context into Eino Workflow local state.
type workflowRuntime struct {
args RunArgs
runID string
idx *graphIndex
state *WorkflowLocalState
}
func withWorkflowRuntime(ctx context.Context, rt *workflowRuntime) context.Context {
return context.WithValue(ctx, workflowRuntimeCtxKey{}, rt)
}
func workflowRuntimeFrom(ctx context.Context) *workflowRuntime {
rt, _ := ctx.Value(workflowRuntimeCtxKey{}).(*workflowRuntime)
return rt
}
func newWorkflowRuntime(args RunArgs, runID string, idx *graphIndex, inputs map[string]interface{}) *workflowRuntime {
return &workflowRuntime{
args: args,
runID: runID,
idx: idx,
state: newWorkflowLocalState(inputs, runID),
}
}
// RunArgs is the execution context for a role-bound workflow run.
type RunArgs struct {
DB *database.DB
Logger *zap.Logger
Role config.RoleConfig
AppCfg *config.Config
Agent *agent.Agent
ConversationID string
ProjectID string
UserMessage string
History []agent.ChatMessage
RoleTools []string
AgentsMarkdownDir string
SystemPromptExtra string
AssistantMessageID string
Progress agent.ProgressCallback
}
type RunResult struct {
Response string
RunID string
Status string
AwaitingHITL bool
}
+236
View File
@@ -0,0 +1,236 @@
package workflow
import (
"context"
"fmt"
"strings"
"sync"
"github.com/cloudwego/eino/compose"
)
type compiledArtifact struct {
runnable compose.Runnable[WorkflowInput, WorkflowOutput]
idx *graphIndex
hitlIDs []string
}
// Engine compiles and caches Eino Workflow artifacts.
type Engine struct {
mu sync.RWMutex
cache map[string]*compiledArtifact
cpStore compose.CheckPointStore
cpStoreMu sync.Once
cpStoreErr error
checkpointDir string
}
var defaultEngine = &Engine{
cache: make(map[string]*compiledArtifact),
checkpointDir: "data/workflow-checkpoints",
}
// SetCheckpointDir overrides the workflow checkpoint root (mainly for tests).
func SetCheckpointDir(dir string) {
defaultEngine.mu.Lock()
defer defaultEngine.mu.Unlock()
defaultEngine.checkpointDir = strings.TrimSpace(dir)
defaultEngine.cpStore = nil
defaultEngine.cpStoreErr = nil
defaultEngine.cpStoreMu = sync.Once{}
}
func (e *Engine) checkpointStore() (compose.CheckPointStore, error) {
e.cpStoreMu.Do(func() {
e.cpStore, e.cpStoreErr = newFileCheckPointStore(e.checkpointDir)
})
return e.cpStore, e.cpStoreErr
}
// InvalidateCompiledCache drops cached compilations for a workflow id.
func InvalidateCompiledCache(workflowID string) {
workflowID = strings.TrimSpace(workflowID)
if workflowID == "" {
return
}
defaultEngine.mu.Lock()
defer defaultEngine.mu.Unlock()
for key := range defaultEngine.cache {
if strings.HasPrefix(key, workflowID+":") {
delete(defaultEngine.cache, key)
}
}
}
// ValidateGraphJSON parses and trial-compiles a canvas graph (save-time gate).
func ValidateGraphJSON(ctx context.Context, graphJSON string) error {
g, err := parseGraph(graphJSON)
if err != nil {
return err
}
idx := indexGraph(g)
if len(findStartNodeIDs(idx)) == 0 {
return fmt.Errorf("工作流缺少可执行的起点节点")
}
if !hasTerminalNode(idx) {
return fmt.Errorf("工作流至少需要一个无出边的终点或 output/end 节点")
}
_, err = defaultEngine.compile(ctx, g)
return err
}
func hasTerminalNode(idx *graphIndex) bool {
for id, node := range idx.nodes {
if len(idx.outgoing[id]) == 0 {
return true
}
if strings.EqualFold(node.Type, "end") || strings.EqualFold(node.Type, "output") {
return true
}
}
return false
}
func (e *Engine) getOrCompile(ctx context.Context, workflowID string, version int, g *graphDef) (*compiledArtifact, error) {
key := cacheKey(workflowID, version)
e.mu.RLock()
if art, ok := e.cache[key]; ok {
e.mu.RUnlock()
return art, nil
}
e.mu.RUnlock()
art, err := e.compile(ctx, g)
if err != nil {
return nil, err
}
e.mu.Lock()
if existing, ok := e.cache[key]; ok {
e.mu.Unlock()
return existing, nil
}
e.cache[key] = art
e.mu.Unlock()
return art, nil
}
func (e *Engine) compile(ctx context.Context, g *graphDef) (*compiledArtifact, error) {
cpStore, err := e.checkpointStore()
if err != nil {
return nil, err
}
idx := indexGraph(g)
hitlIDs := collectHITLNodeIDs(idx)
compileOpts := []compose.GraphCompileOption{
compose.WithGraphName("CyberStrikeWorkflow"),
compose.WithCheckPointStore(cpStore),
}
if len(hitlIDs) > 0 {
compileOpts = append(compileOpts, compose.WithInterruptBeforeNodes(hitlIDs))
}
wf := compose.NewWorkflow[WorkflowInput, WorkflowOutput](
compose.WithGenLocalState(func(runCtx context.Context) *WorkflowLocalState {
if rt := workflowRuntimeFrom(runCtx); rt != nil && rt.state != nil {
return rt.state
}
return &WorkflowLocalState{
Outputs: make(map[string]any),
NodeOutputs: make(map[string]map[string]any),
NodeProceed: make(map[string]bool),
}
}),
)
nodeRefs := make(map[string]*compose.WorkflowNode, len(idx.nodes))
for id, node := range idx.nodes {
n := node
if strings.EqualFold(n.Type, "agent") {
sub, err := compileAgentSubgraph(ctx, n)
if err != nil {
return nil, fmt.Errorf("编译 Agent 子图 %s 失败: %w", id, err)
}
nodeRefs[id] = wf.AddGraphNode(id, sub)
continue
}
if strings.EqualFold(n.Type, "start") {
nodeRefs[id] = wf.AddLambdaNode(id, compose.InvokableLambda(func(runCtx context.Context, _ WorkflowInput) (WorkflowNodeOutput, error) {
return runWorkflowNodeLambda(runCtx, n)
}))
continue
}
nodeRefs[id] = wf.AddLambdaNode(id, compose.InvokableLambda(func(runCtx context.Context, _ WorkflowNodeOutput) (WorkflowNodeOutput, error) {
return runWorkflowNodeLambda(runCtx, n)
}))
}
for id, node := range idx.nodes {
if strings.EqualFold(node.Type, "condition") {
if err := wireConditionBranch(wf, nodeRefs, idx, id, node); err != nil {
return nil, err
}
continue
}
if hasConditionalOutgoingEdges(idx, id) {
if err := wireEdgeConditionBranch(wf, nodeRefs, idx, id, node); err != nil {
return nil, err
}
continue
}
for _, edge := range idx.outgoing[id] {
if target, ok := nodeRefs[edge.Target]; ok {
target.AddInput(id)
}
}
}
for _, startID := range findStartNodeIDs(idx) {
if ref, ok := nodeRefs[startID]; ok {
ref.AddInput(compose.START)
}
}
endNode := wf.End()
for id, node := range idx.nodes {
if len(idx.outgoing[id]) == 0 || strings.EqualFold(node.Type, "end") {
endNode.AddInput(id, compose.ToField(id))
}
}
runnable, err := wf.Compile(ctx, compileOpts...)
if err != nil {
return nil, err
}
return &compiledArtifact{runnable: runnable, idx: idx, hitlIDs: hitlIDs}, nil
}
func collectHITLNodeIDs(idx *graphIndex) []string {
var ids []string
for id, node := range idx.nodes {
if strings.EqualFold(node.Type, "hitl") {
ids = append(ids, id)
}
}
return ids
}
func runWorkflowNodeLambda(runCtx context.Context, n graphNode) (WorkflowNodeOutput, error) {
localRT := workflowRuntimeFrom(runCtx)
if localRT == nil {
return nil, fmt.Errorf("workflow runtime missing in context")
}
result, proceed, err := executeNode(runCtx, localRT.args, localRT.runID, n, localRT.state)
if err != nil {
return nil, err
}
localRT.state.NodeOutputs[n.ID] = result
localRT.state.LastOutput = result
if !proceed && !strings.EqualFold(n.Type, "end") {
label := firstNonEmpty(n.Label, n.ID)
if errText := cfgString(result, "error"); errText != "" {
return result, fmt.Errorf("节点「%s」失败: %s", label, errText)
}
return result, fmt.Errorf("节点「%s」未继续执行", label)
}
return result, nil
}
+24
View File
@@ -0,0 +1,24 @@
package workflow
import "errors"
// AwaitingHITLError indicates the workflow paused before a HITL node for human approval.
type AwaitingHITLError struct {
RunID string
NodeID string
NodeLabel string
Prompt string
Reviewer string
}
func (e *AwaitingHITLError) Error() string {
if e == nil {
return "workflow awaiting human approval"
}
return "workflow awaiting human approval at node " + e.NodeID
}
func IsAwaitingHITL(err error) bool {
var target *AwaitingHITLError
return errors.As(err, &target)
}
+153
View File
@@ -0,0 +1,153 @@
package workflow
import (
"encoding/json"
"fmt"
"sort"
"strings"
)
type graphDef struct {
Nodes []graphNode `json:"nodes"`
Edges []graphEdge `json:"edges"`
Config map[string]any `json:"config"`
}
type graphNode struct {
ID string `json:"id"`
Type string `json:"type"`
Label string `json:"label"`
Position graphPosition `json:"position"`
Config map[string]any `json:"config"`
}
type graphEdge struct {
ID string `json:"id"`
Source string `json:"source"`
Target string `json:"target"`
Label string `json:"label"`
Config map[string]any `json:"config"`
}
type graphPosition struct {
X float64 `json:"x"`
Y float64 `json:"y"`
}
type graphIndex struct {
nodes map[string]graphNode
outgoing map[string][]graphEdge
incoming map[string][]graphEdge
}
func parseGraph(raw string) (*graphDef, error) {
var g graphDef
if err := json.Unmarshal([]byte(strings.TrimSpace(raw)), &g); err != nil {
return nil, fmt.Errorf("解析工作流图失败: %w", err)
}
if len(g.Nodes) == 0 {
return nil, fmt.Errorf("工作流没有节点")
}
if g.Config == nil {
g.Config = make(map[string]any)
}
return &g, nil
}
func indexGraph(g *graphDef) *graphIndex {
idx := &graphIndex{
nodes: make(map[string]graphNode, len(g.Nodes)),
outgoing: make(map[string][]graphEdge),
incoming: make(map[string][]graphEdge),
}
for _, node := range g.Nodes {
node.ID = strings.TrimSpace(node.ID)
if node.ID == "" {
continue
}
if strings.TrimSpace(node.Type) == "" {
node.Type = "tool"
}
if node.Config == nil {
node.Config = make(map[string]any)
}
idx.nodes[node.ID] = node
}
for _, edge := range g.Edges {
if _, ok := idx.nodes[edge.Source]; !ok {
continue
}
if _, ok := idx.nodes[edge.Target]; !ok {
continue
}
idx.outgoing[edge.Source] = append(idx.outgoing[edge.Source], edge)
idx.incoming[edge.Target] = append(idx.incoming[edge.Target], edge)
}
for source := range idx.outgoing {
sortEdgesByCanvas(idx.outgoing[source], idx.nodes)
}
return idx
}
func sortEdgesByCanvas(edges []graphEdge, nodes map[string]graphNode) {
sort.SliceStable(edges, func(i, j int) bool {
a := nodes[edges[i].Target]
b := nodes[edges[j].Target]
if a.Position.Y != b.Position.Y {
return a.Position.Y < b.Position.Y
}
if a.Position.X != b.Position.X {
return a.Position.X < b.Position.X
}
return edges[i].Target < edges[j].Target
})
}
func sortNodeIDsByCanvas(ids []string, nodes map[string]graphNode) {
sort.SliceStable(ids, func(i, j int) bool {
a := nodes[ids[i]]
b := nodes[ids[j]]
if a.Position.Y != b.Position.Y {
return a.Position.Y < b.Position.Y
}
if a.Position.X != b.Position.X {
return a.Position.X < b.Position.X
}
return ids[i] < ids[j]
})
}
func findStartNodeIDs(idx *graphIndex) []string {
var queue []string
for id, node := range idx.nodes {
if strings.EqualFold(node.Type, "start") {
queue = append(queue, id)
}
}
if len(queue) == 0 {
inDegree := make(map[string]int, len(idx.nodes))
for id := range idx.nodes {
inDegree[id] = 0
}
for _, edges := range idx.outgoing {
for _, edge := range edges {
inDegree[edge.Target]++
}
}
for id, deg := range inDegree {
if deg == 0 {
queue = append(queue, id)
}
}
}
sortNodeIDsByCanvas(queue, idx.nodes)
return queue
}
func branchNodeID(nodeID string) string {
return nodeID + "__eino_branch"
}
func edgeBranchNodeID(nodeID string) string {
return nodeID + "__eino_edge_branch"
}
+131
View File
@@ -0,0 +1,131 @@
package workflow
import (
"context"
"encoding/json"
"fmt"
"strings"
"time"
"cyberstrike-ai/internal/database"
"github.com/google/uuid"
)
func executeNode(ctx context.Context, args RunArgs, runID string, node graphNode, state *WorkflowLocalState) (map[string]any, bool, error) {
label := node.Label
if strings.TrimSpace(label) == "" {
label = node.ID
}
nodeRunID := uuid.NewString()
input := map[string]any{
"nodeId": node.ID,
"nodeType": node.Type,
"label": label,
"inputs": state.Inputs,
"previous": state.LastOutput,
}
inputJSON, _ := json.Marshal(input)
if err := args.DB.CreateWorkflowNodeRun(&database.WorkflowNodeRun{
ID: nodeRunID,
RunID: runID,
NodeID: node.ID,
Status: "running",
InputJSON: string(inputJSON),
StartedAt: time.Now(),
}); err != nil {
return nil, false, err
}
if args.Progress != nil {
args.Progress("workflow_node_start", fmt.Sprintf("开始节点:%s", label), map[string]any{
"workflowRunId": runID,
"nodeRunId": nodeRunID,
"nodeId": node.ID,
"nodeType": node.Type,
"label": label,
})
}
result, proceed, status, errText := runBuiltinNode(ctx, args, node, state)
outputJSON, _ := json.Marshal(result)
if err := args.DB.FinishWorkflowNodeRun(nodeRunID, status, string(outputJSON), errText); err != nil {
return nil, false, err
}
if status == "skipped" {
state.Skipped = append(state.Skipped, label)
} else {
state.Executed = append(state.Executed, label)
}
if args.Progress != nil {
progressData := map[string]any{
"workflowRunId": runID,
"nodeRunId": nodeRunID,
"nodeId": node.ID,
"nodeType": node.Type,
"label": label,
"status": status,
"output": result,
}
progressMsg := fmt.Sprintf("节点完成:%s%s", label, status)
if strings.EqualFold(node.Type, "condition") {
matched := false
if v, ok := result["matched"].(bool); ok {
matched = v
}
expr := cfgString(node.Config, "expression")
if matched {
progressMsg = fmt.Sprintf("条件判断:%s → 是", label)
} else {
progressMsg = fmt.Sprintf("条件判断:%s → 否", label)
}
progressData["expression"] = expr
progressData["matched"] = matched
}
args.Progress("workflow_node_result", progressMsg, progressData)
}
state.NodeProceed[node.ID] = proceed
return result, proceed, nil
}
func emitConditionBranchProgress(args RunArgs, runID string, node graphNode, edges []graphEdge, nodes map[string]graphNode, state *WorkflowLocalState) {
if args.Progress == nil || len(edges) == 0 {
return
}
for edgeIdx, edge := range edges {
allowed := edgeAllowed(edge, node, edgeIdx, state)
target := nodes[edge.Target]
targetLabel := strings.TrimSpace(target.Label)
if targetLabel == "" {
targetLabel = edge.Target
}
branchLabel := strings.TrimSpace(edge.Label)
if branchLabel == "" {
switch edgeIdx {
case 0:
branchLabel = "是"
case 1:
branchLabel = "否"
default:
branchLabel = fmt.Sprintf("分支 %d", edgeIdx+1)
}
}
cond := firstNonEmpty(cfgString(edge.Config, "condition"), cfgString(edge.Config, "expression"))
eventType := "workflow_branch_skipped"
msg := fmt.Sprintf("跳过分支「%s」→ %s", branchLabel, targetLabel)
if allowed {
eventType = "workflow_branch_taken"
msg = fmt.Sprintf("执行分支「%s」→ %s", branchLabel, targetLabel)
}
args.Progress(eventType, msg, map[string]any{
"workflowRunId": runID,
"nodeId": node.ID,
"nodeType": node.Type,
"label": node.Label,
"branchLabel": branchLabel,
"targetId": edge.Target,
"targetLabel": targetLabel,
"edgeCondition": cond,
"matched": conditionMatched(state),
})
}
}
+323
View File
@@ -0,0 +1,323 @@
package workflow
import (
"context"
"fmt"
"strings"
"cyberstrike-ai/internal/agent"
"cyberstrike-ai/internal/multiagent"
)
func runBuiltinNode(ctx context.Context, args RunArgs, node graphNode, state *WorkflowLocalState) (map[string]any, bool, string, string) {
cfg := node.Config
switch strings.ToLower(strings.TrimSpace(node.Type)) {
case "start":
out := map[string]any{
"output": state.Inputs["message"],
"message": state.Inputs["message"],
"conversationId": state.Inputs["conversationId"],
"projectId": state.Inputs["projectId"],
}
return out, true, "completed", ""
case "condition":
expr := cfgString(cfg, "expression")
ok := evalCondition(expr, state)
out := map[string]any{"output": ok, "condition": expr, "matched": ok}
return out, true, "completed", ""
case "output":
key := cfgString(cfg, "output_key")
if key == "" {
key = "result"
}
var value any
if v := cfgString(cfg, "static_value"); v != "" {
value = v
} else {
value = resolveOutputSourceBinding(cfg, state)
}
state.Outputs[key] = value
return map[string]any{"output": value, "outputs": map[string]any{key: value}}, true, "completed", ""
case "end":
value := resolveOutputSourceBinding(cfg, state)
if b, ok := parseFieldBinding(cfg, "result_binding"); ok {
value = resolveBinding(b, state)
}
return map[string]any{"output": value}, false, "completed", ""
case "tool":
return runToolNode(ctx, args, node, state)
case "agent":
return runAgentNode(ctx, args, node, state)
case "hitl":
return runHITLNode(args, node, state)
default:
reason := "未知节点类型"
return map[string]any{"output": "", "skipped": true, "reason": reason, "node_type": node.Type}, true, "skipped", reason
}
}
func runToolNode(ctx context.Context, args RunArgs, node graphNode, state *WorkflowLocalState) (map[string]any, bool, string, string) {
toolName := cfgString(node.Config, "tool_name")
if toolName == "" {
errText := "工具节点未选择 MCP 工具"
return map[string]any{"output": "", "error": errText}, false, "failed", errText
}
if args.Agent == nil {
errText := "工具节点执行失败:Agent 为空"
return map[string]any{"output": "", "tool_name": toolName, "error": errText}, false, "failed", errText
}
toolArgs, err := resolveToolArguments(node.Config, state)
if err != nil {
errText := fmt.Sprintf("工具参数不是合法 JSON%v", err)
return map[string]any{"output": "", "tool_name": toolName, "error": errText}, false, "failed", errText
}
if args.Progress != nil {
args.Progress("workflow_tool_start", fmt.Sprintf("调用工具:%s", toolName), map[string]any{
"nodeId": node.ID,
"tool": toolName,
"args": toolArgs,
})
}
result, err := args.Agent.ExecuteMCPToolForConversation(ctx, args.ConversationID, toolName, toolArgs)
if err != nil {
errText := err.Error()
return map[string]any{"output": "", "tool_name": toolName, "arguments": toolArgs, "error": errText}, false, "failed", errText
}
output := ""
executionID := ""
isError := false
if result != nil {
output = result.Result
executionID = result.ExecutionID
isError = result.IsError
}
out := map[string]any{
"output": output,
"tool_name": toolName,
"arguments": toolArgs,
"execution_id": executionID,
"is_error": isError,
}
if key := cfgString(node.Config, "output_key"); key != "" {
state.Outputs[key] = output
}
if isError {
errText := strings.TrimSpace(output)
if errText == "" {
errText = "工具返回错误"
}
return out, false, "failed", errText
}
return out, true, "completed", ""
}
func runAgentNode(ctx context.Context, args RunArgs, node graphNode, state *WorkflowLocalState) (map[string]any, bool, string, string) {
if args.AppCfg == nil || args.Agent == nil {
errText := "Agent 节点执行失败:应用配置或 Agent 为空"
return map[string]any{"output": "", "error": errText}, false, "failed", errText
}
mode := strings.ToLower(cfgString(node.Config, "agent_mode"))
if mode == "" {
mode = "eino_single"
}
inputSource := resolveNodeInputBinding(node.Config, state)
message := buildAgentNodeMessage(node, state, inputSource)
var result *multiagent.RunResult
var err error
state.SegmentMaxIteration = 0
agentProgress := workflowAgentProgress(args.Progress, state, node)
switch mode {
case "eino_single", "single", "chat":
result, err = multiagent.RunEinoSingleChatModelAgent(
ctx,
args.AppCfg,
&args.AppCfg.MultiAgent,
args.Agent,
args.DB,
args.Logger,
args.ConversationID,
args.ProjectID,
message,
args.History,
args.RoleTools,
agentProgress,
nil,
args.SystemPromptExtra,
)
default:
result, err = multiagent.RunDeepAgent(
ctx,
args.AppCfg,
&args.AppCfg.MultiAgent,
args.Agent,
args.DB,
args.Logger,
args.ConversationID,
args.ProjectID,
message,
args.History,
args.RoleTools,
agentProgress,
args.AgentsMarkdownDir,
mode,
nil,
args.SystemPromptExtra,
)
}
if err != nil {
errText := err.Error()
state.MainIterationOffset += state.SegmentMaxIteration
return map[string]any{"output": "", "mode": mode, "error": errText}, false, "failed", errText
}
state.MainIterationOffset += state.SegmentMaxIteration
response := ""
mcpIDs := []string{}
if result != nil {
response = result.Response
mcpIDs = result.MCPExecutionIDs
}
if args.Progress != nil {
args.Progress("workflow_agent_output", response, map[string]any{
"nodeId": node.ID,
"label": firstNonEmpty(node.Label, node.ID),
"mode": mode,
"inputSource": inputSource,
"inputPreview": truncateWorkflowPreview(inputSource, 500),
"mcpExecutionIds": mcpIDs,
})
}
if key := cfgString(node.Config, "output_key"); key != "" {
state.Outputs[key] = response
}
return map[string]any{
"output": response,
"mode": mode,
"mcp_execution_ids": mcpIDs,
}, true, "completed", ""
}
func buildAgentNodeMessage(node graphNode, state *WorkflowLocalState, upstreamInput string) string {
instruction := strings.TrimSpace(cfgString(node.Config, "instruction"))
upstreamInput = strings.TrimSpace(upstreamInput)
if instruction == "" {
if upstreamInput != "" {
return fmt.Sprintf("请基于上游节点输出继续处理:\n%s", upstreamInput)
}
return fmt.Sprintf("请基于上游节点输出继续处理:\n%v", state.LastOutput["output"])
}
if upstreamInput == "" {
return instruction
}
return strings.TrimSpace(fmt.Sprintf("上游输入:\n%s\n\n节点指令:\n%s", upstreamInput, instruction))
}
func workflowAgentProgress(progress agent.ProgressCallback, state *WorkflowLocalState, node graphNode) agent.ProgressCallback {
if progress == nil {
return nil
}
return func(eventType, message string, data interface{}) {
switch eventType {
case "response_start", "response_delta", "response", "done":
return
default:
enrichWorkflowAgentEventData(data, state, node)
if eventType == "iteration" {
applyWorkflowMainIterationOffset(data, state)
}
progress(eventType, message, data)
}
}
}
func enrichWorkflowAgentEventData(data interface{}, state *WorkflowLocalState, node graphNode) {
m, ok := data.(map[string]interface{})
if !ok || m == nil {
return
}
if node.ID != "" {
m["workflowNodeId"] = node.ID
}
if state != nil && strings.TrimSpace(state.WorkflowRunID) != "" {
m["workflowRunId"] = state.WorkflowRunID
}
}
func applyWorkflowMainIterationOffset(data interface{}, state *WorkflowLocalState) {
if state == nil {
return
}
m, ok := data.(map[string]interface{})
if !ok || m == nil {
return
}
scope, _ := m["einoScope"].(string)
if strings.TrimSpace(scope) != "main" {
return
}
raw := iterationNumberFromProgressData(m)
if raw <= 0 {
return
}
if raw > state.SegmentMaxIteration {
state.SegmentMaxIteration = raw
}
m["iteration"] = raw + state.MainIterationOffset
}
func iterationNumberFromProgressData(m map[string]interface{}) int {
switch v := m["iteration"].(type) {
case int:
return v
case int32:
return int(v)
case int64:
return int(v)
case float64:
return int(v)
case float32:
return int(v)
default:
return 0
}
}
func runHITLNode(args RunArgs, node graphNode, state *WorkflowLocalState) (map[string]any, bool, string, string) {
prompt := resolveHITLPromptBinding(node.Config, state)
reviewer := cfgString(node.Config, "reviewer")
if reviewer == "" {
reviewer = "human"
}
approved := true
if state != nil && state.Inputs != nil {
if v, ok := state.Inputs["_hitl_approved"]; ok {
approved = fmt.Sprint(v) == "true"
}
}
if !approved {
reason := "人工审批已拒绝"
if state != nil && state.Inputs != nil {
if v, ok := state.Inputs["_hitl_comment"]; ok {
if s := strings.TrimSpace(fmt.Sprint(v)); s != "" {
reason = s
}
}
}
return map[string]any{"output": "", "prompt": prompt, "approved": false, "mode": "interactive"}, false, "failed", reason
}
if args.Progress != nil {
args.Progress("workflow_hitl_checkpoint", "人工确认节点已通过", map[string]any{
"nodeId": node.ID,
"prompt": prompt,
"reviewer": reviewer,
"mode": "interactive",
"approved": true,
})
}
return map[string]any{
"output": prompt,
"prompt": prompt,
"reviewer": reviewer,
"approved": true,
"mode": "interactive",
}, true, "completed", ""
}
+48 -821
View File
@@ -4,82 +4,16 @@ import (
"context"
"encoding/json"
"fmt"
"regexp"
"sort"
"strings"
"time"
"cyberstrike-ai/internal/agent"
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/database"
"cyberstrike-ai/internal/multiagent"
"github.com/google/uuid"
"go.uber.org/zap"
)
type RunArgs struct {
DB *database.DB
Logger *zap.Logger
Role config.RoleConfig
AppCfg *config.Config
Agent *agent.Agent
ConversationID string
ProjectID string
UserMessage string
History []agent.ChatMessage
RoleTools []string
AgentsMarkdownDir string
SystemPromptExtra string
AssistantMessageID string
Progress agent.ProgressCallback
}
type RunResult struct {
Response string
RunID string
}
type graphDef struct {
Nodes []graphNode `json:"nodes"`
Edges []graphEdge `json:"edges"`
Config map[string]any `json:"config"`
}
type graphNode struct {
ID string `json:"id"`
Type string `json:"type"`
Label string `json:"label"`
Position graphPosition `json:"position"`
Config map[string]any `json:"config"`
}
type graphEdge struct {
ID string `json:"id"`
Source string `json:"source"`
Target string `json:"target"`
Label string `json:"label"`
Config map[string]any `json:"config"`
}
type graphPosition struct {
X float64 `json:"x"`
Y float64 `json:"y"`
}
type workflowExecState struct {
inputs map[string]any
outputs map[string]any
nodeOutputs map[string]map[string]any
lastOutput map[string]any
executed []string
skipped []string
workflowRunID string
// 图编排内多个 Agent 节点各自从第 1 轮上报 iteration;累计偏移避免对话页迭代序号回跳与流式条目复用错乱。
mainIterationOffset int
segmentMaxIteration int
}
// ShouldAutoRunRoleWorkflow returns true when a role explicitly binds a workflow
// and does not turn it off. Empty policy defaults to auto to keep role UX simple.
func ShouldAutoRunRoleWorkflow(role config.RoleConfig) bool {
@@ -90,10 +24,7 @@ func ShouldAutoRunRoleWorkflow(role config.RoleConfig) bool {
return policy == "" || policy == "auto"
}
// RunRoleBoundWorkflow executes the persisted role-bound workflow graph.
// Control nodes are interpreted locally, tool nodes call the existing MCP bridge,
// and agent nodes reuse the existing Eino ADK runners so role-bound flows share
// the same model/tool/session behavior as the chat page.
// RunRoleBoundWorkflow executes the persisted role-bound workflow via cached Eino Workflow.
func RunRoleBoundWorkflow(ctx context.Context, args RunArgs) (*RunResult, error) {
if args.DB == nil {
return nil, fmt.Errorf("workflow db is nil")
@@ -150,6 +81,7 @@ func RunRoleBoundWorkflow(ctx context.Context, args RunArgs) (*RunResult, error)
"workflowVersion": wf.Version,
"workflowRunId": runID,
"conversationId": args.ConversationID,
"engine": "eino_workflow",
})
}
@@ -158,13 +90,44 @@ func RunRoleBoundWorkflow(ctx context.Context, args RunArgs) (*RunResult, error)
_ = args.DB.FinishWorkflowRun(runID, "failed", "", err.Error())
return nil, err
}
state := &workflowExecState{
inputs: input,
outputs: make(map[string]any),
nodeOutputs: make(map[string]map[string]any),
workflowRunID: runID,
}
if err := executeGraph(ctx, args, runID, graph, state); err != nil {
state := newWorkflowLocalState(input, runID)
if err := executeEinoGraph(ctx, args, runID, wf.ID, wf.Version, graph, state); err != nil {
if IsAwaitingHITL(err) {
hitl := err.(*AwaitingHITLError)
partial := map[string]interface{}{
"workflowId": wf.ID,
"workflowName": wf.Name,
"workflowVersion": wf.Version,
"workflowRunId": runID,
"status": "awaiting_hitl",
"outputs": state.Outputs,
"executedNodes": state.Executed,
"skippedNodes": state.Skipped,
"pendingHitl": map[string]interface{}{
"nodeId": hitl.NodeID,
"label": hitl.NodeLabel,
"prompt": hitl.Prompt,
},
"engine": "eino_workflow",
}
partialJSON, _ := json.Marshal(partial)
_ = args.DB.SetWorkflowRunAwaitingHITL(runID, hitl.NodeID, string(partialJSON))
response := fmt.Sprintf("工作流「%s」已在节点「%s」暂停,等待人工审批。\n运行 ID:%s", wf.Name, firstNonEmpty(hitl.NodeLabel, hitl.NodeID), runID)
if args.Progress != nil {
args.Progress("workflow_paused", response, map[string]interface{}{
"workflowRunId": runID,
"status": "awaiting_hitl",
"nodeId": hitl.NodeID,
"resumeApi": fmt.Sprintf("/api/workflows/runs/%s/resume", runID),
})
}
return &RunResult{
Response: response,
RunID: runID,
Status: "awaiting_hitl",
AwaitingHITL: true,
}, nil
}
_ = args.DB.FinishWorkflowRun(runID, "failed", "", err.Error())
return nil, err
}
@@ -175,9 +138,10 @@ func RunRoleBoundWorkflow(ctx context.Context, args RunArgs) (*RunResult, error)
"workflowVersion": wf.Version,
"workflowRunId": runID,
"status": "completed",
"outputs": state.outputs,
"executedNodes": state.executed,
"skippedNodes": state.skipped,
"outputs": state.Outputs,
"executedNodes": state.Executed,
"skippedNodes": state.Skipped,
"engine": "eino_workflow",
}
outputJSON, _ := json.Marshal(output)
@@ -189,8 +153,9 @@ func RunRoleBoundWorkflow(ctx context.Context, args RunArgs) (*RunResult, error)
args.Progress("workflow_done", fmt.Sprintf("流程「%s」运行完成", wf.Name), map[string]interface{}{
"workflowRunId": runID,
"workflowId": wf.ID,
"outputs": state.outputs,
"outputs": state.Outputs,
"response": response,
"engine": "eino_workflow",
})
}
if args.Logger != nil {
@@ -199,746 +164,8 @@ func RunRoleBoundWorkflow(ctx context.Context, args RunArgs) (*RunResult, error)
zap.String("workflow_run_id", runID),
zap.String("conversation_id", args.ConversationID),
zap.String("role", args.Role.Name),
zap.String("engine", "eino_workflow"),
)
}
return &RunResult{Response: response, RunID: runID}, nil
}
func parseGraph(raw string) (*graphDef, error) {
var g graphDef
if err := json.Unmarshal([]byte(strings.TrimSpace(raw)), &g); err != nil {
return nil, fmt.Errorf("解析工作流图失败: %w", err)
}
if len(g.Nodes) == 0 {
return nil, fmt.Errorf("工作流没有节点")
}
if g.Config == nil {
g.Config = make(map[string]any)
}
return &g, nil
}
func executeGraph(ctx context.Context, args RunArgs, runID string, g *graphDef, state *workflowExecState) error {
nodes := make(map[string]graphNode, len(g.Nodes))
inDegree := make(map[string]int, len(g.Nodes))
outgoing := make(map[string][]graphEdge)
for _, node := range g.Nodes {
node.ID = strings.TrimSpace(node.ID)
if node.ID == "" {
continue
}
if strings.TrimSpace(node.Type) == "" {
node.Type = "tool"
}
if node.Config == nil {
node.Config = make(map[string]any)
}
nodes[node.ID] = node
inDegree[node.ID] = 0
}
for _, edge := range g.Edges {
if _, ok := nodes[edge.Source]; !ok {
continue
}
if _, ok := nodes[edge.Target]; !ok {
continue
}
outgoing[edge.Source] = append(outgoing[edge.Source], edge)
inDegree[edge.Target]++
}
for source := range outgoing {
sort.SliceStable(outgoing[source], func(i, j int) bool {
a := nodes[outgoing[source][i].Target]
b := nodes[outgoing[source][j].Target]
if a.Position.Y != b.Position.Y {
return a.Position.Y < b.Position.Y
}
if a.Position.X != b.Position.X {
return a.Position.X < b.Position.X
}
return outgoing[source][i].Target < outgoing[source][j].Target
})
}
var queue []string
for id, node := range nodes {
if strings.EqualFold(node.Type, "start") {
queue = append(queue, id)
}
}
if len(queue) == 0 {
for id, deg := range inDegree {
if deg == 0 {
queue = append(queue, id)
}
}
}
sortNodeIDsByCanvas(queue, nodes)
seen := make(map[string]bool)
remainingIncoming := make(map[string]int, len(inDegree))
for id, deg := range inDegree {
remainingIncoming[id] = deg
}
for len(queue) > 0 {
select {
case <-ctx.Done():
return ctx.Err()
default:
}
id := queue[0]
queue = queue[1:]
if seen[id] {
continue
}
seen[id] = true
node := nodes[id]
result, proceed, err := executeNode(ctx, args, runID, node, state)
if err != nil {
return err
}
state.nodeOutputs[id] = result
state.lastOutput = result
if proceed {
edges := outgoing[id]
if strings.EqualFold(node.Type, "condition") {
emitConditionBranchProgress(args, runID, node, edges, nodes, state)
}
for edgeIdx, edge := range edges {
if !edgeAllowed(edge, node, edgeIdx, state) {
continue
}
remainingIncoming[edge.Target]--
if remainingIncoming[edge.Target] > 0 {
continue
}
queue = append(queue, edge.Target)
}
sortNodeIDsByCanvas(queue, nodes)
}
}
return nil
}
func sortNodeIDsByCanvas(ids []string, nodes map[string]graphNode) {
sort.SliceStable(ids, func(i, j int) bool {
a := nodes[ids[i]]
b := nodes[ids[j]]
if a.Position.Y != b.Position.Y {
return a.Position.Y < b.Position.Y
}
if a.Position.X != b.Position.X {
return a.Position.X < b.Position.X
}
return ids[i] < ids[j]
})
}
func executeNode(ctx context.Context, args RunArgs, runID string, node graphNode, state *workflowExecState) (map[string]any, bool, error) {
label := node.Label
if strings.TrimSpace(label) == "" {
label = node.ID
}
nodeRunID := uuid.NewString()
input := map[string]any{
"nodeId": node.ID,
"nodeType": node.Type,
"label": label,
"inputs": state.inputs,
"previous": state.lastOutput,
}
inputJSON, _ := json.Marshal(input)
if err := args.DB.CreateWorkflowNodeRun(&database.WorkflowNodeRun{
ID: nodeRunID,
RunID: runID,
NodeID: node.ID,
Status: "running",
InputJSON: string(inputJSON),
StartedAt: time.Now(),
}); err != nil {
return nil, false, err
}
if args.Progress != nil {
args.Progress("workflow_node_start", fmt.Sprintf("开始节点:%s", label), map[string]any{
"workflowRunId": runID,
"nodeRunId": nodeRunID,
"nodeId": node.ID,
"nodeType": node.Type,
"label": label,
})
}
result, proceed, status, errText := runBuiltinNode(ctx, args, node, state)
outputJSON, _ := json.Marshal(result)
if err := args.DB.FinishWorkflowNodeRun(nodeRunID, status, string(outputJSON), errText); err != nil {
return nil, false, err
}
if status == "skipped" {
state.skipped = append(state.skipped, label)
} else {
state.executed = append(state.executed, label)
}
if args.Progress != nil {
progressData := map[string]any{
"workflowRunId": runID,
"nodeRunId": nodeRunID,
"nodeId": node.ID,
"nodeType": node.Type,
"label": label,
"status": status,
"output": result,
}
progressMsg := fmt.Sprintf("节点完成:%s%s", label, status)
if strings.EqualFold(node.Type, "condition") {
matched := false
if v, ok := result["matched"].(bool); ok {
matched = v
}
expr := cfgString(node.Config, "expression")
if matched {
progressMsg = fmt.Sprintf("条件判断:%s → 是", label)
} else {
progressMsg = fmt.Sprintf("条件判断:%s → 否", label)
}
progressData["expression"] = expr
progressData["matched"] = matched
}
args.Progress("workflow_node_result", progressMsg, progressData)
}
return result, proceed, nil
}
func runBuiltinNode(ctx context.Context, args RunArgs, node graphNode, state *workflowExecState) (map[string]any, bool, string, string) {
cfg := node.Config
switch strings.ToLower(strings.TrimSpace(node.Type)) {
case "start":
out := map[string]any{
"output": state.inputs["message"],
"message": state.inputs["message"],
"conversationId": state.inputs["conversationId"],
"projectId": state.inputs["projectId"],
}
return out, true, "completed", ""
case "condition":
expr := cfgString(cfg, "expression")
ok := evalCondition(expr, state)
out := map[string]any{"output": ok, "condition": expr, "matched": ok}
// 条件节点始终继续,由出边条件(或连线标签/顺序)决定走「是/否」分支。
return out, true, "completed", ""
case "output":
key := cfgString(cfg, "output_key")
if key == "" {
key = "result"
}
value := resolveTemplate(cfgString(cfg, "source"), state)
state.outputs[key] = value
return map[string]any{"output": value, "outputs": map[string]any{key: value}}, true, "completed", ""
case "end":
value := resolveTemplate(cfgString(cfg, "result_template"), state)
return map[string]any{"output": value}, false, "completed", ""
case "tool":
return runToolNode(ctx, args, node, state)
case "agent":
return runAgentNode(ctx, args, node, state)
case "hitl":
return runHITLNode(args, node, state)
default:
reason := "未知节点类型"
return map[string]any{"output": "", "skipped": true, "reason": reason, "node_type": node.Type}, true, "skipped", reason
}
}
func runToolNode(ctx context.Context, args RunArgs, node graphNode, state *workflowExecState) (map[string]any, bool, string, string) {
toolName := cfgString(node.Config, "tool_name")
if toolName == "" {
errText := "工具节点未选择 MCP 工具"
return map[string]any{"output": "", "error": errText}, false, "failed", errText
}
if args.Agent == nil {
errText := "工具节点执行失败:Agent 为空"
return map[string]any{"output": "", "tool_name": toolName, "error": errText}, false, "failed", errText
}
toolArgs, err := parseToolArguments(cfgString(node.Config, "arguments"), state)
if err != nil {
errText := fmt.Sprintf("工具参数不是合法 JSON%v", err)
return map[string]any{"output": "", "tool_name": toolName, "error": errText}, false, "failed", errText
}
if args.Progress != nil {
args.Progress("workflow_tool_start", fmt.Sprintf("调用工具:%s", toolName), map[string]any{
"nodeId": node.ID,
"tool": toolName,
"args": toolArgs,
})
}
result, err := args.Agent.ExecuteMCPToolForConversation(ctx, args.ConversationID, toolName, toolArgs)
if err != nil {
errText := err.Error()
return map[string]any{"output": "", "tool_name": toolName, "arguments": toolArgs, "error": errText}, false, "failed", errText
}
output := ""
executionID := ""
isError := false
if result != nil {
output = result.Result
executionID = result.ExecutionID
isError = result.IsError
}
out := map[string]any{
"output": output,
"tool_name": toolName,
"arguments": toolArgs,
"execution_id": executionID,
"is_error": isError,
}
if key := cfgString(node.Config, "output_key"); key != "" {
state.outputs[key] = output
}
if isError {
errText := strings.TrimSpace(output)
if errText == "" {
errText = "工具返回错误"
}
return out, false, "failed", errText
}
return out, true, "completed", ""
}
func runAgentNode(ctx context.Context, args RunArgs, node graphNode, state *workflowExecState) (map[string]any, bool, string, string) {
if args.AppCfg == nil || args.Agent == nil {
errText := "Agent 节点执行失败:应用配置或 Agent 为空"
return map[string]any{"output": "", "error": errText}, false, "failed", errText
}
mode := strings.ToLower(cfgString(node.Config, "agent_mode"))
if mode == "" {
mode = "eino_single"
}
inputSource := cfgString(node.Config, "input_source")
if inputSource == "" {
inputSource = "{{previous.output}}"
}
upstreamInput := strings.TrimSpace(resolveTemplate(inputSource, state))
message := buildAgentNodeMessage(node, state)
var result *multiagent.RunResult
var err error
state.segmentMaxIteration = 0
agentProgress := workflowAgentProgress(args.Progress, state, node)
switch mode {
case "eino_single", "single", "chat":
result, err = multiagent.RunEinoSingleChatModelAgent(
ctx,
args.AppCfg,
&args.AppCfg.MultiAgent,
args.Agent,
args.DB,
args.Logger,
args.ConversationID,
args.ProjectID,
message,
args.History,
args.RoleTools,
agentProgress,
nil,
args.SystemPromptExtra,
)
default:
result, err = multiagent.RunDeepAgent(
ctx,
args.AppCfg,
&args.AppCfg.MultiAgent,
args.Agent,
args.DB,
args.Logger,
args.ConversationID,
args.ProjectID,
message,
args.History,
args.RoleTools,
agentProgress,
args.AgentsMarkdownDir,
mode,
nil,
args.SystemPromptExtra,
)
}
if err != nil {
errText := err.Error()
state.mainIterationOffset += state.segmentMaxIteration
return map[string]any{"output": "", "mode": mode, "error": errText}, false, "failed", errText
}
state.mainIterationOffset += state.segmentMaxIteration
response := ""
mcpIDs := []string{}
if result != nil {
response = result.Response
mcpIDs = result.MCPExecutionIDs
}
if args.Progress != nil {
args.Progress("workflow_agent_output", response, map[string]any{
"nodeId": node.ID,
"label": firstNonEmpty(node.Label, node.ID),
"mode": mode,
"inputSource": inputSource,
"inputPreview": truncateWorkflowPreview(upstreamInput, 500),
"mcpExecutionIds": mcpIDs,
})
}
if key := cfgString(node.Config, "output_key"); key != "" {
state.outputs[key] = response
}
return map[string]any{
"output": response,
"mode": mode,
"mcp_execution_ids": mcpIDs,
}, true, "completed", ""
}
func buildAgentNodeMessage(node graphNode, state *workflowExecState) string {
instruction := strings.TrimSpace(resolveTemplate(cfgString(node.Config, "instruction"), state))
inputSource := cfgString(node.Config, "input_source")
if inputSource == "" {
inputSource = "{{previous.output}}"
}
upstreamInput := strings.TrimSpace(resolveTemplate(inputSource, state))
if instruction == "" {
if upstreamInput != "" {
return fmt.Sprintf("请基于上游节点输出继续处理:\n%s", upstreamInput)
}
return fmt.Sprintf("请基于上游节点输出继续处理:\n%v", state.lastOutput["output"])
}
if upstreamInput == "" {
return instruction
}
return strings.TrimSpace(fmt.Sprintf("上游输入:\n%s\n\n节点指令:\n%s", upstreamInput, instruction))
}
func workflowAgentProgress(progress agent.ProgressCallback, state *workflowExecState, node graphNode) agent.ProgressCallback {
if progress == nil {
return nil
}
return func(eventType, message string, data interface{}) {
switch eventType {
case "response_start", "response_delta", "response", "done":
return
default:
enrichWorkflowAgentEventData(data, state, node)
if eventType == "iteration" {
applyWorkflowMainIterationOffset(data, state)
}
progress(eventType, message, data)
}
}
}
func enrichWorkflowAgentEventData(data interface{}, state *workflowExecState, node graphNode) {
m, ok := data.(map[string]interface{})
if !ok || m == nil {
return
}
if node.ID != "" {
m["workflowNodeId"] = node.ID
}
if state != nil && strings.TrimSpace(state.workflowRunID) != "" {
m["workflowRunId"] = state.workflowRunID
}
}
func applyWorkflowMainIterationOffset(data interface{}, state *workflowExecState) {
if state == nil {
return
}
m, ok := data.(map[string]interface{})
if !ok || m == nil {
return
}
scope, _ := m["einoScope"].(string)
if strings.TrimSpace(scope) != "main" {
return
}
raw := iterationNumberFromProgressData(m)
if raw <= 0 {
return
}
if raw > state.segmentMaxIteration {
state.segmentMaxIteration = raw
}
m["iteration"] = raw + state.mainIterationOffset
}
func iterationNumberFromProgressData(m map[string]interface{}) int {
switch v := m["iteration"].(type) {
case int:
return v
case int32:
return int(v)
case int64:
return int(v)
case float64:
return int(v)
case float32:
return int(v)
default:
return 0
}
}
func runHITLNode(args RunArgs, node graphNode, state *workflowExecState) (map[string]any, bool, string, string) {
prompt := resolveTemplate(cfgString(node.Config, "prompt"), state)
reviewer := cfgString(node.Config, "reviewer")
if args.Progress != nil {
args.Progress("workflow_hitl_checkpoint", "人工确认节点已记录", map[string]any{
"nodeId": node.ID,
"prompt": prompt,
"reviewer": reviewer,
"mode": "record_only",
})
}
return map[string]any{
"output": prompt,
"prompt": prompt,
"reviewer": reviewer,
"approved": true,
"mode": "record_only",
}, true, "completed", ""
}
func parseToolArguments(raw string, state *workflowExecState) (map[string]interface{}, error) {
if raw == "" {
return map[string]interface{}{}, nil
}
raw = strings.TrimSpace(resolveTemplate(raw, state))
if raw == "" {
return map[string]interface{}{}, nil
}
var args map[string]interface{}
if err := json.Unmarshal([]byte(raw), &args); err != nil {
return nil, err
}
if args == nil {
args = map[string]interface{}{}
}
return args, nil
}
func edgeAllowed(edge graphEdge, sourceNode graphNode, edgeIndex int, state *workflowExecState) bool {
cond := firstNonEmpty(cfgString(edge.Config, "condition"), cfgString(edge.Config, "expression"))
if cond != "" {
return evalCondition(cond, state)
}
if strings.EqualFold(strings.TrimSpace(sourceNode.Type), "condition") {
return conditionBranchAllowed(edge, edgeIndex, state)
}
return true
}
func conditionBranchAllowed(edge graphEdge, edgeIndex int, state *workflowExecState) bool {
matched := conditionMatched(state)
if branch := conditionBranchHint(edge); branch != "" {
return (branch == "true" && matched) || (branch == "false" && !matched)
}
switch edgeIndex {
case 0:
return matched
case 1:
return !matched
default:
return false
}
}
func conditionMatched(state *workflowExecState) bool {
v := strings.ToLower(cleanComparable(fmt.Sprint(valueFromPath("previous.matched", state))))
return v == "true" || v == "1"
}
func conditionBranchHint(edge graphEdge) string {
if edge.Config != nil {
switch strings.ToLower(strings.TrimSpace(cfgString(edge.Config, "branch"))) {
case "true", "yes", "y", "是":
return "true"
case "false", "no", "n", "否":
return "false"
}
}
switch strings.ToLower(strings.TrimSpace(edge.Label)) {
case "true", "yes", "y", "是":
return "true"
case "false", "no", "n", "否":
return "false"
}
return ""
}
func emitConditionBranchProgress(args RunArgs, runID string, node graphNode, edges []graphEdge, nodes map[string]graphNode, state *workflowExecState) {
if args.Progress == nil || len(edges) == 0 {
return
}
for edgeIdx, edge := range edges {
allowed := edgeAllowed(edge, node, edgeIdx, state)
target := nodes[edge.Target]
targetLabel := strings.TrimSpace(target.Label)
if targetLabel == "" {
targetLabel = edge.Target
}
branchLabel := strings.TrimSpace(edge.Label)
if branchLabel == "" {
switch edgeIdx {
case 0:
branchLabel = "是"
case 1:
branchLabel = "否"
default:
branchLabel = fmt.Sprintf("分支 %d", edgeIdx+1)
}
}
cond := firstNonEmpty(cfgString(edge.Config, "condition"), cfgString(edge.Config, "expression"))
eventType := "workflow_branch_skipped"
msg := fmt.Sprintf("跳过分支「%s」→ %s", branchLabel, targetLabel)
if allowed {
eventType = "workflow_branch_taken"
msg = fmt.Sprintf("执行分支「%s」→ %s", branchLabel, targetLabel)
}
args.Progress(eventType, msg, map[string]any{
"workflowRunId": runID,
"nodeId": node.ID,
"nodeType": node.Type,
"label": node.Label,
"branchLabel": branchLabel,
"targetId": edge.Target,
"targetLabel": targetLabel,
"edgeCondition": cond,
"matched": conditionMatched(state),
})
}
}
func cfgString(cfg map[string]any, key string) string {
if cfg == nil {
return ""
}
if v, ok := cfg[key]; ok {
return strings.TrimSpace(fmt.Sprint(v))
}
return ""
}
func firstNonEmpty(values ...string) string {
for _, value := range values {
if s := strings.TrimSpace(value); s != "" {
return s
}
}
return ""
}
func truncateWorkflowPreview(s string, limit int) string {
s = strings.TrimSpace(s)
if limit <= 0 || len([]rune(s)) <= limit {
return s
}
runes := []rune(s)
return string(runes[:limit]) + "..."
}
var templateVarRe = regexp.MustCompile(`\{\{\s*([a-zA-Z0-9_.-]+)\s*\}\}`)
func resolveTemplate(s string, state *workflowExecState) string {
if strings.TrimSpace(s) == "" {
return fmt.Sprint(valueFromPath("previous.output", state))
}
return templateVarRe.ReplaceAllStringFunc(s, func(match string) string {
m := templateVarRe.FindStringSubmatch(match)
if len(m) != 2 {
return match
}
return fmt.Sprint(valueFromPath(m[1], state))
})
}
func valueFromPath(path string, state *workflowExecState) any {
parts := strings.Split(path, ".")
if len(parts) == 0 {
return ""
}
var cur any
switch parts[0] {
case "inputs", "input":
cur = state.inputs
case "previous", "prev":
cur = state.lastOutput
case "outputs":
cur = state.outputs
default:
if v, ok := state.inputs[parts[0]]; ok {
cur = v
} else if v, ok := state.nodeOutputs[parts[0]]; ok {
cur = v
} else {
return ""
}
}
for _, p := range parts[1:] {
m, ok := cur.(map[string]any)
if !ok {
return ""
}
cur = m[p]
}
if cur == nil {
return ""
}
return cur
}
func evalCondition(expr string, state *workflowExecState) bool {
expr = strings.TrimSpace(expr)
if expr == "" {
return true
}
resolved := strings.TrimSpace(resolveTemplate(expr, state))
switch {
case strings.Contains(resolved, "!="):
parts := strings.SplitN(resolved, "!=", 2)
return cleanComparable(parts[0]) != cleanComparable(parts[1])
case strings.Contains(resolved, "=="):
parts := strings.SplitN(resolved, "==", 2)
return cleanComparable(parts[0]) == cleanComparable(parts[1])
default:
v := strings.ToLower(cleanComparable(resolved))
return v != "" && v != "false" && v != "0" && v != "null"
}
}
func cleanComparable(s string) string {
s = strings.TrimSpace(s)
s = strings.Trim(s, `"'`)
return s
}
func renderWorkflowResponse(roleName, workflowName string, version int, runID string, state *workflowExecState) string {
var sb strings.Builder
sb.WriteString(fmt.Sprintf("角色「%s」已完成工作流「%s」(版本 %d)。\n\n", roleName, workflowName, version))
sb.WriteString(fmt.Sprintf("运行 ID%s\n", runID))
sb.WriteString(fmt.Sprintf("已执行节点:%d", len(state.executed)))
if len(state.skipped) > 0 {
sb.WriteString(fmt.Sprintf(",跳过节点:%d", len(state.skipped)))
}
sb.WriteString("\n\n")
if len(state.outputs) > 0 {
sb.WriteString("输出:\n")
keys := make([]string, 0, len(state.outputs))
for k := range state.outputs {
keys = append(keys, k)
}
sort.Strings(keys)
for _, k := range keys {
sb.WriteString(fmt.Sprintf("- %s%v\n", k, state.outputs[k]))
}
} else {
sb.WriteString("暂无输出。请检查是否配置了输出节点,或条件分支是否命中。\n")
}
if len(state.skipped) > 0 {
sb.WriteString("\n未执行的节点类型仍会保留运行记录:")
sb.WriteString(strings.Join(state.skipped, "、"))
sb.WriteString("。")
}
return strings.TrimSpace(sb.String())
return &RunResult{Response: response, RunID: runID, Status: "completed"}, nil
}
+224
View File
@@ -0,0 +1,224 @@
package workflow
import (
"fmt"
"regexp"
"sort"
"strings"
"github.com/cloudwego/eino/schema"
)
func init() {
schema.RegisterName[*WorkflowLocalState]("_cyberstrike_workflow_local_state")
}
// WorkflowLocalState is the Eino WithGenLocalState payload (checkpoint-serializable).
type WorkflowLocalState struct {
Inputs map[string]any `json:"inputs,omitempty"`
Outputs map[string]any `json:"outputs,omitempty"`
NodeOutputs map[string]map[string]any `json:"nodeOutputs,omitempty"`
NodeProceed map[string]bool `json:"nodeProceed,omitempty"`
LastOutput map[string]any `json:"lastOutput,omitempty"`
Executed []string `json:"executed,omitempty"`
Skipped []string `json:"skipped,omitempty"`
WorkflowRunID string `json:"workflowRunId,omitempty"`
MainIterationOffset int `json:"mainIterationOffset,omitempty"`
SegmentMaxIteration int `json:"segmentMaxIteration,omitempty"`
}
func newWorkflowLocalState(inputs map[string]interface{}, runID string) *WorkflowLocalState {
in := make(map[string]any, len(inputs))
for k, v := range inputs {
in[k] = v
}
return &WorkflowLocalState{
Inputs: in,
Outputs: make(map[string]any),
NodeOutputs: make(map[string]map[string]any),
NodeProceed: make(map[string]bool),
WorkflowRunID: runID,
}
}
var templateVarRe = regexp.MustCompile(`\{\{\s*([a-zA-Z0-9_.-]+)\s*\}\}`)
func resolveTemplate(s string, state *WorkflowLocalState) string {
if strings.TrimSpace(s) == "" {
return fmt.Sprint(valueFromPath("previous.output", state))
}
return templateVarRe.ReplaceAllStringFunc(s, func(match string) string {
m := templateVarRe.FindStringSubmatch(match)
if len(m) != 2 {
return match
}
return fmt.Sprint(valueFromPath(m[1], state))
})
}
func valueFromPath(path string, state *WorkflowLocalState) any {
parts := strings.Split(path, ".")
if len(parts) == 0 {
return ""
}
var cur any
switch parts[0] {
case "inputs", "input":
cur = state.Inputs
case "previous", "prev":
cur = state.LastOutput
case "outputs":
cur = state.Outputs
default:
if v, ok := state.Inputs[parts[0]]; ok {
cur = v
} else if v, ok := state.NodeOutputs[parts[0]]; ok {
cur = v
} else {
return ""
}
}
for _, p := range parts[1:] {
m, ok := cur.(map[string]any)
if !ok {
return ""
}
cur = m[p]
}
if cur == nil {
return ""
}
return cur
}
func evalCondition(expr string, state *WorkflowLocalState) bool {
expr = strings.TrimSpace(expr)
if expr == "" {
return true
}
resolved := strings.TrimSpace(resolveTemplate(expr, state))
switch {
case strings.Contains(resolved, "!="):
parts := strings.SplitN(resolved, "!=", 2)
return cleanComparable(parts[0]) != cleanComparable(parts[1])
case strings.Contains(resolved, "=="):
parts := strings.SplitN(resolved, "==", 2)
return cleanComparable(parts[0]) == cleanComparable(parts[1])
default:
v := strings.ToLower(cleanComparable(resolved))
return v != "" && v != "false" && v != "0" && v != "null"
}
}
func cleanComparable(s string) string {
s = strings.TrimSpace(s)
s = strings.Trim(s, `"'`)
return s
}
func edgeAllowed(edge graphEdge, sourceNode graphNode, edgeIndex int, state *WorkflowLocalState) bool {
cond := firstNonEmpty(cfgString(edge.Config, "condition"), cfgString(edge.Config, "expression"))
if cond != "" {
return evalCondition(cond, state)
}
if strings.EqualFold(strings.TrimSpace(sourceNode.Type), "condition") {
return conditionBranchAllowed(edge, edgeIndex, state)
}
return true
}
func conditionBranchAllowed(edge graphEdge, edgeIndex int, state *WorkflowLocalState) bool {
matched := conditionMatched(state)
if branch := conditionBranchHint(edge); branch != "" {
return (branch == "true" && matched) || (branch == "false" && !matched)
}
switch edgeIndex {
case 0:
return matched
case 1:
return !matched
default:
return false
}
}
func conditionMatched(state *WorkflowLocalState) bool {
v := strings.ToLower(cleanComparable(fmt.Sprint(valueFromPath("previous.matched", state))))
return v == "true" || v == "1"
}
func conditionBranchHint(edge graphEdge) string {
if edge.Config != nil {
switch strings.ToLower(strings.TrimSpace(cfgString(edge.Config, "branch"))) {
case "true", "yes", "y", "是":
return "true"
case "false", "no", "n", "否":
return "false"
}
}
switch strings.ToLower(strings.TrimSpace(edge.Label)) {
case "true", "yes", "y", "是":
return "true"
case "false", "no", "n", "否":
return "false"
}
return ""
}
func cfgString(cfg map[string]any, key string) string {
if cfg == nil {
return ""
}
if v, ok := cfg[key]; ok {
return strings.TrimSpace(fmt.Sprint(v))
}
return ""
}
func firstNonEmpty(values ...string) string {
for _, value := range values {
if s := strings.TrimSpace(value); s != "" {
return s
}
}
return ""
}
func truncateWorkflowPreview(s string, limit int) string {
s = strings.TrimSpace(s)
if limit <= 0 || len([]rune(s)) <= limit {
return s
}
runes := []rune(s)
return string(runes[:limit]) + "..."
}
func renderWorkflowResponse(roleName, workflowName string, version int, runID string, state *WorkflowLocalState) string {
var sb strings.Builder
sb.WriteString(fmt.Sprintf("角色「%s」已完成工作流「%s」(版本 %d)。\n\n", roleName, workflowName, version))
sb.WriteString(fmt.Sprintf("运行 ID%s\n", runID))
sb.WriteString(fmt.Sprintf("已执行节点:%d", len(state.Executed)))
if len(state.Skipped) > 0 {
sb.WriteString(fmt.Sprintf(",跳过节点:%d", len(state.Skipped)))
}
sb.WriteString("\n\n")
if len(state.Outputs) > 0 {
sb.WriteString("输出:\n")
keys := make([]string, 0, len(state.Outputs))
for k := range state.Outputs {
keys = append(keys, k)
}
sort.Strings(keys)
for _, k := range keys {
sb.WriteString(fmt.Sprintf("- %s%v\n", k, state.Outputs[k]))
}
} else {
sb.WriteString("暂无输出。请检查是否配置了输出节点,或条件分支是否命中。\n")
}
if len(state.Skipped) > 0 {
sb.WriteString("\n未执行的节点类型仍会保留运行记录:")
sb.WriteString(strings.Join(state.Skipped, "、"))
sb.WriteString("。")
}
return strings.TrimSpace(sb.String())
}
+74
View File
@@ -0,0 +1,74 @@
package workflow
import (
"fmt"
"strconv"
)
// WorkflowInput is the typed entry for Eino compose.Workflow[I,O].
type WorkflowInput struct {
Message string `json:"message"`
ConversationID string `json:"conversationId"`
ProjectID string `json:"projectId"`
Role string `json:"role"`
WorkflowID string `json:"workflowId"`
WorkflowVersion int `json:"workflowVersion"`
}
// WorkflowOutput aggregates terminal node payloads keyed by canvas node id.
type WorkflowOutput map[string]any
// WorkflowNodeOutput is the per-node lambda payload (alias for Eino edge type alignment).
type WorkflowNodeOutput = map[string]interface{}
func workflowInputFromMap(m map[string]interface{}) WorkflowInput {
in := WorkflowInput{}
if m == nil {
return in
}
if v, ok := m["message"].(string); ok {
in.Message = v
} else if m["message"] != nil {
in.Message = fmt.Sprint(m["message"])
}
if v, ok := m["conversationId"].(string); ok {
in.ConversationID = v
}
if v, ok := m["projectId"].(string); ok {
in.ProjectID = v
}
if v, ok := m["role"].(string); ok {
in.Role = v
}
if v, ok := m["workflowId"].(string); ok {
in.WorkflowID = v
}
switch v := m["workflowVersion"].(type) {
case int:
in.WorkflowVersion = v
case int64:
in.WorkflowVersion = int(v)
case float64:
in.WorkflowVersion = int(v)
case string:
if n, err := strconv.Atoi(v); err == nil {
in.WorkflowVersion = n
}
}
return in
}
func (in WorkflowInput) toStateInputs() map[string]any {
return map[string]any{
"message": in.Message,
"conversationId": in.ConversationID,
"projectId": in.ProjectID,
"role": in.Role,
"workflowId": in.WorkflowID,
"workflowVersion": in.WorkflowVersion,
}
}
func cacheKey(workflowID string, version int) string {
return workflowID + ":" + strconv.Itoa(version)
}