From 3a995183a62fe984acf47478f9fc6bbf9c054fbc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=85=AC=E6=98=8E?= <83812544+Ed1s0nZ@users.noreply.github.com> Date: Fri, 3 Jul 2026 17:03:37 +0800 Subject: [PATCH] Add files via upload --- internal/handler/eino_single_agent.go | 6 + internal/handler/multi_agent.go | 6 + internal/handler/workflow.go | 142 +++++++++++++++++++++++ internal/handler/workflow_integration.go | 130 +++++++++++++++++++++ 4 files changed, 284 insertions(+) create mode 100644 internal/handler/workflow.go create mode 100644 internal/handler/workflow_integration.go diff --git a/internal/handler/eino_single_agent.go b/internal/handler/eino_single_agent.go index 19b697be..f1d2bccb 100644 --- a/internal/handler/eino_single_agent.go +++ b/internal/handler/eino_single_agent.go @@ -116,6 +116,9 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) { "userMessageId": prep.UserMessageID, }) } + if h.runRoleWorkflowStreamIfBound(&req, prep, sendEvent) { + return + } var cancelWithCause context.CancelCauseFunc curFinalMessage := prep.FinalMessage @@ -385,6 +388,9 @@ func (h *AgentHandler) EinoSingleAgentLoop(c *gin.Context) { if h.hitlManager != nil { defer h.hitlManager.DeactivateConversation(prep.ConversationID) } + if h.runRoleWorkflowJSONIfBound(c, &req, prep) { + return + } var progressBuf strings.Builder progressCallbackRaw := func(eventType, message string, data interface{}) { diff --git a/internal/handler/multi_agent.go b/internal/handler/multi_agent.go index 1c7828ac..f7e0f7ac 100644 --- a/internal/handler/multi_agent.go +++ b/internal/handler/multi_agent.go @@ -133,6 +133,9 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) { "userMessageId": prep.UserMessageID, }) } + if h.runRoleWorkflowStreamIfBound(&req, prep, sendEvent) { + return + } var cancelWithCause context.CancelCauseFunc curFinalMessage := prep.FinalMessage @@ -407,6 +410,9 @@ func (h *AgentHandler) MultiAgentLoop(c *gin.Context) { if h.hitlManager != nil { defer h.hitlManager.DeactivateConversation(prep.ConversationID) } + if h.runRoleWorkflowJSONIfBound(c, &req, prep) { + return + } baseCtx, cancelWithCause := context.WithCancelCause(c.Request.Context()) defer cancelWithCause(nil) diff --git a/internal/handler/workflow.go b/internal/handler/workflow.go new file mode 100644 index 00000000..2e6a6ce0 --- /dev/null +++ b/internal/handler/workflow.go @@ -0,0 +1,142 @@ +package handler + +import ( + "encoding/json" + "net/http" + "strings" + + "cyberstrike-ai/internal/audit" + "cyberstrike-ai/internal/database" + + "github.com/gin-gonic/gin" + "go.uber.org/zap" +) + +type WorkflowHandler struct { + db *database.DB + logger *zap.Logger + audit *audit.Service +} + +func NewWorkflowHandler(db *database.DB, logger *zap.Logger) *WorkflowHandler { + return &WorkflowHandler{db: db, logger: logger} +} + +func (h *WorkflowHandler) SetAudit(s *audit.Service) { + h.audit = s +} + +type workflowSaveRequest struct { + ID string `json:"id"` + Name string `json:"name"` + Description string `json:"description,omitempty"` + Version int `json:"version,omitempty"` + Enabled *bool `json:"enabled,omitempty"` + Graph json.RawMessage `json:"graph,omitempty"` + GraphJSON json.RawMessage `json:"graph_json,omitempty"` +} + +func (h *WorkflowHandler) List(c *gin.Context) { + includeDisabled := strings.EqualFold(c.Query("includeDisabled"), "true") || c.Query("include_disabled") == "1" + items, err := h.db.ListWorkflowDefinitions(includeDisabled) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + c.JSON(http.StatusOK, gin.H{"workflows": items}) +} + +func (h *WorkflowHandler) Get(c *gin.Context) { + id := strings.TrimSpace(c.Param("id")) + wf, err := h.db.GetWorkflowDefinition(id) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + if wf == nil { + c.JSON(http.StatusNotFound, gin.H{"error": "工作流不存在"}) + return + } + c.JSON(http.StatusOK, gin.H{"workflow": wf}) +} + +func (h *WorkflowHandler) Create(c *gin.Context) { + h.save(c, "") +} + +func (h *WorkflowHandler) Update(c *gin.Context) { + h.save(c, c.Param("id")) +} + +func (h *WorkflowHandler) save(c *gin.Context, pathID string) { + var req workflowSaveRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()}) + return + } + id := strings.TrimSpace(req.ID) + if strings.TrimSpace(pathID) != "" { + id = strings.TrimSpace(pathID) + } + name := strings.TrimSpace(req.Name) + if id == "" || name == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "工作流 id 和 name 不能为空"}) + return + } + graph := req.Graph + if len(graph) == 0 { + graph = req.GraphJSON + } + if len(graph) == 0 { + graph = []byte(`{"nodes":[],"edges":[],"config":{}}`) + } + if !json.Valid(graph) { + c.JSON(http.StatusBadRequest, gin.H{"error": "graph 必须是合法 JSON"}) + return + } + var probe interface{} + if err := json.Unmarshal(graph, &probe); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "graph JSON 解析失败: " + err.Error()}) + return + } + enabled := true + if req.Enabled != nil { + enabled = *req.Enabled + } + wf := &database.WorkflowDefinition{ + ID: id, + Name: name, + Description: strings.TrimSpace(req.Description), + Version: req.Version, + GraphJSON: string(graph), + Enabled: enabled, + } + if err := h.db.UpsertWorkflowDefinition(wf); err != nil { + if h.logger != nil { + h.logger.Warn("保存工作流失败", zap.String("id", id), zap.Error(err)) + } + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + saved, _ := h.db.GetWorkflowDefinition(id) + if h.audit != nil { + h.audit.RecordOK(c, "workflow", "save", "保存图编排流程", "workflow", id, map[string]interface{}{"name": name}) + } + c.JSON(http.StatusOK, gin.H{"message": "工作流已保存", "workflow": saved}) +} + +func (h *WorkflowHandler) Delete(c *gin.Context) { + id := strings.TrimSpace(c.Param("id")) + if id == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "工作流 id 不能为空"}) + return + } + if err := h.db.DeleteWorkflowDefinition(id); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + if h.audit != nil { + h.audit.RecordOK(c, "workflow", "delete", "删除图编排流程", "workflow", id, nil) + } + c.JSON(http.StatusOK, gin.H{"message": "工作流已删除"}) +} diff --git a/internal/handler/workflow_integration.go b/internal/handler/workflow_integration.go new file mode 100644 index 00000000..b68896af --- /dev/null +++ b/internal/handler/workflow_integration.go @@ -0,0 +1,130 @@ +package handler + +import ( + "context" + "net/http" + "strings" + "time" + + "cyberstrike-ai/internal/config" + workflowrunner "cyberstrike-ai/internal/workflow" + + "github.com/gin-gonic/gin" +) + +func (h *AgentHandler) roleForWorkflow(req *ChatRequest) (config.RoleConfig, bool) { + if h == nil || h.config == nil || h.config.Roles == nil || req == nil { + return config.RoleConfig{}, false + } + roleName := strings.TrimSpace(req.Role) + if roleName == "" { + return config.RoleConfig{}, false + } + role, ok := h.config.Roles[roleName] + if !ok || !role.Enabled { + return config.RoleConfig{}, false + } + if role.Name == "" { + role.Name = roleName + } + if !workflowrunner.ShouldAutoRunRoleWorkflow(role) { + return config.RoleConfig{}, false + } + return role, true +} + +func (h *AgentHandler) runRoleWorkflowStreamIfBound( + req *ChatRequest, + prep *multiAgentPrepared, + sendEvent func(eventType, message string, data interface{}), +) bool { + role, ok := h.roleForWorkflow(req) + if !ok || prep == nil { + return false + } + baseCtx, cancelWithCause := context.WithCancelCause(context.Background()) + defer cancelWithCause(nil) + progress := h.createProgressCallback(baseCtx, cancelWithCause, prep.ConversationID, prep.AssistantMessageID, sendEvent) + result, err := workflowrunner.RunRoleBoundWorkflow(baseCtx, workflowrunner.RunArgs{ + DB: h.db, + Logger: h.logger, + Role: role, + AppCfg: h.config, + Agent: h.agent, + ConversationID: prep.ConversationID, + ProjectID: h.conversationProjectID(prep.ConversationID), + UserMessage: prep.FinalMessage, + History: prep.History, + RoleTools: prep.RoleTools, + AgentsMarkdownDir: h.agentsMarkdownDir, + SystemPromptExtra: h.agentSessionContextBlock(prep.ConversationID), + AssistantMessageID: prep.AssistantMessageID, + Progress: progress, + }) + if err != nil { + errMsg := "执行角色绑定流程失败: " + err.Error() + if prep.AssistantMessageID != "" { + _, _ = h.db.Exec("UPDATE messages SET content = ?, updated_at = ? WHERE id = ?", errMsg, time.Now(), prep.AssistantMessageID) + _ = h.db.AddProcessDetail(prep.AssistantMessageID, prep.ConversationID, "error", errMsg, nil) + } + sendEvent("error", errMsg, map[string]interface{}{"conversationId": prep.ConversationID}) + sendEvent("done", "", map[string]interface{}{"conversationId": prep.ConversationID}) + return true + } + if prep.AssistantMessageID != "" { + _ = h.db.UpdateAssistantMessageFinalize(prep.AssistantMessageID, result.Response, nil, "") + } + sendEvent("response", result.Response, map[string]interface{}{ + "conversationId": prep.ConversationID, + "messageId": prep.AssistantMessageID, + "agentMode": "workflow", + "workflowRunId": result.RunID, + }) + sendEvent("done", "", map[string]interface{}{"conversationId": prep.ConversationID}) + return true +} + +func (h *AgentHandler) runRoleWorkflowJSONIfBound(c *gin.Context, req *ChatRequest, prep *multiAgentPrepared) bool { + role, ok := h.roleForWorkflow(req) + if !ok || prep == nil { + return false + } + baseCtx, cancelWithCause := context.WithCancelCause(c.Request.Context()) + defer cancelWithCause(nil) + progress := h.createProgressCallback(baseCtx, cancelWithCause, prep.ConversationID, prep.AssistantMessageID, nil) + result, err := workflowrunner.RunRoleBoundWorkflow(baseCtx, workflowrunner.RunArgs{ + DB: h.db, + Logger: h.logger, + Role: role, + AppCfg: h.config, + Agent: h.agent, + ConversationID: prep.ConversationID, + ProjectID: h.conversationProjectID(prep.ConversationID), + UserMessage: prep.FinalMessage, + History: prep.History, + RoleTools: prep.RoleTools, + AgentsMarkdownDir: h.agentsMarkdownDir, + SystemPromptExtra: h.agentSessionContextBlock(prep.ConversationID), + AssistantMessageID: prep.AssistantMessageID, + Progress: progress, + }) + if err != nil { + errMsg := "执行角色绑定流程失败: " + err.Error() + if prep.AssistantMessageID != "" { + _, _ = h.db.Exec("UPDATE messages SET content = ?, updated_at = ? WHERE id = ?", errMsg, time.Now(), prep.AssistantMessageID) + } + c.JSON(http.StatusInternalServerError, gin.H{"error": errMsg, "conversationId": prep.ConversationID}) + return true + } + if prep.AssistantMessageID != "" { + _ = h.db.UpdateAssistantMessageFinalize(prep.AssistantMessageID, result.Response, nil, "") + } + c.JSON(http.StatusOK, gin.H{ + "response": result.Response, + "conversationId": prep.ConversationID, + "assistantMessageId": prep.AssistantMessageID, + "agentMode": "workflow", + "workflowRunId": result.RunID, + }) + return true +}