mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-07-04 11:37:57 +02:00
Add files via upload
This commit is contained in:
@@ -0,0 +1,24 @@
|
||||
package workflow
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/cloudwego/eino/compose"
|
||||
)
|
||||
|
||||
// compileAgentSubgraph wraps an Agent canvas node as an Eino subgraph (AddGraphNode best practice).
|
||||
func compileAgentSubgraph(_ context.Context, node graphNode) (compose.AnyGraph, error) {
|
||||
n := node
|
||||
innerID := n.ID + "__agent"
|
||||
g := compose.NewGraph[WorkflowNodeOutput, WorkflowNodeOutput]()
|
||||
_ = g.AddLambdaNode(innerID, compose.InvokableLambda(func(runCtx context.Context, _ WorkflowNodeOutput) (WorkflowNodeOutput, error) {
|
||||
return runWorkflowNodeLambda(runCtx, n)
|
||||
}))
|
||||
if err := g.AddEdge(compose.START, innerID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := g.AddEdge(innerID, compose.END); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return g, nil
|
||||
}
|
||||
@@ -0,0 +1,141 @@
|
||||
package workflow
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// FieldBinding selects a value from workflow state (replaces {{...}} templates).
|
||||
type FieldBinding struct {
|
||||
From string `json:"from"` // inputs | previous | <nodeId>
|
||||
Field string `json:"field"` // e.g. output, message
|
||||
}
|
||||
|
||||
func parseFieldBinding(cfg map[string]any, keys ...string) (FieldBinding, bool) {
|
||||
for _, key := range keys {
|
||||
if cfg == nil {
|
||||
continue
|
||||
}
|
||||
raw, ok := cfg[key]
|
||||
if !ok || raw == nil {
|
||||
continue
|
||||
}
|
||||
switch v := raw.(type) {
|
||||
case map[string]any:
|
||||
return FieldBinding{
|
||||
From: strings.TrimSpace(fmt.Sprint(v["from"])),
|
||||
Field: strings.TrimSpace(fmt.Sprint(v["field"])),
|
||||
}, true
|
||||
case string:
|
||||
s := strings.TrimSpace(v)
|
||||
if s == "" {
|
||||
continue
|
||||
}
|
||||
var b FieldBinding
|
||||
if err := json.Unmarshal([]byte(s), &b); err == nil && (b.From != "" || b.Field != "") {
|
||||
return b, true
|
||||
}
|
||||
}
|
||||
}
|
||||
return FieldBinding{}, false
|
||||
}
|
||||
|
||||
func defaultBinding(from, field string) FieldBinding {
|
||||
return FieldBinding{From: from, Field: field}
|
||||
}
|
||||
|
||||
func resolveBinding(b FieldBinding, state *WorkflowLocalState) any {
|
||||
from := strings.TrimSpace(b.From)
|
||||
field := strings.TrimSpace(b.Field)
|
||||
if field == "" {
|
||||
field = "output"
|
||||
}
|
||||
if from == "" || from == "previous" || from == "prev" {
|
||||
if field == "output" && state.LastOutput != nil {
|
||||
return state.LastOutput["output"]
|
||||
}
|
||||
return valueFromPath("previous."+field, state)
|
||||
}
|
||||
if from == "inputs" || from == "input" {
|
||||
if field == "" {
|
||||
return state.Inputs
|
||||
}
|
||||
return valueFromPath("inputs."+field, state)
|
||||
}
|
||||
if from == "outputs" {
|
||||
return valueFromPath("outputs."+field, state)
|
||||
}
|
||||
return valueFromPath(from+"."+field, state)
|
||||
}
|
||||
|
||||
func resolveBindingString(b FieldBinding, state *WorkflowLocalState) string {
|
||||
return strings.TrimSpace(fmt.Sprint(resolveBinding(b, state)))
|
||||
}
|
||||
|
||||
func resolveNodeInputBinding(cfg map[string]any, state *WorkflowLocalState) string {
|
||||
if b, ok := parseFieldBinding(cfg, "input_binding"); ok {
|
||||
return resolveBindingString(b, state)
|
||||
}
|
||||
// legacy template field removed — default previous.output
|
||||
return resolveBindingString(defaultBinding("previous", "output"), state)
|
||||
}
|
||||
|
||||
func resolveOutputSourceBinding(cfg map[string]any, state *WorkflowLocalState) any {
|
||||
if b, ok := parseFieldBinding(cfg, "source_binding"); ok {
|
||||
return resolveBinding(b, state)
|
||||
}
|
||||
return resolveBinding(defaultBinding("previous", "output"), state)
|
||||
}
|
||||
|
||||
func resolveHITLPromptBinding(cfg map[string]any, state *WorkflowLocalState) string {
|
||||
if b, ok := parseFieldBinding(cfg, "prompt_binding"); ok {
|
||||
return resolveBindingString(b, state)
|
||||
}
|
||||
if s := cfgString(cfg, "prompt"); s != "" {
|
||||
return s
|
||||
}
|
||||
return resolveBindingString(defaultBinding("previous", "output"), state)
|
||||
}
|
||||
|
||||
func toolArgumentBindings(cfg map[string]any) map[string]FieldBinding {
|
||||
raw, ok := cfg["argument_bindings"].(map[string]any)
|
||||
if !ok || len(raw) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make(map[string]FieldBinding, len(raw))
|
||||
for argName, v := range raw {
|
||||
m, ok := v.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
out[argName] = FieldBinding{
|
||||
From: strings.TrimSpace(fmt.Sprint(m["from"])),
|
||||
Field: strings.TrimSpace(fmt.Sprint(m["field"])),
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func resolveToolArguments(cfg map[string]any, state *WorkflowLocalState) (map[string]interface{}, error) {
|
||||
bindings := toolArgumentBindings(cfg)
|
||||
if len(bindings) > 0 {
|
||||
args := make(map[string]interface{}, len(bindings))
|
||||
for k, b := range bindings {
|
||||
args[k] = resolveBinding(b, state)
|
||||
}
|
||||
return args, nil
|
||||
}
|
||||
raw := cfgString(cfg, "arguments")
|
||||
if raw == "" {
|
||||
return map[string]interface{}{}, nil
|
||||
}
|
||||
var args map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(raw), &args); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if args == nil {
|
||||
args = map[string]interface{}{}
|
||||
}
|
||||
return args, nil
|
||||
}
|
||||
@@ -0,0 +1,69 @@
|
||||
package workflow
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// fileCheckPointStore persists Eino workflow checkpoints on disk (per run id).
|
||||
type fileCheckPointStore struct {
|
||||
dir string
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func newFileCheckPointStore(dir string) (*fileCheckPointStore, error) {
|
||||
dir = strings.TrimSpace(dir)
|
||||
if dir == "" {
|
||||
dir = filepath.Join("data", "workflow-checkpoints")
|
||||
}
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
return nil, fmt.Errorf("create workflow checkpoint dir: %w", err)
|
||||
}
|
||||
return &fileCheckPointStore{dir: dir}, nil
|
||||
}
|
||||
|
||||
func (s *fileCheckPointStore) path(id string) (string, error) {
|
||||
id = strings.TrimSpace(id)
|
||||
if id == "" {
|
||||
return "", fmt.Errorf("checkpoint id is empty")
|
||||
}
|
||||
if strings.Contains(id, "..") || strings.ContainsAny(id, `/\`) {
|
||||
return "", fmt.Errorf("invalid checkpoint id")
|
||||
}
|
||||
return filepath.Join(s.dir, id+".ckpt"), nil
|
||||
}
|
||||
|
||||
func (s *fileCheckPointStore) Get(_ context.Context, checkPointID string) ([]byte, bool, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
p, err := s.path(checkPointID)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
data, err := os.ReadFile(p)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil, false, nil
|
||||
}
|
||||
return nil, false, err
|
||||
}
|
||||
return data, true, nil
|
||||
}
|
||||
|
||||
func (s *fileCheckPointStore) Set(_ context.Context, checkPointID string, checkPoint []byte) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
p, err := s.path(checkPointID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tmp := p + ".tmp"
|
||||
if err := os.WriteFile(tmp, checkPoint, 0o600); err != nil {
|
||||
return err
|
||||
}
|
||||
return os.Rename(tmp, p)
|
||||
}
|
||||
@@ -0,0 +1,107 @@
|
||||
package workflow
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/cloudwego/eino/compose"
|
||||
)
|
||||
|
||||
func hasConditionalOutgoingEdges(idx *graphIndex, nodeID string) bool {
|
||||
for _, edge := range idx.outgoing[nodeID] {
|
||||
cond := firstNonEmpty(cfgString(edge.Config, "condition"), cfgString(edge.Config, "expression"))
|
||||
if cond != "" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func wireConditionBranch(
|
||||
wf *compose.Workflow[WorkflowInput, WorkflowOutput],
|
||||
nodeRefs map[string]*compose.WorkflowNode,
|
||||
idx *graphIndex,
|
||||
condID string,
|
||||
condNode graphNode,
|
||||
) error {
|
||||
edges := idx.outgoing[condID]
|
||||
if len(edges) == 0 {
|
||||
return nil
|
||||
}
|
||||
branchID := branchNodeID(condID)
|
||||
wf.AddPassthroughNode(branchID).AddInput(condID)
|
||||
|
||||
endNodes := map[string]bool{compose.END: true}
|
||||
for _, edge := range edges {
|
||||
endNodes[edge.Target] = true
|
||||
}
|
||||
|
||||
sortedEdges := append([]graphEdge(nil), edges...)
|
||||
sortEdgesByCanvas(sortedEdges, idx.nodes)
|
||||
|
||||
branch := compose.NewGraphBranch(func(runCtx context.Context, _ map[string]any) (string, error) {
|
||||
rt := workflowRuntimeFrom(runCtx)
|
||||
if rt == nil {
|
||||
return compose.END, fmt.Errorf("workflow runtime missing in context")
|
||||
}
|
||||
emitConditionBranchProgress(rt.args, rt.runID, condNode, sortedEdges, idx.nodes, rt.state)
|
||||
for edgeIdx, edge := range sortedEdges {
|
||||
if conditionBranchAllowed(edge, edgeIdx, rt.state) {
|
||||
return edge.Target, nil
|
||||
}
|
||||
}
|
||||
return compose.END, nil
|
||||
}, endNodes)
|
||||
wf.AddBranch(branchID, branch)
|
||||
|
||||
for _, edge := range edges {
|
||||
if target, ok := nodeRefs[edge.Target]; ok {
|
||||
target.AddInput(branchID)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func wireEdgeConditionBranch(
|
||||
wf *compose.Workflow[WorkflowInput, WorkflowOutput],
|
||||
nodeRefs map[string]*compose.WorkflowNode,
|
||||
idx *graphIndex,
|
||||
sourceID string,
|
||||
sourceNode graphNode,
|
||||
) error {
|
||||
edges := idx.outgoing[sourceID]
|
||||
if len(edges) == 0 {
|
||||
return nil
|
||||
}
|
||||
branchID := edgeBranchNodeID(sourceID)
|
||||
wf.AddPassthroughNode(branchID).AddInput(sourceID)
|
||||
|
||||
endNodes := map[string]bool{compose.END: true}
|
||||
for _, edge := range edges {
|
||||
endNodes[edge.Target] = true
|
||||
}
|
||||
|
||||
sortedEdges := append([]graphEdge(nil), edges...)
|
||||
sortEdgesByCanvas(sortedEdges, idx.nodes)
|
||||
|
||||
branch := compose.NewGraphBranch(func(runCtx context.Context, _ map[string]any) (string, error) {
|
||||
rt := workflowRuntimeFrom(runCtx)
|
||||
if rt == nil {
|
||||
return compose.END, fmt.Errorf("workflow runtime missing in context")
|
||||
}
|
||||
for edgeIdx, edge := range sortedEdges {
|
||||
if edgeAllowed(edge, sourceNode, edgeIdx, rt.state) {
|
||||
return edge.Target, nil
|
||||
}
|
||||
}
|
||||
return compose.END, nil
|
||||
}, endNodes)
|
||||
wf.AddBranch(branchID, branch)
|
||||
|
||||
for _, edge := range edges {
|
||||
if target, ok := nodeRefs[edge.Target]; ok {
|
||||
target.AddInput(branchID)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,22 @@
|
||||
package workflow
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/einoobserve"
|
||||
)
|
||||
|
||||
func attachWorkflowCallbacks(ctx context.Context, cfg *config.Config, args RunArgs, workflowName string) context.Context {
|
||||
if cfg == nil {
|
||||
return ctx
|
||||
}
|
||||
cbCfg := &cfg.MultiAgent.EinoCallbacks
|
||||
return einoobserve.AttachAgentRunCallbacks(ctx, cbCfg, einoobserve.Params{
|
||||
Logger: args.Logger,
|
||||
Progress: args.Progress,
|
||||
ConversationID: args.ConversationID,
|
||||
OrchMode: "workflow",
|
||||
OrchestratorName: workflowName,
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,190 @@
|
||||
package workflow
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/cloudwego/eino/compose"
|
||||
)
|
||||
|
||||
func executeEinoGraph(ctx context.Context, args RunArgs, runID string, workflowID string, version int, g *graphDef, state *WorkflowLocalState) error {
|
||||
_, err := invokeEinoGraph(ctx, args, runID, workflowID, version, g, state, false)
|
||||
return err
|
||||
}
|
||||
|
||||
func invokeEinoGraph(ctx context.Context, args RunArgs, runID string, workflowID string, version int, g *graphDef, state *WorkflowLocalState, resume bool) (bool, error) {
|
||||
wfInput := workflowInputFromMap(state.Inputs)
|
||||
if resume {
|
||||
wfInput = WorkflowInput{}
|
||||
}
|
||||
rt := &workflowRuntime{
|
||||
args: args,
|
||||
runID: runID,
|
||||
idx: indexGraph(g),
|
||||
state: state,
|
||||
}
|
||||
|
||||
art, err := defaultEngine.getOrCompile(ctx, workflowID, version, g)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("编译 Eino Workflow 失败: %w", err)
|
||||
}
|
||||
rt.idx = art.idx
|
||||
|
||||
runCtx := withWorkflowRuntime(ctx, rt)
|
||||
runCtx = attachWorkflowCallbacks(runCtx, args.AppCfg, args, workflowID)
|
||||
|
||||
invokeOpts := []compose.Option{compose.WithCheckPointID(runID)}
|
||||
for {
|
||||
_, err = art.runnable.Invoke(runCtx, wfInput, invokeOpts...)
|
||||
if err == nil {
|
||||
return false, nil
|
||||
}
|
||||
if hitlErr := extractAwaitingHITL(err, art, runID, args, state); hitlErr != nil {
|
||||
return true, hitlErr
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
|
||||
func extractAwaitingHITL(err error, art *compiledArtifact, runID string, args RunArgs, state *WorkflowLocalState) error {
|
||||
info, ok := compose.ExtractInterruptInfo(err)
|
||||
if !ok || len(art.hitlIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
nodeID := nextHITLNodeID(info, art.hitlIDs)
|
||||
node := art.idx.nodes[nodeID]
|
||||
if nodeID == "" {
|
||||
return nil
|
||||
}
|
||||
prompt := resolveHITLPromptBinding(node.Config, state)
|
||||
label := firstNonEmpty(node.Label, nodeID)
|
||||
if args.DB != nil {
|
||||
pending := map[string]any{
|
||||
"nodeId": nodeID,
|
||||
"label": label,
|
||||
"prompt": prompt,
|
||||
"reviewer": cfgString(node.Config, "reviewer"),
|
||||
}
|
||||
pendingJSON, _ := json.Marshal(pending)
|
||||
_ = args.DB.SetWorkflowRunAwaitingHITL(runID, nodeID, string(pendingJSON))
|
||||
}
|
||||
if args.Progress != nil {
|
||||
args.Progress("workflow_hitl_waiting", fmt.Sprintf("等待人工确认:%s", label), map[string]any{
|
||||
"workflowRunId": runID,
|
||||
"nodeId": nodeID,
|
||||
"label": label,
|
||||
"prompt": prompt,
|
||||
"reviewer": cfgString(node.Config, "reviewer"),
|
||||
"mode": "interactive",
|
||||
"resumeApi": fmt.Sprintf("/api/workflows/runs/%s/resume", runID),
|
||||
})
|
||||
}
|
||||
return &AwaitingHITLError{
|
||||
RunID: runID,
|
||||
NodeID: nodeID,
|
||||
NodeLabel: label,
|
||||
Prompt: prompt,
|
||||
Reviewer: cfgString(node.Config, "reviewer"),
|
||||
}
|
||||
}
|
||||
|
||||
func nextHITLNodeID(info *compose.InterruptInfo, hitlIDs []string) string {
|
||||
if info != nil && len(info.BeforeNodes) > 0 {
|
||||
for _, id := range info.BeforeNodes {
|
||||
for _, hitl := range hitlIDs {
|
||||
if id == hitl {
|
||||
return id
|
||||
}
|
||||
}
|
||||
}
|
||||
return info.BeforeNodes[0]
|
||||
}
|
||||
if len(hitlIDs) == 0 {
|
||||
return ""
|
||||
}
|
||||
return hitlIDs[0]
|
||||
}
|
||||
|
||||
// ResumeWorkflowRun continues a run paused at HITL after human decision.
|
||||
func ResumeWorkflowRun(ctx context.Context, args RunArgs, runID string, approved bool, comment string) (*RunResult, error) {
|
||||
run, err := args.DB.GetWorkflowRun(runID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if run == nil {
|
||||
return nil, fmt.Errorf("工作流运行不存在")
|
||||
}
|
||||
if run.Status != "awaiting_hitl" {
|
||||
return nil, fmt.Errorf("工作流运行不在等待审批状态: %s", run.Status)
|
||||
}
|
||||
wf, err := args.DB.GetWorkflowDefinition(run.WorkflowID)
|
||||
if err != nil || wf == nil {
|
||||
return nil, fmt.Errorf("工作流定义不存在")
|
||||
}
|
||||
graph, err := parseGraph(wf.GraphJSON)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var input map[string]interface{}
|
||||
_ = json.Unmarshal([]byte(run.InputJSON), &input)
|
||||
state := newWorkflowLocalState(input, runID)
|
||||
if state.Inputs == nil {
|
||||
state.Inputs = map[string]any{}
|
||||
}
|
||||
state.Inputs["_hitl_approved"] = approved
|
||||
state.Inputs["_hitl_comment"] = strings.TrimSpace(comment)
|
||||
state.Inputs["_hitl_node_id"] = run.PendingHITLNodeID
|
||||
|
||||
if !approved {
|
||||
errText := strings.TrimSpace(comment)
|
||||
if errText == "" {
|
||||
errText = "人工审批拒绝"
|
||||
}
|
||||
_ = args.DB.FinishWorkflowRun(runID, "rejected", "", errText)
|
||||
return &RunResult{
|
||||
RunID: runID,
|
||||
Response: fmt.Sprintf("工作流已在审批节点「%s」被拒绝。", run.PendingHITLNodeID),
|
||||
Status: "rejected",
|
||||
}, nil
|
||||
}
|
||||
|
||||
_ = args.DB.SetWorkflowRunStatus(runID, "running")
|
||||
resumeArgs := args
|
||||
if strings.TrimSpace(resumeArgs.ConversationID) == "" {
|
||||
resumeArgs.ConversationID = run.ConversationID
|
||||
}
|
||||
|
||||
awaiting, err := invokeEinoGraph(ctx, resumeArgs, runID, wf.ID, run.WorkflowVersion, graph, state, true)
|
||||
if err != nil {
|
||||
if IsAwaitingHITL(err) {
|
||||
return &RunResult{
|
||||
RunID: runID,
|
||||
Status: "awaiting_hitl",
|
||||
Response: fmt.Sprintf("工作流在节点「%s」等待下一次人工确认。", err.(*AwaitingHITLError).NodeID),
|
||||
AwaitingHITL: true,
|
||||
}, nil
|
||||
}
|
||||
_ = args.DB.FinishWorkflowRun(runID, "failed", "", err.Error())
|
||||
return nil, err
|
||||
}
|
||||
_ = awaiting
|
||||
|
||||
output := map[string]interface{}{
|
||||
"workflowId": wf.ID,
|
||||
"workflowName": wf.Name,
|
||||
"workflowVersion": wf.Version,
|
||||
"workflowRunId": runID,
|
||||
"status": "completed",
|
||||
"outputs": state.Outputs,
|
||||
"executedNodes": state.Executed,
|
||||
"skippedNodes": state.Skipped,
|
||||
"engine": "eino_workflow",
|
||||
}
|
||||
outputJSON, _ := json.Marshal(output)
|
||||
response := renderWorkflowResponse(args.Role.Name, wf.Name, wf.Version, runID, state)
|
||||
_ = args.DB.FinishWorkflowRun(runID, "completed", string(outputJSON), "")
|
||||
return &RunResult{Response: response, RunID: runID, Status: "completed"}, nil
|
||||
}
|
||||
@@ -0,0 +1,195 @@
|
||||
package workflow
|
||||
|
||||
import (
|
||||
"context"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/database"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func testWorkflowDB(t *testing.T) *database.DB {
|
||||
t.Helper()
|
||||
dir := t.TempDir()
|
||||
db, err := database.NewDB(filepath.Join(dir, "workflow.db"), zap.NewNop())
|
||||
if err != nil {
|
||||
t.Fatalf("NewDB: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = db.Close() })
|
||||
return db
|
||||
}
|
||||
|
||||
func linearStartOutputGraph() string {
|
||||
return `{
|
||||
"nodes": [
|
||||
{"id": "start-1", "type": "start", "label": "开始", "position": {"x": 0, "y": 0}, "config": {}},
|
||||
{"id": "out-1", "type": "output", "label": "输出", "position": {"x": 0, "y": 120}, "config": {"output_key": "result", "source_binding": {"from": "inputs", "field": "message"}}}
|
||||
],
|
||||
"edges": [
|
||||
{"id": "e1", "source": "start-1", "target": "out-1"}
|
||||
],
|
||||
"config": {"schema_version": 1}
|
||||
}`
|
||||
}
|
||||
|
||||
func conditionBranchGraph() string {
|
||||
return `{
|
||||
"nodes": [
|
||||
{"id": "start-1", "type": "start", "label": "开始", "position": {"x": 0, "y": 0}, "config": {}},
|
||||
{"id": "cond-1", "type": "condition", "label": "判断", "position": {"x": 0, "y": 80}, "config": {"expression": "{{inputs.message}} == yes"}},
|
||||
{"id": "out-yes", "type": "output", "label": "是", "position": {"x": -80, "y": 160}, "config": {"output_key": "branch", "static_value": "yes"}},
|
||||
{"id": "out-no", "type": "output", "label": "否", "position": {"x": 80, "y": 160}, "config": {"output_key": "branch", "static_value": "no"}}
|
||||
],
|
||||
"edges": [
|
||||
{"id": "e1", "source": "start-1", "target": "cond-1"},
|
||||
{"id": "e2", "source": "cond-1", "target": "out-yes", "label": "是"},
|
||||
{"id": "e3", "source": "cond-1", "target": "out-no", "label": "否"}
|
||||
],
|
||||
"config": {"schema_version": 1}
|
||||
}`
|
||||
}
|
||||
|
||||
func TestValidateGraphJSON_linear(t *testing.T) {
|
||||
if err := ValidateGraphJSON(context.Background(), linearStartOutputGraph()); err != nil {
|
||||
t.Fatalf("validate: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompileEngine_linear(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
SetCheckpointDir(t.TempDir())
|
||||
g, err := parseGraph(linearStartOutputGraph())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := defaultEngine.compile(ctx, g); err != nil {
|
||||
t.Fatalf("compile: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func createTestWorkflowRun(t *testing.T, db *database.DB, runID string) {
|
||||
t.Helper()
|
||||
if err := db.CreateWorkflowRun(&database.WorkflowRun{
|
||||
ID: runID,
|
||||
WorkflowID: "test-wf",
|
||||
Status: "running",
|
||||
}); err != nil {
|
||||
t.Fatalf("CreateWorkflowRun: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteEinoGraph_linearStartOutput(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
SetCheckpointDir(t.TempDir())
|
||||
db := testWorkflowDB(t)
|
||||
createTestWorkflowRun(t, db, "run-linear")
|
||||
g, err := parseGraph(linearStartOutputGraph())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
state := newWorkflowLocalState(map[string]interface{}{"message": "ping"}, "run-linear")
|
||||
args := RunArgs{DB: db}
|
||||
if err := executeEinoGraph(ctx, args, "run-linear", "test-wf", 1, g, state); err != nil {
|
||||
t.Fatalf("execute: %v", err)
|
||||
}
|
||||
if got := state.Outputs["result"]; got != "ping" {
|
||||
t.Fatalf("outputs[result] = %v, want ping", got)
|
||||
}
|
||||
if len(state.Executed) != 2 {
|
||||
t.Fatalf("executed nodes = %d, want 2", len(state.Executed))
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteEinoGraph_conditionBranch(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
SetCheckpointDir(t.TempDir())
|
||||
db := testWorkflowDB(t)
|
||||
createTestWorkflowRun(t, db, "run-yes")
|
||||
createTestWorkflowRun(t, db, "run-no")
|
||||
g, err := parseGraph(conditionBranchGraph())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
stateYes := newWorkflowLocalState(map[string]interface{}{"message": "yes"}, "run-yes")
|
||||
if err := executeEinoGraph(ctx, RunArgs{DB: db}, "run-yes", "test-wf-branch", 1, g, stateYes); err != nil {
|
||||
t.Fatalf("execute yes: %v", err)
|
||||
}
|
||||
if got := stateYes.Outputs["branch"]; got != "yes" {
|
||||
t.Fatalf("yes branch output = %v", got)
|
||||
}
|
||||
|
||||
stateNo := newWorkflowLocalState(map[string]interface{}{"message": "no"}, "run-no")
|
||||
if err := executeEinoGraph(ctx, RunArgs{DB: db}, "run-no", "test-wf-branch", 1, g, stateNo); err != nil {
|
||||
t.Fatalf("execute no: %v", err)
|
||||
}
|
||||
if got := stateNo.Outputs["branch"]; got != "no" {
|
||||
t.Fatalf("no branch output = %v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunRoleBoundWorkflow_integration(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
SetCheckpointDir(t.TempDir())
|
||||
db := testWorkflowDB(t)
|
||||
graph := linearStartOutputGraph()
|
||||
if err := db.UpsertWorkflowDefinition(&database.WorkflowDefinition{
|
||||
ID: "wf-linear",
|
||||
Name: "线性流程",
|
||||
Version: 1,
|
||||
GraphJSON: graph,
|
||||
Enabled: true,
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
role := config.RoleConfig{
|
||||
Name: "tester",
|
||||
Enabled: true,
|
||||
WorkflowID: "wf-linear",
|
||||
WorkflowPolicy: "auto",
|
||||
}
|
||||
result, err := RunRoleBoundWorkflow(ctx, RunArgs{
|
||||
DB: db,
|
||||
Logger: zap.NewNop(),
|
||||
Role: role,
|
||||
UserMessage: "from-role",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("RunRoleBoundWorkflow: %v", err)
|
||||
}
|
||||
if result == nil || result.RunID == "" {
|
||||
t.Fatal("expected run result")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompiledCache_reuse(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
SetCheckpointDir(t.TempDir())
|
||||
InvalidateCompiledCache("cache-wf")
|
||||
g, err := parseGraph(linearStartOutputGraph())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
a1, err := defaultEngine.getOrCompile(ctx, "cache-wf", 1, g)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
a2, err := defaultEngine.getOrCompile(ctx, "cache-wf", 1, g)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if a1 != a2 {
|
||||
t.Fatal("expected cached artifact pointer reuse")
|
||||
}
|
||||
InvalidateCompiledCache("cache-wf")
|
||||
a3, err := defaultEngine.getOrCompile(ctx, "cache-wf", 1, g)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if a1 == a3 {
|
||||
t.Fatal("expected new artifact after invalidation")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,64 @@
|
||||
package workflow
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"cyberstrike-ai/internal/agent"
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/database"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type workflowRuntimeCtxKey struct{}
|
||||
|
||||
// workflowRuntime carries per-run execution context into Eino Workflow local state.
|
||||
type workflowRuntime struct {
|
||||
args RunArgs
|
||||
runID string
|
||||
idx *graphIndex
|
||||
state *WorkflowLocalState
|
||||
}
|
||||
|
||||
func withWorkflowRuntime(ctx context.Context, rt *workflowRuntime) context.Context {
|
||||
return context.WithValue(ctx, workflowRuntimeCtxKey{}, rt)
|
||||
}
|
||||
|
||||
func workflowRuntimeFrom(ctx context.Context) *workflowRuntime {
|
||||
rt, _ := ctx.Value(workflowRuntimeCtxKey{}).(*workflowRuntime)
|
||||
return rt
|
||||
}
|
||||
|
||||
func newWorkflowRuntime(args RunArgs, runID string, idx *graphIndex, inputs map[string]interface{}) *workflowRuntime {
|
||||
return &workflowRuntime{
|
||||
args: args,
|
||||
runID: runID,
|
||||
idx: idx,
|
||||
state: newWorkflowLocalState(inputs, runID),
|
||||
}
|
||||
}
|
||||
|
||||
// RunArgs is the execution context for a role-bound workflow run.
|
||||
type RunArgs struct {
|
||||
DB *database.DB
|
||||
Logger *zap.Logger
|
||||
Role config.RoleConfig
|
||||
AppCfg *config.Config
|
||||
Agent *agent.Agent
|
||||
ConversationID string
|
||||
ProjectID string
|
||||
UserMessage string
|
||||
History []agent.ChatMessage
|
||||
RoleTools []string
|
||||
AgentsMarkdownDir string
|
||||
SystemPromptExtra string
|
||||
AssistantMessageID string
|
||||
Progress agent.ProgressCallback
|
||||
}
|
||||
|
||||
type RunResult struct {
|
||||
Response string
|
||||
RunID string
|
||||
Status string
|
||||
AwaitingHITL bool
|
||||
}
|
||||
@@ -0,0 +1,236 @@
|
||||
package workflow
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/cloudwego/eino/compose"
|
||||
)
|
||||
|
||||
type compiledArtifact struct {
|
||||
runnable compose.Runnable[WorkflowInput, WorkflowOutput]
|
||||
idx *graphIndex
|
||||
hitlIDs []string
|
||||
}
|
||||
|
||||
// Engine compiles and caches Eino Workflow artifacts.
|
||||
type Engine struct {
|
||||
mu sync.RWMutex
|
||||
cache map[string]*compiledArtifact
|
||||
cpStore compose.CheckPointStore
|
||||
cpStoreMu sync.Once
|
||||
cpStoreErr error
|
||||
checkpointDir string
|
||||
}
|
||||
|
||||
var defaultEngine = &Engine{
|
||||
cache: make(map[string]*compiledArtifact),
|
||||
checkpointDir: "data/workflow-checkpoints",
|
||||
}
|
||||
|
||||
// SetCheckpointDir overrides the workflow checkpoint root (mainly for tests).
|
||||
func SetCheckpointDir(dir string) {
|
||||
defaultEngine.mu.Lock()
|
||||
defer defaultEngine.mu.Unlock()
|
||||
defaultEngine.checkpointDir = strings.TrimSpace(dir)
|
||||
defaultEngine.cpStore = nil
|
||||
defaultEngine.cpStoreErr = nil
|
||||
defaultEngine.cpStoreMu = sync.Once{}
|
||||
}
|
||||
|
||||
func (e *Engine) checkpointStore() (compose.CheckPointStore, error) {
|
||||
e.cpStoreMu.Do(func() {
|
||||
e.cpStore, e.cpStoreErr = newFileCheckPointStore(e.checkpointDir)
|
||||
})
|
||||
return e.cpStore, e.cpStoreErr
|
||||
}
|
||||
|
||||
// InvalidateCompiledCache drops cached compilations for a workflow id.
|
||||
func InvalidateCompiledCache(workflowID string) {
|
||||
workflowID = strings.TrimSpace(workflowID)
|
||||
if workflowID == "" {
|
||||
return
|
||||
}
|
||||
defaultEngine.mu.Lock()
|
||||
defer defaultEngine.mu.Unlock()
|
||||
for key := range defaultEngine.cache {
|
||||
if strings.HasPrefix(key, workflowID+":") {
|
||||
delete(defaultEngine.cache, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateGraphJSON parses and trial-compiles a canvas graph (save-time gate).
|
||||
func ValidateGraphJSON(ctx context.Context, graphJSON string) error {
|
||||
g, err := parseGraph(graphJSON)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
idx := indexGraph(g)
|
||||
if len(findStartNodeIDs(idx)) == 0 {
|
||||
return fmt.Errorf("工作流缺少可执行的起点节点")
|
||||
}
|
||||
if !hasTerminalNode(idx) {
|
||||
return fmt.Errorf("工作流至少需要一个无出边的终点或 output/end 节点")
|
||||
}
|
||||
_, err = defaultEngine.compile(ctx, g)
|
||||
return err
|
||||
}
|
||||
|
||||
func hasTerminalNode(idx *graphIndex) bool {
|
||||
for id, node := range idx.nodes {
|
||||
if len(idx.outgoing[id]) == 0 {
|
||||
return true
|
||||
}
|
||||
if strings.EqualFold(node.Type, "end") || strings.EqualFold(node.Type, "output") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (e *Engine) getOrCompile(ctx context.Context, workflowID string, version int, g *graphDef) (*compiledArtifact, error) {
|
||||
key := cacheKey(workflowID, version)
|
||||
e.mu.RLock()
|
||||
if art, ok := e.cache[key]; ok {
|
||||
e.mu.RUnlock()
|
||||
return art, nil
|
||||
}
|
||||
e.mu.RUnlock()
|
||||
|
||||
art, err := e.compile(ctx, g)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
e.mu.Lock()
|
||||
if existing, ok := e.cache[key]; ok {
|
||||
e.mu.Unlock()
|
||||
return existing, nil
|
||||
}
|
||||
e.cache[key] = art
|
||||
e.mu.Unlock()
|
||||
return art, nil
|
||||
}
|
||||
|
||||
func (e *Engine) compile(ctx context.Context, g *graphDef) (*compiledArtifact, error) {
|
||||
cpStore, err := e.checkpointStore()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
idx := indexGraph(g)
|
||||
hitlIDs := collectHITLNodeIDs(idx)
|
||||
compileOpts := []compose.GraphCompileOption{
|
||||
compose.WithGraphName("CyberStrikeWorkflow"),
|
||||
compose.WithCheckPointStore(cpStore),
|
||||
}
|
||||
if len(hitlIDs) > 0 {
|
||||
compileOpts = append(compileOpts, compose.WithInterruptBeforeNodes(hitlIDs))
|
||||
}
|
||||
|
||||
wf := compose.NewWorkflow[WorkflowInput, WorkflowOutput](
|
||||
compose.WithGenLocalState(func(runCtx context.Context) *WorkflowLocalState {
|
||||
if rt := workflowRuntimeFrom(runCtx); rt != nil && rt.state != nil {
|
||||
return rt.state
|
||||
}
|
||||
return &WorkflowLocalState{
|
||||
Outputs: make(map[string]any),
|
||||
NodeOutputs: make(map[string]map[string]any),
|
||||
NodeProceed: make(map[string]bool),
|
||||
}
|
||||
}),
|
||||
)
|
||||
|
||||
nodeRefs := make(map[string]*compose.WorkflowNode, len(idx.nodes))
|
||||
for id, node := range idx.nodes {
|
||||
n := node
|
||||
if strings.EqualFold(n.Type, "agent") {
|
||||
sub, err := compileAgentSubgraph(ctx, n)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("编译 Agent 子图 %s 失败: %w", id, err)
|
||||
}
|
||||
nodeRefs[id] = wf.AddGraphNode(id, sub)
|
||||
continue
|
||||
}
|
||||
if strings.EqualFold(n.Type, "start") {
|
||||
nodeRefs[id] = wf.AddLambdaNode(id, compose.InvokableLambda(func(runCtx context.Context, _ WorkflowInput) (WorkflowNodeOutput, error) {
|
||||
return runWorkflowNodeLambda(runCtx, n)
|
||||
}))
|
||||
continue
|
||||
}
|
||||
nodeRefs[id] = wf.AddLambdaNode(id, compose.InvokableLambda(func(runCtx context.Context, _ WorkflowNodeOutput) (WorkflowNodeOutput, error) {
|
||||
return runWorkflowNodeLambda(runCtx, n)
|
||||
}))
|
||||
}
|
||||
|
||||
for id, node := range idx.nodes {
|
||||
if strings.EqualFold(node.Type, "condition") {
|
||||
if err := wireConditionBranch(wf, nodeRefs, idx, id, node); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
continue
|
||||
}
|
||||
if hasConditionalOutgoingEdges(idx, id) {
|
||||
if err := wireEdgeConditionBranch(wf, nodeRefs, idx, id, node); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
continue
|
||||
}
|
||||
for _, edge := range idx.outgoing[id] {
|
||||
if target, ok := nodeRefs[edge.Target]; ok {
|
||||
target.AddInput(id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, startID := range findStartNodeIDs(idx) {
|
||||
if ref, ok := nodeRefs[startID]; ok {
|
||||
ref.AddInput(compose.START)
|
||||
}
|
||||
}
|
||||
|
||||
endNode := wf.End()
|
||||
for id, node := range idx.nodes {
|
||||
if len(idx.outgoing[id]) == 0 || strings.EqualFold(node.Type, "end") {
|
||||
endNode.AddInput(id, compose.ToField(id))
|
||||
}
|
||||
}
|
||||
|
||||
runnable, err := wf.Compile(ctx, compileOpts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &compiledArtifact{runnable: runnable, idx: idx, hitlIDs: hitlIDs}, nil
|
||||
}
|
||||
|
||||
func collectHITLNodeIDs(idx *graphIndex) []string {
|
||||
var ids []string
|
||||
for id, node := range idx.nodes {
|
||||
if strings.EqualFold(node.Type, "hitl") {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
func runWorkflowNodeLambda(runCtx context.Context, n graphNode) (WorkflowNodeOutput, error) {
|
||||
localRT := workflowRuntimeFrom(runCtx)
|
||||
if localRT == nil {
|
||||
return nil, fmt.Errorf("workflow runtime missing in context")
|
||||
}
|
||||
result, proceed, err := executeNode(runCtx, localRT.args, localRT.runID, n, localRT.state)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
localRT.state.NodeOutputs[n.ID] = result
|
||||
localRT.state.LastOutput = result
|
||||
if !proceed && !strings.EqualFold(n.Type, "end") {
|
||||
label := firstNonEmpty(n.Label, n.ID)
|
||||
if errText := cfgString(result, "error"); errText != "" {
|
||||
return result, fmt.Errorf("节点「%s」失败: %s", label, errText)
|
||||
}
|
||||
return result, fmt.Errorf("节点「%s」未继续执行", label)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
@@ -0,0 +1,24 @@
|
||||
package workflow
|
||||
|
||||
import "errors"
|
||||
|
||||
// AwaitingHITLError indicates the workflow paused before a HITL node for human approval.
|
||||
type AwaitingHITLError struct {
|
||||
RunID string
|
||||
NodeID string
|
||||
NodeLabel string
|
||||
Prompt string
|
||||
Reviewer string
|
||||
}
|
||||
|
||||
func (e *AwaitingHITLError) Error() string {
|
||||
if e == nil {
|
||||
return "workflow awaiting human approval"
|
||||
}
|
||||
return "workflow awaiting human approval at node " + e.NodeID
|
||||
}
|
||||
|
||||
func IsAwaitingHITL(err error) bool {
|
||||
var target *AwaitingHITLError
|
||||
return errors.As(err, &target)
|
||||
}
|
||||
@@ -0,0 +1,153 @@
|
||||
package workflow
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type graphDef struct {
|
||||
Nodes []graphNode `json:"nodes"`
|
||||
Edges []graphEdge `json:"edges"`
|
||||
Config map[string]any `json:"config"`
|
||||
}
|
||||
|
||||
type graphNode struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Label string `json:"label"`
|
||||
Position graphPosition `json:"position"`
|
||||
Config map[string]any `json:"config"`
|
||||
}
|
||||
|
||||
type graphEdge struct {
|
||||
ID string `json:"id"`
|
||||
Source string `json:"source"`
|
||||
Target string `json:"target"`
|
||||
Label string `json:"label"`
|
||||
Config map[string]any `json:"config"`
|
||||
}
|
||||
|
||||
type graphPosition struct {
|
||||
X float64 `json:"x"`
|
||||
Y float64 `json:"y"`
|
||||
}
|
||||
|
||||
type graphIndex struct {
|
||||
nodes map[string]graphNode
|
||||
outgoing map[string][]graphEdge
|
||||
incoming map[string][]graphEdge
|
||||
}
|
||||
|
||||
func parseGraph(raw string) (*graphDef, error) {
|
||||
var g graphDef
|
||||
if err := json.Unmarshal([]byte(strings.TrimSpace(raw)), &g); err != nil {
|
||||
return nil, fmt.Errorf("解析工作流图失败: %w", err)
|
||||
}
|
||||
if len(g.Nodes) == 0 {
|
||||
return nil, fmt.Errorf("工作流没有节点")
|
||||
}
|
||||
if g.Config == nil {
|
||||
g.Config = make(map[string]any)
|
||||
}
|
||||
return &g, nil
|
||||
}
|
||||
|
||||
func indexGraph(g *graphDef) *graphIndex {
|
||||
idx := &graphIndex{
|
||||
nodes: make(map[string]graphNode, len(g.Nodes)),
|
||||
outgoing: make(map[string][]graphEdge),
|
||||
incoming: make(map[string][]graphEdge),
|
||||
}
|
||||
for _, node := range g.Nodes {
|
||||
node.ID = strings.TrimSpace(node.ID)
|
||||
if node.ID == "" {
|
||||
continue
|
||||
}
|
||||
if strings.TrimSpace(node.Type) == "" {
|
||||
node.Type = "tool"
|
||||
}
|
||||
if node.Config == nil {
|
||||
node.Config = make(map[string]any)
|
||||
}
|
||||
idx.nodes[node.ID] = node
|
||||
}
|
||||
for _, edge := range g.Edges {
|
||||
if _, ok := idx.nodes[edge.Source]; !ok {
|
||||
continue
|
||||
}
|
||||
if _, ok := idx.nodes[edge.Target]; !ok {
|
||||
continue
|
||||
}
|
||||
idx.outgoing[edge.Source] = append(idx.outgoing[edge.Source], edge)
|
||||
idx.incoming[edge.Target] = append(idx.incoming[edge.Target], edge)
|
||||
}
|
||||
for source := range idx.outgoing {
|
||||
sortEdgesByCanvas(idx.outgoing[source], idx.nodes)
|
||||
}
|
||||
return idx
|
||||
}
|
||||
|
||||
func sortEdgesByCanvas(edges []graphEdge, nodes map[string]graphNode) {
|
||||
sort.SliceStable(edges, func(i, j int) bool {
|
||||
a := nodes[edges[i].Target]
|
||||
b := nodes[edges[j].Target]
|
||||
if a.Position.Y != b.Position.Y {
|
||||
return a.Position.Y < b.Position.Y
|
||||
}
|
||||
if a.Position.X != b.Position.X {
|
||||
return a.Position.X < b.Position.X
|
||||
}
|
||||
return edges[i].Target < edges[j].Target
|
||||
})
|
||||
}
|
||||
|
||||
func sortNodeIDsByCanvas(ids []string, nodes map[string]graphNode) {
|
||||
sort.SliceStable(ids, func(i, j int) bool {
|
||||
a := nodes[ids[i]]
|
||||
b := nodes[ids[j]]
|
||||
if a.Position.Y != b.Position.Y {
|
||||
return a.Position.Y < b.Position.Y
|
||||
}
|
||||
if a.Position.X != b.Position.X {
|
||||
return a.Position.X < b.Position.X
|
||||
}
|
||||
return ids[i] < ids[j]
|
||||
})
|
||||
}
|
||||
|
||||
func findStartNodeIDs(idx *graphIndex) []string {
|
||||
var queue []string
|
||||
for id, node := range idx.nodes {
|
||||
if strings.EqualFold(node.Type, "start") {
|
||||
queue = append(queue, id)
|
||||
}
|
||||
}
|
||||
if len(queue) == 0 {
|
||||
inDegree := make(map[string]int, len(idx.nodes))
|
||||
for id := range idx.nodes {
|
||||
inDegree[id] = 0
|
||||
}
|
||||
for _, edges := range idx.outgoing {
|
||||
for _, edge := range edges {
|
||||
inDegree[edge.Target]++
|
||||
}
|
||||
}
|
||||
for id, deg := range inDegree {
|
||||
if deg == 0 {
|
||||
queue = append(queue, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
sortNodeIDsByCanvas(queue, idx.nodes)
|
||||
return queue
|
||||
}
|
||||
|
||||
func branchNodeID(nodeID string) string {
|
||||
return nodeID + "__eino_branch"
|
||||
}
|
||||
|
||||
func edgeBranchNodeID(nodeID string) string {
|
||||
return nodeID + "__eino_edge_branch"
|
||||
}
|
||||
@@ -0,0 +1,131 @@
|
||||
package workflow
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/database"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
func executeNode(ctx context.Context, args RunArgs, runID string, node graphNode, state *WorkflowLocalState) (map[string]any, bool, error) {
|
||||
label := node.Label
|
||||
if strings.TrimSpace(label) == "" {
|
||||
label = node.ID
|
||||
}
|
||||
nodeRunID := uuid.NewString()
|
||||
input := map[string]any{
|
||||
"nodeId": node.ID,
|
||||
"nodeType": node.Type,
|
||||
"label": label,
|
||||
"inputs": state.Inputs,
|
||||
"previous": state.LastOutput,
|
||||
}
|
||||
inputJSON, _ := json.Marshal(input)
|
||||
if err := args.DB.CreateWorkflowNodeRun(&database.WorkflowNodeRun{
|
||||
ID: nodeRunID,
|
||||
RunID: runID,
|
||||
NodeID: node.ID,
|
||||
Status: "running",
|
||||
InputJSON: string(inputJSON),
|
||||
StartedAt: time.Now(),
|
||||
}); err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
if args.Progress != nil {
|
||||
args.Progress("workflow_node_start", fmt.Sprintf("开始节点:%s", label), map[string]any{
|
||||
"workflowRunId": runID,
|
||||
"nodeRunId": nodeRunID,
|
||||
"nodeId": node.ID,
|
||||
"nodeType": node.Type,
|
||||
"label": label,
|
||||
})
|
||||
}
|
||||
|
||||
result, proceed, status, errText := runBuiltinNode(ctx, args, node, state)
|
||||
outputJSON, _ := json.Marshal(result)
|
||||
if err := args.DB.FinishWorkflowNodeRun(nodeRunID, status, string(outputJSON), errText); err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
if status == "skipped" {
|
||||
state.Skipped = append(state.Skipped, label)
|
||||
} else {
|
||||
state.Executed = append(state.Executed, label)
|
||||
}
|
||||
if args.Progress != nil {
|
||||
progressData := map[string]any{
|
||||
"workflowRunId": runID,
|
||||
"nodeRunId": nodeRunID,
|
||||
"nodeId": node.ID,
|
||||
"nodeType": node.Type,
|
||||
"label": label,
|
||||
"status": status,
|
||||
"output": result,
|
||||
}
|
||||
progressMsg := fmt.Sprintf("节点完成:%s(%s)", label, status)
|
||||
if strings.EqualFold(node.Type, "condition") {
|
||||
matched := false
|
||||
if v, ok := result["matched"].(bool); ok {
|
||||
matched = v
|
||||
}
|
||||
expr := cfgString(node.Config, "expression")
|
||||
if matched {
|
||||
progressMsg = fmt.Sprintf("条件判断:%s → 是", label)
|
||||
} else {
|
||||
progressMsg = fmt.Sprintf("条件判断:%s → 否", label)
|
||||
}
|
||||
progressData["expression"] = expr
|
||||
progressData["matched"] = matched
|
||||
}
|
||||
args.Progress("workflow_node_result", progressMsg, progressData)
|
||||
}
|
||||
state.NodeProceed[node.ID] = proceed
|
||||
return result, proceed, nil
|
||||
}
|
||||
|
||||
func emitConditionBranchProgress(args RunArgs, runID string, node graphNode, edges []graphEdge, nodes map[string]graphNode, state *WorkflowLocalState) {
|
||||
if args.Progress == nil || len(edges) == 0 {
|
||||
return
|
||||
}
|
||||
for edgeIdx, edge := range edges {
|
||||
allowed := edgeAllowed(edge, node, edgeIdx, state)
|
||||
target := nodes[edge.Target]
|
||||
targetLabel := strings.TrimSpace(target.Label)
|
||||
if targetLabel == "" {
|
||||
targetLabel = edge.Target
|
||||
}
|
||||
branchLabel := strings.TrimSpace(edge.Label)
|
||||
if branchLabel == "" {
|
||||
switch edgeIdx {
|
||||
case 0:
|
||||
branchLabel = "是"
|
||||
case 1:
|
||||
branchLabel = "否"
|
||||
default:
|
||||
branchLabel = fmt.Sprintf("分支 %d", edgeIdx+1)
|
||||
}
|
||||
}
|
||||
cond := firstNonEmpty(cfgString(edge.Config, "condition"), cfgString(edge.Config, "expression"))
|
||||
eventType := "workflow_branch_skipped"
|
||||
msg := fmt.Sprintf("跳过分支「%s」→ %s", branchLabel, targetLabel)
|
||||
if allowed {
|
||||
eventType = "workflow_branch_taken"
|
||||
msg = fmt.Sprintf("执行分支「%s」→ %s", branchLabel, targetLabel)
|
||||
}
|
||||
args.Progress(eventType, msg, map[string]any{
|
||||
"workflowRunId": runID,
|
||||
"nodeId": node.ID,
|
||||
"nodeType": node.Type,
|
||||
"label": node.Label,
|
||||
"branchLabel": branchLabel,
|
||||
"targetId": edge.Target,
|
||||
"targetLabel": targetLabel,
|
||||
"edgeCondition": cond,
|
||||
"matched": conditionMatched(state),
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,323 @@
|
||||
package workflow
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/agent"
|
||||
"cyberstrike-ai/internal/multiagent"
|
||||
)
|
||||
|
||||
func runBuiltinNode(ctx context.Context, args RunArgs, node graphNode, state *WorkflowLocalState) (map[string]any, bool, string, string) {
|
||||
cfg := node.Config
|
||||
switch strings.ToLower(strings.TrimSpace(node.Type)) {
|
||||
case "start":
|
||||
out := map[string]any{
|
||||
"output": state.Inputs["message"],
|
||||
"message": state.Inputs["message"],
|
||||
"conversationId": state.Inputs["conversationId"],
|
||||
"projectId": state.Inputs["projectId"],
|
||||
}
|
||||
return out, true, "completed", ""
|
||||
case "condition":
|
||||
expr := cfgString(cfg, "expression")
|
||||
ok := evalCondition(expr, state)
|
||||
out := map[string]any{"output": ok, "condition": expr, "matched": ok}
|
||||
return out, true, "completed", ""
|
||||
case "output":
|
||||
key := cfgString(cfg, "output_key")
|
||||
if key == "" {
|
||||
key = "result"
|
||||
}
|
||||
var value any
|
||||
if v := cfgString(cfg, "static_value"); v != "" {
|
||||
value = v
|
||||
} else {
|
||||
value = resolveOutputSourceBinding(cfg, state)
|
||||
}
|
||||
state.Outputs[key] = value
|
||||
return map[string]any{"output": value, "outputs": map[string]any{key: value}}, true, "completed", ""
|
||||
case "end":
|
||||
value := resolveOutputSourceBinding(cfg, state)
|
||||
if b, ok := parseFieldBinding(cfg, "result_binding"); ok {
|
||||
value = resolveBinding(b, state)
|
||||
}
|
||||
return map[string]any{"output": value}, false, "completed", ""
|
||||
case "tool":
|
||||
return runToolNode(ctx, args, node, state)
|
||||
case "agent":
|
||||
return runAgentNode(ctx, args, node, state)
|
||||
case "hitl":
|
||||
return runHITLNode(args, node, state)
|
||||
default:
|
||||
reason := "未知节点类型"
|
||||
return map[string]any{"output": "", "skipped": true, "reason": reason, "node_type": node.Type}, true, "skipped", reason
|
||||
}
|
||||
}
|
||||
|
||||
func runToolNode(ctx context.Context, args RunArgs, node graphNode, state *WorkflowLocalState) (map[string]any, bool, string, string) {
|
||||
toolName := cfgString(node.Config, "tool_name")
|
||||
if toolName == "" {
|
||||
errText := "工具节点未选择 MCP 工具"
|
||||
return map[string]any{"output": "", "error": errText}, false, "failed", errText
|
||||
}
|
||||
if args.Agent == nil {
|
||||
errText := "工具节点执行失败:Agent 为空"
|
||||
return map[string]any{"output": "", "tool_name": toolName, "error": errText}, false, "failed", errText
|
||||
}
|
||||
toolArgs, err := resolveToolArguments(node.Config, state)
|
||||
if err != nil {
|
||||
errText := fmt.Sprintf("工具参数不是合法 JSON:%v", err)
|
||||
return map[string]any{"output": "", "tool_name": toolName, "error": errText}, false, "failed", errText
|
||||
}
|
||||
if args.Progress != nil {
|
||||
args.Progress("workflow_tool_start", fmt.Sprintf("调用工具:%s", toolName), map[string]any{
|
||||
"nodeId": node.ID,
|
||||
"tool": toolName,
|
||||
"args": toolArgs,
|
||||
})
|
||||
}
|
||||
result, err := args.Agent.ExecuteMCPToolForConversation(ctx, args.ConversationID, toolName, toolArgs)
|
||||
if err != nil {
|
||||
errText := err.Error()
|
||||
return map[string]any{"output": "", "tool_name": toolName, "arguments": toolArgs, "error": errText}, false, "failed", errText
|
||||
}
|
||||
output := ""
|
||||
executionID := ""
|
||||
isError := false
|
||||
if result != nil {
|
||||
output = result.Result
|
||||
executionID = result.ExecutionID
|
||||
isError = result.IsError
|
||||
}
|
||||
out := map[string]any{
|
||||
"output": output,
|
||||
"tool_name": toolName,
|
||||
"arguments": toolArgs,
|
||||
"execution_id": executionID,
|
||||
"is_error": isError,
|
||||
}
|
||||
if key := cfgString(node.Config, "output_key"); key != "" {
|
||||
state.Outputs[key] = output
|
||||
}
|
||||
if isError {
|
||||
errText := strings.TrimSpace(output)
|
||||
if errText == "" {
|
||||
errText = "工具返回错误"
|
||||
}
|
||||
return out, false, "failed", errText
|
||||
}
|
||||
return out, true, "completed", ""
|
||||
}
|
||||
|
||||
func runAgentNode(ctx context.Context, args RunArgs, node graphNode, state *WorkflowLocalState) (map[string]any, bool, string, string) {
|
||||
if args.AppCfg == nil || args.Agent == nil {
|
||||
errText := "Agent 节点执行失败:应用配置或 Agent 为空"
|
||||
return map[string]any{"output": "", "error": errText}, false, "failed", errText
|
||||
}
|
||||
mode := strings.ToLower(cfgString(node.Config, "agent_mode"))
|
||||
if mode == "" {
|
||||
mode = "eino_single"
|
||||
}
|
||||
inputSource := resolveNodeInputBinding(node.Config, state)
|
||||
message := buildAgentNodeMessage(node, state, inputSource)
|
||||
var result *multiagent.RunResult
|
||||
var err error
|
||||
state.SegmentMaxIteration = 0
|
||||
agentProgress := workflowAgentProgress(args.Progress, state, node)
|
||||
switch mode {
|
||||
case "eino_single", "single", "chat":
|
||||
result, err = multiagent.RunEinoSingleChatModelAgent(
|
||||
ctx,
|
||||
args.AppCfg,
|
||||
&args.AppCfg.MultiAgent,
|
||||
args.Agent,
|
||||
args.DB,
|
||||
args.Logger,
|
||||
args.ConversationID,
|
||||
args.ProjectID,
|
||||
message,
|
||||
args.History,
|
||||
args.RoleTools,
|
||||
agentProgress,
|
||||
nil,
|
||||
args.SystemPromptExtra,
|
||||
)
|
||||
default:
|
||||
result, err = multiagent.RunDeepAgent(
|
||||
ctx,
|
||||
args.AppCfg,
|
||||
&args.AppCfg.MultiAgent,
|
||||
args.Agent,
|
||||
args.DB,
|
||||
args.Logger,
|
||||
args.ConversationID,
|
||||
args.ProjectID,
|
||||
message,
|
||||
args.History,
|
||||
args.RoleTools,
|
||||
agentProgress,
|
||||
args.AgentsMarkdownDir,
|
||||
mode,
|
||||
nil,
|
||||
args.SystemPromptExtra,
|
||||
)
|
||||
}
|
||||
if err != nil {
|
||||
errText := err.Error()
|
||||
state.MainIterationOffset += state.SegmentMaxIteration
|
||||
return map[string]any{"output": "", "mode": mode, "error": errText}, false, "failed", errText
|
||||
}
|
||||
state.MainIterationOffset += state.SegmentMaxIteration
|
||||
response := ""
|
||||
mcpIDs := []string{}
|
||||
if result != nil {
|
||||
response = result.Response
|
||||
mcpIDs = result.MCPExecutionIDs
|
||||
}
|
||||
if args.Progress != nil {
|
||||
args.Progress("workflow_agent_output", response, map[string]any{
|
||||
"nodeId": node.ID,
|
||||
"label": firstNonEmpty(node.Label, node.ID),
|
||||
"mode": mode,
|
||||
"inputSource": inputSource,
|
||||
"inputPreview": truncateWorkflowPreview(inputSource, 500),
|
||||
"mcpExecutionIds": mcpIDs,
|
||||
})
|
||||
}
|
||||
if key := cfgString(node.Config, "output_key"); key != "" {
|
||||
state.Outputs[key] = response
|
||||
}
|
||||
return map[string]any{
|
||||
"output": response,
|
||||
"mode": mode,
|
||||
"mcp_execution_ids": mcpIDs,
|
||||
}, true, "completed", ""
|
||||
}
|
||||
|
||||
func buildAgentNodeMessage(node graphNode, state *WorkflowLocalState, upstreamInput string) string {
|
||||
instruction := strings.TrimSpace(cfgString(node.Config, "instruction"))
|
||||
upstreamInput = strings.TrimSpace(upstreamInput)
|
||||
if instruction == "" {
|
||||
if upstreamInput != "" {
|
||||
return fmt.Sprintf("请基于上游节点输出继续处理:\n%s", upstreamInput)
|
||||
}
|
||||
return fmt.Sprintf("请基于上游节点输出继续处理:\n%v", state.LastOutput["output"])
|
||||
}
|
||||
if upstreamInput == "" {
|
||||
return instruction
|
||||
}
|
||||
return strings.TrimSpace(fmt.Sprintf("上游输入:\n%s\n\n节点指令:\n%s", upstreamInput, instruction))
|
||||
}
|
||||
|
||||
func workflowAgentProgress(progress agent.ProgressCallback, state *WorkflowLocalState, node graphNode) agent.ProgressCallback {
|
||||
if progress == nil {
|
||||
return nil
|
||||
}
|
||||
return func(eventType, message string, data interface{}) {
|
||||
switch eventType {
|
||||
case "response_start", "response_delta", "response", "done":
|
||||
return
|
||||
default:
|
||||
enrichWorkflowAgentEventData(data, state, node)
|
||||
if eventType == "iteration" {
|
||||
applyWorkflowMainIterationOffset(data, state)
|
||||
}
|
||||
progress(eventType, message, data)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func enrichWorkflowAgentEventData(data interface{}, state *WorkflowLocalState, node graphNode) {
|
||||
m, ok := data.(map[string]interface{})
|
||||
if !ok || m == nil {
|
||||
return
|
||||
}
|
||||
if node.ID != "" {
|
||||
m["workflowNodeId"] = node.ID
|
||||
}
|
||||
if state != nil && strings.TrimSpace(state.WorkflowRunID) != "" {
|
||||
m["workflowRunId"] = state.WorkflowRunID
|
||||
}
|
||||
}
|
||||
|
||||
func applyWorkflowMainIterationOffset(data interface{}, state *WorkflowLocalState) {
|
||||
if state == nil {
|
||||
return
|
||||
}
|
||||
m, ok := data.(map[string]interface{})
|
||||
if !ok || m == nil {
|
||||
return
|
||||
}
|
||||
scope, _ := m["einoScope"].(string)
|
||||
if strings.TrimSpace(scope) != "main" {
|
||||
return
|
||||
}
|
||||
raw := iterationNumberFromProgressData(m)
|
||||
if raw <= 0 {
|
||||
return
|
||||
}
|
||||
if raw > state.SegmentMaxIteration {
|
||||
state.SegmentMaxIteration = raw
|
||||
}
|
||||
m["iteration"] = raw + state.MainIterationOffset
|
||||
}
|
||||
|
||||
func iterationNumberFromProgressData(m map[string]interface{}) int {
|
||||
switch v := m["iteration"].(type) {
|
||||
case int:
|
||||
return v
|
||||
case int32:
|
||||
return int(v)
|
||||
case int64:
|
||||
return int(v)
|
||||
case float64:
|
||||
return int(v)
|
||||
case float32:
|
||||
return int(v)
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
func runHITLNode(args RunArgs, node graphNode, state *WorkflowLocalState) (map[string]any, bool, string, string) {
|
||||
prompt := resolveHITLPromptBinding(node.Config, state)
|
||||
reviewer := cfgString(node.Config, "reviewer")
|
||||
if reviewer == "" {
|
||||
reviewer = "human"
|
||||
}
|
||||
approved := true
|
||||
if state != nil && state.Inputs != nil {
|
||||
if v, ok := state.Inputs["_hitl_approved"]; ok {
|
||||
approved = fmt.Sprint(v) == "true"
|
||||
}
|
||||
}
|
||||
if !approved {
|
||||
reason := "人工审批已拒绝"
|
||||
if state != nil && state.Inputs != nil {
|
||||
if v, ok := state.Inputs["_hitl_comment"]; ok {
|
||||
if s := strings.TrimSpace(fmt.Sprint(v)); s != "" {
|
||||
reason = s
|
||||
}
|
||||
}
|
||||
}
|
||||
return map[string]any{"output": "", "prompt": prompt, "approved": false, "mode": "interactive"}, false, "failed", reason
|
||||
}
|
||||
if args.Progress != nil {
|
||||
args.Progress("workflow_hitl_checkpoint", "人工确认节点已通过", map[string]any{
|
||||
"nodeId": node.ID,
|
||||
"prompt": prompt,
|
||||
"reviewer": reviewer,
|
||||
"mode": "interactive",
|
||||
"approved": true,
|
||||
})
|
||||
}
|
||||
return map[string]any{
|
||||
"output": prompt,
|
||||
"prompt": prompt,
|
||||
"reviewer": reviewer,
|
||||
"approved": true,
|
||||
"mode": "interactive",
|
||||
}, true, "completed", ""
|
||||
}
|
||||
+48
-821
@@ -4,82 +4,16 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/agent"
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/database"
|
||||
"cyberstrike-ai/internal/multiagent"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type RunArgs struct {
|
||||
DB *database.DB
|
||||
Logger *zap.Logger
|
||||
Role config.RoleConfig
|
||||
AppCfg *config.Config
|
||||
Agent *agent.Agent
|
||||
ConversationID string
|
||||
ProjectID string
|
||||
UserMessage string
|
||||
History []agent.ChatMessage
|
||||
RoleTools []string
|
||||
AgentsMarkdownDir string
|
||||
SystemPromptExtra string
|
||||
AssistantMessageID string
|
||||
Progress agent.ProgressCallback
|
||||
}
|
||||
|
||||
type RunResult struct {
|
||||
Response string
|
||||
RunID string
|
||||
}
|
||||
|
||||
type graphDef struct {
|
||||
Nodes []graphNode `json:"nodes"`
|
||||
Edges []graphEdge `json:"edges"`
|
||||
Config map[string]any `json:"config"`
|
||||
}
|
||||
|
||||
type graphNode struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Label string `json:"label"`
|
||||
Position graphPosition `json:"position"`
|
||||
Config map[string]any `json:"config"`
|
||||
}
|
||||
|
||||
type graphEdge struct {
|
||||
ID string `json:"id"`
|
||||
Source string `json:"source"`
|
||||
Target string `json:"target"`
|
||||
Label string `json:"label"`
|
||||
Config map[string]any `json:"config"`
|
||||
}
|
||||
|
||||
type graphPosition struct {
|
||||
X float64 `json:"x"`
|
||||
Y float64 `json:"y"`
|
||||
}
|
||||
|
||||
type workflowExecState struct {
|
||||
inputs map[string]any
|
||||
outputs map[string]any
|
||||
nodeOutputs map[string]map[string]any
|
||||
lastOutput map[string]any
|
||||
executed []string
|
||||
skipped []string
|
||||
workflowRunID string
|
||||
// 图编排内多个 Agent 节点各自从第 1 轮上报 iteration;累计偏移避免对话页迭代序号回跳与流式条目复用错乱。
|
||||
mainIterationOffset int
|
||||
segmentMaxIteration int
|
||||
}
|
||||
|
||||
// ShouldAutoRunRoleWorkflow returns true when a role explicitly binds a workflow
|
||||
// and does not turn it off. Empty policy defaults to auto to keep role UX simple.
|
||||
func ShouldAutoRunRoleWorkflow(role config.RoleConfig) bool {
|
||||
@@ -90,10 +24,7 @@ func ShouldAutoRunRoleWorkflow(role config.RoleConfig) bool {
|
||||
return policy == "" || policy == "auto"
|
||||
}
|
||||
|
||||
// RunRoleBoundWorkflow executes the persisted role-bound workflow graph.
|
||||
// Control nodes are interpreted locally, tool nodes call the existing MCP bridge,
|
||||
// and agent nodes reuse the existing Eino ADK runners so role-bound flows share
|
||||
// the same model/tool/session behavior as the chat page.
|
||||
// RunRoleBoundWorkflow executes the persisted role-bound workflow via cached Eino Workflow.
|
||||
func RunRoleBoundWorkflow(ctx context.Context, args RunArgs) (*RunResult, error) {
|
||||
if args.DB == nil {
|
||||
return nil, fmt.Errorf("workflow db is nil")
|
||||
@@ -150,6 +81,7 @@ func RunRoleBoundWorkflow(ctx context.Context, args RunArgs) (*RunResult, error)
|
||||
"workflowVersion": wf.Version,
|
||||
"workflowRunId": runID,
|
||||
"conversationId": args.ConversationID,
|
||||
"engine": "eino_workflow",
|
||||
})
|
||||
}
|
||||
|
||||
@@ -158,13 +90,44 @@ func RunRoleBoundWorkflow(ctx context.Context, args RunArgs) (*RunResult, error)
|
||||
_ = args.DB.FinishWorkflowRun(runID, "failed", "", err.Error())
|
||||
return nil, err
|
||||
}
|
||||
state := &workflowExecState{
|
||||
inputs: input,
|
||||
outputs: make(map[string]any),
|
||||
nodeOutputs: make(map[string]map[string]any),
|
||||
workflowRunID: runID,
|
||||
}
|
||||
if err := executeGraph(ctx, args, runID, graph, state); err != nil {
|
||||
state := newWorkflowLocalState(input, runID)
|
||||
if err := executeEinoGraph(ctx, args, runID, wf.ID, wf.Version, graph, state); err != nil {
|
||||
if IsAwaitingHITL(err) {
|
||||
hitl := err.(*AwaitingHITLError)
|
||||
partial := map[string]interface{}{
|
||||
"workflowId": wf.ID,
|
||||
"workflowName": wf.Name,
|
||||
"workflowVersion": wf.Version,
|
||||
"workflowRunId": runID,
|
||||
"status": "awaiting_hitl",
|
||||
"outputs": state.Outputs,
|
||||
"executedNodes": state.Executed,
|
||||
"skippedNodes": state.Skipped,
|
||||
"pendingHitl": map[string]interface{}{
|
||||
"nodeId": hitl.NodeID,
|
||||
"label": hitl.NodeLabel,
|
||||
"prompt": hitl.Prompt,
|
||||
},
|
||||
"engine": "eino_workflow",
|
||||
}
|
||||
partialJSON, _ := json.Marshal(partial)
|
||||
_ = args.DB.SetWorkflowRunAwaitingHITL(runID, hitl.NodeID, string(partialJSON))
|
||||
response := fmt.Sprintf("工作流「%s」已在节点「%s」暂停,等待人工审批。\n运行 ID:%s", wf.Name, firstNonEmpty(hitl.NodeLabel, hitl.NodeID), runID)
|
||||
if args.Progress != nil {
|
||||
args.Progress("workflow_paused", response, map[string]interface{}{
|
||||
"workflowRunId": runID,
|
||||
"status": "awaiting_hitl",
|
||||
"nodeId": hitl.NodeID,
|
||||
"resumeApi": fmt.Sprintf("/api/workflows/runs/%s/resume", runID),
|
||||
})
|
||||
}
|
||||
return &RunResult{
|
||||
Response: response,
|
||||
RunID: runID,
|
||||
Status: "awaiting_hitl",
|
||||
AwaitingHITL: true,
|
||||
}, nil
|
||||
}
|
||||
_ = args.DB.FinishWorkflowRun(runID, "failed", "", err.Error())
|
||||
return nil, err
|
||||
}
|
||||
@@ -175,9 +138,10 @@ func RunRoleBoundWorkflow(ctx context.Context, args RunArgs) (*RunResult, error)
|
||||
"workflowVersion": wf.Version,
|
||||
"workflowRunId": runID,
|
||||
"status": "completed",
|
||||
"outputs": state.outputs,
|
||||
"executedNodes": state.executed,
|
||||
"skippedNodes": state.skipped,
|
||||
"outputs": state.Outputs,
|
||||
"executedNodes": state.Executed,
|
||||
"skippedNodes": state.Skipped,
|
||||
"engine": "eino_workflow",
|
||||
}
|
||||
outputJSON, _ := json.Marshal(output)
|
||||
|
||||
@@ -189,8 +153,9 @@ func RunRoleBoundWorkflow(ctx context.Context, args RunArgs) (*RunResult, error)
|
||||
args.Progress("workflow_done", fmt.Sprintf("流程「%s」运行完成", wf.Name), map[string]interface{}{
|
||||
"workflowRunId": runID,
|
||||
"workflowId": wf.ID,
|
||||
"outputs": state.outputs,
|
||||
"outputs": state.Outputs,
|
||||
"response": response,
|
||||
"engine": "eino_workflow",
|
||||
})
|
||||
}
|
||||
if args.Logger != nil {
|
||||
@@ -199,746 +164,8 @@ func RunRoleBoundWorkflow(ctx context.Context, args RunArgs) (*RunResult, error)
|
||||
zap.String("workflow_run_id", runID),
|
||||
zap.String("conversation_id", args.ConversationID),
|
||||
zap.String("role", args.Role.Name),
|
||||
zap.String("engine", "eino_workflow"),
|
||||
)
|
||||
}
|
||||
return &RunResult{Response: response, RunID: runID}, nil
|
||||
}
|
||||
|
||||
func parseGraph(raw string) (*graphDef, error) {
|
||||
var g graphDef
|
||||
if err := json.Unmarshal([]byte(strings.TrimSpace(raw)), &g); err != nil {
|
||||
return nil, fmt.Errorf("解析工作流图失败: %w", err)
|
||||
}
|
||||
if len(g.Nodes) == 0 {
|
||||
return nil, fmt.Errorf("工作流没有节点")
|
||||
}
|
||||
if g.Config == nil {
|
||||
g.Config = make(map[string]any)
|
||||
}
|
||||
return &g, nil
|
||||
}
|
||||
|
||||
func executeGraph(ctx context.Context, args RunArgs, runID string, g *graphDef, state *workflowExecState) error {
|
||||
nodes := make(map[string]graphNode, len(g.Nodes))
|
||||
inDegree := make(map[string]int, len(g.Nodes))
|
||||
outgoing := make(map[string][]graphEdge)
|
||||
for _, node := range g.Nodes {
|
||||
node.ID = strings.TrimSpace(node.ID)
|
||||
if node.ID == "" {
|
||||
continue
|
||||
}
|
||||
if strings.TrimSpace(node.Type) == "" {
|
||||
node.Type = "tool"
|
||||
}
|
||||
if node.Config == nil {
|
||||
node.Config = make(map[string]any)
|
||||
}
|
||||
nodes[node.ID] = node
|
||||
inDegree[node.ID] = 0
|
||||
}
|
||||
for _, edge := range g.Edges {
|
||||
if _, ok := nodes[edge.Source]; !ok {
|
||||
continue
|
||||
}
|
||||
if _, ok := nodes[edge.Target]; !ok {
|
||||
continue
|
||||
}
|
||||
outgoing[edge.Source] = append(outgoing[edge.Source], edge)
|
||||
inDegree[edge.Target]++
|
||||
}
|
||||
for source := range outgoing {
|
||||
sort.SliceStable(outgoing[source], func(i, j int) bool {
|
||||
a := nodes[outgoing[source][i].Target]
|
||||
b := nodes[outgoing[source][j].Target]
|
||||
if a.Position.Y != b.Position.Y {
|
||||
return a.Position.Y < b.Position.Y
|
||||
}
|
||||
if a.Position.X != b.Position.X {
|
||||
return a.Position.X < b.Position.X
|
||||
}
|
||||
return outgoing[source][i].Target < outgoing[source][j].Target
|
||||
})
|
||||
}
|
||||
|
||||
var queue []string
|
||||
for id, node := range nodes {
|
||||
if strings.EqualFold(node.Type, "start") {
|
||||
queue = append(queue, id)
|
||||
}
|
||||
}
|
||||
if len(queue) == 0 {
|
||||
for id, deg := range inDegree {
|
||||
if deg == 0 {
|
||||
queue = append(queue, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
sortNodeIDsByCanvas(queue, nodes)
|
||||
seen := make(map[string]bool)
|
||||
remainingIncoming := make(map[string]int, len(inDegree))
|
||||
for id, deg := range inDegree {
|
||||
remainingIncoming[id] = deg
|
||||
}
|
||||
for len(queue) > 0 {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
}
|
||||
id := queue[0]
|
||||
queue = queue[1:]
|
||||
if seen[id] {
|
||||
continue
|
||||
}
|
||||
seen[id] = true
|
||||
node := nodes[id]
|
||||
result, proceed, err := executeNode(ctx, args, runID, node, state)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
state.nodeOutputs[id] = result
|
||||
state.lastOutput = result
|
||||
if proceed {
|
||||
edges := outgoing[id]
|
||||
if strings.EqualFold(node.Type, "condition") {
|
||||
emitConditionBranchProgress(args, runID, node, edges, nodes, state)
|
||||
}
|
||||
for edgeIdx, edge := range edges {
|
||||
if !edgeAllowed(edge, node, edgeIdx, state) {
|
||||
continue
|
||||
}
|
||||
remainingIncoming[edge.Target]--
|
||||
if remainingIncoming[edge.Target] > 0 {
|
||||
continue
|
||||
}
|
||||
queue = append(queue, edge.Target)
|
||||
}
|
||||
sortNodeIDsByCanvas(queue, nodes)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func sortNodeIDsByCanvas(ids []string, nodes map[string]graphNode) {
|
||||
sort.SliceStable(ids, func(i, j int) bool {
|
||||
a := nodes[ids[i]]
|
||||
b := nodes[ids[j]]
|
||||
if a.Position.Y != b.Position.Y {
|
||||
return a.Position.Y < b.Position.Y
|
||||
}
|
||||
if a.Position.X != b.Position.X {
|
||||
return a.Position.X < b.Position.X
|
||||
}
|
||||
return ids[i] < ids[j]
|
||||
})
|
||||
}
|
||||
|
||||
func executeNode(ctx context.Context, args RunArgs, runID string, node graphNode, state *workflowExecState) (map[string]any, bool, error) {
|
||||
label := node.Label
|
||||
if strings.TrimSpace(label) == "" {
|
||||
label = node.ID
|
||||
}
|
||||
nodeRunID := uuid.NewString()
|
||||
input := map[string]any{
|
||||
"nodeId": node.ID,
|
||||
"nodeType": node.Type,
|
||||
"label": label,
|
||||
"inputs": state.inputs,
|
||||
"previous": state.lastOutput,
|
||||
}
|
||||
inputJSON, _ := json.Marshal(input)
|
||||
if err := args.DB.CreateWorkflowNodeRun(&database.WorkflowNodeRun{
|
||||
ID: nodeRunID,
|
||||
RunID: runID,
|
||||
NodeID: node.ID,
|
||||
Status: "running",
|
||||
InputJSON: string(inputJSON),
|
||||
StartedAt: time.Now(),
|
||||
}); err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
if args.Progress != nil {
|
||||
args.Progress("workflow_node_start", fmt.Sprintf("开始节点:%s", label), map[string]any{
|
||||
"workflowRunId": runID,
|
||||
"nodeRunId": nodeRunID,
|
||||
"nodeId": node.ID,
|
||||
"nodeType": node.Type,
|
||||
"label": label,
|
||||
})
|
||||
}
|
||||
|
||||
result, proceed, status, errText := runBuiltinNode(ctx, args, node, state)
|
||||
outputJSON, _ := json.Marshal(result)
|
||||
if err := args.DB.FinishWorkflowNodeRun(nodeRunID, status, string(outputJSON), errText); err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
if status == "skipped" {
|
||||
state.skipped = append(state.skipped, label)
|
||||
} else {
|
||||
state.executed = append(state.executed, label)
|
||||
}
|
||||
if args.Progress != nil {
|
||||
progressData := map[string]any{
|
||||
"workflowRunId": runID,
|
||||
"nodeRunId": nodeRunID,
|
||||
"nodeId": node.ID,
|
||||
"nodeType": node.Type,
|
||||
"label": label,
|
||||
"status": status,
|
||||
"output": result,
|
||||
}
|
||||
progressMsg := fmt.Sprintf("节点完成:%s(%s)", label, status)
|
||||
if strings.EqualFold(node.Type, "condition") {
|
||||
matched := false
|
||||
if v, ok := result["matched"].(bool); ok {
|
||||
matched = v
|
||||
}
|
||||
expr := cfgString(node.Config, "expression")
|
||||
if matched {
|
||||
progressMsg = fmt.Sprintf("条件判断:%s → 是", label)
|
||||
} else {
|
||||
progressMsg = fmt.Sprintf("条件判断:%s → 否", label)
|
||||
}
|
||||
progressData["expression"] = expr
|
||||
progressData["matched"] = matched
|
||||
}
|
||||
args.Progress("workflow_node_result", progressMsg, progressData)
|
||||
}
|
||||
return result, proceed, nil
|
||||
}
|
||||
|
||||
func runBuiltinNode(ctx context.Context, args RunArgs, node graphNode, state *workflowExecState) (map[string]any, bool, string, string) {
|
||||
cfg := node.Config
|
||||
switch strings.ToLower(strings.TrimSpace(node.Type)) {
|
||||
case "start":
|
||||
out := map[string]any{
|
||||
"output": state.inputs["message"],
|
||||
"message": state.inputs["message"],
|
||||
"conversationId": state.inputs["conversationId"],
|
||||
"projectId": state.inputs["projectId"],
|
||||
}
|
||||
return out, true, "completed", ""
|
||||
case "condition":
|
||||
expr := cfgString(cfg, "expression")
|
||||
ok := evalCondition(expr, state)
|
||||
out := map[string]any{"output": ok, "condition": expr, "matched": ok}
|
||||
// 条件节点始终继续,由出边条件(或连线标签/顺序)决定走「是/否」分支。
|
||||
return out, true, "completed", ""
|
||||
case "output":
|
||||
key := cfgString(cfg, "output_key")
|
||||
if key == "" {
|
||||
key = "result"
|
||||
}
|
||||
value := resolveTemplate(cfgString(cfg, "source"), state)
|
||||
state.outputs[key] = value
|
||||
return map[string]any{"output": value, "outputs": map[string]any{key: value}}, true, "completed", ""
|
||||
case "end":
|
||||
value := resolveTemplate(cfgString(cfg, "result_template"), state)
|
||||
return map[string]any{"output": value}, false, "completed", ""
|
||||
case "tool":
|
||||
return runToolNode(ctx, args, node, state)
|
||||
case "agent":
|
||||
return runAgentNode(ctx, args, node, state)
|
||||
case "hitl":
|
||||
return runHITLNode(args, node, state)
|
||||
default:
|
||||
reason := "未知节点类型"
|
||||
return map[string]any{"output": "", "skipped": true, "reason": reason, "node_type": node.Type}, true, "skipped", reason
|
||||
}
|
||||
}
|
||||
|
||||
func runToolNode(ctx context.Context, args RunArgs, node graphNode, state *workflowExecState) (map[string]any, bool, string, string) {
|
||||
toolName := cfgString(node.Config, "tool_name")
|
||||
if toolName == "" {
|
||||
errText := "工具节点未选择 MCP 工具"
|
||||
return map[string]any{"output": "", "error": errText}, false, "failed", errText
|
||||
}
|
||||
if args.Agent == nil {
|
||||
errText := "工具节点执行失败:Agent 为空"
|
||||
return map[string]any{"output": "", "tool_name": toolName, "error": errText}, false, "failed", errText
|
||||
}
|
||||
toolArgs, err := parseToolArguments(cfgString(node.Config, "arguments"), state)
|
||||
if err != nil {
|
||||
errText := fmt.Sprintf("工具参数不是合法 JSON:%v", err)
|
||||
return map[string]any{"output": "", "tool_name": toolName, "error": errText}, false, "failed", errText
|
||||
}
|
||||
if args.Progress != nil {
|
||||
args.Progress("workflow_tool_start", fmt.Sprintf("调用工具:%s", toolName), map[string]any{
|
||||
"nodeId": node.ID,
|
||||
"tool": toolName,
|
||||
"args": toolArgs,
|
||||
})
|
||||
}
|
||||
result, err := args.Agent.ExecuteMCPToolForConversation(ctx, args.ConversationID, toolName, toolArgs)
|
||||
if err != nil {
|
||||
errText := err.Error()
|
||||
return map[string]any{"output": "", "tool_name": toolName, "arguments": toolArgs, "error": errText}, false, "failed", errText
|
||||
}
|
||||
output := ""
|
||||
executionID := ""
|
||||
isError := false
|
||||
if result != nil {
|
||||
output = result.Result
|
||||
executionID = result.ExecutionID
|
||||
isError = result.IsError
|
||||
}
|
||||
out := map[string]any{
|
||||
"output": output,
|
||||
"tool_name": toolName,
|
||||
"arguments": toolArgs,
|
||||
"execution_id": executionID,
|
||||
"is_error": isError,
|
||||
}
|
||||
if key := cfgString(node.Config, "output_key"); key != "" {
|
||||
state.outputs[key] = output
|
||||
}
|
||||
if isError {
|
||||
errText := strings.TrimSpace(output)
|
||||
if errText == "" {
|
||||
errText = "工具返回错误"
|
||||
}
|
||||
return out, false, "failed", errText
|
||||
}
|
||||
return out, true, "completed", ""
|
||||
}
|
||||
|
||||
func runAgentNode(ctx context.Context, args RunArgs, node graphNode, state *workflowExecState) (map[string]any, bool, string, string) {
|
||||
if args.AppCfg == nil || args.Agent == nil {
|
||||
errText := "Agent 节点执行失败:应用配置或 Agent 为空"
|
||||
return map[string]any{"output": "", "error": errText}, false, "failed", errText
|
||||
}
|
||||
mode := strings.ToLower(cfgString(node.Config, "agent_mode"))
|
||||
if mode == "" {
|
||||
mode = "eino_single"
|
||||
}
|
||||
inputSource := cfgString(node.Config, "input_source")
|
||||
if inputSource == "" {
|
||||
inputSource = "{{previous.output}}"
|
||||
}
|
||||
upstreamInput := strings.TrimSpace(resolveTemplate(inputSource, state))
|
||||
message := buildAgentNodeMessage(node, state)
|
||||
var result *multiagent.RunResult
|
||||
var err error
|
||||
state.segmentMaxIteration = 0
|
||||
agentProgress := workflowAgentProgress(args.Progress, state, node)
|
||||
switch mode {
|
||||
case "eino_single", "single", "chat":
|
||||
result, err = multiagent.RunEinoSingleChatModelAgent(
|
||||
ctx,
|
||||
args.AppCfg,
|
||||
&args.AppCfg.MultiAgent,
|
||||
args.Agent,
|
||||
args.DB,
|
||||
args.Logger,
|
||||
args.ConversationID,
|
||||
args.ProjectID,
|
||||
message,
|
||||
args.History,
|
||||
args.RoleTools,
|
||||
agentProgress,
|
||||
nil,
|
||||
args.SystemPromptExtra,
|
||||
)
|
||||
default:
|
||||
result, err = multiagent.RunDeepAgent(
|
||||
ctx,
|
||||
args.AppCfg,
|
||||
&args.AppCfg.MultiAgent,
|
||||
args.Agent,
|
||||
args.DB,
|
||||
args.Logger,
|
||||
args.ConversationID,
|
||||
args.ProjectID,
|
||||
message,
|
||||
args.History,
|
||||
args.RoleTools,
|
||||
agentProgress,
|
||||
args.AgentsMarkdownDir,
|
||||
mode,
|
||||
nil,
|
||||
args.SystemPromptExtra,
|
||||
)
|
||||
}
|
||||
if err != nil {
|
||||
errText := err.Error()
|
||||
state.mainIterationOffset += state.segmentMaxIteration
|
||||
return map[string]any{"output": "", "mode": mode, "error": errText}, false, "failed", errText
|
||||
}
|
||||
state.mainIterationOffset += state.segmentMaxIteration
|
||||
response := ""
|
||||
mcpIDs := []string{}
|
||||
if result != nil {
|
||||
response = result.Response
|
||||
mcpIDs = result.MCPExecutionIDs
|
||||
}
|
||||
if args.Progress != nil {
|
||||
args.Progress("workflow_agent_output", response, map[string]any{
|
||||
"nodeId": node.ID,
|
||||
"label": firstNonEmpty(node.Label, node.ID),
|
||||
"mode": mode,
|
||||
"inputSource": inputSource,
|
||||
"inputPreview": truncateWorkflowPreview(upstreamInput, 500),
|
||||
"mcpExecutionIds": mcpIDs,
|
||||
})
|
||||
}
|
||||
if key := cfgString(node.Config, "output_key"); key != "" {
|
||||
state.outputs[key] = response
|
||||
}
|
||||
return map[string]any{
|
||||
"output": response,
|
||||
"mode": mode,
|
||||
"mcp_execution_ids": mcpIDs,
|
||||
}, true, "completed", ""
|
||||
}
|
||||
|
||||
func buildAgentNodeMessage(node graphNode, state *workflowExecState) string {
|
||||
instruction := strings.TrimSpace(resolveTemplate(cfgString(node.Config, "instruction"), state))
|
||||
inputSource := cfgString(node.Config, "input_source")
|
||||
if inputSource == "" {
|
||||
inputSource = "{{previous.output}}"
|
||||
}
|
||||
upstreamInput := strings.TrimSpace(resolveTemplate(inputSource, state))
|
||||
if instruction == "" {
|
||||
if upstreamInput != "" {
|
||||
return fmt.Sprintf("请基于上游节点输出继续处理:\n%s", upstreamInput)
|
||||
}
|
||||
return fmt.Sprintf("请基于上游节点输出继续处理:\n%v", state.lastOutput["output"])
|
||||
}
|
||||
if upstreamInput == "" {
|
||||
return instruction
|
||||
}
|
||||
return strings.TrimSpace(fmt.Sprintf("上游输入:\n%s\n\n节点指令:\n%s", upstreamInput, instruction))
|
||||
}
|
||||
|
||||
func workflowAgentProgress(progress agent.ProgressCallback, state *workflowExecState, node graphNode) agent.ProgressCallback {
|
||||
if progress == nil {
|
||||
return nil
|
||||
}
|
||||
return func(eventType, message string, data interface{}) {
|
||||
switch eventType {
|
||||
case "response_start", "response_delta", "response", "done":
|
||||
return
|
||||
default:
|
||||
enrichWorkflowAgentEventData(data, state, node)
|
||||
if eventType == "iteration" {
|
||||
applyWorkflowMainIterationOffset(data, state)
|
||||
}
|
||||
progress(eventType, message, data)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func enrichWorkflowAgentEventData(data interface{}, state *workflowExecState, node graphNode) {
|
||||
m, ok := data.(map[string]interface{})
|
||||
if !ok || m == nil {
|
||||
return
|
||||
}
|
||||
if node.ID != "" {
|
||||
m["workflowNodeId"] = node.ID
|
||||
}
|
||||
if state != nil && strings.TrimSpace(state.workflowRunID) != "" {
|
||||
m["workflowRunId"] = state.workflowRunID
|
||||
}
|
||||
}
|
||||
|
||||
func applyWorkflowMainIterationOffset(data interface{}, state *workflowExecState) {
|
||||
if state == nil {
|
||||
return
|
||||
}
|
||||
m, ok := data.(map[string]interface{})
|
||||
if !ok || m == nil {
|
||||
return
|
||||
}
|
||||
scope, _ := m["einoScope"].(string)
|
||||
if strings.TrimSpace(scope) != "main" {
|
||||
return
|
||||
}
|
||||
raw := iterationNumberFromProgressData(m)
|
||||
if raw <= 0 {
|
||||
return
|
||||
}
|
||||
if raw > state.segmentMaxIteration {
|
||||
state.segmentMaxIteration = raw
|
||||
}
|
||||
m["iteration"] = raw + state.mainIterationOffset
|
||||
}
|
||||
|
||||
func iterationNumberFromProgressData(m map[string]interface{}) int {
|
||||
switch v := m["iteration"].(type) {
|
||||
case int:
|
||||
return v
|
||||
case int32:
|
||||
return int(v)
|
||||
case int64:
|
||||
return int(v)
|
||||
case float64:
|
||||
return int(v)
|
||||
case float32:
|
||||
return int(v)
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
func runHITLNode(args RunArgs, node graphNode, state *workflowExecState) (map[string]any, bool, string, string) {
|
||||
prompt := resolveTemplate(cfgString(node.Config, "prompt"), state)
|
||||
reviewer := cfgString(node.Config, "reviewer")
|
||||
if args.Progress != nil {
|
||||
args.Progress("workflow_hitl_checkpoint", "人工确认节点已记录", map[string]any{
|
||||
"nodeId": node.ID,
|
||||
"prompt": prompt,
|
||||
"reviewer": reviewer,
|
||||
"mode": "record_only",
|
||||
})
|
||||
}
|
||||
return map[string]any{
|
||||
"output": prompt,
|
||||
"prompt": prompt,
|
||||
"reviewer": reviewer,
|
||||
"approved": true,
|
||||
"mode": "record_only",
|
||||
}, true, "completed", ""
|
||||
}
|
||||
|
||||
func parseToolArguments(raw string, state *workflowExecState) (map[string]interface{}, error) {
|
||||
if raw == "" {
|
||||
return map[string]interface{}{}, nil
|
||||
}
|
||||
raw = strings.TrimSpace(resolveTemplate(raw, state))
|
||||
if raw == "" {
|
||||
return map[string]interface{}{}, nil
|
||||
}
|
||||
var args map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(raw), &args); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if args == nil {
|
||||
args = map[string]interface{}{}
|
||||
}
|
||||
return args, nil
|
||||
}
|
||||
|
||||
func edgeAllowed(edge graphEdge, sourceNode graphNode, edgeIndex int, state *workflowExecState) bool {
|
||||
cond := firstNonEmpty(cfgString(edge.Config, "condition"), cfgString(edge.Config, "expression"))
|
||||
if cond != "" {
|
||||
return evalCondition(cond, state)
|
||||
}
|
||||
if strings.EqualFold(strings.TrimSpace(sourceNode.Type), "condition") {
|
||||
return conditionBranchAllowed(edge, edgeIndex, state)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func conditionBranchAllowed(edge graphEdge, edgeIndex int, state *workflowExecState) bool {
|
||||
matched := conditionMatched(state)
|
||||
if branch := conditionBranchHint(edge); branch != "" {
|
||||
return (branch == "true" && matched) || (branch == "false" && !matched)
|
||||
}
|
||||
switch edgeIndex {
|
||||
case 0:
|
||||
return matched
|
||||
case 1:
|
||||
return !matched
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func conditionMatched(state *workflowExecState) bool {
|
||||
v := strings.ToLower(cleanComparable(fmt.Sprint(valueFromPath("previous.matched", state))))
|
||||
return v == "true" || v == "1"
|
||||
}
|
||||
|
||||
func conditionBranchHint(edge graphEdge) string {
|
||||
if edge.Config != nil {
|
||||
switch strings.ToLower(strings.TrimSpace(cfgString(edge.Config, "branch"))) {
|
||||
case "true", "yes", "y", "是":
|
||||
return "true"
|
||||
case "false", "no", "n", "否":
|
||||
return "false"
|
||||
}
|
||||
}
|
||||
switch strings.ToLower(strings.TrimSpace(edge.Label)) {
|
||||
case "true", "yes", "y", "是":
|
||||
return "true"
|
||||
case "false", "no", "n", "否":
|
||||
return "false"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func emitConditionBranchProgress(args RunArgs, runID string, node graphNode, edges []graphEdge, nodes map[string]graphNode, state *workflowExecState) {
|
||||
if args.Progress == nil || len(edges) == 0 {
|
||||
return
|
||||
}
|
||||
for edgeIdx, edge := range edges {
|
||||
allowed := edgeAllowed(edge, node, edgeIdx, state)
|
||||
target := nodes[edge.Target]
|
||||
targetLabel := strings.TrimSpace(target.Label)
|
||||
if targetLabel == "" {
|
||||
targetLabel = edge.Target
|
||||
}
|
||||
branchLabel := strings.TrimSpace(edge.Label)
|
||||
if branchLabel == "" {
|
||||
switch edgeIdx {
|
||||
case 0:
|
||||
branchLabel = "是"
|
||||
case 1:
|
||||
branchLabel = "否"
|
||||
default:
|
||||
branchLabel = fmt.Sprintf("分支 %d", edgeIdx+1)
|
||||
}
|
||||
}
|
||||
cond := firstNonEmpty(cfgString(edge.Config, "condition"), cfgString(edge.Config, "expression"))
|
||||
eventType := "workflow_branch_skipped"
|
||||
msg := fmt.Sprintf("跳过分支「%s」→ %s", branchLabel, targetLabel)
|
||||
if allowed {
|
||||
eventType = "workflow_branch_taken"
|
||||
msg = fmt.Sprintf("执行分支「%s」→ %s", branchLabel, targetLabel)
|
||||
}
|
||||
args.Progress(eventType, msg, map[string]any{
|
||||
"workflowRunId": runID,
|
||||
"nodeId": node.ID,
|
||||
"nodeType": node.Type,
|
||||
"label": node.Label,
|
||||
"branchLabel": branchLabel,
|
||||
"targetId": edge.Target,
|
||||
"targetLabel": targetLabel,
|
||||
"edgeCondition": cond,
|
||||
"matched": conditionMatched(state),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func cfgString(cfg map[string]any, key string) string {
|
||||
if cfg == nil {
|
||||
return ""
|
||||
}
|
||||
if v, ok := cfg[key]; ok {
|
||||
return strings.TrimSpace(fmt.Sprint(v))
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func firstNonEmpty(values ...string) string {
|
||||
for _, value := range values {
|
||||
if s := strings.TrimSpace(value); s != "" {
|
||||
return s
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func truncateWorkflowPreview(s string, limit int) string {
|
||||
s = strings.TrimSpace(s)
|
||||
if limit <= 0 || len([]rune(s)) <= limit {
|
||||
return s
|
||||
}
|
||||
runes := []rune(s)
|
||||
return string(runes[:limit]) + "..."
|
||||
}
|
||||
|
||||
var templateVarRe = regexp.MustCompile(`\{\{\s*([a-zA-Z0-9_.-]+)\s*\}\}`)
|
||||
|
||||
func resolveTemplate(s string, state *workflowExecState) string {
|
||||
if strings.TrimSpace(s) == "" {
|
||||
return fmt.Sprint(valueFromPath("previous.output", state))
|
||||
}
|
||||
return templateVarRe.ReplaceAllStringFunc(s, func(match string) string {
|
||||
m := templateVarRe.FindStringSubmatch(match)
|
||||
if len(m) != 2 {
|
||||
return match
|
||||
}
|
||||
return fmt.Sprint(valueFromPath(m[1], state))
|
||||
})
|
||||
}
|
||||
|
||||
func valueFromPath(path string, state *workflowExecState) any {
|
||||
parts := strings.Split(path, ".")
|
||||
if len(parts) == 0 {
|
||||
return ""
|
||||
}
|
||||
var cur any
|
||||
switch parts[0] {
|
||||
case "inputs", "input":
|
||||
cur = state.inputs
|
||||
case "previous", "prev":
|
||||
cur = state.lastOutput
|
||||
case "outputs":
|
||||
cur = state.outputs
|
||||
default:
|
||||
if v, ok := state.inputs[parts[0]]; ok {
|
||||
cur = v
|
||||
} else if v, ok := state.nodeOutputs[parts[0]]; ok {
|
||||
cur = v
|
||||
} else {
|
||||
return ""
|
||||
}
|
||||
}
|
||||
for _, p := range parts[1:] {
|
||||
m, ok := cur.(map[string]any)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
cur = m[p]
|
||||
}
|
||||
if cur == nil {
|
||||
return ""
|
||||
}
|
||||
return cur
|
||||
}
|
||||
|
||||
func evalCondition(expr string, state *workflowExecState) bool {
|
||||
expr = strings.TrimSpace(expr)
|
||||
if expr == "" {
|
||||
return true
|
||||
}
|
||||
resolved := strings.TrimSpace(resolveTemplate(expr, state))
|
||||
switch {
|
||||
case strings.Contains(resolved, "!="):
|
||||
parts := strings.SplitN(resolved, "!=", 2)
|
||||
return cleanComparable(parts[0]) != cleanComparable(parts[1])
|
||||
case strings.Contains(resolved, "=="):
|
||||
parts := strings.SplitN(resolved, "==", 2)
|
||||
return cleanComparable(parts[0]) == cleanComparable(parts[1])
|
||||
default:
|
||||
v := strings.ToLower(cleanComparable(resolved))
|
||||
return v != "" && v != "false" && v != "0" && v != "null"
|
||||
}
|
||||
}
|
||||
|
||||
func cleanComparable(s string) string {
|
||||
s = strings.TrimSpace(s)
|
||||
s = strings.Trim(s, `"'`)
|
||||
return s
|
||||
}
|
||||
|
||||
func renderWorkflowResponse(roleName, workflowName string, version int, runID string, state *workflowExecState) string {
|
||||
var sb strings.Builder
|
||||
sb.WriteString(fmt.Sprintf("角色「%s」已完成工作流「%s」(版本 %d)。\n\n", roleName, workflowName, version))
|
||||
sb.WriteString(fmt.Sprintf("运行 ID:%s\n", runID))
|
||||
sb.WriteString(fmt.Sprintf("已执行节点:%d", len(state.executed)))
|
||||
if len(state.skipped) > 0 {
|
||||
sb.WriteString(fmt.Sprintf(",跳过节点:%d", len(state.skipped)))
|
||||
}
|
||||
sb.WriteString("\n\n")
|
||||
if len(state.outputs) > 0 {
|
||||
sb.WriteString("输出:\n")
|
||||
keys := make([]string, 0, len(state.outputs))
|
||||
for k := range state.outputs {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
for _, k := range keys {
|
||||
sb.WriteString(fmt.Sprintf("- %s:%v\n", k, state.outputs[k]))
|
||||
}
|
||||
} else {
|
||||
sb.WriteString("暂无输出。请检查是否配置了输出节点,或条件分支是否命中。\n")
|
||||
}
|
||||
if len(state.skipped) > 0 {
|
||||
sb.WriteString("\n未执行的节点类型仍会保留运行记录:")
|
||||
sb.WriteString(strings.Join(state.skipped, "、"))
|
||||
sb.WriteString("。")
|
||||
}
|
||||
return strings.TrimSpace(sb.String())
|
||||
return &RunResult{Response: response, RunID: runID, Status: "completed"}, nil
|
||||
}
|
||||
|
||||
@@ -0,0 +1,224 @@
|
||||
package workflow
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/cloudwego/eino/schema"
|
||||
)
|
||||
|
||||
func init() {
|
||||
schema.RegisterName[*WorkflowLocalState]("_cyberstrike_workflow_local_state")
|
||||
}
|
||||
|
||||
// WorkflowLocalState is the Eino WithGenLocalState payload (checkpoint-serializable).
|
||||
type WorkflowLocalState struct {
|
||||
Inputs map[string]any `json:"inputs,omitempty"`
|
||||
Outputs map[string]any `json:"outputs,omitempty"`
|
||||
NodeOutputs map[string]map[string]any `json:"nodeOutputs,omitempty"`
|
||||
NodeProceed map[string]bool `json:"nodeProceed,omitempty"`
|
||||
LastOutput map[string]any `json:"lastOutput,omitempty"`
|
||||
Executed []string `json:"executed,omitempty"`
|
||||
Skipped []string `json:"skipped,omitempty"`
|
||||
WorkflowRunID string `json:"workflowRunId,omitempty"`
|
||||
MainIterationOffset int `json:"mainIterationOffset,omitempty"`
|
||||
SegmentMaxIteration int `json:"segmentMaxIteration,omitempty"`
|
||||
}
|
||||
|
||||
func newWorkflowLocalState(inputs map[string]interface{}, runID string) *WorkflowLocalState {
|
||||
in := make(map[string]any, len(inputs))
|
||||
for k, v := range inputs {
|
||||
in[k] = v
|
||||
}
|
||||
return &WorkflowLocalState{
|
||||
Inputs: in,
|
||||
Outputs: make(map[string]any),
|
||||
NodeOutputs: make(map[string]map[string]any),
|
||||
NodeProceed: make(map[string]bool),
|
||||
WorkflowRunID: runID,
|
||||
}
|
||||
}
|
||||
|
||||
var templateVarRe = regexp.MustCompile(`\{\{\s*([a-zA-Z0-9_.-]+)\s*\}\}`)
|
||||
|
||||
func resolveTemplate(s string, state *WorkflowLocalState) string {
|
||||
if strings.TrimSpace(s) == "" {
|
||||
return fmt.Sprint(valueFromPath("previous.output", state))
|
||||
}
|
||||
return templateVarRe.ReplaceAllStringFunc(s, func(match string) string {
|
||||
m := templateVarRe.FindStringSubmatch(match)
|
||||
if len(m) != 2 {
|
||||
return match
|
||||
}
|
||||
return fmt.Sprint(valueFromPath(m[1], state))
|
||||
})
|
||||
}
|
||||
|
||||
func valueFromPath(path string, state *WorkflowLocalState) any {
|
||||
parts := strings.Split(path, ".")
|
||||
if len(parts) == 0 {
|
||||
return ""
|
||||
}
|
||||
var cur any
|
||||
switch parts[0] {
|
||||
case "inputs", "input":
|
||||
cur = state.Inputs
|
||||
case "previous", "prev":
|
||||
cur = state.LastOutput
|
||||
case "outputs":
|
||||
cur = state.Outputs
|
||||
default:
|
||||
if v, ok := state.Inputs[parts[0]]; ok {
|
||||
cur = v
|
||||
} else if v, ok := state.NodeOutputs[parts[0]]; ok {
|
||||
cur = v
|
||||
} else {
|
||||
return ""
|
||||
}
|
||||
}
|
||||
for _, p := range parts[1:] {
|
||||
m, ok := cur.(map[string]any)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
cur = m[p]
|
||||
}
|
||||
if cur == nil {
|
||||
return ""
|
||||
}
|
||||
return cur
|
||||
}
|
||||
|
||||
func evalCondition(expr string, state *WorkflowLocalState) bool {
|
||||
expr = strings.TrimSpace(expr)
|
||||
if expr == "" {
|
||||
return true
|
||||
}
|
||||
resolved := strings.TrimSpace(resolveTemplate(expr, state))
|
||||
switch {
|
||||
case strings.Contains(resolved, "!="):
|
||||
parts := strings.SplitN(resolved, "!=", 2)
|
||||
return cleanComparable(parts[0]) != cleanComparable(parts[1])
|
||||
case strings.Contains(resolved, "=="):
|
||||
parts := strings.SplitN(resolved, "==", 2)
|
||||
return cleanComparable(parts[0]) == cleanComparable(parts[1])
|
||||
default:
|
||||
v := strings.ToLower(cleanComparable(resolved))
|
||||
return v != "" && v != "false" && v != "0" && v != "null"
|
||||
}
|
||||
}
|
||||
|
||||
func cleanComparable(s string) string {
|
||||
s = strings.TrimSpace(s)
|
||||
s = strings.Trim(s, `"'`)
|
||||
return s
|
||||
}
|
||||
|
||||
func edgeAllowed(edge graphEdge, sourceNode graphNode, edgeIndex int, state *WorkflowLocalState) bool {
|
||||
cond := firstNonEmpty(cfgString(edge.Config, "condition"), cfgString(edge.Config, "expression"))
|
||||
if cond != "" {
|
||||
return evalCondition(cond, state)
|
||||
}
|
||||
if strings.EqualFold(strings.TrimSpace(sourceNode.Type), "condition") {
|
||||
return conditionBranchAllowed(edge, edgeIndex, state)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func conditionBranchAllowed(edge graphEdge, edgeIndex int, state *WorkflowLocalState) bool {
|
||||
matched := conditionMatched(state)
|
||||
if branch := conditionBranchHint(edge); branch != "" {
|
||||
return (branch == "true" && matched) || (branch == "false" && !matched)
|
||||
}
|
||||
switch edgeIndex {
|
||||
case 0:
|
||||
return matched
|
||||
case 1:
|
||||
return !matched
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func conditionMatched(state *WorkflowLocalState) bool {
|
||||
v := strings.ToLower(cleanComparable(fmt.Sprint(valueFromPath("previous.matched", state))))
|
||||
return v == "true" || v == "1"
|
||||
}
|
||||
|
||||
func conditionBranchHint(edge graphEdge) string {
|
||||
if edge.Config != nil {
|
||||
switch strings.ToLower(strings.TrimSpace(cfgString(edge.Config, "branch"))) {
|
||||
case "true", "yes", "y", "是":
|
||||
return "true"
|
||||
case "false", "no", "n", "否":
|
||||
return "false"
|
||||
}
|
||||
}
|
||||
switch strings.ToLower(strings.TrimSpace(edge.Label)) {
|
||||
case "true", "yes", "y", "是":
|
||||
return "true"
|
||||
case "false", "no", "n", "否":
|
||||
return "false"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func cfgString(cfg map[string]any, key string) string {
|
||||
if cfg == nil {
|
||||
return ""
|
||||
}
|
||||
if v, ok := cfg[key]; ok {
|
||||
return strings.TrimSpace(fmt.Sprint(v))
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func firstNonEmpty(values ...string) string {
|
||||
for _, value := range values {
|
||||
if s := strings.TrimSpace(value); s != "" {
|
||||
return s
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func truncateWorkflowPreview(s string, limit int) string {
|
||||
s = strings.TrimSpace(s)
|
||||
if limit <= 0 || len([]rune(s)) <= limit {
|
||||
return s
|
||||
}
|
||||
runes := []rune(s)
|
||||
return string(runes[:limit]) + "..."
|
||||
}
|
||||
|
||||
func renderWorkflowResponse(roleName, workflowName string, version int, runID string, state *WorkflowLocalState) string {
|
||||
var sb strings.Builder
|
||||
sb.WriteString(fmt.Sprintf("角色「%s」已完成工作流「%s」(版本 %d)。\n\n", roleName, workflowName, version))
|
||||
sb.WriteString(fmt.Sprintf("运行 ID:%s\n", runID))
|
||||
sb.WriteString(fmt.Sprintf("已执行节点:%d", len(state.Executed)))
|
||||
if len(state.Skipped) > 0 {
|
||||
sb.WriteString(fmt.Sprintf(",跳过节点:%d", len(state.Skipped)))
|
||||
}
|
||||
sb.WriteString("\n\n")
|
||||
if len(state.Outputs) > 0 {
|
||||
sb.WriteString("输出:\n")
|
||||
keys := make([]string, 0, len(state.Outputs))
|
||||
for k := range state.Outputs {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
for _, k := range keys {
|
||||
sb.WriteString(fmt.Sprintf("- %s:%v\n", k, state.Outputs[k]))
|
||||
}
|
||||
} else {
|
||||
sb.WriteString("暂无输出。请检查是否配置了输出节点,或条件分支是否命中。\n")
|
||||
}
|
||||
if len(state.Skipped) > 0 {
|
||||
sb.WriteString("\n未执行的节点类型仍会保留运行记录:")
|
||||
sb.WriteString(strings.Join(state.Skipped, "、"))
|
||||
sb.WriteString("。")
|
||||
}
|
||||
return strings.TrimSpace(sb.String())
|
||||
}
|
||||
@@ -0,0 +1,74 @@
|
||||
package workflow
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
// WorkflowInput is the typed entry for Eino compose.Workflow[I,O].
|
||||
type WorkflowInput struct {
|
||||
Message string `json:"message"`
|
||||
ConversationID string `json:"conversationId"`
|
||||
ProjectID string `json:"projectId"`
|
||||
Role string `json:"role"`
|
||||
WorkflowID string `json:"workflowId"`
|
||||
WorkflowVersion int `json:"workflowVersion"`
|
||||
}
|
||||
|
||||
// WorkflowOutput aggregates terminal node payloads keyed by canvas node id.
|
||||
type WorkflowOutput map[string]any
|
||||
|
||||
// WorkflowNodeOutput is the per-node lambda payload (alias for Eino edge type alignment).
|
||||
type WorkflowNodeOutput = map[string]interface{}
|
||||
|
||||
func workflowInputFromMap(m map[string]interface{}) WorkflowInput {
|
||||
in := WorkflowInput{}
|
||||
if m == nil {
|
||||
return in
|
||||
}
|
||||
if v, ok := m["message"].(string); ok {
|
||||
in.Message = v
|
||||
} else if m["message"] != nil {
|
||||
in.Message = fmt.Sprint(m["message"])
|
||||
}
|
||||
if v, ok := m["conversationId"].(string); ok {
|
||||
in.ConversationID = v
|
||||
}
|
||||
if v, ok := m["projectId"].(string); ok {
|
||||
in.ProjectID = v
|
||||
}
|
||||
if v, ok := m["role"].(string); ok {
|
||||
in.Role = v
|
||||
}
|
||||
if v, ok := m["workflowId"].(string); ok {
|
||||
in.WorkflowID = v
|
||||
}
|
||||
switch v := m["workflowVersion"].(type) {
|
||||
case int:
|
||||
in.WorkflowVersion = v
|
||||
case int64:
|
||||
in.WorkflowVersion = int(v)
|
||||
case float64:
|
||||
in.WorkflowVersion = int(v)
|
||||
case string:
|
||||
if n, err := strconv.Atoi(v); err == nil {
|
||||
in.WorkflowVersion = n
|
||||
}
|
||||
}
|
||||
return in
|
||||
}
|
||||
|
||||
func (in WorkflowInput) toStateInputs() map[string]any {
|
||||
return map[string]any{
|
||||
"message": in.Message,
|
||||
"conversationId": in.ConversationID,
|
||||
"projectId": in.ProjectID,
|
||||
"role": in.Role,
|
||||
"workflowId": in.WorkflowID,
|
||||
"workflowVersion": in.WorkflowVersion,
|
||||
}
|
||||
}
|
||||
|
||||
func cacheKey(workflowID string, version int) string {
|
||||
return workflowID + ":" + strconv.Itoa(version)
|
||||
}
|
||||
Reference in New Issue
Block a user