Files
CyberStrikeAI/internal/workflow/engine.go
T
2026-07-03 19:36:40 +08:00

237 lines
6.2 KiB
Go

package workflow
import (
"context"
"fmt"
"strings"
"sync"
"github.com/cloudwego/eino/compose"
)
type compiledArtifact struct {
runnable compose.Runnable[WorkflowInput, WorkflowOutput]
idx *graphIndex
hitlIDs []string
}
// Engine compiles and caches Eino Workflow artifacts.
type Engine struct {
mu sync.RWMutex
cache map[string]*compiledArtifact
cpStore compose.CheckPointStore
cpStoreMu sync.Once
cpStoreErr error
checkpointDir string
}
var defaultEngine = &Engine{
cache: make(map[string]*compiledArtifact),
checkpointDir: "data/workflow-checkpoints",
}
// SetCheckpointDir overrides the workflow checkpoint root (mainly for tests).
func SetCheckpointDir(dir string) {
defaultEngine.mu.Lock()
defer defaultEngine.mu.Unlock()
defaultEngine.checkpointDir = strings.TrimSpace(dir)
defaultEngine.cpStore = nil
defaultEngine.cpStoreErr = nil
defaultEngine.cpStoreMu = sync.Once{}
}
func (e *Engine) checkpointStore() (compose.CheckPointStore, error) {
e.cpStoreMu.Do(func() {
e.cpStore, e.cpStoreErr = newFileCheckPointStore(e.checkpointDir)
})
return e.cpStore, e.cpStoreErr
}
// InvalidateCompiledCache drops cached compilations for a workflow id.
func InvalidateCompiledCache(workflowID string) {
workflowID = strings.TrimSpace(workflowID)
if workflowID == "" {
return
}
defaultEngine.mu.Lock()
defer defaultEngine.mu.Unlock()
for key := range defaultEngine.cache {
if strings.HasPrefix(key, workflowID+":") {
delete(defaultEngine.cache, key)
}
}
}
// ValidateGraphJSON parses and trial-compiles a canvas graph (save-time gate).
func ValidateGraphJSON(ctx context.Context, graphJSON string) error {
g, err := parseGraph(graphJSON)
if err != nil {
return err
}
idx := indexGraph(g)
if len(findStartNodeIDs(idx)) == 0 {
return fmt.Errorf("工作流缺少可执行的起点节点")
}
if !hasTerminalNode(idx) {
return fmt.Errorf("工作流至少需要一个无出边的终点或 output/end 节点")
}
_, err = defaultEngine.compile(ctx, g)
return err
}
func hasTerminalNode(idx *graphIndex) bool {
for id, node := range idx.nodes {
if len(idx.outgoing[id]) == 0 {
return true
}
if strings.EqualFold(node.Type, "end") || strings.EqualFold(node.Type, "output") {
return true
}
}
return false
}
func (e *Engine) getOrCompile(ctx context.Context, workflowID string, version int, g *graphDef) (*compiledArtifact, error) {
key := cacheKey(workflowID, version)
e.mu.RLock()
if art, ok := e.cache[key]; ok {
e.mu.RUnlock()
return art, nil
}
e.mu.RUnlock()
art, err := e.compile(ctx, g)
if err != nil {
return nil, err
}
e.mu.Lock()
if existing, ok := e.cache[key]; ok {
e.mu.Unlock()
return existing, nil
}
e.cache[key] = art
e.mu.Unlock()
return art, nil
}
func (e *Engine) compile(ctx context.Context, g *graphDef) (*compiledArtifact, error) {
cpStore, err := e.checkpointStore()
if err != nil {
return nil, err
}
idx := indexGraph(g)
hitlIDs := collectHITLNodeIDs(idx)
compileOpts := []compose.GraphCompileOption{
compose.WithGraphName("CyberStrikeWorkflow"),
compose.WithCheckPointStore(cpStore),
}
if len(hitlIDs) > 0 {
compileOpts = append(compileOpts, compose.WithInterruptBeforeNodes(hitlIDs))
}
wf := compose.NewWorkflow[WorkflowInput, WorkflowOutput](
compose.WithGenLocalState(func(runCtx context.Context) *WorkflowLocalState {
if rt := workflowRuntimeFrom(runCtx); rt != nil && rt.state != nil {
return rt.state
}
return &WorkflowLocalState{
Outputs: make(map[string]any),
NodeOutputs: make(map[string]map[string]any),
NodeProceed: make(map[string]bool),
}
}),
)
nodeRefs := make(map[string]*compose.WorkflowNode, len(idx.nodes))
for id, node := range idx.nodes {
n := node
if strings.EqualFold(n.Type, "agent") {
sub, err := compileAgentSubgraph(ctx, n)
if err != nil {
return nil, fmt.Errorf("编译 Agent 子图 %s 失败: %w", id, err)
}
nodeRefs[id] = wf.AddGraphNode(id, sub)
continue
}
if strings.EqualFold(n.Type, "start") {
nodeRefs[id] = wf.AddLambdaNode(id, compose.InvokableLambda(func(runCtx context.Context, _ WorkflowInput) (WorkflowNodeOutput, error) {
return runWorkflowNodeLambda(runCtx, n)
}))
continue
}
nodeRefs[id] = wf.AddLambdaNode(id, compose.InvokableLambda(func(runCtx context.Context, _ WorkflowNodeOutput) (WorkflowNodeOutput, error) {
return runWorkflowNodeLambda(runCtx, n)
}))
}
for id, node := range idx.nodes {
if strings.EqualFold(node.Type, "condition") {
if err := wireConditionBranch(wf, nodeRefs, idx, id, node); err != nil {
return nil, err
}
continue
}
if hasConditionalOutgoingEdges(idx, id) {
if err := wireEdgeConditionBranch(wf, nodeRefs, idx, id, node); err != nil {
return nil, err
}
continue
}
for _, edge := range idx.outgoing[id] {
if target, ok := nodeRefs[edge.Target]; ok {
target.AddInput(id)
}
}
}
for _, startID := range findStartNodeIDs(idx) {
if ref, ok := nodeRefs[startID]; ok {
ref.AddInput(compose.START)
}
}
endNode := wf.End()
for id, node := range idx.nodes {
if len(idx.outgoing[id]) == 0 || strings.EqualFold(node.Type, "end") {
endNode.AddInput(id, compose.ToField(id))
}
}
runnable, err := wf.Compile(ctx, compileOpts...)
if err != nil {
return nil, err
}
return &compiledArtifact{runnable: runnable, idx: idx, hitlIDs: hitlIDs}, nil
}
func collectHITLNodeIDs(idx *graphIndex) []string {
var ids []string
for id, node := range idx.nodes {
if strings.EqualFold(node.Type, "hitl") {
ids = append(ids, id)
}
}
return ids
}
func runWorkflowNodeLambda(runCtx context.Context, n graphNode) (WorkflowNodeOutput, error) {
localRT := workflowRuntimeFrom(runCtx)
if localRT == nil {
return nil, fmt.Errorf("workflow runtime missing in context")
}
result, proceed, err := executeNode(runCtx, localRT.args, localRT.runID, n, localRT.state)
if err != nil {
return nil, err
}
localRT.state.NodeOutputs[n.ID] = result
localRT.state.LastOutput = result
if !proceed && !strings.EqualFold(n.Type, "end") {
label := firstNonEmpty(n.Label, n.ID)
if errText := cfgString(result, "error"); errText != "" {
return result, fmt.Errorf("节点「%s」失败: %s", label, errText)
}
return result, fmt.Errorf("节点「%s」未继续执行", label)
}
return result, nil
}