diff --git a/internal/app/app.go b/internal/app/app.go index 55a6d55b..0a429427 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -358,6 +358,7 @@ func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error projectHandler := handler.NewProjectHandler(db, log.Logger) workflowHandler := handler.NewWorkflowHandler(db, log.Logger) workflowHandler.SetAudit(auditSvc) + workflowHandler.SetRuntime(agent, cfg) vulnerabilityHandler.SetAudit(auditSvc) webshellHandler := handler.NewWebShellHandler(log.Logger, db) webshellHandler.SetAudit(auditSvc) @@ -1197,6 +1198,9 @@ func setupRoutes( protected.DELETE("/roles/:name", roleHandler.DeleteRole) // 图编排 / 工作流定义(图结构固定,业务字段保存在 graph_json 中) + protected.GET("/workflows/runs/pending", workflowHandler.ListPendingRuns) + protected.GET("/workflows/runs/:runId", workflowHandler.GetRun) + protected.POST("/workflows/runs/:runId/resume", workflowHandler.ResumeRun) protected.GET("/workflows", workflowHandler.List) protected.GET("/workflows/:id", workflowHandler.Get) protected.POST("/workflows", workflowHandler.Create) diff --git a/internal/handler/workflow.go b/internal/handler/workflow.go index 2e6a6ce0..4ea7166d 100644 --- a/internal/handler/workflow.go +++ b/internal/handler/workflow.go @@ -5,8 +5,11 @@ import ( "net/http" "strings" + "cyberstrike-ai/internal/agent" "cyberstrike-ai/internal/audit" + "cyberstrike-ai/internal/config" "cyberstrike-ai/internal/database" + workflowrunner "cyberstrike-ai/internal/workflow" "github.com/gin-gonic/gin" "go.uber.org/zap" @@ -16,6 +19,8 @@ type WorkflowHandler struct { db *database.DB logger *zap.Logger audit *audit.Service + agent *agent.Agent + cfg *config.Config } func NewWorkflowHandler(db *database.DB, logger *zap.Logger) *WorkflowHandler { @@ -94,6 +99,10 @@ func (h *WorkflowHandler) save(c *gin.Context, pathID string) { c.JSON(http.StatusBadRequest, gin.H{"error": "graph 必须是合法 JSON"}) return } + if err := workflowrunner.ValidateGraphJSON(c.Request.Context(), string(graph)); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "工作流图无法编译: " + err.Error()}) + return + } var probe interface{} if err := json.Unmarshal(graph, &probe); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": "graph JSON 解析失败: " + err.Error()}) @@ -119,6 +128,7 @@ func (h *WorkflowHandler) save(c *gin.Context, pathID string) { return } saved, _ := h.db.GetWorkflowDefinition(id) + workflowrunner.InvalidateCompiledCache(id) if h.audit != nil { h.audit.RecordOK(c, "workflow", "save", "保存图编排流程", "workflow", id, map[string]interface{}{"name": name}) } @@ -135,6 +145,7 @@ func (h *WorkflowHandler) Delete(c *gin.Context) { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } + workflowrunner.InvalidateCompiledCache(id) if h.audit != nil { h.audit.RecordOK(c, "workflow", "delete", "删除图编排流程", "workflow", id, nil) } diff --git a/internal/handler/workflow_integration.go b/internal/handler/workflow_integration.go index b68896af..3eae0a0a 100644 --- a/internal/handler/workflow_integration.go +++ b/internal/handler/workflow_integration.go @@ -74,12 +74,24 @@ func (h *AgentHandler) runRoleWorkflowStreamIfBound( if prep.AssistantMessageID != "" { _ = h.db.UpdateAssistantMessageFinalize(prep.AssistantMessageID, result.Response, nil, "") } - sendEvent("response", result.Response, map[string]interface{}{ + payload := map[string]interface{}{ "conversationId": prep.ConversationID, "messageId": prep.AssistantMessageID, "agentMode": "workflow", "workflowRunId": result.RunID, - }) + } + if result.AwaitingHITL { + payload["workflowStatus"] = "awaiting_hitl" + payload["awaitingHitl"] = true + } + sendEvent("response", result.Response, payload) + if result.AwaitingHITL { + sendEvent("done", "", map[string]interface{}{ + "conversationId": prep.ConversationID, + "workflowStatus": "awaiting_hitl", + }) + return true + } sendEvent("done", "", map[string]interface{}{"conversationId": prep.ConversationID}) return true } @@ -125,6 +137,8 @@ func (h *AgentHandler) runRoleWorkflowJSONIfBound(c *gin.Context, req *ChatReque "assistantMessageId": prep.AssistantMessageID, "agentMode": "workflow", "workflowRunId": result.RunID, + "workflowStatus": result.Status, + "awaitingHitl": result.AwaitingHITL, }) return true } diff --git a/internal/handler/workflow_run.go b/internal/handler/workflow_run.go new file mode 100644 index 00000000..0a19c0b8 --- /dev/null +++ b/internal/handler/workflow_run.go @@ -0,0 +1,95 @@ +package handler + +import ( + "net/http" + "strings" + + "cyberstrike-ai/internal/agent" + "cyberstrike-ai/internal/config" + workflowrunner "cyberstrike-ai/internal/workflow" + + "github.com/gin-gonic/gin" +) + +func (h *WorkflowHandler) SetRuntime(agent *agent.Agent, cfg *config.Config) { + h.agent = agent + h.cfg = cfg +} + +func (h *WorkflowHandler) GetRun(c *gin.Context) { + runID := strings.TrimSpace(c.Param("runId")) + run, err := h.db.GetWorkflowRun(runID) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + if run == nil { + c.JSON(http.StatusNotFound, gin.H{"error": "工作流运行不存在"}) + return + } + c.JSON(http.StatusOK, gin.H{"run": run}) +} + +func (h *WorkflowHandler) ListPendingRuns(c *gin.Context) { + runs, err := h.db.ListWorkflowRunsAwaitingHITL(50) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + c.JSON(http.StatusOK, gin.H{"runs": runs}) +} + +type workflowResumeRequest struct { + Approved bool `json:"approved"` + Comment string `json:"comment,omitempty"` +} + +func (h *WorkflowHandler) ResumeRun(c *gin.Context) { + if h.agent == nil || h.cfg == nil { + c.JSON(http.StatusServiceUnavailable, gin.H{"error": "工作流运行时未初始化"}) + return + } + runID := strings.TrimSpace(c.Param("runId")) + var req workflowResumeRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()}) + return + } + run, err := h.db.GetWorkflowRun(runID) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + if run == nil { + c.JSON(http.StatusNotFound, gin.H{"error": "工作流运行不存在"}) + return + } + role := config.RoleConfig{Name: strings.TrimSpace(run.RoleID)} + if role.Name != "" && h.cfg.Roles != nil { + if r, ok := h.cfg.Roles[role.Name]; ok { + role = r + if role.Name == "" { + role.Name = run.RoleID + } + } + } + result, err := workflowrunner.ResumeWorkflowRun(c.Request.Context(), workflowrunner.RunArgs{ + DB: h.db, + Logger: h.logger, + Role: role, + AppCfg: h.cfg, + Agent: h.agent, + ConversationID: run.ConversationID, + ProjectID: run.ProjectID, + }, runID, req.Approved, req.Comment) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + c.JSON(http.StatusOK, gin.H{ + "response": result.Response, + "workflowRunId": result.RunID, + "status": result.Status, + "awaitingHitl": result.AwaitingHITL, + }) +}