mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-07-03 19:17:55 +02:00
Add files via upload
This commit is contained in:
@@ -116,6 +116,9 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
||||
"userMessageId": prep.UserMessageID,
|
||||
})
|
||||
}
|
||||
if h.runRoleWorkflowStreamIfBound(&req, prep, sendEvent) {
|
||||
return
|
||||
}
|
||||
|
||||
var cancelWithCause context.CancelCauseFunc
|
||||
curFinalMessage := prep.FinalMessage
|
||||
@@ -385,6 +388,9 @@ func (h *AgentHandler) EinoSingleAgentLoop(c *gin.Context) {
|
||||
if h.hitlManager != nil {
|
||||
defer h.hitlManager.DeactivateConversation(prep.ConversationID)
|
||||
}
|
||||
if h.runRoleWorkflowJSONIfBound(c, &req, prep) {
|
||||
return
|
||||
}
|
||||
|
||||
var progressBuf strings.Builder
|
||||
progressCallbackRaw := func(eventType, message string, data interface{}) {
|
||||
|
||||
@@ -133,6 +133,9 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
||||
"userMessageId": prep.UserMessageID,
|
||||
})
|
||||
}
|
||||
if h.runRoleWorkflowStreamIfBound(&req, prep, sendEvent) {
|
||||
return
|
||||
}
|
||||
|
||||
var cancelWithCause context.CancelCauseFunc
|
||||
curFinalMessage := prep.FinalMessage
|
||||
@@ -407,6 +410,9 @@ func (h *AgentHandler) MultiAgentLoop(c *gin.Context) {
|
||||
if h.hitlManager != nil {
|
||||
defer h.hitlManager.DeactivateConversation(prep.ConversationID)
|
||||
}
|
||||
if h.runRoleWorkflowJSONIfBound(c, &req, prep) {
|
||||
return
|
||||
}
|
||||
|
||||
baseCtx, cancelWithCause := context.WithCancelCause(c.Request.Context())
|
||||
defer cancelWithCause(nil)
|
||||
|
||||
@@ -0,0 +1,142 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/audit"
|
||||
"cyberstrike-ai/internal/database"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type WorkflowHandler struct {
|
||||
db *database.DB
|
||||
logger *zap.Logger
|
||||
audit *audit.Service
|
||||
}
|
||||
|
||||
func NewWorkflowHandler(db *database.DB, logger *zap.Logger) *WorkflowHandler {
|
||||
return &WorkflowHandler{db: db, logger: logger}
|
||||
}
|
||||
|
||||
func (h *WorkflowHandler) SetAudit(s *audit.Service) {
|
||||
h.audit = s
|
||||
}
|
||||
|
||||
type workflowSaveRequest struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Version int `json:"version,omitempty"`
|
||||
Enabled *bool `json:"enabled,omitempty"`
|
||||
Graph json.RawMessage `json:"graph,omitempty"`
|
||||
GraphJSON json.RawMessage `json:"graph_json,omitempty"`
|
||||
}
|
||||
|
||||
func (h *WorkflowHandler) List(c *gin.Context) {
|
||||
includeDisabled := strings.EqualFold(c.Query("includeDisabled"), "true") || c.Query("include_disabled") == "1"
|
||||
items, err := h.db.ListWorkflowDefinitions(includeDisabled)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"workflows": items})
|
||||
}
|
||||
|
||||
func (h *WorkflowHandler) Get(c *gin.Context) {
|
||||
id := strings.TrimSpace(c.Param("id"))
|
||||
wf, err := h.db.GetWorkflowDefinition(id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if wf == nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "工作流不存在"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"workflow": wf})
|
||||
}
|
||||
|
||||
func (h *WorkflowHandler) Create(c *gin.Context) {
|
||||
h.save(c, "")
|
||||
}
|
||||
|
||||
func (h *WorkflowHandler) Update(c *gin.Context) {
|
||||
h.save(c, c.Param("id"))
|
||||
}
|
||||
|
||||
func (h *WorkflowHandler) save(c *gin.Context, pathID string) {
|
||||
var req workflowSaveRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()})
|
||||
return
|
||||
}
|
||||
id := strings.TrimSpace(req.ID)
|
||||
if strings.TrimSpace(pathID) != "" {
|
||||
id = strings.TrimSpace(pathID)
|
||||
}
|
||||
name := strings.TrimSpace(req.Name)
|
||||
if id == "" || name == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "工作流 id 和 name 不能为空"})
|
||||
return
|
||||
}
|
||||
graph := req.Graph
|
||||
if len(graph) == 0 {
|
||||
graph = req.GraphJSON
|
||||
}
|
||||
if len(graph) == 0 {
|
||||
graph = []byte(`{"nodes":[],"edges":[],"config":{}}`)
|
||||
}
|
||||
if !json.Valid(graph) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "graph 必须是合法 JSON"})
|
||||
return
|
||||
}
|
||||
var probe interface{}
|
||||
if err := json.Unmarshal(graph, &probe); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "graph JSON 解析失败: " + err.Error()})
|
||||
return
|
||||
}
|
||||
enabled := true
|
||||
if req.Enabled != nil {
|
||||
enabled = *req.Enabled
|
||||
}
|
||||
wf := &database.WorkflowDefinition{
|
||||
ID: id,
|
||||
Name: name,
|
||||
Description: strings.TrimSpace(req.Description),
|
||||
Version: req.Version,
|
||||
GraphJSON: string(graph),
|
||||
Enabled: enabled,
|
||||
}
|
||||
if err := h.db.UpsertWorkflowDefinition(wf); err != nil {
|
||||
if h.logger != nil {
|
||||
h.logger.Warn("保存工作流失败", zap.String("id", id), zap.Error(err))
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
saved, _ := h.db.GetWorkflowDefinition(id)
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "workflow", "save", "保存图编排流程", "workflow", id, map[string]interface{}{"name": name})
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"message": "工作流已保存", "workflow": saved})
|
||||
}
|
||||
|
||||
func (h *WorkflowHandler) Delete(c *gin.Context) {
|
||||
id := strings.TrimSpace(c.Param("id"))
|
||||
if id == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "工作流 id 不能为空"})
|
||||
return
|
||||
}
|
||||
if err := h.db.DeleteWorkflowDefinition(id); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "workflow", "delete", "删除图编排流程", "workflow", id, nil)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"message": "工作流已删除"})
|
||||
}
|
||||
@@ -0,0 +1,130 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
workflowrunner "cyberstrike-ai/internal/workflow"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func (h *AgentHandler) roleForWorkflow(req *ChatRequest) (config.RoleConfig, bool) {
|
||||
if h == nil || h.config == nil || h.config.Roles == nil || req == nil {
|
||||
return config.RoleConfig{}, false
|
||||
}
|
||||
roleName := strings.TrimSpace(req.Role)
|
||||
if roleName == "" {
|
||||
return config.RoleConfig{}, false
|
||||
}
|
||||
role, ok := h.config.Roles[roleName]
|
||||
if !ok || !role.Enabled {
|
||||
return config.RoleConfig{}, false
|
||||
}
|
||||
if role.Name == "" {
|
||||
role.Name = roleName
|
||||
}
|
||||
if !workflowrunner.ShouldAutoRunRoleWorkflow(role) {
|
||||
return config.RoleConfig{}, false
|
||||
}
|
||||
return role, true
|
||||
}
|
||||
|
||||
func (h *AgentHandler) runRoleWorkflowStreamIfBound(
|
||||
req *ChatRequest,
|
||||
prep *multiAgentPrepared,
|
||||
sendEvent func(eventType, message string, data interface{}),
|
||||
) bool {
|
||||
role, ok := h.roleForWorkflow(req)
|
||||
if !ok || prep == nil {
|
||||
return false
|
||||
}
|
||||
baseCtx, cancelWithCause := context.WithCancelCause(context.Background())
|
||||
defer cancelWithCause(nil)
|
||||
progress := h.createProgressCallback(baseCtx, cancelWithCause, prep.ConversationID, prep.AssistantMessageID, sendEvent)
|
||||
result, err := workflowrunner.RunRoleBoundWorkflow(baseCtx, workflowrunner.RunArgs{
|
||||
DB: h.db,
|
||||
Logger: h.logger,
|
||||
Role: role,
|
||||
AppCfg: h.config,
|
||||
Agent: h.agent,
|
||||
ConversationID: prep.ConversationID,
|
||||
ProjectID: h.conversationProjectID(prep.ConversationID),
|
||||
UserMessage: prep.FinalMessage,
|
||||
History: prep.History,
|
||||
RoleTools: prep.RoleTools,
|
||||
AgentsMarkdownDir: h.agentsMarkdownDir,
|
||||
SystemPromptExtra: h.agentSessionContextBlock(prep.ConversationID),
|
||||
AssistantMessageID: prep.AssistantMessageID,
|
||||
Progress: progress,
|
||||
})
|
||||
if err != nil {
|
||||
errMsg := "执行角色绑定流程失败: " + err.Error()
|
||||
if prep.AssistantMessageID != "" {
|
||||
_, _ = h.db.Exec("UPDATE messages SET content = ?, updated_at = ? WHERE id = ?", errMsg, time.Now(), prep.AssistantMessageID)
|
||||
_ = h.db.AddProcessDetail(prep.AssistantMessageID, prep.ConversationID, "error", errMsg, nil)
|
||||
}
|
||||
sendEvent("error", errMsg, map[string]interface{}{"conversationId": prep.ConversationID})
|
||||
sendEvent("done", "", map[string]interface{}{"conversationId": prep.ConversationID})
|
||||
return true
|
||||
}
|
||||
if prep.AssistantMessageID != "" {
|
||||
_ = h.db.UpdateAssistantMessageFinalize(prep.AssistantMessageID, result.Response, nil, "")
|
||||
}
|
||||
sendEvent("response", result.Response, map[string]interface{}{
|
||||
"conversationId": prep.ConversationID,
|
||||
"messageId": prep.AssistantMessageID,
|
||||
"agentMode": "workflow",
|
||||
"workflowRunId": result.RunID,
|
||||
})
|
||||
sendEvent("done", "", map[string]interface{}{"conversationId": prep.ConversationID})
|
||||
return true
|
||||
}
|
||||
|
||||
func (h *AgentHandler) runRoleWorkflowJSONIfBound(c *gin.Context, req *ChatRequest, prep *multiAgentPrepared) bool {
|
||||
role, ok := h.roleForWorkflow(req)
|
||||
if !ok || prep == nil {
|
||||
return false
|
||||
}
|
||||
baseCtx, cancelWithCause := context.WithCancelCause(c.Request.Context())
|
||||
defer cancelWithCause(nil)
|
||||
progress := h.createProgressCallback(baseCtx, cancelWithCause, prep.ConversationID, prep.AssistantMessageID, nil)
|
||||
result, err := workflowrunner.RunRoleBoundWorkflow(baseCtx, workflowrunner.RunArgs{
|
||||
DB: h.db,
|
||||
Logger: h.logger,
|
||||
Role: role,
|
||||
AppCfg: h.config,
|
||||
Agent: h.agent,
|
||||
ConversationID: prep.ConversationID,
|
||||
ProjectID: h.conversationProjectID(prep.ConversationID),
|
||||
UserMessage: prep.FinalMessage,
|
||||
History: prep.History,
|
||||
RoleTools: prep.RoleTools,
|
||||
AgentsMarkdownDir: h.agentsMarkdownDir,
|
||||
SystemPromptExtra: h.agentSessionContextBlock(prep.ConversationID),
|
||||
AssistantMessageID: prep.AssistantMessageID,
|
||||
Progress: progress,
|
||||
})
|
||||
if err != nil {
|
||||
errMsg := "执行角色绑定流程失败: " + err.Error()
|
||||
if prep.AssistantMessageID != "" {
|
||||
_, _ = h.db.Exec("UPDATE messages SET content = ?, updated_at = ? WHERE id = ?", errMsg, time.Now(), prep.AssistantMessageID)
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": errMsg, "conversationId": prep.ConversationID})
|
||||
return true
|
||||
}
|
||||
if prep.AssistantMessageID != "" {
|
||||
_ = h.db.UpdateAssistantMessageFinalize(prep.AssistantMessageID, result.Response, nil, "")
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"response": result.Response,
|
||||
"conversationId": prep.ConversationID,
|
||||
"assistantMessageId": prep.AssistantMessageID,
|
||||
"agentMode": "workflow",
|
||||
"workflowRunId": result.RunID,
|
||||
})
|
||||
return true
|
||||
}
|
||||
Reference in New Issue
Block a user