mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-07-04 11:37:57 +02:00
Add files via upload
This commit is contained in:
@@ -2,6 +2,7 @@ package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -10,6 +11,7 @@ import (
|
||||
workflowrunner "cyberstrike-ai/internal/workflow"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func (h *AgentHandler) roleForWorkflow(req *ChatRequest) (config.RoleConfig, bool) {
|
||||
@@ -42,33 +44,108 @@ func (h *AgentHandler) runRoleWorkflowStreamIfBound(
|
||||
if !ok || prep == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
conversationID := prep.ConversationID
|
||||
assistantMessageID := prep.AssistantMessageID
|
||||
userMessage := ""
|
||||
if req != nil {
|
||||
userMessage = req.Message
|
||||
}
|
||||
|
||||
taskStatus := "completed"
|
||||
taskOwned := false
|
||||
defer func() {
|
||||
if taskOwned {
|
||||
h.tasks.FinishTask(conversationID, taskStatus)
|
||||
}
|
||||
}()
|
||||
|
||||
baseCtx, cancelWithCause := context.WithCancelCause(context.Background())
|
||||
defer cancelWithCause(nil)
|
||||
progress := h.createProgressCallback(baseCtx, cancelWithCause, prep.ConversationID, prep.AssistantMessageID, sendEvent)
|
||||
result, err := workflowrunner.RunRoleBoundWorkflow(baseCtx, workflowrunner.RunArgs{
|
||||
taskCtx, timeoutCancel := context.WithTimeout(baseCtx, 600*time.Minute)
|
||||
defer timeoutCancel()
|
||||
|
||||
if _, err := h.tasks.StartTask(conversationID, userMessage, cancelWithCause); err != nil {
|
||||
var errorMsg string
|
||||
if errors.Is(err, ErrTaskAlreadyRunning) {
|
||||
errorMsg = "⚠️ 当前会话已有任务正在执行中,请等待当前任务完成或点击「停止任务」后再尝试。"
|
||||
sendEvent("error", errorMsg, map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"errorType": "task_already_running",
|
||||
})
|
||||
} else {
|
||||
errorMsg = "❌ 无法启动任务: " + err.Error()
|
||||
sendEvent("error", errorMsg, nil)
|
||||
}
|
||||
if assistantMessageID != "" {
|
||||
_, _ = h.db.Exec("UPDATE messages SET content = ?, updated_at = ? WHERE id = ?", errorMsg, time.Now(), assistantMessageID)
|
||||
}
|
||||
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
|
||||
return true
|
||||
}
|
||||
taskOwned = true
|
||||
|
||||
progress := h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, sendEvent)
|
||||
result, err := workflowrunner.RunRoleBoundWorkflow(taskCtx, workflowrunner.RunArgs{
|
||||
DB: h.db,
|
||||
Logger: h.logger,
|
||||
Role: role,
|
||||
AppCfg: h.config,
|
||||
Agent: h.agent,
|
||||
ConversationID: prep.ConversationID,
|
||||
ProjectID: h.conversationProjectID(prep.ConversationID),
|
||||
ConversationID: conversationID,
|
||||
ProjectID: h.conversationProjectID(conversationID),
|
||||
UserMessage: prep.FinalMessage,
|
||||
History: prep.History,
|
||||
RoleTools: prep.RoleTools,
|
||||
AgentsMarkdownDir: h.agentsMarkdownDir,
|
||||
SystemPromptExtra: h.agentSessionContextBlock(prep.ConversationID),
|
||||
AssistantMessageID: prep.AssistantMessageID,
|
||||
SystemPromptExtra: h.agentSessionContextBlock(conversationID),
|
||||
AssistantMessageID: assistantMessageID,
|
||||
Progress: progress,
|
||||
})
|
||||
if err != nil {
|
||||
errMsg := "执行角色绑定流程失败: " + err.Error()
|
||||
if prep.AssistantMessageID != "" {
|
||||
_, _ = h.db.Exec("UPDATE messages SET content = ?, updated_at = ? WHERE id = ?", errMsg, time.Now(), prep.AssistantMessageID)
|
||||
_ = h.db.AddProcessDetail(prep.AssistantMessageID, prep.ConversationID, "error", errMsg, nil)
|
||||
cause := context.Cause(baseCtx)
|
||||
if errors.Is(cause, ErrTaskCancelled) {
|
||||
taskStatus = "cancelled"
|
||||
h.tasks.UpdateTaskStatus(conversationID, taskStatus)
|
||||
cancelMsg := "任务已被用户取消,后续操作已停止。"
|
||||
if assistantMessageID != "" {
|
||||
if err := h.appendAssistantMessageNotice(assistantMessageID, cancelMsg); err != nil {
|
||||
h.logger.Warn("更新取消后的助手消息失败", zap.Error(err))
|
||||
}
|
||||
_ = h.db.AddProcessDetail(assistantMessageID, conversationID, "cancelled", cancelMsg, nil)
|
||||
}
|
||||
sendEvent("cancelled", cancelMsg, map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"messageId": assistantMessageID,
|
||||
})
|
||||
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
|
||||
return true
|
||||
}
|
||||
sendEvent("error", errMsg, map[string]interface{}{"conversationId": prep.ConversationID})
|
||||
sendEvent("done", "", map[string]interface{}{"conversationId": prep.ConversationID})
|
||||
if errors.Is(err, context.DeadlineExceeded) || errors.Is(context.Cause(taskCtx), context.DeadlineExceeded) {
|
||||
taskStatus = "timeout"
|
||||
h.tasks.UpdateTaskStatus(conversationID, taskStatus)
|
||||
timeoutMsg := "任务执行超时,已自动终止。"
|
||||
if assistantMessageID != "" {
|
||||
_, _ = h.db.Exec("UPDATE messages SET content = ?, updated_at = ? WHERE id = ?", timeoutMsg, time.Now(), assistantMessageID)
|
||||
_ = h.db.AddProcessDetail(assistantMessageID, conversationID, "timeout", timeoutMsg, nil)
|
||||
}
|
||||
sendEvent("error", timeoutMsg, map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"messageId": assistantMessageID,
|
||||
"errorType": "timeout",
|
||||
})
|
||||
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
|
||||
return true
|
||||
}
|
||||
errMsg := "执行角色绑定流程失败: " + err.Error()
|
||||
taskStatus = "failed"
|
||||
h.tasks.UpdateTaskStatus(conversationID, taskStatus)
|
||||
if assistantMessageID != "" {
|
||||
_, _ = h.db.Exec("UPDATE messages SET content = ?, updated_at = ? WHERE id = ?", errMsg, time.Now(), assistantMessageID)
|
||||
_ = h.db.AddProcessDetail(assistantMessageID, conversationID, "error", errMsg, nil)
|
||||
}
|
||||
sendEvent("error", errMsg, map[string]interface{}{"conversationId": conversationID})
|
||||
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
|
||||
return true
|
||||
}
|
||||
if prep.AssistantMessageID != "" {
|
||||
@@ -85,13 +162,6 @@ func (h *AgentHandler) runRoleWorkflowStreamIfBound(
|
||||
payload["awaitingHitl"] = true
|
||||
}
|
||||
sendEvent("response", result.Response, payload)
|
||||
if result.AwaitingHITL {
|
||||
sendEvent("done", "", map[string]interface{}{
|
||||
"conversationId": prep.ConversationID,
|
||||
"workflowStatus": "awaiting_hitl",
|
||||
})
|
||||
return true
|
||||
}
|
||||
sendEvent("done", "", map[string]interface{}{"conversationId": prep.ConversationID})
|
||||
return true
|
||||
}
|
||||
@@ -101,31 +171,80 @@ func (h *AgentHandler) runRoleWorkflowJSONIfBound(c *gin.Context, req *ChatReque
|
||||
if !ok || prep == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
conversationID := prep.ConversationID
|
||||
assistantMessageID := prep.AssistantMessageID
|
||||
userMessage := ""
|
||||
if req != nil {
|
||||
userMessage = req.Message
|
||||
}
|
||||
|
||||
taskStatus := "completed"
|
||||
taskOwned := false
|
||||
defer func() {
|
||||
if taskOwned {
|
||||
h.tasks.FinishTask(conversationID, taskStatus)
|
||||
}
|
||||
}()
|
||||
|
||||
baseCtx, cancelWithCause := context.WithCancelCause(c.Request.Context())
|
||||
defer cancelWithCause(nil)
|
||||
progress := h.createProgressCallback(baseCtx, cancelWithCause, prep.ConversationID, prep.AssistantMessageID, nil)
|
||||
result, err := workflowrunner.RunRoleBoundWorkflow(baseCtx, workflowrunner.RunArgs{
|
||||
taskCtx, timeoutCancel := context.WithTimeout(baseCtx, 600*time.Minute)
|
||||
defer timeoutCancel()
|
||||
|
||||
if _, err := h.tasks.StartTask(conversationID, userMessage, cancelWithCause); err != nil {
|
||||
if errors.Is(err, ErrTaskAlreadyRunning) {
|
||||
c.JSON(http.StatusConflict, gin.H{
|
||||
"error": "⚠️ 当前会话已有任务正在执行中,请等待当前任务完成或点击「停止任务」后再尝试。",
|
||||
"conversationId": conversationID,
|
||||
"errorType": "task_already_running",
|
||||
})
|
||||
} else {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "❌ 无法启动任务: " + err.Error()})
|
||||
}
|
||||
return true
|
||||
}
|
||||
taskOwned = true
|
||||
|
||||
progress := h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, nil)
|
||||
result, err := workflowrunner.RunRoleBoundWorkflow(taskCtx, workflowrunner.RunArgs{
|
||||
DB: h.db,
|
||||
Logger: h.logger,
|
||||
Role: role,
|
||||
AppCfg: h.config,
|
||||
Agent: h.agent,
|
||||
ConversationID: prep.ConversationID,
|
||||
ProjectID: h.conversationProjectID(prep.ConversationID),
|
||||
ConversationID: conversationID,
|
||||
ProjectID: h.conversationProjectID(conversationID),
|
||||
UserMessage: prep.FinalMessage,
|
||||
History: prep.History,
|
||||
RoleTools: prep.RoleTools,
|
||||
AgentsMarkdownDir: h.agentsMarkdownDir,
|
||||
SystemPromptExtra: h.agentSessionContextBlock(prep.ConversationID),
|
||||
AssistantMessageID: prep.AssistantMessageID,
|
||||
SystemPromptExtra: h.agentSessionContextBlock(conversationID),
|
||||
AssistantMessageID: assistantMessageID,
|
||||
Progress: progress,
|
||||
})
|
||||
if err != nil {
|
||||
errMsg := "执行角色绑定流程失败: " + err.Error()
|
||||
if prep.AssistantMessageID != "" {
|
||||
_, _ = h.db.Exec("UPDATE messages SET content = ?, updated_at = ? WHERE id = ?", errMsg, time.Now(), prep.AssistantMessageID)
|
||||
cause := context.Cause(baseCtx)
|
||||
if errors.Is(cause, ErrTaskCancelled) {
|
||||
taskStatus = "cancelled"
|
||||
cancelMsg := "任务已被用户取消,后续操作已停止。"
|
||||
if assistantMessageID != "" {
|
||||
_ = h.appendAssistantMessageNotice(assistantMessageID, cancelMsg)
|
||||
_ = h.db.AddProcessDetail(assistantMessageID, conversationID, "cancelled", cancelMsg, nil)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"status": "cancelled",
|
||||
"message": cancelMsg,
|
||||
"conversationId": conversationID,
|
||||
})
|
||||
return true
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": errMsg, "conversationId": prep.ConversationID})
|
||||
errMsg := "执行角色绑定流程失败: " + err.Error()
|
||||
taskStatus = "failed"
|
||||
if assistantMessageID != "" {
|
||||
_, _ = h.db.Exec("UPDATE messages SET content = ?, updated_at = ? WHERE id = ?", errMsg, time.Now(), assistantMessageID)
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": errMsg, "conversationId": conversationID})
|
||||
return true
|
||||
}
|
||||
if prep.AssistantMessageID != "" {
|
||||
|
||||
@@ -3,6 +3,7 @@ package handler
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/agent"
|
||||
"cyberstrike-ai/internal/config"
|
||||
@@ -31,7 +32,8 @@ func (h *WorkflowHandler) GetRun(c *gin.Context) {
|
||||
}
|
||||
|
||||
func (h *WorkflowHandler) ListPendingRuns(c *gin.Context) {
|
||||
runs, err := h.db.ListWorkflowRunsAwaitingHITL(50)
|
||||
conversationID := strings.TrimSpace(c.Query("conversationId"))
|
||||
runs, err := h.db.ListWorkflowRunsAwaitingHITLFiltered(conversationID, 50)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
@@ -73,6 +75,37 @@ func (h *WorkflowHandler) ResumeRun(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
}
|
||||
if run.Status != "awaiting_hitl" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "工作流运行不在等待审批状态: " + run.Status})
|
||||
return
|
||||
}
|
||||
if err := h.db.RecordWorkflowRunHITLDecision(runID, req.Approved, req.Comment); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
decision := workflowrunner.HITLDecision{
|
||||
Approved: req.Approved,
|
||||
Comment: strings.TrimSpace(req.Comment),
|
||||
}
|
||||
delegated := workflowrunner.NotifyHITLDecision(runID, decision)
|
||||
if !delegated {
|
||||
for i := 0; i < 10; i++ {
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
if workflowrunner.NotifyHITLDecision(runID, decision) {
|
||||
delegated = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if delegated {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"workflowRunId": runID,
|
||||
"status": "delegated",
|
||||
"streamResuming": true,
|
||||
"approved": req.Approved,
|
||||
})
|
||||
return
|
||||
}
|
||||
result, err := workflowrunner.ResumeWorkflowRun(c.Request.Context(), workflowrunner.RunArgs{
|
||||
DB: h.db,
|
||||
Logger: h.logger,
|
||||
@@ -87,9 +120,9 @@ func (h *WorkflowHandler) ResumeRun(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"response": result.Response,
|
||||
"response": result.Response,
|
||||
"workflowRunId": result.RunID,
|
||||
"status": result.Status,
|
||||
"awaitingHitl": result.AwaitingHITL,
|
||||
"status": result.Status,
|
||||
"awaitingHitl": result.AwaitingHITL,
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user