diff --git a/data/conversations.db b/data/conversations.db new file mode 100644 index 00000000..4ebf78cf Binary files /dev/null and b/data/conversations.db differ diff --git a/data/conversations.db-shm b/data/conversations.db-shm new file mode 100644 index 00000000..5153d9cd Binary files /dev/null and b/data/conversations.db-shm differ diff --git a/data/conversations.db-wal b/data/conversations.db-wal new file mode 100644 index 00000000..cb381552 Binary files /dev/null and b/data/conversations.db-wal differ diff --git a/internal/app/app.go b/internal/app/app.go index 2480e3fe..265e66c8 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -42,7 +42,7 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) { if dbPath == "" { dbPath = "data/conversations.db" } - + // 确保目录存在 if err := os.MkdirAll(filepath.Dir(dbPath), 0755); err != nil { return nil, fmt.Errorf("创建数据库目录失败: %w", err) @@ -95,10 +95,10 @@ func (a *App) Run() error { go func() { mcpAddr := fmt.Sprintf("%s:%d", a.config.MCP.Host, a.config.MCP.Port) a.logger.Info("启动MCP服务器", zap.String("address", mcpAddr)) - + mux := http.NewServeMux() mux.HandleFunc("/mcp", a.mcpServer.HandleHTTP) - + if err := http.ListenAndServe(mcpAddr, mux); err != nil { a.logger.Error("MCP服务器启动失败", zap.Error(err)) } @@ -108,7 +108,7 @@ func (a *App) Run() error { // 启动主服务器 addr := fmt.Sprintf("%s:%d", a.config.Server.Host, a.config.Server.Port) a.logger.Info("启动HTTP服务器", zap.String("address", addr)) - + return a.router.Run(addr) } @@ -121,19 +121,22 @@ func setupRoutes(router *gin.Engine, agentHandler *handler.AgentHandler, monitor api.POST("/agent-loop", agentHandler.AgentLoop) // Agent Loop 流式输出 api.POST("/agent-loop/stream", agentHandler.AgentLoopStream) - + // Agent Loop 取消与任务列表 + api.POST("/agent-loop/cancel", agentHandler.CancelAgentLoop) + api.GET("/agent-loop/tasks", agentHandler.ListAgentTasks) + // 对话历史 api.POST("/conversations", conversationHandler.CreateConversation) api.GET("/conversations", conversationHandler.ListConversations) api.GET("/conversations/:id", conversationHandler.GetConversation) api.DELETE("/conversations/:id", conversationHandler.DeleteConversation) - + // 监控 api.GET("/monitor", monitorHandler.Monitor) api.GET("/monitor/execution/:id", monitorHandler.GetExecution) api.GET("/monitor/stats", monitorHandler.GetStats) api.GET("/monitor/vulnerabilities", monitorHandler.GetVulnerabilities) - + // MCP端点 api.POST("/mcp", func(c *gin.Context) { mcpServer.HandleHTTP(c.Writer, c.Request) @@ -143,7 +146,7 @@ func setupRoutes(router *gin.Engine, agentHandler *handler.AgentHandler, monitor // 静态文件 router.Static("/static", "./web/static") router.LoadHTMLGlob("web/templates/*") - + // 前端页面 router.GET("/", func(c *gin.Context) { c.HTML(http.StatusOK, "index.html", nil) @@ -166,4 +169,3 @@ func corsMiddleware() gin.HandlerFunc { c.Next() } } - diff --git a/internal/handler/agent.go b/internal/handler/agent.go index e9e50aa7..372108ec 100644 --- a/internal/handler/agent.go +++ b/internal/handler/agent.go @@ -3,6 +3,7 @@ package handler import ( "context" "encoding/json" + "errors" "fmt" "net/http" "time" @@ -18,6 +19,7 @@ type AgentHandler struct { agent *agent.Agent db *database.DB logger *zap.Logger + tasks *AgentTaskManager } // NewAgentHandler 创建新的Agent处理器 @@ -26,6 +28,7 @@ func NewAgentHandler(agent *agent.Agent, db *database.DB, logger *zap.Logger) *A agent: agent, db: db, logger: logger, + tasks: NewAgentTaskManager(), } } @@ -101,7 +104,7 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) { zap.String("content", contentPreview), ) } - + h.logger.Info("历史消息转换完成", zap.Int("originalCount", len(historyMessages)), zap.Int("convertedCount", len(agentHistoryMessages)), @@ -130,14 +133,14 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) { c.JSON(http.StatusOK, ChatResponse{ Response: result.Response, MCPExecutionIDs: result.MCPExecutionIDs, - ConversationID: conversationID, + ConversationID: conversationID, Time: time.Now(), }) } // StreamEvent 流式事件 type StreamEvent struct { - Type string `json:"type"` // progress, tool_call, tool_result, response, error, done + Type string `json:"type"` // conversation, progress, tool_call, tool_result, response, error, cancelled, done Message string `json:"message"` // 显示消息 Data interface{} `json:"data,omitempty"` } @@ -174,13 +177,13 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) { // 发送初始事件 // 用于跟踪客户端是否已断开连接 clientDisconnected := false - + sendEvent := func(eventType, message string, data interface{}) { // 如果客户端已断开,不再发送事件 if clientDisconnected { return } - + // 检查请求上下文是否被取消(客户端断开) select { case <-c.Request.Context().Done(): @@ -188,21 +191,21 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) { return default: } - + event := StreamEvent{ Type: eventType, Message: message, Data: data, } eventJSON, _ := json.Marshal(event) - + // 尝试写入事件,如果失败则标记客户端断开 if _, err := fmt.Fprintf(c.Writer, "data: %s\n\n", eventJSON); err != nil { clientDisconnected = true h.logger.Debug("客户端断开连接,停止发送SSE事件", zap.Error(err)) return } - + // 刷新响应,如果失败则标记客户端断开 if flusher, ok := c.Writer.(http.Flusher); ok { flusher.Flush() @@ -227,6 +230,10 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) { conversationID = conv.ID } + sendEvent("conversation", "会话已创建", map[string]interface{}{ + "conversationId": conversationID, + }) + // 获取历史消息 historyMessages, err := h.db.GetMessages(conversationID) if err != nil { @@ -262,10 +269,10 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) { if assistantMsg != nil { assistantMessageID = assistantMsg.ID } - + progressCallback := func(eventType, message string, data interface{}) { sendEvent(eventType, message, data) - + // 保存过程详情到数据库(排除response和done事件,它们会在后面单独处理) if assistantMessageID != "" && eventType != "response" && eventType != "done" { if err := h.db.AddProcessDetail(assistantMessageID, conversationID, eventType, message, data); err != nil { @@ -276,20 +283,101 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) { // 创建一个独立的上下文用于任务执行,不随HTTP请求取消 // 这样即使客户端断开连接(如刷新页面),任务也能继续执行 - taskCtx, taskCancel := context.WithTimeout(context.Background(), 30*time.Minute) - defer taskCancel() - + baseCtx, cancelWithCause := context.WithCancelCause(context.Background()) + taskCtx, timeoutCancel := context.WithTimeout(baseCtx, 30*time.Minute) + defer timeoutCancel() + defer cancelWithCause(nil) + + if _, err := h.tasks.StartTask(conversationID, req.Message, cancelWithCause); err != nil { + if errors.Is(err, ErrTaskAlreadyRunning) { + sendEvent("error", "当前会话已有任务正在执行,请先停止后再尝试。", map[string]interface{}{ + "conversationId": conversationID, + }) + } else { + sendEvent("error", "无法启动任务: "+err.Error(), map[string]interface{}{ + "conversationId": conversationID, + }) + } + sendEvent("done", "", map[string]interface{}{ + "conversationId": conversationID, + }) + return + } + + taskStatus := "completed" + defer h.tasks.FinishTask(conversationID, taskStatus) + // 执行Agent Loop,传入独立的上下文,确保任务不会因客户端断开而中断 sendEvent("progress", "正在分析您的请求...", nil) result, err := h.agent.AgentLoopWithProgress(taskCtx, req.Message, agentHistoryMessages, progressCallback) if err != nil { h.logger.Error("Agent Loop执行失败", zap.Error(err)) - sendEvent("error", "执行失败: "+err.Error(), nil) - // 保存错误事件 - if assistantMessageID != "" { - h.db.AddProcessDetail(assistantMessageID, conversationID, "error", "执行失败: "+err.Error(), nil) + cause := context.Cause(baseCtx) + + switch { + case errors.Is(cause, ErrTaskCancelled): + taskStatus = "cancelled" + cancelMsg := "任务已被用户取消,后续操作已停止。" + if assistantMessageID != "" { + if _, updateErr := h.db.Exec( + "UPDATE messages SET content = ? WHERE id = ?", + cancelMsg, + assistantMessageID, + ); updateErr != nil { + h.logger.Warn("更新取消后的助手消息失败", zap.Error(updateErr)) + } + h.db.AddProcessDetail(assistantMessageID, conversationID, "cancelled", cancelMsg, nil) + } + sendEvent("cancelled", cancelMsg, map[string]interface{}{ + "conversationId": conversationID, + "messageId": assistantMessageID, + }) + sendEvent("done", "", map[string]interface{}{ + "conversationId": conversationID, + }) + return + case errors.Is(err, context.DeadlineExceeded) || errors.Is(cause, context.DeadlineExceeded): + taskStatus = "timeout" + timeoutMsg := "任务执行超时,已自动终止。" + if assistantMessageID != "" { + if _, updateErr := h.db.Exec( + "UPDATE messages SET content = ? WHERE id = ?", + timeoutMsg, + assistantMessageID, + ); updateErr != nil { + h.logger.Warn("更新超时后的助手消息失败", zap.Error(updateErr)) + } + h.db.AddProcessDetail(assistantMessageID, conversationID, "timeout", timeoutMsg, nil) + } + sendEvent("error", timeoutMsg, map[string]interface{}{ + "conversationId": conversationID, + "messageId": assistantMessageID, + }) + sendEvent("done", "", map[string]interface{}{ + "conversationId": conversationID, + }) + return + default: + taskStatus = "failed" + errorMsg := "执行失败: " + err.Error() + if assistantMessageID != "" { + if _, updateErr := h.db.Exec( + "UPDATE messages SET content = ? WHERE id = ?", + errorMsg, + assistantMessageID, + ); updateErr != nil { + h.logger.Warn("更新失败后的助手消息失败", zap.Error(updateErr)) + } + h.db.AddProcessDetail(assistantMessageID, conversationID, "error", errorMsg, nil) + } + sendEvent("error", errorMsg, map[string]interface{}{ + "conversationId": conversationID, + "messageId": assistantMessageID, + }) + sendEvent("done", "", map[string]interface{}{ + "conversationId": conversationID, + }) } - sendEvent("done", "", nil) return } @@ -329,3 +417,39 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) { }) } +// CancelAgentLoop 取消正在执行的任务 +func (h *AgentHandler) CancelAgentLoop(c *gin.Context) { + var req struct { + ConversationID string `json:"conversationId" binding:"required"` + } + + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + ok, err := h.tasks.CancelTask(req.ConversationID, ErrTaskCancelled) + if err != nil { + h.logger.Error("取消任务失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + if !ok { + c.JSON(http.StatusNotFound, gin.H{"error": "未找到正在执行的任务"}) + return + } + + c.JSON(http.StatusOK, gin.H{ + "status": "cancelling", + "conversationId": req.ConversationID, + "message": "已提交取消请求,任务将在当前步骤完成后停止。", + }) +} + +// ListAgentTasks 列出所有运行中的任务 +func (h *AgentHandler) ListAgentTasks(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{ + "tasks": h.tasks.GetActiveTasks(), + }) +} diff --git a/internal/handler/task_manager.go b/internal/handler/task_manager.go new file mode 100644 index 00000000..02a549ad --- /dev/null +++ b/internal/handler/task_manager.go @@ -0,0 +1,124 @@ +package handler + +import ( + "context" + "errors" + "sync" + "time" +) + +// ErrTaskCancelled 用户取消任务的错误 +var ErrTaskCancelled = errors.New("agent task cancelled by user") + +// ErrTaskAlreadyRunning 会话已有任务正在执行 +var ErrTaskAlreadyRunning = errors.New("agent task already running for conversation") + +// AgentTask 描述正在运行的Agent任务 +type AgentTask struct { + ConversationID string `json:"conversationId"` + Message string `json:"message,omitempty"` + StartedAt time.Time `json:"startedAt"` + Status string `json:"status"` + + cancel func(error) +} + +// AgentTaskManager 管理正在运行的Agent任务 +type AgentTaskManager struct { + mu sync.RWMutex + tasks map[string]*AgentTask +} + +// NewAgentTaskManager 创建任务管理器 +func NewAgentTaskManager() *AgentTaskManager { + return &AgentTaskManager{ + tasks: make(map[string]*AgentTask), + } +} + +// StartTask 注册并开始一个新的任务 +func (m *AgentTaskManager) StartTask(conversationID, message string, cancel context.CancelCauseFunc) (*AgentTask, error) { + m.mu.Lock() + defer m.mu.Unlock() + + if _, exists := m.tasks[conversationID]; exists { + return nil, ErrTaskAlreadyRunning + } + + task := &AgentTask{ + ConversationID: conversationID, + Message: message, + StartedAt: time.Now(), + Status: "running", + cancel: func(err error) { + if cancel != nil { + cancel(err) + } + }, + } + + m.tasks[conversationID] = task + return task, nil +} + +// CancelTask 取消指定会话的任务 +func (m *AgentTaskManager) CancelTask(conversationID string, cause error) (bool, error) { + m.mu.Lock() + task, exists := m.tasks[conversationID] + if !exists { + m.mu.Unlock() + return false, nil + } + + // 如果已经处于取消流程,直接返回 + if task.Status == "cancelling" { + m.mu.Unlock() + return false, nil + } + + task.Status = "cancelling" + cancel := task.cancel + m.mu.Unlock() + + if cause == nil { + cause = ErrTaskCancelled + } + if cancel != nil { + cancel(cause) + } + return true, nil +} + +// FinishTask 完成任务并从管理器中移除 +func (m *AgentTaskManager) FinishTask(conversationID string, finalStatus string) { + m.mu.Lock() + defer m.mu.Unlock() + + task, exists := m.tasks[conversationID] + if !exists { + return + } + + if finalStatus != "" { + task.Status = finalStatus + } + + delete(m.tasks, conversationID) +} + +// GetActiveTasks 返回所有正在运行的任务 +func (m *AgentTaskManager) GetActiveTasks() []*AgentTask { + m.mu.RLock() + defer m.mu.RUnlock() + + result := make([]*AgentTask, 0, len(m.tasks)) + for _, task := range m.tasks { + result = append(result, &AgentTask{ + ConversationID: task.ConversationID, + Message: task.Message, + StartedAt: task.StartedAt, + Status: task.Status, + }) + } + return result +} diff --git a/web/static/css/style.css b/web/static/css/style.css index 0638a539..00aceef4 100644 --- a/web/static/css/style.css +++ b/web/static/css/style.css @@ -993,12 +993,39 @@ header { border-bottom: 1px solid var(--border-color); } +.progress-actions { + display: flex; + align-items: center; + gap: 8px; +} + .progress-title { font-weight: 600; color: var(--text-primary); font-size: 0.9375rem; } +.progress-stop { + padding: 4px 12px; + background: rgba(220, 53, 69, 0.1); + border: 1px solid rgba(220, 53, 69, 0.4); + border-radius: 4px; + font-size: 0.8125rem; + color: var(--error-color); + cursor: pointer; + transition: all 0.2s; +} + +.progress-stop:hover { + background: rgba(220, 53, 69, 0.15); + border-color: var(--error-color); +} + +.progress-stop:disabled { + opacity: 0.6; + cursor: not-allowed; +} + .progress-toggle { padding: 4px 12px; background: var(--bg-tertiary); @@ -1070,6 +1097,11 @@ header { background: rgba(220, 53, 69, 0.1); } +.timeline-item-cancelled { + border-left-color: #ff7043; + background: rgba(255, 112, 67, 0.12); +} + .timeline-item-header { display: flex; align-items: center; @@ -1182,3 +1214,91 @@ header { font-family: 'Monaco', 'Menlo', 'Ubuntu Mono', monospace; color: var(--text-secondary); } + +/* 活跃任务栏 */ +.active-tasks-bar { + display: none; + align-items: center; + gap: 12px; + padding: 12px 20px; + background: rgba(0, 102, 255, 0.06); + border-bottom: 1px solid rgba(0, 102, 255, 0.15); + color: var(--text-primary); +} + +.active-task-item { + display: flex; + align-items: center; + justify-content: space-between; + gap: 16px; + background: var(--bg-primary); + border: 1px solid rgba(0, 102, 255, 0.2); + border-radius: 8px; + padding: 8px 12px; + flex: 1; + min-width: 0; +} + +.active-task-info { + display: flex; + align-items: center; + gap: 8px; + min-width: 0; +} + +.active-task-status { + background: rgba(0, 102, 255, 0.12); + color: var(--accent-color); + padding: 2px 8px; + border-radius: 999px; + font-size: 0.75rem; + font-weight: 600; + flex-shrink: 0; +} + +.active-task-message { + font-size: 0.875rem; + color: var(--text-primary); + overflow: hidden; + text-overflow: ellipsis; + white-space: nowrap; + max-width: 320px; +} + +.active-task-actions { + display: flex; + align-items: center; + gap: 10px; + flex-shrink: 0; +} + +.active-task-time { + font-size: 0.75rem; + color: var(--text-muted); +} + +.active-task-cancel { + padding: 6px 12px; + background: rgba(220, 53, 69, 0.1); + border: 1px solid rgba(220, 53, 69, 0.4); + border-radius: 6px; + color: var(--error-color); + font-size: 0.8125rem; + cursor: pointer; + transition: all 0.2s; +} + +.active-task-cancel:hover { + background: rgba(220, 53, 69, 0.2); + border-color: var(--error-color); +} + +.active-task-cancel:disabled { + opacity: 0.6; + cursor: not-allowed; +} + +.active-task-error { + font-size: 0.875rem; + color: var(--error-color); +} diff --git a/web/static/js/app.js b/web/static/js/app.js index 941af6fb..7cbb5fa0 100644 --- a/web/static/js/app.js +++ b/web/static/js/app.js @@ -1,6 +1,63 @@ // 当前对话ID let currentConversationId = null; +// 进度ID与任务信息映射 +const progressTaskState = new Map(); +// 活跃任务刷新定时器 +let activeTaskInterval = null; +const ACTIVE_TASK_REFRESH_INTERVAL = 20000; + +function registerProgressTask(progressId, conversationId = null) { + const state = progressTaskState.get(progressId) || {}; + state.conversationId = conversationId !== undefined && conversationId !== null + ? conversationId + : (state.conversationId ?? currentConversationId); + state.cancelling = false; + progressTaskState.set(progressId, state); + + const progressElement = document.getElementById(progressId); + if (progressElement) { + progressElement.dataset.conversationId = state.conversationId || ''; + } +} + +function updateProgressConversation(progressId, conversationId) { + if (!conversationId) { + return; + } + registerProgressTask(progressId, conversationId); +} + +function markProgressCancelling(progressId) { + const state = progressTaskState.get(progressId); + if (state) { + state.cancelling = true; + } +} + +function finalizeProgressTask(progressId, finalLabel = '已完成') { + const stopBtn = document.getElementById(`${progressId}-stop-btn`); + if (stopBtn) { + stopBtn.disabled = true; + stopBtn.textContent = finalLabel; + } + progressTaskState.delete(progressId); +} + +async function requestCancel(conversationId) { + const response = await fetch('/api/agent-loop/cancel', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ conversationId }), + }); + const result = await response.json().catch(() => ({})); + if (!response.ok) { + throw new Error(result.error || '取消失败'); + } + return result; +} // 发送消息 async function sendMessage() { @@ -18,6 +75,8 @@ async function sendMessage() { // 创建进度消息容器(使用详细的进度展示) const progressId = addProgressMessage(); const progressElement = document.getElementById(progressId); + registerProgressTask(progressId, currentConversationId); + loadActiveTasks(); let assistantMessageId = null; let mcpExecutionIds = []; @@ -103,13 +162,17 @@ function addProgressMessage() { bubble.innerHTML = `