mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-07-04 19:48:02 +02:00
237 lines
6.2 KiB
Go
237 lines
6.2 KiB
Go
package workflow
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"strings"
|
|
"sync"
|
|
|
|
"github.com/cloudwego/eino/compose"
|
|
)
|
|
|
|
type compiledArtifact struct {
|
|
runnable compose.Runnable[WorkflowInput, WorkflowOutput]
|
|
idx *graphIndex
|
|
hitlIDs []string
|
|
}
|
|
|
|
// Engine compiles and caches Eino Workflow artifacts.
|
|
type Engine struct {
|
|
mu sync.RWMutex
|
|
cache map[string]*compiledArtifact
|
|
cpStore compose.CheckPointStore
|
|
cpStoreMu sync.Once
|
|
cpStoreErr error
|
|
checkpointDir string
|
|
}
|
|
|
|
var defaultEngine = &Engine{
|
|
cache: make(map[string]*compiledArtifact),
|
|
checkpointDir: "data/workflow-checkpoints",
|
|
}
|
|
|
|
// SetCheckpointDir overrides the workflow checkpoint root (mainly for tests).
|
|
func SetCheckpointDir(dir string) {
|
|
defaultEngine.mu.Lock()
|
|
defer defaultEngine.mu.Unlock()
|
|
defaultEngine.checkpointDir = strings.TrimSpace(dir)
|
|
defaultEngine.cpStore = nil
|
|
defaultEngine.cpStoreErr = nil
|
|
defaultEngine.cpStoreMu = sync.Once{}
|
|
}
|
|
|
|
func (e *Engine) checkpointStore() (compose.CheckPointStore, error) {
|
|
e.cpStoreMu.Do(func() {
|
|
e.cpStore, e.cpStoreErr = newFileCheckPointStore(e.checkpointDir)
|
|
})
|
|
return e.cpStore, e.cpStoreErr
|
|
}
|
|
|
|
// InvalidateCompiledCache drops cached compilations for a workflow id.
|
|
func InvalidateCompiledCache(workflowID string) {
|
|
workflowID = strings.TrimSpace(workflowID)
|
|
if workflowID == "" {
|
|
return
|
|
}
|
|
defaultEngine.mu.Lock()
|
|
defer defaultEngine.mu.Unlock()
|
|
for key := range defaultEngine.cache {
|
|
if strings.HasPrefix(key, workflowID+":") {
|
|
delete(defaultEngine.cache, key)
|
|
}
|
|
}
|
|
}
|
|
|
|
// ValidateGraphJSON parses and trial-compiles a canvas graph (save-time gate).
|
|
func ValidateGraphJSON(ctx context.Context, graphJSON string) error {
|
|
g, err := parseGraph(graphJSON)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
idx := indexGraph(g)
|
|
if len(findStartNodeIDs(idx)) == 0 {
|
|
return fmt.Errorf("工作流缺少可执行的起点节点")
|
|
}
|
|
if !hasTerminalNode(idx) {
|
|
return fmt.Errorf("工作流至少需要一个无出边的终点或 output/end 节点")
|
|
}
|
|
_, err = defaultEngine.compile(ctx, g)
|
|
return err
|
|
}
|
|
|
|
func hasTerminalNode(idx *graphIndex) bool {
|
|
for id, node := range idx.nodes {
|
|
if len(idx.outgoing[id]) == 0 {
|
|
return true
|
|
}
|
|
if strings.EqualFold(node.Type, "end") || strings.EqualFold(node.Type, "output") {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func (e *Engine) getOrCompile(ctx context.Context, workflowID string, version int, g *graphDef) (*compiledArtifact, error) {
|
|
key := cacheKey(workflowID, version)
|
|
e.mu.RLock()
|
|
if art, ok := e.cache[key]; ok {
|
|
e.mu.RUnlock()
|
|
return art, nil
|
|
}
|
|
e.mu.RUnlock()
|
|
|
|
art, err := e.compile(ctx, g)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
e.mu.Lock()
|
|
if existing, ok := e.cache[key]; ok {
|
|
e.mu.Unlock()
|
|
return existing, nil
|
|
}
|
|
e.cache[key] = art
|
|
e.mu.Unlock()
|
|
return art, nil
|
|
}
|
|
|
|
func (e *Engine) compile(ctx context.Context, g *graphDef) (*compiledArtifact, error) {
|
|
cpStore, err := e.checkpointStore()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
idx := indexGraph(g)
|
|
hitlIDs := collectHITLNodeIDs(idx)
|
|
compileOpts := []compose.GraphCompileOption{
|
|
compose.WithGraphName("CyberStrikeWorkflow"),
|
|
compose.WithCheckPointStore(cpStore),
|
|
}
|
|
if len(hitlIDs) > 0 {
|
|
compileOpts = append(compileOpts, compose.WithInterruptBeforeNodes(hitlIDs))
|
|
}
|
|
|
|
wf := compose.NewWorkflow[WorkflowInput, WorkflowOutput](
|
|
compose.WithGenLocalState(func(runCtx context.Context) *WorkflowLocalState {
|
|
if rt := workflowRuntimeFrom(runCtx); rt != nil && rt.state != nil {
|
|
return rt.state
|
|
}
|
|
return &WorkflowLocalState{
|
|
Outputs: make(map[string]any),
|
|
NodeOutputs: make(map[string]map[string]any),
|
|
NodeProceed: make(map[string]bool),
|
|
}
|
|
}),
|
|
)
|
|
|
|
nodeRefs := make(map[string]*compose.WorkflowNode, len(idx.nodes))
|
|
for id, node := range idx.nodes {
|
|
n := node
|
|
if strings.EqualFold(n.Type, "agent") {
|
|
sub, err := compileAgentSubgraph(ctx, n)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("编译 Agent 子图 %s 失败: %w", id, err)
|
|
}
|
|
nodeRefs[id] = wf.AddGraphNode(id, sub)
|
|
continue
|
|
}
|
|
if strings.EqualFold(n.Type, "start") {
|
|
nodeRefs[id] = wf.AddLambdaNode(id, compose.InvokableLambda(func(runCtx context.Context, _ WorkflowInput) (WorkflowNodeOutput, error) {
|
|
return runWorkflowNodeLambda(runCtx, n)
|
|
}))
|
|
continue
|
|
}
|
|
nodeRefs[id] = wf.AddLambdaNode(id, compose.InvokableLambda(func(runCtx context.Context, _ WorkflowNodeOutput) (WorkflowNodeOutput, error) {
|
|
return runWorkflowNodeLambda(runCtx, n)
|
|
}))
|
|
}
|
|
|
|
for id, node := range idx.nodes {
|
|
if strings.EqualFold(node.Type, "condition") {
|
|
if err := wireConditionBranch(wf, nodeRefs, idx, id, node); err != nil {
|
|
return nil, err
|
|
}
|
|
continue
|
|
}
|
|
if hasConditionalOutgoingEdges(idx, id) {
|
|
if err := wireEdgeConditionBranch(wf, nodeRefs, idx, id, node); err != nil {
|
|
return nil, err
|
|
}
|
|
continue
|
|
}
|
|
for _, edge := range idx.outgoing[id] {
|
|
if target, ok := nodeRefs[edge.Target]; ok {
|
|
target.AddInput(id)
|
|
}
|
|
}
|
|
}
|
|
|
|
for _, startID := range findStartNodeIDs(idx) {
|
|
if ref, ok := nodeRefs[startID]; ok {
|
|
ref.AddInput(compose.START)
|
|
}
|
|
}
|
|
|
|
endNode := wf.End()
|
|
for id, node := range idx.nodes {
|
|
if len(idx.outgoing[id]) == 0 || strings.EqualFold(node.Type, "end") {
|
|
endNode.AddInput(id, compose.ToField(id))
|
|
}
|
|
}
|
|
|
|
runnable, err := wf.Compile(ctx, compileOpts...)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &compiledArtifact{runnable: runnable, idx: idx, hitlIDs: hitlIDs}, nil
|
|
}
|
|
|
|
func collectHITLNodeIDs(idx *graphIndex) []string {
|
|
var ids []string
|
|
for id, node := range idx.nodes {
|
|
if strings.EqualFold(node.Type, "hitl") {
|
|
ids = append(ids, id)
|
|
}
|
|
}
|
|
return ids
|
|
}
|
|
|
|
func runWorkflowNodeLambda(runCtx context.Context, n graphNode) (WorkflowNodeOutput, error) {
|
|
localRT := workflowRuntimeFrom(runCtx)
|
|
if localRT == nil {
|
|
return nil, fmt.Errorf("workflow runtime missing in context")
|
|
}
|
|
result, proceed, err := executeNode(runCtx, localRT.args, localRT.runID, n, localRT.state)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
localRT.state.NodeOutputs[n.ID] = result
|
|
localRT.state.LastOutput = result
|
|
if !proceed && !strings.EqualFold(n.Type, "end") {
|
|
label := firstNonEmpty(n.Label, n.ID)
|
|
if errText := cfgString(result, "error"); errText != "" {
|
|
return result, fmt.Errorf("节点「%s」失败: %s", label, errText)
|
|
}
|
|
return result, fmt.Errorf("节点「%s」未继续执行", label)
|
|
}
|
|
return result, nil
|
|
}
|