mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-07-04 03:27:54 +02:00
Add files via upload
This commit is contained in:
@@ -5,8 +5,11 @@ import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/agent"
|
||||
"cyberstrike-ai/internal/audit"
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/database"
|
||||
workflowrunner "cyberstrike-ai/internal/workflow"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
@@ -16,6 +19,8 @@ type WorkflowHandler struct {
|
||||
db *database.DB
|
||||
logger *zap.Logger
|
||||
audit *audit.Service
|
||||
agent *agent.Agent
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
func NewWorkflowHandler(db *database.DB, logger *zap.Logger) *WorkflowHandler {
|
||||
@@ -94,6 +99,10 @@ func (h *WorkflowHandler) save(c *gin.Context, pathID string) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "graph 必须是合法 JSON"})
|
||||
return
|
||||
}
|
||||
if err := workflowrunner.ValidateGraphJSON(c.Request.Context(), string(graph)); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "工作流图无法编译: " + err.Error()})
|
||||
return
|
||||
}
|
||||
var probe interface{}
|
||||
if err := json.Unmarshal(graph, &probe); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "graph JSON 解析失败: " + err.Error()})
|
||||
@@ -119,6 +128,7 @@ func (h *WorkflowHandler) save(c *gin.Context, pathID string) {
|
||||
return
|
||||
}
|
||||
saved, _ := h.db.GetWorkflowDefinition(id)
|
||||
workflowrunner.InvalidateCompiledCache(id)
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "workflow", "save", "保存图编排流程", "workflow", id, map[string]interface{}{"name": name})
|
||||
}
|
||||
@@ -135,6 +145,7 @@ func (h *WorkflowHandler) Delete(c *gin.Context) {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
workflowrunner.InvalidateCompiledCache(id)
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "workflow", "delete", "删除图编排流程", "workflow", id, nil)
|
||||
}
|
||||
|
||||
@@ -74,12 +74,24 @@ func (h *AgentHandler) runRoleWorkflowStreamIfBound(
|
||||
if prep.AssistantMessageID != "" {
|
||||
_ = h.db.UpdateAssistantMessageFinalize(prep.AssistantMessageID, result.Response, nil, "")
|
||||
}
|
||||
sendEvent("response", result.Response, map[string]interface{}{
|
||||
payload := map[string]interface{}{
|
||||
"conversationId": prep.ConversationID,
|
||||
"messageId": prep.AssistantMessageID,
|
||||
"agentMode": "workflow",
|
||||
"workflowRunId": result.RunID,
|
||||
})
|
||||
}
|
||||
if result.AwaitingHITL {
|
||||
payload["workflowStatus"] = "awaiting_hitl"
|
||||
payload["awaitingHitl"] = true
|
||||
}
|
||||
sendEvent("response", result.Response, payload)
|
||||
if result.AwaitingHITL {
|
||||
sendEvent("done", "", map[string]interface{}{
|
||||
"conversationId": prep.ConversationID,
|
||||
"workflowStatus": "awaiting_hitl",
|
||||
})
|
||||
return true
|
||||
}
|
||||
sendEvent("done", "", map[string]interface{}{"conversationId": prep.ConversationID})
|
||||
return true
|
||||
}
|
||||
@@ -125,6 +137,8 @@ func (h *AgentHandler) runRoleWorkflowJSONIfBound(c *gin.Context, req *ChatReque
|
||||
"assistantMessageId": prep.AssistantMessageID,
|
||||
"agentMode": "workflow",
|
||||
"workflowRunId": result.RunID,
|
||||
"workflowStatus": result.Status,
|
||||
"awaitingHitl": result.AwaitingHITL,
|
||||
})
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -0,0 +1,95 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/agent"
|
||||
"cyberstrike-ai/internal/config"
|
||||
workflowrunner "cyberstrike-ai/internal/workflow"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func (h *WorkflowHandler) SetRuntime(agent *agent.Agent, cfg *config.Config) {
|
||||
h.agent = agent
|
||||
h.cfg = cfg
|
||||
}
|
||||
|
||||
func (h *WorkflowHandler) GetRun(c *gin.Context) {
|
||||
runID := strings.TrimSpace(c.Param("runId"))
|
||||
run, err := h.db.GetWorkflowRun(runID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if run == nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "工作流运行不存在"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"run": run})
|
||||
}
|
||||
|
||||
func (h *WorkflowHandler) ListPendingRuns(c *gin.Context) {
|
||||
runs, err := h.db.ListWorkflowRunsAwaitingHITL(50)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"runs": runs})
|
||||
}
|
||||
|
||||
type workflowResumeRequest struct {
|
||||
Approved bool `json:"approved"`
|
||||
Comment string `json:"comment,omitempty"`
|
||||
}
|
||||
|
||||
func (h *WorkflowHandler) ResumeRun(c *gin.Context) {
|
||||
if h.agent == nil || h.cfg == nil {
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "工作流运行时未初始化"})
|
||||
return
|
||||
}
|
||||
runID := strings.TrimSpace(c.Param("runId"))
|
||||
var req workflowResumeRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()})
|
||||
return
|
||||
}
|
||||
run, err := h.db.GetWorkflowRun(runID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if run == nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "工作流运行不存在"})
|
||||
return
|
||||
}
|
||||
role := config.RoleConfig{Name: strings.TrimSpace(run.RoleID)}
|
||||
if role.Name != "" && h.cfg.Roles != nil {
|
||||
if r, ok := h.cfg.Roles[role.Name]; ok {
|
||||
role = r
|
||||
if role.Name == "" {
|
||||
role.Name = run.RoleID
|
||||
}
|
||||
}
|
||||
}
|
||||
result, err := workflowrunner.ResumeWorkflowRun(c.Request.Context(), workflowrunner.RunArgs{
|
||||
DB: h.db,
|
||||
Logger: h.logger,
|
||||
Role: role,
|
||||
AppCfg: h.cfg,
|
||||
Agent: h.agent,
|
||||
ConversationID: run.ConversationID,
|
||||
ProjectID: run.ProjectID,
|
||||
}, runID, req.Approved, req.Comment)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"response": result.Response,
|
||||
"workflowRunId": result.RunID,
|
||||
"status": result.Status,
|
||||
"awaitingHitl": result.AwaitingHITL,
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user