mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-07-04 03:27:54 +02:00
Add files via upload
This commit is contained in:
@@ -0,0 +1,236 @@
|
||||
package workflow
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/cloudwego/eino/compose"
|
||||
)
|
||||
|
||||
type compiledArtifact struct {
|
||||
runnable compose.Runnable[WorkflowInput, WorkflowOutput]
|
||||
idx *graphIndex
|
||||
hitlIDs []string
|
||||
}
|
||||
|
||||
// Engine compiles and caches Eino Workflow artifacts.
|
||||
type Engine struct {
|
||||
mu sync.RWMutex
|
||||
cache map[string]*compiledArtifact
|
||||
cpStore compose.CheckPointStore
|
||||
cpStoreMu sync.Once
|
||||
cpStoreErr error
|
||||
checkpointDir string
|
||||
}
|
||||
|
||||
var defaultEngine = &Engine{
|
||||
cache: make(map[string]*compiledArtifact),
|
||||
checkpointDir: "data/workflow-checkpoints",
|
||||
}
|
||||
|
||||
// SetCheckpointDir overrides the workflow checkpoint root (mainly for tests).
|
||||
func SetCheckpointDir(dir string) {
|
||||
defaultEngine.mu.Lock()
|
||||
defer defaultEngine.mu.Unlock()
|
||||
defaultEngine.checkpointDir = strings.TrimSpace(dir)
|
||||
defaultEngine.cpStore = nil
|
||||
defaultEngine.cpStoreErr = nil
|
||||
defaultEngine.cpStoreMu = sync.Once{}
|
||||
}
|
||||
|
||||
func (e *Engine) checkpointStore() (compose.CheckPointStore, error) {
|
||||
e.cpStoreMu.Do(func() {
|
||||
e.cpStore, e.cpStoreErr = newFileCheckPointStore(e.checkpointDir)
|
||||
})
|
||||
return e.cpStore, e.cpStoreErr
|
||||
}
|
||||
|
||||
// InvalidateCompiledCache drops cached compilations for a workflow id.
|
||||
func InvalidateCompiledCache(workflowID string) {
|
||||
workflowID = strings.TrimSpace(workflowID)
|
||||
if workflowID == "" {
|
||||
return
|
||||
}
|
||||
defaultEngine.mu.Lock()
|
||||
defer defaultEngine.mu.Unlock()
|
||||
for key := range defaultEngine.cache {
|
||||
if strings.HasPrefix(key, workflowID+":") {
|
||||
delete(defaultEngine.cache, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateGraphJSON parses and trial-compiles a canvas graph (save-time gate).
|
||||
func ValidateGraphJSON(ctx context.Context, graphJSON string) error {
|
||||
g, err := parseGraph(graphJSON)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
idx := indexGraph(g)
|
||||
if len(findStartNodeIDs(idx)) == 0 {
|
||||
return fmt.Errorf("工作流缺少可执行的起点节点")
|
||||
}
|
||||
if !hasTerminalNode(idx) {
|
||||
return fmt.Errorf("工作流至少需要一个无出边的终点或 output/end 节点")
|
||||
}
|
||||
_, err = defaultEngine.compile(ctx, g)
|
||||
return err
|
||||
}
|
||||
|
||||
func hasTerminalNode(idx *graphIndex) bool {
|
||||
for id, node := range idx.nodes {
|
||||
if len(idx.outgoing[id]) == 0 {
|
||||
return true
|
||||
}
|
||||
if strings.EqualFold(node.Type, "end") || strings.EqualFold(node.Type, "output") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (e *Engine) getOrCompile(ctx context.Context, workflowID string, version int, g *graphDef) (*compiledArtifact, error) {
|
||||
key := cacheKey(workflowID, version)
|
||||
e.mu.RLock()
|
||||
if art, ok := e.cache[key]; ok {
|
||||
e.mu.RUnlock()
|
||||
return art, nil
|
||||
}
|
||||
e.mu.RUnlock()
|
||||
|
||||
art, err := e.compile(ctx, g)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
e.mu.Lock()
|
||||
if existing, ok := e.cache[key]; ok {
|
||||
e.mu.Unlock()
|
||||
return existing, nil
|
||||
}
|
||||
e.cache[key] = art
|
||||
e.mu.Unlock()
|
||||
return art, nil
|
||||
}
|
||||
|
||||
func (e *Engine) compile(ctx context.Context, g *graphDef) (*compiledArtifact, error) {
|
||||
cpStore, err := e.checkpointStore()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
idx := indexGraph(g)
|
||||
hitlIDs := collectHITLNodeIDs(idx)
|
||||
compileOpts := []compose.GraphCompileOption{
|
||||
compose.WithGraphName("CyberStrikeWorkflow"),
|
||||
compose.WithCheckPointStore(cpStore),
|
||||
}
|
||||
if len(hitlIDs) > 0 {
|
||||
compileOpts = append(compileOpts, compose.WithInterruptBeforeNodes(hitlIDs))
|
||||
}
|
||||
|
||||
wf := compose.NewWorkflow[WorkflowInput, WorkflowOutput](
|
||||
compose.WithGenLocalState(func(runCtx context.Context) *WorkflowLocalState {
|
||||
if rt := workflowRuntimeFrom(runCtx); rt != nil && rt.state != nil {
|
||||
return rt.state
|
||||
}
|
||||
return &WorkflowLocalState{
|
||||
Outputs: make(map[string]any),
|
||||
NodeOutputs: make(map[string]map[string]any),
|
||||
NodeProceed: make(map[string]bool),
|
||||
}
|
||||
}),
|
||||
)
|
||||
|
||||
nodeRefs := make(map[string]*compose.WorkflowNode, len(idx.nodes))
|
||||
for id, node := range idx.nodes {
|
||||
n := node
|
||||
if strings.EqualFold(n.Type, "agent") {
|
||||
sub, err := compileAgentSubgraph(ctx, n)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("编译 Agent 子图 %s 失败: %w", id, err)
|
||||
}
|
||||
nodeRefs[id] = wf.AddGraphNode(id, sub)
|
||||
continue
|
||||
}
|
||||
if strings.EqualFold(n.Type, "start") {
|
||||
nodeRefs[id] = wf.AddLambdaNode(id, compose.InvokableLambda(func(runCtx context.Context, _ WorkflowInput) (WorkflowNodeOutput, error) {
|
||||
return runWorkflowNodeLambda(runCtx, n)
|
||||
}))
|
||||
continue
|
||||
}
|
||||
nodeRefs[id] = wf.AddLambdaNode(id, compose.InvokableLambda(func(runCtx context.Context, _ WorkflowNodeOutput) (WorkflowNodeOutput, error) {
|
||||
return runWorkflowNodeLambda(runCtx, n)
|
||||
}))
|
||||
}
|
||||
|
||||
for id, node := range idx.nodes {
|
||||
if strings.EqualFold(node.Type, "condition") {
|
||||
if err := wireConditionBranch(wf, nodeRefs, idx, id, node); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
continue
|
||||
}
|
||||
if hasConditionalOutgoingEdges(idx, id) {
|
||||
if err := wireEdgeConditionBranch(wf, nodeRefs, idx, id, node); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
continue
|
||||
}
|
||||
for _, edge := range idx.outgoing[id] {
|
||||
if target, ok := nodeRefs[edge.Target]; ok {
|
||||
target.AddInput(id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, startID := range findStartNodeIDs(idx) {
|
||||
if ref, ok := nodeRefs[startID]; ok {
|
||||
ref.AddInput(compose.START)
|
||||
}
|
||||
}
|
||||
|
||||
endNode := wf.End()
|
||||
for id, node := range idx.nodes {
|
||||
if len(idx.outgoing[id]) == 0 || strings.EqualFold(node.Type, "end") {
|
||||
endNode.AddInput(id, compose.ToField(id))
|
||||
}
|
||||
}
|
||||
|
||||
runnable, err := wf.Compile(ctx, compileOpts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &compiledArtifact{runnable: runnable, idx: idx, hitlIDs: hitlIDs}, nil
|
||||
}
|
||||
|
||||
func collectHITLNodeIDs(idx *graphIndex) []string {
|
||||
var ids []string
|
||||
for id, node := range idx.nodes {
|
||||
if strings.EqualFold(node.Type, "hitl") {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
func runWorkflowNodeLambda(runCtx context.Context, n graphNode) (WorkflowNodeOutput, error) {
|
||||
localRT := workflowRuntimeFrom(runCtx)
|
||||
if localRT == nil {
|
||||
return nil, fmt.Errorf("workflow runtime missing in context")
|
||||
}
|
||||
result, proceed, err := executeNode(runCtx, localRT.args, localRT.runID, n, localRT.state)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
localRT.state.NodeOutputs[n.ID] = result
|
||||
localRT.state.LastOutput = result
|
||||
if !proceed && !strings.EqualFold(n.Type, "end") {
|
||||
label := firstNonEmpty(n.Label, n.ID)
|
||||
if errText := cfgString(result, "error"); errText != "" {
|
||||
return result, fmt.Errorf("节点「%s」失败: %s", label, errText)
|
||||
}
|
||||
return result, fmt.Errorf("节点「%s」未继续执行", label)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
Reference in New Issue
Block a user