From 5da2d461c6d8ec5ca40069d22359b352e7ee529f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=85=AC=E6=98=8E?= <83812544+Ed1s0nZ@users.noreply.github.com> Date: Sun, 19 Apr 2026 01:20:22 +0800 Subject: [PATCH] Delete handler directory --- handler/agent.go | 2579 ---------------- handler/attackchain.go | 173 -- handler/auth.go | 156 - handler/batch_task_manager.go | 1122 ------- handler/batch_task_mcp.go | 813 ----- handler/chat_uploads.go | 512 ---- handler/config.go | 1594 ---------- handler/conversation.go | 233 -- handler/external_mcp.go | 542 ---- handler/external_mcp_test.go | 518 ---- handler/fofa.go | 467 --- handler/group.go | 320 -- handler/knowledge.go | 517 ---- handler/markdown_agents.go | 299 -- handler/monitor.go | 420 --- handler/multi_agent.go | 316 -- handler/multi_agent_prepare.go | 138 - handler/openapi.go | 4596 ---------------------------- handler/openapi_i18n.go | 139 - handler/robot.go | 907 ------ handler/role.go | 487 --- handler/skills.go | 758 ----- handler/sse_keepalive.go | 58 - handler/task_manager.go | 276 -- handler/terminal.go | 257 -- handler/terminal_stream_unix.go | 46 - handler/terminal_stream_windows.go | 65 - handler/terminal_ws_unix.go | 112 - handler/vulnerability.go | 263 -- handler/webshell.go | 706 ----- 30 files changed, 19389 deletions(-) delete mode 100644 handler/agent.go delete mode 100644 handler/attackchain.go delete mode 100644 handler/auth.go delete mode 100644 handler/batch_task_manager.go delete mode 100644 handler/batch_task_mcp.go delete mode 100644 handler/chat_uploads.go delete mode 100644 handler/config.go delete mode 100644 handler/conversation.go delete mode 100644 handler/external_mcp.go delete mode 100644 handler/external_mcp_test.go delete mode 100644 handler/fofa.go delete mode 100644 handler/group.go delete mode 100644 handler/knowledge.go delete mode 100644 handler/markdown_agents.go delete mode 100644 handler/monitor.go delete mode 100644 handler/multi_agent.go delete mode 100644 handler/multi_agent_prepare.go delete mode 100644 handler/openapi.go delete mode 100644 handler/openapi_i18n.go delete mode 100644 handler/robot.go delete mode 100644 handler/role.go delete mode 100644 handler/skills.go delete mode 100644 handler/sse_keepalive.go delete mode 100644 handler/task_manager.go delete mode 100644 handler/terminal.go delete mode 100644 handler/terminal_stream_unix.go delete mode 100644 handler/terminal_stream_windows.go delete mode 100644 handler/terminal_ws_unix.go delete mode 100644 handler/vulnerability.go delete mode 100644 handler/webshell.go diff --git a/handler/agent.go b/handler/agent.go deleted file mode 100644 index 55bdeee3..00000000 --- a/handler/agent.go +++ /dev/null @@ -1,2579 +0,0 @@ -package handler - -import ( - "context" - "crypto/rand" - "encoding/base64" - "encoding/json" - "errors" - "fmt" - "net/http" - "os" - "path/filepath" - "strconv" - "strings" - "sync" - "time" - "unicode/utf8" - - "cyberstrike-ai/internal/agent" - "cyberstrike-ai/internal/config" - "cyberstrike-ai/internal/database" - "cyberstrike-ai/internal/mcp/builtin" - "cyberstrike-ai/internal/multiagent" - - "github.com/gin-gonic/gin" - "github.com/robfig/cron/v3" - "go.uber.org/zap" -) - -// safeTruncateString 安全截断字符串,避免在 UTF-8 字符中间截断 -func safeTruncateString(s string, maxLen int) string { - if maxLen <= 0 { - return "" - } - if utf8.RuneCountInString(s) <= maxLen { - return s - } - - // 将字符串转换为 rune 切片以正确计算字符数 - runes := []rune(s) - if len(runes) <= maxLen { - return s - } - - // 截断到最大长度 - truncated := string(runes[:maxLen]) - - // 尝试在标点符号或空格处截断,使截断更自然 - // 在截断点往前查找合适的断点(不超过20%的长度) - searchRange := maxLen / 5 - if searchRange > maxLen { - searchRange = maxLen - } - breakChars := []rune(",。、 ,.;:!?!?/\\-_") - bestBreakPos := len(runes[:maxLen]) - - for i := bestBreakPos - 1; i >= bestBreakPos-searchRange && i >= 0; i-- { - for _, breakChar := range breakChars { - if runes[i] == breakChar { - bestBreakPos = i + 1 // 在标点符号后断开 - goto found - } - } - } - -found: - truncated = string(runes[:bestBreakPos]) - return truncated + "..." -} - -// responsePlanAgg buffers main-assistant response_stream chunks for one "planning" process_detail row. -type responsePlanAgg struct { - meta map[string]interface{} - b strings.Builder -} - -func normalizeProcessDetailText(s string) string { - s = strings.ReplaceAll(s, "\r\n", "\n") - s = strings.ReplaceAll(s, "\r", "\n") - return strings.TrimSpace(s) -} - -// discardPlanningIfEchoesToolResult drops buffered planning text when it only repeats the -// upcoming tool_result body. Streaming models often echo tool stdout in chunk.Content; flushing -// that into "planning" before persisting tool_result duplicates the output after page refresh. -func discardPlanningIfEchoesToolResult(respPlan *responsePlanAgg, toolData interface{}) { - if respPlan == nil { - return - } - plan := normalizeProcessDetailText(respPlan.b.String()) - if plan == "" { - return - } - dataMap, ok := toolData.(map[string]interface{}) - if !ok { - return - } - res, ok := dataMap["result"].(string) - if !ok { - return - } - r := normalizeProcessDetailText(res) - if r == "" { - return - } - if plan == r || strings.HasSuffix(plan, r) { - respPlan.meta = nil - respPlan.b.Reset() - } -} - -// AgentHandler Agent处理器 -type AgentHandler struct { - agent *agent.Agent - db *database.DB - logger *zap.Logger - tasks *AgentTaskManager - batchTaskManager *BatchTaskManager - config *config.Config // 配置引用,用于获取角色信息 - knowledgeManager interface { // 知识库管理器接口 - LogRetrieval(conversationID, messageID, query, riskType string, retrievedItems []string) error - } - agentsMarkdownDir string // 多代理:Markdown 子 Agent 目录(绝对路径,空则不从磁盘合并) - batchCronParser cron.Parser - batchRunnerMu sync.Mutex - batchRunning map[string]struct{} -} - -// NewAgentHandler 创建新的Agent处理器 -func NewAgentHandler(agent *agent.Agent, db *database.DB, cfg *config.Config, logger *zap.Logger) *AgentHandler { - batchTaskManager := NewBatchTaskManager(logger) - batchTaskManager.SetDB(db) - - // 从数据库加载所有批量任务队列 - if err := batchTaskManager.LoadFromDB(); err != nil { - logger.Warn("从数据库加载批量任务队列失败", zap.Error(err)) - } - - handler := &AgentHandler{ - agent: agent, - db: db, - logger: logger, - tasks: NewAgentTaskManager(), - batchTaskManager: batchTaskManager, - config: cfg, - batchCronParser: cron.NewParser(cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow | cron.Descriptor), - batchRunning: make(map[string]struct{}), - } - go handler.batchQueueSchedulerLoop() - return handler -} - -// SetKnowledgeManager 设置知识库管理器(用于记录检索日志) -func (h *AgentHandler) SetKnowledgeManager(manager interface { - LogRetrieval(conversationID, messageID, query, riskType string, retrievedItems []string) error -}) { - h.knowledgeManager = manager -} - -// SetAgentsMarkdownDir 设置 agents/*.md 子代理目录(绝对路径);空表示仅使用 config.yaml 中的 sub_agents。 -func (h *AgentHandler) SetAgentsMarkdownDir(absDir string) { - h.agentsMarkdownDir = strings.TrimSpace(absDir) -} - -// ChatAttachment 聊天附件(用户上传的文件) -type ChatAttachment struct { - FileName string `json:"fileName"` // 展示用文件名 - Content string `json:"content,omitempty"` // 文本或 base64;若已预先上传到服务器可留空 - MimeType string `json:"mimeType,omitempty"` - ServerPath string `json:"serverPath,omitempty"` // 已保存在 chat_uploads 下的绝对路径(由 POST /api/chat-uploads 返回) -} - -// ChatRequest 聊天请求 -type ChatRequest struct { - Message string `json:"message" binding:"required"` - ConversationID string `json:"conversationId,omitempty"` - Role string `json:"role,omitempty"` // 角色名称 - Attachments []ChatAttachment `json:"attachments,omitempty"` - WebShellConnectionID string `json:"webshellConnectionId,omitempty"` // WebShell 管理 - AI 助手:当前选中的连接 ID,仅使用 webshell_* 工具 -} - -const ( - maxAttachments = 10 - chatUploadsDirName = "chat_uploads" // 对话附件保存的根目录(相对当前工作目录) -) - -// validateChatAttachmentServerPath 校验绝对路径落在工作目录 chat_uploads 下且为普通文件(防路径穿越) -func validateChatAttachmentServerPath(abs string) (string, error) { - p := strings.TrimSpace(abs) - if p == "" { - return "", fmt.Errorf("empty path") - } - cwd, err := os.Getwd() - if err != nil { - return "", fmt.Errorf("获取当前工作目录失败: %w", err) - } - root := filepath.Join(cwd, chatUploadsDirName) - rootAbs, err := filepath.Abs(filepath.Clean(root)) - if err != nil { - return "", err - } - pathAbs, err := filepath.Abs(filepath.Clean(p)) - if err != nil { - return "", err - } - sep := string(filepath.Separator) - if pathAbs != rootAbs && !strings.HasPrefix(pathAbs, rootAbs+sep) { - return "", fmt.Errorf("path outside chat_uploads") - } - st, err := os.Stat(pathAbs) - if err != nil { - return "", err - } - if st.IsDir() { - return "", fmt.Errorf("not a regular file") - } - return pathAbs, nil -} - -// avoidChatUploadDestCollision 若 path 已存在则生成带时间戳+随机后缀的新文件名(与上传接口命名风格一致) -func avoidChatUploadDestCollision(path string) string { - if _, err := os.Stat(path); os.IsNotExist(err) { - return path - } - dir := filepath.Dir(path) - base := filepath.Base(path) - ext := filepath.Ext(base) - nameNoExt := strings.TrimSuffix(base, ext) - suffix := fmt.Sprintf("_%s_%s", time.Now().Format("150405"), shortRand(6)) - var unique string - if ext != "" { - unique = nameNoExt + suffix + ext - } else { - unique = base + suffix - } - return filepath.Join(dir, unique) -} - -// relocateManualOrNewUploadToConversation 无会话 ID 时前端会上传到 …/日期/_manual;首条消息创建会话后,将文件移入 …/日期/{conversationId}/ 以便按对话隔离。 -func relocateManualOrNewUploadToConversation(absPath, conversationID string, logger *zap.Logger) (string, error) { - conv := strings.TrimSpace(conversationID) - if conv == "" { - return absPath, nil - } - convSan := strings.ReplaceAll(conv, string(filepath.Separator), "_") - if convSan == "" || convSan == "_manual" || convSan == "_new" { - return absPath, nil - } - cwd, err := os.Getwd() - if err != nil { - return absPath, err - } - rootAbs, err := filepath.Abs(filepath.Join(cwd, chatUploadsDirName)) - if err != nil { - return absPath, err - } - rel, err := filepath.Rel(rootAbs, absPath) - if err != nil { - return absPath, nil - } - rel = filepath.ToSlash(filepath.Clean(rel)) - var segs []string - for _, p := range strings.Split(rel, "/") { - if p != "" && p != "." { - segs = append(segs, p) - } - } - // 仅处理扁平结构:日期/_manual|_new/文件名 - if len(segs) != 3 { - return absPath, nil - } - datePart, placeFolder, baseName := segs[0], segs[1], segs[2] - if placeFolder != "_manual" && placeFolder != "_new" { - return absPath, nil - } - targetDir := filepath.Join(rootAbs, datePart, convSan) - if err := os.MkdirAll(targetDir, 0755); err != nil { - return "", fmt.Errorf("创建会话附件目录失败: %w", err) - } - dest := filepath.Join(targetDir, baseName) - dest = avoidChatUploadDestCollision(dest) - if err := os.Rename(absPath, dest); err != nil { - return "", fmt.Errorf("将附件移入会话目录失败: %w", err) - } - out, _ := filepath.Abs(dest) - if logger != nil { - logger.Info("对话附件已从占位目录移入会话目录", - zap.String("from", absPath), - zap.String("to", out), - zap.String("conversationId", conv)) - } - return out, nil -} - -// saveAttachmentsToDateAndConversationDir 处理附件:若带 serverPath 则仅校验已存在文件;否则将 content 写入 chat_uploads/YYYY-MM-DD/{conversationID}/。 -// conversationID 为空时使用 "_new" 作为目录名(新对话尚未有 ID) -func saveAttachmentsToDateAndConversationDir(attachments []ChatAttachment, conversationID string, logger *zap.Logger) (savedPaths []string, err error) { - if len(attachments) == 0 { - return nil, nil - } - cwd, err := os.Getwd() - if err != nil { - return nil, fmt.Errorf("获取当前工作目录失败: %w", err) - } - dateDir := filepath.Join(cwd, chatUploadsDirName, time.Now().Format("2006-01-02")) - convDirName := strings.TrimSpace(conversationID) - if convDirName == "" { - convDirName = "_new" - } else { - convDirName = strings.ReplaceAll(convDirName, string(filepath.Separator), "_") - } - targetDir := filepath.Join(dateDir, convDirName) - if err = os.MkdirAll(targetDir, 0755); err != nil { - return nil, fmt.Errorf("创建上传目录失败: %w", err) - } - savedPaths = make([]string, 0, len(attachments)) - for i, a := range attachments { - if sp := strings.TrimSpace(a.ServerPath); sp != "" { - valid, verr := validateChatAttachmentServerPath(sp) - if verr != nil { - return nil, fmt.Errorf("附件 %s: %w", a.FileName, verr) - } - finalPath, rerr := relocateManualOrNewUploadToConversation(valid, conversationID, logger) - if rerr != nil { - return nil, fmt.Errorf("附件 %s: %w", a.FileName, rerr) - } - savedPaths = append(savedPaths, finalPath) - if logger != nil { - logger.Debug("对话附件使用已上传路径", zap.Int("index", i+1), zap.String("fileName", a.FileName), zap.String("path", finalPath)) - } - continue - } - if strings.TrimSpace(a.Content) == "" { - return nil, fmt.Errorf("附件 %s 缺少内容或未提供 serverPath", a.FileName) - } - raw, decErr := attachmentContentToBytes(a) - if decErr != nil { - return nil, fmt.Errorf("附件 %s 解码失败: %w", a.FileName, decErr) - } - baseName := filepath.Base(a.FileName) - if baseName == "" || baseName == "." { - baseName = "file" - } - baseName = strings.ReplaceAll(baseName, string(filepath.Separator), "_") - ext := filepath.Ext(baseName) - nameNoExt := strings.TrimSuffix(baseName, ext) - suffix := fmt.Sprintf("_%s_%s", time.Now().Format("150405"), shortRand(6)) - var unique string - if ext != "" { - unique = nameNoExt + suffix + ext - } else { - unique = baseName + suffix - } - fullPath := filepath.Join(targetDir, unique) - if err = os.WriteFile(fullPath, raw, 0644); err != nil { - return nil, fmt.Errorf("写入文件 %s 失败: %w", a.FileName, err) - } - absPath, _ := filepath.Abs(fullPath) - savedPaths = append(savedPaths, absPath) - if logger != nil { - logger.Debug("对话附件已保存", zap.Int("index", i+1), zap.String("fileName", a.FileName), zap.String("path", absPath)) - } - } - return savedPaths, nil -} - -func shortRand(n int) string { - const letters = "0123456789abcdef" - b := make([]byte, n) - _, _ = rand.Read(b) - for i := range b { - b[i] = letters[int(b[i])%len(letters)] - } - return string(b) -} - -func attachmentContentToBytes(a ChatAttachment) ([]byte, error) { - content := a.Content - if decoded, err := base64.StdEncoding.DecodeString(content); err == nil && len(decoded) > 0 { - return decoded, nil - } - return []byte(content), nil -} - -// userMessageContentForStorage 返回要存入数据库的用户消息内容:有附件时在正文后追加附件名(及路径),刷新后仍能显示,继续对话时大模型也能从历史中拿到路径 -func userMessageContentForStorage(message string, attachments []ChatAttachment, savedPaths []string) string { - if len(attachments) == 0 { - return message - } - var b strings.Builder - b.WriteString(message) - for i, a := range attachments { - b.WriteString("\n📎 ") - b.WriteString(a.FileName) - if i < len(savedPaths) && savedPaths[i] != "" { - b.WriteString(": ") - b.WriteString(savedPaths[i]) - } - } - return b.String() -} - -// appendAttachmentsToMessage 仅将附件的保存路径追加到用户消息末尾,不再内联附件内容,避免上下文过长 -func appendAttachmentsToMessage(msg string, attachments []ChatAttachment, savedPaths []string) string { - if len(attachments) == 0 { - return msg - } - var b strings.Builder - b.WriteString(msg) - b.WriteString("\n\n[用户上传的文件已保存到以下路径(请按需读取文件内容,而不是依赖内联内容)]\n") - for i, a := range attachments { - if i < len(savedPaths) && savedPaths[i] != "" { - b.WriteString(fmt.Sprintf("- %s: %s\n", a.FileName, savedPaths[i])) - } else { - b.WriteString(fmt.Sprintf("- %s: (路径未知,可能保存失败)\n", a.FileName)) - } - } - return b.String() -} - -// ChatResponse 聊天响应 -type ChatResponse struct { - Response string `json:"response"` - MCPExecutionIDs []string `json:"mcpExecutionIds,omitempty"` // 本次对话中执行的MCP调用ID列表 - ConversationID string `json:"conversationId"` // 对话ID - Time time.Time `json:"time"` -} - -// AgentLoop 处理Agent Loop请求 -func (h *AgentHandler) AgentLoop(c *gin.Context) { - var req ChatRequest - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - - h.logger.Info("收到Agent Loop请求", - zap.String("message", req.Message), - zap.String("conversationId", req.ConversationID), - ) - - // 如果没有对话ID,创建新对话 - conversationID := req.ConversationID - if conversationID == "" { - title := safeTruncateString(req.Message, 50) - conv, err := h.db.CreateConversation(title) - if err != nil { - h.logger.Error("创建对话失败", zap.Error(err)) - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - conversationID = conv.ID - } else { - // 验证对话是否存在 - _, err := h.db.GetConversation(conversationID) - if err != nil { - h.logger.Error("对话不存在", zap.String("conversationId", conversationID), zap.Error(err)) - c.JSON(http.StatusNotFound, gin.H{"error": "对话不存在"}) - return - } - } - - // 优先尝试从保存的ReAct数据恢复历史上下文 - agentHistoryMessages, err := h.loadHistoryFromReActData(conversationID) - if err != nil { - h.logger.Warn("从ReAct数据加载历史消息失败,使用消息表", zap.Error(err)) - // 回退到使用数据库消息表 - historyMessages, err := h.db.GetMessages(conversationID) - if err != nil { - h.logger.Warn("获取历史消息失败", zap.Error(err)) - agentHistoryMessages = []agent.ChatMessage{} - } else { - // 将数据库消息转换为Agent消息格式 - agentHistoryMessages = make([]agent.ChatMessage, 0, len(historyMessages)) - for _, msg := range historyMessages { - agentHistoryMessages = append(agentHistoryMessages, agent.ChatMessage{ - Role: msg.Role, - Content: msg.Content, - }) - } - h.logger.Info("从消息表加载历史消息", zap.Int("count", len(agentHistoryMessages))) - } - } else { - h.logger.Info("从ReAct数据恢复历史上下文", zap.Int("count", len(agentHistoryMessages))) - } - - // 校验附件数量(非流式) - if len(req.Attachments) > maxAttachments { - c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("附件最多 %d 个", maxAttachments)}) - return - } - - // 应用角色用户提示词和工具配置 - finalMessage := req.Message - var roleTools []string // 角色配置的工具列表 - var roleSkills []string // 角色配置的skills列表(用于提示AI,但不硬编码内容) - - // WebShell AI 助手模式:绑定当前连接,仅开放 webshell_* 工具并注入 connection_id - if req.WebShellConnectionID != "" { - conn, err := h.db.GetWebshellConnection(strings.TrimSpace(req.WebShellConnectionID)) - if err != nil || conn == nil { - h.logger.Warn("WebShell AI 助手:未找到连接", zap.String("id", req.WebShellConnectionID), zap.Error(err)) - c.JSON(http.StatusBadRequest, gin.H{"error": "未找到该 WebShell 连接"}) - return - } - remark := conn.Remark - if remark == "" { - remark = conn.URL - } - finalMessage = fmt.Sprintf("[WebShell 助手上下文] 当前连接 ID:%s,备注:%s。可用工具(仅在该连接上操作时使用,connection_id 填 \"%s\"):webshell_exec、webshell_file_list、webshell_file_read、webshell_file_write、record_vulnerability、list_knowledge_risk_types、search_knowledge_base。Skills 包请使用「多代理 / Eino DeepAgent」会话中的内置 `skill` 工具渐进加载。\n\n用户请求:%s", - conn.ID, remark, conn.ID, req.Message) - roleTools = []string{ - builtin.ToolWebshellExec, - builtin.ToolWebshellFileList, - builtin.ToolWebshellFileRead, - builtin.ToolWebshellFileWrite, - builtin.ToolRecordVulnerability, - builtin.ToolListKnowledgeRiskTypes, - builtin.ToolSearchKnowledgeBase, - } - roleSkills = nil - } else if req.Role != "" && req.Role != "默认" { - if h.config.Roles != nil { - if role, exists := h.config.Roles[req.Role]; exists && role.Enabled { - // 应用用户提示词 - if role.UserPrompt != "" { - finalMessage = role.UserPrompt + "\n\n" + req.Message - h.logger.Info("应用角色用户提示词", zap.String("role", req.Role)) - } - // 获取角色配置的工具列表(优先使用tools字段,向后兼容mcps字段) - if len(role.Tools) > 0 { - roleTools = role.Tools - h.logger.Info("使用角色配置的工具列表", zap.String("role", req.Role), zap.Int("toolCount", len(roleTools))) - } - // 获取角色配置的skills列表(用于在系统提示词中提示AI,但不硬编码内容) - if len(role.Skills) > 0 { - roleSkills = role.Skills - h.logger.Info("角色配置了skills,将在系统提示词中提示AI", zap.String("role", req.Role), zap.Int("skillCount", len(roleSkills)), zap.Strings("skills", roleSkills)) - } - } - } - } - var savedPaths []string - if len(req.Attachments) > 0 { - savedPaths, err = saveAttachmentsToDateAndConversationDir(req.Attachments, conversationID, h.logger) - if err != nil { - h.logger.Error("保存对话附件失败", zap.Error(err)) - c.JSON(http.StatusInternalServerError, gin.H{"error": "保存上传文件失败: " + err.Error()}) - return - } - } - finalMessage = appendAttachmentsToMessage(finalMessage, req.Attachments, savedPaths) - - // 保存用户消息:有附件时一并保存附件名与路径,刷新后显示、继续对话时大模型也能从历史中拿到路径 - userContent := userMessageContentForStorage(req.Message, req.Attachments, savedPaths) - _, err = h.db.AddMessage(conversationID, "user", userContent, nil) - if err != nil { - h.logger.Error("保存用户消息失败", zap.Error(err)) - c.JSON(http.StatusInternalServerError, gin.H{"error": "保存用户消息失败: " + err.Error()}) - return - } - - // 执行Agent Loop,传入历史消息和对话ID(使用包含角色提示词的finalMessage和角色工具列表) - // 注意:skills不会硬编码注入,但会在系统提示词中提示AI这个角色推荐使用哪些skills - result, err := h.agent.AgentLoopWithProgress(c.Request.Context(), finalMessage, agentHistoryMessages, conversationID, nil, roleTools, roleSkills) - if err != nil { - h.logger.Error("Agent Loop执行失败", zap.Error(err)) - - // 即使执行失败,也尝试保存ReAct数据(如果result中有) - if result != nil && (result.LastReActInput != "" || result.LastReActOutput != "") { - if saveErr := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); saveErr != nil { - h.logger.Warn("保存失败任务的ReAct数据失败", zap.Error(saveErr)) - } else { - h.logger.Info("已保存失败任务的ReAct数据", zap.String("conversationId", conversationID)) - } - } - - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - - // 保存助手回复 - _, err = h.db.AddMessage(conversationID, "assistant", result.Response, result.MCPExecutionIDs) - if err != nil { - h.logger.Error("保存助手消息失败", zap.Error(err)) - // 即使保存失败,也返回响应,但记录错误 - // 因为AI已经生成了回复,用户应该能看到 - } - - // 保存最后一轮ReAct的输入和输出 - if result.LastReActInput != "" || result.LastReActOutput != "" { - if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil { - h.logger.Warn("保存ReAct数据失败", zap.Error(err)) - } else { - h.logger.Info("已保存ReAct数据", zap.String("conversationId", conversationID)) - } - } - - c.JSON(http.StatusOK, ChatResponse{ - Response: result.Response, - MCPExecutionIDs: result.MCPExecutionIDs, - ConversationID: conversationID, - Time: time.Now(), - }) -} - -// ProcessMessageForRobot 供机器人(企业微信/钉钉/飞书)调用:与 /api/agent-loop/stream 相同执行路径(含 progressCallback、过程详情),仅不发送 SSE,最后返回完整回复 -func (h *AgentHandler) ProcessMessageForRobot(ctx context.Context, conversationID, message, role string) (response string, convID string, err error) { - if conversationID == "" { - title := safeTruncateString(message, 50) - conv, createErr := h.db.CreateConversation(title) - if createErr != nil { - return "", "", fmt.Errorf("创建对话失败: %w", createErr) - } - conversationID = conv.ID - } else { - if _, getErr := h.db.GetConversation(conversationID); getErr != nil { - return "", "", fmt.Errorf("对话不存在") - } - } - - agentHistoryMessages, err := h.loadHistoryFromReActData(conversationID) - if err != nil { - historyMessages, getErr := h.db.GetMessages(conversationID) - if getErr != nil { - agentHistoryMessages = []agent.ChatMessage{} - } else { - agentHistoryMessages = make([]agent.ChatMessage, 0, len(historyMessages)) - for _, msg := range historyMessages { - agentHistoryMessages = append(agentHistoryMessages, agent.ChatMessage{Role: msg.Role, Content: msg.Content}) - } - } - } - - finalMessage := message - var roleTools, roleSkills []string - if role != "" && role != "默认" && h.config.Roles != nil { - if r, exists := h.config.Roles[role]; exists && r.Enabled { - if r.UserPrompt != "" { - finalMessage = r.UserPrompt + "\n\n" + message - } - roleTools = r.Tools - roleSkills = r.Skills - } - } - - if _, err = h.db.AddMessage(conversationID, "user", message, nil); err != nil { - return "", "", fmt.Errorf("保存用户消息失败: %w", err) - } - - // 与 agent-loop/stream 一致:先创建助手消息占位,用 progressCallback 写过程详情(不发送 SSE) - assistantMsg, err := h.db.AddMessage(conversationID, "assistant", "处理中...", nil) - if err != nil { - h.logger.Warn("机器人:创建助手消息占位失败", zap.Error(err)) - } - var assistantMessageID string - if assistantMsg != nil { - assistantMessageID = assistantMsg.ID - } - progressCallback := h.createProgressCallback(conversationID, assistantMessageID, nil) - - useRobotMulti := h.config != nil && h.config.MultiAgent.Enabled && h.config.MultiAgent.RobotUseMultiAgent - if useRobotMulti { - resultMA, errMA := multiagent.RunDeepAgent( - ctx, - h.config, - &h.config.MultiAgent, - h.agent, - h.logger, - conversationID, - finalMessage, - agentHistoryMessages, - roleTools, - progressCallback, - h.agentsMarkdownDir, - ) - if errMA != nil { - errMsg := "执行失败: " + errMA.Error() - if assistantMessageID != "" { - _, _ = h.db.Exec("UPDATE messages SET content = ? WHERE id = ?", errMsg, assistantMessageID) - _ = h.db.AddProcessDetail(assistantMessageID, conversationID, "error", errMsg, nil) - } - return "", conversationID, errMA - } - if assistantMessageID != "" { - mcpIDsJSON := "" - if len(resultMA.MCPExecutionIDs) > 0 { - jsonData, _ := json.Marshal(resultMA.MCPExecutionIDs) - mcpIDsJSON = string(jsonData) - } - _, err = h.db.Exec( - "UPDATE messages SET content = ?, mcp_execution_ids = ? WHERE id = ?", - resultMA.Response, mcpIDsJSON, assistantMessageID, - ) - if err != nil { - h.logger.Warn("机器人:更新助手消息失败", zap.Error(err)) - } - } else { - if _, err = h.db.AddMessage(conversationID, "assistant", resultMA.Response, resultMA.MCPExecutionIDs); err != nil { - h.logger.Warn("机器人:保存助手消息失败", zap.Error(err)) - } - } - if resultMA.LastReActInput != "" || resultMA.LastReActOutput != "" { - _ = h.db.SaveReActData(conversationID, resultMA.LastReActInput, resultMA.LastReActOutput) - } - return resultMA.Response, conversationID, nil - } - - result, err := h.agent.AgentLoopWithProgress(ctx, finalMessage, agentHistoryMessages, conversationID, progressCallback, roleTools, roleSkills) - if err != nil { - errMsg := "执行失败: " + err.Error() - if assistantMessageID != "" { - _, _ = h.db.Exec("UPDATE messages SET content = ? WHERE id = ?", errMsg, assistantMessageID) - _ = h.db.AddProcessDetail(assistantMessageID, conversationID, "error", errMsg, nil) - } - return "", conversationID, err - } - - // 更新助手消息内容与 MCP 执行 ID(与 stream 一致) - if assistantMessageID != "" { - mcpIDsJSON := "" - if len(result.MCPExecutionIDs) > 0 { - jsonData, _ := json.Marshal(result.MCPExecutionIDs) - mcpIDsJSON = string(jsonData) - } - _, err = h.db.Exec( - "UPDATE messages SET content = ?, mcp_execution_ids = ? WHERE id = ?", - result.Response, mcpIDsJSON, assistantMessageID, - ) - if err != nil { - h.logger.Warn("机器人:更新助手消息失败", zap.Error(err)) - } - } else { - if _, err = h.db.AddMessage(conversationID, "assistant", result.Response, result.MCPExecutionIDs); err != nil { - h.logger.Warn("机器人:保存助手消息失败", zap.Error(err)) - } - } - if result.LastReActInput != "" || result.LastReActOutput != "" { - _ = h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput) - } - return result.Response, conversationID, nil -} - -// StreamEvent 流式事件 -type StreamEvent struct { - Type string `json:"type"` // conversation, progress, tool_call, tool_result, response, error, cancelled, done - Message string `json:"message"` // 显示消息 - Data interface{} `json:"data,omitempty"` -} - -// createProgressCallback 创建进度回调函数,用于保存processDetails -// sendEventFunc: 可选的流式事件发送函数,如果为nil则不发送流式事件 -func (h *AgentHandler) createProgressCallback(conversationID, assistantMessageID string, sendEventFunc func(eventType, message string, data interface{})) agent.ProgressCallback { - // 用于保存tool_call事件中的参数,以便在tool_result时使用 - toolCallCache := make(map[string]map[string]interface{}) // toolCallId -> arguments - - // thinking_stream_*:不逐条落库,按 streamId 聚合,在后续关键事件前补一条可持久化的 thinking - type thinkingBuf struct { - b strings.Builder - meta map[string]interface{} - } - thinkingStreams := make(map[string]*thinkingBuf) // streamId -> buf - flushedThinking := make(map[string]bool) // streamId -> flushed - - // response_start + response_delta:前端时间线显示为「📝 规划中」(monitor.js),不落逐条 delta; - // 聚合为一条 planning 写入 process_details,刷新后与线上一致。 - var respPlan responsePlanAgg - flushResponsePlan := func() { - if assistantMessageID == "" { - return - } - content := strings.TrimSpace(respPlan.b.String()) - if content == "" { - respPlan.meta = nil - respPlan.b.Reset() - return - } - data := map[string]interface{}{ - "source": "response_stream", - } - for k, v := range respPlan.meta { - data[k] = v - } - if err := h.db.AddProcessDetail(assistantMessageID, conversationID, "planning", content, data); err != nil { - h.logger.Warn("保存过程详情失败", zap.Error(err), zap.String("eventType", "planning")) - } - respPlan.meta = nil - respPlan.b.Reset() - } - - flushThinkingStreams := func() { - if assistantMessageID == "" { - return - } - for sid, tb := range thinkingStreams { - if sid == "" || flushedThinking[sid] || tb == nil { - continue - } - content := strings.TrimSpace(tb.b.String()) - if content == "" { - flushedThinking[sid] = true - continue - } - data := map[string]interface{}{ - "streamId": sid, - } - for k, v := range tb.meta { - // 避免覆盖 streamId - if k == "streamId" { - continue - } - data[k] = v - } - if err := h.db.AddProcessDetail(assistantMessageID, conversationID, "thinking", content, data); err != nil { - h.logger.Warn("保存过程详情失败", zap.Error(err), zap.String("eventType", "thinking")) - } - flushedThinking[sid] = true - } - } - - return func(eventType, message string, data interface{}) { - // 如果提供了sendEventFunc,发送流式事件 - if sendEventFunc != nil { - sendEventFunc(eventType, message, data) - } - - // 保存tool_call事件中的参数 - if eventType == "tool_call" { - if dataMap, ok := data.(map[string]interface{}); ok { - toolName, _ := dataMap["toolName"].(string) - if toolName == builtin.ToolSearchKnowledgeBase { - if toolCallId, ok := dataMap["toolCallId"].(string); ok && toolCallId != "" { - if argumentsObj, ok := dataMap["argumentsObj"].(map[string]interface{}); ok { - toolCallCache[toolCallId] = argumentsObj - } - } - } - } - } - - // 处理知识检索日志记录 - if eventType == "tool_result" && h.knowledgeManager != nil { - if dataMap, ok := data.(map[string]interface{}); ok { - toolName, _ := dataMap["toolName"].(string) - if toolName == builtin.ToolSearchKnowledgeBase { - // 提取检索信息 - query := "" - riskType := "" - var retrievedItems []string - - // 首先尝试从tool_call缓存中获取参数 - if toolCallId, ok := dataMap["toolCallId"].(string); ok && toolCallId != "" { - if cachedArgs, exists := toolCallCache[toolCallId]; exists { - if q, ok := cachedArgs["query"].(string); ok && q != "" { - query = q - } - if rt, ok := cachedArgs["risk_type"].(string); ok && rt != "" { - riskType = rt - } - // 使用后清理缓存 - delete(toolCallCache, toolCallId) - } - } - - // 如果缓存中没有,尝试从argumentsObj中提取 - if query == "" { - if arguments, ok := dataMap["argumentsObj"].(map[string]interface{}); ok { - if q, ok := arguments["query"].(string); ok && q != "" { - query = q - } - if rt, ok := arguments["risk_type"].(string); ok && rt != "" { - riskType = rt - } - } - } - - // 如果query仍然为空,尝试从result中提取(从结果文本的第一行) - if query == "" { - if result, ok := dataMap["result"].(string); ok && result != "" { - // 尝试从结果中提取查询内容(如果结果包含"未找到与查询 'xxx' 相关的知识") - if strings.Contains(result, "未找到与查询 '") { - start := strings.Index(result, "未找到与查询 '") + len("未找到与查询 '") - end := strings.Index(result[start:], "'") - if end > 0 { - query = result[start : start+end] - } - } - } - // 如果还是为空,使用默认值 - if query == "" { - query = "未知查询" - } - } - - // 从工具结果中提取检索到的知识项ID - // 结果格式:"找到 X 条相关知识:\n\n--- 结果 1 (相似度: XX.XX%) ---\n来源: [分类] 标题\n...\n" - if result, ok := dataMap["result"].(string); ok && result != "" { - // 尝试从元数据中提取知识项ID - metadataMatch := strings.Index(result, "") - if metadataEnd > 0 { - metadataJSON := result[metadataStart : metadataStart+metadataEnd] - var metadata map[string]interface{} - if err := json.Unmarshal([]byte(metadataJSON), &metadata); err == nil { - if meta, ok := metadata["_metadata"].(map[string]interface{}); ok { - if ids, ok := meta["retrievedItemIDs"].([]interface{}); ok { - retrievedItems = make([]string, 0, len(ids)) - for _, id := range ids { - if idStr, ok := id.(string); ok { - retrievedItems = append(retrievedItems, idStr) - } - } - } - } - } - } - } - - // 如果没有从元数据中提取到,但结果包含"找到 X 条",至少标记为有结果 - if len(retrievedItems) == 0 && strings.Contains(result, "找到") && !strings.Contains(result, "未找到") { - // 有结果,但无法准确提取ID,使用特殊标记 - retrievedItems = []string{"_has_results"} - } - } - - // 记录检索日志(异步,不阻塞) - go func() { - if err := h.knowledgeManager.LogRetrieval(conversationID, assistantMessageID, query, riskType, retrievedItems); err != nil { - h.logger.Warn("记录知识检索日志失败", zap.Error(err)) - } - }() - - // 添加知识检索事件到processDetails - if assistantMessageID != "" { - retrievalData := map[string]interface{}{ - "query": query, - "riskType": riskType, - "toolName": toolName, - } - if err := h.db.AddProcessDetail(assistantMessageID, conversationID, "knowledge_retrieval", fmt.Sprintf("检索知识: %s", query), retrievalData); err != nil { - h.logger.Warn("保存知识检索详情失败", zap.Error(err)) - } - } - } - } - } - - // 子代理回复流式增量不落库;结束时合并为一条 eino_agent_reply - if assistantMessageID != "" && eventType == "eino_agent_reply_stream_end" { - flushResponsePlan() - // 确保思考流在子代理回复前能持久化(刷新后可读) - flushThinkingStreams() - if err := h.db.AddProcessDetail(assistantMessageID, conversationID, "eino_agent_reply", message, data); err != nil { - h.logger.Warn("保存过程详情失败", zap.Error(err), zap.String("eventType", eventType)) - } - return - } - - // 多代理主代理「规划中」:response_start / response_delta 仅用于 SSE,聚合落一条 planning - if eventType == "response_start" { - flushResponsePlan() - respPlan.meta = nil - if dataMap, ok := data.(map[string]interface{}); ok { - respPlan.meta = make(map[string]interface{}, len(dataMap)) - for k, v := range dataMap { - respPlan.meta[k] = v - } - } - respPlan.b.Reset() - return - } - if eventType == "response_delta" { - respPlan.b.WriteString(message) - if dataMap, ok := data.(map[string]interface{}); ok && respPlan.meta == nil { - respPlan.meta = make(map[string]interface{}, len(dataMap)) - for k, v := range dataMap { - respPlan.meta[k] = v - } - } else if dataMap, ok := data.(map[string]interface{}); ok { - for k, v := range dataMap { - respPlan.meta[k] = v - } - } - return - } - if eventType == "response" { - flushResponsePlan() - return - } - - // 聚合 thinking_stream_*(ReasoningContent),不逐条落库 - if eventType == "thinking_stream_start" { - if dataMap, ok := data.(map[string]interface{}); ok { - if sid, ok2 := dataMap["streamId"].(string); ok2 && sid != "" { - tb := thinkingStreams[sid] - if tb == nil { - tb = &thinkingBuf{meta: map[string]interface{}{}} - thinkingStreams[sid] = tb - } - // 记录元信息(source/einoAgent/einoRole/iteration 等) - for k, v := range dataMap { - tb.meta[k] = v - } - } - } - return - } - if eventType == "thinking_stream_delta" { - if dataMap, ok := data.(map[string]interface{}); ok { - if sid, ok2 := dataMap["streamId"].(string); ok2 && sid != "" { - tb := thinkingStreams[sid] - if tb == nil { - tb = &thinkingBuf{meta: map[string]interface{}{}} - thinkingStreams[sid] = tb - } - // delta 片段直接拼接;message 本身就是 reasoning content - tb.b.WriteString(message) - // 有时 delta 先到 start 未到,补充元信息 - for k, v := range dataMap { - tb.meta[k] = v - } - } - } - return - } - - // 当 Agent 同时发送 thinking_stream_* 和 thinking(带同一 streamId)时, - // thinking_stream_* 已经会在 flushThinkingStreams() 聚合落库; - // 这里跳过同 streamId 的 thinking,避免 processDetails 双份展示。 - if eventType == "thinking" { - if dataMap, ok := data.(map[string]interface{}); ok { - if sid, ok2 := dataMap["streamId"].(string); ok2 && sid != "" { - if tb, exists := thinkingStreams[sid]; exists && tb != nil { - if strings.TrimSpace(tb.b.String()) != "" { - return - } - } - if flushedThinking[sid] { - return - } - } - } - } - - // 保存过程详情到数据库(排除 response/done;response 正文已在 messages 表) - // response_start/response_delta 已聚合为 planning,不落逐条。 - if assistantMessageID != "" && - eventType != "response" && - eventType != "done" && - eventType != "response_start" && - eventType != "response_delta" && - eventType != "tool_result_delta" && - eventType != "eino_agent_reply_stream_start" && - eventType != "eino_agent_reply_stream_delta" && - eventType != "eino_agent_reply_stream_end" { - if eventType == "tool_result" { - discardPlanningIfEchoesToolResult(&respPlan, data) - } - // 在关键过程事件落库前,先把「规划中」与 thinking_stream 落库 - flushResponsePlan() - flushThinkingStreams() - if err := h.db.AddProcessDetail(assistantMessageID, conversationID, eventType, message, data); err != nil { - h.logger.Warn("保存过程详情失败", zap.Error(err), zap.String("eventType", eventType)) - } - } - } -} - -// AgentLoopStream 处理Agent Loop流式请求 -func (h *AgentHandler) AgentLoopStream(c *gin.Context) { - var req ChatRequest - if err := c.ShouldBindJSON(&req); err != nil { - // 对于流式请求,也发送SSE格式的错误 - c.Header("Content-Type", "text/event-stream") - c.Header("Cache-Control", "no-cache") - c.Header("Connection", "keep-alive") - event := StreamEvent{ - Type: "error", - Message: "请求参数错误: " + err.Error(), - } - eventJSON, _ := json.Marshal(event) - fmt.Fprintf(c.Writer, "data: %s\n\n", eventJSON) - c.Writer.Flush() - return - } - - h.logger.Info("收到Agent Loop流式请求", - zap.String("message", req.Message), - zap.String("conversationId", req.ConversationID), - ) - - // 设置SSE响应头 - c.Header("Content-Type", "text/event-stream") - c.Header("Cache-Control", "no-cache") - c.Header("Connection", "keep-alive") - c.Header("X-Accel-Buffering", "no") // 禁用nginx缓冲 - - // 发送初始事件 - // 用于跟踪客户端是否已断开连接 - clientDisconnected := false - // 与 sseKeepalive 共用:禁止并发写 ResponseWriter,否则会破坏 chunked 编码(ERR_INVALID_CHUNKED_ENCODING)。 - var sseWriteMu sync.Mutex - // 用于快速确认模型是否真的产生了流式 delta - var responseDeltaCount int - var responseStartLogged bool - - sendEvent := func(eventType, message string, data interface{}) { - if eventType == "response_start" { - responseDeltaCount = 0 - responseStartLogged = true - h.logger.Info("SSE: response_start", - zap.Int("conversationIdPresent", func() int { - if m, ok := data.(map[string]interface{}); ok { - if v, ok2 := m["conversationId"]; ok2 && v != nil && fmt.Sprint(v) != "" { - return 1 - } - } - return 0 - }()), - zap.String("messageGeneratedBy", func() string { - if m, ok := data.(map[string]interface{}); ok { - if v, ok2 := m["messageGeneratedBy"]; ok2 { - if s, ok3 := v.(string); ok3 { - return s - } - return fmt.Sprint(v) - } - } - return "" - }()), - ) - } else if eventType == "response_delta" { - responseDeltaCount++ - // 只打前几条,避免刷屏 - if responseStartLogged && responseDeltaCount <= 3 { - h.logger.Info("SSE: response_delta", - zap.Int("index", responseDeltaCount), - zap.Int("deltaLen", len(message)), - zap.String("deltaPreview", func() string { - p := strings.ReplaceAll(message, "\n", "\\n") - if len(p) > 80 { - return p[:80] + "..." - } - return p - }()), - ) - } - } - - // 如果客户端已断开,不再发送事件 - if clientDisconnected { - return - } - - // 检查请求上下文是否被取消(客户端断开) - select { - case <-c.Request.Context().Done(): - clientDisconnected = true - return - default: - } - - event := StreamEvent{ - Type: eventType, - Message: message, - Data: data, - } - eventJSON, _ := json.Marshal(event) - - sseWriteMu.Lock() - _, err := fmt.Fprintf(c.Writer, "data: %s\n\n", eventJSON) - if err != nil { - sseWriteMu.Unlock() - clientDisconnected = true - h.logger.Debug("客户端断开连接,停止发送SSE事件", zap.Error(err)) - return - } - if flusher, ok := c.Writer.(http.Flusher); ok { - flusher.Flush() - } else { - c.Writer.Flush() - } - sseWriteMu.Unlock() - } - - // 如果没有对话ID,创建新对话(WebShell 助手模式下关联连接 ID 以便持久化展示) - conversationID := req.ConversationID - if conversationID == "" { - title := safeTruncateString(req.Message, 50) - var conv *database.Conversation - var err error - if req.WebShellConnectionID != "" { - conv, err = h.db.CreateConversationWithWebshell(strings.TrimSpace(req.WebShellConnectionID), title) - } else { - conv, err = h.db.CreateConversation(title) - } - if err != nil { - h.logger.Error("创建对话失败", zap.Error(err)) - sendEvent("error", "创建对话失败: "+err.Error(), nil) - return - } - conversationID = conv.ID - sendEvent("conversation", "会话已创建", map[string]interface{}{ - "conversationId": conversationID, - }) - } else { - // 验证对话是否存在 - _, err := h.db.GetConversation(conversationID) - if err != nil { - h.logger.Error("对话不存在", zap.String("conversationId", conversationID), zap.Error(err)) - sendEvent("error", "对话不存在", nil) - return - } - } - - // 优先尝试从保存的ReAct数据恢复历史上下文 - agentHistoryMessages, err := h.loadHistoryFromReActData(conversationID) - if err != nil { - h.logger.Warn("从ReAct数据加载历史消息失败,使用消息表", zap.Error(err)) - // 回退到使用数据库消息表 - historyMessages, err := h.db.GetMessages(conversationID) - if err != nil { - h.logger.Warn("获取历史消息失败", zap.Error(err)) - agentHistoryMessages = []agent.ChatMessage{} - } else { - // 将数据库消息转换为Agent消息格式 - agentHistoryMessages = make([]agent.ChatMessage, 0, len(historyMessages)) - for _, msg := range historyMessages { - agentHistoryMessages = append(agentHistoryMessages, agent.ChatMessage{ - Role: msg.Role, - Content: msg.Content, - }) - } - h.logger.Info("从消息表加载历史消息", zap.Int("count", len(agentHistoryMessages))) - } - } else { - h.logger.Info("从ReAct数据恢复历史上下文", zap.Int("count", len(agentHistoryMessages))) - } - - // 校验附件数量 - if len(req.Attachments) > maxAttachments { - sendEvent("error", fmt.Sprintf("附件最多 %d 个", maxAttachments), nil) - return - } - - // 应用角色用户提示词和工具配置 - finalMessage := req.Message - var roleTools []string // 角色配置的工具列表 - var roleSkills []string - if req.WebShellConnectionID != "" { - conn, errConn := h.db.GetWebshellConnection(strings.TrimSpace(req.WebShellConnectionID)) - if errConn != nil || conn == nil { - h.logger.Warn("WebShell AI 助手:未找到连接", zap.String("id", req.WebShellConnectionID), zap.Error(errConn)) - sendEvent("error", "未找到该 WebShell 连接", nil) - return - } - remark := conn.Remark - if remark == "" { - remark = conn.URL - } - finalMessage = fmt.Sprintf("[WebShell 助手上下文] 当前连接 ID:%s,备注:%s。可用工具(仅在该连接上操作时使用,connection_id 填 \"%s\"):webshell_exec、webshell_file_list、webshell_file_read、webshell_file_write、record_vulnerability、list_knowledge_risk_types、search_knowledge_base。Skills 包请使用「多代理 / Eino DeepAgent」会话中的内置 `skill` 工具渐进加载。\n\n用户请求:%s", - conn.ID, remark, conn.ID, req.Message) - roleTools = []string{ - builtin.ToolWebshellExec, - builtin.ToolWebshellFileList, - builtin.ToolWebshellFileRead, - builtin.ToolWebshellFileWrite, - builtin.ToolRecordVulnerability, - builtin.ToolListKnowledgeRiskTypes, - builtin.ToolSearchKnowledgeBase, - } - } else if req.Role != "" && req.Role != "默认" { - if h.config.Roles != nil { - if role, exists := h.config.Roles[req.Role]; exists && role.Enabled { - // 应用用户提示词 - if role.UserPrompt != "" { - finalMessage = role.UserPrompt + "\n\n" + req.Message - h.logger.Info("应用角色用户提示词", zap.String("role", req.Role)) - } - // 获取角色配置的工具列表(优先使用tools字段,向后兼容mcps字段) - if len(role.Tools) > 0 { - roleTools = role.Tools - h.logger.Info("使用角色配置的工具列表", zap.String("role", req.Role), zap.Int("toolCount", len(roleTools))) - } else if len(role.MCPs) > 0 { - // 向后兼容:如果只有mcps字段,暂时使用空列表(表示使用所有工具) - // 因为mcps是MCP服务器名称,不是工具列表 - h.logger.Info("角色配置使用旧的mcps字段,将使用所有工具", zap.String("role", req.Role)) - } - // 注意:角色 skills 仅在系统提示词中提示;运行时加载请使用 Eino 多代理内置 `skill` 工具 - if len(role.Skills) > 0 { - roleSkills = role.Skills - h.logger.Info("角色配置了skills,AI可通过工具按需调用", zap.String("role", req.Role), zap.Int("skillCount", len(role.Skills)), zap.Strings("skills", role.Skills)) - } - } - } - } - var savedPaths []string - if len(req.Attachments) > 0 { - savedPaths, err = saveAttachmentsToDateAndConversationDir(req.Attachments, conversationID, h.logger) - if err != nil { - h.logger.Error("保存对话附件失败", zap.Error(err)) - sendEvent("error", "保存上传文件失败: "+err.Error(), nil) - return - } - } - // 仅将附件保存路径追加到 finalMessage,避免将文件内容内联到大模型上下文中 - finalMessage = appendAttachmentsToMessage(finalMessage, req.Attachments, savedPaths) - // 如果roleTools为空,表示使用所有工具(默认角色或未配置工具的角色) - - // 保存用户消息:有附件时一并保存附件名与路径,刷新后显示、继续对话时大模型也能从历史中拿到路径 - userContent := userMessageContentForStorage(req.Message, req.Attachments, savedPaths) - userMsgRow, err := h.db.AddMessage(conversationID, "user", userContent, nil) - if err != nil { - h.logger.Error("保存用户消息失败", zap.Error(err)) - } - - // 预先创建助手消息,以便关联过程详情 - assistantMsg, err := h.db.AddMessage(conversationID, "assistant", "处理中...", nil) - if err != nil { - h.logger.Error("创建助手消息失败", zap.Error(err)) - // 如果创建失败,继续执行但不保存过程详情 - assistantMsg = nil - } - - // 创建进度回调函数,同时保存到数据库 - var assistantMessageID string - if assistantMsg != nil { - assistantMessageID = assistantMsg.ID - } - - // 尽早下发消息 ID,便于前端在流式结束前挂上「删除本轮」等(无需等整段结束再刷新) - if userMsgRow != nil { - sendEvent("message_saved", "", map[string]interface{}{ - "conversationId": conversationID, - "userMessageId": userMsgRow.ID, - }) - } - - // 创建进度回调函数,复用统一逻辑 - progressCallback := h.createProgressCallback(conversationID, assistantMessageID, sendEvent) - - // 创建一个独立的上下文用于任务执行,不随HTTP请求取消 - // 这样即使客户端断开连接(如刷新页面),任务也能继续执行 - baseCtx, cancelWithCause := context.WithCancelCause(context.Background()) - taskCtx, timeoutCancel := context.WithTimeout(baseCtx, 600*time.Minute) - defer timeoutCancel() - defer cancelWithCause(nil) - - if _, err := h.tasks.StartTask(conversationID, req.Message, cancelWithCause); err != nil { - var errorMsg string - if errors.Is(err, ErrTaskAlreadyRunning) { - errorMsg = "⚠️ 当前会话已有任务正在执行中,请等待当前任务完成或点击「停止任务」按钮后再尝试。" - sendEvent("error", errorMsg, map[string]interface{}{ - "conversationId": conversationID, - "errorType": "task_already_running", - }) - } else { - errorMsg = "❌ 无法启动任务: " + err.Error() - sendEvent("error", errorMsg, map[string]interface{}{ - "conversationId": conversationID, - "errorType": "task_start_failed", - }) - } - - // 更新助手消息内容并保存错误详情到数据库 - if assistantMessageID != "" { - if _, updateErr := h.db.Exec( - "UPDATE messages SET content = ? WHERE id = ?", - errorMsg, - assistantMessageID, - ); updateErr != nil { - h.logger.Warn("更新错误后的助手消息失败", zap.Error(updateErr)) - } - // 保存错误详情到数据库 - if err := h.db.AddProcessDetail(assistantMessageID, conversationID, "error", errorMsg, map[string]interface{}{ - "errorType": func() string { - if errors.Is(err, ErrTaskAlreadyRunning) { - return "task_already_running" - } - return "task_start_failed" - }(), - }); err != nil { - h.logger.Warn("保存错误详情失败", zap.Error(err)) - } - } - - sendEvent("done", "", map[string]interface{}{ - "conversationId": conversationID, - }) - return - } - - taskStatus := "completed" - defer h.tasks.FinishTask(conversationID, taskStatus) - - // 执行Agent Loop,传入独立的上下文,确保任务不会因客户端断开而中断(使用包含角色提示词的finalMessage和角色工具列表) - sendEvent("progress", "正在分析您的请求...", nil) - // 注意:roleSkills 已在上方根据 req.Role 或 WebShell 模式设置 - stopKeepalive := make(chan struct{}) - go sseKeepalive(c, stopKeepalive, &sseWriteMu) - defer close(stopKeepalive) - - result, err := h.agent.AgentLoopWithProgress(taskCtx, finalMessage, agentHistoryMessages, conversationID, progressCallback, roleTools, roleSkills) - if err != nil { - h.logger.Error("Agent Loop执行失败", zap.Error(err)) - cause := context.Cause(baseCtx) - - // 检查是否是用户取消:context的cause是ErrTaskCancelled - // 如果cause是ErrTaskCancelled,无论错误是什么类型(包括context.Canceled),都视为用户取消 - // 这样可以正确处理在API调用过程中被取消的情况 - isCancelled := errors.Is(cause, ErrTaskCancelled) - - switch { - case isCancelled: - taskStatus = "cancelled" - cancelMsg := "任务已被用户取消,后续操作已停止。" - - // 在发送事件前更新任务状态,确保前端能及时看到状态变化 - h.tasks.UpdateTaskStatus(conversationID, taskStatus) - - if assistantMessageID != "" { - if _, updateErr := h.db.Exec( - "UPDATE messages SET content = ? WHERE id = ?", - cancelMsg, - assistantMessageID, - ); updateErr != nil { - h.logger.Warn("更新取消后的助手消息失败", zap.Error(updateErr)) - } - h.db.AddProcessDetail(assistantMessageID, conversationID, "cancelled", cancelMsg, nil) - } - - // 即使任务被取消,也尝试保存ReAct数据(如果result中有) - if result != nil && (result.LastReActInput != "" || result.LastReActOutput != "") { - if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil { - h.logger.Warn("保存取消任务的ReAct数据失败", zap.Error(err)) - } else { - h.logger.Info("已保存取消任务的ReAct数据", zap.String("conversationId", conversationID)) - } - } - - sendEvent("cancelled", cancelMsg, map[string]interface{}{ - "conversationId": conversationID, - "messageId": assistantMessageID, - }) - sendEvent("done", "", map[string]interface{}{ - "conversationId": conversationID, - }) - return - case errors.Is(err, context.DeadlineExceeded) || errors.Is(cause, context.DeadlineExceeded): - taskStatus = "timeout" - timeoutMsg := "任务执行超时,已自动终止。" - - // 在发送事件前更新任务状态,确保前端能及时看到状态变化 - h.tasks.UpdateTaskStatus(conversationID, taskStatus) - - if assistantMessageID != "" { - if _, updateErr := h.db.Exec( - "UPDATE messages SET content = ? WHERE id = ?", - timeoutMsg, - assistantMessageID, - ); updateErr != nil { - h.logger.Warn("更新超时后的助手消息失败", zap.Error(updateErr)) - } - h.db.AddProcessDetail(assistantMessageID, conversationID, "timeout", timeoutMsg, nil) - } - - // 即使任务超时,也尝试保存ReAct数据(如果result中有) - if result != nil && (result.LastReActInput != "" || result.LastReActOutput != "") { - if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil { - h.logger.Warn("保存超时任务的ReAct数据失败", zap.Error(err)) - } else { - h.logger.Info("已保存超时任务的ReAct数据", zap.String("conversationId", conversationID)) - } - } - - sendEvent("error", timeoutMsg, map[string]interface{}{ - "conversationId": conversationID, - "messageId": assistantMessageID, - }) - sendEvent("done", "", map[string]interface{}{ - "conversationId": conversationID, - }) - return - default: - taskStatus = "failed" - errorMsg := "执行失败: " + err.Error() - - // 在发送事件前更新任务状态,确保前端能及时看到状态变化 - h.tasks.UpdateTaskStatus(conversationID, taskStatus) - - if assistantMessageID != "" { - if _, updateErr := h.db.Exec( - "UPDATE messages SET content = ? WHERE id = ?", - errorMsg, - assistantMessageID, - ); updateErr != nil { - h.logger.Warn("更新失败后的助手消息失败", zap.Error(updateErr)) - } - h.db.AddProcessDetail(assistantMessageID, conversationID, "error", errorMsg, nil) - } - - // 即使任务失败,也尝试保存ReAct数据(如果result中有) - if result != nil && (result.LastReActInput != "" || result.LastReActOutput != "") { - if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil { - h.logger.Warn("保存失败任务的ReAct数据失败", zap.Error(err)) - } else { - h.logger.Info("已保存失败任务的ReAct数据", zap.String("conversationId", conversationID)) - } - } - - sendEvent("error", errorMsg, map[string]interface{}{ - "conversationId": conversationID, - "messageId": assistantMessageID, - }) - sendEvent("done", "", map[string]interface{}{ - "conversationId": conversationID, - }) - } - return - } - - // 更新助手消息内容 - if assistantMsg != nil { - _, err = h.db.Exec( - "UPDATE messages SET content = ?, mcp_execution_ids = ? WHERE id = ?", - result.Response, - func() string { - if len(result.MCPExecutionIDs) > 0 { - jsonData, _ := json.Marshal(result.MCPExecutionIDs) - return string(jsonData) - } - return "" - }(), - assistantMessageID, - ) - if err != nil { - h.logger.Error("更新助手消息失败", zap.Error(err)) - } - } else { - // 如果之前创建失败,现在创建 - _, err = h.db.AddMessage(conversationID, "assistant", result.Response, result.MCPExecutionIDs) - if err != nil { - h.logger.Error("保存助手消息失败", zap.Error(err)) - } - } - - // 保存最后一轮ReAct的输入和输出 - if result.LastReActInput != "" || result.LastReActOutput != "" { - if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil { - h.logger.Warn("保存ReAct数据失败", zap.Error(err)) - } else { - h.logger.Info("已保存ReAct数据", zap.String("conversationId", conversationID)) - } - } - - // 发送最终响应 - sendEvent("response", result.Response, map[string]interface{}{ - "mcpExecutionIds": result.MCPExecutionIDs, - "conversationId": conversationID, - "messageId": assistantMessageID, // 包含消息ID,以便前端关联过程详情 - }) - sendEvent("done", "", map[string]interface{}{ - "conversationId": conversationID, - }) -} - -// CancelAgentLoop 取消正在执行的任务 -func (h *AgentHandler) CancelAgentLoop(c *gin.Context) { - var req struct { - ConversationID string `json:"conversationId" binding:"required"` - } - - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - - ok, err := h.tasks.CancelTask(req.ConversationID, ErrTaskCancelled) - if err != nil { - h.logger.Error("取消任务失败", zap.Error(err)) - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - - if !ok { - c.JSON(http.StatusNotFound, gin.H{"error": "未找到正在执行的任务"}) - return - } - - c.JSON(http.StatusOK, gin.H{ - "status": "cancelling", - "conversationId": req.ConversationID, - "message": "已提交取消请求,任务将在当前步骤完成后停止。", - }) -} - -// ListAgentTasks 列出所有运行中的任务 -func (h *AgentHandler) ListAgentTasks(c *gin.Context) { - c.JSON(http.StatusOK, gin.H{ - "tasks": h.tasks.GetActiveTasks(), - }) -} - -// ListCompletedTasks 列出最近完成的任务历史 -func (h *AgentHandler) ListCompletedTasks(c *gin.Context) { - c.JSON(http.StatusOK, gin.H{ - "tasks": h.tasks.GetCompletedTasks(), - }) -} - -// BatchTaskRequest 批量任务请求 -type BatchTaskRequest struct { - Title string `json:"title"` // 任务标题(可选) - Tasks []string `json:"tasks" binding:"required"` // 任务列表,每行一个任务 - Role string `json:"role,omitempty"` // 角色名称(可选,空字符串表示默认角色) - AgentMode string `json:"agentMode,omitempty"` // single | multi - ScheduleMode string `json:"scheduleMode,omitempty"` // manual | cron - CronExpr string `json:"cronExpr,omitempty"` // scheduleMode=cron 时必填 - ExecuteNow bool `json:"executeNow,omitempty"` // 创建后是否立即执行(默认 false) -} - -func normalizeBatchQueueAgentMode(mode string) string { - if strings.TrimSpace(mode) == "multi" { - return "multi" - } - return "single" -} - -func normalizeBatchQueueScheduleMode(mode string) string { - if strings.TrimSpace(mode) == "cron" { - return "cron" - } - return "manual" -} - -// CreateBatchQueue 创建批量任务队列 -func (h *AgentHandler) CreateBatchQueue(c *gin.Context) { - var req BatchTaskRequest - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - - if len(req.Tasks) == 0 { - c.JSON(http.StatusBadRequest, gin.H{"error": "任务列表不能为空"}) - return - } - - // 过滤空任务 - validTasks := make([]string, 0, len(req.Tasks)) - for _, task := range req.Tasks { - if task != "" { - validTasks = append(validTasks, task) - } - } - - if len(validTasks) == 0 { - c.JSON(http.StatusBadRequest, gin.H{"error": "没有有效的任务"}) - return - } - - agentMode := normalizeBatchQueueAgentMode(req.AgentMode) - scheduleMode := normalizeBatchQueueScheduleMode(req.ScheduleMode) - cronExpr := strings.TrimSpace(req.CronExpr) - var nextRunAt *time.Time - if scheduleMode == "cron" { - if cronExpr == "" { - c.JSON(http.StatusBadRequest, gin.H{"error": "启用 Cron 调度时,调度表达式不能为空"}) - return - } - schedule, err := h.batchCronParser.Parse(cronExpr) - if err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": "无效的 Cron 表达式: " + err.Error()}) - return - } - next := schedule.Next(time.Now()) - nextRunAt = &next - } - - queue, createErr := h.batchTaskManager.CreateBatchQueue(req.Title, req.Role, agentMode, scheduleMode, cronExpr, nextRunAt, validTasks) - if createErr != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": createErr.Error()}) - return - } - started := false - if req.ExecuteNow { - ok, err := h.startBatchQueueExecution(queue.ID, false) - if !ok { - c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在"}) - return - } - if err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error(), "queueId": queue.ID}) - return - } - started = true - if refreshed, exists := h.batchTaskManager.GetBatchQueue(queue.ID); exists { - queue = refreshed - } - } - c.JSON(http.StatusOK, gin.H{ - "queueId": queue.ID, - "queue": queue, - "started": started, - }) -} - -// GetBatchQueue 获取批量任务队列 -func (h *AgentHandler) GetBatchQueue(c *gin.Context) { - queueID := c.Param("queueId") - queue, exists := h.batchTaskManager.GetBatchQueue(queueID) - if !exists { - c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在"}) - return - } - c.JSON(http.StatusOK, gin.H{"queue": queue}) -} - -// ListBatchQueuesResponse 批量任务队列列表响应 -type ListBatchQueuesResponse struct { - Queues []*BatchTaskQueue `json:"queues"` - Total int `json:"total"` - Page int `json:"page"` - PageSize int `json:"page_size"` - TotalPages int `json:"total_pages"` -} - -// ListBatchQueues 列出所有批量任务队列(支持筛选和分页) -func (h *AgentHandler) ListBatchQueues(c *gin.Context) { - limitStr := c.DefaultQuery("limit", "10") - offsetStr := c.DefaultQuery("offset", "0") - pageStr := c.Query("page") - status := c.Query("status") - keyword := c.Query("keyword") - - limit, _ := strconv.Atoi(limitStr) - offset, _ := strconv.Atoi(offsetStr) - page := 1 - - // 如果提供了page参数,优先使用page计算offset - if pageStr != "" { - if p, err := strconv.Atoi(pageStr); err == nil && p > 0 { - page = p - offset = (page - 1) * limit - } - } - - // 限制pageSize范围 - if limit <= 0 || limit > 100 { - limit = 10 - } - if offset < 0 { - offset = 0 - } - // 防止恶意大 offset 导致 DB 性能问题 - const maxOffset = 100000 - if offset > maxOffset { - offset = maxOffset - } - - // 默认status为"all" - if status == "" { - status = "all" - } - - // 获取队列列表和总数 - queues, total, err := h.batchTaskManager.ListQueues(limit, offset, status, keyword) - if err != nil { - h.logger.Error("获取批量任务队列列表失败", zap.Error(err)) - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - - // 计算总页数 - totalPages := (total + limit - 1) / limit - if totalPages == 0 { - totalPages = 1 - } - - // 如果使用offset计算page,需要重新计算 - if pageStr == "" { - page = (offset / limit) + 1 - } - - response := ListBatchQueuesResponse{ - Queues: queues, - Total: total, - Page: page, - PageSize: limit, - TotalPages: totalPages, - } - - c.JSON(http.StatusOK, response) -} - -// StartBatchQueue 开始执行批量任务队列 -func (h *AgentHandler) StartBatchQueue(c *gin.Context) { - queueID := c.Param("queueId") - ok, err := h.startBatchQueueExecution(queueID, false) - if err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - if !ok { - c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在"}) - return - } - c.JSON(http.StatusOK, gin.H{"message": "批量任务已开始执行", "queueId": queueID}) -} - -// RerunBatchQueue 重跑批量任务队列(重置所有子任务后重新执行) -func (h *AgentHandler) RerunBatchQueue(c *gin.Context) { - queueID := c.Param("queueId") - queue, exists := h.batchTaskManager.GetBatchQueue(queueID) - if !exists { - c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在"}) - return - } - if queue.Status != "completed" && queue.Status != "cancelled" { - c.JSON(http.StatusBadRequest, gin.H{"error": "仅已完成或已取消的队列可以重跑"}) - return - } - if !h.batchTaskManager.ResetQueueForRerun(queueID) { - c.JSON(http.StatusInternalServerError, gin.H{"error": "重置队列失败"}) - return - } - ok, err := h.startBatchQueueExecution(queueID, false) - if err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - if !ok { - c.JSON(http.StatusInternalServerError, gin.H{"error": "启动失败"}) - return - } - c.JSON(http.StatusOK, gin.H{"message": "批量任务已重新开始执行", "queueId": queueID}) -} - -// PauseBatchQueue 暂停批量任务队列 -func (h *AgentHandler) PauseBatchQueue(c *gin.Context) { - queueID := c.Param("queueId") - success := h.batchTaskManager.PauseQueue(queueID) - if !success { - c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在或无法暂停"}) - return - } - c.JSON(http.StatusOK, gin.H{"message": "批量任务已暂停"}) -} - -// UpdateBatchQueueMetadata 修改批量任务队列的标题、角色和代理模式 -func (h *AgentHandler) UpdateBatchQueueMetadata(c *gin.Context) { - queueID := c.Param("queueId") - var req struct { - Title string `json:"title"` - Role string `json:"role"` - AgentMode string `json:"agentMode"` - } - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - if err := h.batchTaskManager.UpdateQueueMetadata(queueID, req.Title, req.Role, req.AgentMode); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - updated, _ := h.batchTaskManager.GetBatchQueue(queueID) - c.JSON(http.StatusOK, gin.H{"queue": updated}) -} - -// UpdateBatchQueueSchedule 修改批量任务队列的调度配置(scheduleMode / cronExpr) -func (h *AgentHandler) UpdateBatchQueueSchedule(c *gin.Context) { - queueID := c.Param("queueId") - queue, exists := h.batchTaskManager.GetBatchQueue(queueID) - if !exists { - c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在"}) - return - } - // 仅在非 running 状态下允许修改调度 - if queue.Status == "running" { - c.JSON(http.StatusBadRequest, gin.H{"error": "队列正在运行中,无法修改调度配置"}) - return - } - var req struct { - ScheduleMode string `json:"scheduleMode"` - CronExpr string `json:"cronExpr"` - } - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - scheduleMode := normalizeBatchQueueScheduleMode(req.ScheduleMode) - cronExpr := strings.TrimSpace(req.CronExpr) - var nextRunAt *time.Time - if scheduleMode == "cron" { - if cronExpr == "" { - c.JSON(http.StatusBadRequest, gin.H{"error": "启用 Cron 调度时,调度表达式不能为空"}) - return - } - schedule, err := h.batchCronParser.Parse(cronExpr) - if err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": "无效的 Cron 表达式: " + err.Error()}) - return - } - next := schedule.Next(time.Now()) - nextRunAt = &next - } - h.batchTaskManager.UpdateQueueSchedule(queueID, scheduleMode, cronExpr, nextRunAt) - updated, _ := h.batchTaskManager.GetBatchQueue(queueID) - c.JSON(http.StatusOK, gin.H{"queue": updated}) -} - -// SetBatchQueueScheduleEnabled 开启/关闭 Cron 自动调度(手工执行不受影响) -func (h *AgentHandler) SetBatchQueueScheduleEnabled(c *gin.Context) { - queueID := c.Param("queueId") - if _, exists := h.batchTaskManager.GetBatchQueue(queueID); !exists { - c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在"}) - return - } - var req struct { - ScheduleEnabled bool `json:"scheduleEnabled"` - } - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - if !h.batchTaskManager.SetScheduleEnabled(queueID, req.ScheduleEnabled) { - c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在"}) - return - } - queue, _ := h.batchTaskManager.GetBatchQueue(queueID) - c.JSON(http.StatusOK, gin.H{"queue": queue}) -} - -// DeleteBatchQueue 删除批量任务队列 -func (h *AgentHandler) DeleteBatchQueue(c *gin.Context) { - queueID := c.Param("queueId") - success := h.batchTaskManager.DeleteQueue(queueID) - if !success { - c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在"}) - return - } - c.JSON(http.StatusOK, gin.H{"message": "批量任务队列已删除"}) -} - -// UpdateBatchTask 更新批量任务消息 -func (h *AgentHandler) UpdateBatchTask(c *gin.Context) { - queueID := c.Param("queueId") - taskID := c.Param("taskId") - - var req struct { - Message string `json:"message" binding:"required"` - } - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()}) - return - } - - if req.Message == "" { - c.JSON(http.StatusBadRequest, gin.H{"error": "任务消息不能为空"}) - return - } - - err := h.batchTaskManager.UpdateTaskMessage(queueID, taskID, req.Message) - if err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - - // 返回更新后的队列信息 - queue, exists := h.batchTaskManager.GetBatchQueue(queueID) - if !exists { - c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在"}) - return - } - c.JSON(http.StatusOK, gin.H{"message": "任务已更新", "queue": queue}) -} - -// AddBatchTask 添加任务到批量任务队列 -func (h *AgentHandler) AddBatchTask(c *gin.Context) { - queueID := c.Param("queueId") - - var req struct { - Message string `json:"message" binding:"required"` - } - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()}) - return - } - - if req.Message == "" { - c.JSON(http.StatusBadRequest, gin.H{"error": "任务消息不能为空"}) - return - } - - task, err := h.batchTaskManager.AddTaskToQueue(queueID, req.Message) - if err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - - // 返回更新后的队列信息 - queue, exists := h.batchTaskManager.GetBatchQueue(queueID) - if !exists { - c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在"}) - return - } - c.JSON(http.StatusOK, gin.H{"message": "任务已添加", "task": task, "queue": queue}) -} - -// DeleteBatchTask 删除批量任务 -func (h *AgentHandler) DeleteBatchTask(c *gin.Context) { - queueID := c.Param("queueId") - taskID := c.Param("taskId") - - err := h.batchTaskManager.DeleteTask(queueID, taskID) - if err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - - // 返回更新后的队列信息 - queue, exists := h.batchTaskManager.GetBatchQueue(queueID) - if !exists { - c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在"}) - return - } - c.JSON(http.StatusOK, gin.H{"message": "任务已删除", "queue": queue}) -} - -func (h *AgentHandler) markBatchQueueRunning(queueID string) bool { - h.batchRunnerMu.Lock() - defer h.batchRunnerMu.Unlock() - if _, exists := h.batchRunning[queueID]; exists { - return false - } - h.batchRunning[queueID] = struct{}{} - return true -} - -func (h *AgentHandler) unmarkBatchQueueRunning(queueID string) { - h.batchRunnerMu.Lock() - defer h.batchRunnerMu.Unlock() - delete(h.batchRunning, queueID) -} - -func (h *AgentHandler) nextBatchQueueRunAt(cronExpr string, from time.Time) (*time.Time, error) { - expr := strings.TrimSpace(cronExpr) - if expr == "" { - return nil, nil - } - schedule, err := h.batchCronParser.Parse(expr) - if err != nil { - return nil, err - } - next := schedule.Next(from) - return &next, nil -} - -func (h *AgentHandler) startBatchQueueExecution(queueID string, scheduled bool) (bool, error) { - queue, exists := h.batchTaskManager.GetBatchQueue(queueID) - if !exists { - return false, nil - } - if !h.markBatchQueueRunning(queueID) { - return true, nil - } - - if scheduled { - if queue.ScheduleMode != "cron" { - h.unmarkBatchQueueRunning(queueID) - err := fmt.Errorf("队列未启用 cron 调度") - h.batchTaskManager.SetLastScheduleError(queueID, err.Error()) - return true, err - } - if queue.Status == "running" || queue.Status == "paused" || queue.Status == "cancelled" { - h.unmarkBatchQueueRunning(queueID) - err := fmt.Errorf("当前队列状态不允许被调度执行") - h.batchTaskManager.SetLastScheduleError(queueID, err.Error()) - return true, err - } - if !h.batchTaskManager.ResetQueueForRerun(queueID) { - h.unmarkBatchQueueRunning(queueID) - err := fmt.Errorf("重置队列失败") - h.batchTaskManager.SetLastScheduleError(queueID, err.Error()) - return true, err - } - queue, _ = h.batchTaskManager.GetBatchQueue(queueID) - } else if queue.Status != "pending" && queue.Status != "paused" { - h.unmarkBatchQueueRunning(queueID) - return true, fmt.Errorf("队列状态不允许启动") - } - - if queue != nil && queue.AgentMode == "multi" && (h.config == nil || !h.config.MultiAgent.Enabled) { - h.unmarkBatchQueueRunning(queueID) - err := fmt.Errorf("当前队列配置为多代理,但系统未启用多代理") - if scheduled { - h.batchTaskManager.SetLastScheduleError(queueID, err.Error()) - } - return true, err - } - - if scheduled { - h.batchTaskManager.RecordScheduledRunStart(queueID) - } - h.batchTaskManager.UpdateQueueStatus(queueID, "running") - if queue != nil && queue.ScheduleMode == "cron" { - nextRunAt, err := h.nextBatchQueueRunAt(queue.CronExpr, time.Now()) - if err == nil { - h.batchTaskManager.UpdateQueueSchedule(queueID, "cron", queue.CronExpr, nextRunAt) - } - } - - go h.executeBatchQueue(queueID) - return true, nil -} - -func (h *AgentHandler) batchQueueSchedulerLoop() { - ticker := time.NewTicker(20 * time.Second) - defer ticker.Stop() - for range ticker.C { - queues := h.batchTaskManager.GetLoadedQueues() - now := time.Now() - for _, queue := range queues { - if queue == nil || queue.ScheduleMode != "cron" || !queue.ScheduleEnabled || queue.Status == "cancelled" || queue.Status == "running" || queue.Status == "paused" { - continue - } - nextRunAt := queue.NextRunAt - if nextRunAt == nil { - next, err := h.nextBatchQueueRunAt(queue.CronExpr, now) - if err != nil { - h.logger.Warn("批量任务 cron 表达式无效,跳过调度", zap.String("queueId", queue.ID), zap.String("cronExpr", queue.CronExpr), zap.Error(err)) - continue - } - h.batchTaskManager.UpdateQueueSchedule(queue.ID, "cron", queue.CronExpr, next) - nextRunAt = next - } - if nextRunAt != nil && (nextRunAt.Before(now) || nextRunAt.Equal(now)) { - if _, err := h.startBatchQueueExecution(queue.ID, true); err != nil { - h.logger.Warn("自动调度批量任务失败", zap.String("queueId", queue.ID), zap.Error(err)) - } - } - } - } -} - -// executeBatchQueue 执行批量任务队列 -func (h *AgentHandler) executeBatchQueue(queueID string) { - defer h.unmarkBatchQueueRunning(queueID) - h.logger.Info("开始执行批量任务队列", zap.String("queueId", queueID)) - - for { - // 检查队列状态 - queue, exists := h.batchTaskManager.GetBatchQueue(queueID) - if !exists || queue.Status == "cancelled" || queue.Status == "completed" || queue.Status == "paused" { - break - } - - // 获取下一个任务 - task, hasNext := h.batchTaskManager.GetNextTask(queueID) - if !hasNext { - // 所有任务完成:汇总子任务失败信息便于排障 - q, ok := h.batchTaskManager.GetBatchQueue(queueID) - lastRunErr := "" - if ok { - for _, t := range q.Tasks { - if t.Status == "failed" && t.Error != "" { - lastRunErr = t.Error - } - } - } - h.batchTaskManager.SetLastRunError(queueID, lastRunErr) - h.batchTaskManager.UpdateQueueStatus(queueID, "completed") - h.logger.Info("批量任务队列执行完成", zap.String("queueId", queueID)) - break - } - - // 更新任务状态为运行中 - h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, "running", "", "") - - // 创建新对话 - title := safeTruncateString(task.Message, 50) - conv, err := h.db.CreateConversation(title) - var conversationID string - if err != nil { - h.logger.Error("创建对话失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err)) - h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, "failed", "", "创建对话失败: "+err.Error()) - h.batchTaskManager.MoveToNextTask(queueID) - continue - } - conversationID = conv.ID - - // 保存conversationId到任务中(即使是运行中状态也要保存,以便查看对话) - h.batchTaskManager.UpdateTaskStatusWithConversationID(queueID, task.ID, "running", "", "", conversationID) - - // 应用角色用户提示词和工具配置 - finalMessage := task.Message - var roleTools []string // 角色配置的工具列表 - var roleSkills []string // 角色配置的skills列表(用于提示AI,但不硬编码内容) - if queue.Role != "" && queue.Role != "默认" { - if h.config.Roles != nil { - if role, exists := h.config.Roles[queue.Role]; exists && role.Enabled { - // 应用用户提示词 - if role.UserPrompt != "" { - finalMessage = role.UserPrompt + "\n\n" + task.Message - h.logger.Info("应用角色用户提示词", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("role", queue.Role)) - } - // 获取角色配置的工具列表(优先使用tools字段,向后兼容mcps字段) - if len(role.Tools) > 0 { - roleTools = role.Tools - h.logger.Info("使用角色配置的工具列表", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("role", queue.Role), zap.Int("toolCount", len(roleTools))) - } - // 获取角色配置的skills列表(用于在系统提示词中提示AI,但不硬编码内容) - if len(role.Skills) > 0 { - roleSkills = role.Skills - h.logger.Info("角色配置了skills,将在系统提示词中提示AI", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("role", queue.Role), zap.Int("skillCount", len(roleSkills)), zap.Strings("skills", roleSkills)) - } - } - } - } - - // 保存用户消息(保存原始消息,不包含角色提示词) - _, err = h.db.AddMessage(conversationID, "user", task.Message, nil) - if err != nil { - h.logger.Error("保存用户消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(err)) - } - - // 预先创建助手消息,以便关联过程详情 - assistantMsg, err := h.db.AddMessage(conversationID, "assistant", "处理中...", nil) - if err != nil { - h.logger.Error("创建助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(err)) - // 如果创建失败,继续执行但不保存过程详情 - assistantMsg = nil - } - - // 创建进度回调函数,复用统一逻辑(批量任务不需要流式事件,所以传入nil) - var assistantMessageID string - if assistantMsg != nil { - assistantMessageID = assistantMsg.ID - } - progressCallback := h.createProgressCallback(conversationID, assistantMessageID, nil) - - // 执行任务(使用包含角色提示词的finalMessage和角色工具列表) - h.logger.Info("执行批量任务", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("message", task.Message), zap.String("role", queue.Role), zap.String("conversationId", conversationID)) - - // 单个子任务超时时间:从30分钟调整为6小时,适配长时间渗透/扫描任务 - ctx, cancel := context.WithTimeout(context.Background(), 6*time.Hour) - // 存储取消函数,以便在取消队列时能够取消当前任务 - h.batchTaskManager.SetTaskCancel(queueID, cancel) - // 使用队列配置的角色工具列表(如果为空,表示使用所有工具) - // 注意:skills不会硬编码注入,但会在系统提示词中提示AI这个角色推荐使用哪些skills - useBatchMulti := false - if queue.AgentMode == "multi" { - useBatchMulti = h.config != nil && h.config.MultiAgent.Enabled - } else if queue.AgentMode == "" { - // 兼容历史数据:未配置队列代理模式时,沿用旧的系统级开关 - useBatchMulti = h.config != nil && h.config.MultiAgent.Enabled && h.config.MultiAgent.BatchUseMultiAgent - } - var result *agent.AgentLoopResult - var resultMA *multiagent.RunResult - var runErr error - if useBatchMulti { - resultMA, runErr = multiagent.RunDeepAgent(ctx, h.config, &h.config.MultiAgent, h.agent, h.logger, conversationID, finalMessage, []agent.ChatMessage{}, roleTools, progressCallback, h.agentsMarkdownDir) - } else { - result, runErr = h.agent.AgentLoopWithProgress(ctx, finalMessage, []agent.ChatMessage{}, conversationID, progressCallback, roleTools, roleSkills) - } - // 任务执行完成,清理取消函数 - h.batchTaskManager.SetTaskCancel(queueID, nil) - cancel() - - if runErr != nil { - // 检查是否是取消错误 - // 1. 直接检查是否是 context.Canceled(包括包装后的错误) - // 2. 检查错误消息中是否包含"context canceled"或"cancelled"关键字 - // 3. 检查 result.Response 中是否包含取消相关的消息 - errStr := runErr.Error() - partialResp := "" - if result != nil { - partialResp = result.Response - } else if resultMA != nil { - partialResp = resultMA.Response - } - isCancelled := errors.Is(runErr, context.Canceled) || - strings.Contains(strings.ToLower(errStr), "context canceled") || - strings.Contains(strings.ToLower(errStr), "context cancelled") || - (partialResp != "" && (strings.Contains(partialResp, "任务已被取消") || strings.Contains(partialResp, "任务执行中断"))) - - if isCancelled { - h.logger.Info("批量任务被取消", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID)) - cancelMsg := "任务已被用户取消,后续操作已停止。" - // 如果执行结果中有更具体的取消消息,使用它 - if partialResp != "" && (strings.Contains(partialResp, "任务已被取消") || strings.Contains(partialResp, "任务执行中断")) { - cancelMsg = partialResp - } - // 更新助手消息内容 - if assistantMessageID != "" { - if _, updateErr := h.db.Exec( - "UPDATE messages SET content = ? WHERE id = ?", - cancelMsg, - assistantMessageID, - ); updateErr != nil { - h.logger.Warn("更新取消后的助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(updateErr)) - } - // 保存取消详情到数据库 - if err := h.db.AddProcessDetail(assistantMessageID, conversationID, "cancelled", cancelMsg, nil); err != nil { - h.logger.Warn("保存取消详情失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err)) - } - } else { - // 如果没有预先创建的助手消息,创建一个新的 - _, errMsg := h.db.AddMessage(conversationID, "assistant", cancelMsg, nil) - if errMsg != nil { - h.logger.Warn("保存取消消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(errMsg)) - } - } - // 保存ReAct数据(如果存在) - if result != nil && (result.LastReActInput != "" || result.LastReActOutput != "") { - if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil { - h.logger.Warn("保存取消任务的ReAct数据失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err)) - } - } else if resultMA != nil && (resultMA.LastReActInput != "" || resultMA.LastReActOutput != "") { - if err := h.db.SaveReActData(conversationID, resultMA.LastReActInput, resultMA.LastReActOutput); err != nil { - h.logger.Warn("保存取消任务的ReAct数据失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err)) - } - } - h.batchTaskManager.UpdateTaskStatusWithConversationID(queueID, task.ID, "cancelled", cancelMsg, "", conversationID) - } else { - h.logger.Error("批量任务执行失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(runErr)) - errorMsg := "执行失败: " + runErr.Error() - // 更新助手消息内容 - if assistantMessageID != "" { - if _, updateErr := h.db.Exec( - "UPDATE messages SET content = ? WHERE id = ?", - errorMsg, - assistantMessageID, - ); updateErr != nil { - h.logger.Warn("更新失败后的助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(updateErr)) - } - // 保存错误详情到数据库 - if err := h.db.AddProcessDetail(assistantMessageID, conversationID, "error", errorMsg, nil); err != nil { - h.logger.Warn("保存错误详情失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err)) - } - } - h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, "failed", "", runErr.Error()) - } - } else { - h.logger.Info("批量任务执行成功", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID)) - - var resText string - var mcpIDs []string - var lastIn, lastOut string - if useBatchMulti { - resText = resultMA.Response - mcpIDs = resultMA.MCPExecutionIDs - lastIn = resultMA.LastReActInput - lastOut = resultMA.LastReActOutput - } else { - resText = result.Response - mcpIDs = result.MCPExecutionIDs - lastIn = result.LastReActInput - lastOut = result.LastReActOutput - } - - // 更新助手消息内容 - if assistantMessageID != "" { - mcpIDsJSON := "" - if len(mcpIDs) > 0 { - jsonData, _ := json.Marshal(mcpIDs) - mcpIDsJSON = string(jsonData) - } - if _, updateErr := h.db.Exec( - "UPDATE messages SET content = ?, mcp_execution_ids = ? WHERE id = ?", - resText, - mcpIDsJSON, - assistantMessageID, - ); updateErr != nil { - h.logger.Warn("更新助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(updateErr)) - // 如果更新失败,尝试创建新消息 - _, err = h.db.AddMessage(conversationID, "assistant", resText, mcpIDs) - if err != nil { - h.logger.Error("保存助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(err)) - } - } - } else { - // 如果没有预先创建的助手消息,创建一个新的 - _, err = h.db.AddMessage(conversationID, "assistant", resText, mcpIDs) - if err != nil { - h.logger.Error("保存助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(err)) - } - } - - // 保存ReAct数据 - if lastIn != "" || lastOut != "" { - if err := h.db.SaveReActData(conversationID, lastIn, lastOut); err != nil { - h.logger.Warn("保存ReAct数据失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err)) - } else { - h.logger.Info("已保存ReAct数据", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID)) - } - } - - // 保存结果 - h.batchTaskManager.UpdateTaskStatusWithConversationID(queueID, task.ID, "completed", resText, "", conversationID) - } - - // 移动到下一个任务 - h.batchTaskManager.MoveToNextTask(queueID) - - // 检查是否被取消或暂停 - queue, _ = h.batchTaskManager.GetBatchQueue(queueID) - if queue.Status == "cancelled" || queue.Status == "paused" { - break - } - } -} - -// loadHistoryFromReActData 从保存的ReAct数据恢复历史消息上下文 -// 采用与攻击链生成类似的拼接逻辑:优先使用保存的last_react_input和last_react_output,若不存在则回退到消息表 -func (h *AgentHandler) loadHistoryFromReActData(conversationID string) ([]agent.ChatMessage, error) { - // 获取保存的ReAct输入和输出 - reactInputJSON, reactOutput, err := h.db.GetReActData(conversationID) - if err != nil { - return nil, fmt.Errorf("获取ReAct数据失败: %w", err) - } - - // 如果last_react_input为空,回退到使用消息表(与攻击链生成逻辑一致) - if reactInputJSON == "" { - return nil, fmt.Errorf("ReAct数据为空,将使用消息表") - } - - dataSource := "database_last_react_input" - - // 解析JSON格式的messages数组 - var messagesArray []map[string]interface{} - if err := json.Unmarshal([]byte(reactInputJSON), &messagesArray); err != nil { - return nil, fmt.Errorf("解析ReAct输入JSON失败: %w", err) - } - - messageCount := len(messagesArray) - - h.logger.Info("使用保存的ReAct数据恢复历史上下文", - zap.String("conversationId", conversationID), - zap.String("dataSource", dataSource), - zap.Int("reactInputSize", len(reactInputJSON)), - zap.Int("messageCount", messageCount), - zap.Int("reactOutputSize", len(reactOutput)), - ) - // fmt.Println("messagesArray:", messagesArray)//debug - - // 转换为Agent消息格式 - agentMessages := make([]agent.ChatMessage, 0, len(messagesArray)) - for _, msgMap := range messagesArray { - msg := agent.ChatMessage{} - - // 解析role - if role, ok := msgMap["role"].(string); ok { - msg.Role = role - } else { - continue // 跳过无效消息 - } - - // 跳过system消息(AgentLoop会重新添加) - if msg.Role == "system" { - continue - } - - // 解析content - if content, ok := msgMap["content"].(string); ok { - msg.Content = content - } - - // 解析tool_calls(如果存在) - if toolCallsRaw, ok := msgMap["tool_calls"]; ok && toolCallsRaw != nil { - if toolCallsArray, ok := toolCallsRaw.([]interface{}); ok { - msg.ToolCalls = make([]agent.ToolCall, 0, len(toolCallsArray)) - for _, tcRaw := range toolCallsArray { - if tcMap, ok := tcRaw.(map[string]interface{}); ok { - toolCall := agent.ToolCall{} - - // 解析ID - if id, ok := tcMap["id"].(string); ok { - toolCall.ID = id - } - - // 解析Type - if toolType, ok := tcMap["type"].(string); ok { - toolCall.Type = toolType - } - - // 解析Function - if funcMap, ok := tcMap["function"].(map[string]interface{}); ok { - toolCall.Function = agent.FunctionCall{} - - // 解析函数名 - if name, ok := funcMap["name"].(string); ok { - toolCall.Function.Name = name - } - - // 解析arguments(可能是字符串或对象) - if argsRaw, ok := funcMap["arguments"]; ok { - if argsStr, ok := argsRaw.(string); ok { - // 如果是字符串,解析为JSON - var argsMap map[string]interface{} - if err := json.Unmarshal([]byte(argsStr), &argsMap); err == nil { - toolCall.Function.Arguments = argsMap - } - } else if argsMap, ok := argsRaw.(map[string]interface{}); ok { - // 如果已经是对象,直接使用 - toolCall.Function.Arguments = argsMap - } - } - } - - if toolCall.ID != "" { - msg.ToolCalls = append(msg.ToolCalls, toolCall) - } - } - } - } - } - - // 解析tool_call_id(tool角色消息) - if toolCallID, ok := msgMap["tool_call_id"].(string); ok { - msg.ToolCallID = toolCallID - } - - agentMessages = append(agentMessages, msg) - } - - // 如果存在last_react_output,需要将其作为最后一条assistant消息 - // 因为last_react_input是在迭代开始前保存的,不包含最后一轮的最终输出 - if reactOutput != "" { - // 检查最后一条消息是否是assistant消息且没有tool_calls - // 如果有tool_calls,说明后面应该还有tool消息和最终的assistant回复 - if len(agentMessages) > 0 { - lastMsg := &agentMessages[len(agentMessages)-1] - if strings.EqualFold(lastMsg.Role, "assistant") && len(lastMsg.ToolCalls) == 0 { - // 最后一条是assistant消息且没有tool_calls,用最终输出更新其content - lastMsg.Content = reactOutput - } else { - // 最后一条不是assistant消息,或者有tool_calls,添加最终输出作为新的assistant消息 - agentMessages = append(agentMessages, agent.ChatMessage{ - Role: "assistant", - Content: reactOutput, - }) - } - } else { - // 如果没有消息,直接添加最终输出 - agentMessages = append(agentMessages, agent.ChatMessage{ - Role: "assistant", - Content: reactOutput, - }) - } - } - - if len(agentMessages) == 0 { - return nil, fmt.Errorf("从ReAct数据解析的消息为空") - } - - // 修复可能存在的失配tool消息,避免OpenAI报错 - // 这可以防止出现"messages with role 'tool' must be a response to a preceeding message with 'tool_calls'"错误 - if h.agent != nil { - if fixed := h.agent.RepairOrphanToolMessages(&agentMessages); fixed { - h.logger.Info("修复了从ReAct数据恢复的历史消息中的失配tool消息", - zap.String("conversationId", conversationID), - ) - } - } - - h.logger.Info("从ReAct数据恢复历史消息完成", - zap.String("conversationId", conversationID), - zap.String("dataSource", dataSource), - zap.Int("originalMessageCount", messageCount), - zap.Int("finalMessageCount", len(agentMessages)), - zap.Bool("hasReactOutput", reactOutput != ""), - ) - fmt.Println("agentMessages:", agentMessages) //debug - return agentMessages, nil -} diff --git a/handler/attackchain.go b/handler/attackchain.go deleted file mode 100644 index 2b78b9bf..00000000 --- a/handler/attackchain.go +++ /dev/null @@ -1,173 +0,0 @@ -package handler - -import ( - "context" - "net/http" - "sync" - "time" - - "cyberstrike-ai/internal/attackchain" - "cyberstrike-ai/internal/config" - "cyberstrike-ai/internal/database" - - "github.com/gin-gonic/gin" - "go.uber.org/zap" -) - -// AttackChainHandler 攻击链处理器 -type AttackChainHandler struct { - db *database.DB - logger *zap.Logger - openAIConfig *config.OpenAIConfig - mu sync.RWMutex // 保护 openAIConfig 的并发访问 - // 用于防止同一对话的并发生成 - generatingLocks sync.Map // map[string]*sync.Mutex -} - -// NewAttackChainHandler 创建新的攻击链处理器 -func NewAttackChainHandler(db *database.DB, openAIConfig *config.OpenAIConfig, logger *zap.Logger) *AttackChainHandler { - return &AttackChainHandler{ - db: db, - logger: logger, - openAIConfig: openAIConfig, - } -} - -// UpdateConfig 更新OpenAI配置 -func (h *AttackChainHandler) UpdateConfig(cfg *config.OpenAIConfig) { - h.mu.Lock() - defer h.mu.Unlock() - h.openAIConfig = cfg - h.logger.Info("AttackChainHandler配置已更新", - zap.String("base_url", cfg.BaseURL), - zap.String("model", cfg.Model), - ) -} - -// getOpenAIConfig 获取OpenAI配置(线程安全) -func (h *AttackChainHandler) getOpenAIConfig() *config.OpenAIConfig { - h.mu.RLock() - defer h.mu.RUnlock() - return h.openAIConfig -} - -// GetAttackChain 获取攻击链(按需生成) -// GET /api/attack-chain/:conversationId -func (h *AttackChainHandler) GetAttackChain(c *gin.Context) { - conversationID := c.Param("conversationId") - if conversationID == "" { - c.JSON(http.StatusBadRequest, gin.H{"error": "conversationId is required"}) - return - } - - // 检查对话是否存在 - _, err := h.db.GetConversation(conversationID) - if err != nil { - h.logger.Warn("对话不存在", zap.String("conversationId", conversationID), zap.Error(err)) - c.JSON(http.StatusNotFound, gin.H{"error": "对话不存在"}) - return - } - - // 先尝试从数据库加载(如果已生成过) - openAIConfig := h.getOpenAIConfig() - builder := attackchain.NewBuilder(h.db, openAIConfig, h.logger) - chain, err := builder.LoadChainFromDatabase(conversationID) - if err == nil && len(chain.Nodes) > 0 { - // 如果已存在,直接返回 - h.logger.Info("返回已存在的攻击链", zap.String("conversationId", conversationID)) - c.JSON(http.StatusOK, chain) - return - } - - // 如果不存在,则生成新的攻击链(按需生成) - // 使用锁机制防止同一对话的并发生成 - lockInterface, _ := h.generatingLocks.LoadOrStore(conversationID, &sync.Mutex{}) - lock := lockInterface.(*sync.Mutex) - - // 尝试获取锁,如果正在生成则返回错误 - acquired := lock.TryLock() - if !acquired { - h.logger.Info("攻击链正在生成中,请稍后再试", zap.String("conversationId", conversationID)) - c.JSON(http.StatusConflict, gin.H{"error": "攻击链正在生成中,请稍后再试"}) - return - } - defer lock.Unlock() - - // 再次检查是否已生成(可能在等待锁的过程中已经生成完成) - chain, err = builder.LoadChainFromDatabase(conversationID) - if err == nil && len(chain.Nodes) > 0 { - h.logger.Info("返回已存在的攻击链(在锁等待期间已生成)", zap.String("conversationId", conversationID)) - c.JSON(http.StatusOK, chain) - return - } - - h.logger.Info("开始生成攻击链", zap.String("conversationId", conversationID)) - - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) - defer cancel() - - chain, err = builder.BuildChainFromConversation(ctx, conversationID) - if err != nil { - h.logger.Error("生成攻击链失败", zap.String("conversationId", conversationID), zap.Error(err)) - c.JSON(http.StatusInternalServerError, gin.H{"error": "生成攻击链失败: " + err.Error()}) - return - } - - // 生成完成后,从锁映射中删除(可选,保留也可以用于防止短时间内重复生成) - // h.generatingLocks.Delete(conversationID) - - c.JSON(http.StatusOK, chain) -} - -// RegenerateAttackChain 重新生成攻击链 -// POST /api/attack-chain/:conversationId/regenerate -func (h *AttackChainHandler) RegenerateAttackChain(c *gin.Context) { - conversationID := c.Param("conversationId") - if conversationID == "" { - c.JSON(http.StatusBadRequest, gin.H{"error": "conversationId is required"}) - return - } - - // 检查对话是否存在 - _, err := h.db.GetConversation(conversationID) - if err != nil { - h.logger.Warn("对话不存在", zap.String("conversationId", conversationID), zap.Error(err)) - c.JSON(http.StatusNotFound, gin.H{"error": "对话不存在"}) - return - } - - // 删除旧的攻击链 - if err := h.db.DeleteAttackChain(conversationID); err != nil { - h.logger.Warn("删除旧攻击链失败", zap.Error(err)) - } - - // 使用锁机制防止并发生成 - lockInterface, _ := h.generatingLocks.LoadOrStore(conversationID, &sync.Mutex{}) - lock := lockInterface.(*sync.Mutex) - - acquired := lock.TryLock() - if !acquired { - h.logger.Info("攻击链正在生成中,请稍后再试", zap.String("conversationId", conversationID)) - c.JSON(http.StatusConflict, gin.H{"error": "攻击链正在生成中,请稍后再试"}) - return - } - defer lock.Unlock() - - // 生成新的攻击链 - h.logger.Info("重新生成攻击链", zap.String("conversationId", conversationID)) - - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) - defer cancel() - - openAIConfig := h.getOpenAIConfig() - builder := attackchain.NewBuilder(h.db, openAIConfig, h.logger) - chain, err := builder.BuildChainFromConversation(ctx, conversationID) - if err != nil { - h.logger.Error("生成攻击链失败", zap.String("conversationId", conversationID), zap.Error(err)) - c.JSON(http.StatusInternalServerError, gin.H{"error": "生成攻击链失败: " + err.Error()}) - return - } - - c.JSON(http.StatusOK, chain) -} - diff --git a/handler/auth.go b/handler/auth.go deleted file mode 100644 index 508553c1..00000000 --- a/handler/auth.go +++ /dev/null @@ -1,156 +0,0 @@ -package handler - -import ( - "net/http" - "strings" - "time" - - "cyberstrike-ai/internal/config" - "cyberstrike-ai/internal/security" - - "github.com/gin-gonic/gin" - "go.uber.org/zap" -) - -// AuthHandler handles authentication-related endpoints. -type AuthHandler struct { - manager *security.AuthManager - config *config.Config - configPath string - logger *zap.Logger -} - -// NewAuthHandler creates a new AuthHandler. -func NewAuthHandler(manager *security.AuthManager, cfg *config.Config, configPath string, logger *zap.Logger) *AuthHandler { - return &AuthHandler{ - manager: manager, - config: cfg, - configPath: configPath, - logger: logger, - } -} - -type loginRequest struct { - Password string `json:"password" binding:"required"` -} - -type changePasswordRequest struct { - OldPassword string `json:"oldPassword"` - NewPassword string `json:"newPassword"` -} - -// Login verifies password and returns a session token. -func (h *AuthHandler) Login(c *gin.Context) { - var req loginRequest - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": "密码不能为空"}) - return - } - - token, expiresAt, err := h.manager.Authenticate(req.Password) - if err != nil { - c.JSON(http.StatusUnauthorized, gin.H{"error": "密码错误"}) - return - } - - c.JSON(http.StatusOK, gin.H{ - "token": token, - "expires_at": expiresAt.UTC().Format(time.RFC3339), - "session_duration_hr": h.manager.SessionDurationHours(), - }) -} - -// Logout revokes the current session token. -func (h *AuthHandler) Logout(c *gin.Context) { - token := c.GetString(security.ContextAuthTokenKey) - if token == "" { - authHeader := c.GetHeader("Authorization") - if len(authHeader) > 7 && strings.EqualFold(authHeader[:7], "Bearer ") { - token = strings.TrimSpace(authHeader[7:]) - } else { - token = strings.TrimSpace(authHeader) - } - } - - h.manager.RevokeToken(token) - c.JSON(http.StatusOK, gin.H{"message": "已退出登录"}) -} - -// ChangePassword updates the login password. -func (h *AuthHandler) ChangePassword(c *gin.Context) { - var req changePasswordRequest - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": "参数无效"}) - return - } - - oldPassword := strings.TrimSpace(req.OldPassword) - newPassword := strings.TrimSpace(req.NewPassword) - - if oldPassword == "" || newPassword == "" { - c.JSON(http.StatusBadRequest, gin.H{"error": "当前密码和新密码均不能为空"}) - return - } - - if len(newPassword) < 8 { - c.JSON(http.StatusBadRequest, gin.H{"error": "新密码长度至少需要 8 位"}) - return - } - - if oldPassword == newPassword { - c.JSON(http.StatusBadRequest, gin.H{"error": "新密码不能与旧密码相同"}) - return - } - - if !h.manager.CheckPassword(oldPassword) { - c.JSON(http.StatusBadRequest, gin.H{"error": "当前密码不正确"}) - return - } - - if err := config.PersistAuthPassword(h.configPath, newPassword); err != nil { - if h.logger != nil { - h.logger.Error("保存新密码失败", zap.Error(err)) - } - c.JSON(http.StatusInternalServerError, gin.H{"error": "保存新密码失败,请重试"}) - return - } - - if err := h.manager.UpdateConfig(newPassword, h.config.Auth.SessionDurationHours); err != nil { - if h.logger != nil { - h.logger.Error("更新认证配置失败", zap.Error(err)) - } - c.JSON(http.StatusInternalServerError, gin.H{"error": "更新认证配置失败"}) - return - } - - h.config.Auth.Password = newPassword - h.config.Auth.GeneratedPassword = "" - h.config.Auth.GeneratedPasswordPersisted = false - h.config.Auth.GeneratedPasswordPersistErr = "" - - if h.logger != nil { - h.logger.Info("登录密码已更新,所有会话已失效") - } - - c.JSON(http.StatusOK, gin.H{"message": "密码已更新,请使用新密码重新登录"}) -} - -// Validate returns the current session status. -func (h *AuthHandler) Validate(c *gin.Context) { - token := c.GetString(security.ContextAuthTokenKey) - if token == "" { - c.JSON(http.StatusUnauthorized, gin.H{"error": "会话无效"}) - return - } - - session, ok := h.manager.ValidateToken(token) - if !ok { - c.JSON(http.StatusUnauthorized, gin.H{"error": "会话已过期"}) - return - } - - c.JSON(http.StatusOK, gin.H{ - "token": session.Token, - "expires_at": session.ExpiresAt.UTC().Format(time.RFC3339), - }) -} diff --git a/handler/batch_task_manager.go b/handler/batch_task_manager.go deleted file mode 100644 index aef4c9e5..00000000 --- a/handler/batch_task_manager.go +++ /dev/null @@ -1,1122 +0,0 @@ -package handler - -import ( - "context" - "crypto/rand" - "encoding/hex" - "fmt" - "sort" - "strings" - "sync" - "time" - "unicode/utf8" - - "cyberstrike-ai/internal/database" - - "go.uber.org/zap" -) - -// 批量任务状态常量 -const ( - BatchQueueStatusPending = "pending" - BatchQueueStatusRunning = "running" - BatchQueueStatusPaused = "paused" - BatchQueueStatusCompleted = "completed" - BatchQueueStatusCancelled = "cancelled" - - BatchTaskStatusPending = "pending" - BatchTaskStatusRunning = "running" - BatchTaskStatusCompleted = "completed" - BatchTaskStatusFailed = "failed" - BatchTaskStatusCancelled = "cancelled" - - // MaxBatchTasksPerQueue 单个队列最大任务数 - MaxBatchTasksPerQueue = 10000 - - // MaxBatchQueueTitleLen 队列标题最大长度 - MaxBatchQueueTitleLen = 200 - - // MaxBatchQueueRoleLen 角色名最大长度 - MaxBatchQueueRoleLen = 100 -) - -// BatchTask 批量任务项 -type BatchTask struct { - ID string `json:"id"` - Message string `json:"message"` - ConversationID string `json:"conversationId,omitempty"` - Status string `json:"status"` // pending, running, completed, failed, cancelled - StartedAt *time.Time `json:"startedAt,omitempty"` - CompletedAt *time.Time `json:"completedAt,omitempty"` - Error string `json:"error,omitempty"` - Result string `json:"result,omitempty"` -} - -// BatchTaskQueue 批量任务队列 -type BatchTaskQueue struct { - ID string `json:"id"` - Title string `json:"title,omitempty"` - Role string `json:"role,omitempty"` // 角色名称(空字符串表示默认角色) - AgentMode string `json:"agentMode"` // single | multi - ScheduleMode string `json:"scheduleMode"` // manual | cron - CronExpr string `json:"cronExpr,omitempty"` - NextRunAt *time.Time `json:"nextRunAt,omitempty"` - ScheduleEnabled bool `json:"scheduleEnabled"` - LastScheduleTriggerAt *time.Time `json:"lastScheduleTriggerAt,omitempty"` - LastScheduleError string `json:"lastScheduleError,omitempty"` - LastRunError string `json:"lastRunError,omitempty"` - Tasks []*BatchTask `json:"tasks"` - Status string `json:"status"` // pending, running, paused, completed, cancelled - CreatedAt time.Time `json:"createdAt"` - StartedAt *time.Time `json:"startedAt,omitempty"` - CompletedAt *time.Time `json:"completedAt,omitempty"` - CurrentIndex int `json:"currentIndex"` -} - -// BatchTaskManager 批量任务管理器 -type BatchTaskManager struct { - db *database.DB - logger *zap.Logger - queues map[string]*BatchTaskQueue - taskCancels map[string]context.CancelFunc // 存储每个队列当前任务的取消函数 - mu sync.RWMutex -} - -// NewBatchTaskManager 创建批量任务管理器 -func NewBatchTaskManager(logger *zap.Logger) *BatchTaskManager { - if logger == nil { - logger = zap.NewNop() - } - return &BatchTaskManager{ - logger: logger, - queues: make(map[string]*BatchTaskQueue), - taskCancels: make(map[string]context.CancelFunc), - } -} - -// SetDB 设置数据库连接 -func (m *BatchTaskManager) SetDB(db *database.DB) { - m.mu.Lock() - defer m.mu.Unlock() - m.db = db -} - -// CreateBatchQueue 创建批量任务队列 -func (m *BatchTaskManager) CreateBatchQueue( - title, role, agentMode, scheduleMode, cronExpr string, - nextRunAt *time.Time, - tasks []string, -) (*BatchTaskQueue, error) { - // 输入校验 - if utf8.RuneCountInString(title) > MaxBatchQueueTitleLen { - return nil, fmt.Errorf("标题不能超过 %d 个字符", MaxBatchQueueTitleLen) - } - if utf8.RuneCountInString(role) > MaxBatchQueueRoleLen { - return nil, fmt.Errorf("角色名不能超过 %d 个字符", MaxBatchQueueRoleLen) - } - if len(tasks) > MaxBatchTasksPerQueue { - return nil, fmt.Errorf("单个队列最多 %d 条任务", MaxBatchTasksPerQueue) - } - - m.mu.Lock() - defer m.mu.Unlock() - - queueID := time.Now().Format("20060102150405") + "-" + generateShortID() - queue := &BatchTaskQueue{ - ID: queueID, - Title: title, - Role: role, - AgentMode: normalizeBatchQueueAgentMode(agentMode), - ScheduleMode: normalizeBatchQueueScheduleMode(scheduleMode), - CronExpr: strings.TrimSpace(cronExpr), - NextRunAt: nextRunAt, - ScheduleEnabled: true, - Tasks: make([]*BatchTask, 0, len(tasks)), - Status: BatchQueueStatusPending, - CreatedAt: time.Now(), - CurrentIndex: 0, - } - if queue.ScheduleMode != "cron" { - queue.CronExpr = "" - queue.NextRunAt = nil - } - - // 准备数据库保存的任务数据 - dbTasks := make([]map[string]interface{}, 0, len(tasks)) - - for _, message := range tasks { - if message == "" { - continue // 跳过空行 - } - taskID := generateShortID() - task := &BatchTask{ - ID: taskID, - Message: message, - Status: BatchTaskStatusPending, - } - queue.Tasks = append(queue.Tasks, task) - dbTasks = append(dbTasks, map[string]interface{}{ - "id": taskID, - "message": message, - }) - } - - // 保存到数据库 - if m.db != nil { - if err := m.db.CreateBatchQueue( - queueID, - title, - role, - queue.AgentMode, - queue.ScheduleMode, - queue.CronExpr, - queue.NextRunAt, - dbTasks, - ); err != nil { - m.logger.Warn("batch queue DB create failed", zap.String("queueId", queueID), zap.Error(err)) - } - } - - m.queues[queueID] = queue - return queue, nil -} - -// GetBatchQueue 获取批量任务队列 -func (m *BatchTaskManager) GetBatchQueue(queueID string) (*BatchTaskQueue, bool) { - m.mu.RLock() - queue, exists := m.queues[queueID] - m.mu.RUnlock() - - if exists { - return queue, true - } - - // 如果内存中不存在,尝试从数据库加载 - if m.db != nil { - if queue := m.loadQueueFromDB(queueID); queue != nil { - m.mu.Lock() - m.queues[queueID] = queue - m.mu.Unlock() - return queue, true - } - } - - return nil, false -} - -// loadQueueFromDB 从数据库加载单个队列 -func (m *BatchTaskManager) loadQueueFromDB(queueID string) *BatchTaskQueue { - if m.db == nil { - return nil - } - - queueRow, err := m.db.GetBatchQueue(queueID) - if err != nil || queueRow == nil { - return nil - } - - taskRows, err := m.db.GetBatchTasks(queueID) - if err != nil { - return nil - } - - queue := &BatchTaskQueue{ - ID: queueRow.ID, - AgentMode: "single", - ScheduleMode: "manual", - Status: queueRow.Status, - CreatedAt: queueRow.CreatedAt, - CurrentIndex: queueRow.CurrentIndex, - Tasks: make([]*BatchTask, 0, len(taskRows)), - } - - if queueRow.Title.Valid { - queue.Title = queueRow.Title.String - } - if queueRow.Role.Valid { - queue.Role = queueRow.Role.String - } - if queueRow.AgentMode.Valid { - queue.AgentMode = normalizeBatchQueueAgentMode(queueRow.AgentMode.String) - } - if queueRow.ScheduleMode.Valid { - queue.ScheduleMode = normalizeBatchQueueScheduleMode(queueRow.ScheduleMode.String) - } - if queueRow.CronExpr.Valid && queue.ScheduleMode == "cron" { - queue.CronExpr = strings.TrimSpace(queueRow.CronExpr.String) - } - if queueRow.NextRunAt.Valid && queue.ScheduleMode == "cron" { - t := queueRow.NextRunAt.Time - queue.NextRunAt = &t - } - queue.ScheduleEnabled = true - if queueRow.ScheduleEnabled.Valid && queueRow.ScheduleEnabled.Int64 == 0 { - queue.ScheduleEnabled = false - } - if queueRow.LastScheduleTriggerAt.Valid { - t := queueRow.LastScheduleTriggerAt.Time - queue.LastScheduleTriggerAt = &t - } - if queueRow.LastScheduleError.Valid { - queue.LastScheduleError = strings.TrimSpace(queueRow.LastScheduleError.String) - } - if queueRow.LastRunError.Valid { - queue.LastRunError = strings.TrimSpace(queueRow.LastRunError.String) - } - if queueRow.StartedAt.Valid { - queue.StartedAt = &queueRow.StartedAt.Time - } - if queueRow.CompletedAt.Valid { - queue.CompletedAt = &queueRow.CompletedAt.Time - } - - for _, taskRow := range taskRows { - task := &BatchTask{ - ID: taskRow.ID, - Message: taskRow.Message, - Status: taskRow.Status, - } - if taskRow.ConversationID.Valid { - task.ConversationID = taskRow.ConversationID.String - } - if taskRow.StartedAt.Valid { - task.StartedAt = &taskRow.StartedAt.Time - } - if taskRow.CompletedAt.Valid { - task.CompletedAt = &taskRow.CompletedAt.Time - } - if taskRow.Error.Valid { - task.Error = taskRow.Error.String - } - if taskRow.Result.Valid { - task.Result = taskRow.Result.String - } - queue.Tasks = append(queue.Tasks, task) - } - - return queue -} - -// GetLoadedQueues 获取内存中已加载的队列(不触发 DB 加载,仅用 RLock) -func (m *BatchTaskManager) GetLoadedQueues() []*BatchTaskQueue { - m.mu.RLock() - result := make([]*BatchTaskQueue, 0, len(m.queues)) - for _, queue := range m.queues { - result = append(result, queue) - } - m.mu.RUnlock() - return result -} - -// GetAllQueues 获取所有队列 -func (m *BatchTaskManager) GetAllQueues() []*BatchTaskQueue { - m.mu.RLock() - result := make([]*BatchTaskQueue, 0, len(m.queues)) - for _, queue := range m.queues { - result = append(result, queue) - } - m.mu.RUnlock() - - // 如果数据库可用,确保所有数据库中的队列都已加载到内存 - if m.db != nil { - dbQueues, err := m.db.GetAllBatchQueues() - if err == nil { - m.mu.Lock() - for _, queueRow := range dbQueues { - if _, exists := m.queues[queueRow.ID]; !exists { - if queue := m.loadQueueFromDB(queueRow.ID); queue != nil { - m.queues[queueRow.ID] = queue - result = append(result, queue) - } - } - } - m.mu.Unlock() - } - } - - return result -} - -// ListQueues 列出队列(支持筛选和分页) -func (m *BatchTaskManager) ListQueues(limit, offset int, status, keyword string) ([]*BatchTaskQueue, int, error) { - var queues []*BatchTaskQueue - var total int - - // 如果数据库可用,从数据库查询 - if m.db != nil { - // 获取总数 - count, err := m.db.CountBatchQueues(status, keyword) - if err != nil { - return nil, 0, fmt.Errorf("统计队列总数失败: %w", err) - } - total = count - - // 获取队列列表(只获取ID) - queueRows, err := m.db.ListBatchQueues(limit, offset, status, keyword) - if err != nil { - return nil, 0, fmt.Errorf("查询队列列表失败: %w", err) - } - - // 加载完整的队列信息(从内存或数据库) - m.mu.Lock() - for _, queueRow := range queueRows { - var queue *BatchTaskQueue - // 先从内存查找 - if cached, exists := m.queues[queueRow.ID]; exists { - queue = cached - } else { - // 从数据库加载 - queue = m.loadQueueFromDB(queueRow.ID) - if queue != nil { - m.queues[queueRow.ID] = queue - } - } - if queue != nil { - queues = append(queues, queue) - } - } - m.mu.Unlock() - } else { - // 没有数据库,从内存中筛选和分页 - m.mu.RLock() - allQueues := make([]*BatchTaskQueue, 0, len(m.queues)) - for _, queue := range m.queues { - allQueues = append(allQueues, queue) - } - m.mu.RUnlock() - - // 筛选 - filtered := make([]*BatchTaskQueue, 0) - for _, queue := range allQueues { - // 状态筛选 - if status != "" && status != "all" && queue.Status != status { - continue - } - // 关键字搜索(搜索队列ID和标题) - if keyword != "" { - keywordLower := strings.ToLower(keyword) - queueIDLower := strings.ToLower(queue.ID) - queueTitleLower := strings.ToLower(queue.Title) - if !strings.Contains(queueIDLower, keywordLower) && !strings.Contains(queueTitleLower, keywordLower) { - // 也可以搜索创建时间 - createdAtStr := queue.CreatedAt.Format("2006-01-02 15:04:05") - if !strings.Contains(createdAtStr, keyword) { - continue - } - } - } - filtered = append(filtered, queue) - } - - // 按创建时间倒序排序 - sort.Slice(filtered, func(i, j int) bool { - return filtered[i].CreatedAt.After(filtered[j].CreatedAt) - }) - - total = len(filtered) - - // 分页 - start := offset - if start > len(filtered) { - start = len(filtered) - } - end := start + limit - if end > len(filtered) { - end = len(filtered) - } - if start < len(filtered) { - queues = filtered[start:end] - } - } - - return queues, total, nil -} - -// LoadFromDB 从数据库加载所有队列 -func (m *BatchTaskManager) LoadFromDB() error { - if m.db == nil { - return nil - } - - queueRows, err := m.db.GetAllBatchQueues() - if err != nil { - return err - } - - m.mu.Lock() - defer m.mu.Unlock() - - for _, queueRow := range queueRows { - if _, exists := m.queues[queueRow.ID]; exists { - continue // 已存在,跳过 - } - - taskRows, err := m.db.GetBatchTasks(queueRow.ID) - if err != nil { - continue // 跳过加载失败的任务 - } - - queue := &BatchTaskQueue{ - ID: queueRow.ID, - AgentMode: "single", - ScheduleMode: "manual", - Status: queueRow.Status, - CreatedAt: queueRow.CreatedAt, - CurrentIndex: queueRow.CurrentIndex, - Tasks: make([]*BatchTask, 0, len(taskRows)), - } - - if queueRow.Title.Valid { - queue.Title = queueRow.Title.String - } - if queueRow.Role.Valid { - queue.Role = queueRow.Role.String - } - if queueRow.AgentMode.Valid { - queue.AgentMode = normalizeBatchQueueAgentMode(queueRow.AgentMode.String) - } - if queueRow.ScheduleMode.Valid { - queue.ScheduleMode = normalizeBatchQueueScheduleMode(queueRow.ScheduleMode.String) - } - if queueRow.CronExpr.Valid && queue.ScheduleMode == "cron" { - queue.CronExpr = strings.TrimSpace(queueRow.CronExpr.String) - } - if queueRow.NextRunAt.Valid && queue.ScheduleMode == "cron" { - t := queueRow.NextRunAt.Time - queue.NextRunAt = &t - } - queue.ScheduleEnabled = true - if queueRow.ScheduleEnabled.Valid && queueRow.ScheduleEnabled.Int64 == 0 { - queue.ScheduleEnabled = false - } - if queueRow.LastScheduleTriggerAt.Valid { - t := queueRow.LastScheduleTriggerAt.Time - queue.LastScheduleTriggerAt = &t - } - if queueRow.LastScheduleError.Valid { - queue.LastScheduleError = strings.TrimSpace(queueRow.LastScheduleError.String) - } - if queueRow.LastRunError.Valid { - queue.LastRunError = strings.TrimSpace(queueRow.LastRunError.String) - } - if queueRow.StartedAt.Valid { - queue.StartedAt = &queueRow.StartedAt.Time - } - if queueRow.CompletedAt.Valid { - queue.CompletedAt = &queueRow.CompletedAt.Time - } - - for _, taskRow := range taskRows { - task := &BatchTask{ - ID: taskRow.ID, - Message: taskRow.Message, - Status: taskRow.Status, - } - if taskRow.ConversationID.Valid { - task.ConversationID = taskRow.ConversationID.String - } - if taskRow.StartedAt.Valid { - task.StartedAt = &taskRow.StartedAt.Time - } - if taskRow.CompletedAt.Valid { - task.CompletedAt = &taskRow.CompletedAt.Time - } - if taskRow.Error.Valid { - task.Error = taskRow.Error.String - } - if taskRow.Result.Valid { - task.Result = taskRow.Result.String - } - queue.Tasks = append(queue.Tasks, task) - } - - m.queues[queueRow.ID] = queue - } - - return nil -} - -// UpdateTaskStatus 更新任务状态 -func (m *BatchTaskManager) UpdateTaskStatus(queueID, taskID, status string, result, errorMsg string) { - m.UpdateTaskStatusWithConversationID(queueID, taskID, status, result, errorMsg, "") -} - -// UpdateTaskStatusWithConversationID 更新任务状态(包含conversationId) -func (m *BatchTaskManager) UpdateTaskStatusWithConversationID(queueID, taskID, status string, result, errorMsg, conversationID string) { - var needDBUpdate bool - - // 在锁内只更新内存状态 - m.mu.Lock() - queue, exists := m.queues[queueID] - if !exists { - m.mu.Unlock() - return - } - - for _, task := range queue.Tasks { - if task.ID == taskID { - task.Status = status - if result != "" { - task.Result = result - } - if errorMsg != "" { - task.Error = errorMsg - } - if conversationID != "" { - task.ConversationID = conversationID - } - now := time.Now() - if status == BatchTaskStatusRunning && task.StartedAt == nil { - task.StartedAt = &now - } - if status == BatchTaskStatusCompleted || status == BatchTaskStatusFailed || status == BatchTaskStatusCancelled { - task.CompletedAt = &now - } - break - } - } - - needDBUpdate = m.db != nil - m.mu.Unlock() - - // 释放锁后写 DB - if needDBUpdate { - if err := m.db.UpdateBatchTaskStatus(queueID, taskID, status, conversationID, result, errorMsg); err != nil { - m.logger.Warn("batch task DB status update failed", zap.String("queueId", queueID), zap.String("taskId", taskID), zap.Error(err)) - } - } -} - -// UpdateQueueStatus 更新队列状态 -func (m *BatchTaskManager) UpdateQueueStatus(queueID, status string) { - var needDBUpdate bool - - // 在锁内只更新内存状态 - m.mu.Lock() - queue, exists := m.queues[queueID] - if !exists { - m.mu.Unlock() - return - } - - queue.Status = status - now := time.Now() - if status == BatchQueueStatusRunning && queue.StartedAt == nil { - queue.StartedAt = &now - } - if status == BatchQueueStatusCompleted || status == BatchQueueStatusCancelled { - queue.CompletedAt = &now - } - - needDBUpdate = m.db != nil - m.mu.Unlock() - - // 释放锁后写 DB - if needDBUpdate { - if err := m.db.UpdateBatchQueueStatus(queueID, status); err != nil { - m.logger.Warn("batch queue DB status update failed", zap.String("queueId", queueID), zap.Error(err)) - } - } -} - -// UpdateQueueSchedule 更新队列调度配置 -func (m *BatchTaskManager) UpdateQueueSchedule(queueID, scheduleMode, cronExpr string, nextRunAt *time.Time) { - m.mu.Lock() - defer m.mu.Unlock() - - queue, exists := m.queues[queueID] - if !exists { - return - } - - queue.ScheduleMode = normalizeBatchQueueScheduleMode(scheduleMode) - if queue.ScheduleMode == "cron" { - queue.CronExpr = strings.TrimSpace(cronExpr) - queue.NextRunAt = nextRunAt - } else { - queue.CronExpr = "" - queue.NextRunAt = nil - } - - if m.db != nil { - if err := m.db.UpdateBatchQueueSchedule(queueID, queue.ScheduleMode, queue.CronExpr, queue.NextRunAt); err != nil { - m.logger.Warn("batch queue DB schedule update failed", zap.String("queueId", queueID), zap.Error(err)) - } - } -} - -// UpdateQueueMetadata 更新队列标题、角色和代理模式(非 running 时可用) -func (m *BatchTaskManager) UpdateQueueMetadata(queueID, title, role, agentMode string) error { - if utf8.RuneCountInString(title) > MaxBatchQueueTitleLen { - return fmt.Errorf("标题不能超过 %d 个字符", MaxBatchQueueTitleLen) - } - if utf8.RuneCountInString(role) > MaxBatchQueueRoleLen { - return fmt.Errorf("角色名不能超过 %d 个字符", MaxBatchQueueRoleLen) - } - m.mu.Lock() - defer m.mu.Unlock() - - queue, exists := m.queues[queueID] - if !exists { - return fmt.Errorf("队列不存在") - } - if queue.Status == BatchQueueStatusRunning { - return fmt.Errorf("队列正在运行中,无法修改") - } - - // 如果未传 agentMode,保留原值 - if strings.TrimSpace(agentMode) != "" { - agentMode = normalizeBatchQueueAgentMode(agentMode) - } else { - agentMode = queue.AgentMode - } - - queue.Title = title - queue.Role = role - queue.AgentMode = agentMode - - if m.db != nil { - if err := m.db.UpdateBatchQueueMetadata(queueID, title, role, agentMode); err != nil { - m.logger.Warn("batch queue DB metadata update failed", zap.String("queueId", queueID), zap.Error(err)) - } - } - return nil -} - -// SetScheduleEnabled 暂停/恢复 Cron 自动调度(不影响手工执行) -func (m *BatchTaskManager) SetScheduleEnabled(queueID string, enabled bool) bool { - m.mu.Lock() - defer m.mu.Unlock() - - queue, exists := m.queues[queueID] - if !exists { - return false - } - queue.ScheduleEnabled = enabled - if m.db != nil { - _ = m.db.UpdateBatchQueueScheduleEnabled(queueID, enabled) - } - return true -} - -// RecordScheduledRunStart Cron 触发成功、即将执行子任务时调用 -func (m *BatchTaskManager) RecordScheduledRunStart(queueID string) { - now := time.Now() - m.mu.Lock() - defer m.mu.Unlock() - - queue, exists := m.queues[queueID] - if !exists { - return - } - queue.LastScheduleTriggerAt = &now - queue.LastScheduleError = "" - if m.db != nil { - _ = m.db.RecordBatchQueueScheduledTriggerStart(queueID, now) - } -} - -// SetLastScheduleError 调度层失败(未成功开始执行) -func (m *BatchTaskManager) SetLastScheduleError(queueID, msg string) { - m.mu.Lock() - defer m.mu.Unlock() - - queue, exists := m.queues[queueID] - if !exists { - return - } - queue.LastScheduleError = strings.TrimSpace(msg) - if m.db != nil { - _ = m.db.SetBatchQueueLastScheduleError(queueID, queue.LastScheduleError) - } -} - -// SetLastRunError 最近一轮批量执行中的失败摘要 -func (m *BatchTaskManager) SetLastRunError(queueID, msg string) { - msg = strings.TrimSpace(msg) - m.mu.Lock() - defer m.mu.Unlock() - - queue, exists := m.queues[queueID] - if !exists { - return - } - queue.LastRunError = msg - if m.db != nil { - _ = m.db.SetBatchQueueLastRunError(queueID, msg) - } -} - -// ResetQueueForRerun 重置队列与子任务状态,供 cron 下一轮执行 -func (m *BatchTaskManager) ResetQueueForRerun(queueID string) bool { - m.mu.Lock() - defer m.mu.Unlock() - - queue, exists := m.queues[queueID] - if !exists { - return false - } - queue.Status = BatchQueueStatusPending - queue.CurrentIndex = 0 - queue.StartedAt = nil - queue.CompletedAt = nil - queue.NextRunAt = nil - queue.LastRunError = "" - queue.LastScheduleError = "" - for _, task := range queue.Tasks { - task.Status = BatchTaskStatusPending - task.ConversationID = "" - task.StartedAt = nil - task.CompletedAt = nil - task.Error = "" - task.Result = "" - } - - if m.db != nil { - if err := m.db.ResetBatchQueueForRerun(queueID); err != nil { - return false - } - } - return true -} - -// UpdateTaskMessage 更新任务消息(队列空闲时可改;任务需非 running) -func (m *BatchTaskManager) UpdateTaskMessage(queueID, taskID, message string) error { - m.mu.Lock() - defer m.mu.Unlock() - - queue, exists := m.queues[queueID] - if !exists { - return fmt.Errorf("队列不存在") - } - - if !queueAllowsTaskListMutationLocked(queue) { - return fmt.Errorf("队列正在执行或未就绪,无法编辑任务") - } - - // 查找并更新任务 - for _, task := range queue.Tasks { - if task.ID == taskID { - if task.Status == BatchTaskStatusRunning { - return fmt.Errorf("执行中的任务不能编辑") - } - task.Message = message - - // 同步到数据库 - if m.db != nil { - if err := m.db.UpdateBatchTaskMessage(queueID, taskID, message); err != nil { - return fmt.Errorf("更新任务消息失败: %w", err) - } - } - return nil - } - } - - return fmt.Errorf("任务不存在") -} - -// AddTaskToQueue 添加任务到队列(队列空闲时可添加:含 cron 本轮 completed、手动暂停后等) -func (m *BatchTaskManager) AddTaskToQueue(queueID, message string) (*BatchTask, error) { - m.mu.Lock() - defer m.mu.Unlock() - - queue, exists := m.queues[queueID] - if !exists { - return nil, fmt.Errorf("队列不存在") - } - - if !queueAllowsTaskListMutationLocked(queue) { - return nil, fmt.Errorf("队列正在执行或未就绪,无法添加任务") - } - - if message == "" { - return nil, fmt.Errorf("任务消息不能为空") - } - - // 生成任务ID - taskID := generateShortID() - task := &BatchTask{ - ID: taskID, - Message: message, - Status: BatchTaskStatusPending, - } - - // 添加到内存队列 - queue.Tasks = append(queue.Tasks, task) - - // 同步到数据库 - if m.db != nil { - if err := m.db.AddBatchTask(queueID, taskID, message); err != nil { - // 如果数据库保存失败,从内存中移除 - queue.Tasks = queue.Tasks[:len(queue.Tasks)-1] - return nil, fmt.Errorf("添加任务失败: %w", err) - } - } - - return task, nil -} - -// DeleteTask 删除任务(队列空闲时可删;执行中任务不可删) -func (m *BatchTaskManager) DeleteTask(queueID, taskID string) error { - m.mu.Lock() - defer m.mu.Unlock() - - queue, exists := m.queues[queueID] - if !exists { - return fmt.Errorf("队列不存在") - } - - if !queueAllowsTaskListMutationLocked(queue) { - return fmt.Errorf("队列正在执行或未就绪,无法删除任务") - } - - // 查找并删除任务 - taskIndex := -1 - for i, task := range queue.Tasks { - if task.ID == taskID { - if task.Status == BatchTaskStatusRunning { - return fmt.Errorf("执行中的任务不能删除") - } - taskIndex = i - break - } - } - - if taskIndex == -1 { - return fmt.Errorf("任务不存在") - } - - // 从内存队列中删除 - queue.Tasks = append(queue.Tasks[:taskIndex], queue.Tasks[taskIndex+1:]...) - - // 同步到数据库 - if m.db != nil { - if err := m.db.DeleteBatchTask(queueID, taskID); err != nil { - // 如果数据库删除失败,恢复内存中的任务 - // 这里需要重新插入,但为了简化,我们只记录错误 - return fmt.Errorf("删除任务失败: %w", err) - } - } - - return nil -} - -func queueHasRunningTaskLocked(queue *BatchTaskQueue) bool { - if queue == nil { - return false - } - for _, t := range queue.Tasks { - if t != nil && t.Status == BatchTaskStatusRunning { - return true - } - } - return false -} - -// queueAllowsTaskListMutationLocked 是否允许增删改子任务文案/列表(必须在持有 BatchTaskManager.mu 下调用) -func queueAllowsTaskListMutationLocked(queue *BatchTaskQueue) bool { - if queue == nil { - return false - } - if queue.Status == BatchQueueStatusRunning { - return false - } - if queueHasRunningTaskLocked(queue) { - return false - } - switch queue.Status { - case BatchQueueStatusPending, BatchQueueStatusPaused, BatchQueueStatusCompleted, BatchQueueStatusCancelled: - return true - default: - return false - } -} - -// GetNextTask 获取下一个待执行的任务 -func (m *BatchTaskManager) GetNextTask(queueID string) (*BatchTask, bool) { - m.mu.Lock() - defer m.mu.Unlock() - - queue, exists := m.queues[queueID] - if !exists { - return nil, false - } - - for i := queue.CurrentIndex; i < len(queue.Tasks); i++ { - task := queue.Tasks[i] - if task.Status == BatchTaskStatusPending { - queue.CurrentIndex = i - return task, true - } - } - - return nil, false -} - -// MoveToNextTask 移动到下一个任务 -func (m *BatchTaskManager) MoveToNextTask(queueID string) { - m.mu.Lock() - defer m.mu.Unlock() - - queue, exists := m.queues[queueID] - if !exists { - return - } - - queue.CurrentIndex++ - - // 同步到数据库 - if m.db != nil { - if err := m.db.UpdateBatchQueueCurrentIndex(queueID, queue.CurrentIndex); err != nil { - m.logger.Warn("batch queue DB index update failed", zap.String("queueId", queueID), zap.Error(err)) - } - } -} - -// SetTaskCancel 设置当前任务的取消函数 -func (m *BatchTaskManager) SetTaskCancel(queueID string, cancel context.CancelFunc) { - m.mu.Lock() - defer m.mu.Unlock() - if cancel != nil { - m.taskCancels[queueID] = cancel - } else { - delete(m.taskCancels, queueID) - } -} - -// PauseQueue 暂停队列 -func (m *BatchTaskManager) PauseQueue(queueID string) bool { - var cancelFunc context.CancelFunc - var needDBUpdate bool - - // 在锁内只更新内存状态 - m.mu.Lock() - queue, exists := m.queues[queueID] - if !exists { - m.mu.Unlock() - return false - } - - if queue.Status != BatchQueueStatusRunning { - m.mu.Unlock() - return false - } - - queue.Status = BatchQueueStatusPaused - - // 取消当前正在执行的任务(通过取消context) - if cancel, ok := m.taskCancels[queueID]; ok { - cancelFunc = cancel - delete(m.taskCancels, queueID) - } - - needDBUpdate = m.db != nil - m.mu.Unlock() - - // 释放锁后执行取消回调 - if cancelFunc != nil { - cancelFunc() - } - - // 释放锁后写 DB - if needDBUpdate { - if err := m.db.UpdateBatchQueueStatus(queueID, BatchQueueStatusPaused); err != nil { - m.logger.Warn("batch queue DB pause update failed", zap.String("queueId", queueID), zap.Error(err)) - } - } - - return true -} - -// CancelQueue 取消队列(保留此方法以保持向后兼容,但建议使用PauseQueue) -func (m *BatchTaskManager) CancelQueue(queueID string) bool { - now := time.Now() - var cancelFunc context.CancelFunc - var needDBUpdate bool - - // 在锁内只更新内存状态,不做 DB 操作 - m.mu.Lock() - queue, exists := m.queues[queueID] - if !exists { - m.mu.Unlock() - return false - } - - if queue.Status == BatchQueueStatusCompleted || queue.Status == BatchQueueStatusCancelled { - m.mu.Unlock() - return false - } - - queue.Status = BatchQueueStatusCancelled - queue.CompletedAt = &now - - // 内存中批量标记所有 pending 任务为 cancelled - for _, task := range queue.Tasks { - if task.Status == BatchTaskStatusPending { - task.Status = BatchTaskStatusCancelled - task.CompletedAt = &now - } - } - - // 取消当前正在执行的任务 - if cancel, ok := m.taskCancels[queueID]; ok { - cancelFunc = cancel - delete(m.taskCancels, queueID) - } - - needDBUpdate = m.db != nil - m.mu.Unlock() - - // 释放锁后执行取消回调 - if cancelFunc != nil { - cancelFunc() - } - - // 释放锁后批量写 DB(单条 SQL 取消所有 pending 任务) - if needDBUpdate { - if err := m.db.CancelPendingBatchTasks(queueID, now); err != nil { - m.logger.Warn("batch task DB batch cancel failed", zap.String("queueId", queueID), zap.Error(err)) - } - if err := m.db.UpdateBatchQueueStatus(queueID, BatchQueueStatusCancelled); err != nil { - m.logger.Warn("batch queue DB cancel update failed", zap.String("queueId", queueID), zap.Error(err)) - } - } - - return true -} - -// DeleteQueue 删除队列(运行中的队列不允许删除) -func (m *BatchTaskManager) DeleteQueue(queueID string) bool { - m.mu.Lock() - defer m.mu.Unlock() - - queue, exists := m.queues[queueID] - if !exists { - return false - } - - // 运行中的队列不允许删除,防止孤儿协程和数据丢失 - if queue.Status == BatchQueueStatusRunning { - return false - } - - // 清理取消函数 - delete(m.taskCancels, queueID) - - // 从数据库删除 - if m.db != nil { - if err := m.db.DeleteBatchQueue(queueID); err != nil { - m.logger.Warn("batch queue DB delete failed", zap.String("queueId", queueID), zap.Error(err)) - } - } - - delete(m.queues, queueID) - return true -} - -// generateShortID 生成短ID -func generateShortID() string { - b := make([]byte, 4) - rand.Read(b) - return time.Now().Format("150405") + "-" + hex.EncodeToString(b) -} diff --git a/handler/batch_task_mcp.go b/handler/batch_task_mcp.go deleted file mode 100644 index 72ae8457..00000000 --- a/handler/batch_task_mcp.go +++ /dev/null @@ -1,813 +0,0 @@ -package handler - -import ( - "context" - "encoding/json" - "fmt" - "strconv" - "strings" - "time" - - "cyberstrike-ai/internal/mcp" - "cyberstrike-ai/internal/mcp/builtin" - - "go.uber.org/zap" -) - -// RegisterBatchTaskMCPTools 注册批量任务队列相关 MCP 工具(需传入已初始化 DB 的 AgentHandler) -func RegisterBatchTaskMCPTools(mcpServer *mcp.Server, h *AgentHandler, logger *zap.Logger) { - if mcpServer == nil || h == nil || logger == nil { - return - } - - reg := func(tool mcp.Tool, fn func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error)) { - mcpServer.RegisterTool(tool, fn) - } - - // --- list --- - reg(mcp.Tool{ - Name: builtin.ToolBatchTaskList, - Description: "列出批量任务队列(精简摘要,省上下文)。含队列元数据、子任务 id/status/截断后的 message、各状态计数。完整子任务(含 result/error/conversationId/时间等)请用 batch_task_get(queue_id)。", - ShortDescription: "列出批量任务队列", - InputSchema: map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "status": map[string]interface{}{ - "type": "string", - "description": "筛选状态:all(默认)、pending、running、paused、completed、cancelled", - "enum": []string{"all", "pending", "running", "paused", "completed", "cancelled"}, - }, - "keyword": map[string]interface{}{ - "type": "string", - "description": "按队列 ID 或标题模糊搜索", - }, - "page": map[string]interface{}{ - "type": "integer", - "description": "页码,从 1 开始,默认 1", - }, - "page_size": map[string]interface{}{ - "type": "integer", - "description": "每页条数,默认 20,最大 100", - }, - }, - }, - }, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) { - status := mcpArgString(args, "status") - if status == "" { - status = "all" - } - keyword := mcpArgString(args, "keyword") - page := int(mcpArgFloat(args, "page")) - if page <= 0 { - page = 1 - } - pageSize := int(mcpArgFloat(args, "page_size")) - if pageSize <= 0 { - pageSize = 20 - } - if pageSize > 100 { - pageSize = 100 - } - offset := (page - 1) * pageSize - if offset > 100000 { - offset = 100000 - } - queues, total, err := h.batchTaskManager.ListQueues(pageSize, offset, status, keyword) - if err != nil { - return batchMCPTextResult(fmt.Sprintf("列出队列失败: %v", err), true), nil - } - totalPages := (total + pageSize - 1) / pageSize - if totalPages == 0 { - totalPages = 1 - } - slim := make([]batchTaskQueueMCPListItem, 0, len(queues)) - for _, q := range queues { - if q == nil { - continue - } - slim = append(slim, toBatchTaskQueueMCPListItem(q)) - } - payload := map[string]interface{}{ - "queues": slim, - "total": total, - "page": page, - "page_size": pageSize, - "total_pages": totalPages, - } - logger.Info("MCP batch_task_list", zap.String("status", status), zap.Int("total", total)) - return batchMCPJSONResult(payload) - }) - - // --- get --- - reg(mcp.Tool{ - Name: builtin.ToolBatchTaskGet, - Description: "根据 queue_id 获取单个批量任务队列详情(含子任务列表、Cron、调度开关与最近错误信息)。", - ShortDescription: "获取批量任务队列详情", - InputSchema: map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "queue_id": map[string]interface{}{ - "type": "string", - "description": "队列 ID", - }, - }, - "required": []string{"queue_id"}, - }, - }, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) { - qid := mcpArgString(args, "queue_id") - if qid == "" { - return batchMCPTextResult("queue_id 不能为空", true), nil - } - queue, ok := h.batchTaskManager.GetBatchQueue(qid) - if !ok { - return batchMCPTextResult("队列不存在: "+qid, true), nil - } - return batchMCPJSONResult(queue) - }) - - // --- create --- - reg(mcp.Tool{ - Name: builtin.ToolBatchTaskCreate, - Description: `创建新的批量任务队列。任务列表使用 tasks(字符串数组)或 tasks_text(多行,每行一条)。 -agent_mode: single(默认)或 multi(需系统启用多代理)。schedule_mode: manual(默认)或 cron;为 cron 时必须提供 cron_expr(如 "0 */6 * * *")。 -默认创建后不会立即执行。可通过 execute_now=true 在创建后立即启动;也可后续调用 batch_task_start 手工启动。Cron 队列若需按表达式自动触发下一轮,还需保持调度开关开启(可用 batch_task_schedule_enabled)。`, - ShortDescription: "创建批量任务队列(可选立即执行)", - InputSchema: map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "title": map[string]interface{}{ - "type": "string", - "description": "可选标题", - }, - "role": map[string]interface{}{ - "type": "string", - "description": "角色名称,空表示默认", - }, - "tasks": map[string]interface{}{ - "type": "array", - "description": "任务指令列表,每项一条", - "items": map[string]interface{}{"type": "string"}, - }, - "tasks_text": map[string]interface{}{ - "type": "string", - "description": "多行文本,每行一条任务(与 tasks 二选一)", - }, - "agent_mode": map[string]interface{}{ - "type": "string", - "description": "single 或 multi", - "enum": []string{"single", "multi"}, - }, - "schedule_mode": map[string]interface{}{ - "type": "string", - "description": "manual 或 cron", - "enum": []string{"manual", "cron"}, - }, - "cron_expr": map[string]interface{}{ - "type": "string", - "description": "schedule_mode 为 cron 时必填。标准 5 段格式:分钟 小时 日 月 星期,例如 \"0 */6 * * *\"(每6小时)、\"30 2 * * 1-5\"(工作日凌晨2:30)", - }, - "execute_now": map[string]interface{}{ - "type": "boolean", - "description": "是否创建后立即执行,默认 false", - }, - }, - }, - }, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) { - tasks, errMsg := batchMCPTasksFromArgs(args) - if errMsg != "" { - return batchMCPTextResult(errMsg, true), nil - } - title := mcpArgString(args, "title") - role := mcpArgString(args, "role") - agentMode := normalizeBatchQueueAgentMode(mcpArgString(args, "agent_mode")) - scheduleMode := normalizeBatchQueueScheduleMode(mcpArgString(args, "schedule_mode")) - cronExpr := strings.TrimSpace(mcpArgString(args, "cron_expr")) - var nextRunAt *time.Time - if scheduleMode == "cron" { - if cronExpr == "" { - return batchMCPTextResult("Cron 调度模式下 cron_expr 不能为空", true), nil - } - sch, err := h.batchCronParser.Parse(cronExpr) - if err != nil { - return batchMCPTextResult("无效的 Cron 表达式: "+err.Error(), true), nil - } - n := sch.Next(time.Now()) - nextRunAt = &n - } - executeNow, ok := mcpArgBool(args, "execute_now") - if !ok { - executeNow = false - } - queue, createErr := h.batchTaskManager.CreateBatchQueue(title, role, agentMode, scheduleMode, cronExpr, nextRunAt, tasks) - if createErr != nil { - return batchMCPTextResult("创建队列失败: "+createErr.Error(), true), nil - } - started := false - if executeNow { - ok, err := h.startBatchQueueExecution(queue.ID, false) - if !ok { - return batchMCPTextResult("队列不存在: "+queue.ID, true), nil - } - if err != nil { - return batchMCPTextResult("创建成功但启动失败: "+err.Error(), true), nil - } - started = true - if refreshed, exists := h.batchTaskManager.GetBatchQueue(queue.ID); exists { - queue = refreshed - } - } - logger.Info("MCP batch_task_create", zap.String("queueId", queue.ID), zap.Int("taskCount", len(tasks))) - return batchMCPJSONResult(map[string]interface{}{ - "queue_id": queue.ID, - "queue": queue, - "started": started, - "execute_now": executeNow, - "reminder": func() string { - if started { - return "队列已创建并立即启动。" - } - return "队列已创建,当前为 pending。需要开始执行时请调用 MCP 工具 batch_task_start(queue_id 同上)。Cron 自动调度需 schedule_enabled 为 true,可用 batch_task_schedule_enabled。" - }(), - }) - }) - - // --- start --- - reg(mcp.Tool{ - Name: builtin.ToolBatchTaskStart, - Description: `启动或继续执行批量任务队列(pending / paused)。 -与 batch_task_create 配合使用:仅创建队列不会自动执行,需调用本工具才会开始跑子任务。`, - ShortDescription: "启动/继续批量任务队列(创建后需调用才会执行)", - InputSchema: map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "queue_id": map[string]interface{}{ - "type": "string", - "description": "队列 ID", - }, - }, - "required": []string{"queue_id"}, - }, - }, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) { - qid := mcpArgString(args, "queue_id") - if qid == "" { - return batchMCPTextResult("queue_id 不能为空", true), nil - } - ok, err := h.startBatchQueueExecution(qid, false) - if !ok { - return batchMCPTextResult("队列不存在: "+qid, true), nil - } - if err != nil { - return batchMCPTextResult("启动失败: "+err.Error(), true), nil - } - logger.Info("MCP batch_task_start", zap.String("queueId", qid)) - return batchMCPTextResult("已提交启动,队列将开始执行。", false), nil - }) - - // --- rerun (reset + start for completed/cancelled queues) --- - reg(mcp.Tool{ - Name: builtin.ToolBatchTaskRerun, - Description: "重跑已完成或已取消的批量任务队列。会重置所有子任务状态后重新执行一轮。", - ShortDescription: "重跑批量任务队列", - InputSchema: map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "queue_id": map[string]interface{}{ - "type": "string", - "description": "队列 ID", - }, - }, - "required": []string{"queue_id"}, - }, - }, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) { - qid := mcpArgString(args, "queue_id") - if qid == "" { - return batchMCPTextResult("queue_id 不能为空", true), nil - } - queue, exists := h.batchTaskManager.GetBatchQueue(qid) - if !exists { - return batchMCPTextResult("队列不存在: "+qid, true), nil - } - if queue.Status != "completed" && queue.Status != "cancelled" { - return batchMCPTextResult("仅已完成或已取消的队列可以重跑,当前状态: "+queue.Status, true), nil - } - if !h.batchTaskManager.ResetQueueForRerun(qid) { - return batchMCPTextResult("重置队列失败", true), nil - } - ok, err := h.startBatchQueueExecution(qid, false) - if !ok { - return batchMCPTextResult("启动失败", true), nil - } - if err != nil { - return batchMCPTextResult("启动失败: "+err.Error(), true), nil - } - logger.Info("MCP batch_task_rerun", zap.String("queueId", qid)) - return batchMCPTextResult("已重置并重新启动队列。", false), nil - }) - - // --- pause --- - reg(mcp.Tool{ - Name: builtin.ToolBatchTaskPause, - Description: "暂停正在运行的批量任务队列(当前子任务会被取消)。", - ShortDescription: "暂停批量任务队列", - InputSchema: map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "queue_id": map[string]interface{}{ - "type": "string", - "description": "队列 ID", - }, - }, - "required": []string{"queue_id"}, - }, - }, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) { - qid := mcpArgString(args, "queue_id") - if qid == "" { - return batchMCPTextResult("queue_id 不能为空", true), nil - } - if !h.batchTaskManager.PauseQueue(qid) { - return batchMCPTextResult("无法暂停:队列不存在或当前非 running 状态", true), nil - } - logger.Info("MCP batch_task_pause", zap.String("queueId", qid)) - return batchMCPTextResult("队列已暂停。", false), nil - }) - - // --- delete queue --- - reg(mcp.Tool{ - Name: builtin.ToolBatchTaskDelete, - Description: "删除批量任务队列及其子任务记录。", - ShortDescription: "删除批量任务队列", - InputSchema: map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "queue_id": map[string]interface{}{ - "type": "string", - "description": "队列 ID", - }, - }, - "required": []string{"queue_id"}, - }, - }, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) { - qid := mcpArgString(args, "queue_id") - if qid == "" { - return batchMCPTextResult("queue_id 不能为空", true), nil - } - if !h.batchTaskManager.DeleteQueue(qid) { - return batchMCPTextResult("删除失败:队列不存在", true), nil - } - logger.Info("MCP batch_task_delete", zap.String("queueId", qid)) - return batchMCPTextResult("队列已删除。", false), nil - }) - - // --- update metadata (title/role/agentMode) --- - reg(mcp.Tool{ - Name: builtin.ToolBatchTaskUpdateMetadata, - Description: "修改批量任务队列的标题、角色和代理模式。仅在队列非 running 状态下可修改。", - ShortDescription: "修改批量任务队列标题/角色/代理模式", - InputSchema: map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "queue_id": map[string]interface{}{ - "type": "string", - "description": "队列 ID", - }, - "title": map[string]interface{}{ - "type": "string", - "description": "新标题(空字符串清除标题)", - }, - "role": map[string]interface{}{ - "type": "string", - "description": "新角色名(空字符串使用默认角色)", - }, - "agent_mode": map[string]interface{}{ - "type": "string", - "description": "代理模式:single(单代理 ReAct)或 multi(多代理)", - "enum": []string{"single", "multi"}, - }, - }, - "required": []string{"queue_id"}, - }, - }, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) { - qid := mcpArgString(args, "queue_id") - if qid == "" { - return batchMCPTextResult("queue_id 不能为空", true), nil - } - title := mcpArgString(args, "title") - role := mcpArgString(args, "role") - agentMode := mcpArgString(args, "agent_mode") - if err := h.batchTaskManager.UpdateQueueMetadata(qid, title, role, agentMode); err != nil { - return batchMCPTextResult(err.Error(), true), nil - } - updated, _ := h.batchTaskManager.GetBatchQueue(qid) - logger.Info("MCP batch_task_update_metadata", zap.String("queueId", qid)) - return batchMCPJSONResult(updated) - }) - - // --- update schedule --- - reg(mcp.Tool{ - Name: builtin.ToolBatchTaskUpdateSchedule, - Description: `修改批量任务队列的调度方式和 Cron 表达式。仅在队列非 running 状态下可修改。 -schedule_mode 为 cron 时必须提供有效 cron_expr;为 manual 时会清除 Cron 配置。`, - ShortDescription: "修改批量任务调度配置(Cron 表达式)", - InputSchema: map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "queue_id": map[string]interface{}{ - "type": "string", - "description": "队列 ID", - }, - "schedule_mode": map[string]interface{}{ - "type": "string", - "description": "manual 或 cron", - "enum": []string{"manual", "cron"}, - }, - "cron_expr": map[string]interface{}{ - "type": "string", - "description": "Cron 表达式(schedule_mode 为 cron 时必填)。标准 5 段格式:分钟 小时 日 月 星期,如 \"0 */6 * * *\"(每6小时)、\"30 2 * * 1-5\"(工作日凌晨2:30)", - }, - }, - "required": []string{"queue_id", "schedule_mode"}, - }, - }, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) { - qid := mcpArgString(args, "queue_id") - if qid == "" { - return batchMCPTextResult("queue_id 不能为空", true), nil - } - queue, exists := h.batchTaskManager.GetBatchQueue(qid) - if !exists { - return batchMCPTextResult("队列不存在: "+qid, true), nil - } - if queue.Status == "running" { - return batchMCPTextResult("队列正在运行中,无法修改调度配置", true), nil - } - scheduleMode := normalizeBatchQueueScheduleMode(mcpArgString(args, "schedule_mode")) - cronExpr := strings.TrimSpace(mcpArgString(args, "cron_expr")) - var nextRunAt *time.Time - if scheduleMode == "cron" { - if cronExpr == "" { - return batchMCPTextResult("Cron 调度模式下 cron_expr 不能为空", true), nil - } - sch, err := h.batchCronParser.Parse(cronExpr) - if err != nil { - return batchMCPTextResult("无效的 Cron 表达式: "+err.Error(), true), nil - } - n := sch.Next(time.Now()) - nextRunAt = &n - } - h.batchTaskManager.UpdateQueueSchedule(qid, scheduleMode, cronExpr, nextRunAt) - updated, _ := h.batchTaskManager.GetBatchQueue(qid) - logger.Info("MCP batch_task_update_schedule", zap.String("queueId", qid), zap.String("scheduleMode", scheduleMode), zap.String("cronExpr", cronExpr)) - return batchMCPJSONResult(updated) - }) - - // --- schedule enabled --- - reg(mcp.Tool{ - Name: builtin.ToolBatchTaskScheduleEnabled, - Description: `设置是否允许 Cron 自动触发该队列。关闭后仍保留 Cron 表达式,仅停止定时自动跑;可用手工「启动」执行。 -仅对 schedule_mode 为 cron 的队列有意义。`, - ShortDescription: "开关批量任务 Cron 自动调度", - InputSchema: map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "queue_id": map[string]interface{}{ - "type": "string", - "description": "队列 ID", - }, - "schedule_enabled": map[string]interface{}{ - "type": "boolean", - "description": "true 允许定时触发,false 仅手工执行", - }, - }, - "required": []string{"queue_id", "schedule_enabled"}, - }, - }, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) { - qid := mcpArgString(args, "queue_id") - if qid == "" { - return batchMCPTextResult("queue_id 不能为空", true), nil - } - en, ok := mcpArgBool(args, "schedule_enabled") - if !ok { - return batchMCPTextResult("schedule_enabled 必须为布尔值", true), nil - } - if _, exists := h.batchTaskManager.GetBatchQueue(qid); !exists { - return batchMCPTextResult("队列不存在", true), nil - } - if !h.batchTaskManager.SetScheduleEnabled(qid, en) { - return batchMCPTextResult("更新失败", true), nil - } - queue, _ := h.batchTaskManager.GetBatchQueue(qid) - logger.Info("MCP batch_task_schedule_enabled", zap.String("queueId", qid), zap.Bool("enabled", en)) - return batchMCPJSONResult(queue) - }) - - // --- add task --- - reg(mcp.Tool{ - Name: builtin.ToolBatchTaskAdd, - Description: "向处于 pending 状态的队列追加一条子任务。", - ShortDescription: "批量队列添加子任务", - InputSchema: map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "queue_id": map[string]interface{}{ - "type": "string", - "description": "队列 ID", - }, - "message": map[string]interface{}{ - "type": "string", - "description": "任务指令内容", - }, - }, - "required": []string{"queue_id", "message"}, - }, - }, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) { - qid := mcpArgString(args, "queue_id") - msg := strings.TrimSpace(mcpArgString(args, "message")) - if qid == "" || msg == "" { - return batchMCPTextResult("queue_id 与 message 均不能为空", true), nil - } - task, err := h.batchTaskManager.AddTaskToQueue(qid, msg) - if err != nil { - return batchMCPTextResult(err.Error(), true), nil - } - queue, _ := h.batchTaskManager.GetBatchQueue(qid) - logger.Info("MCP batch_task_add_task", zap.String("queueId", qid), zap.String("taskId", task.ID)) - return batchMCPJSONResult(map[string]interface{}{"task": task, "queue": queue}) - }) - - // --- update task --- - reg(mcp.Tool{ - Name: builtin.ToolBatchTaskUpdate, - Description: "修改 pending 队列中仍为 pending 的子任务文案。", - ShortDescription: "更新批量子任务内容", - InputSchema: map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "queue_id": map[string]interface{}{ - "type": "string", - "description": "队列 ID", - }, - "task_id": map[string]interface{}{ - "type": "string", - "description": "子任务 ID", - }, - "message": map[string]interface{}{ - "type": "string", - "description": "新的任务指令", - }, - }, - "required": []string{"queue_id", "task_id", "message"}, - }, - }, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) { - qid := mcpArgString(args, "queue_id") - tid := mcpArgString(args, "task_id") - msg := strings.TrimSpace(mcpArgString(args, "message")) - if qid == "" || tid == "" || msg == "" { - return batchMCPTextResult("queue_id、task_id、message 均不能为空", true), nil - } - if err := h.batchTaskManager.UpdateTaskMessage(qid, tid, msg); err != nil { - return batchMCPTextResult(err.Error(), true), nil - } - queue, _ := h.batchTaskManager.GetBatchQueue(qid) - logger.Info("MCP batch_task_update_task", zap.String("queueId", qid), zap.String("taskId", tid)) - return batchMCPJSONResult(queue) - }) - - // --- remove task --- - reg(mcp.Tool{ - Name: builtin.ToolBatchTaskRemove, - Description: "从 pending 队列中删除仍为 pending 的子任务。", - ShortDescription: "删除批量子任务", - InputSchema: map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "queue_id": map[string]interface{}{ - "type": "string", - "description": "队列 ID", - }, - "task_id": map[string]interface{}{ - "type": "string", - "description": "子任务 ID", - }, - }, - "required": []string{"queue_id", "task_id"}, - }, - }, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) { - qid := mcpArgString(args, "queue_id") - tid := mcpArgString(args, "task_id") - if qid == "" || tid == "" { - return batchMCPTextResult("queue_id 与 task_id 均不能为空", true), nil - } - if err := h.batchTaskManager.DeleteTask(qid, tid); err != nil { - return batchMCPTextResult(err.Error(), true), nil - } - queue, _ := h.batchTaskManager.GetBatchQueue(qid) - logger.Info("MCP batch_task_remove_task", zap.String("queueId", qid), zap.String("taskId", tid)) - return batchMCPJSONResult(queue) - }) - - logger.Info("批量任务 MCP 工具已注册", zap.Int("count", 12)) -} - -// --- batch_task_list 精简结构(避免把每条子任务的 result 等大段文本塞进列表上下文) --- - -const mcpBatchListTaskMessageMaxRunes = 160 - -// batchTaskMCPListSummary 列表中的子任务摘要(完整字段用 batch_task_get) -type batchTaskMCPListSummary struct { - ID string `json:"id"` - Status string `json:"status"` - Message string `json:"message,omitempty"` -} - -// batchTaskQueueMCPListItem 列表中的队列摘要 -type batchTaskQueueMCPListItem struct { - ID string `json:"id"` - Title string `json:"title,omitempty"` - Role string `json:"role,omitempty"` - AgentMode string `json:"agentMode"` - ScheduleMode string `json:"scheduleMode"` - CronExpr string `json:"cronExpr,omitempty"` - NextRunAt *time.Time `json:"nextRunAt,omitempty"` - ScheduleEnabled bool `json:"scheduleEnabled"` - LastScheduleTriggerAt *time.Time `json:"lastScheduleTriggerAt,omitempty"` - Status string `json:"status"` - CreatedAt time.Time `json:"createdAt"` - StartedAt *time.Time `json:"startedAt,omitempty"` - CompletedAt *time.Time `json:"completedAt,omitempty"` - CurrentIndex int `json:"currentIndex"` - TaskTotal int `json:"task_total"` - TaskCounts map[string]int `json:"task_counts"` - Tasks []batchTaskMCPListSummary `json:"tasks"` -} - -func truncateStringRunes(s string, maxRunes int) string { - if maxRunes <= 0 { - return "" - } - n := 0 - for i := range s { - if n == maxRunes { - out := strings.TrimSpace(s[:i]) - if out == "" { - return "…" - } - return out + "…" - } - n++ - } - return s -} - -const mcpBatchListMaxTasksPerQueue = 200 // 列表中每个队列最多返回的子任务摘要数 - -func toBatchTaskQueueMCPListItem(q *BatchTaskQueue) batchTaskQueueMCPListItem { - counts := map[string]int{ - "pending": 0, - "running": 0, - "completed": 0, - "failed": 0, - "cancelled": 0, - } - tasks := make([]batchTaskMCPListSummary, 0, len(q.Tasks)) - for _, t := range q.Tasks { - if t == nil { - continue - } - counts[t.Status]++ - // 列表视图限制子任务摘要数量,完整列表通过 batch_task_get 查看 - if len(tasks) < mcpBatchListMaxTasksPerQueue { - tasks = append(tasks, batchTaskMCPListSummary{ - ID: t.ID, - Status: t.Status, - Message: truncateStringRunes(t.Message, mcpBatchListTaskMessageMaxRunes), - }) - } - } - return batchTaskQueueMCPListItem{ - ID: q.ID, - Title: q.Title, - Role: q.Role, - AgentMode: q.AgentMode, - ScheduleMode: q.ScheduleMode, - CronExpr: q.CronExpr, - NextRunAt: q.NextRunAt, - ScheduleEnabled: q.ScheduleEnabled, - LastScheduleTriggerAt: q.LastScheduleTriggerAt, - Status: q.Status, - CreatedAt: q.CreatedAt, - StartedAt: q.StartedAt, - CompletedAt: q.CompletedAt, - CurrentIndex: q.CurrentIndex, - TaskTotal: len(tasks), - TaskCounts: counts, - Tasks: tasks, - } -} - -func batchMCPTextResult(text string, isErr bool) *mcp.ToolResult { - return &mcp.ToolResult{ - Content: []mcp.Content{{Type: "text", Text: text}}, - IsError: isErr, - } -} - -func batchMCPJSONResult(v interface{}) (*mcp.ToolResult, error) { - b, err := json.MarshalIndent(v, "", " ") - if err != nil { - return batchMCPTextResult(fmt.Sprintf("JSON 编码失败: %v", err), true), nil - } - return &mcp.ToolResult{Content: []mcp.Content{{Type: "text", Text: string(b)}}}, nil -} - -func batchMCPTasksFromArgs(args map[string]interface{}) ([]string, string) { - if raw, ok := args["tasks"]; ok && raw != nil { - switch t := raw.(type) { - case []interface{}: - out := make([]string, 0, len(t)) - for _, x := range t { - if s, ok := x.(string); ok { - if tr := strings.TrimSpace(s); tr != "" { - out = append(out, tr) - } - } - } - if len(out) > 0 { - return out, "" - } - } - } - if txt := mcpArgString(args, "tasks_text"); txt != "" { - lines := strings.Split(txt, "\n") - out := make([]string, 0, len(lines)) - for _, line := range lines { - if tr := strings.TrimSpace(line); tr != "" { - out = append(out, tr) - } - } - if len(out) > 0 { - return out, "" - } - } - return nil, "需要提供 tasks(字符串数组)或 tasks_text(多行文本,每行一条任务)" -} - -func mcpArgString(args map[string]interface{}, key string) string { - v, ok := args[key] - if !ok || v == nil { - return "" - } - switch t := v.(type) { - case string: - return strings.TrimSpace(t) - case float64: - return strings.TrimSpace(strconv.FormatFloat(t, 'f', -1, 64)) - case json.Number: - return strings.TrimSpace(t.String()) - default: - return strings.TrimSpace(fmt.Sprint(t)) - } -} - -func mcpArgFloat(args map[string]interface{}, key string) float64 { - v, ok := args[key] - if !ok || v == nil { - return 0 - } - switch t := v.(type) { - case float64: - return t - case int: - return float64(t) - case int64: - return float64(t) - case json.Number: - f, _ := t.Float64() - return f - case string: - f, _ := strconv.ParseFloat(strings.TrimSpace(t), 64) - return f - default: - return 0 - } -} - -func mcpArgBool(args map[string]interface{}, key string) (val bool, ok bool) { - v, exists := args[key] - if !exists { - return false, false - } - switch t := v.(type) { - case bool: - return t, true - case string: - s := strings.ToLower(strings.TrimSpace(t)) - if s == "true" || s == "1" || s == "yes" { - return true, true - } - if s == "false" || s == "0" || s == "no" { - return false, true - } - case float64: - return t != 0, true - } - return false, false -} diff --git a/handler/chat_uploads.go b/handler/chat_uploads.go deleted file mode 100644 index c3e25fec..00000000 --- a/handler/chat_uploads.go +++ /dev/null @@ -1,512 +0,0 @@ -package handler - -import ( - "crypto/rand" - "fmt" - "io" - "net/http" - "os" - "path/filepath" - "sort" - "strings" - "time" - "unicode/utf8" - - "github.com/gin-gonic/gin" - "go.uber.org/zap" -) - -const ( - chatUploadsRootDirName = "chat_uploads" - maxChatUploadEditBytes = 2 * 1024 * 1024 // 文本编辑上限 -) - -// ChatUploadsHandler 对话中上传附件(chat_uploads 目录)的管理 API -type ChatUploadsHandler struct { - logger *zap.Logger -} - -// NewChatUploadsHandler 创建处理器 -func NewChatUploadsHandler(logger *zap.Logger) *ChatUploadsHandler { - return &ChatUploadsHandler{logger: logger} -} - -func (h *ChatUploadsHandler) absRoot() (string, error) { - cwd, err := os.Getwd() - if err != nil { - return "", err - } - return filepath.Abs(filepath.Join(cwd, chatUploadsRootDirName)) -} - -// resolveUnderChatUploads 校验 relativePath(使用 / 分隔)对应文件必须在 chat_uploads 根下 -func (h *ChatUploadsHandler) resolveUnderChatUploads(relativePath string) (abs string, err error) { - root, err := h.absRoot() - if err != nil { - return "", err - } - rel := strings.TrimSpace(relativePath) - if rel == "" { - return "", fmt.Errorf("empty path") - } - rel = filepath.Clean(filepath.FromSlash(rel)) - if rel == "." || strings.HasPrefix(rel, "..") { - return "", fmt.Errorf("invalid path") - } - full := filepath.Join(root, rel) - full, err = filepath.Abs(full) - if err != nil { - return "", err - } - rootAbs, _ := filepath.Abs(root) - if full != rootAbs && !strings.HasPrefix(full, rootAbs+string(filepath.Separator)) { - return "", fmt.Errorf("path escapes chat_uploads root") - } - return full, nil -} - -// ChatUploadFileItem 列表项 -type ChatUploadFileItem struct { - RelativePath string `json:"relativePath"` - AbsolutePath string `json:"absolutePath"` // 服务器上的绝对路径,便于在对话中引用(与附件落盘路径一致) - Name string `json:"name"` - Size int64 `json:"size"` - ModifiedUnix int64 `json:"modifiedUnix"` - Date string `json:"date"` - ConversationID string `json:"conversationId"` - // SubPath 为日期、会话目录之下的子路径(不含文件名),如 date/conv/a/b/file 则为 "a/b";无嵌套则为 ""。 - SubPath string `json:"subPath"` -} - -// List GET /api/chat-uploads -func (h *ChatUploadsHandler) List(c *gin.Context) { - conversationFilter := strings.TrimSpace(c.Query("conversation")) - root, err := h.absRoot() - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - // 保证根目录存在,否则「按文件夹」浏览时无法 mkdir,且首次列表为空时界面无路径工具栏 - if err := os.MkdirAll(root, 0755); err != nil { - h.logger.Warn("创建 chat_uploads 根目录失败", zap.Error(err)) - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - var files []ChatUploadFileItem - var folders []string - err = filepath.WalkDir(root, func(path string, d os.DirEntry, walkErr error) error { - if walkErr != nil { - return walkErr - } - rel, err := filepath.Rel(root, path) - if err != nil { - return err - } - if rel == "." { - return nil - } - relSlash := filepath.ToSlash(rel) - if d.IsDir() { - folders = append(folders, relSlash) - return nil - } - info, err := d.Info() - if err != nil { - return err - } - parts := strings.Split(relSlash, "/") - var dateStr, convID string - if len(parts) >= 2 { - dateStr = parts[0] - } - if len(parts) >= 3 { - convID = parts[1] - } - var subPath string - if len(parts) >= 4 { - subPath = strings.Join(parts[2:len(parts)-1], "/") - } - if conversationFilter != "" && convID != conversationFilter { - return nil - } - absPath, _ := filepath.Abs(path) - files = append(files, ChatUploadFileItem{ - RelativePath: relSlash, - AbsolutePath: absPath, - Name: d.Name(), - Size: info.Size(), - ModifiedUnix: info.ModTime().Unix(), - Date: dateStr, - ConversationID: convID, - SubPath: subPath, - }) - return nil - }) - if err != nil { - h.logger.Warn("列举对话附件失败", zap.Error(err)) - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - if conversationFilter != "" { - filteredFolders := make([]string, 0, len(folders)) - for _, rel := range folders { - parts := strings.Split(rel, "/") - if len(parts) >= 2 && parts[1] == conversationFilter { - filteredFolders = append(filteredFolders, rel) - continue - } - if len(parts) == 1 { - prefix := rel + "/" - for _, f := range files { - if strings.HasPrefix(f.RelativePath, prefix) { - filteredFolders = append(filteredFolders, rel) - break - } - } - } - } - folders = filteredFolders - } - sort.Strings(folders) - sort.Slice(files, func(i, j int) bool { - return files[i].ModifiedUnix > files[j].ModifiedUnix - }) - c.JSON(http.StatusOK, gin.H{"files": files, "folders": folders}) -} - -// Download GET /api/chat-uploads/download?path=... -func (h *ChatUploadsHandler) Download(c *gin.Context) { - p := c.Query("path") - abs, err := h.resolveUnderChatUploads(p) - if err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - st, err := os.Stat(abs) - if err != nil || st.IsDir() { - c.JSON(http.StatusNotFound, gin.H{"error": "file not found"}) - return - } - c.FileAttachment(abs, filepath.Base(abs)) -} - -type chatUploadPathBody struct { - Path string `json:"path"` -} - -// Delete DELETE /api/chat-uploads -func (h *ChatUploadsHandler) Delete(c *gin.Context) { - var body chatUploadPathBody - if err := c.ShouldBindJSON(&body); err != nil || strings.TrimSpace(body.Path) == "" { - c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"}) - return - } - abs, err := h.resolveUnderChatUploads(body.Path) - if err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - st, err := os.Stat(abs) - if err != nil { - if os.IsNotExist(err) { - c.JSON(http.StatusNotFound, gin.H{"error": "file not found"}) - return - } - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - if st.IsDir() { - if err := os.RemoveAll(abs); err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - } else { - if err := os.Remove(abs); err != nil { - if os.IsNotExist(err) { - c.JSON(http.StatusNotFound, gin.H{"error": "file not found"}) - return - } - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - } - c.JSON(http.StatusOK, gin.H{"ok": true}) -} - -type chatUploadMkdirBody struct { - Parent string `json:"parent"` - Name string `json:"name"` -} - -// Mkdir POST /api/chat-uploads/mkdir — 在 parent 目录下新建子目录(parent 为 chat_uploads 下相对路径,空表示根目录;name 为单段目录名) -func (h *ChatUploadsHandler) Mkdir(c *gin.Context) { - var body chatUploadMkdirBody - if err := c.ShouldBindJSON(&body); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"}) - return - } - name := strings.TrimSpace(body.Name) - if name == "" || strings.ContainsAny(name, `/\`) || name == "." || name == ".." { - c.JSON(http.StatusBadRequest, gin.H{"error": "invalid name"}) - return - } - if utf8.RuneCountInString(name) > 200 { - c.JSON(http.StatusBadRequest, gin.H{"error": "name too long"}) - return - } - - parent := strings.TrimSpace(body.Parent) - parent = filepath.ToSlash(filepath.Clean(filepath.FromSlash(parent))) - parent = strings.Trim(parent, "/") - if parent == "." { - parent = "" - } - - root, err := h.absRoot() - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - - if parent != "" { - absParent, err := h.resolveUnderChatUploads(parent) - if err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - st, err := os.Stat(absParent) - if err != nil || !st.IsDir() { - c.JSON(http.StatusBadRequest, gin.H{"error": "parent not found"}) - return - } - } - - var rel string - if parent == "" { - rel = name - } else { - rel = parent + "/" + name - } - absNew, err := h.resolveUnderChatUploads(rel) - if err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - if _, err := os.Stat(absNew); err == nil { - c.JSON(http.StatusConflict, gin.H{"error": "already exists"}) - return - } - if err := os.Mkdir(absNew, 0755); err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - relOut, _ := filepath.Rel(root, absNew) - c.JSON(http.StatusOK, gin.H{"ok": true, "relativePath": filepath.ToSlash(relOut)}) -} - -type chatUploadRenameBody struct { - Path string `json:"path"` - NewName string `json:"newName"` -} - -// Rename PUT /api/chat-uploads/rename -func (h *ChatUploadsHandler) Rename(c *gin.Context) { - var body chatUploadRenameBody - if err := c.ShouldBindJSON(&body); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"}) - return - } - newName := strings.TrimSpace(body.NewName) - if newName == "" || strings.ContainsAny(newName, `/\`) { - c.JSON(http.StatusBadRequest, gin.H{"error": "invalid newName"}) - return - } - abs, err := h.resolveUnderChatUploads(body.Path) - if err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - dir := filepath.Dir(abs) - newAbs := filepath.Join(dir, filepath.Base(newName)) - root, _ := h.absRoot() - newAbs, _ = filepath.Abs(newAbs) - if newAbs != root && !strings.HasPrefix(newAbs, root+string(filepath.Separator)) { - c.JSON(http.StatusBadRequest, gin.H{"error": "invalid target path"}) - return - } - if err := os.Rename(abs, newAbs); err != nil { - if os.IsNotExist(err) { - c.JSON(http.StatusNotFound, gin.H{"error": "file not found"}) - return - } - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - newRel, _ := filepath.Rel(root, newAbs) - c.JSON(http.StatusOK, gin.H{"ok": true, "relativePath": filepath.ToSlash(newRel)}) -} - -type chatUploadContentBody struct { - Path string `json:"path"` - Content string `json:"content"` -} - -// GetContent GET /api/chat-uploads/content?path=... -func (h *ChatUploadsHandler) GetContent(c *gin.Context) { - p := c.Query("path") - abs, err := h.resolveUnderChatUploads(p) - if err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - st, err := os.Stat(abs) - if err != nil || st.IsDir() { - c.JSON(http.StatusNotFound, gin.H{"error": "file not found"}) - return - } - if st.Size() > maxChatUploadEditBytes { - c.JSON(http.StatusRequestEntityTooLarge, gin.H{"error": "file too large for editor"}) - return - } - b, err := os.ReadFile(abs) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - if !utf8.Valid(b) { - c.JSON(http.StatusBadRequest, gin.H{"error": "binary file not editable in UI"}) - return - } - c.JSON(http.StatusOK, gin.H{"content": string(b)}) -} - -// PutContent PUT /api/chat-uploads/content -func (h *ChatUploadsHandler) PutContent(c *gin.Context) { - var body chatUploadContentBody - if err := c.ShouldBindJSON(&body); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"}) - return - } - if !utf8.ValidString(body.Content) { - c.JSON(http.StatusBadRequest, gin.H{"error": "content must be valid UTF-8"}) - return - } - if len(body.Content) > maxChatUploadEditBytes { - c.JSON(http.StatusRequestEntityTooLarge, gin.H{"error": "content too large"}) - return - } - abs, err := h.resolveUnderChatUploads(body.Path) - if err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - if err := os.WriteFile(abs, []byte(body.Content), 0644); err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - c.JSON(http.StatusOK, gin.H{"ok": true}) -} - -func chatUploadShortRand(n int) string { - const letters = "0123456789abcdef" - b := make([]byte, n) - _, _ = rand.Read(b) - for i := range b { - b[i] = letters[int(b[i])%len(letters)] - } - return string(b) -} - -// Upload POST /api/chat-uploads multipart: file;conversationId 可选;relativeDir 可选(chat_uploads 下目录的相对路径,将文件直接上传至该目录) -func (h *ChatUploadsHandler) Upload(c *gin.Context) { - fh, err := c.FormFile("file") - if err != nil || fh == nil { - c.JSON(http.StatusBadRequest, gin.H{"error": "missing file"}) - return - } - root, err := h.absRoot() - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - - var targetDir string - targetRel := strings.TrimSpace(c.PostForm("relativeDir")) - if targetRel != "" { - absDir, err := h.resolveUnderChatUploads(targetRel) - if err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - st, err := os.Stat(absDir) - if err != nil { - if os.IsNotExist(err) { - if err := os.MkdirAll(absDir, 0755); err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - } else { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - } else if !st.IsDir() { - c.JSON(http.StatusBadRequest, gin.H{"error": "relativeDir is not a directory"}) - return - } - targetDir = absDir - } else { - convID := strings.TrimSpace(c.PostForm("conversationId")) - convDir := convID - if convDir == "" { - convDir = "_manual" - } else { - convDir = strings.ReplaceAll(convDir, string(filepath.Separator), "_") - } - dateStr := time.Now().Format("2006-01-02") - targetDir = filepath.Join(root, dateStr, convDir) - if err := os.MkdirAll(targetDir, 0755); err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - } - baseName := filepath.Base(fh.Filename) - if baseName == "" || baseName == "." { - baseName = "file" - } - baseName = strings.ReplaceAll(baseName, string(filepath.Separator), "_") - ext := filepath.Ext(baseName) - nameNoExt := strings.TrimSuffix(baseName, ext) - suffix := fmt.Sprintf("_%s_%s", time.Now().Format("150405"), chatUploadShortRand(6)) - var unique string - if ext != "" { - unique = nameNoExt + suffix + ext - } else { - unique = baseName + suffix - } - fullPath := filepath.Join(targetDir, unique) - src, err := fh.Open() - if err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - defer src.Close() - dst, err := os.Create(fullPath) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - defer dst.Close() - if _, err := io.Copy(dst, src); err != nil { - _ = os.Remove(fullPath) - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - rel, _ := filepath.Rel(root, fullPath) - absSaved, _ := filepath.Abs(fullPath) - c.JSON(http.StatusOK, gin.H{ - "ok": true, - "relativePath": filepath.ToSlash(rel), - "absolutePath": absSaved, - "name": unique, - }) -} diff --git a/handler/config.go b/handler/config.go deleted file mode 100644 index 54bb19f0..00000000 --- a/handler/config.go +++ /dev/null @@ -1,1594 +0,0 @@ -package handler - -import ( - "bytes" - "context" - "fmt" - "net/http" - "os" - "path/filepath" - "strconv" - "strings" - "sync" - "time" - - "cyberstrike-ai/internal/agents" - "cyberstrike-ai/internal/config" - "cyberstrike-ai/internal/knowledge" - "cyberstrike-ai/internal/mcp" - "cyberstrike-ai/internal/openai" - "cyberstrike-ai/internal/security" - - "github.com/gin-gonic/gin" - "go.uber.org/zap" - "gopkg.in/yaml.v3" -) - -// KnowledgeToolRegistrar 知识库工具注册器接口 -type KnowledgeToolRegistrar func() error - -// VulnerabilityToolRegistrar 漏洞工具注册器接口 -type VulnerabilityToolRegistrar func() error - -// WebshellToolRegistrar WebShell 工具注册器接口(ApplyConfig 时重新注册) -type WebshellToolRegistrar func() error - -// SkillsToolRegistrar Skills工具注册器接口 -type SkillsToolRegistrar func() error - -// BatchTaskToolRegistrar 批量任务 MCP 工具注册器(ApplyConfig 时重新注册) -type BatchTaskToolRegistrar func() error - -// RetrieverUpdater 检索器更新接口 -type RetrieverUpdater interface { - UpdateConfig(config *knowledge.RetrievalConfig) -} - -// KnowledgeInitializer 知识库初始化器接口 -type KnowledgeInitializer func() (*KnowledgeHandler, error) - -// AppUpdater App更新接口(用于更新App中的知识库组件) -type AppUpdater interface { - UpdateKnowledgeComponents(handler *KnowledgeHandler, manager interface{}, retriever interface{}, indexer interface{}) -} - -// RobotRestarter 机器人连接重启器(用于配置应用后重启钉钉/飞书长连接) -type RobotRestarter interface { - RestartRobotConnections() -} - -// ConfigHandler 配置处理器 -type ConfigHandler struct { - configPath string - config *config.Config - mcpServer *mcp.Server - executor *security.Executor - agent AgentUpdater // Agent接口,用于更新Agent配置 - attackChainHandler AttackChainUpdater // 攻击链处理器接口,用于更新配置 - externalMCPMgr *mcp.ExternalMCPManager // 外部MCP管理器 - knowledgeToolRegistrar KnowledgeToolRegistrar // 知识库工具注册器(可选) - vulnerabilityToolRegistrar VulnerabilityToolRegistrar // 漏洞工具注册器(可选) - webshellToolRegistrar WebshellToolRegistrar // WebShell 工具注册器(可选) - skillsToolRegistrar SkillsToolRegistrar // Skills工具注册器(可选) - batchTaskToolRegistrar BatchTaskToolRegistrar // 批量任务 MCP 工具(可选) - retrieverUpdater RetrieverUpdater // 检索器更新器(可选) - knowledgeInitializer KnowledgeInitializer // 知识库初始化器(可选) - appUpdater AppUpdater // App更新器(可选) - robotRestarter RobotRestarter // 机器人连接重启器(可选),ApplyConfig 时重启钉钉/飞书 - logger *zap.Logger - mu sync.RWMutex - lastEmbeddingConfig *config.EmbeddingConfig // 上一次的嵌入模型配置(用于检测变更) -} - -// AttackChainUpdater 攻击链处理器更新接口 -type AttackChainUpdater interface { - UpdateConfig(cfg *config.OpenAIConfig) -} - -// AgentUpdater Agent更新接口 -type AgentUpdater interface { - UpdateConfig(cfg *config.OpenAIConfig) - UpdateMaxIterations(maxIterations int) -} - -// NewConfigHandler 创建新的配置处理器 -func NewConfigHandler(configPath string, cfg *config.Config, mcpServer *mcp.Server, executor *security.Executor, agent AgentUpdater, attackChainHandler AttackChainUpdater, externalMCPMgr *mcp.ExternalMCPManager, logger *zap.Logger) *ConfigHandler { - // 保存初始的嵌入模型配置(如果知识库已启用) - var lastEmbeddingConfig *config.EmbeddingConfig - if cfg.Knowledge.Enabled { - lastEmbeddingConfig = &config.EmbeddingConfig{ - Provider: cfg.Knowledge.Embedding.Provider, - Model: cfg.Knowledge.Embedding.Model, - BaseURL: cfg.Knowledge.Embedding.BaseURL, - APIKey: cfg.Knowledge.Embedding.APIKey, - } - } - return &ConfigHandler{ - configPath: configPath, - config: cfg, - mcpServer: mcpServer, - executor: executor, - agent: agent, - attackChainHandler: attackChainHandler, - externalMCPMgr: externalMCPMgr, - logger: logger, - lastEmbeddingConfig: lastEmbeddingConfig, - } -} - -// SetKnowledgeToolRegistrar 设置知识库工具注册器 -func (h *ConfigHandler) SetKnowledgeToolRegistrar(registrar KnowledgeToolRegistrar) { - h.mu.Lock() - defer h.mu.Unlock() - h.knowledgeToolRegistrar = registrar -} - -// SetVulnerabilityToolRegistrar 设置漏洞工具注册器 -func (h *ConfigHandler) SetVulnerabilityToolRegistrar(registrar VulnerabilityToolRegistrar) { - h.mu.Lock() - defer h.mu.Unlock() - h.vulnerabilityToolRegistrar = registrar -} - -// SetWebshellToolRegistrar 设置 WebShell 工具注册器 -func (h *ConfigHandler) SetWebshellToolRegistrar(registrar WebshellToolRegistrar) { - h.mu.Lock() - defer h.mu.Unlock() - h.webshellToolRegistrar = registrar -} - -// SetSkillsToolRegistrar 设置Skills工具注册器 -func (h *ConfigHandler) SetSkillsToolRegistrar(registrar SkillsToolRegistrar) { - h.mu.Lock() - defer h.mu.Unlock() - h.skillsToolRegistrar = registrar -} - -// SetBatchTaskToolRegistrar 设置批量任务 MCP 工具注册器 -func (h *ConfigHandler) SetBatchTaskToolRegistrar(registrar BatchTaskToolRegistrar) { - h.mu.Lock() - defer h.mu.Unlock() - h.batchTaskToolRegistrar = registrar -} - -// SetRetrieverUpdater 设置检索器更新器 -func (h *ConfigHandler) SetRetrieverUpdater(updater RetrieverUpdater) { - h.mu.Lock() - defer h.mu.Unlock() - h.retrieverUpdater = updater -} - -// SetKnowledgeInitializer 设置知识库初始化器 -func (h *ConfigHandler) SetKnowledgeInitializer(initializer KnowledgeInitializer) { - h.mu.Lock() - defer h.mu.Unlock() - h.knowledgeInitializer = initializer -} - -// SetAppUpdater 设置App更新器 -func (h *ConfigHandler) SetAppUpdater(updater AppUpdater) { - h.mu.Lock() - defer h.mu.Unlock() - h.appUpdater = updater -} - -// SetRobotRestarter 设置机器人连接重启器(ApplyConfig 时用于重启钉钉/飞书长连接) -func (h *ConfigHandler) SetRobotRestarter(restarter RobotRestarter) { - h.mu.Lock() - defer h.mu.Unlock() - h.robotRestarter = restarter -} - -// GetConfigResponse 获取配置响应 -type GetConfigResponse struct { - OpenAI config.OpenAIConfig `json:"openai"` - FOFA config.FofaConfig `json:"fofa"` - MCP config.MCPConfig `json:"mcp"` - Tools []ToolConfigInfo `json:"tools"` - Agent config.AgentConfig `json:"agent"` - Knowledge config.KnowledgeConfig `json:"knowledge"` - Robots config.RobotsConfig `json:"robots,omitempty"` - MultiAgent config.MultiAgentPublic `json:"multi_agent,omitempty"` -} - -// ToolConfigInfo 工具配置信息 -type ToolConfigInfo struct { - Name string `json:"name"` - Description string `json:"description"` - Enabled bool `json:"enabled"` - IsExternal bool `json:"is_external,omitempty"` // 是否为外部MCP工具 - ExternalMCP string `json:"external_mcp,omitempty"` // 外部MCP名称(如果是外部工具) - RoleEnabled *bool `json:"role_enabled,omitempty"` // 该工具在当前角色中是否启用(nil表示未指定角色或使用所有工具) -} - -// GetConfig 获取当前配置 -func (h *ConfigHandler) GetConfig(c *gin.Context) { - h.mu.RLock() - defer h.mu.RUnlock() - - // 获取工具列表(包含内部和外部工具) - // 首先从配置文件获取工具 - configToolMap := make(map[string]bool) - tools := make([]ToolConfigInfo, 0, len(h.config.Security.Tools)) - for _, tool := range h.config.Security.Tools { - configToolMap[tool.Name] = true - tools = append(tools, ToolConfigInfo{ - Name: tool.Name, - Description: h.pickToolDescription(tool.ShortDescription, tool.Description), - Enabled: tool.Enabled, - IsExternal: false, - }) - } - - // 从MCP服务器获取所有已注册的工具(包括直接注册的工具,如知识检索工具) - if h.mcpServer != nil { - mcpTools := h.mcpServer.GetAllTools() - for _, mcpTool := range mcpTools { - // 跳过已经在配置文件中的工具(避免重复) - if configToolMap[mcpTool.Name] { - continue - } - // 添加直接注册到MCP服务器的工具(如知识检索工具) - description := mcpTool.ShortDescription - if description == "" { - description = mcpTool.Description - } - if len(description) > 10000 { - description = description[:10000] + "..." - } - tools = append(tools, ToolConfigInfo{ - Name: mcpTool.Name, - Description: description, - Enabled: true, // 直接注册的工具默认启用 - IsExternal: false, - }) - } - } - - // 获取外部MCP工具 - if h.externalMCPMgr != nil { - ctx := context.Background() - externalTools := h.getExternalMCPTools(ctx) - for _, toolInfo := range externalTools { - tools = append(tools, toolInfo) - } - } - - subAgentCount := len(h.config.MultiAgent.SubAgents) - agentsDir := strings.TrimSpace(h.config.AgentsDir) - if agentsDir == "" { - agentsDir = "agents" - } - if !filepath.IsAbs(agentsDir) { - agentsDir = filepath.Join(filepath.Dir(h.configPath), agentsDir) - } - if load, err := agents.LoadMarkdownAgentsDir(agentsDir); err == nil { - subAgentCount = len(agents.MergeYAMLAndMarkdown(h.config.MultiAgent.SubAgents, load.SubAgents)) - } - multiPub := config.MultiAgentPublic{ - Enabled: h.config.MultiAgent.Enabled, - DefaultMode: h.config.MultiAgent.DefaultMode, - RobotUseMultiAgent: h.config.MultiAgent.RobotUseMultiAgent, - BatchUseMultiAgent: h.config.MultiAgent.BatchUseMultiAgent, - SubAgentCount: subAgentCount, - } - if strings.TrimSpace(multiPub.DefaultMode) == "" { - multiPub.DefaultMode = "single" - } - - c.JSON(http.StatusOK, GetConfigResponse{ - OpenAI: h.config.OpenAI, - FOFA: h.config.FOFA, - MCP: h.config.MCP, - Tools: tools, - Agent: h.config.Agent, - Knowledge: h.config.Knowledge, - Robots: h.config.Robots, - MultiAgent: multiPub, - }) -} - -// GetToolsResponse 获取工具列表响应(分页) -type GetToolsResponse struct { - Tools []ToolConfigInfo `json:"tools"` - Total int `json:"total"` - TotalEnabled int `json:"total_enabled"` // 已启用的工具总数 - Page int `json:"page"` - PageSize int `json:"page_size"` - TotalPages int `json:"total_pages"` -} - -// GetTools 获取工具列表(支持分页和搜索) -func (h *ConfigHandler) GetTools(c *gin.Context) { - h.mu.RLock() - defer h.mu.RUnlock() - - // 解析分页参数 - page := 1 - pageSize := 20 - if pageStr := c.Query("page"); pageStr != "" { - if p, err := strconv.Atoi(pageStr); err == nil && p > 0 { - page = p - } - } - if pageSizeStr := c.Query("page_size"); pageSizeStr != "" { - if ps, err := strconv.Atoi(pageSizeStr); err == nil && ps > 0 && ps <= 100 { - pageSize = ps - } - } - - // 解析搜索参数 - searchTerm := c.Query("search") - searchTermLower := "" - if searchTerm != "" { - searchTermLower = strings.ToLower(searchTerm) - } - - // 解析状态筛选参数: "true" = 仅已启用, "false" = 仅已停用, "" = 全部 - enabledFilter := c.Query("enabled") - var filterEnabled *bool - if enabledFilter == "true" { - v := true - filterEnabled = &v - } else if enabledFilter == "false" { - v := false - filterEnabled = &v - } - - // 解析角色参数,用于过滤工具并标注启用状态 - roleName := c.Query("role") - var roleToolsSet map[string]bool // 角色配置的工具集合 - var roleUsesAllTools bool = true // 角色是否使用所有工具(默认角色) - if roleName != "" && roleName != "默认" && h.config.Roles != nil { - if role, exists := h.config.Roles[roleName]; exists && role.Enabled { - if len(role.Tools) > 0 { - // 角色配置了工具列表,只使用这些工具 - roleToolsSet = make(map[string]bool) - for _, toolKey := range role.Tools { - roleToolsSet[toolKey] = true - } - roleUsesAllTools = false - } - } - } - - // 获取所有内部工具并应用搜索过滤 - configToolMap := make(map[string]bool) - allTools := make([]ToolConfigInfo, 0, len(h.config.Security.Tools)) - for _, tool := range h.config.Security.Tools { - configToolMap[tool.Name] = true - toolInfo := ToolConfigInfo{ - Name: tool.Name, - Description: h.pickToolDescription(tool.ShortDescription, tool.Description), - Enabled: tool.Enabled, - IsExternal: false, - } - - // 根据角色配置标注工具状态 - if roleName != "" { - if roleUsesAllTools { - // 角色使用所有工具,标注启用的工具为role_enabled=true - if tool.Enabled { - roleEnabled := true - toolInfo.RoleEnabled = &roleEnabled - } else { - roleEnabled := false - toolInfo.RoleEnabled = &roleEnabled - } - } else { - // 角色配置了工具列表,检查工具是否在列表中 - // 内部工具使用工具名称作为key - if roleToolsSet[tool.Name] { - roleEnabled := tool.Enabled // 工具必须在角色列表中且本身启用 - toolInfo.RoleEnabled = &roleEnabled - } else { - // 不在角色列表中,标记为false - roleEnabled := false - toolInfo.RoleEnabled = &roleEnabled - } - } - } - - // 如果有关键词,进行搜索过滤 - if searchTermLower != "" { - nameLower := strings.ToLower(toolInfo.Name) - descLower := strings.ToLower(toolInfo.Description) - if !strings.Contains(nameLower, searchTermLower) && !strings.Contains(descLower, searchTermLower) { - continue // 不匹配,跳过 - } - } - - // 状态筛选 - if filterEnabled != nil && toolInfo.Enabled != *filterEnabled { - continue - } - - allTools = append(allTools, toolInfo) - } - - // 从MCP服务器获取所有已注册的工具(包括直接注册的工具,如知识检索工具) - if h.mcpServer != nil { - mcpTools := h.mcpServer.GetAllTools() - for _, mcpTool := range mcpTools { - // 跳过已经在配置文件中的工具(避免重复) - if configToolMap[mcpTool.Name] { - continue - } - - description := mcpTool.ShortDescription - if description == "" { - description = mcpTool.Description - } - if len(description) > 10000 { - description = description[:10000] + "..." - } - - toolInfo := ToolConfigInfo{ - Name: mcpTool.Name, - Description: description, - Enabled: true, // 直接注册的工具默认启用 - IsExternal: false, - } - - // 根据角色配置标注工具状态 - if roleName != "" { - if roleUsesAllTools { - // 角色使用所有工具,直接注册的工具默认启用 - roleEnabled := true - toolInfo.RoleEnabled = &roleEnabled - } else { - // 角色配置了工具列表,检查工具是否在列表中 - // 内部工具使用工具名称作为key - if roleToolsSet[mcpTool.Name] { - roleEnabled := true // 在角色列表中且工具本身启用 - toolInfo.RoleEnabled = &roleEnabled - } else { - // 不在角色列表中,标记为false - roleEnabled := false - toolInfo.RoleEnabled = &roleEnabled - } - } - } - - // 如果有关键词,进行搜索过滤 - if searchTermLower != "" { - nameLower := strings.ToLower(toolInfo.Name) - descLower := strings.ToLower(toolInfo.Description) - if !strings.Contains(nameLower, searchTermLower) && !strings.Contains(descLower, searchTermLower) { - continue // 不匹配,跳过 - } - } - - // 状态筛选 - if filterEnabled != nil && toolInfo.Enabled != *filterEnabled { - continue - } - - allTools = append(allTools, toolInfo) - } - } - - // 获取外部MCP工具 - if h.externalMCPMgr != nil { - // 创建context用于获取外部工具 - ctx := context.Background() - externalTools := h.getExternalMCPTools(ctx) - - // 应用搜索过滤和角色配置 - for _, toolInfo := range externalTools { - // 搜索过滤 - if searchTermLower != "" { - nameLower := strings.ToLower(toolInfo.Name) - descLower := strings.ToLower(toolInfo.Description) - if !strings.Contains(nameLower, searchTermLower) && !strings.Contains(descLower, searchTermLower) { - continue // 不匹配,跳过 - } - } - - // 根据角色配置标注工具状态 - if roleName != "" { - if roleUsesAllTools { - // 角色使用所有工具,标注启用的工具为role_enabled=true - roleEnabled := toolInfo.Enabled - toolInfo.RoleEnabled = &roleEnabled - } else { - // 角色配置了工具列表,检查工具是否在列表中 - // 外部工具使用 "mcpName::toolName" 格式作为key - externalToolKey := fmt.Sprintf("%s::%s", toolInfo.ExternalMCP, toolInfo.Name) - if roleToolsSet[externalToolKey] { - roleEnabled := toolInfo.Enabled // 工具必须在角色列表中且本身启用 - toolInfo.RoleEnabled = &roleEnabled - } else { - // 不在角色列表中,标记为false - roleEnabled := false - toolInfo.RoleEnabled = &roleEnabled - } - } - } - - // 状态筛选 - if filterEnabled != nil && toolInfo.Enabled != *filterEnabled { - continue - } - - allTools = append(allTools, toolInfo) - } - } - - // 如果角色配置了工具列表,过滤工具(只保留列表中的工具,但保留其他工具并标记为禁用) - // 注意:这里我们不直接过滤掉工具,而是保留所有工具,但通过 role_enabled 字段标注状态 - // 这样前端可以显示所有工具,并标注哪些工具在当前角色中可用 - - total := len(allTools) - // 统计已启用的工具数(在角色中的启用工具数) - totalEnabled := 0 - for _, tool := range allTools { - if tool.RoleEnabled != nil && *tool.RoleEnabled { - totalEnabled++ - } else if tool.RoleEnabled == nil && tool.Enabled { - // 如果未指定角色,统计所有启用的工具 - totalEnabled++ - } - } - - totalPages := (total + pageSize - 1) / pageSize - if totalPages == 0 { - totalPages = 1 - } - - // 计算分页范围 - offset := (page - 1) * pageSize - end := offset + pageSize - if end > total { - end = total - } - - var tools []ToolConfigInfo - if offset < total { - tools = allTools[offset:end] - } else { - tools = []ToolConfigInfo{} - } - - c.JSON(http.StatusOK, GetToolsResponse{ - Tools: tools, - Total: total, - TotalEnabled: totalEnabled, - Page: page, - PageSize: pageSize, - TotalPages: totalPages, - }) -} - -// UpdateConfigRequest 更新配置请求 -type UpdateConfigRequest struct { - OpenAI *config.OpenAIConfig `json:"openai,omitempty"` - FOFA *config.FofaConfig `json:"fofa,omitempty"` - MCP *config.MCPConfig `json:"mcp,omitempty"` - Tools []ToolEnableStatus `json:"tools,omitempty"` - Agent *config.AgentConfig `json:"agent,omitempty"` - Knowledge *config.KnowledgeConfig `json:"knowledge,omitempty"` - Robots *config.RobotsConfig `json:"robots,omitempty"` - MultiAgent *config.MultiAgentAPIUpdate `json:"multi_agent,omitempty"` -} - -// ToolEnableStatus 工具启用状态 -type ToolEnableStatus struct { - Name string `json:"name"` - Enabled bool `json:"enabled"` - IsExternal bool `json:"is_external,omitempty"` // 是否为外部MCP工具 - ExternalMCP string `json:"external_mcp,omitempty"` // 外部MCP名称(如果是外部工具) -} - -// UpdateConfig 更新配置 -func (h *ConfigHandler) UpdateConfig(c *gin.Context) { - var req UpdateConfigRequest - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()}) - return - } - - h.mu.Lock() - defer h.mu.Unlock() - - // 更新OpenAI配置 - if req.OpenAI != nil { - h.config.OpenAI = *req.OpenAI - h.logger.Info("更新OpenAI配置", - zap.String("base_url", h.config.OpenAI.BaseURL), - zap.String("model", h.config.OpenAI.Model), - ) - } - - // 更新FOFA配置 - if req.FOFA != nil { - h.config.FOFA = *req.FOFA - h.logger.Info("更新FOFA配置", zap.String("email", h.config.FOFA.Email)) - } - - // 更新MCP配置 - if req.MCP != nil { - h.config.MCP = *req.MCP - h.logger.Info("更新MCP配置", - zap.Bool("enabled", h.config.MCP.Enabled), - zap.String("host", h.config.MCP.Host), - zap.Int("port", h.config.MCP.Port), - ) - } - - // 更新Agent配置 - if req.Agent != nil { - h.config.Agent = *req.Agent - h.logger.Info("更新Agent配置", - zap.Int("max_iterations", h.config.Agent.MaxIterations), - ) - } - - // 更新Knowledge配置 - if req.Knowledge != nil { - // 保存旧的嵌入模型配置(用于检测变更) - if h.config.Knowledge.Enabled { - h.lastEmbeddingConfig = &config.EmbeddingConfig{ - Provider: h.config.Knowledge.Embedding.Provider, - Model: h.config.Knowledge.Embedding.Model, - BaseURL: h.config.Knowledge.Embedding.BaseURL, - APIKey: h.config.Knowledge.Embedding.APIKey, - } - } - h.config.Knowledge = *req.Knowledge - h.logger.Info("更新Knowledge配置", - zap.Bool("enabled", h.config.Knowledge.Enabled), - zap.String("base_path", h.config.Knowledge.BasePath), - zap.String("embedding_model", h.config.Knowledge.Embedding.Model), - zap.Int("retrieval_top_k", h.config.Knowledge.Retrieval.TopK), - zap.Float64("similarity_threshold", h.config.Knowledge.Retrieval.SimilarityThreshold), - ) - } - - // 更新机器人配置 - if req.Robots != nil { - h.config.Robots = *req.Robots - h.logger.Info("更新机器人配置", - zap.Bool("wecom_enabled", h.config.Robots.Wecom.Enabled), - zap.Bool("dingtalk_enabled", h.config.Robots.Dingtalk.Enabled), - zap.Bool("lark_enabled", h.config.Robots.Lark.Enabled), - ) - } - - // 多代理标量(sub_agents 等仍由 config.yaml 维护) - if req.MultiAgent != nil { - h.config.MultiAgent.Enabled = req.MultiAgent.Enabled - dm := strings.TrimSpace(req.MultiAgent.DefaultMode) - if dm == "multi" || dm == "single" { - h.config.MultiAgent.DefaultMode = dm - } - h.config.MultiAgent.RobotUseMultiAgent = req.MultiAgent.RobotUseMultiAgent - h.config.MultiAgent.BatchUseMultiAgent = req.MultiAgent.BatchUseMultiAgent - h.logger.Info("更新多代理配置", - zap.Bool("enabled", h.config.MultiAgent.Enabled), - zap.String("default_mode", h.config.MultiAgent.DefaultMode), - zap.Bool("robot_use_multi_agent", h.config.MultiAgent.RobotUseMultiAgent), - zap.Bool("batch_use_multi_agent", h.config.MultiAgent.BatchUseMultiAgent), - ) - } - - // 更新工具启用状态 - if req.Tools != nil { - // 分离内部工具和外部工具 - internalToolMap := make(map[string]bool) - // 外部工具状态:MCP名称 -> 工具名称 -> 启用状态 - externalMCPToolMap := make(map[string]map[string]bool) - - for _, toolStatus := range req.Tools { - if toolStatus.IsExternal && toolStatus.ExternalMCP != "" { - // 外部工具:保存每个工具的独立状态 - mcpName := toolStatus.ExternalMCP - if externalMCPToolMap[mcpName] == nil { - externalMCPToolMap[mcpName] = make(map[string]bool) - } - externalMCPToolMap[mcpName][toolStatus.Name] = toolStatus.Enabled - } else { - // 内部工具 - internalToolMap[toolStatus.Name] = toolStatus.Enabled - } - } - - // 更新内部工具状态 - for i := range h.config.Security.Tools { - if enabled, ok := internalToolMap[h.config.Security.Tools[i].Name]; ok { - h.config.Security.Tools[i].Enabled = enabled - h.logger.Info("更新工具启用状态", - zap.String("tool", h.config.Security.Tools[i].Name), - zap.Bool("enabled", enabled), - ) - } - } - - // 更新外部MCP工具状态 - if h.externalMCPMgr != nil { - for mcpName, toolStates := range externalMCPToolMap { - // 更新配置中的工具启用状态 - if h.config.ExternalMCP.Servers == nil { - h.config.ExternalMCP.Servers = make(map[string]config.ExternalMCPServerConfig) - } - cfg, exists := h.config.ExternalMCP.Servers[mcpName] - if !exists { - h.logger.Warn("外部MCP配置不存在", zap.String("mcp", mcpName)) - continue - } - - // 初始化ToolEnabled map - if cfg.ToolEnabled == nil { - cfg.ToolEnabled = make(map[string]bool) - } - - // 更新每个工具的启用状态 - for toolName, enabled := range toolStates { - cfg.ToolEnabled[toolName] = enabled - h.logger.Info("更新外部工具启用状态", - zap.String("mcp", mcpName), - zap.String("tool", toolName), - zap.Bool("enabled", enabled), - ) - } - - // 检查是否有任何工具启用,如果有则启用MCP - hasEnabledTool := false - for _, enabled := range cfg.ToolEnabled { - if enabled { - hasEnabledTool = true - break - } - } - - // 如果MCP之前未启用,但现在有工具启用,则启用MCP - // 如果MCP之前已启用,保持启用状态(允许部分工具禁用) - if !cfg.ExternalMCPEnable && hasEnabledTool { - cfg.ExternalMCPEnable = true - h.logger.Info("自动启用外部MCP(因为有工具启用)", zap.String("mcp", mcpName)) - } - - h.config.ExternalMCP.Servers[mcpName] = cfg - } - - // 同步更新 externalMCPMgr 中的配置,确保 GetConfigs() 返回最新配置 - // 在循环外部统一更新,避免重复调用 - h.externalMCPMgr.LoadConfigs(&h.config.ExternalMCP) - - // 处理MCP连接状态(异步启动,避免阻塞) - for mcpName := range externalMCPToolMap { - cfg := h.config.ExternalMCP.Servers[mcpName] - // 如果MCP需要启用,确保客户端已启动 - if cfg.ExternalMCPEnable { - // 启动外部MCP(如果未启动)- 异步执行,避免阻塞 - client, exists := h.externalMCPMgr.GetClient(mcpName) - if !exists || !client.IsConnected() { - go func(name string) { - if err := h.externalMCPMgr.StartClient(name); err != nil { - h.logger.Warn("启动外部MCP失败", - zap.String("mcp", name), - zap.Error(err), - ) - } else { - h.logger.Info("启动外部MCP", - zap.String("mcp", name), - ) - } - }(mcpName) - } - } - } - } - } - - // 保存配置到文件 - if err := h.saveConfig(); err != nil { - h.logger.Error("保存配置失败", zap.Error(err)) - c.JSON(http.StatusInternalServerError, gin.H{"error": "保存配置失败: " + err.Error()}) - return - } - - c.JSON(http.StatusOK, gin.H{"message": "配置已更新"}) -} - -// TestOpenAIRequest 测试OpenAI连接请求 -type TestOpenAIRequest struct { - Provider string `json:"provider"` - BaseURL string `json:"base_url"` - APIKey string `json:"api_key"` - Model string `json:"model"` -} - -// TestOpenAI 测试OpenAI API连接是否可用 -func (h *ConfigHandler) TestOpenAI(c *gin.Context) { - var req TestOpenAIRequest - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()}) - return - } - - if strings.TrimSpace(req.APIKey) == "" { - c.JSON(http.StatusBadRequest, gin.H{"error": "API Key 不能为空"}) - return - } - if strings.TrimSpace(req.Model) == "" { - c.JSON(http.StatusBadRequest, gin.H{"error": "模型不能为空"}) - return - } - - baseURL := strings.TrimSuffix(strings.TrimSpace(req.BaseURL), "/") - if baseURL == "" { - if strings.EqualFold(strings.TrimSpace(req.Provider), "claude") { - baseURL = "https://api.anthropic.com" - } else { - baseURL = "https://api.openai.com/v1" - } - } - - // 构造一个最小的 chat completion 请求 - payload := map[string]interface{}{ - "model": req.Model, - "messages": []map[string]string{ - {"role": "user", "content": "Hi"}, - }, - "max_tokens": 5, - } - - // 使用内部 openai Client 进行测试,若 provider 为 claude 会自动走桥接层 - tmpCfg := &config.OpenAIConfig{ - Provider: req.Provider, - BaseURL: baseURL, - APIKey: strings.TrimSpace(req.APIKey), - Model: req.Model, - } - client := openai.NewClient(tmpCfg, nil, h.logger) - - ctx, cancel := context.WithTimeout(c.Request.Context(), 30*time.Second) - defer cancel() - - start := time.Now() - var chatResp struct { - ID string `json:"id"` - Object string `json:"object"` - Model string `json:"model"` - Choices []struct { - Message struct { - Role string `json:"role"` - Content string `json:"content"` - } `json:"message"` - } `json:"choices"` - } - err := client.ChatCompletion(ctx, payload, &chatResp) - latency := time.Since(start) - - if err != nil { - if apiErr, ok := err.(*openai.APIError); ok { - c.JSON(http.StatusOK, gin.H{ - "success": false, - "error": fmt.Sprintf("API 返回错误 (HTTP %d): %s", apiErr.StatusCode, apiErr.Body), - "status_code": apiErr.StatusCode, - }) - return - } - c.JSON(http.StatusOK, gin.H{ - "success": false, - "error": "连接失败: " + err.Error(), - }) - return - } - - // 严格校验:必须包含 choices 且有 assistant 回复 - if len(chatResp.Choices) == 0 { - c.JSON(http.StatusOK, gin.H{ - "success": false, - "error": "API 响应缺少 choices 字段,请检查 Base URL 路径是否正确", - }) - return - } - if chatResp.ID == "" && chatResp.Model == "" { - c.JSON(http.StatusOK, gin.H{ - "success": false, - "error": "API 响应格式不符合预期,请检查 Base URL 是否正确", - }) - return - } - - c.JSON(http.StatusOK, gin.H{ - "success": true, - "model": chatResp.Model, - "latency_ms": latency.Milliseconds(), - }) -} - -// ApplyConfig 应用配置(重新加载并重启相关服务) -func (h *ConfigHandler) ApplyConfig(c *gin.Context) { - // 先检查是否需要动态初始化知识库(在锁外执行,避免阻塞其他请求) - var needInitKnowledge bool - var knowledgeInitializer KnowledgeInitializer - - h.mu.RLock() - needInitKnowledge = h.config.Knowledge.Enabled && h.knowledgeToolRegistrar == nil && h.knowledgeInitializer != nil - if needInitKnowledge { - knowledgeInitializer = h.knowledgeInitializer - } - h.mu.RUnlock() - - // 如果需要动态初始化知识库,在锁外执行(这是耗时操作) - if needInitKnowledge { - h.logger.Info("检测到知识库从禁用变为启用,开始动态初始化知识库组件") - if _, err := knowledgeInitializer(); err != nil { - h.logger.Error("动态初始化知识库失败", zap.Error(err)) - c.JSON(http.StatusInternalServerError, gin.H{"error": "初始化知识库失败: " + err.Error()}) - return - } - h.logger.Info("知识库动态初始化完成,工具已注册") - } - - // 检查嵌入模型配置是否变更(需要在锁外执行,避免阻塞) - var needReinitKnowledge bool - var reinitKnowledgeInitializer KnowledgeInitializer - h.mu.RLock() - if h.config.Knowledge.Enabled && h.knowledgeInitializer != nil && h.lastEmbeddingConfig != nil { - // 检查嵌入模型配置是否变更 - currentEmbedding := h.config.Knowledge.Embedding - if currentEmbedding.Provider != h.lastEmbeddingConfig.Provider || - currentEmbedding.Model != h.lastEmbeddingConfig.Model || - currentEmbedding.BaseURL != h.lastEmbeddingConfig.BaseURL || - currentEmbedding.APIKey != h.lastEmbeddingConfig.APIKey { - needReinitKnowledge = true - reinitKnowledgeInitializer = h.knowledgeInitializer - h.logger.Info("检测到嵌入模型配置变更,需要重新初始化知识库组件", - zap.String("old_model", h.lastEmbeddingConfig.Model), - zap.String("new_model", currentEmbedding.Model), - zap.String("old_base_url", h.lastEmbeddingConfig.BaseURL), - zap.String("new_base_url", currentEmbedding.BaseURL), - ) - } - } - h.mu.RUnlock() - - // 如果需要重新初始化知识库(嵌入模型配置变更),在锁外执行 - if needReinitKnowledge { - h.logger.Info("开始重新初始化知识库组件(嵌入模型配置已变更)") - if _, err := reinitKnowledgeInitializer(); err != nil { - h.logger.Error("重新初始化知识库失败", zap.Error(err)) - c.JSON(http.StatusInternalServerError, gin.H{"error": "重新初始化知识库失败: " + err.Error()}) - return - } - h.logger.Info("知识库组件重新初始化完成") - } - - // 现在获取写锁,执行快速的操作 - h.mu.Lock() - defer h.mu.Unlock() - - // 如果重新初始化了知识库,更新嵌入模型配置记录 - if needReinitKnowledge && h.config.Knowledge.Enabled { - h.lastEmbeddingConfig = &config.EmbeddingConfig{ - Provider: h.config.Knowledge.Embedding.Provider, - Model: h.config.Knowledge.Embedding.Model, - BaseURL: h.config.Knowledge.Embedding.BaseURL, - APIKey: h.config.Knowledge.Embedding.APIKey, - } - h.logger.Info("已更新嵌入模型配置记录") - } - - // 重新注册工具(根据新的启用状态) - h.logger.Info("重新注册工具") - - // 清空MCP服务器中的工具 - h.mcpServer.ClearTools() - - // 重新注册安全工具 - h.executor.RegisterTools(h.mcpServer) - - // 重新注册漏洞记录工具(内置工具,必须注册) - if h.vulnerabilityToolRegistrar != nil { - h.logger.Info("重新注册漏洞记录工具") - if err := h.vulnerabilityToolRegistrar(); err != nil { - h.logger.Error("重新注册漏洞记录工具失败", zap.Error(err)) - } else { - h.logger.Info("漏洞记录工具已重新注册") - } - } - - // 重新注册 WebShell 工具(内置工具,必须注册) - if h.webshellToolRegistrar != nil { - h.logger.Info("重新注册 WebShell 工具") - if err := h.webshellToolRegistrar(); err != nil { - h.logger.Error("重新注册 WebShell 工具失败", zap.Error(err)) - } else { - h.logger.Info("WebShell 工具已重新注册") - } - } - - // 重新注册Skills工具(内置工具,必须注册) - if h.skillsToolRegistrar != nil { - h.logger.Info("重新注册Skills工具") - if err := h.skillsToolRegistrar(); err != nil { - h.logger.Error("重新注册Skills工具失败", zap.Error(err)) - } else { - h.logger.Info("Skills工具已重新注册") - } - } - - // 重新注册批量任务 MCP 工具 - if h.batchTaskToolRegistrar != nil { - h.logger.Info("重新注册批量任务 MCP 工具") - if err := h.batchTaskToolRegistrar(); err != nil { - h.logger.Error("重新注册批量任务 MCP 工具失败", zap.Error(err)) - } else { - h.logger.Info("批量任务 MCP 工具已重新注册") - } - } - - // 如果知识库启用,重新注册知识库工具 - if h.config.Knowledge.Enabled && h.knowledgeToolRegistrar != nil { - h.logger.Info("重新注册知识库工具") - if err := h.knowledgeToolRegistrar(); err != nil { - h.logger.Error("重新注册知识库工具失败", zap.Error(err)) - } else { - h.logger.Info("知识库工具已重新注册") - } - } - - // 更新Agent的OpenAI配置 - if h.agent != nil { - h.agent.UpdateConfig(&h.config.OpenAI) - h.agent.UpdateMaxIterations(h.config.Agent.MaxIterations) - h.logger.Info("Agent配置已更新") - } - - // 更新AttackChainHandler的OpenAI配置 - if h.attackChainHandler != nil { - h.attackChainHandler.UpdateConfig(&h.config.OpenAI) - h.logger.Info("AttackChainHandler配置已更新") - } - - // 更新检索器配置(如果知识库启用) - if h.config.Knowledge.Enabled && h.retrieverUpdater != nil { - retrievalConfig := &knowledge.RetrievalConfig{ - TopK: h.config.Knowledge.Retrieval.TopK, - SimilarityThreshold: h.config.Knowledge.Retrieval.SimilarityThreshold, - SubIndexFilter: h.config.Knowledge.Retrieval.SubIndexFilter, - PostRetrieve: h.config.Knowledge.Retrieval.PostRetrieve, - } - h.retrieverUpdater.UpdateConfig(retrievalConfig) - h.logger.Info("检索器配置已更新", - zap.Int("top_k", retrievalConfig.TopK), - zap.Float64("similarity_threshold", retrievalConfig.SimilarityThreshold), - ) - } - - // 更新嵌入模型配置记录(如果知识库启用) - if h.config.Knowledge.Enabled { - h.lastEmbeddingConfig = &config.EmbeddingConfig{ - Provider: h.config.Knowledge.Embedding.Provider, - Model: h.config.Knowledge.Embedding.Model, - BaseURL: h.config.Knowledge.Embedding.BaseURL, - APIKey: h.config.Knowledge.Embedding.APIKey, - } - } - - // 重启钉钉/飞书长连接,使前端修改的机器人配置立即生效(无需重启服务) - if h.robotRestarter != nil { - h.robotRestarter.RestartRobotConnections() - h.logger.Info("已触发机器人连接重启(钉钉/飞书)") - } - - h.logger.Info("配置已应用", - zap.Int("tools_count", len(h.config.Security.Tools)), - ) - - c.JSON(http.StatusOK, gin.H{ - "message": "配置已应用", - "tools_count": len(h.config.Security.Tools), - }) -} - -// saveConfig 保存配置到文件 -func (h *ConfigHandler) saveConfig() error { - // 读取现有配置文件并创建备份 - data, err := os.ReadFile(h.configPath) - if err != nil { - return fmt.Errorf("读取配置文件失败: %w", err) - } - - if err := os.WriteFile(h.configPath+".backup", data, 0644); err != nil { - h.logger.Warn("创建配置备份失败", zap.Error(err)) - } - - root, err := loadYAMLDocument(h.configPath) - if err != nil { - return fmt.Errorf("解析配置文件失败: %w", err) - } - - updateAgentConfig(root, h.config.Agent.MaxIterations) - updateMCPConfig(root, h.config.MCP) - updateOpenAIConfig(root, h.config.OpenAI) - updateFOFAConfig(root, h.config.FOFA) - updateKnowledgeConfig(root, h.config.Knowledge) - updateRobotsConfig(root, h.config.Robots) - updateMultiAgentConfig(root, h.config.MultiAgent) - // 更新外部MCP配置(使用external_mcp.go中的函数,同一包中可直接调用) - // 读取原始配置以保持向后兼容 - originalConfigs := make(map[string]map[string]bool) - externalMCPNode := findMapValue(root, "external_mcp") - if externalMCPNode != nil && externalMCPNode.Kind == yaml.MappingNode { - serversNode := findMapValue(externalMCPNode, "servers") - if serversNode != nil && serversNode.Kind == yaml.MappingNode { - for i := 0; i < len(serversNode.Content); i += 2 { - if i+1 >= len(serversNode.Content) { - break - } - nameNode := serversNode.Content[i] - serverNode := serversNode.Content[i+1] - if nameNode.Kind == yaml.ScalarNode && serverNode.Kind == yaml.MappingNode { - serverName := nameNode.Value - originalConfigs[serverName] = make(map[string]bool) - if enabledVal := findBoolInMap(serverNode, "enabled"); enabledVal != nil { - originalConfigs[serverName]["enabled"] = *enabledVal - } - if disabledVal := findBoolInMap(serverNode, "disabled"); disabledVal != nil { - originalConfigs[serverName]["disabled"] = *disabledVal - } - } - } - } - } - updateExternalMCPConfig(root, h.config.ExternalMCP, originalConfigs) - - if err := writeYAMLDocument(h.configPath, root); err != nil { - return fmt.Errorf("保存配置文件失败: %w", err) - } - - // 更新工具配置文件中的enabled状态 - if h.config.Security.ToolsDir != "" { - configDir := filepath.Dir(h.configPath) - toolsDir := h.config.Security.ToolsDir - if !filepath.IsAbs(toolsDir) { - toolsDir = filepath.Join(configDir, toolsDir) - } - - for _, tool := range h.config.Security.Tools { - toolFile := filepath.Join(toolsDir, tool.Name+".yaml") - // 检查文件是否存在 - if _, err := os.Stat(toolFile); os.IsNotExist(err) { - // 尝试.yml扩展名 - toolFile = filepath.Join(toolsDir, tool.Name+".yml") - if _, err := os.Stat(toolFile); os.IsNotExist(err) { - h.logger.Warn("工具配置文件不存在", zap.String("tool", tool.Name)) - continue - } - } - - toolDoc, err := loadYAMLDocument(toolFile) - if err != nil { - h.logger.Warn("解析工具配置失败", zap.String("tool", tool.Name), zap.Error(err)) - continue - } - - setBoolInMap(toolDoc.Content[0], "enabled", tool.Enabled) - - if err := writeYAMLDocument(toolFile, toolDoc); err != nil { - h.logger.Warn("保存工具配置文件失败", zap.String("tool", tool.Name), zap.Error(err)) - continue - } - - h.logger.Info("更新工具配置", zap.String("tool", tool.Name), zap.Bool("enabled", tool.Enabled)) - } - } - - h.logger.Info("配置已保存", zap.String("path", h.configPath)) - return nil -} - -func loadYAMLDocument(path string) (*yaml.Node, error) { - data, err := os.ReadFile(path) - if err != nil { - return nil, err - } - - if len(bytes.TrimSpace(data)) == 0 { - return newEmptyYAMLDocument(), nil - } - - var doc yaml.Node - if err := yaml.Unmarshal(data, &doc); err != nil { - return nil, err - } - - if doc.Kind != yaml.DocumentNode || len(doc.Content) == 0 { - return newEmptyYAMLDocument(), nil - } - - if doc.Content[0].Kind != yaml.MappingNode { - root := &yaml.Node{Kind: yaml.MappingNode, Tag: "!!map"} - doc.Content = []*yaml.Node{root} - } - - return &doc, nil -} - -func newEmptyYAMLDocument() *yaml.Node { - root := &yaml.Node{ - Kind: yaml.DocumentNode, - Content: []*yaml.Node{{Kind: yaml.MappingNode, Tag: "!!map"}}, - } - return root -} - -func writeYAMLDocument(path string, doc *yaml.Node) error { - var buf bytes.Buffer - encoder := yaml.NewEncoder(&buf) - encoder.SetIndent(2) - if err := encoder.Encode(doc); err != nil { - return err - } - if err := encoder.Close(); err != nil { - return err - } - return os.WriteFile(path, buf.Bytes(), 0644) -} - -func updateAgentConfig(doc *yaml.Node, maxIterations int) { - root := doc.Content[0] - agentNode := ensureMap(root, "agent") - setIntInMap(agentNode, "max_iterations", maxIterations) -} - -func updateMCPConfig(doc *yaml.Node, cfg config.MCPConfig) { - root := doc.Content[0] - mcpNode := ensureMap(root, "mcp") - setBoolInMap(mcpNode, "enabled", cfg.Enabled) - setStringInMap(mcpNode, "host", cfg.Host) - setIntInMap(mcpNode, "port", cfg.Port) -} - -func updateOpenAIConfig(doc *yaml.Node, cfg config.OpenAIConfig) { - root := doc.Content[0] - openaiNode := ensureMap(root, "openai") - if cfg.Provider != "" { - setStringInMap(openaiNode, "provider", cfg.Provider) - } - setStringInMap(openaiNode, "api_key", cfg.APIKey) - setStringInMap(openaiNode, "base_url", cfg.BaseURL) - setStringInMap(openaiNode, "model", cfg.Model) - if cfg.MaxTotalTokens > 0 { - setIntInMap(openaiNode, "max_total_tokens", cfg.MaxTotalTokens) - } -} - -func updateFOFAConfig(doc *yaml.Node, cfg config.FofaConfig) { - root := doc.Content[0] - fofaNode := ensureMap(root, "fofa") - setStringInMap(fofaNode, "base_url", cfg.BaseURL) - setStringInMap(fofaNode, "email", cfg.Email) - setStringInMap(fofaNode, "api_key", cfg.APIKey) -} - -func updateKnowledgeConfig(doc *yaml.Node, cfg config.KnowledgeConfig) { - root := doc.Content[0] - knowledgeNode := ensureMap(root, "knowledge") - setBoolInMap(knowledgeNode, "enabled", cfg.Enabled) - setStringInMap(knowledgeNode, "base_path", cfg.BasePath) - - // 更新嵌入配置 - embeddingNode := ensureMap(knowledgeNode, "embedding") - setStringInMap(embeddingNode, "provider", cfg.Embedding.Provider) - setStringInMap(embeddingNode, "model", cfg.Embedding.Model) - if cfg.Embedding.BaseURL != "" { - setStringInMap(embeddingNode, "base_url", cfg.Embedding.BaseURL) - } - if cfg.Embedding.APIKey != "" { - setStringInMap(embeddingNode, "api_key", cfg.Embedding.APIKey) - } - - // 更新检索配置 - retrievalNode := ensureMap(knowledgeNode, "retrieval") - setIntInMap(retrievalNode, "top_k", cfg.Retrieval.TopK) - setFloatInMap(retrievalNode, "similarity_threshold", cfg.Retrieval.SimilarityThreshold) - setStringInMap(retrievalNode, "sub_index_filter", cfg.Retrieval.SubIndexFilter) - postNode := ensureMap(retrievalNode, "post_retrieve") - setIntInMap(postNode, "prefetch_top_k", cfg.Retrieval.PostRetrieve.PrefetchTopK) - setIntInMap(postNode, "max_context_chars", cfg.Retrieval.PostRetrieve.MaxContextChars) - setIntInMap(postNode, "max_context_tokens", cfg.Retrieval.PostRetrieve.MaxContextTokens) - - // 更新索引配置 - indexingNode := ensureMap(knowledgeNode, "indexing") - setStringInMap(indexingNode, "chunk_strategy", cfg.Indexing.ChunkStrategy) - setIntInMap(indexingNode, "request_timeout_seconds", cfg.Indexing.RequestTimeoutSeconds) - setIntInMap(indexingNode, "chunk_size", cfg.Indexing.ChunkSize) - setIntInMap(indexingNode, "chunk_overlap", cfg.Indexing.ChunkOverlap) - setIntInMap(indexingNode, "max_chunks_per_item", cfg.Indexing.MaxChunksPerItem) - setBoolInMap(indexingNode, "prefer_source_file", cfg.Indexing.PreferSourceFile) - setIntInMap(indexingNode, "batch_size", cfg.Indexing.BatchSize) - setStringSliceInMap(indexingNode, "sub_indexes", cfg.Indexing.SubIndexes) - setIntInMap(indexingNode, "max_rpm", cfg.Indexing.MaxRPM) - setIntInMap(indexingNode, "rate_limit_delay_ms", cfg.Indexing.RateLimitDelayMs) - setIntInMap(indexingNode, "max_retries", cfg.Indexing.MaxRetries) - setIntInMap(indexingNode, "retry_delay_ms", cfg.Indexing.RetryDelayMs) -} - -func updateRobotsConfig(doc *yaml.Node, cfg config.RobotsConfig) { - root := doc.Content[0] - robotsNode := ensureMap(root, "robots") - - wecomNode := ensureMap(robotsNode, "wecom") - setBoolInMap(wecomNode, "enabled", cfg.Wecom.Enabled) - setStringInMap(wecomNode, "token", cfg.Wecom.Token) - setStringInMap(wecomNode, "encoding_aes_key", cfg.Wecom.EncodingAESKey) - setStringInMap(wecomNode, "corp_id", cfg.Wecom.CorpID) - setStringInMap(wecomNode, "secret", cfg.Wecom.Secret) - setIntInMap(wecomNode, "agent_id", int(cfg.Wecom.AgentID)) - - dingtalkNode := ensureMap(robotsNode, "dingtalk") - setBoolInMap(dingtalkNode, "enabled", cfg.Dingtalk.Enabled) - setStringInMap(dingtalkNode, "client_id", cfg.Dingtalk.ClientID) - setStringInMap(dingtalkNode, "client_secret", cfg.Dingtalk.ClientSecret) - - larkNode := ensureMap(robotsNode, "lark") - setBoolInMap(larkNode, "enabled", cfg.Lark.Enabled) - setStringInMap(larkNode, "app_id", cfg.Lark.AppID) - setStringInMap(larkNode, "app_secret", cfg.Lark.AppSecret) - setStringInMap(larkNode, "verify_token", cfg.Lark.VerifyToken) -} - -func updateMultiAgentConfig(doc *yaml.Node, cfg config.MultiAgentConfig) { - root := doc.Content[0] - maNode := ensureMap(root, "multi_agent") - setBoolInMap(maNode, "enabled", cfg.Enabled) - setStringInMap(maNode, "default_mode", cfg.DefaultMode) - setBoolInMap(maNode, "robot_use_multi_agent", cfg.RobotUseMultiAgent) - setBoolInMap(maNode, "batch_use_multi_agent", cfg.BatchUseMultiAgent) -} - -func ensureMap(parent *yaml.Node, path ...string) *yaml.Node { - current := parent - for _, key := range path { - value := findMapValue(current, key) - if value == nil { - keyNode := &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: key} - mapNode := &yaml.Node{Kind: yaml.MappingNode, Tag: "!!map"} - current.Content = append(current.Content, keyNode, mapNode) - value = mapNode - } - - if value.Kind != yaml.MappingNode { - value.Kind = yaml.MappingNode - value.Tag = "!!map" - value.Style = 0 - value.Content = nil - } - - current = value - } - - return current -} - -func findMapValue(mapNode *yaml.Node, key string) *yaml.Node { - if mapNode == nil || mapNode.Kind != yaml.MappingNode { - return nil - } - - for i := 0; i < len(mapNode.Content); i += 2 { - if mapNode.Content[i].Value == key { - return mapNode.Content[i+1] - } - } - return nil -} - -func ensureKeyValue(mapNode *yaml.Node, key string) (*yaml.Node, *yaml.Node) { - if mapNode == nil || mapNode.Kind != yaml.MappingNode { - return nil, nil - } - - for i := 0; i < len(mapNode.Content); i += 2 { - if mapNode.Content[i].Value == key { - return mapNode.Content[i], mapNode.Content[i+1] - } - } - - keyNode := &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: key} - valueNode := &yaml.Node{} - mapNode.Content = append(mapNode.Content, keyNode, valueNode) - return keyNode, valueNode -} - -func setStringInMap(mapNode *yaml.Node, key, value string) { - _, valueNode := ensureKeyValue(mapNode, key) - valueNode.Kind = yaml.ScalarNode - valueNode.Tag = "!!str" - valueNode.Style = 0 - valueNode.Value = value -} - -func setStringSliceInMap(mapNode *yaml.Node, key string, values []string) { - _, valueNode := ensureKeyValue(mapNode, key) - valueNode.Kind = yaml.SequenceNode - valueNode.Tag = "!!seq" - valueNode.Style = 0 - valueNode.Content = nil - for _, v := range values { - valueNode.Content = append(valueNode.Content, &yaml.Node{ - Kind: yaml.ScalarNode, - Tag: "!!str", - Value: v, - }) - } -} - -func setIntInMap(mapNode *yaml.Node, key string, value int) { - _, valueNode := ensureKeyValue(mapNode, key) - valueNode.Kind = yaml.ScalarNode - valueNode.Tag = "!!int" - valueNode.Style = 0 - valueNode.Value = fmt.Sprintf("%d", value) -} - -func findBoolInMap(mapNode *yaml.Node, key string) *bool { - if mapNode == nil || mapNode.Kind != yaml.MappingNode { - return nil - } - - for i := 0; i < len(mapNode.Content); i += 2 { - if i+1 >= len(mapNode.Content) { - break - } - keyNode := mapNode.Content[i] - valueNode := mapNode.Content[i+1] - - if keyNode.Kind == yaml.ScalarNode && keyNode.Value == key { - if valueNode.Kind == yaml.ScalarNode { - if valueNode.Value == "true" { - result := true - return &result - } else if valueNode.Value == "false" { - result := false - return &result - } - } - return nil - } - } - return nil -} - -func setBoolInMap(mapNode *yaml.Node, key string, value bool) { - _, valueNode := ensureKeyValue(mapNode, key) - valueNode.Kind = yaml.ScalarNode - valueNode.Tag = "!!bool" - valueNode.Style = 0 - if value { - valueNode.Value = "true" - } else { - valueNode.Value = "false" - } -} - -func setFloatInMap(mapNode *yaml.Node, key string, value float64) { - _, valueNode := ensureKeyValue(mapNode, key) - valueNode.Kind = yaml.ScalarNode - valueNode.Tag = "!!float" - valueNode.Style = 0 - // 对于0.0到1.0之间的值(如 similarity_threshold),使用%.1f确保0.0被明确序列化为"0.0" - // 对于其他值,使用%g自动选择最合适的格式 - if value >= 0.0 && value <= 1.0 { - valueNode.Value = fmt.Sprintf("%.1f", value) - } else { - valueNode.Value = fmt.Sprintf("%g", value) - } -} - -// getExternalMCPTools 获取外部MCP工具列表(公共方法) -// 返回 ToolConfigInfo 列表,已处理启用状态和描述信息 -func (h *ConfigHandler) getExternalMCPTools(ctx context.Context) []ToolConfigInfo { - var result []ToolConfigInfo - - if h.externalMCPMgr == nil { - return result - } - - // 使用较短的超时时间(5秒)进行快速失败,避免阻塞页面加载 - timeoutCtx, cancel := context.WithTimeout(ctx, 5*time.Second) - defer cancel() - - externalTools, err := h.externalMCPMgr.GetAllTools(timeoutCtx) - if err != nil { - // 记录警告但不阻塞,继续返回已缓存的工具(如果有) - h.logger.Warn("获取外部MCP工具失败(可能连接断开),尝试返回缓存的工具", - zap.Error(err), - zap.String("hint", "如果外部MCP工具未显示,请检查连接状态或点击刷新按钮"), - ) - } - - // 如果获取到了工具(即使有错误),继续处理 - if len(externalTools) == 0 { - return result - } - - externalMCPConfigs := h.externalMCPMgr.GetConfigs() - - for _, externalTool := range externalTools { - // 解析工具名称:mcpName::toolName - mcpName, actualToolName := h.parseExternalToolName(externalTool.Name) - if mcpName == "" || actualToolName == "" { - continue // 跳过格式不正确的工具 - } - - // 计算启用状态 - enabled := h.calculateExternalToolEnabled(mcpName, actualToolName, externalMCPConfigs) - - // 处理描述信息 - description := h.pickToolDescription(externalTool.ShortDescription, externalTool.Description) - - result = append(result, ToolConfigInfo{ - Name: actualToolName, - Description: description, - Enabled: enabled, - IsExternal: true, - ExternalMCP: mcpName, - }) - } - - return result -} - -// parseExternalToolName 解析外部工具名称(格式:mcpName::toolName) -func (h *ConfigHandler) parseExternalToolName(fullName string) (mcpName, toolName string) { - idx := strings.Index(fullName, "::") - if idx > 0 { - return fullName[:idx], fullName[idx+2:] - } - return "", "" -} - -// calculateExternalToolEnabled 计算外部工具的启用状态 -func (h *ConfigHandler) calculateExternalToolEnabled(mcpName, toolName string, configs map[string]config.ExternalMCPServerConfig) bool { - cfg, exists := configs[mcpName] - if !exists { - return false - } - - // 首先检查外部MCP是否启用 - if !cfg.ExternalMCPEnable && !(cfg.Enabled && !cfg.Disabled) { - return false // MCP未启用,所有工具都禁用 - } - - // MCP已启用,检查单个工具的启用状态 - // 如果ToolEnabled为空或未设置该工具,默认为启用(向后兼容) - if cfg.ToolEnabled == nil { - // 未设置工具状态,默认为启用 - } else if toolEnabled, exists := cfg.ToolEnabled[toolName]; exists { - // 使用配置的工具状态 - if !toolEnabled { - return false - } - } - // 工具未在配置中,默认为启用 - - // 最后检查外部MCP是否已连接 - client, exists := h.externalMCPMgr.GetClient(mcpName) - if !exists || !client.IsConnected() { - return false // 未连接时视为禁用 - } - - return true -} - -// pickToolDescription 根据 security.tool_description_mode 选择 short 或 full 描述并限制长度 -func (h *ConfigHandler) pickToolDescription(shortDesc, fullDesc string) string { - useFull := strings.TrimSpace(strings.ToLower(h.config.Security.ToolDescriptionMode)) == "full" - description := shortDesc - if useFull { - description = fullDesc - } else if description == "" { - description = fullDesc - } - if len(description) > 10000 { - description = description[:10000] + "..." - } - return description -} diff --git a/handler/conversation.go b/handler/conversation.go deleted file mode 100644 index 4bb72bbe..00000000 --- a/handler/conversation.go +++ /dev/null @@ -1,233 +0,0 @@ -package handler - -import ( - "encoding/json" - "net/http" - "strconv" - - "cyberstrike-ai/internal/database" - "github.com/gin-gonic/gin" - "go.uber.org/zap" -) - -// ConversationHandler 对话处理器 -type ConversationHandler struct { - db *database.DB - logger *zap.Logger -} - -// NewConversationHandler 创建新的对话处理器 -func NewConversationHandler(db *database.DB, logger *zap.Logger) *ConversationHandler { - return &ConversationHandler{ - db: db, - logger: logger, - } -} - -// CreateConversationRequest 创建对话请求 -type CreateConversationRequest struct { - Title string `json:"title"` -} - -// CreateConversation 创建新对话 -func (h *ConversationHandler) CreateConversation(c *gin.Context) { - var req CreateConversationRequest - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - - title := req.Title - if title == "" { - title = "新对话" - } - - conv, err := h.db.CreateConversation(title) - if err != nil { - h.logger.Error("创建对话失败", zap.Error(err)) - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - - c.JSON(http.StatusOK, conv) -} - -// ListConversations 列出对话 -func (h *ConversationHandler) ListConversations(c *gin.Context) { - limitStr := c.DefaultQuery("limit", "50") - offsetStr := c.DefaultQuery("offset", "0") - search := c.Query("search") // 获取搜索参数 - - limit, _ := strconv.Atoi(limitStr) - offset, _ := strconv.Atoi(offsetStr) - - if limit <= 0 || limit > 100 { - limit = 50 - } - - conversations, err := h.db.ListConversations(limit, offset, search) - if err != nil { - h.logger.Error("获取对话列表失败", zap.Error(err)) - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - - c.JSON(http.StatusOK, conversations) -} - -// GetConversation 获取对话 -func (h *ConversationHandler) GetConversation(c *gin.Context) { - id := c.Param("id") - - // 默认轻量加载,只有用户需要展开详情时再按需拉取 - // include_process_details=1/true 时返回全量 processDetails(兼容旧行为) - includeStr := c.DefaultQuery("include_process_details", "0") - include := includeStr == "1" || includeStr == "true" || includeStr == "yes" - - var ( - conv *database.Conversation - err error - ) - if include { - conv, err = h.db.GetConversation(id) - } else { - conv, err = h.db.GetConversationLite(id) - } - if err != nil { - h.logger.Error("获取对话失败", zap.Error(err)) - c.JSON(http.StatusNotFound, gin.H{"error": "对话不存在"}) - return - } - - c.JSON(http.StatusOK, conv) -} - -// GetMessageProcessDetails 获取指定消息的过程详情(按需加载) -func (h *ConversationHandler) GetMessageProcessDetails(c *gin.Context) { - messageID := c.Param("id") - if messageID == "" { - c.JSON(http.StatusBadRequest, gin.H{"error": "message id required"}) - return - } - - details, err := h.db.GetProcessDetails(messageID) - if err != nil { - h.logger.Error("获取过程详情失败", zap.Error(err)) - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - - // 转换为前端期望的 JSON 结构(与 GetConversation 中 processDetails 结构一致) - out := make([]map[string]interface{}, 0, len(details)) - for _, d := range details { - var data interface{} - if d.Data != "" { - if err := json.Unmarshal([]byte(d.Data), &data); err != nil { - h.logger.Warn("解析过程详情数据失败", zap.Error(err)) - } - } - out = append(out, map[string]interface{}{ - "id": d.ID, - "messageId": d.MessageID, - "conversationId": d.ConversationID, - "eventType": d.EventType, - "message": d.Message, - "data": data, - "createdAt": d.CreatedAt, - }) - } - - c.JSON(http.StatusOK, gin.H{"processDetails": out}) -} - -// UpdateConversationRequest 更新对话请求 -type UpdateConversationRequest struct { - Title string `json:"title"` -} - -// UpdateConversation 更新对话 -func (h *ConversationHandler) UpdateConversation(c *gin.Context) { - id := c.Param("id") - - var req UpdateConversationRequest - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - - if req.Title == "" { - c.JSON(http.StatusBadRequest, gin.H{"error": "标题不能为空"}) - return - } - - if err := h.db.UpdateConversationTitle(id, req.Title); err != nil { - h.logger.Error("更新对话失败", zap.Error(err)) - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - - // 返回更新后的对话 - conv, err := h.db.GetConversation(id) - if err != nil { - h.logger.Error("获取更新后的对话失败", zap.Error(err)) - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - - c.JSON(http.StatusOK, conv) -} - -// DeleteConversation 删除对话 -func (h *ConversationHandler) DeleteConversation(c *gin.Context) { - id := c.Param("id") - - if err := h.db.DeleteConversation(id); err != nil { - h.logger.Error("删除对话失败", zap.Error(err)) - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - - c.JSON(http.StatusOK, gin.H{"message": "删除成功"}) -} - -// DeleteTurnRequest 删除一轮对话(POST /api/conversations/:id/delete-turn) -type DeleteTurnRequest struct { - MessageID string `json:"messageId"` -} - -// DeleteConversationTurn 删除锚点消息所在轮次(从该轮 user 到下一轮 user 之前),并清空 last_react_*。 -func (h *ConversationHandler) DeleteConversationTurn(c *gin.Context) { - conversationID := c.Param("id") - if conversationID == "" { - c.JSON(http.StatusBadRequest, gin.H{"error": "conversation id required"}) - return - } - - var req DeleteTurnRequest - if err := c.ShouldBindJSON(&req); err != nil || req.MessageID == "" { - c.JSON(http.StatusBadRequest, gin.H{"error": "messageId required"}) - return - } - - if _, err := h.db.GetConversation(conversationID); err != nil { - c.JSON(http.StatusNotFound, gin.H{"error": "对话不存在"}) - return - } - - deletedIDs, err := h.db.DeleteConversationTurn(conversationID, req.MessageID) - if err != nil { - h.logger.Warn("删除对话轮次失败", - zap.String("conversationId", conversationID), - zap.String("messageId", req.MessageID), - zap.Error(err), - ) - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - - c.JSON(http.StatusOK, gin.H{ - "deletedMessageIds": deletedIDs, - "message": "ok", - }) -} - diff --git a/handler/external_mcp.go b/handler/external_mcp.go deleted file mode 100644 index a8b57ae6..00000000 --- a/handler/external_mcp.go +++ /dev/null @@ -1,542 +0,0 @@ -package handler - -import ( - "fmt" - "net/http" - "os" - "sync" - - "cyberstrike-ai/internal/config" - "cyberstrike-ai/internal/mcp" - - "github.com/gin-gonic/gin" - "go.uber.org/zap" - "gopkg.in/yaml.v3" -) - -// ExternalMCPHandler 外部MCP处理器 -type ExternalMCPHandler struct { - manager *mcp.ExternalMCPManager - config *config.Config - configPath string - logger *zap.Logger - mu sync.RWMutex -} - -// NewExternalMCPHandler 创建外部MCP处理器 -func NewExternalMCPHandler(manager *mcp.ExternalMCPManager, cfg *config.Config, configPath string, logger *zap.Logger) *ExternalMCPHandler { - return &ExternalMCPHandler{ - manager: manager, - config: cfg, - configPath: configPath, - logger: logger, - } -} - -// GetExternalMCPs 获取所有外部MCP配置 -func (h *ExternalMCPHandler) GetExternalMCPs(c *gin.Context) { - h.mu.RLock() - defer h.mu.RUnlock() - - configs := h.manager.GetConfigs() - - // 获取所有外部MCP的工具数量 - toolCounts := h.manager.GetToolCounts() - - // 转换为响应格式 - result := make(map[string]ExternalMCPResponse) - for name, cfg := range configs { - client, exists := h.manager.GetClient(name) - status := "disconnected" - if exists { - status = client.GetStatus() - } else if h.isEnabled(cfg) { - status = "disconnected" - } else { - status = "disabled" - } - - toolCount := toolCounts[name] - errorMsg := "" - if status == "error" { - errorMsg = h.manager.GetError(name) - } - - result[name] = ExternalMCPResponse{ - Config: cfg, - Status: status, - ToolCount: toolCount, - Error: errorMsg, - } - } - - c.JSON(http.StatusOK, gin.H{ - "servers": result, - "stats": h.manager.GetStats(), - }) -} - -// GetExternalMCP 获取单个外部MCP配置 -func (h *ExternalMCPHandler) GetExternalMCP(c *gin.Context) { - name := c.Param("name") - - h.mu.RLock() - defer h.mu.RUnlock() - - configs := h.manager.GetConfigs() - cfg, exists := configs[name] - if !exists { - c.JSON(http.StatusNotFound, gin.H{"error": "外部MCP配置不存在"}) - return - } - - client, clientExists := h.manager.GetClient(name) - status := "disconnected" - if clientExists { - status = client.GetStatus() - } else if h.isEnabled(cfg) { - status = "disconnected" - } else { - status = "disabled" - } - - // 获取工具数量 - toolCount := 0 - if clientExists && client.IsConnected() { - if count, err := h.manager.GetToolCount(name); err == nil { - toolCount = count - } - } - - // 获取错误信息 - errorMsg := "" - if status == "error" { - errorMsg = h.manager.GetError(name) - } - - c.JSON(http.StatusOK, ExternalMCPResponse{ - Config: cfg, - Status: status, - ToolCount: toolCount, - Error: errorMsg, - }) -} - -// AddOrUpdateExternalMCP 添加或更新外部MCP配置 -func (h *ExternalMCPHandler) AddOrUpdateExternalMCP(c *gin.Context) { - var req AddOrUpdateExternalMCPRequest - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()}) - return - } - - name := c.Param("name") - if name == "" { - c.JSON(http.StatusBadRequest, gin.H{"error": "名称不能为空"}) - return - } - - // 验证配置 - if err := h.validateConfig(req.Config); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - - h.mu.Lock() - defer h.mu.Unlock() - - // 添加或更新配置 - if err := h.manager.AddOrUpdateConfig(name, req.Config); err != nil { - h.logger.Error("添加或更新外部MCP配置失败", zap.Error(err)) - c.JSON(http.StatusInternalServerError, gin.H{"error": "添加或更新配置失败: " + err.Error()}) - return - } - - // 更新内存中的配置 - if h.config.ExternalMCP.Servers == nil { - h.config.ExternalMCP.Servers = make(map[string]config.ExternalMCPServerConfig) - } - - // 如果用户提供了 disabled 或 enabled 字段,保留它们以保持向后兼容 - // 同时将值迁移到 external_mcp_enable - cfg := req.Config - - if req.Config.Disabled { - // 用户设置了 disabled: true - cfg.ExternalMCPEnable = false - cfg.Disabled = true - cfg.Enabled = false - } else if req.Config.Enabled { - // 用户设置了 enabled: true - cfg.ExternalMCPEnable = true - cfg.Enabled = true - cfg.Disabled = false - } else if !req.Config.ExternalMCPEnable { - // 用户没有设置任何字段,且 external_mcp_enable 为 false - // 检查现有配置是否有旧字段 - if existingCfg, exists := h.config.ExternalMCP.Servers[name]; exists { - // 保留现有的旧字段 - cfg.Enabled = existingCfg.Enabled - cfg.Disabled = existingCfg.Disabled - } - } else { - // 用户通过新字段启用了(external_mcp_enable: true),但没有设置旧字段 - // 为了向后兼容,我们设置 enabled: true - // 这样即使原始配置中有 disabled: false,也会被转换为 enabled: true - cfg.Enabled = true - cfg.Disabled = false - } - - h.config.ExternalMCP.Servers[name] = cfg - - // 保存到配置文件 - if err := h.saveConfig(); err != nil { - h.logger.Error("保存配置失败", zap.Error(err)) - c.JSON(http.StatusInternalServerError, gin.H{"error": "保存配置失败: " + err.Error()}) - return - } - - h.logger.Info("外部MCP配置已更新", zap.String("name", name)) - c.JSON(http.StatusOK, gin.H{"message": "配置已更新"}) -} - -// DeleteExternalMCP 删除外部MCP配置 -func (h *ExternalMCPHandler) DeleteExternalMCP(c *gin.Context) { - name := c.Param("name") - - h.mu.Lock() - defer h.mu.Unlock() - - // 移除配置 - if err := h.manager.RemoveConfig(name); err != nil { - c.JSON(http.StatusNotFound, gin.H{"error": "配置不存在"}) - return - } - - // 从内存配置中删除 - if h.config.ExternalMCP.Servers != nil { - delete(h.config.ExternalMCP.Servers, name) - } - - // 保存到配置文件 - if err := h.saveConfig(); err != nil { - h.logger.Error("保存配置失败", zap.Error(err)) - c.JSON(http.StatusInternalServerError, gin.H{"error": "保存配置失败: " + err.Error()}) - return - } - - h.logger.Info("外部MCP配置已删除", zap.String("name", name)) - c.JSON(http.StatusOK, gin.H{"message": "配置已删除"}) -} - -// StartExternalMCP 启动外部MCP -func (h *ExternalMCPHandler) StartExternalMCP(c *gin.Context) { - name := c.Param("name") - - h.mu.Lock() - defer h.mu.Unlock() - - // 更新配置为启用 - if h.config.ExternalMCP.Servers == nil { - h.config.ExternalMCP.Servers = make(map[string]config.ExternalMCPServerConfig) - } - cfg := h.config.ExternalMCP.Servers[name] - cfg.ExternalMCPEnable = true - h.config.ExternalMCP.Servers[name] = cfg - - // 保存到配置文件 - if err := h.saveConfig(); err != nil { - h.logger.Error("保存配置失败", zap.Error(err)) - c.JSON(http.StatusInternalServerError, gin.H{"error": "保存配置失败: " + err.Error()}) - return - } - - // 启动客户端(立即创建客户端并设置状态为connecting,实际连接在后台进行) - h.logger.Info("开始启动外部MCP", zap.String("name", name)) - if err := h.manager.StartClient(name); err != nil { - h.logger.Error("启动外部MCP失败", zap.String("name", name), zap.Error(err)) - c.JSON(http.StatusBadRequest, gin.H{ - "error": err.Error(), - "status": "error", - }) - return - } - - // 获取客户端状态(应该是connecting) - client, exists := h.manager.GetClient(name) - status := "connecting" - if exists { - status = client.GetStatus() - } - - // 立即返回,不等待连接完成 - // 客户端会在后台异步连接,用户可以通过状态查询接口查看连接状态 - c.JSON(http.StatusOK, gin.H{ - "message": "外部MCP启动请求已提交,正在后台连接中", - "status": status, - }) -} - -// StopExternalMCP 停止外部MCP -func (h *ExternalMCPHandler) StopExternalMCP(c *gin.Context) { - name := c.Param("name") - - h.mu.Lock() - defer h.mu.Unlock() - - // 停止客户端 - if err := h.manager.StopClient(name); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - - // 更新配置 - if h.config.ExternalMCP.Servers == nil { - h.config.ExternalMCP.Servers = make(map[string]config.ExternalMCPServerConfig) - } - cfg := h.config.ExternalMCP.Servers[name] - cfg.ExternalMCPEnable = false - h.config.ExternalMCP.Servers[name] = cfg - - // 保存到配置文件 - if err := h.saveConfig(); err != nil { - h.logger.Error("保存配置失败", zap.Error(err)) - c.JSON(http.StatusInternalServerError, gin.H{"error": "保存配置失败: " + err.Error()}) - return - } - - h.logger.Info("外部MCP已停止", zap.String("name", name)) - c.JSON(http.StatusOK, gin.H{"message": "外部MCP已停止"}) -} - -// GetExternalMCPStats 获取统计信息 -func (h *ExternalMCPHandler) GetExternalMCPStats(c *gin.Context) { - stats := h.manager.GetStats() - c.JSON(http.StatusOK, stats) -} - -// validateConfig 验证配置 -func (h *ExternalMCPHandler) validateConfig(cfg config.ExternalMCPServerConfig) error { - transport := cfg.Transport - if transport == "" { - // 如果没有指定transport,根据是否有command或url判断 - if cfg.Command != "" { - transport = "stdio" - } else if cfg.URL != "" { - transport = "http" - } else { - return fmt.Errorf("需要指定command(stdio模式)或url(http/sse模式)") - } - } - - switch transport { - case "http": - if cfg.URL == "" { - return fmt.Errorf("HTTP模式需要URL") - } - case "stdio": - if cfg.Command == "" { - return fmt.Errorf("stdio模式需要command") - } - case "sse": - if cfg.URL == "" { - return fmt.Errorf("SSE模式需要URL") - } - default: - return fmt.Errorf("不支持的传输模式: %s,支持的模式: http, stdio, sse", transport) - } - - return nil -} - -// isEnabled 检查是否启用 -func (h *ExternalMCPHandler) isEnabled(cfg config.ExternalMCPServerConfig) bool { - // 优先使用 ExternalMCPEnable 字段 - // 如果没有设置,检查旧的 enabled/disabled 字段(向后兼容) - if cfg.ExternalMCPEnable { - return true - } - // 向后兼容:检查旧字段 - if cfg.Disabled { - return false - } - if cfg.Enabled { - return true - } - // 都没有设置,默认为启用 - return true -} - -// saveConfig 保存配置到文件 -func (h *ExternalMCPHandler) saveConfig() error { - // 读取现有配置文件并创建备份 - data, err := os.ReadFile(h.configPath) - if err != nil { - return fmt.Errorf("读取配置文件失败: %w", err) - } - - if err := os.WriteFile(h.configPath+".backup", data, 0644); err != nil { - h.logger.Warn("创建配置备份失败", zap.Error(err)) - } - - root, err := loadYAMLDocument(h.configPath) - if err != nil { - return fmt.Errorf("解析配置文件失败: %w", err) - } - - // 在更新前,读取原始配置中的 enabled/disabled 字段,以便保持向后兼容 - originalConfigs := make(map[string]map[string]bool) - externalMCPNode := findMapValue(root.Content[0], "external_mcp") - if externalMCPNode != nil && externalMCPNode.Kind == yaml.MappingNode { - serversNode := findMapValue(externalMCPNode, "servers") - if serversNode != nil && serversNode.Kind == yaml.MappingNode { - // 遍历现有的服务器配置,保存 enabled/disabled 字段 - for i := 0; i < len(serversNode.Content); i += 2 { - if i+1 >= len(serversNode.Content) { - break - } - nameNode := serversNode.Content[i] - serverNode := serversNode.Content[i+1] - if nameNode.Kind == yaml.ScalarNode && serverNode.Kind == yaml.MappingNode { - serverName := nameNode.Value - originalConfigs[serverName] = make(map[string]bool) - // 检查是否有 enabled 字段 - if enabledVal := findBoolInMap(serverNode, "enabled"); enabledVal != nil { - originalConfigs[serverName]["enabled"] = *enabledVal - } - // 检查是否有 disabled 字段 - if disabledVal := findBoolInMap(serverNode, "disabled"); disabledVal != nil { - originalConfigs[serverName]["disabled"] = *disabledVal - } - } - } - } - } - - // 更新外部MCP配置 - updateExternalMCPConfig(root, h.config.ExternalMCP, originalConfigs) - - if err := writeYAMLDocument(h.configPath, root); err != nil { - return fmt.Errorf("保存配置文件失败: %w", err) - } - - h.logger.Info("配置已保存", zap.String("path", h.configPath)) - return nil -} - -// updateExternalMCPConfig 更新外部MCP配置 -func updateExternalMCPConfig(doc *yaml.Node, cfg config.ExternalMCPConfig, originalConfigs map[string]map[string]bool) { - root := doc.Content[0] - externalMCPNode := ensureMap(root, "external_mcp") - serversNode := ensureMap(externalMCPNode, "servers") - - // 清空现有服务器配置 - serversNode.Content = nil - - // 添加新的服务器配置 - for name, serverCfg := range cfg.Servers { - // 添加服务器名称键 - nameNode := &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: name} - serverNode := &yaml.Node{Kind: yaml.MappingNode, Tag: "!!map"} - serversNode.Content = append(serversNode.Content, nameNode, serverNode) - - // 设置服务器配置字段 - if serverCfg.Command != "" { - setStringInMap(serverNode, "command", serverCfg.Command) - } - if len(serverCfg.Args) > 0 { - setStringArrayInMap(serverNode, "args", serverCfg.Args) - } - // 保存 env 字段(环境变量) - if serverCfg.Env != nil && len(serverCfg.Env) > 0 { - envNode := ensureMap(serverNode, "env") - for envKey, envValue := range serverCfg.Env { - setStringInMap(envNode, envKey, envValue) - } - } - if serverCfg.Transport != "" { - setStringInMap(serverNode, "transport", serverCfg.Transport) - } - if serverCfg.URL != "" { - setStringInMap(serverNode, "url", serverCfg.URL) - } - // 保存 headers 字段(HTTP/SSE 请求头) - if serverCfg.Headers != nil && len(serverCfg.Headers) > 0 { - headersNode := ensureMap(serverNode, "headers") - for k, v := range serverCfg.Headers { - setStringInMap(headersNode, k, v) - } - } - if serverCfg.Description != "" { - setStringInMap(serverNode, "description", serverCfg.Description) - } - if serverCfg.Timeout > 0 { - setIntInMap(serverNode, "timeout", serverCfg.Timeout) - } - // 保存 external_mcp_enable 字段(新字段) - setBoolInMap(serverNode, "external_mcp_enable", serverCfg.ExternalMCPEnable) - // 保存 tool_enabled 字段(每个工具的启用状态) - if serverCfg.ToolEnabled != nil && len(serverCfg.ToolEnabled) > 0 { - toolEnabledNode := ensureMap(serverNode, "tool_enabled") - for toolName, enabled := range serverCfg.ToolEnabled { - setBoolInMap(toolEnabledNode, toolName, enabled) - } - } - // 保留旧的 enabled/disabled 字段以保持向后兼容 - originalFields, hasOriginal := originalConfigs[name] - - // 如果原始配置中有 enabled 字段,保留它 - if hasOriginal { - if enabledVal, hasEnabled := originalFields["enabled"]; hasEnabled { - setBoolInMap(serverNode, "enabled", enabledVal) - } - // 如果原始配置中有 disabled 字段,保留它 - // 注意:由于 omitempty,disabled: false 不会被保存,但 disabled: true 会被保存 - if disabledVal, hasDisabled := originalFields["disabled"]; hasDisabled { - if disabledVal { - setBoolInMap(serverNode, "disabled", disabledVal) - } else { - // 如果原始配置中有 disabled: false,我们保存 enabled: true 来等效表示 - // 因为 disabled: false 等价于 enabled: true - setBoolInMap(serverNode, "enabled", true) - } - } - } - - // 如果用户在当前请求中明确设置了这些字段,也保存它们 - if serverCfg.Enabled { - setBoolInMap(serverNode, "enabled", serverCfg.Enabled) - } - if serverCfg.Disabled { - setBoolInMap(serverNode, "disabled", serverCfg.Disabled) - } else if !hasOriginal && serverCfg.ExternalMCPEnable { - // 如果用户通过新字段启用了,且原始配置中没有旧字段,保存 enabled: true 以保持向后兼容 - setBoolInMap(serverNode, "enabled", true) - } - } -} - -// setStringArrayInMap 设置字符串数组 -func setStringArrayInMap(mapNode *yaml.Node, key string, values []string) { - _, valueNode := ensureKeyValue(mapNode, key) - valueNode.Kind = yaml.SequenceNode - valueNode.Tag = "!!seq" - valueNode.Content = nil - for _, v := range values { - itemNode := &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: v} - valueNode.Content = append(valueNode.Content, itemNode) - } -} - -// AddOrUpdateExternalMCPRequest 添加或更新外部MCP请求 -type AddOrUpdateExternalMCPRequest struct { - Config config.ExternalMCPServerConfig `json:"config"` -} - -// ExternalMCPResponse 外部MCP响应 -type ExternalMCPResponse struct { - Config config.ExternalMCPServerConfig `json:"config"` - Status string `json:"status"` // "connected", "disconnected", "disabled", "error", "connecting" - ToolCount int `json:"tool_count"` // 工具数量 - Error string `json:"error,omitempty"` // 错误信息(仅在status为error时存在) -} diff --git a/handler/external_mcp_test.go b/handler/external_mcp_test.go deleted file mode 100644 index a663c489..00000000 --- a/handler/external_mcp_test.go +++ /dev/null @@ -1,518 +0,0 @@ -package handler - -import ( - "bytes" - "encoding/json" - "net/http" - "net/http/httptest" - "os" - "strings" - "testing" - - "cyberstrike-ai/internal/config" - "cyberstrike-ai/internal/mcp" - - "github.com/gin-gonic/gin" - "go.uber.org/zap" -) - -func setupTestRouter() (*gin.Engine, *ExternalMCPHandler, string) { - gin.SetMode(gin.TestMode) - router := gin.New() - - // 创建临时配置文件 - tmpFile, err := os.CreateTemp("", "test-config-*.yaml") - if err != nil { - panic(err) - } - tmpFile.WriteString("server:\n host: 0.0.0.0\n port: 8080\n") - tmpFile.Close() - configPath := tmpFile.Name() - - logger := zap.NewNop() - manager := mcp.NewExternalMCPManager(logger) - cfg := &config.Config{ - ExternalMCP: config.ExternalMCPConfig{ - Servers: make(map[string]config.ExternalMCPServerConfig), - }, - } - - handler := NewExternalMCPHandler(manager, cfg, configPath, logger) - - api := router.Group("/api") - api.GET("/external-mcp", handler.GetExternalMCPs) - api.GET("/external-mcp/stats", handler.GetExternalMCPStats) - api.GET("/external-mcp/:name", handler.GetExternalMCP) - api.PUT("/external-mcp/:name", handler.AddOrUpdateExternalMCP) - api.DELETE("/external-mcp/:name", handler.DeleteExternalMCP) - api.POST("/external-mcp/:name/start", handler.StartExternalMCP) - api.POST("/external-mcp/:name/stop", handler.StopExternalMCP) - - return router, handler, configPath -} - -func cleanupTestConfig(configPath string) { - os.Remove(configPath) - os.Remove(configPath + ".backup") -} - -func TestExternalMCPHandler_AddOrUpdateExternalMCP_Stdio(t *testing.T) { - router, _, configPath := setupTestRouter() - defer cleanupTestConfig(configPath) - - // 测试添加stdio模式的配置 - configJSON := `{ - "command": "python3", - "args": ["/path/to/script.py", "--server", "http://example.com"], - "description": "Test stdio MCP", - "timeout": 300, - "enabled": true - }` - - var configObj config.ExternalMCPServerConfig - if err := json.Unmarshal([]byte(configJSON), &configObj); err != nil { - t.Fatalf("解析配置JSON失败: %v", err) - } - - reqBody := AddOrUpdateExternalMCPRequest{ - Config: configObj, - } - - body, _ := json.Marshal(reqBody) - req := httptest.NewRequest("PUT", "/api/external-mcp/test-stdio", bytes.NewBuffer(body)) - req.Header.Set("Content-Type", "application/json") - w := httptest.NewRecorder() - - router.ServeHTTP(w, req) - - if w.Code != http.StatusOK { - t.Fatalf("期望状态码200,实际%d: %s", w.Code, w.Body.String()) - } - - // 验证配置已添加 - req2 := httptest.NewRequest("GET", "/api/external-mcp/test-stdio", nil) - w2 := httptest.NewRecorder() - router.ServeHTTP(w2, req2) - - if w2.Code != http.StatusOK { - t.Fatalf("期望状态码200,实际%d: %s", w2.Code, w2.Body.String()) - } - - var response ExternalMCPResponse - if err := json.Unmarshal(w2.Body.Bytes(), &response); err != nil { - t.Fatalf("解析响应失败: %v", err) - } - - if response.Config.Command != "python3" { - t.Errorf("期望command为python3,实际%s", response.Config.Command) - } - if len(response.Config.Args) != 3 { - t.Errorf("期望args长度为3,实际%d", len(response.Config.Args)) - } - if response.Config.Description != "Test stdio MCP" { - t.Errorf("期望description为'Test stdio MCP',实际%s", response.Config.Description) - } - if response.Config.Timeout != 300 { - t.Errorf("期望timeout为300,实际%d", response.Config.Timeout) - } - if !response.Config.Enabled { - t.Error("期望enabled为true") - } -} - -func TestExternalMCPHandler_AddOrUpdateExternalMCP_HTTP(t *testing.T) { - router, _, configPath := setupTestRouter() - defer cleanupTestConfig(configPath) - - // 测试添加HTTP模式的配置 - configJSON := `{ - "transport": "http", - "url": "http://127.0.0.1:8081/mcp", - "enabled": true - }` - - var configObj config.ExternalMCPServerConfig - if err := json.Unmarshal([]byte(configJSON), &configObj); err != nil { - t.Fatalf("解析配置JSON失败: %v", err) - } - - reqBody := AddOrUpdateExternalMCPRequest{ - Config: configObj, - } - - body, _ := json.Marshal(reqBody) - req := httptest.NewRequest("PUT", "/api/external-mcp/test-http", bytes.NewBuffer(body)) - req.Header.Set("Content-Type", "application/json") - w := httptest.NewRecorder() - - router.ServeHTTP(w, req) - - if w.Code != http.StatusOK { - t.Fatalf("期望状态码200,实际%d: %s", w.Code, w.Body.String()) - } - - // 验证配置已添加 - req2 := httptest.NewRequest("GET", "/api/external-mcp/test-http", nil) - w2 := httptest.NewRecorder() - router.ServeHTTP(w2, req2) - - if w2.Code != http.StatusOK { - t.Fatalf("期望状态码200,实际%d: %s", w2.Code, w2.Body.String()) - } - - var response ExternalMCPResponse - if err := json.Unmarshal(w2.Body.Bytes(), &response); err != nil { - t.Fatalf("解析响应失败: %v", err) - } - - if response.Config.Transport != "http" { - t.Errorf("期望transport为http,实际%s", response.Config.Transport) - } - if response.Config.URL != "http://127.0.0.1:8081/mcp" { - t.Errorf("期望url为'http://127.0.0.1:8081/mcp',实际%s", response.Config.URL) - } - if !response.Config.Enabled { - t.Error("期望enabled为true") - } -} - -func TestExternalMCPHandler_AddOrUpdateExternalMCP_InvalidConfig(t *testing.T) { - router, _, configPath := setupTestRouter() - defer cleanupTestConfig(configPath) - - testCases := []struct { - name string - configJSON string - expectedErr string - }{ - { - name: "缺少command和url", - configJSON: `{"enabled": true}`, - expectedErr: "需要指定command(stdio模式)或url(http/sse模式)", - }, - { - name: "stdio模式缺少command", - configJSON: `{"args": ["test"], "enabled": true}`, - expectedErr: "stdio模式需要command", - }, - { - name: "http模式缺少url", - configJSON: `{"transport": "http", "enabled": true}`, - expectedErr: "HTTP模式需要URL", - }, - { - name: "无效的transport", - configJSON: `{"transport": "invalid", "enabled": true}`, - expectedErr: "不支持的传输模式", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - var configObj config.ExternalMCPServerConfig - if err := json.Unmarshal([]byte(tc.configJSON), &configObj); err != nil { - t.Fatalf("解析配置JSON失败: %v", err) - } - - reqBody := AddOrUpdateExternalMCPRequest{ - Config: configObj, - } - - body, _ := json.Marshal(reqBody) - req := httptest.NewRequest("PUT", "/api/external-mcp/test-invalid", bytes.NewBuffer(body)) - req.Header.Set("Content-Type", "application/json") - w := httptest.NewRecorder() - - router.ServeHTTP(w, req) - - if w.Code != http.StatusBadRequest { - t.Errorf("期望状态码400,实际%d: %s", w.Code, w.Body.String()) - } - - var response map[string]interface{} - if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil { - t.Fatalf("解析响应失败: %v", err) - } - - errorMsg := response["error"].(string) - // 对于stdio模式缺少command的情况,错误信息可能略有不同 - if tc.name == "stdio模式缺少command" { - if !strings.Contains(errorMsg, "stdio") && !strings.Contains(errorMsg, "command") { - t.Errorf("期望错误信息包含'stdio'或'command',实际'%s'", errorMsg) - } - } else if !strings.Contains(errorMsg, tc.expectedErr) { - t.Errorf("期望错误信息包含'%s',实际'%s'", tc.expectedErr, errorMsg) - } - }) - } -} - -func TestExternalMCPHandler_DeleteExternalMCP(t *testing.T) { - router, handler, configPath := setupTestRouter() - defer cleanupTestConfig(configPath) - - // 先添加一个配置 - configObj := config.ExternalMCPServerConfig{ - Command: "python3", - Enabled: true, - } - handler.manager.AddOrUpdateConfig("test-delete", configObj) - - // 删除配置 - req := httptest.NewRequest("DELETE", "/api/external-mcp/test-delete", nil) - w := httptest.NewRecorder() - router.ServeHTTP(w, req) - - if w.Code != http.StatusOK { - t.Fatalf("期望状态码200,实际%d: %s", w.Code, w.Body.String()) - } - - // 验证配置已删除 - req2 := httptest.NewRequest("GET", "/api/external-mcp/test-delete", nil) - w2 := httptest.NewRecorder() - router.ServeHTTP(w2, req2) - - if w2.Code != http.StatusNotFound { - t.Errorf("期望状态码404,实际%d: %s", w2.Code, w2.Body.String()) - } -} - -func TestExternalMCPHandler_GetExternalMCPs(t *testing.T) { - router, handler, _ := setupTestRouter() - - // 添加多个配置 - handler.manager.AddOrUpdateConfig("test1", config.ExternalMCPServerConfig{ - Command: "python3", - Enabled: true, - }) - handler.manager.AddOrUpdateConfig("test2", config.ExternalMCPServerConfig{ - URL: "http://127.0.0.1:8081/mcp", - Enabled: false, - }) - - req := httptest.NewRequest("GET", "/api/external-mcp", nil) - w := httptest.NewRecorder() - router.ServeHTTP(w, req) - - if w.Code != http.StatusOK { - t.Fatalf("期望状态码200,实际%d: %s", w.Code, w.Body.String()) - } - - var response map[string]interface{} - if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil { - t.Fatalf("解析响应失败: %v", err) - } - - servers := response["servers"].(map[string]interface{}) - if len(servers) != 2 { - t.Errorf("期望2个服务器,实际%d", len(servers)) - } - if _, ok := servers["test1"]; !ok { - t.Error("期望包含test1") - } - if _, ok := servers["test2"]; !ok { - t.Error("期望包含test2") - } - - stats := response["stats"].(map[string]interface{}) - if int(stats["total"].(float64)) != 2 { - t.Errorf("期望总数为2,实际%d", int(stats["total"].(float64))) - } -} - -func TestExternalMCPHandler_GetExternalMCPStats(t *testing.T) { - router, handler, _ := setupTestRouter() - - // 添加配置 - handler.manager.AddOrUpdateConfig("enabled1", config.ExternalMCPServerConfig{ - Command: "python3", - Enabled: true, - }) - handler.manager.AddOrUpdateConfig("enabled2", config.ExternalMCPServerConfig{ - URL: "http://127.0.0.1:8081/mcp", - Enabled: true, - }) - handler.manager.AddOrUpdateConfig("disabled1", config.ExternalMCPServerConfig{ - Command: "python3", - Enabled: false, - Disabled: true, - }) - - req := httptest.NewRequest("GET", "/api/external-mcp/stats", nil) - w := httptest.NewRecorder() - router.ServeHTTP(w, req) - - if w.Code != http.StatusOK { - t.Fatalf("期望状态码200,实际%d: %s", w.Code, w.Body.String()) - } - - var stats map[string]interface{} - if err := json.Unmarshal(w.Body.Bytes(), &stats); err != nil { - t.Fatalf("解析响应失败: %v", err) - } - - if int(stats["total"].(float64)) != 3 { - t.Errorf("期望总数为3,实际%d", int(stats["total"].(float64))) - } - if int(stats["enabled"].(float64)) != 2 { - t.Errorf("期望启用数为2,实际%d", int(stats["enabled"].(float64))) - } - if int(stats["disabled"].(float64)) != 1 { - t.Errorf("期望停用数为1,实际%d", int(stats["disabled"].(float64))) - } -} - -func TestExternalMCPHandler_StartStopExternalMCP(t *testing.T) { - router, handler, configPath := setupTestRouter() - defer cleanupTestConfig(configPath) - - // 添加一个禁用的配置 - handler.manager.AddOrUpdateConfig("test-start-stop", config.ExternalMCPServerConfig{ - Command: "python3", - Enabled: false, - Disabled: true, - }) - - // 测试启动(可能会失败,因为没有真实的服务器) - req := httptest.NewRequest("POST", "/api/external-mcp/test-start-stop/start", nil) - w := httptest.NewRecorder() - router.ServeHTTP(w, req) - - // 启动可能会失败,但应该返回合理的状态码 - if w.Code != http.StatusOK { - // 如果启动失败,应该是400或500 - if w.Code != http.StatusBadRequest && w.Code != http.StatusInternalServerError { - t.Errorf("期望状态码200/400/500,实际%d: %s", w.Code, w.Body.String()) - } - } - - // 测试停止 - req2 := httptest.NewRequest("POST", "/api/external-mcp/test-start-stop/stop", nil) - w2 := httptest.NewRecorder() - router.ServeHTTP(w2, req2) - - if w2.Code != http.StatusOK { - t.Errorf("期望状态码200,实际%d: %s", w2.Code, w2.Body.String()) - } -} - -func TestExternalMCPHandler_GetExternalMCP_NotFound(t *testing.T) { - router, _, _ := setupTestRouter() - - req := httptest.NewRequest("GET", "/api/external-mcp/nonexistent", nil) - w := httptest.NewRecorder() - router.ServeHTTP(w, req) - - if w.Code != http.StatusNotFound { - t.Errorf("期望状态码404,实际%d: %s", w.Code, w.Body.String()) - } -} - -func TestExternalMCPHandler_DeleteExternalMCP_NotFound(t *testing.T) { - router, _, configPath := setupTestRouter() - defer cleanupTestConfig(configPath) - - req := httptest.NewRequest("DELETE", "/api/external-mcp/nonexistent", nil) - w := httptest.NewRecorder() - router.ServeHTTP(w, req) - - // 删除不存在的配置可能返回200(幂等操作)或404,都是合理的 - if w.Code != http.StatusNotFound && w.Code != http.StatusOK { - t.Errorf("期望状态码404或200,实际%d: %s", w.Code, w.Body.String()) - } -} - -func TestExternalMCPHandler_AddOrUpdateExternalMCP_EmptyName(t *testing.T) { - router, _, _ := setupTestRouter() - - configObj := config.ExternalMCPServerConfig{ - Command: "python3", - Enabled: true, - } - - reqBody := AddOrUpdateExternalMCPRequest{ - Config: configObj, - } - - body, _ := json.Marshal(reqBody) - req := httptest.NewRequest("PUT", "/api/external-mcp/", bytes.NewBuffer(body)) - req.Header.Set("Content-Type", "application/json") - w := httptest.NewRecorder() - - router.ServeHTTP(w, req) - - // 空名称应该返回404或400 - if w.Code != http.StatusNotFound && w.Code != http.StatusBadRequest { - t.Errorf("期望状态码404或400,实际%d: %s", w.Code, w.Body.String()) - } -} - -func TestExternalMCPHandler_AddOrUpdateExternalMCP_InvalidJSON(t *testing.T) { - router, _, _ := setupTestRouter() - - // 发送无效的JSON - body := []byte(`{"config": invalid json}`) - req := httptest.NewRequest("PUT", "/api/external-mcp/test", bytes.NewBuffer(body)) - req.Header.Set("Content-Type", "application/json") - w := httptest.NewRecorder() - - router.ServeHTTP(w, req) - - if w.Code != http.StatusBadRequest { - t.Errorf("期望状态码400,实际%d: %s", w.Code, w.Body.String()) - } -} - -func TestExternalMCPHandler_UpdateExistingConfig(t *testing.T) { - router, handler, configPath := setupTestRouter() - defer cleanupTestConfig(configPath) - - // 先添加配置 - config1 := config.ExternalMCPServerConfig{ - Command: "python3", - Enabled: true, - } - handler.manager.AddOrUpdateConfig("test-update", config1) - - // 更新配置 - config2 := config.ExternalMCPServerConfig{ - URL: "http://127.0.0.1:8081/mcp", - Enabled: true, - } - - reqBody := AddOrUpdateExternalMCPRequest{ - Config: config2, - } - - body, _ := json.Marshal(reqBody) - req := httptest.NewRequest("PUT", "/api/external-mcp/test-update", bytes.NewBuffer(body)) - req.Header.Set("Content-Type", "application/json") - w := httptest.NewRecorder() - - router.ServeHTTP(w, req) - - if w.Code != http.StatusOK { - t.Fatalf("期望状态码200,实际%d: %s", w.Code, w.Body.String()) - } - - // 验证配置已更新 - req2 := httptest.NewRequest("GET", "/api/external-mcp/test-update", nil) - w2 := httptest.NewRecorder() - router.ServeHTTP(w2, req2) - - if w2.Code != http.StatusOK { - t.Fatalf("期望状态码200,实际%d: %s", w2.Code, w2.Body.String()) - } - - var response ExternalMCPResponse - if err := json.Unmarshal(w2.Body.Bytes(), &response); err != nil { - t.Fatalf("解析响应失败: %v", err) - } - - if response.Config.URL != "http://127.0.0.1:8081/mcp" { - t.Errorf("期望url为'http://127.0.0.1:8081/mcp',实际%s", response.Config.URL) - } - if response.Config.Command != "" { - t.Errorf("期望command为空,实际%s", response.Config.Command) - } -} diff --git a/handler/fofa.go b/handler/fofa.go deleted file mode 100644 index 1b8d1db4..00000000 --- a/handler/fofa.go +++ /dev/null @@ -1,467 +0,0 @@ -package handler - -import ( - "context" - "encoding/base64" - "encoding/json" - "errors" - "fmt" - "net/http" - "net/url" - "os" - "strings" - "time" - - "cyberstrike-ai/internal/config" - openaiClient "cyberstrike-ai/internal/openai" - - "github.com/gin-gonic/gin" - "go.uber.org/zap" -) - -type FofaHandler struct { - cfg *config.Config - logger *zap.Logger - client *http.Client - openAIClient *openaiClient.Client -} - -func NewFofaHandler(cfg *config.Config, logger *zap.Logger) *FofaHandler { - // LLM 请求通常比 FOFA 查询更慢一点,单独给一个更宽松的超时。 - llmHTTPClient := &http.Client{Timeout: 2 * time.Minute} - var llmCfg *config.OpenAIConfig - if cfg != nil { - llmCfg = &cfg.OpenAI - } - return &FofaHandler{ - cfg: cfg, - logger: logger, - client: &http.Client{Timeout: 30 * time.Second}, - openAIClient: openaiClient.NewClient(llmCfg, llmHTTPClient, logger), - } -} - -type fofaSearchRequest struct { - Query string `json:"query" binding:"required"` - Size int `json:"size,omitempty"` - Page int `json:"page,omitempty"` - Fields string `json:"fields,omitempty"` - Full bool `json:"full,omitempty"` -} - -type fofaParseRequest struct { - Text string `json:"text" binding:"required"` -} - -type fofaParseResponse struct { - Query string `json:"query"` - Explanation string `json:"explanation,omitempty"` - Warnings []string `json:"warnings,omitempty"` -} - -type fofaAPIResponse struct { - Error bool `json:"error"` - ErrMsg string `json:"errmsg"` - Size int `json:"size"` - Page int `json:"page"` - Total int `json:"total"` - Mode string `json:"mode"` - Query string `json:"query"` - Results [][]interface{} `json:"results"` -} - -type fofaSearchResponse struct { - Query string `json:"query"` - Size int `json:"size"` - Page int `json:"page"` - Total int `json:"total"` - Fields []string `json:"fields"` - ResultsCount int `json:"results_count"` - Results []map[string]interface{} `json:"results"` -} - -func (h *FofaHandler) resolveCredentials() (email, apiKey string) { - // 优先环境变量(便于容器部署),其次配置文件 - email = strings.TrimSpace(os.Getenv("FOFA_EMAIL")) - apiKey = strings.TrimSpace(os.Getenv("FOFA_API_KEY")) - if email != "" && apiKey != "" { - return email, apiKey - } - if h.cfg != nil { - if email == "" { - email = strings.TrimSpace(h.cfg.FOFA.Email) - } - if apiKey == "" { - apiKey = strings.TrimSpace(h.cfg.FOFA.APIKey) - } - } - return email, apiKey -} - -func (h *FofaHandler) resolveBaseURL() string { - if h.cfg != nil { - if v := strings.TrimSpace(h.cfg.FOFA.BaseURL); v != "" { - return v - } - } - return "https://fofa.info/api/v1/search/all" -} - -// ParseNaturalLanguage 将自然语言解析为 FOFA 查询语法(仅生成,不执行查询) -func (h *FofaHandler) ParseNaturalLanguage(c *gin.Context) { - var req fofaParseRequest - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()}) - return - } - req.Text = strings.TrimSpace(req.Text) - if req.Text == "" { - c.JSON(http.StatusBadRequest, gin.H{"error": "text 不能为空"}) - return - } - - if h.cfg == nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "系统配置未初始化"}) - return - } - if strings.TrimSpace(h.cfg.OpenAI.APIKey) == "" || strings.TrimSpace(h.cfg.OpenAI.Model) == "" { - c.JSON(http.StatusBadRequest, gin.H{ - "error": "未配置 AI 模型:请在系统设置中填写 openai.api_key 与 openai.model(支持 OpenAI 兼容 API,如 DeepSeek)", - "need": []string{"openai.api_key", "openai.model"}, - }) - return - } - if h.openAIClient == nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "AI 客户端未初始化"}) - return - } - - systemPrompt := strings.TrimSpace(` -你是“FOFA 查询语法生成器”。任务:把用户输入的自然语言搜索意图,转换成 FOFA 查询语法。 - -输出要求(非常重要): -1) 只输出 JSON(不要 markdown、不要代码块、不要额外解释文本) -2) JSON 结构必须是: -{ - "query": "string,FOFA查询语法(可直接粘贴到 FOFA 或本系统查询框)", - "explanation": "string,可选,解释你如何映射字段/逻辑", - "warnings": ["string"...] 可选,列出歧义/风险/需要人工确认的点 -} -3) 如果用户输入本身已经是 FOFA 查询语法(或非常接近 FOFA 语法的表达式),应当“原样返回”为 query: - - 不要擅自改写字段名、操作符、括号结构 - - 不要改写任何字符串值(尤其是地理位置类值),不要做缩写/同义词替换/翻译/音译 - -查询语法要点(来自 FOFA 语法参考): -- 逻辑连接符:&&(与)、||(或),必要时用 () 包住子表达式以确认优先级(括号优先级最高) -- 当同一层级同时出现 && 与 ||(混用)时,用 () 明确优先级(避免歧义) -- 比较/匹配: - - = 匹配;当字段="" 时,可查询“不存在该字段”或“值为空”的情况 - - == 完全匹配;当字段=="" 时,可查询“字段存在且值为空”的情况 - - != 不匹配;当字段!="" 时,可查询“值不为空”的情况 - - *= 模糊匹配;可使用 * 或 ? 进行搜索 -- 直接输入关键词(不带字段)会在标题、HTML内容、HTTP头、URL字段中搜索;但当意图明确时优先用字段表达(更可控、更准确) - -字段示例速查(来自用户提供的案例,可直接套用/拼接): -- 高级搜索操作符示例: - - title="beijing" (= 匹配) - - title=="" (== 完全匹配,字段存在且值为空) - - title="" (= 匹配,可能表示字段不存在或值为空) - - title!="" (!= 不匹配,可用于值不为空) - - title*="*Home*" (*= 模糊匹配,用 * 或 ?) - - (app="Apache" || app="Nginx") && country="CN" (混用 && / || 时用括号) -- 基础类(General): - - ip="1.1.1.1" - - ip="220.181.111.1/24" - - ip="2600:9000:202a:2600:18:4ab7:f600:93a1" - - port="6379" - - domain="qq.com" - - host=".fofa.info" - - os="centos" - - server="Microsoft-IIS/10" - - asn="19551" - - org="LLC Baxet" - - is_domain=true / is_domain=false - - is_ipv6=true / is_ipv6=false -- 标记类(Special Label): - - app="Microsoft-Exchange" - - fid="sSXXGNUO2FefBTcCLIT/2Q==" - - product="NGINX" - - product="Roundcube-Webmail" && product.version="1.6.10" - - category="服务" - - type="service" / type="subdomain" - - cloud_name="Aliyundun" - - is_cloud=true / is_cloud=false - - is_fraud=true / is_fraud=false - - is_honeypot=true / is_honeypot=false -- 协议类(type=service): - - protocol="quic" - - banner="users" - - banner_hash="7330105010150477363" - - banner_fid="zRpqmn0FXQRjZpH8MjMX55zpMy9SgsW8" - - base_protocol="udp" / base_protocol="tcp" -- 网站类(type=subdomain): - - title="beijing" - - header="elastic" - - header_hash="1258854265" - - body="网络空间测绘" - - body_hash="-2090962452" - - js_name="js/jquery.js" - - js_md5="82ac3f14327a8b7ba49baa208d4eaa15" - - cname="customers.spektrix.com" - - cname_domain="siteforce.com" - - icon_hash="-247388890" - - status_code="402" - - icp="京ICP证030173号" - - sdk_hash="Are3qNnP2Eqn7q5kAoUO3l+w3mgVIytO" -- 地理位置(Location): - - country="CN" 或 country="中国" - - region="Zhejiang" 或 region="浙江"(仅支持中国地区中文) - - city="Hangzhou" -- 证书类(Certificate): - - cert="baidu" - - cert.subject="Oracle Corporation" - - cert.issuer="DigiCert" - - cert.subject.org="Oracle Corporation" - - cert.subject.cn="baidu.com" - - cert.issuer.org="cPanel, Inc." - - cert.issuer.cn="Synology Inc. CA" - - cert.domain="huawei.com" - - cert.is_equal=true / cert.is_equal=false - - cert.is_valid=true / cert.is_valid=false - - cert.is_match=true / cert.is_match=false - - cert.is_expired=true / cert.is_expired=false - - jarm="2ad2ad0002ad2ad22c2ad2ad2ad2ad2eac92ec34bcc0cf7520e97547f83e81" - - tls.version="TLS 1.3" - - tls.ja3s="15af977ce25de452b96affa2addb1036" - - cert.sn="356078156165546797850343536942784588840297" - - cert.not_after.after="2025-03-01" / cert.not_after.before="2025-03-01" - - cert.not_before.after="2025-03-01" / cert.not_before.before="2025-03-01" -- 时间类(Last update time): - - after="2023-01-01" - - before="2023-12-01" - - after="2023-01-01" && before="2023-12-01" -- 独立IP语法(需配合 ip_filter / ip_exclude): - - ip_filter(banner="SSH-2.0-OpenSSH_6.7p2") && ip_filter(icon_hash="-1057022626") - - ip_filter(banner="SSH-2.0-OpenSSH_6.7p2" && asn="3462") && ip_exclude(title="EdgeOS") - - port_size="6" / port_size_gt="6" / port_size_lt="12" - - ip_ports="80,161" - - ip_country="CN" - - ip_region="Zhejiang" - - ip_city="Hangzhou" - - ip_after="2021-03-18" - - ip_before="2019-09-09" - -生成约束与注意事项: -- 字符串值一律用英文双引号包裹,例如 title="登录"、country="CN" -- 字符串值保持字面一致:不要缩写(例如 city="beijing" 不要变成 city="BJ"),不要用别名(例如 Beijing/Peking),不要擅自翻译/音译/改写大小写 -- 地理位置字段(country/region/city)更倾向于“按用户给定值输出”;不确定合法取值时,不要猜测,把备选写进 warnings -- 不要捏造不存在的 FOFA 字段;不确定时把不确定点写进 warnings,并输出一个保守的 query -- 当用户描述里有“多个与/或条件”,优先加 () 明确优先级,例如:(app="Apache" || app="Nginx") && country="CN" -- 当用户缺少关键条件导致范围过大或歧义(如地点/协议/端口/服务类型未说明),允许 query 为空字符串,并在 warnings 里明确需要补充的信息 -`) - - userPrompt := fmt.Sprintf("自然语言意图:%s", req.Text) - - requestBody := map[string]interface{}{ - "model": h.cfg.OpenAI.Model, - "messages": []map[string]interface{}{ - {"role": "system", "content": systemPrompt}, - {"role": "user", "content": userPrompt}, - }, - "temperature": 0.1, - "max_tokens": 1200, - } - - // OpenAI 返回结构:只需要 choices[0].message.content - var apiResponse struct { - Choices []struct { - Message struct { - Content string `json:"content"` - } `json:"message"` - } `json:"choices"` - } - - ctx, cancel := context.WithTimeout(c.Request.Context(), 90*time.Second) - defer cancel() - - if err := h.openAIClient.ChatCompletion(ctx, requestBody, &apiResponse); err != nil { - var apiErr *openaiClient.APIError - if errors.As(err, &apiErr) { - h.logger.Warn("FOFA自然语言解析:LLM返回错误", zap.Int("status", apiErr.StatusCode)) - c.JSON(http.StatusBadGateway, gin.H{"error": "AI 解析失败(上游返回非 200),请检查模型配置或稍后重试"}) - return - } - c.JSON(http.StatusBadGateway, gin.H{"error": "AI 解析失败: " + err.Error()}) - return - } - if len(apiResponse.Choices) == 0 { - c.JSON(http.StatusBadGateway, gin.H{"error": "AI 未返回有效结果"}) - return - } - - content := strings.TrimSpace(apiResponse.Choices[0].Message.Content) - // 兼容模型偶尔返回 ```json ... ``` 的情况 - content = strings.TrimPrefix(content, "```json") - content = strings.TrimPrefix(content, "```") - content = strings.TrimSuffix(content, "```") - content = strings.TrimSpace(content) - - var parsed fofaParseResponse - if err := json.Unmarshal([]byte(content), &parsed); err != nil { - // 直接回传一部分原文,方便排查,但避免太大 - snippet := content - if len(snippet) > 1200 { - snippet = snippet[:1200] - } - c.JSON(http.StatusBadGateway, gin.H{ - "error": "AI 返回内容无法解析为 JSON,请稍后重试或换个描述方式", - "snippet": snippet, - }) - return - } - parsed.Query = strings.TrimSpace(parsed.Query) - if parsed.Query == "" { - // query 允许为空(表示需求不明确),但前端需要明确提示 - if len(parsed.Warnings) == 0 { - parsed.Warnings = []string{"需求信息不足,未能生成可用的 FOFA 查询语法,请补充关键条件(如国家/端口/产品/域名等)。"} - } - } - - c.JSON(http.StatusOK, parsed) -} - -// Search FOFA 查询(后端代理,避免前端暴露 key) -func (h *FofaHandler) Search(c *gin.Context) { - var req fofaSearchRequest - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()}) - return - } - - req.Query = strings.TrimSpace(req.Query) - if req.Query == "" { - c.JSON(http.StatusBadRequest, gin.H{"error": "query 不能为空"}) - return - } - if req.Size <= 0 { - req.Size = 100 - } - if req.Page <= 0 { - req.Page = 1 - } - // FOFA 接口 size 上限和账户权限相关,这里只做一个合理的保护 - if req.Size > 10000 { - req.Size = 10000 - } - if req.Fields == "" { - req.Fields = "host,ip,port,domain,title,protocol,country,province,city,server" - } - - email, apiKey := h.resolveCredentials() - if email == "" || apiKey == "" { - c.JSON(http.StatusBadRequest, gin.H{ - "error": "FOFA 未配置:请在系统设置中填写 FOFA Email/API Key,或设置环境变量 FOFA_EMAIL/FOFA_API_KEY", - "need": []string{"fofa.email", "fofa.api_key"}, - "env_key": []string{"FOFA_EMAIL", "FOFA_API_KEY"}, - }) - return - } - - baseURL := h.resolveBaseURL() - qb64 := base64.StdEncoding.EncodeToString([]byte(req.Query)) - - u, err := url.Parse(baseURL) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "FOFA base_url 无效: " + err.Error()}) - return - } - - params := u.Query() - params.Set("email", email) - params.Set("key", apiKey) - params.Set("qbase64", qb64) - params.Set("size", fmt.Sprintf("%d", req.Size)) - params.Set("page", fmt.Sprintf("%d", req.Page)) - params.Set("fields", strings.TrimSpace(req.Fields)) - if req.Full { - params.Set("full", "true") - } else { - // 明确传 false,便于排查 - params.Set("full", "false") - } - u.RawQuery = params.Encode() - - httpReq, err := http.NewRequestWithContext(c.Request.Context(), http.MethodGet, u.String(), nil) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "创建请求失败: " + err.Error()}) - return - } - - resp, err := h.client.Do(httpReq) - if err != nil { - c.JSON(http.StatusBadGateway, gin.H{"error": "请求 FOFA 失败: " + err.Error()}) - return - } - defer resp.Body.Close() - - if resp.StatusCode < 200 || resp.StatusCode >= 300 { - c.JSON(http.StatusBadGateway, gin.H{"error": fmt.Sprintf("FOFA 返回非 2xx: %d", resp.StatusCode)}) - return - } - - var apiResp fofaAPIResponse - if err := json.NewDecoder(resp.Body).Decode(&apiResp); err != nil { - c.JSON(http.StatusBadGateway, gin.H{"error": "解析 FOFA 响应失败: " + err.Error()}) - return - } - if apiResp.Error { - msg := strings.TrimSpace(apiResp.ErrMsg) - if msg == "" { - msg = "FOFA 返回错误" - } - c.JSON(http.StatusBadGateway, gin.H{"error": msg}) - return - } - - fields := splitAndCleanCSV(req.Fields) - results := make([]map[string]interface{}, 0, len(apiResp.Results)) - for _, row := range apiResp.Results { - item := make(map[string]interface{}, len(fields)) - for i, f := range fields { - if i < len(row) { - item[f] = row[i] - } else { - item[f] = nil - } - } - results = append(results, item) - } - - c.JSON(http.StatusOK, fofaSearchResponse{ - Query: req.Query, - Size: apiResp.Size, - Page: apiResp.Page, - Total: apiResp.Total, - Fields: fields, - ResultsCount: len(results), - Results: results, - }) -} - -func splitAndCleanCSV(s string) []string { - parts := strings.Split(s, ",") - out := make([]string, 0, len(parts)) - seen := make(map[string]struct{}, len(parts)) - for _, p := range parts { - v := strings.TrimSpace(p) - if v == "" { - continue - } - if _, ok := seen[v]; ok { - continue - } - seen[v] = struct{}{} - out = append(out, v) - } - return out -} diff --git a/handler/group.go b/handler/group.go deleted file mode 100644 index 495e7695..00000000 --- a/handler/group.go +++ /dev/null @@ -1,320 +0,0 @@ -package handler - -import ( - "net/http" - "time" - - "cyberstrike-ai/internal/database" - - "github.com/gin-gonic/gin" - "go.uber.org/zap" -) - -// GroupHandler 分组处理器 -type GroupHandler struct { - db *database.DB - logger *zap.Logger -} - -// NewGroupHandler 创建新的分组处理器 -func NewGroupHandler(db *database.DB, logger *zap.Logger) *GroupHandler { - return &GroupHandler{ - db: db, - logger: logger, - } -} - -// CreateGroupRequest 创建分组请求 -type CreateGroupRequest struct { - Name string `json:"name"` - Icon string `json:"icon"` -} - -// CreateGroup 创建分组 -func (h *GroupHandler) CreateGroup(c *gin.Context) { - var req CreateGroupRequest - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - - if req.Name == "" { - c.JSON(http.StatusBadRequest, gin.H{"error": "分组名称不能为空"}) - return - } - - group, err := h.db.CreateGroup(req.Name, req.Icon) - if err != nil { - h.logger.Error("创建分组失败", zap.Error(err)) - // 如果是名称重复错误,返回400状态码 - if err.Error() == "分组名称已存在" { - c.JSON(http.StatusBadRequest, gin.H{"error": "分组名称已存在"}) - return - } - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - - c.JSON(http.StatusOK, group) -} - -// ListGroups 列出所有分组 -func (h *GroupHandler) ListGroups(c *gin.Context) { - groups, err := h.db.ListGroups() - if err != nil { - h.logger.Error("获取分组列表失败", zap.Error(err)) - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - - c.JSON(http.StatusOK, groups) -} - -// GetGroup 获取分组 -func (h *GroupHandler) GetGroup(c *gin.Context) { - id := c.Param("id") - - group, err := h.db.GetGroup(id) - if err != nil { - h.logger.Error("获取分组失败", zap.Error(err)) - c.JSON(http.StatusNotFound, gin.H{"error": "分组不存在"}) - return - } - - c.JSON(http.StatusOK, group) -} - -// UpdateGroupRequest 更新分组请求 -type UpdateGroupRequest struct { - Name string `json:"name"` - Icon string `json:"icon"` -} - -// UpdateGroup 更新分组 -func (h *GroupHandler) UpdateGroup(c *gin.Context) { - id := c.Param("id") - - var req UpdateGroupRequest - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - - if req.Name == "" { - c.JSON(http.StatusBadRequest, gin.H{"error": "分组名称不能为空"}) - return - } - - if err := h.db.UpdateGroup(id, req.Name, req.Icon); err != nil { - h.logger.Error("更新分组失败", zap.Error(err)) - // 如果是名称重复错误,返回400状态码 - if err.Error() == "分组名称已存在" { - c.JSON(http.StatusBadRequest, gin.H{"error": "分组名称已存在"}) - return - } - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - - group, err := h.db.GetGroup(id) - if err != nil { - h.logger.Error("获取更新后的分组失败", zap.Error(err)) - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - - c.JSON(http.StatusOK, group) -} - -// DeleteGroup 删除分组 -func (h *GroupHandler) DeleteGroup(c *gin.Context) { - id := c.Param("id") - - if err := h.db.DeleteGroup(id); err != nil { - h.logger.Error("删除分组失败", zap.Error(err)) - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - - c.JSON(http.StatusOK, gin.H{"message": "删除成功"}) -} - -// AddConversationToGroupRequest 添加对话到分组请求 -type AddConversationToGroupRequest struct { - ConversationID string `json:"conversationId"` - GroupID string `json:"groupId"` -} - -// AddConversationToGroup 将对话添加到分组 -func (h *GroupHandler) AddConversationToGroup(c *gin.Context) { - var req AddConversationToGroupRequest - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - - if err := h.db.AddConversationToGroup(req.ConversationID, req.GroupID); err != nil { - h.logger.Error("添加对话到分组失败", zap.Error(err)) - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - - c.JSON(http.StatusOK, gin.H{"message": "添加成功"}) -} - -// RemoveConversationFromGroup 从分组中移除对话 -func (h *GroupHandler) RemoveConversationFromGroup(c *gin.Context) { - conversationID := c.Param("conversationId") - groupID := c.Param("id") - - if err := h.db.RemoveConversationFromGroup(conversationID, groupID); err != nil { - h.logger.Error("从分组中移除对话失败", zap.Error(err)) - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - - c.JSON(http.StatusOK, gin.H{"message": "移除成功"}) -} - -// GroupConversation 分组对话响应结构 -type GroupConversation struct { - ID string `json:"id"` - Title string `json:"title"` - Pinned bool `json:"pinned"` - GroupPinned bool `json:"groupPinned"` - CreatedAt time.Time `json:"createdAt"` - UpdatedAt time.Time `json:"updatedAt"` -} - -// GetGroupConversations 获取分组中的所有对话 -func (h *GroupHandler) GetGroupConversations(c *gin.Context) { - groupID := c.Param("id") - searchQuery := c.Query("search") // 获取搜索参数 - - var conversations []*database.Conversation - var err error - - // 如果有搜索关键词,使用搜索方法;否则使用普通方法 - if searchQuery != "" { - conversations, err = h.db.SearchConversationsByGroup(groupID, searchQuery) - } else { - conversations, err = h.db.GetConversationsByGroup(groupID) - } - - if err != nil { - h.logger.Error("获取分组对话失败", zap.Error(err)) - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - - // 获取每个对话在分组中的置顶状态 - groupConvs := make([]GroupConversation, 0, len(conversations)) - for _, conv := range conversations { - // 查询分组内置顶状态 - var groupPinned int - err := h.db.QueryRow( - "SELECT COALESCE(pinned, 0) FROM conversation_group_mappings WHERE conversation_id = ? AND group_id = ?", - conv.ID, groupID, - ).Scan(&groupPinned) - if err != nil { - h.logger.Warn("查询分组内置顶状态失败", zap.String("conversationId", conv.ID), zap.Error(err)) - groupPinned = 0 - } - - groupConvs = append(groupConvs, GroupConversation{ - ID: conv.ID, - Title: conv.Title, - Pinned: conv.Pinned, - GroupPinned: groupPinned != 0, - CreatedAt: conv.CreatedAt, - UpdatedAt: conv.UpdatedAt, - }) - } - - c.JSON(http.StatusOK, groupConvs) -} - -// GetAllMappings 批量获取所有分组映射(消除前端 N+1 请求) -func (h *GroupHandler) GetAllMappings(c *gin.Context) { - mappings, err := h.db.GetAllGroupMappings() - if err != nil { - h.logger.Error("获取分组映射失败", zap.Error(err)) - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - - c.JSON(http.StatusOK, mappings) -} - -// UpdateConversationPinnedRequest 更新对话置顶状态请求 -type UpdateConversationPinnedRequest struct { - Pinned bool `json:"pinned"` -} - -// UpdateConversationPinned 更新对话置顶状态 -func (h *GroupHandler) UpdateConversationPinned(c *gin.Context) { - conversationID := c.Param("id") - - var req UpdateConversationPinnedRequest - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - - if err := h.db.UpdateConversationPinned(conversationID, req.Pinned); err != nil { - h.logger.Error("更新对话置顶状态失败", zap.Error(err)) - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - - c.JSON(http.StatusOK, gin.H{"message": "更新成功"}) -} - -// UpdateGroupPinnedRequest 更新分组置顶状态请求 -type UpdateGroupPinnedRequest struct { - Pinned bool `json:"pinned"` -} - -// UpdateGroupPinned 更新分组置顶状态 -func (h *GroupHandler) UpdateGroupPinned(c *gin.Context) { - groupID := c.Param("id") - - var req UpdateGroupPinnedRequest - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - - if err := h.db.UpdateGroupPinned(groupID, req.Pinned); err != nil { - h.logger.Error("更新分组置顶状态失败", zap.Error(err)) - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - - c.JSON(http.StatusOK, gin.H{"message": "更新成功"}) -} - -// UpdateConversationPinnedInGroupRequest 更新分组对话置顶状态请求 -type UpdateConversationPinnedInGroupRequest struct { - Pinned bool `json:"pinned"` -} - -// UpdateConversationPinnedInGroup 更新对话在分组中的置顶状态 -func (h *GroupHandler) UpdateConversationPinnedInGroup(c *gin.Context) { - groupID := c.Param("id") - conversationID := c.Param("conversationId") - - var req UpdateConversationPinnedInGroupRequest - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - - if err := h.db.UpdateConversationPinnedInGroup(conversationID, groupID, req.Pinned); err != nil { - h.logger.Error("更新分组对话置顶状态失败", zap.Error(err)) - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - - c.JSON(http.StatusOK, gin.H{"message": "更新成功"}) -} diff --git a/handler/knowledge.go b/handler/knowledge.go deleted file mode 100644 index 76d7b974..00000000 --- a/handler/knowledge.go +++ /dev/null @@ -1,517 +0,0 @@ -package handler - -import ( - "context" - "fmt" - "net/http" - "time" - - "cyberstrike-ai/internal/database" - "cyberstrike-ai/internal/knowledge" - - "github.com/gin-gonic/gin" - "go.uber.org/zap" -) - -// KnowledgeHandler 知识库处理器 -type KnowledgeHandler struct { - manager *knowledge.Manager - retriever *knowledge.Retriever - indexer *knowledge.Indexer - db *database.DB - logger *zap.Logger -} - -// NewKnowledgeHandler 创建新的知识库处理器 -func NewKnowledgeHandler( - manager *knowledge.Manager, - retriever *knowledge.Retriever, - indexer *knowledge.Indexer, - db *database.DB, - logger *zap.Logger, -) *KnowledgeHandler { - return &KnowledgeHandler{ - manager: manager, - retriever: retriever, - indexer: indexer, - db: db, - logger: logger, - } -} - -// GetCategories 获取所有分类 -func (h *KnowledgeHandler) GetCategories(c *gin.Context) { - categories, err := h.manager.GetCategories() - if err != nil { - h.logger.Error("获取分类失败", zap.Error(err)) - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - - c.JSON(http.StatusOK, gin.H{"categories": categories}) -} - -// GetItems 获取知识项列表(支持按分类分页和关键字搜索,默认不返回完整内容) -func (h *KnowledgeHandler) GetItems(c *gin.Context) { - category := c.Query("category") - searchKeyword := c.Query("search") // 搜索关键字 - - // 如果提供了搜索关键字,执行关键字搜索(在所有数据中搜索) - if searchKeyword != "" { - items, err := h.manager.SearchItemsByKeyword(searchKeyword, category) - if err != nil { - h.logger.Error("搜索知识项失败", zap.Error(err)) - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - - // 按分类分组结果 - groupedByCategory := make(map[string][]*knowledge.KnowledgeItemSummary) - for _, item := range items { - cat := item.Category - if cat == "" { - cat = "未分类" - } - groupedByCategory[cat] = append(groupedByCategory[cat], item) - } - - // 转换为 CategoryWithItems 格式 - categoriesWithItems := make([]*knowledge.CategoryWithItems, 0, len(groupedByCategory)) - for cat, catItems := range groupedByCategory { - categoriesWithItems = append(categoriesWithItems, &knowledge.CategoryWithItems{ - Category: cat, - ItemCount: len(catItems), - Items: catItems, - }) - } - - // 按分类名称排序 - for i := 0; i < len(categoriesWithItems)-1; i++ { - for j := i + 1; j < len(categoriesWithItems); j++ { - if categoriesWithItems[i].Category > categoriesWithItems[j].Category { - categoriesWithItems[i], categoriesWithItems[j] = categoriesWithItems[j], categoriesWithItems[i] - } - } - } - - c.JSON(http.StatusOK, gin.H{ - "categories": categoriesWithItems, - "total": len(categoriesWithItems), - "search": searchKeyword, - "is_search": true, - }) - return - } - - // 分页模式:categoryPage=true 表示按分类分页,否则按项分页(向后兼容) - categoryPageMode := c.Query("categoryPage") != "false" // 默认使用分类分页 - - // 分页参数 - limit := 50 // 默认每页 50 条(分类分页时为分类数,项分页时为项数) - offset := 0 - if limitStr := c.Query("limit"); limitStr != "" { - if parsed, err := parseInt(limitStr); err == nil && parsed > 0 && parsed <= 500 { - limit = parsed - } - } - if offsetStr := c.Query("offset"); offsetStr != "" { - if parsed, err := parseInt(offsetStr); err == nil && parsed >= 0 { - offset = parsed - } - } - - // 如果指定了 category 参数,且使用分类分页模式,则只返回该分类 - if category != "" && categoryPageMode { - // 单分类模式:返回该分类的所有知识项(不分页) - items, total, err := h.manager.GetItemsSummary(category, 0, 0) - if err != nil { - h.logger.Error("获取知识项失败", zap.Error(err)) - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - - // 包装成分类结构 - categoriesWithItems := []*knowledge.CategoryWithItems{ - { - Category: category, - ItemCount: total, - Items: items, - }, - } - - c.JSON(http.StatusOK, gin.H{ - "categories": categoriesWithItems, - "total": 1, // 只有一个分类 - "limit": limit, - "offset": offset, - }) - return - } - - if categoryPageMode { - // 按分类分页模式(默认) - // limit 表示每页分类数,推荐 5-10 个分类 - if limit <= 0 || limit > 100 { - limit = 10 // 默认每页 10 个分类 - } - - categoriesWithItems, totalCategories, err := h.manager.GetCategoriesWithItems(limit, offset) - if err != nil { - h.logger.Error("获取分类知识项失败", zap.Error(err)) - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - - c.JSON(http.StatusOK, gin.H{ - "categories": categoriesWithItems, - "total": totalCategories, - "limit": limit, - "offset": offset, - }) - return - } - - // 按项分页模式(向后兼容) - // 是否包含完整内容(默认 false,只返回摘要) - includeContent := c.Query("includeContent") == "true" - - if includeContent { - // 返回完整内容(向后兼容) - items, err := h.manager.GetItemsWithOptions(category, limit, offset, true) - if err != nil { - h.logger.Error("获取知识项失败", zap.Error(err)) - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - - // 获取总数 - total, err := h.manager.GetItemsCount(category) - if err != nil { - h.logger.Warn("获取知识项总数失败", zap.Error(err)) - total = len(items) - } - - c.JSON(http.StatusOK, gin.H{ - "items": items, - "total": total, - "limit": limit, - "offset": offset, - }) - } else { - // 返回摘要(不包含完整内容,推荐方式) - items, total, err := h.manager.GetItemsSummary(category, limit, offset) - if err != nil { - h.logger.Error("获取知识项失败", zap.Error(err)) - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - - c.JSON(http.StatusOK, gin.H{ - "items": items, - "total": total, - "limit": limit, - "offset": offset, - }) - } -} - -// GetItem 获取单个知识项 -func (h *KnowledgeHandler) GetItem(c *gin.Context) { - id := c.Param("id") - - item, err := h.manager.GetItem(id) - if err != nil { - h.logger.Error("获取知识项失败", zap.Error(err)) - c.JSON(http.StatusNotFound, gin.H{"error": err.Error()}) - return - } - - c.JSON(http.StatusOK, item) -} - -// CreateItem 创建知识项 -func (h *KnowledgeHandler) CreateItem(c *gin.Context) { - var req struct { - Category string `json:"category" binding:"required"` - Title string `json:"title" binding:"required"` - Content string `json:"content" binding:"required"` - } - - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - - item, err := h.manager.CreateItem(req.Category, req.Title, req.Content) - if err != nil { - h.logger.Error("创建知识项失败", zap.Error(err)) - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - - // 异步索引 - go func() { - ctx := context.Background() - if err := h.indexer.IndexItem(ctx, item.ID); err != nil { - h.logger.Warn("索引知识项失败", zap.String("itemId", item.ID), zap.Error(err)) - } - }() - - c.JSON(http.StatusOK, item) -} - -// UpdateItem 更新知识项 -func (h *KnowledgeHandler) UpdateItem(c *gin.Context) { - id := c.Param("id") - - var req struct { - Category string `json:"category" binding:"required"` - Title string `json:"title" binding:"required"` - Content string `json:"content" binding:"required"` - } - - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - - item, err := h.manager.UpdateItem(id, req.Category, req.Title, req.Content) - if err != nil { - h.logger.Error("更新知识项失败", zap.Error(err)) - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - - // 异步重新索引 - go func() { - ctx := context.Background() - if err := h.indexer.IndexItem(ctx, item.ID); err != nil { - h.logger.Warn("重新索引知识项失败", zap.String("itemId", item.ID), zap.Error(err)) - } - }() - - c.JSON(http.StatusOK, item) -} - -// DeleteItem 删除知识项 -func (h *KnowledgeHandler) DeleteItem(c *gin.Context) { - id := c.Param("id") - - if err := h.manager.DeleteItem(id); err != nil { - h.logger.Error("删除知识项失败", zap.Error(err)) - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - - c.JSON(http.StatusOK, gin.H{"message": "删除成功"}) -} - -// RebuildIndex 重建索引 -func (h *KnowledgeHandler) RebuildIndex(c *gin.Context) { - // 异步重建索引 - go func() { - ctx := context.Background() - if err := h.indexer.RebuildIndex(ctx); err != nil { - h.logger.Error("重建索引失败", zap.Error(err)) - } - }() - - c.JSON(http.StatusOK, gin.H{"message": "索引重建已开始,将在后台进行"}) -} - -// ScanKnowledgeBase 扫描知识库 -func (h *KnowledgeHandler) ScanKnowledgeBase(c *gin.Context) { - itemsToIndex, err := h.manager.ScanKnowledgeBase() - if err != nil { - h.logger.Error("扫描知识库失败", zap.Error(err)) - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - - if len(itemsToIndex) == 0 { - c.JSON(http.StatusOK, gin.H{"message": "扫描完成,没有需要索引的新项或更新项"}) - return - } - - // 异步索引新添加或更新的项(增量索引) - go func() { - ctx := context.Background() - h.logger.Info("开始增量索引", zap.Int("count", len(itemsToIndex))) - failedCount := 0 - consecutiveFailures := 0 - var firstFailureItemID string - var firstFailureError error - - for i, itemID := range itemsToIndex { - if err := h.indexer.IndexItem(ctx, itemID); err != nil { - failedCount++ - consecutiveFailures++ - - // 只在第一个失败时记录详细日志 - if consecutiveFailures == 1 { - firstFailureItemID = itemID - firstFailureError = err - h.logger.Warn("索引知识项失败", - zap.String("itemId", itemID), - zap.Int("totalItems", len(itemsToIndex)), - zap.Error(err), - ) - } - - // 如果连续失败 2 次,立即停止增量索引 - if consecutiveFailures >= 2 { - h.logger.Error("连续索引失败次数过多,立即停止增量索引", - zap.Int("consecutiveFailures", consecutiveFailures), - zap.Int("totalItems", len(itemsToIndex)), - zap.Int("processedItems", i+1), - zap.String("firstFailureItemId", firstFailureItemID), - zap.Error(firstFailureError), - ) - break - } - continue - } - - // 成功时重置连续失败计数 - if consecutiveFailures > 0 { - consecutiveFailures = 0 - firstFailureItemID = "" - firstFailureError = nil - } - - // 减少进度日志频率 - if (i+1)%10 == 0 || i+1 == len(itemsToIndex) { - h.logger.Info("索引进度", zap.Int("current", i+1), zap.Int("total", len(itemsToIndex)), zap.Int("failed", failedCount)) - } - } - h.logger.Info("增量索引完成", zap.Int("totalItems", len(itemsToIndex)), zap.Int("failedCount", failedCount)) - }() - - c.JSON(http.StatusOK, gin.H{ - "message": fmt.Sprintf("扫描完成,开始索引 %d 个新添加或更新的知识项", len(itemsToIndex)), - "items_to_index": len(itemsToIndex), - }) -} - -// GetRetrievalLogs 获取检索日志 -func (h *KnowledgeHandler) GetRetrievalLogs(c *gin.Context) { - conversationID := c.Query("conversationId") - messageID := c.Query("messageId") - limit := 50 // 默认 50 条 - - if limitStr := c.Query("limit"); limitStr != "" { - if parsed, err := parseInt(limitStr); err == nil && parsed > 0 { - limit = parsed - } - } - - logs, err := h.manager.GetRetrievalLogs(conversationID, messageID, limit) - if err != nil { - h.logger.Error("获取检索日志失败", zap.Error(err)) - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - - c.JSON(http.StatusOK, gin.H{"logs": logs}) -} - -// DeleteRetrievalLog 删除检索日志 -func (h *KnowledgeHandler) DeleteRetrievalLog(c *gin.Context) { - id := c.Param("id") - - if err := h.manager.DeleteRetrievalLog(id); err != nil { - h.logger.Error("删除检索日志失败", zap.Error(err)) - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - - c.JSON(http.StatusOK, gin.H{"message": "删除成功"}) -} - -// GetIndexStatus 获取索引状态 -func (h *KnowledgeHandler) GetIndexStatus(c *gin.Context) { - status, err := h.manager.GetIndexStatus() - if err != nil { - h.logger.Error("获取索引状态失败", zap.Error(err)) - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - - // 获取索引器的错误信息 - if h.indexer != nil { - lastError, lastErrorTime := h.indexer.GetLastError() - if lastError != "" { - // 如果错误是最近发生的(5 分钟内),则返回错误信息 - if time.Since(lastErrorTime) < 5*time.Minute { - status["last_error"] = lastError - status["last_error_time"] = lastErrorTime.Format(time.RFC3339) - } - } - - // 获取重建索引状态 - isRebuilding, totalItems, current, failed, lastItemID, lastChunks, startTime := h.indexer.GetRebuildStatus() - if isRebuilding { - status["is_rebuilding"] = true - status["rebuild_total"] = totalItems - status["rebuild_current"] = current - status["rebuild_failed"] = failed - status["rebuild_start_time"] = startTime.Format(time.RFC3339) - if lastItemID != "" { - status["rebuild_last_item_id"] = lastItemID - } - if lastChunks > 0 { - status["rebuild_last_chunks"] = lastChunks - } - // 重建中时,is_complete 为 false - status["is_complete"] = false - // 计算重建进度百分比 - if totalItems > 0 { - status["progress_percent"] = float64(current) / float64(totalItems) * 100 - } - } - } - - c.JSON(http.StatusOK, status) -} - -// Search 搜索知识库(用于 API 调用,Agent 内部使用 Retriever) -func (h *KnowledgeHandler) Search(c *gin.Context) { - var req knowledge.SearchRequest - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - - // Retriever.Search 经 Eino VectorEinoRetriever,与 MCP 工具链一致。 - results, err := h.retriever.Search(c.Request.Context(), &req) - if err != nil { - h.logger.Error("搜索知识库失败", zap.Error(err)) - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - - c.JSON(http.StatusOK, gin.H{"results": results}) -} - -// GetStats 获取知识库统计信息 -func (h *KnowledgeHandler) GetStats(c *gin.Context) { - totalCategories, totalItems, err := h.manager.GetStats() - if err != nil { - h.logger.Error("获取知识库统计信息失败", zap.Error(err)) - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - - c.JSON(http.StatusOK, gin.H{ - "enabled": true, - "total_categories": totalCategories, - "total_items": totalItems, - }) -} - -// 辅助函数:解析整数 -func parseInt(s string) (int, error) { - var result int - _, err := fmt.Sscanf(s, "%d", &result) - return result, err -} diff --git a/handler/markdown_agents.go b/handler/markdown_agents.go deleted file mode 100644 index 90295540..00000000 --- a/handler/markdown_agents.go +++ /dev/null @@ -1,299 +0,0 @@ -package handler - -import ( - "fmt" - "net/http" - "os" - "path/filepath" - "regexp" - "strings" - - "cyberstrike-ai/internal/agents" - "cyberstrike-ai/internal/config" - - "github.com/gin-gonic/gin" -) - -var markdownAgentFilenameRe = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9_.-]*\.md$`) - -// MarkdownAgentsHandler 管理 agents 目录下子代理 Markdown(增删改查)。 -type MarkdownAgentsHandler struct { - dir string -} - -// NewMarkdownAgentsHandler dir 须为已解析的绝对路径。 -func NewMarkdownAgentsHandler(dir string) *MarkdownAgentsHandler { - return &MarkdownAgentsHandler{dir: strings.TrimSpace(dir)} -} - -func (h *MarkdownAgentsHandler) safeJoin(filename string) (string, error) { - filename = strings.TrimSpace(filename) - if filename == "" || !markdownAgentFilenameRe.MatchString(filename) { - return "", fmt.Errorf("非法文件名") - } - clean := filepath.Clean(filename) - if clean != filename || strings.Contains(clean, "..") { - return "", fmt.Errorf("非法文件名") - } - return filepath.Join(h.dir, clean), nil -} - -// existingOtherOrchestrator 若目录中已有别的主代理文件,返回其文件名;writingBasename 为当前正在写入的文件名时视为同一文件不冲突。 -func existingOtherOrchestrator(dir, writingBasename string) (other string, err error) { - load, err := agents.LoadMarkdownAgentsDir(dir) - if err != nil { - return "", err - } - if load.Orchestrator == nil { - return "", nil - } - if strings.EqualFold(load.Orchestrator.Filename, writingBasename) { - return "", nil - } - return load.Orchestrator.Filename, nil -} - -// ListMarkdownAgents GET /api/multi-agent/markdown-agents -func (h *MarkdownAgentsHandler) ListMarkdownAgents(c *gin.Context) { - if h.dir == "" { - c.JSON(http.StatusOK, gin.H{"agents": []any{}, "dir": "", "error": "未配置 agents 目录"}) - return - } - files, err := agents.LoadMarkdownAgentFiles(h.dir) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - out := make([]gin.H, 0, len(files)) - for _, fa := range files { - sub := fa.Config - out = append(out, gin.H{ - "filename": fa.Filename, - "id": sub.ID, - "name": sub.Name, - "description": sub.Description, - "is_orchestrator": fa.IsOrchestrator, - "kind": sub.Kind, - }) - } - c.JSON(http.StatusOK, gin.H{"agents": out, "dir": h.dir}) -} - -// GetMarkdownAgent GET /api/multi-agent/markdown-agents/:filename -func (h *MarkdownAgentsHandler) GetMarkdownAgent(c *gin.Context) { - filename := c.Param("filename") - path, err := h.safeJoin(filename) - if err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - b, err := os.ReadFile(path) - if err != nil { - if os.IsNotExist(err) { - c.JSON(http.StatusNotFound, gin.H{"error": "文件不存在"}) - return - } - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - sub, err := agents.ParseMarkdownSubAgent(filename, string(b)) - if err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - isOrch := agents.IsOrchestratorMarkdown(filename, agents.FrontMatter{Kind: sub.Kind}) - c.JSON(http.StatusOK, gin.H{ - "filename": filename, - "raw": string(b), - "id": sub.ID, - "name": sub.Name, - "description": sub.Description, - "tools": sub.RoleTools, - "instruction": sub.Instruction, - "bind_role": sub.BindRole, - "max_iterations": sub.MaxIterations, - "kind": sub.Kind, - "is_orchestrator": isOrch, - }) -} - -type markdownAgentBody struct { - Filename string `json:"filename"` - ID string `json:"id"` - Name string `json:"name"` - Description string `json:"description"` - Tools []string `json:"tools"` - Instruction string `json:"instruction"` - BindRole string `json:"bind_role"` - MaxIterations int `json:"max_iterations"` - Kind string `json:"kind"` - Raw string `json:"raw"` -} - -// CreateMarkdownAgent POST /api/multi-agent/markdown-agents -func (h *MarkdownAgentsHandler) CreateMarkdownAgent(c *gin.Context) { - if h.dir == "" { - c.JSON(http.StatusBadRequest, gin.H{"error": "未配置 agents 目录"}) - return - } - var body markdownAgentBody - if err := c.ShouldBindJSON(&body); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - filename := strings.TrimSpace(body.Filename) - if filename == "" { - if strings.EqualFold(strings.TrimSpace(body.Kind), "orchestrator") { - filename = agents.OrchestratorMarkdownFilename - } else { - base := agents.SlugID(body.Name) - if base == "" { - base = "agent" - } - filename = base + ".md" - } - } - path, err := h.safeJoin(filename) - if err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - if _, err := os.Stat(path); err == nil { - c.JSON(http.StatusConflict, gin.H{"error": "文件已存在"}) - return - } - sub := config.MultiAgentSubConfig{ - ID: strings.TrimSpace(body.ID), - Name: strings.TrimSpace(body.Name), - Description: strings.TrimSpace(body.Description), - Instruction: strings.TrimSpace(body.Instruction), - RoleTools: body.Tools, - BindRole: strings.TrimSpace(body.BindRole), - MaxIterations: body.MaxIterations, - Kind: strings.TrimSpace(body.Kind), - } - if strings.EqualFold(filepath.Base(path), agents.OrchestratorMarkdownFilename) && sub.Kind == "" { - sub.Kind = "orchestrator" - } - if sub.ID == "" { - sub.ID = agents.SlugID(sub.Name) - } - if sub.Name == "" { - c.JSON(http.StatusBadRequest, gin.H{"error": "name 必填"}) - return - } - var out []byte - if strings.TrimSpace(body.Raw) != "" { - out = []byte(body.Raw) - } else { - out, err = agents.BuildMarkdownFile(sub) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - } - if want := agents.WantsMarkdownOrchestrator(filepath.Base(path), body.Kind, string(out)); want { - other, oerr := existingOtherOrchestrator(h.dir, filepath.Base(path)) - if oerr != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": oerr.Error()}) - return - } - if other != "" { - c.JSON(http.StatusConflict, gin.H{"error": fmt.Sprintf("已存在主代理定义:%s,请先删除或取消其主代理标记", other)}) - return - } - } - if err := os.MkdirAll(h.dir, 0755); err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - if err := os.WriteFile(path, out, 0644); err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - c.JSON(http.StatusOK, gin.H{"filename": filepath.Base(path), "message": "已创建"}) -} - -// UpdateMarkdownAgent PUT /api/multi-agent/markdown-agents/:filename -func (h *MarkdownAgentsHandler) UpdateMarkdownAgent(c *gin.Context) { - filename := c.Param("filename") - path, err := h.safeJoin(filename) - if err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - var body markdownAgentBody - if err := c.ShouldBindJSON(&body); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - sub := config.MultiAgentSubConfig{ - ID: strings.TrimSpace(body.ID), - Name: strings.TrimSpace(body.Name), - Description: strings.TrimSpace(body.Description), - Instruction: strings.TrimSpace(body.Instruction), - RoleTools: body.Tools, - BindRole: strings.TrimSpace(body.BindRole), - MaxIterations: body.MaxIterations, - Kind: strings.TrimSpace(body.Kind), - } - if strings.EqualFold(filename, agents.OrchestratorMarkdownFilename) && sub.Kind == "" { - sub.Kind = "orchestrator" - } - if sub.Name == "" { - c.JSON(http.StatusBadRequest, gin.H{"error": "name 必填"}) - return - } - if sub.ID == "" { - sub.ID = agents.SlugID(sub.Name) - } - var out []byte - if strings.TrimSpace(body.Raw) != "" { - out = []byte(body.Raw) - } else { - out, err = agents.BuildMarkdownFile(sub) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - } - if want := agents.WantsMarkdownOrchestrator(filename, body.Kind, string(out)); want { - other, oerr := existingOtherOrchestrator(h.dir, filename) - if oerr != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": oerr.Error()}) - return - } - if other != "" { - c.JSON(http.StatusConflict, gin.H{"error": fmt.Sprintf("已存在主代理定义:%s,请先删除或取消其主代理标记", other)}) - return - } - } - if err := os.WriteFile(path, out, 0644); err != nil { - if os.IsNotExist(err) { - c.JSON(http.StatusNotFound, gin.H{"error": "文件不存在"}) - return - } - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - c.JSON(http.StatusOK, gin.H{"message": "已保存"}) -} - -// DeleteMarkdownAgent DELETE /api/multi-agent/markdown-agents/:filename -func (h *MarkdownAgentsHandler) DeleteMarkdownAgent(c *gin.Context) { - filename := c.Param("filename") - path, err := h.safeJoin(filename) - if err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - if err := os.Remove(path); err != nil { - if os.IsNotExist(err) { - c.JSON(http.StatusNotFound, gin.H{"error": "文件不存在"}) - return - } - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - c.JSON(http.StatusOK, gin.H{"message": "已删除"}) -} diff --git a/handler/monitor.go b/handler/monitor.go deleted file mode 100644 index c337c374..00000000 --- a/handler/monitor.go +++ /dev/null @@ -1,420 +0,0 @@ -package handler - -import ( - "net/http" - "strconv" - "strings" - "time" - - "cyberstrike-ai/internal/database" - "cyberstrike-ai/internal/mcp" - "cyberstrike-ai/internal/security" - "github.com/gin-gonic/gin" - "go.uber.org/zap" -) - -// MonitorHandler 监控处理器 -type MonitorHandler struct { - mcpServer *mcp.Server - externalMCPMgr *mcp.ExternalMCPManager - executor *security.Executor - db *database.DB - logger *zap.Logger -} - -// NewMonitorHandler 创建新的监控处理器 -func NewMonitorHandler(mcpServer *mcp.Server, executor *security.Executor, db *database.DB, logger *zap.Logger) *MonitorHandler { - return &MonitorHandler{ - mcpServer: mcpServer, - externalMCPMgr: nil, // 将在创建后设置 - executor: executor, - db: db, - logger: logger, - } -} - -// SetExternalMCPManager 设置外部MCP管理器 -func (h *MonitorHandler) SetExternalMCPManager(mgr *mcp.ExternalMCPManager) { - h.externalMCPMgr = mgr -} - -// MonitorResponse 监控响应 -type MonitorResponse struct { - Executions []*mcp.ToolExecution `json:"executions"` - Stats map[string]*mcp.ToolStats `json:"stats"` - Timestamp time.Time `json:"timestamp"` - Total int `json:"total,omitempty"` - Page int `json:"page,omitempty"` - PageSize int `json:"page_size,omitempty"` - TotalPages int `json:"total_pages,omitempty"` -} - -// Monitor 获取监控信息 -func (h *MonitorHandler) Monitor(c *gin.Context) { - // 解析分页参数 - page := 1 - pageSize := 20 - if pageStr := c.Query("page"); pageStr != "" { - if p, err := strconv.Atoi(pageStr); err == nil && p > 0 { - page = p - } - } - if pageSizeStr := c.Query("page_size"); pageSizeStr != "" { - if ps, err := strconv.Atoi(pageSizeStr); err == nil && ps > 0 && ps <= 100 { - pageSize = ps - } - } - - // 解析状态筛选参数 - status := c.Query("status") - // 解析工具筛选参数 - toolName := c.Query("tool") - - executions, total := h.loadExecutionsWithPagination(page, pageSize, status, toolName) - stats := h.loadStats() - - totalPages := (total + pageSize - 1) / pageSize - if totalPages == 0 { - totalPages = 1 - } - - c.JSON(http.StatusOK, MonitorResponse{ - Executions: executions, - Stats: stats, - Timestamp: time.Now(), - Total: total, - Page: page, - PageSize: pageSize, - TotalPages: totalPages, - }) -} - -func (h *MonitorHandler) loadExecutions() []*mcp.ToolExecution { - executions, _ := h.loadExecutionsWithPagination(1, 1000, "", "") - return executions -} - -func (h *MonitorHandler) loadExecutionsWithPagination(page, pageSize int, status, toolName string) ([]*mcp.ToolExecution, int) { - if h.db == nil { - allExecutions := h.mcpServer.GetAllExecutions() - // 如果指定了状态筛选或工具筛选,先进行筛选 - if status != "" || toolName != "" { - filtered := make([]*mcp.ToolExecution, 0) - for _, exec := range allExecutions { - matchStatus := status == "" || exec.Status == status - // 支持部分匹配(模糊搜索) - matchTool := toolName == "" || strings.Contains(strings.ToLower(exec.ToolName), strings.ToLower(toolName)) - if matchStatus && matchTool { - filtered = append(filtered, exec) - } - } - allExecutions = filtered - } - total := len(allExecutions) - offset := (page - 1) * pageSize - end := offset + pageSize - if end > total { - end = total - } - if offset >= total { - return []*mcp.ToolExecution{}, total - } - return allExecutions[offset:end], total - } - - offset := (page - 1) * pageSize - executions, err := h.db.LoadToolExecutionsWithPagination(offset, pageSize, status, toolName) - if err != nil { - h.logger.Warn("从数据库加载执行记录失败,回退到内存数据", zap.Error(err)) - allExecutions := h.mcpServer.GetAllExecutions() - // 如果指定了状态筛选或工具筛选,先进行筛选 - if status != "" || toolName != "" { - filtered := make([]*mcp.ToolExecution, 0) - for _, exec := range allExecutions { - matchStatus := status == "" || exec.Status == status - // 支持部分匹配(模糊搜索) - matchTool := toolName == "" || strings.Contains(strings.ToLower(exec.ToolName), strings.ToLower(toolName)) - if matchStatus && matchTool { - filtered = append(filtered, exec) - } - } - allExecutions = filtered - } - total := len(allExecutions) - offset := (page - 1) * pageSize - end := offset + pageSize - if end > total { - end = total - } - if offset >= total { - return []*mcp.ToolExecution{}, total - } - return allExecutions[offset:end], total - } - - // 获取总数(考虑状态筛选和工具筛选) - total, err := h.db.CountToolExecutions(status, toolName) - if err != nil { - h.logger.Warn("获取执行记录总数失败", zap.Error(err)) - // 回退:使用已加载的记录数估算 - total = offset + len(executions) - if len(executions) == pageSize { - total = offset + len(executions) + 1 - } - } - - return executions, total -} - -func (h *MonitorHandler) loadStats() map[string]*mcp.ToolStats { - // 合并内部MCP服务器和外部MCP管理器的统计信息 - stats := make(map[string]*mcp.ToolStats) - - // 加载内部MCP服务器的统计信息 - if h.db == nil { - internalStats := h.mcpServer.GetStats() - for k, v := range internalStats { - stats[k] = v - } - } else { - dbStats, err := h.db.LoadToolStats() - if err != nil { - h.logger.Warn("从数据库加载统计信息失败,回退到内存数据", zap.Error(err)) - internalStats := h.mcpServer.GetStats() - for k, v := range internalStats { - stats[k] = v - } - } else { - for k, v := range dbStats { - stats[k] = v - } - } - } - - // 合并外部MCP管理器的统计信息 - if h.externalMCPMgr != nil { - externalStats := h.externalMCPMgr.GetToolStats() - for k, v := range externalStats { - // 如果已存在,合并统计信息 - if existing, exists := stats[k]; exists { - existing.TotalCalls += v.TotalCalls - existing.SuccessCalls += v.SuccessCalls - existing.FailedCalls += v.FailedCalls - // 使用最新的调用时间 - if v.LastCallTime != nil && (existing.LastCallTime == nil || v.LastCallTime.After(*existing.LastCallTime)) { - existing.LastCallTime = v.LastCallTime - } - } else { - stats[k] = v - } - } - } - - return stats -} - - -// GetExecution 获取特定执行记录 -func (h *MonitorHandler) GetExecution(c *gin.Context) { - id := c.Param("id") - - // 先从内部MCP服务器查找 - exec, exists := h.mcpServer.GetExecution(id) - if exists { - c.JSON(http.StatusOK, exec) - return - } - - // 如果找不到,尝试从外部MCP管理器查找 - if h.externalMCPMgr != nil { - exec, exists = h.externalMCPMgr.GetExecution(id) - if exists { - c.JSON(http.StatusOK, exec) - return - } - } - - // 如果都找不到,尝试从数据库查找(如果使用数据库存储) - if h.db != nil { - exec, err := h.db.GetToolExecution(id) - if err == nil && exec != nil { - c.JSON(http.StatusOK, exec) - return - } - } - - c.JSON(http.StatusNotFound, gin.H{"error": "执行记录未找到"}) -} - -// BatchGetToolNames 批量获取工具执行的工具名称(消除前端 N+1 请求) -func (h *MonitorHandler) BatchGetToolNames(c *gin.Context) { - var req struct { - IDs []string `json:"ids"` - } - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - - result := make(map[string]string, len(req.IDs)) - for _, id := range req.IDs { - // 先从内部MCP服务器查找 - if exec, exists := h.mcpServer.GetExecution(id); exists { - result[id] = exec.ToolName - continue - } - // 再从外部MCP管理器查找 - if h.externalMCPMgr != nil { - if exec, exists := h.externalMCPMgr.GetExecution(id); exists { - result[id] = exec.ToolName - continue - } - } - // 最后从数据库查找 - if h.db != nil { - if exec, err := h.db.GetToolExecution(id); err == nil && exec != nil { - result[id] = exec.ToolName - } - } - } - - c.JSON(http.StatusOK, result) -} - -// GetStats 获取统计信息 -func (h *MonitorHandler) GetStats(c *gin.Context) { - stats := h.loadStats() - c.JSON(http.StatusOK, stats) -} - -// DeleteExecution 删除执行记录 -func (h *MonitorHandler) DeleteExecution(c *gin.Context) { - id := c.Param("id") - if id == "" { - c.JSON(http.StatusBadRequest, gin.H{"error": "执行记录ID不能为空"}) - return - } - - // 如果使用数据库,先获取执行记录信息,然后删除并更新统计 - if h.db != nil { - // 先获取执行记录信息(用于更新统计) - exec, err := h.db.GetToolExecution(id) - if err != nil { - // 如果找不到记录,可能已经被删除,直接返回成功 - h.logger.Warn("执行记录不存在,可能已被删除", zap.String("executionId", id), zap.Error(err)) - c.JSON(http.StatusOK, gin.H{"message": "执行记录不存在或已被删除"}) - return - } - - // 删除执行记录 - err = h.db.DeleteToolExecution(id) - if err != nil { - h.logger.Error("删除执行记录失败", zap.Error(err), zap.String("executionId", id)) - c.JSON(http.StatusInternalServerError, gin.H{"error": "删除执行记录失败: " + err.Error()}) - return - } - - // 更新统计信息(减少相应的计数) - totalCalls := 1 - successCalls := 0 - failedCalls := 0 - if exec.Status == "failed" { - failedCalls = 1 - } else if exec.Status == "completed" { - successCalls = 1 - } - - if exec.ToolName != "" { - if err := h.db.DecreaseToolStats(exec.ToolName, totalCalls, successCalls, failedCalls); err != nil { - h.logger.Warn("更新统计信息失败", zap.Error(err), zap.String("toolName", exec.ToolName)) - // 不返回错误,因为记录已经删除成功 - } - } - - h.logger.Info("执行记录已从数据库删除", zap.String("executionId", id), zap.String("toolName", exec.ToolName)) - c.JSON(http.StatusOK, gin.H{"message": "执行记录已删除"}) - return - } - - // 如果不使用数据库,尝试从内存中删除(内部MCP服务器) - // 注意:内存中的记录可能已经被清理,所以这里只记录日志 - h.logger.Info("尝试删除内存中的执行记录", zap.String("executionId", id)) - c.JSON(http.StatusOK, gin.H{"message": "执行记录已删除(如果存在)"}) -} - -// DeleteExecutions 批量删除执行记录 -func (h *MonitorHandler) DeleteExecutions(c *gin.Context) { - var request struct { - IDs []string `json:"ids"` - } - - if err := c.ShouldBindJSON(&request); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": "请求参数无效: " + err.Error()}) - return - } - - if len(request.IDs) == 0 { - c.JSON(http.StatusBadRequest, gin.H{"error": "执行记录ID列表不能为空"}) - return - } - - // 如果使用数据库,先获取执行记录信息,然后删除并更新统计 - if h.db != nil { - // 先获取执行记录信息(用于更新统计) - executions, err := h.db.GetToolExecutionsByIds(request.IDs) - if err != nil { - h.logger.Error("获取执行记录失败", zap.Error(err)) - c.JSON(http.StatusInternalServerError, gin.H{"error": "获取执行记录失败: " + err.Error()}) - return - } - - // 按工具名称分组统计需要减少的数量 - toolStats := make(map[string]struct { - totalCalls int - successCalls int - failedCalls int - }) - - for _, exec := range executions { - if exec.ToolName == "" { - continue - } - - stats := toolStats[exec.ToolName] - stats.totalCalls++ - if exec.Status == "failed" { - stats.failedCalls++ - } else if exec.Status == "completed" { - stats.successCalls++ - } - toolStats[exec.ToolName] = stats - } - - // 批量删除执行记录 - err = h.db.DeleteToolExecutions(request.IDs) - if err != nil { - h.logger.Error("批量删除执行记录失败", zap.Error(err), zap.Int("count", len(request.IDs))) - c.JSON(http.StatusInternalServerError, gin.H{"error": "批量删除执行记录失败: " + err.Error()}) - return - } - - // 更新统计信息(减少相应的计数) - for toolName, stats := range toolStats { - if err := h.db.DecreaseToolStats(toolName, stats.totalCalls, stats.successCalls, stats.failedCalls); err != nil { - h.logger.Warn("更新统计信息失败", zap.Error(err), zap.String("toolName", toolName)) - // 不返回错误,因为记录已经删除成功 - } - } - - h.logger.Info("批量删除执行记录成功", zap.Int("count", len(request.IDs))) - c.JSON(http.StatusOK, gin.H{"message": "成功删除执行记录", "deleted": len(executions)}) - return - } - - // 如果不使用数据库,尝试从内存中删除(内部MCP服务器) - // 注意:内存中的记录可能已经被清理,所以这里只记录日志 - h.logger.Info("尝试批量删除内存中的执行记录", zap.Int("count", len(request.IDs))) - c.JSON(http.StatusOK, gin.H{"message": "执行记录已删除(如果存在)"}) -} - - diff --git a/handler/multi_agent.go b/handler/multi_agent.go deleted file mode 100644 index d8a54625..00000000 --- a/handler/multi_agent.go +++ /dev/null @@ -1,316 +0,0 @@ -package handler - -import ( - "context" - "encoding/json" - "errors" - "fmt" - "net/http" - "strings" - "sync" - "time" - - "cyberstrike-ai/internal/multiagent" - - "github.com/gin-gonic/gin" - "go.uber.org/zap" -) - -// MultiAgentLoopStream Eino DeepAgent 流式对话(需 config.multi_agent.enabled)。 -func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) { - c.Header("Content-Type", "text/event-stream") - c.Header("Cache-Control", "no-cache") - c.Header("Connection", "keep-alive") - if h.config == nil || !h.config.MultiAgent.Enabled { - ev := StreamEvent{Type: "error", Message: "多代理未启用,请在设置或 config.yaml 中开启 multi_agent.enabled"} - b, _ := json.Marshal(ev) - fmt.Fprintf(c.Writer, "data: %s\n\n", b) - done := StreamEvent{Type: "done", Message: ""} - db, _ := json.Marshal(done) - fmt.Fprintf(c.Writer, "data: %s\n\n", db) - if flusher, ok := c.Writer.(http.Flusher); ok { - flusher.Flush() - } - return - } - - var req ChatRequest - if err := c.ShouldBindJSON(&req); err != nil { - event := StreamEvent{Type: "error", Message: "请求参数错误: " + err.Error()} - b, _ := json.Marshal(event) - fmt.Fprintf(c.Writer, "data: %s\n\n", b) - c.Writer.Flush() - return - } - - c.Header("X-Accel-Buffering", "no") - - // 用于在 sendEvent 中判断是否为用户主动停止导致的取消。 - // 注意:baseCtx 会在后面创建;该变量用于闭包提前捕获引用。 - var baseCtx context.Context - - clientDisconnected := false - // 与 sseKeepalive 共用:禁止并发写 ResponseWriter,否则会破坏 chunked 编码(ERR_INVALID_CHUNKED_ENCODING)。 - var sseWriteMu sync.Mutex - sendEvent := func(eventType, message string, data interface{}) { - if clientDisconnected { - return - } - // 用户主动停止时,Eino 可能仍会并发上报 eventType=="error"。 - // 为避免 UI 看到“取消错误 + cancelled 文案”两条回复,这里直接丢弃取消对应的 error。 - if eventType == "error" && baseCtx != nil && errors.Is(context.Cause(baseCtx), ErrTaskCancelled) { - return - } - select { - case <-c.Request.Context().Done(): - clientDisconnected = true - return - default: - } - ev := StreamEvent{Type: eventType, Message: message, Data: data} - b, _ := json.Marshal(ev) - sseWriteMu.Lock() - _, err := fmt.Fprintf(c.Writer, "data: %s\n\n", b) - if err != nil { - sseWriteMu.Unlock() - clientDisconnected = true - return - } - if flusher, ok := c.Writer.(http.Flusher); ok { - flusher.Flush() - } else { - c.Writer.Flush() - } - sseWriteMu.Unlock() - } - - h.logger.Info("收到 Eino DeepAgent 流式请求", - zap.String("conversationId", req.ConversationID), - ) - - prep, err := h.prepareMultiAgentSession(&req) - if err != nil { - sendEvent("error", err.Error(), nil) - sendEvent("done", "", nil) - return - } - if prep.CreatedNew { - sendEvent("conversation", "会话已创建", map[string]interface{}{ - "conversationId": prep.ConversationID, - }) - } - - conversationID := prep.ConversationID - assistantMessageID := prep.AssistantMessageID - - if prep.UserMessageID != "" { - sendEvent("message_saved", "", map[string]interface{}{ - "conversationId": conversationID, - "userMessageId": prep.UserMessageID, - }) - } - - progressCallback := h.createProgressCallback(conversationID, assistantMessageID, sendEvent) - - baseCtx, cancelWithCause := context.WithCancelCause(context.Background()) - taskCtx, timeoutCancel := context.WithTimeout(baseCtx, 600*time.Minute) - defer timeoutCancel() - defer cancelWithCause(nil) - - if _, err := h.tasks.StartTask(conversationID, req.Message, cancelWithCause); err != nil { - var errorMsg string - if errors.Is(err, ErrTaskAlreadyRunning) { - errorMsg = "⚠️ 当前会话已有任务正在执行中,请等待当前任务完成或点击「停止任务」后再尝试。" - sendEvent("error", errorMsg, map[string]interface{}{ - "conversationId": conversationID, - "errorType": "task_already_running", - }) - } else { - errorMsg = "❌ 无法启动任务: " + err.Error() - sendEvent("error", errorMsg, nil) - } - if assistantMessageID != "" { - _, _ = h.db.Exec("UPDATE messages SET content = ? WHERE id = ?", errorMsg, assistantMessageID) - } - sendEvent("done", "", map[string]interface{}{"conversationId": conversationID}) - return - } - - taskStatus := "completed" - defer h.tasks.FinishTask(conversationID, taskStatus) - - sendEvent("progress", "正在启动 Eino DeepAgent...", map[string]interface{}{ - "conversationId": conversationID, - }) - - stopKeepalive := make(chan struct{}) - go sseKeepalive(c, stopKeepalive, &sseWriteMu) - defer close(stopKeepalive) - - result, runErr := multiagent.RunDeepAgent( - taskCtx, - h.config, - &h.config.MultiAgent, - h.agent, - h.logger, - conversationID, - prep.FinalMessage, - prep.History, - prep.RoleTools, - progressCallback, - h.agentsMarkdownDir, - ) - - if runErr != nil { - cause := context.Cause(baseCtx) - if errors.Is(cause, ErrTaskCancelled) { - taskStatus = "cancelled" - h.tasks.UpdateTaskStatus(conversationID, taskStatus) - cancelMsg := "任务已被用户取消,后续操作已停止。" - if assistantMessageID != "" { - _, _ = h.db.Exec("UPDATE messages SET content = ? WHERE id = ?", cancelMsg, assistantMessageID) - _ = h.db.AddProcessDetail(assistantMessageID, conversationID, "cancelled", cancelMsg, nil) - } - sendEvent("cancelled", cancelMsg, map[string]interface{}{ - "conversationId": conversationID, - "messageId": assistantMessageID, - }) - sendEvent("done", "", map[string]interface{}{"conversationId": conversationID}) - return - } - - h.logger.Error("Eino DeepAgent 执行失败", zap.Error(runErr)) - taskStatus = "failed" - h.tasks.UpdateTaskStatus(conversationID, taskStatus) - errMsg := "执行失败: " + runErr.Error() - if assistantMessageID != "" { - _, _ = h.db.Exec("UPDATE messages SET content = ? WHERE id = ?", errMsg, assistantMessageID) - _ = h.db.AddProcessDetail(assistantMessageID, conversationID, "error", errMsg, nil) - } - sendEvent("error", errMsg, map[string]interface{}{ - "conversationId": conversationID, - "messageId": assistantMessageID, - }) - sendEvent("done", "", map[string]interface{}{"conversationId": conversationID}) - return - } - - if assistantMessageID != "" { - mcpIDsJSON := "" - if len(result.MCPExecutionIDs) > 0 { - jsonData, _ := json.Marshal(result.MCPExecutionIDs) - mcpIDsJSON = string(jsonData) - } - _, _ = h.db.Exec( - "UPDATE messages SET content = ?, mcp_execution_ids = ? WHERE id = ?", - result.Response, - mcpIDsJSON, - assistantMessageID, - ) - } - - if result.LastReActInput != "" || result.LastReActOutput != "" { - if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil { - h.logger.Warn("保存 ReAct 数据失败", zap.Error(err)) - } - } - - sendEvent("response", result.Response, map[string]interface{}{ - "mcpExecutionIds": result.MCPExecutionIDs, - "conversationId": conversationID, - "messageId": assistantMessageID, - "agentMode": "eino_deep", - }) - sendEvent("done", "", map[string]interface{}{"conversationId": conversationID}) -} - -// MultiAgentLoop Eino DeepAgent 非流式对话(与 POST /api/agent-loop 对齐,需 multi_agent.enabled)。 -func (h *AgentHandler) MultiAgentLoop(c *gin.Context) { - if h.config == nil || !h.config.MultiAgent.Enabled { - c.JSON(http.StatusNotFound, gin.H{"error": "多代理未启用,请在 config.yaml 中设置 multi_agent.enabled: true"}) - return - } - - var req ChatRequest - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - - h.logger.Info("收到 Eino DeepAgent 非流式请求", zap.String("conversationId", req.ConversationID)) - - prep, err := h.prepareMultiAgentSession(&req) - if err != nil { - status, msg := multiAgentHTTPErrorStatus(err) - c.JSON(status, gin.H{"error": msg}) - return - } - - result, runErr := multiagent.RunDeepAgent( - c.Request.Context(), - h.config, - &h.config.MultiAgent, - h.agent, - h.logger, - prep.ConversationID, - prep.FinalMessage, - prep.History, - prep.RoleTools, - nil, - h.agentsMarkdownDir, - ) - if runErr != nil { - h.logger.Error("Eino DeepAgent 执行失败", zap.Error(runErr)) - errMsg := "执行失败: " + runErr.Error() - if prep.AssistantMessageID != "" { - _, _ = h.db.Exec("UPDATE messages SET content = ? WHERE id = ?", errMsg, prep.AssistantMessageID) - } - c.JSON(http.StatusInternalServerError, gin.H{"error": errMsg}) - return - } - - if prep.AssistantMessageID != "" { - mcpIDsJSON := "" - if len(result.MCPExecutionIDs) > 0 { - jsonData, _ := json.Marshal(result.MCPExecutionIDs) - mcpIDsJSON = string(jsonData) - } - _, _ = h.db.Exec( - "UPDATE messages SET content = ?, mcp_execution_ids = ? WHERE id = ?", - result.Response, - mcpIDsJSON, - prep.AssistantMessageID, - ) - } - - if result.LastReActInput != "" || result.LastReActOutput != "" { - if err := h.db.SaveReActData(prep.ConversationID, result.LastReActInput, result.LastReActOutput); err != nil { - h.logger.Warn("保存 ReAct 数据失败", zap.Error(err)) - } - } - - c.JSON(http.StatusOK, ChatResponse{ - Response: result.Response, - MCPExecutionIDs: result.MCPExecutionIDs, - ConversationID: prep.ConversationID, - Time: time.Now(), - }) -} - -func multiAgentHTTPErrorStatus(err error) (int, string) { - msg := err.Error() - switch { - case strings.Contains(msg, "对话不存在"): - return http.StatusNotFound, msg - case strings.Contains(msg, "未找到该 WebShell"): - return http.StatusBadRequest, msg - case strings.Contains(msg, "附件最多"): - return http.StatusBadRequest, msg - case strings.Contains(msg, "保存用户消息失败"), strings.Contains(msg, "创建对话失败"): - return http.StatusInternalServerError, msg - case strings.Contains(msg, "保存上传文件失败"): - return http.StatusInternalServerError, msg - default: - return http.StatusBadRequest, msg - } -} diff --git a/handler/multi_agent_prepare.go b/handler/multi_agent_prepare.go deleted file mode 100644 index 27190013..00000000 --- a/handler/multi_agent_prepare.go +++ /dev/null @@ -1,138 +0,0 @@ -package handler - -import ( - "fmt" - "strings" - - "cyberstrike-ai/internal/agent" - "cyberstrike-ai/internal/database" - "cyberstrike-ai/internal/mcp/builtin" - - "go.uber.org/zap" -) - -// multiAgentPrepared 多代理请求在调用 Eino 前的会话与消息准备结果。 -type multiAgentPrepared struct { - ConversationID string - CreatedNew bool - History []agent.ChatMessage - FinalMessage string - RoleTools []string - AssistantMessageID string - UserMessageID string -} - -func (h *AgentHandler) prepareMultiAgentSession(req *ChatRequest) (*multiAgentPrepared, error) { - if len(req.Attachments) > maxAttachments { - return nil, fmt.Errorf("附件最多 %d 个", maxAttachments) - } - - conversationID := strings.TrimSpace(req.ConversationID) - createdNew := false - if conversationID == "" { - title := safeTruncateString(req.Message, 50) - var conv *database.Conversation - var err error - if strings.TrimSpace(req.WebShellConnectionID) != "" { - conv, err = h.db.CreateConversationWithWebshell(strings.TrimSpace(req.WebShellConnectionID), title) - } else { - conv, err = h.db.CreateConversation(title) - } - if err != nil { - return nil, fmt.Errorf("创建对话失败: %w", err) - } - conversationID = conv.ID - createdNew = true - } else { - if _, err := h.db.GetConversation(conversationID); err != nil { - return nil, fmt.Errorf("对话不存在") - } - } - - agentHistoryMessages, err := h.loadHistoryFromReActData(conversationID) - if err != nil { - historyMessages, getErr := h.db.GetMessages(conversationID) - if getErr != nil { - agentHistoryMessages = []agent.ChatMessage{} - } else { - agentHistoryMessages = make([]agent.ChatMessage, 0, len(historyMessages)) - for _, msg := range historyMessages { - agentHistoryMessages = append(agentHistoryMessages, agent.ChatMessage{ - Role: msg.Role, - Content: msg.Content, - }) - } - } - } - - finalMessage := req.Message - var roleTools []string - if req.WebShellConnectionID != "" { - conn, errConn := h.db.GetWebshellConnection(strings.TrimSpace(req.WebShellConnectionID)) - if errConn != nil || conn == nil { - h.logger.Warn("WebShell AI 助手:未找到连接", zap.String("id", req.WebShellConnectionID), zap.Error(errConn)) - return nil, fmt.Errorf("未找到该 WebShell 连接") - } - remark := conn.Remark - if remark == "" { - remark = conn.URL - } - finalMessage = fmt.Sprintf("[WebShell 助手上下文] 当前连接 ID:%s,备注:%s。可用工具(仅在该连接上操作时使用,connection_id 填 \"%s\"):webshell_exec、webshell_file_list、webshell_file_read、webshell_file_write、record_vulnerability、list_knowledge_risk_types、search_knowledge_base。Skills 包请使用 Eino 多代理内置 `skill` 工具。\n\n用户请求:%s", - conn.ID, remark, conn.ID, req.Message) - roleTools = []string{ - builtin.ToolWebshellExec, - builtin.ToolWebshellFileList, - builtin.ToolWebshellFileRead, - builtin.ToolWebshellFileWrite, - builtin.ToolRecordVulnerability, - builtin.ToolListKnowledgeRiskTypes, - builtin.ToolSearchKnowledgeBase, - } - } else if req.Role != "" && req.Role != "默认" && h.config != nil && h.config.Roles != nil { - if role, exists := h.config.Roles[req.Role]; exists && role.Enabled { - if role.UserPrompt != "" { - finalMessage = role.UserPrompt + "\n\n" + req.Message - } - roleTools = role.Tools - } - } - - var savedPaths []string - if len(req.Attachments) > 0 { - var aerr error - savedPaths, aerr = saveAttachmentsToDateAndConversationDir(req.Attachments, conversationID, h.logger) - if aerr != nil { - return nil, fmt.Errorf("保存上传文件失败: %w", aerr) - } - } - finalMessage = appendAttachmentsToMessage(finalMessage, req.Attachments, savedPaths) - - userContent := userMessageContentForStorage(req.Message, req.Attachments, savedPaths) - userMsgRow, uerr := h.db.AddMessage(conversationID, "user", userContent, nil) - if uerr != nil { - h.logger.Error("保存用户消息失败", zap.Error(uerr)) - return nil, fmt.Errorf("保存用户消息失败: %w", uerr) - } - userMessageID := "" - if userMsgRow != nil { - userMessageID = userMsgRow.ID - } - - assistantMsg, aerr := h.db.AddMessage(conversationID, "assistant", "处理中...", nil) - var assistantMessageID string - if aerr != nil { - h.logger.Warn("创建助手消息占位失败", zap.Error(aerr)) - } else if assistantMsg != nil { - assistantMessageID = assistantMsg.ID - } - - return &multiAgentPrepared{ - ConversationID: conversationID, - CreatedNew: createdNew, - History: agentHistoryMessages, - FinalMessage: finalMessage, - RoleTools: roleTools, - AssistantMessageID: assistantMessageID, - UserMessageID: userMessageID, - }, nil -} diff --git a/handler/openapi.go b/handler/openapi.go deleted file mode 100644 index 5b1b80c0..00000000 --- a/handler/openapi.go +++ /dev/null @@ -1,4596 +0,0 @@ -package handler - -import ( - "net/http" - "time" - - "cyberstrike-ai/internal/database" - "cyberstrike-ai/internal/storage" - - "github.com/gin-gonic/gin" - "go.uber.org/zap" -) - -// OpenAPIHandler OpenAPI处理器 -type OpenAPIHandler struct { - db *database.DB - logger *zap.Logger - resultStorage storage.ResultStorage - conversationHdlr *ConversationHandler - agentHdlr *AgentHandler -} - -// NewOpenAPIHandler 创建新的OpenAPI处理器 -func NewOpenAPIHandler(db *database.DB, logger *zap.Logger, resultStorage storage.ResultStorage, conversationHdlr *ConversationHandler, agentHdlr *AgentHandler) *OpenAPIHandler { - return &OpenAPIHandler{ - db: db, - logger: logger, - resultStorage: resultStorage, - conversationHdlr: conversationHdlr, - agentHdlr: agentHdlr, - } -} - -// GetOpenAPISpec 获取OpenAPI规范 -func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) { - host := c.Request.Host - scheme := "http" - if c.Request.TLS != nil { - scheme = "https" - } - - spec := map[string]interface{}{ - "openapi": "3.0.0", - "info": map[string]interface{}{ - "title": "CyberStrikeAI API", - "description": "AI驱动的自动化安全测试平台API文档", - "version": "1.0.0", - "contact": map[string]interface{}{ - "name": "CyberStrikeAI", - }, - }, - "servers": []map[string]interface{}{ - { - "url": scheme + "://" + host, - "description": "当前服务器", - }, - }, - "components": map[string]interface{}{ - "securitySchemes": map[string]interface{}{ - "bearerAuth": map[string]interface{}{ - "type": "http", - "scheme": "bearer", - "bearerFormat": "JWT", - "description": "使用Bearer Token进行认证。Token通过 /api/auth/login 接口获取。", - }, - }, - "schemas": map[string]interface{}{ - "CreateConversationRequest": map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "title": map[string]interface{}{ - "type": "string", - "description": "对话标题", - "example": "Web应用安全测试", - }, - }, - }, - "Conversation": map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "id": map[string]interface{}{ - "type": "string", - "description": "对话ID", - "example": "550e8400-e29b-41d4-a716-446655440000", - }, - "title": map[string]interface{}{ - "type": "string", - "description": "对话标题", - "example": "Web应用安全测试", - }, - "createdAt": map[string]interface{}{ - "type": "string", - "format": "date-time", - "description": "创建时间", - }, - "updatedAt": map[string]interface{}{ - "type": "string", - "format": "date-time", - "description": "更新时间", - }, - }, - }, - "ConversationDetail": map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "id": map[string]interface{}{ - "type": "string", - "description": "对话ID", - }, - "title": map[string]interface{}{ - "type": "string", - "description": "对话标题", - }, - "status": map[string]interface{}{ - "type": "string", - "description": "对话状态:active(进行中)、completed(已完成)、failed(失败)", - "enum": []string{"active", "completed", "failed"}, - }, - "createdAt": map[string]interface{}{ - "type": "string", - "format": "date-time", - "description": "创建时间", - }, - "updatedAt": map[string]interface{}{ - "type": "string", - "format": "date-time", - "description": "更新时间", - }, - "messages": map[string]interface{}{ - "type": "array", - "description": "消息列表", - "items": map[string]interface{}{ - "$ref": "#/components/schemas/Message", - }, - }, - "messageCount": map[string]interface{}{ - "type": "integer", - "description": "消息数量", - }, - }, - }, - "Message": map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "id": map[string]interface{}{ - "type": "string", - "description": "消息ID", - }, - "conversationId": map[string]interface{}{ - "type": "string", - "description": "对话ID", - }, - "role": map[string]interface{}{ - "type": "string", - "description": "消息角色:user(用户)、assistant(助手)", - "enum": []string{"user", "assistant"}, - }, - "content": map[string]interface{}{ - "type": "string", - "description": "消息内容", - }, - "createdAt": map[string]interface{}{ - "type": "string", - "format": "date-time", - "description": "创建时间", - }, - }, - }, - "ConversationResults": map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "conversationId": map[string]interface{}{ - "type": "string", - "description": "对话ID", - }, - "messages": map[string]interface{}{ - "type": "array", - "description": "消息列表", - "items": map[string]interface{}{ - "$ref": "#/components/schemas/Message", - }, - }, - "vulnerabilities": map[string]interface{}{ - "type": "array", - "description": "发现的漏洞列表", - "items": map[string]interface{}{ - "$ref": "#/components/schemas/Vulnerability", - }, - }, - "executionResults": map[string]interface{}{ - "type": "array", - "description": "执行结果列表", - "items": map[string]interface{}{ - "$ref": "#/components/schemas/ExecutionResult", - }, - }, - }, - }, - "Vulnerability": map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "id": map[string]interface{}{ - "type": "string", - "description": "漏洞ID", - }, - "title": map[string]interface{}{ - "type": "string", - "description": "漏洞标题", - }, - "description": map[string]interface{}{ - "type": "string", - "description": "漏洞描述", - }, - "severity": map[string]interface{}{ - "type": "string", - "description": "严重程度", - "enum": []string{"critical", "high", "medium", "low", "info"}, - }, - "status": map[string]interface{}{ - "type": "string", - "description": "状态", - "enum": []string{"open", "closed", "fixed"}, - }, - "target": map[string]interface{}{ - "type": "string", - "description": "受影响的目标", - }, - }, - }, - "ExecutionResult": map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "id": map[string]interface{}{ - "type": "string", - "description": "执行ID", - }, - "toolName": map[string]interface{}{ - "type": "string", - "description": "工具名称", - }, - "status": map[string]interface{}{ - "type": "string", - "description": "执行状态", - "enum": []string{"success", "failed", "running"}, - }, - "result": map[string]interface{}{ - "type": "string", - "description": "执行结果", - }, - "createdAt": map[string]interface{}{ - "type": "string", - "format": "date-time", - "description": "创建时间", - }, - }, - }, - "Error": map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "error": map[string]interface{}{ - "type": "string", - "description": "错误信息", - }, - }, - }, - "LoginRequest": map[string]interface{}{ - "type": "object", - "required": []string{"password"}, - "properties": map[string]interface{}{ - "password": map[string]interface{}{ - "type": "string", - "description": "登录密码", - }, - }, - }, - "LoginResponse": map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "token": map[string]interface{}{ - "type": "string", - "description": "认证Token", - }, - "expires_at": map[string]interface{}{ - "type": "string", - "format": "date-time", - "description": "Token过期时间", - }, - "session_duration_hr": map[string]interface{}{ - "type": "integer", - "description": "会话持续时间(小时)", - }, - }, - }, - "ChangePasswordRequest": map[string]interface{}{ - "type": "object", - "required": []string{"oldPassword", "newPassword"}, - "properties": map[string]interface{}{ - "oldPassword": map[string]interface{}{ - "type": "string", - "description": "当前密码", - }, - "newPassword": map[string]interface{}{ - "type": "string", - "description": "新密码(至少8位)", - }, - }, - }, - "UpdateConversationRequest": map[string]interface{}{ - "type": "object", - "required": []string{"title"}, - "properties": map[string]interface{}{ - "title": map[string]interface{}{ - "type": "string", - "description": "对话标题", - }, - }, - }, - "Group": map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "id": map[string]interface{}{ - "type": "string", - "description": "分组ID", - }, - "name": map[string]interface{}{ - "type": "string", - "description": "分组名称", - }, - "icon": map[string]interface{}{ - "type": "string", - "description": "分组图标", - }, - "createdAt": map[string]interface{}{ - "type": "string", - "format": "date-time", - "description": "创建时间", - }, - "updatedAt": map[string]interface{}{ - "type": "string", - "format": "date-time", - "description": "更新时间", - }, - }, - }, - "CreateGroupRequest": map[string]interface{}{ - "type": "object", - "required": []string{"name"}, - "properties": map[string]interface{}{ - "name": map[string]interface{}{ - "type": "string", - "description": "分组名称", - }, - "icon": map[string]interface{}{ - "type": "string", - "description": "分组图标(可选)", - }, - }, - }, - "UpdateGroupRequest": map[string]interface{}{ - "type": "object", - "required": []string{"name"}, - "properties": map[string]interface{}{ - "name": map[string]interface{}{ - "type": "string", - "description": "分组名称", - }, - "icon": map[string]interface{}{ - "type": "string", - "description": "分组图标", - }, - }, - }, - "AddConversationToGroupRequest": map[string]interface{}{ - "type": "object", - "required": []string{"conversationId", "groupId"}, - "properties": map[string]interface{}{ - "conversationId": map[string]interface{}{ - "type": "string", - "description": "对话ID", - }, - "groupId": map[string]interface{}{ - "type": "string", - "description": "分组ID", - }, - }, - }, - "BatchTaskRequest": map[string]interface{}{ - "type": "object", - "required": []string{"tasks"}, - "properties": map[string]interface{}{ - "title": map[string]interface{}{ - "type": "string", - "description": "任务标题(可选)", - }, - "tasks": map[string]interface{}{ - "type": "array", - "description": "任务列表,每行一个任务", - "items": map[string]interface{}{ - "type": "string", - }, - }, - "role": map[string]interface{}{ - "type": "string", - "description": "角色名称(可选)", - }, - "agentMode": map[string]interface{}{ - "type": "string", - "description": "代理模式(single | multi)", - "enum": []string{"single", "multi"}, - }, - "scheduleMode": map[string]interface{}{ - "type": "string", - "description": "调度方式(manual | cron)", - "enum": []string{"manual", "cron"}, - }, - "cronExpr": map[string]interface{}{ - "type": "string", - "description": "Cron 表达式(scheduleMode=cron 时必填)", - }, - "executeNow": map[string]interface{}{ - "type": "boolean", - "description": "是否创建后立即执行(默认 false)", - }, - }, - }, - "BatchQueue": map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "id": map[string]interface{}{ - "type": "string", - "description": "队列ID", - }, - "title": map[string]interface{}{ - "type": "string", - "description": "队列标题", - }, - "status": map[string]interface{}{ - "type": "string", - "description": "队列状态", - "enum": []string{"pending", "running", "paused", "completed", "failed"}, - }, - "tasks": map[string]interface{}{ - "type": "array", - "description": "任务列表", - "items": map[string]interface{}{ - "type": "object", - }, - }, - "createdAt": map[string]interface{}{ - "type": "string", - "format": "date-time", - "description": "创建时间", - }, - }, - }, - "CancelAgentLoopRequest": map[string]interface{}{ - "type": "object", - "required": []string{"conversationId"}, - "properties": map[string]interface{}{ - "conversationId": map[string]interface{}{ - "type": "string", - "description": "对话ID", - }, - }, - }, - "AgentTask": map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "conversationId": map[string]interface{}{ - "type": "string", - "description": "对话ID", - }, - "status": map[string]interface{}{ - "type": "string", - "description": "任务状态", - "enum": []string{"running", "completed", "failed", "cancelled", "timeout"}, - }, - "startedAt": map[string]interface{}{ - "type": "string", - "format": "date-time", - "description": "开始时间", - }, - }, - }, - "CreateVulnerabilityRequest": map[string]interface{}{ - "type": "object", - "required": []string{"conversation_id", "title", "severity"}, - "properties": map[string]interface{}{ - "conversation_id": map[string]interface{}{ - "type": "string", - "description": "对话ID", - }, - "title": map[string]interface{}{ - "type": "string", - "description": "漏洞标题", - }, - "description": map[string]interface{}{ - "type": "string", - "description": "漏洞描述", - }, - "severity": map[string]interface{}{ - "type": "string", - "description": "严重程度", - "enum": []string{"critical", "high", "medium", "low", "info"}, - }, - "status": map[string]interface{}{ - "type": "string", - "description": "状态", - "enum": []string{"open", "closed", "fixed"}, - }, - "type": map[string]interface{}{ - "type": "string", - "description": "漏洞类型", - }, - "target": map[string]interface{}{ - "type": "string", - "description": "受影响的目标", - }, - "proof": map[string]interface{}{ - "type": "string", - "description": "漏洞证明", - }, - "impact": map[string]interface{}{ - "type": "string", - "description": "影响", - }, - "recommendation": map[string]interface{}{ - "type": "string", - "description": "修复建议", - }, - }, - }, - "UpdateVulnerabilityRequest": map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "title": map[string]interface{}{ - "type": "string", - "description": "漏洞标题", - }, - "description": map[string]interface{}{ - "type": "string", - "description": "漏洞描述", - }, - "severity": map[string]interface{}{ - "type": "string", - "description": "严重程度", - "enum": []string{"critical", "high", "medium", "low", "info"}, - }, - "status": map[string]interface{}{ - "type": "string", - "description": "状态", - "enum": []string{"open", "closed", "fixed"}, - }, - "type": map[string]interface{}{ - "type": "string", - "description": "漏洞类型", - }, - "target": map[string]interface{}{ - "type": "string", - "description": "受影响的目标", - }, - "proof": map[string]interface{}{ - "type": "string", - "description": "漏洞证明", - }, - "impact": map[string]interface{}{ - "type": "string", - "description": "影响", - }, - "recommendation": map[string]interface{}{ - "type": "string", - "description": "修复建议", - }, - }, - }, - "ListVulnerabilitiesResponse": map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "vulnerabilities": map[string]interface{}{ - "type": "array", - "description": "漏洞列表", - "items": map[string]interface{}{ - "$ref": "#/components/schemas/Vulnerability", - }, - }, - "total": map[string]interface{}{ - "type": "integer", - "description": "总数", - }, - "page": map[string]interface{}{ - "type": "integer", - "description": "当前页", - }, - "page_size": map[string]interface{}{ - "type": "integer", - "description": "每页数量", - }, - "total_pages": map[string]interface{}{ - "type": "integer", - "description": "总页数", - }, - }, - }, - "VulnerabilityStats": map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "total": map[string]interface{}{ - "type": "integer", - "description": "总漏洞数", - }, - "by_severity": map[string]interface{}{ - "type": "object", - "description": "按严重程度统计", - }, - "by_status": map[string]interface{}{ - "type": "object", - "description": "按状态统计", - }, - }, - }, - "RoleConfig": map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "name": map[string]interface{}{ - "type": "string", - "description": "角色名称", - }, - "description": map[string]interface{}{ - "type": "string", - "description": "角色描述", - }, - "enabled": map[string]interface{}{ - "type": "boolean", - "description": "是否启用", - }, - "systemPrompt": map[string]interface{}{ - "type": "string", - "description": "系统提示词", - }, - "userPrompt": map[string]interface{}{ - "type": "string", - "description": "用户提示词", - }, - "tools": map[string]interface{}{ - "type": "array", - "description": "工具列表", - "items": map[string]interface{}{ - "type": "string", - }, - }, - "skills": map[string]interface{}{ - "type": "array", - "description": "Skills列表", - "items": map[string]interface{}{ - "type": "string", - }, - }, - }, - }, - "Skill": map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "name": map[string]interface{}{ - "type": "string", - "description": "Skill名称", - }, - "description": map[string]interface{}{ - "type": "string", - "description": "Skill描述", - }, - "path": map[string]interface{}{ - "type": "string", - "description": "Skill路径", - }, - }, - }, - "CreateSkillRequest": map[string]interface{}{ - "type": "object", - "required": []string{"name", "description"}, - "properties": map[string]interface{}{ - "name": map[string]interface{}{ - "type": "string", - "description": "Skill名称", - }, - "description": map[string]interface{}{ - "type": "string", - "description": "Skill描述", - }, - }, - }, - "UpdateSkillRequest": map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "description": map[string]interface{}{ - "type": "string", - "description": "Skill描述", - }, - }, - }, - "ToolExecution": map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "id": map[string]interface{}{ - "type": "string", - "description": "执行ID", - }, - "toolName": map[string]interface{}{ - "type": "string", - "description": "工具名称", - }, - "status": map[string]interface{}{ - "type": "string", - "description": "执行状态", - "enum": []string{"success", "failed", "running"}, - }, - "createdAt": map[string]interface{}{ - "type": "string", - "format": "date-time", - "description": "创建时间", - }, - }, - }, - "MonitorResponse": map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "executions": map[string]interface{}{ - "type": "array", - "description": "执行记录列表", - "items": map[string]interface{}{ - "$ref": "#/components/schemas/ToolExecution", - }, - }, - "stats": map[string]interface{}{ - "type": "object", - "description": "统计信息", - }, - "timestamp": map[string]interface{}{ - "type": "string", - "format": "date-time", - "description": "时间戳", - }, - "total": map[string]interface{}{ - "type": "integer", - "description": "总数", - }, - "page": map[string]interface{}{ - "type": "integer", - "description": "当前页", - }, - "page_size": map[string]interface{}{ - "type": "integer", - "description": "每页数量", - }, - "total_pages": map[string]interface{}{ - "type": "integer", - "description": "总页数", - }, - }, - }, - "ConfigResponse": map[string]interface{}{ - "type": "object", - "description": "配置信息", - }, - "UpdateConfigRequest": map[string]interface{}{ - "type": "object", - "description": "更新配置请求", - }, - "ExternalMCPConfig": map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "enabled": map[string]interface{}{ - "type": "boolean", - "description": "是否启用", - }, - "command": map[string]interface{}{ - "type": "string", - "description": "命令", - }, - "args": map[string]interface{}{ - "type": "array", - "description": "参数列表", - "items": map[string]interface{}{ - "type": "string", - }, - }, - }, - }, - "ExternalMCPResponse": map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "config": map[string]interface{}{ - "$ref": "#/components/schemas/ExternalMCPConfig", - }, - "status": map[string]interface{}{ - "type": "string", - "description": "状态", - "enum": []string{"connected", "disconnected", "error", "disabled"}, - }, - "toolCount": map[string]interface{}{ - "type": "integer", - "description": "工具数量", - }, - "error": map[string]interface{}{ - "type": "string", - "description": "错误信息", - }, - }, - }, - "AddOrUpdateExternalMCPRequest": map[string]interface{}{ - "type": "object", - "required": []string{"config"}, - "properties": map[string]interface{}{ - "config": map[string]interface{}{ - "$ref": "#/components/schemas/ExternalMCPConfig", - }, - }, - }, - "AttackChain": map[string]interface{}{ - "type": "object", - "description": "攻击链数据", - }, - "MCPMessage": map[string]interface{}{ - "type": "object", - "description": "MCP消息(符合JSON-RPC 2.0规范)", - "required": []string{"jsonrpc"}, - "properties": map[string]interface{}{ - "id": map[string]interface{}{ - "description": "消息ID,可以是字符串、数字或null。对于请求,必须提供;对于通知,可以省略", - "oneOf": []map[string]interface{}{ - {"type": "string"}, - {"type": "number"}, - {"type": "null"}, - }, - "example": "550e8400-e29b-41d4-a716-446655440000", - }, - "method": map[string]interface{}{ - "type": "string", - "description": "方法名。支持的方法:\n- `initialize`: 初始化MCP连接\n- `tools/list`: 列出所有可用工具\n- `tools/call`: 调用工具\n- `prompts/list`: 列出所有提示词模板\n- `prompts/get`: 获取提示词模板\n- `resources/list`: 列出所有资源\n- `resources/read`: 读取资源内容\n- `sampling/request`: 采样请求", - "enum": []string{ - "initialize", - "tools/list", - "tools/call", - "prompts/list", - "prompts/get", - "resources/list", - "resources/read", - "sampling/request", - }, - "example": "tools/list", - }, - "params": map[string]interface{}{ - "description": "方法参数(JSON对象),根据不同的method有不同的结构", - "type": "object", - }, - "jsonrpc": map[string]interface{}{ - "type": "string", - "description": "JSON-RPC版本,固定为\"2.0\"", - "enum": []string{"2.0"}, - "example": "2.0", - }, - }, - }, - "MCPInitializeParams": map[string]interface{}{ - "type": "object", - "required": []string{"protocolVersion", "capabilities", "clientInfo"}, - "properties": map[string]interface{}{ - "protocolVersion": map[string]interface{}{ - "type": "string", - "description": "协议版本", - "example": "2024-11-05", - }, - "capabilities": map[string]interface{}{ - "type": "object", - "description": "客户端能力", - }, - "clientInfo": map[string]interface{}{ - "type": "object", - "required": []string{"name", "version"}, - "properties": map[string]interface{}{ - "name": map[string]interface{}{ - "type": "string", - "description": "客户端名称", - "example": "MyClient", - }, - "version": map[string]interface{}{ - "type": "string", - "description": "客户端版本", - "example": "1.0.0", - }, - }, - }, - }, - }, - "MCPCallToolParams": map[string]interface{}{ - "type": "object", - "required": []string{"name", "arguments"}, - "properties": map[string]interface{}{ - "name": map[string]interface{}{ - "type": "string", - "description": "工具名称", - "example": "nmap", - }, - "arguments": map[string]interface{}{ - "type": "object", - "description": "工具参数(键值对),具体参数取决于工具定义", - "example": map[string]interface{}{ - "target": "192.168.1.1", - "ports": "80,443", - }, - }, - }, - }, - "MCPResponse": map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "id": map[string]interface{}{ - "description": "消息ID(与请求中的id相同)", - "oneOf": []map[string]interface{}{ - {"type": "string"}, - {"type": "number"}, - {"type": "null"}, - }, - }, - "result": map[string]interface{}{ - "description": "方法执行结果(JSON对象),结构取决于调用的方法", - "type": "object", - }, - "error": map[string]interface{}{ - "type": "object", - "description": "错误信息(如果执行失败)", - "properties": map[string]interface{}{ - "code": map[string]interface{}{ - "type": "integer", - "description": "错误代码", - "example": -32600, - }, - "message": map[string]interface{}{ - "type": "string", - "description": "错误消息", - "example": "Invalid Request", - }, - "data": map[string]interface{}{ - "description": "错误详情(可选)", - }, - }, - }, - "jsonrpc": map[string]interface{}{ - "type": "string", - "description": "JSON-RPC版本", - "example": "2.0", - }, - }, - }, - }, - }, - "security": []map[string]interface{}{ - { - "bearerAuth": []string{}, - }, - }, - "paths": map[string]interface{}{ - "/api/auth/login": map[string]interface{}{ - "post": map[string]interface{}{ - "tags": []string{"认证"}, - "summary": "用户登录", - "description": "使用密码登录获取认证Token", - "operationId": "login", - "security": []map[string]interface{}{}, - "requestBody": map[string]interface{}{ - "required": true, - "content": map[string]interface{}{ - "application/json": map[string]interface{}{ - "schema": map[string]interface{}{ - "$ref": "#/components/schemas/LoginRequest", - }, - }, - }, - }, - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "登录成功", - "content": map[string]interface{}{ - "application/json": map[string]interface{}{ - "schema": map[string]interface{}{ - "$ref": "#/components/schemas/LoginResponse", - }, - }, - }, - }, - "401": map[string]interface{}{ - "description": "密码错误", - }, - }, - }, - }, - "/api/auth/logout": map[string]interface{}{ - "post": map[string]interface{}{ - "tags": []string{"认证"}, - "summary": "用户登出", - "description": "登出当前会话,使Token失效", - "operationId": "logout", - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "登出成功", - "content": map[string]interface{}{ - "application/json": map[string]interface{}{ - "schema": map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "message": map[string]interface{}{ - "type": "string", - "example": "已退出登录", - }, - }, - }, - }, - }, - }, - "401": map[string]interface{}{ - "description": "未授权", - }, - }, - }, - }, - "/api/auth/change-password": map[string]interface{}{ - "post": map[string]interface{}{ - "tags": []string{"认证"}, - "summary": "修改密码", - "description": "修改登录密码,修改后所有会话将失效", - "operationId": "changePassword", - "requestBody": map[string]interface{}{ - "required": true, - "content": map[string]interface{}{ - "application/json": map[string]interface{}{ - "schema": map[string]interface{}{ - "$ref": "#/components/schemas/ChangePasswordRequest", - }, - }, - }, - }, - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "密码修改成功", - "content": map[string]interface{}{ - "application/json": map[string]interface{}{ - "schema": map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "message": map[string]interface{}{ - "type": "string", - "example": "密码已更新,请使用新密码重新登录", - }, - }, - }, - }, - }, - }, - "400": map[string]interface{}{ - "description": "请求参数错误", - }, - "401": map[string]interface{}{ - "description": "未授权", - }, - }, - }, - }, - "/api/auth/validate": map[string]interface{}{ - "get": map[string]interface{}{ - "tags": []string{"认证"}, - "summary": "验证Token", - "description": "验证当前Token是否有效", - "operationId": "validateToken", - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "Token有效", - "content": map[string]interface{}{ - "application/json": map[string]interface{}{ - "schema": map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "token": map[string]interface{}{ - "type": "string", - "description": "Token", - }, - "expires_at": map[string]interface{}{ - "type": "string", - "format": "date-time", - "description": "过期时间", - }, - }, - }, - }, - }, - }, - "401": map[string]interface{}{ - "description": "Token无效或已过期", - }, - }, - }, - }, - "/api/conversations": map[string]interface{}{ - "post": map[string]interface{}{ - "tags": []string{"对话管理"}, - "summary": "创建对话", - "description": "创建一个新的安全测试对话。\n**重要说明**:\n- ✅ 创建的对话会**立即保存到数据库**\n- ✅ 前端页面会**自动刷新**显示新对话\n- ✅ 与前端创建的对话**完全一致**\n**创建对话的两种方式**:\n**方式1(推荐):** 直接使用 `/api/agent-loop` 发送消息,**不提供** `conversationId` 参数,系统会自动创建新对话并发送消息。这是最简单的方式,一步完成创建和发送。\n**方式2:** 先调用此端点创建空对话,然后使用返回的 `conversationId` 调用 `/api/agent-loop` 发送消息。适用于需要先创建对话,稍后再发送消息的场景。\n**示例**:\n```json\n{\n \"title\": \"Web应用安全测试\"\n}\n```", - "operationId": "createConversation", - "requestBody": map[string]interface{}{ - "required": true, - "content": map[string]interface{}{ - "application/json": map[string]interface{}{ - "schema": map[string]interface{}{ - "$ref": "#/components/schemas/CreateConversationRequest", - }, - }, - }, - }, - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "对话创建成功", - "content": map[string]interface{}{ - "application/json": map[string]interface{}{ - "schema": map[string]interface{}{ - "$ref": "#/components/schemas/Conversation", - }, - }, - }, - }, - "400": map[string]interface{}{ - "description": "请求参数错误", - }, - "401": map[string]interface{}{ - "description": "未授权,需要有效的Token", - }, - "500": map[string]interface{}{ - "description": "服务器内部错误", - }, - }, - }, - "get": map[string]interface{}{ - "tags": []string{"对话管理"}, - "summary": "列出对话", - "description": "获取对话列表,支持分页和搜索", - "operationId": "listConversations", - "parameters": []map[string]interface{}{ - { - "name": "limit", - "in": "query", - "required": false, - "description": "返回数量限制", - "schema": map[string]interface{}{ - "type": "integer", - "default": 50, - "minimum": 1, - "maximum": 100, - }, - }, - { - "name": "offset", - "in": "query", - "required": false, - "description": "偏移量", - "schema": map[string]interface{}{ - "type": "integer", - "default": 0, - "minimum": 0, - }, - }, - { - "name": "search", - "in": "query", - "required": false, - "description": "搜索关键词", - "schema": map[string]interface{}{ - "type": "string", - }, - }, - }, - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "获取成功", - "content": map[string]interface{}{ - "application/json": map[string]interface{}{ - "schema": map[string]interface{}{ - "type": "array", - "items": map[string]interface{}{ - "$ref": "#/components/schemas/Conversation", - }, - }, - }, - }, - }, - "401": map[string]interface{}{ - "description": "未授权,需要有效的Token", - }, - }, - }, - }, - "/api/conversations/{id}": map[string]interface{}{ - "get": map[string]interface{}{ - "tags": []string{"对话管理"}, - "summary": "查看对话详情", - "description": "获取指定对话的详细信息,包括对话信息和消息列表", - "operationId": "getConversation", - "parameters": []map[string]interface{}{ - { - "name": "id", - "in": "path", - "required": true, - "description": "对话ID", - "schema": map[string]interface{}{ - "type": "string", - }, - }, - }, - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "获取成功", - "content": map[string]interface{}{ - "application/json": map[string]interface{}{ - "schema": map[string]interface{}{ - "$ref": "#/components/schemas/ConversationDetail", - }, - }, - }, - }, - "404": map[string]interface{}{ - "description": "对话不存在", - }, - "401": map[string]interface{}{ - "description": "未授权,需要有效的Token", - }, - }, - }, - "put": map[string]interface{}{ - "tags": []string{"对话管理"}, - "summary": "更新对话", - "description": "更新对话标题", - "operationId": "updateConversation", - "parameters": []map[string]interface{}{ - { - "name": "id", - "in": "path", - "required": true, - "description": "对话ID", - "schema": map[string]interface{}{ - "type": "string", - }, - }, - }, - "requestBody": map[string]interface{}{ - "required": true, - "content": map[string]interface{}{ - "application/json": map[string]interface{}{ - "schema": map[string]interface{}{ - "$ref": "#/components/schemas/UpdateConversationRequest", - }, - }, - }, - }, - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "更新成功", - "content": map[string]interface{}{ - "application/json": map[string]interface{}{ - "schema": map[string]interface{}{ - "$ref": "#/components/schemas/Conversation", - }, - }, - }, - }, - "400": map[string]interface{}{ - "description": "请求参数错误", - }, - "404": map[string]interface{}{ - "description": "对话不存在", - }, - "401": map[string]interface{}{ - "description": "未授权,需要有效的Token", - }, - }, - }, - "delete": map[string]interface{}{ - "tags": []string{"对话管理"}, - "summary": "删除对话", - "description": "删除指定的对话及其所有相关数据(消息、漏洞等)。**此操作不可恢复**。", - "operationId": "deleteConversation", - "parameters": []map[string]interface{}{ - { - "name": "id", - "in": "path", - "required": true, - "description": "对话ID", - "schema": map[string]interface{}{ - "type": "string", - }, - }, - }, - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "删除成功", - "content": map[string]interface{}{ - "application/json": map[string]interface{}{ - "schema": map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "message": map[string]interface{}{ - "type": "string", - "description": "成功消息", - "example": "删除成功", - }, - }, - }, - }, - }, - }, - "404": map[string]interface{}{ - "description": "对话不存在", - }, - "401": map[string]interface{}{ - "description": "未授权,需要有效的Token", - }, - "500": map[string]interface{}{ - "description": "服务器内部错误", - }, - }, - }, - }, - "/api/conversations/{id}/results": map[string]interface{}{ - "get": map[string]interface{}{ - "tags": []string{"对话管理"}, - "summary": "获取对话结果", - "description": "获取指定对话的执行结果,包括消息、漏洞信息和执行结果", - "operationId": "getConversationResults", - "parameters": []map[string]interface{}{ - { - "name": "id", - "in": "path", - "required": true, - "description": "对话ID", - "schema": map[string]interface{}{ - "type": "string", - }, - }, - }, - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "获取成功", - "content": map[string]interface{}{ - "application/json": map[string]interface{}{ - "schema": map[string]interface{}{ - "$ref": "#/components/schemas/ConversationResults", - }, - }, - }, - }, - "404": map[string]interface{}{ - "description": "对话不存在或结果不存在", - }, - "401": map[string]interface{}{ - "description": "未授权,需要有效的Token", - }, - }, - }, - }, - "/api/agent-loop": map[string]interface{}{ - "post": map[string]interface{}{ - "tags": []string{"对话交互"}, - "summary": "发送消息并获取AI回复(非流式)", - "description": "向AI发送消息并获取回复(非流式响应)。**这是与AI交互的核心端点**,与前端聊天功能完全一致。\n**重要说明**:\n- ✅ 通过此API创建/发送的消息会**立即保存到数据库**\n- ✅ 前端页面会**自动刷新**显示新创建的对话和消息\n- ✅ 所有操作都有**完整的交互痕迹**,就像在前端操作一样\n- ✅ 支持角色配置,可以指定使用哪个测试角色\n**推荐使用流程**:\n1. **先创建对话**:调用 `POST /api/conversations` 创建新对话,获取 `conversationId`\n2. **再发送消息**:使用返回的 `conversationId` 调用此端点发送消息\n**使用示例**:\n**步骤1 - 创建对话:**\n```json\nPOST /api/conversations\n{\n \"title\": \"Web应用安全测试\"\n}\n```\n**步骤2 - 发送消息:**\n```json\nPOST /api/agent-loop\n{\n \"conversationId\": \"返回的对话ID\",\n \"message\": \"扫描 http://example.com 的SQL注入漏洞\",\n \"role\": \"渗透测试\"\n}\n```\n**其他方式**:\n如果不提供 `conversationId`,系统会自动创建新对话并发送消息。但**推荐先创建对话**,这样可以更好地管理对话列表。\n**响应**:返回AI的回复、对话ID和MCP执行ID列表。前端会自动刷新显示新消息。", - "operationId": "sendMessage", - "requestBody": map[string]interface{}{ - "required": true, - "content": map[string]interface{}{ - "application/json": map[string]interface{}{ - "schema": map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "message": map[string]interface{}{ - "type": "string", - "description": "要发送的消息(必需)", - "example": "扫描 http://example.com 的SQL注入漏洞", - }, - "conversationId": map[string]interface{}{ - "type": "string", - "description": "对话ID(可选)。\n- **不提供**:自动创建新对话并发送消息(推荐)\n- **提供**:消息会添加到指定对话中(对话必须存在)", - "example": "550e8400-e29b-41d4-a716-446655440000", - }, - "role": map[string]interface{}{ - "type": "string", - "description": "角色名称(可选),如:默认、渗透测试、Web应用扫描等", - "example": "默认", - }, - }, - "required": []string{"message"}, - }, - }, - }, - }, - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "消息发送成功,返回AI回复", - "content": map[string]interface{}{ - "application/json": map[string]interface{}{ - "schema": map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "response": map[string]interface{}{ - "type": "string", - "description": "AI的回复内容", - }, - "conversationId": map[string]interface{}{ - "type": "string", - "description": "对话ID", - }, - "mcpExecutionIds": map[string]interface{}{ - "type": "array", - "description": "MCP执行ID列表", - "items": map[string]interface{}{ - "type": "string", - }, - }, - "time": map[string]interface{}{ - "type": "string", - "format": "date-time", - "description": "响应时间", - }, - }, - }, - }, - }, - }, - "400": map[string]interface{}{ - "description": "请求参数错误", - }, - "401": map[string]interface{}{ - "description": "未授权,需要有效的Token", - }, - "500": map[string]interface{}{ - "description": "服务器内部错误", - }, - }, - }, - }, - "/api/agent-loop/stream": map[string]interface{}{ - "post": map[string]interface{}{ - "tags": []string{"对话交互"}, - "summary": "发送消息并获取AI回复(流式)", - "description": "向AI发送消息并获取流式回复(Server-Sent Events)。**这是与AI交互的核心端点**,与前端聊天功能完全一致。\n**重要说明**:\n- ✅ 通过此API创建/发送的消息会**立即保存到数据库**\n- ✅ 前端页面会**自动刷新**显示新创建的对话和消息\n- ✅ 所有操作都有**完整的交互痕迹**,就像在前端操作一样\n- ✅ 支持角色配置,可以指定使用哪个测试角色\n- ✅ 返回流式响应,适合实时显示AI回复\n**推荐使用流程**:\n1. **先创建对话**:调用 `POST /api/conversations` 创建新对话,获取 `conversationId`\n2. **再发送消息**:使用返回的 `conversationId` 调用此端点发送消息\n**使用示例**:\n**步骤1 - 创建对话:**\n```json\nPOST /api/conversations\n{\n \"title\": \"Web应用安全测试\"\n}\n```\n**步骤2 - 发送消息(流式):**\n```json\nPOST /api/agent-loop/stream\n{\n \"conversationId\": \"返回的对话ID\",\n \"message\": \"扫描 http://example.com 的SQL注入漏洞\",\n \"role\": \"渗透测试\"\n}\n```\n**响应格式**:Server-Sent Events (SSE),事件类型包括:\n- `message`: 用户消息确认\n- `response`: AI回复片段\n- `progress`: 进度更新\n- `done`: 完成\n- `error`: 错误\n- `cancelled`: 已取消", - "operationId": "sendMessageStream", - "requestBody": map[string]interface{}{ - "required": true, - "content": map[string]interface{}{ - "application/json": map[string]interface{}{ - "schema": map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "message": map[string]interface{}{ - "type": "string", - "description": "要发送的消息(必需)", - "example": "扫描 http://example.com 的SQL注入漏洞", - }, - "conversationId": map[string]interface{}{ - "type": "string", - "description": "对话ID(可选)。\n- **不提供**:自动创建新对话并发送消息(推荐)\n- **提供**:消息会添加到指定对话中(对话必须存在)", - "example": "550e8400-e29b-41d4-a716-446655440000", - }, - "role": map[string]interface{}{ - "type": "string", - "description": "角色名称(可选),如:默认、渗透测试、Web应用扫描等", - "example": "默认", - }, - }, - "required": []string{"message"}, - }, - }, - }, - }, - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "流式响应(Server-Sent Events)", - "content": map[string]interface{}{ - "text/event-stream": map[string]interface{}{ - "schema": map[string]interface{}{ - "type": "string", - "description": "SSE流式数据", - }, - }, - }, - }, - "400": map[string]interface{}{ - "description": "请求参数错误", - }, - "401": map[string]interface{}{ - "description": "未授权,需要有效的Token", - }, - "500": map[string]interface{}{ - "description": "服务器内部错误", - }, - }, - }, - }, - "/api/multi-agent": map[string]interface{}{ - "post": map[string]interface{}{ - "tags": []string{"对话交互"}, - "summary": "发送消息并获取 AI 回复(Eino DeepAgent,非流式)", - "description": "与 `POST /api/agent-loop` 请求体相同,但由 **CloudWeGo Eino DeepAgent** 执行多代理编排。**前提**:`multi_agent.enabled: true`(可在设置页或 `config.yaml` 开启);未启用时返回 404 JSON。请求体支持 `webshellConnectionId`(与单代理 WebShell 助手一致)。", - "operationId": "sendMessageMultiAgent", - "requestBody": map[string]interface{}{ - "required": true, - "content": map[string]interface{}{ - "application/json": map[string]interface{}{ - "schema": map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "message": map[string]interface{}{ - "type": "string", - "description": "要发送的消息(必需)", - }, - "conversationId": map[string]interface{}{ - "type": "string", - "description": "对话 ID(可选,不提供则新建)", - }, - "role": map[string]interface{}{ - "type": "string", - "description": "角色名称(可选)", - }, - "webshellConnectionId": map[string]interface{}{ - "type": "string", - "description": "WebShell 连接 ID(可选,与 agent-loop 行为一致)", - }, - }, - "required": []string{"message"}, - }, - }, - }, - }, - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "成功,响应格式同 /api/agent-loop", - }, - "400": map[string]interface{}{"description": "参数错误"}, - "401": map[string]interface{}{"description": "未授权"}, - "404": map[string]interface{}{"description": "多代理未启用或对话不存在"}, - "500": map[string]interface{}{"description": "执行失败"}, - }, - }, - }, - "/api/multi-agent/stream": map[string]interface{}{ - "post": map[string]interface{}{ - "tags": []string{"对话交互"}, - "summary": "发送消息并获取 AI 回复(Eino DeepAgent,SSE)", - "description": "与 `POST /api/agent-loop/stream` 类似,事件类型兼容;由 Eino DeepAgent 执行。**前提**:`multi_agent.enabled: true`;路由常注册,未启用时仍返回 200 SSE,流内首条为 `type: error` 后接 `done`。支持 `webshellConnectionId`。", - "operationId": "sendMessageMultiAgentStream", - "requestBody": map[string]interface{}{ - "required": true, - "content": map[string]interface{}{ - "application/json": map[string]interface{}{ - "schema": map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "message": map[string]interface{}{"type": "string"}, - "conversationId": map[string]interface{}{"type": "string"}, - "role": map[string]interface{}{"type": "string"}, - "webshellConnectionId": map[string]interface{}{"type": "string"}, - }, - "required": []string{"message"}, - }, - }, - }, - }, - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "text/event-stream(SSE)", - "content": map[string]interface{}{ - "text/event-stream": map[string]interface{}{ - "schema": map[string]interface{}{ - "type": "string", - "description": "SSE 流", - }, - }, - }, - }, - "401": map[string]interface{}{"description": "未授权"}, - }, - }, - }, - "/api/agent-loop/cancel": map[string]interface{}{ - "post": map[string]interface{}{ - "tags": []string{"对话交互"}, - "summary": "取消任务", - "description": "取消正在执行的Agent Loop任务", - "operationId": "cancelAgentLoop", - "requestBody": map[string]interface{}{ - "required": true, - "content": map[string]interface{}{ - "application/json": map[string]interface{}{ - "schema": map[string]interface{}{ - "$ref": "#/components/schemas/CancelAgentLoopRequest", - }, - }, - }, - }, - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "取消请求已提交", - "content": map[string]interface{}{ - "application/json": map[string]interface{}{ - "schema": map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "status": map[string]interface{}{ - "type": "string", - "example": "cancelling", - }, - "conversationId": map[string]interface{}{ - "type": "string", - "description": "对话ID", - }, - "message": map[string]interface{}{ - "type": "string", - "example": "已提交取消请求,任务将在当前步骤完成后停止。", - }, - }, - }, - }, - }, - }, - "404": map[string]interface{}{ - "description": "未找到正在执行的任务", - }, - "401": map[string]interface{}{ - "description": "未授权", - }, - }, - }, - }, - "/api/agent-loop/tasks": map[string]interface{}{ - "get": map[string]interface{}{ - "tags": []string{"对话交互"}, - "summary": "列出运行中的任务", - "description": "获取所有正在运行的Agent Loop任务", - "operationId": "listAgentTasks", - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "获取成功", - "content": map[string]interface{}{ - "application/json": map[string]interface{}{ - "schema": map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "tasks": map[string]interface{}{ - "type": "array", - "description": "任务列表", - "items": map[string]interface{}{ - "$ref": "#/components/schemas/AgentTask", - }, - }, - }, - }, - }, - }, - }, - "401": map[string]interface{}{ - "description": "未授权", - }, - }, - }, - }, - "/api/agent-loop/tasks/completed": map[string]interface{}{ - "get": map[string]interface{}{ - "tags": []string{"对话交互"}, - "summary": "列出已完成的任务", - "description": "获取最近完成的Agent Loop任务历史", - "operationId": "listCompletedTasks", - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "获取成功", - "content": map[string]interface{}{ - "application/json": map[string]interface{}{ - "schema": map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "tasks": map[string]interface{}{ - "type": "array", - "description": "已完成任务列表", - "items": map[string]interface{}{ - "$ref": "#/components/schemas/AgentTask", - }, - }, - }, - }, - }, - }, - }, - "401": map[string]interface{}{ - "description": "未授权", - }, - }, - }, - }, - "/api/batch-tasks": map[string]interface{}{ - "post": map[string]interface{}{ - "tags": []string{"批量任务"}, - "summary": "创建批量任务队列", - "description": "创建一个批量任务队列,包含多个任务", - "operationId": "createBatchQueue", - "requestBody": map[string]interface{}{ - "required": true, - "content": map[string]interface{}{ - "application/json": map[string]interface{}{ - "schema": map[string]interface{}{ - "$ref": "#/components/schemas/BatchTaskRequest", - }, - }, - }, - }, - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "创建成功", - "content": map[string]interface{}{ - "application/json": map[string]interface{}{ - "schema": map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "queueId": map[string]interface{}{ - "type": "string", - "description": "队列ID", - }, - "queue": map[string]interface{}{ - "$ref": "#/components/schemas/BatchQueue", - }, - "started": map[string]interface{}{ - "type": "boolean", - "description": "是否已立即启动执行", - }, - }, - }, - }, - }, - }, - "400": map[string]interface{}{ - "description": "请求参数错误", - }, - "401": map[string]interface{}{ - "description": "未授权", - }, - }, - }, - "get": map[string]interface{}{ - "tags": []string{"批量任务"}, - "summary": "列出批量任务队列", - "description": "获取所有批量任务队列", - "operationId": "listBatchQueues", - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "获取成功", - "content": map[string]interface{}{ - "application/json": map[string]interface{}{ - "schema": map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "queues": map[string]interface{}{ - "type": "array", - "description": "队列列表", - "items": map[string]interface{}{ - "$ref": "#/components/schemas/BatchQueue", - }, - }, - }, - }, - }, - }, - }, - "401": map[string]interface{}{ - "description": "未授权", - }, - }, - }, - }, - "/api/batch-tasks/{queueId}": map[string]interface{}{ - "get": map[string]interface{}{ - "tags": []string{"批量任务"}, - "summary": "获取批量任务队列", - "description": "获取指定批量任务队列的详细信息", - "operationId": "getBatchQueue", - "parameters": []map[string]interface{}{ - { - "name": "queueId", - "in": "path", - "required": true, - "description": "队列ID", - "schema": map[string]interface{}{ - "type": "string", - }, - }, - }, - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "获取成功", - "content": map[string]interface{}{ - "application/json": map[string]interface{}{ - "schema": map[string]interface{}{ - "$ref": "#/components/schemas/BatchQueue", - }, - }, - }, - }, - "404": map[string]interface{}{ - "description": "队列不存在", - }, - "401": map[string]interface{}{ - "description": "未授权", - }, - }, - }, - "delete": map[string]interface{}{ - "tags": []string{"批量任务"}, - "summary": "删除批量任务队列", - "description": "删除指定的批量任务队列", - "operationId": "deleteBatchQueue", - "parameters": []map[string]interface{}{ - { - "name": "queueId", - "in": "path", - "required": true, - "description": "队列ID", - "schema": map[string]interface{}{ - "type": "string", - }, - }, - }, - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "删除成功", - }, - "404": map[string]interface{}{ - "description": "队列不存在", - }, - "401": map[string]interface{}{ - "description": "未授权", - }, - }, - }, - }, - "/api/batch-tasks/{queueId}/start": map[string]interface{}{ - "post": map[string]interface{}{ - "tags": []string{"批量任务"}, - "summary": "启动批量任务队列", - "description": "开始执行批量任务队列中的任务", - "operationId": "startBatchQueue", - "parameters": []map[string]interface{}{ - { - "name": "queueId", - "in": "path", - "required": true, - "description": "队列ID", - "schema": map[string]interface{}{ - "type": "string", - }, - }, - }, - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "启动成功", - }, - "404": map[string]interface{}{ - "description": "队列不存在", - }, - "401": map[string]interface{}{ - "description": "未授权", - }, - }, - }, - }, - "/api/batch-tasks/{queueId}/pause": map[string]interface{}{ - "post": map[string]interface{}{ - "tags": []string{"批量任务"}, - "summary": "暂停批量任务队列", - "description": "暂停正在执行的批量任务队列", - "operationId": "pauseBatchQueue", - "parameters": []map[string]interface{}{ - { - "name": "queueId", - "in": "path", - "required": true, - "description": "队列ID", - "schema": map[string]interface{}{ - "type": "string", - }, - }, - }, - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "暂停成功", - }, - "404": map[string]interface{}{ - "description": "队列不存在", - }, - "401": map[string]interface{}{ - "description": "未授权", - }, - }, - }, - }, - "/api/batch-tasks/{queueId}/tasks": map[string]interface{}{ - "post": map[string]interface{}{ - "tags": []string{"批量任务"}, - "summary": "添加任务到队列", - "description": "向批量任务队列添加新任务。任务会添加到队列末尾,按照队列顺序依次执行。每个任务会创建一个独立的对话,支持完整的状态跟踪。\n**任务格式**:\n任务内容是一个字符串,描述要执行的安全测试任务。例如:\n- \"扫描 http://example.com 的SQL注入漏洞\"\n- \"对 192.168.1.1 进行端口扫描\"\n- \"检测 https://target.com 的XSS漏洞\"\n**使用示例**:\n```json\n{\n \"task\": \"扫描 http://example.com 的SQL注入漏洞\"\n}\n```", - "operationId": "addBatchTask", - "parameters": []map[string]interface{}{ - { - "name": "queueId", - "in": "path", - "required": true, - "description": "队列ID", - "schema": map[string]interface{}{ - "type": "string", - }, - }, - }, - "requestBody": map[string]interface{}{ - "required": true, - "content": map[string]interface{}{ - "application/json": map[string]interface{}{ - "schema": map[string]interface{}{ - "type": "object", - "required": []string{"task"}, - "properties": map[string]interface{}{ - "task": map[string]interface{}{ - "type": "string", - "description": "任务内容,描述要执行的安全测试任务(必需)", - "example": "扫描 http://example.com 的SQL注入漏洞", - }, - }, - }, - "examples": map[string]interface{}{ - "sqlInjection": map[string]interface{}{ - "summary": "SQL注入扫描", - "description": "扫描目标网站的SQL注入漏洞", - "value": map[string]interface{}{ - "task": "扫描 http://example.com 的SQL注入漏洞", - }, - }, - "portScan": map[string]interface{}{ - "summary": "端口扫描", - "description": "对目标IP进行端口扫描", - "value": map[string]interface{}{ - "task": "对 192.168.1.1 进行端口扫描", - }, - }, - }, - }, - }, - }, - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "添加成功", - "content": map[string]interface{}{ - "application/json": map[string]interface{}{ - "schema": map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "taskId": map[string]interface{}{ - "type": "string", - "description": "新添加的任务ID", - }, - "message": map[string]interface{}{ - "type": "string", - "description": "成功消息", - "example": "任务已添加到队列", - }, - }, - }, - }, - }, - }, - "400": map[string]interface{}{ - "description": "请求参数错误(如task为空)", - }, - "404": map[string]interface{}{ - "description": "队列不存在", - }, - "401": map[string]interface{}{ - "description": "未授权", - }, - }, - }, - }, - "/api/batch-tasks/{queueId}/tasks/{taskId}": map[string]interface{}{ - "put": map[string]interface{}{ - "tags": []string{"批量任务"}, - "summary": "更新批量任务", - "description": "更新批量任务队列中的指定任务", - "operationId": "updateBatchTask", - "parameters": []map[string]interface{}{ - { - "name": "queueId", - "in": "path", - "required": true, - "description": "队列ID", - "schema": map[string]interface{}{ - "type": "string", - }, - }, - { - "name": "taskId", - "in": "path", - "required": true, - "description": "任务ID", - "schema": map[string]interface{}{ - "type": "string", - }, - }, - }, - "requestBody": map[string]interface{}{ - "required": true, - "content": map[string]interface{}{ - "application/json": map[string]interface{}{ - "schema": map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "task": map[string]interface{}{ - "type": "string", - "description": "任务内容", - }, - }, - }, - }, - }, - }, - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "更新成功", - }, - "404": map[string]interface{}{ - "description": "任务不存在", - }, - "401": map[string]interface{}{ - "description": "未授权", - }, - }, - }, - "delete": map[string]interface{}{ - "tags": []string{"批量任务"}, - "summary": "删除批量任务", - "description": "从批量任务队列中删除指定任务", - "operationId": "deleteBatchTask", - "parameters": []map[string]interface{}{ - { - "name": "queueId", - "in": "path", - "required": true, - "description": "队列ID", - "schema": map[string]interface{}{ - "type": "string", - }, - }, - { - "name": "taskId", - "in": "path", - "required": true, - "description": "任务ID", - "schema": map[string]interface{}{ - "type": "string", - }, - }, - }, - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "删除成功", - }, - "404": map[string]interface{}{ - "description": "任务不存在", - }, - "401": map[string]interface{}{ - "description": "未授权", - }, - }, - }, - }, - "/api/groups": map[string]interface{}{ - "post": map[string]interface{}{ - "tags": []string{"对话分组"}, - "summary": "创建分组", - "description": "创建一个新的对话分组", - "operationId": "createGroup", - "requestBody": map[string]interface{}{ - "required": true, - "content": map[string]interface{}{ - "application/json": map[string]interface{}{ - "schema": map[string]interface{}{ - "$ref": "#/components/schemas/CreateGroupRequest", - }, - }, - }, - }, - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "创建成功", - "content": map[string]interface{}{ - "application/json": map[string]interface{}{ - "schema": map[string]interface{}{ - "$ref": "#/components/schemas/Group", - }, - }, - }, - }, - "400": map[string]interface{}{ - "description": "请求参数错误或分组名称已存在", - }, - "401": map[string]interface{}{ - "description": "未授权", - }, - }, - }, - "get": map[string]interface{}{ - "tags": []string{"对话分组"}, - "summary": "列出分组", - "description": "获取所有对话分组", - "operationId": "listGroups", - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "获取成功", - "content": map[string]interface{}{ - "application/json": map[string]interface{}{ - "schema": map[string]interface{}{ - "type": "array", - "items": map[string]interface{}{ - "$ref": "#/components/schemas/Group", - }, - }, - }, - }, - }, - "401": map[string]interface{}{ - "description": "未授权", - }, - }, - }, - }, - "/api/groups/{id}": map[string]interface{}{ - "get": map[string]interface{}{ - "tags": []string{"对话分组"}, - "summary": "获取分组", - "description": "获取指定分组的详细信息", - "operationId": "getGroup", - "parameters": []map[string]interface{}{ - { - "name": "id", - "in": "path", - "required": true, - "description": "分组ID", - "schema": map[string]interface{}{ - "type": "string", - }, - }, - }, - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "获取成功", - "content": map[string]interface{}{ - "application/json": map[string]interface{}{ - "schema": map[string]interface{}{ - "$ref": "#/components/schemas/Group", - }, - }, - }, - }, - "404": map[string]interface{}{ - "description": "分组不存在", - }, - "401": map[string]interface{}{ - "description": "未授权", - }, - }, - }, - "put": map[string]interface{}{ - "tags": []string{"对话分组"}, - "summary": "更新分组", - "description": "更新分组信息", - "operationId": "updateGroup", - "parameters": []map[string]interface{}{ - { - "name": "id", - "in": "path", - "required": true, - "description": "分组ID", - "schema": map[string]interface{}{ - "type": "string", - }, - }, - }, - "requestBody": map[string]interface{}{ - "required": true, - "content": map[string]interface{}{ - "application/json": map[string]interface{}{ - "schema": map[string]interface{}{ - "$ref": "#/components/schemas/UpdateGroupRequest", - }, - }, - }, - }, - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "更新成功", - "content": map[string]interface{}{ - "application/json": map[string]interface{}{ - "schema": map[string]interface{}{ - "$ref": "#/components/schemas/Group", - }, - }, - }, - }, - "400": map[string]interface{}{ - "description": "请求参数错误或分组名称已存在", - }, - "404": map[string]interface{}{ - "description": "分组不存在", - }, - "401": map[string]interface{}{ - "description": "未授权", - }, - }, - }, - "delete": map[string]interface{}{ - "tags": []string{"对话分组"}, - "summary": "删除分组", - "description": "删除指定分组", - "operationId": "deleteGroup", - "parameters": []map[string]interface{}{ - { - "name": "id", - "in": "path", - "required": true, - "description": "分组ID", - "schema": map[string]interface{}{ - "type": "string", - }, - }, - }, - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "删除成功", - }, - "404": map[string]interface{}{ - "description": "分组不存在", - }, - "401": map[string]interface{}{ - "description": "未授权", - }, - }, - }, - }, - "/api/groups/{id}/conversations": map[string]interface{}{ - "get": map[string]interface{}{ - "tags": []string{"对话分组"}, - "summary": "获取分组中的对话", - "description": "获取指定分组中的所有对话", - "operationId": "getGroupConversations", - "parameters": []map[string]interface{}{ - { - "name": "id", - "in": "path", - "required": true, - "description": "分组ID", - "schema": map[string]interface{}{ - "type": "string", - }, - }, - }, - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "获取成功", - "content": map[string]interface{}{ - "application/json": map[string]interface{}{ - "schema": map[string]interface{}{ - "type": "array", - "items": map[string]interface{}{ - "$ref": "#/components/schemas/Conversation", - }, - }, - }, - }, - }, - "404": map[string]interface{}{ - "description": "分组不存在", - }, - "401": map[string]interface{}{ - "description": "未授权", - }, - }, - }, - }, - "/api/groups/conversations": map[string]interface{}{ - "post": map[string]interface{}{ - "tags": []string{"对话分组"}, - "summary": "添加对话到分组", - "description": "将对话添加到指定分组", - "operationId": "addConversationToGroup", - "requestBody": map[string]interface{}{ - "required": true, - "content": map[string]interface{}{ - "application/json": map[string]interface{}{ - "schema": map[string]interface{}{ - "$ref": "#/components/schemas/AddConversationToGroupRequest", - }, - }, - }, - }, - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "添加成功", - }, - "400": map[string]interface{}{ - "description": "请求参数错误", - }, - "404": map[string]interface{}{ - "description": "对话或分组不存在", - }, - "401": map[string]interface{}{ - "description": "未授权", - }, - }, - }, - }, - "/api/groups/{id}/conversations/{conversationId}": map[string]interface{}{ - "delete": map[string]interface{}{ - "tags": []string{"对话分组"}, - "summary": "从分组移除对话", - "description": "从指定分组中移除对话", - "operationId": "removeConversationFromGroup", - "parameters": []map[string]interface{}{ - { - "name": "id", - "in": "path", - "required": true, - "description": "分组ID", - "schema": map[string]interface{}{ - "type": "string", - }, - }, - { - "name": "conversationId", - "in": "path", - "required": true, - "description": "对话ID", - "schema": map[string]interface{}{ - "type": "string", - }, - }, - }, - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "移除成功", - }, - "404": map[string]interface{}{ - "description": "对话或分组不存在", - }, - "401": map[string]interface{}{ - "description": "未授权", - }, - }, - }, - }, - "/api/vulnerabilities": map[string]interface{}{ - "get": map[string]interface{}{ - "tags": []string{"漏洞管理"}, - "summary": "列出漏洞", - "description": "获取漏洞列表,支持分页和筛选", - "operationId": "listVulnerabilities", - "parameters": []map[string]interface{}{ - { - "name": "limit", - "in": "query", - "required": false, - "description": "每页数量", - "schema": map[string]interface{}{ - "type": "integer", - "default": 20, - "minimum": 1, - "maximum": 100, - }, - }, - { - "name": "offset", - "in": "query", - "required": false, - "description": "偏移量", - "schema": map[string]interface{}{ - "type": "integer", - "default": 0, - "minimum": 0, - }, - }, - { - "name": "page", - "in": "query", - "required": false, - "description": "页码(与offset二选一)", - "schema": map[string]interface{}{ - "type": "integer", - "minimum": 1, - }, - }, - { - "name": "id", - "in": "query", - "required": false, - "description": "漏洞ID", - "schema": map[string]interface{}{ - "type": "string", - }, - }, - { - "name": "conversation_id", - "in": "query", - "required": false, - "description": "对话ID", - "schema": map[string]interface{}{ - "type": "string", - }, - }, - { - "name": "severity", - "in": "query", - "required": false, - "description": "严重程度", - "schema": map[string]interface{}{ - "type": "string", - "enum": []string{"critical", "high", "medium", "low", "info"}, - }, - }, - { - "name": "status", - "in": "query", - "required": false, - "description": "状态", - "schema": map[string]interface{}{ - "type": "string", - "enum": []string{"open", "closed", "fixed"}, - }, - }, - }, - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "获取成功", - "content": map[string]interface{}{ - "application/json": map[string]interface{}{ - "schema": map[string]interface{}{ - "$ref": "#/components/schemas/ListVulnerabilitiesResponse", - }, - }, - }, - }, - "401": map[string]interface{}{ - "description": "未授权", - }, - }, - }, - "post": map[string]interface{}{ - "tags": []string{"漏洞管理"}, - "summary": "创建漏洞", - "description": "创建一个新的漏洞记录", - "operationId": "createVulnerability", - "requestBody": map[string]interface{}{ - "required": true, - "content": map[string]interface{}{ - "application/json": map[string]interface{}{ - "schema": map[string]interface{}{ - "$ref": "#/components/schemas/CreateVulnerabilityRequest", - }, - }, - }, - }, - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "创建成功", - "content": map[string]interface{}{ - "application/json": map[string]interface{}{ - "schema": map[string]interface{}{ - "$ref": "#/components/schemas/Vulnerability", - }, - }, - }, - }, - "400": map[string]interface{}{ - "description": "请求参数错误", - }, - "401": map[string]interface{}{ - "description": "未授权", - }, - }, - }, - }, - "/api/vulnerabilities/stats": map[string]interface{}{ - "get": map[string]interface{}{ - "tags": []string{"漏洞管理"}, - "summary": "获取漏洞统计", - "description": "获取漏洞统计信息", - "operationId": "getVulnerabilityStats", - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "获取成功", - "content": map[string]interface{}{ - "application/json": map[string]interface{}{ - "schema": map[string]interface{}{ - "$ref": "#/components/schemas/VulnerabilityStats", - }, - }, - }, - }, - "401": map[string]interface{}{ - "description": "未授权", - }, - }, - }, - }, - "/api/vulnerabilities/{id}": map[string]interface{}{ - "get": map[string]interface{}{ - "tags": []string{"漏洞管理"}, - "summary": "获取漏洞", - "description": "获取指定漏洞的详细信息", - "operationId": "getVulnerability", - "parameters": []map[string]interface{}{ - { - "name": "id", - "in": "path", - "required": true, - "description": "漏洞ID", - "schema": map[string]interface{}{ - "type": "string", - }, - }, - }, - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "获取成功", - "content": map[string]interface{}{ - "application/json": map[string]interface{}{ - "schema": map[string]interface{}{ - "$ref": "#/components/schemas/Vulnerability", - }, - }, - }, - }, - "404": map[string]interface{}{ - "description": "漏洞不存在", - }, - "401": map[string]interface{}{ - "description": "未授权", - }, - }, - }, - "put": map[string]interface{}{ - "tags": []string{"漏洞管理"}, - "summary": "更新漏洞", - "description": "更新漏洞信息", - "operationId": "updateVulnerability", - "parameters": []map[string]interface{}{ - { - "name": "id", - "in": "path", - "required": true, - "description": "漏洞ID", - "schema": map[string]interface{}{ - "type": "string", - }, - }, - }, - "requestBody": map[string]interface{}{ - "required": true, - "content": map[string]interface{}{ - "application/json": map[string]interface{}{ - "schema": map[string]interface{}{ - "$ref": "#/components/schemas/UpdateVulnerabilityRequest", - }, - }, - }, - }, - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "更新成功", - "content": map[string]interface{}{ - "application/json": map[string]interface{}{ - "schema": map[string]interface{}{ - "$ref": "#/components/schemas/Vulnerability", - }, - }, - }, - }, - "400": map[string]interface{}{ - "description": "请求参数错误", - }, - "404": map[string]interface{}{ - "description": "漏洞不存在", - }, - "401": map[string]interface{}{ - "description": "未授权", - }, - }, - }, - "delete": map[string]interface{}{ - "tags": []string{"漏洞管理"}, - "summary": "删除漏洞", - "description": "删除指定漏洞", - "operationId": "deleteVulnerability", - "parameters": []map[string]interface{}{ - { - "name": "id", - "in": "path", - "required": true, - "description": "漏洞ID", - "schema": map[string]interface{}{ - "type": "string", - }, - }, - }, - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "删除成功", - }, - "404": map[string]interface{}{ - "description": "漏洞不存在", - }, - "401": map[string]interface{}{ - "description": "未授权", - }, - }, - }, - }, - "/api/roles": map[string]interface{}{ - "get": map[string]interface{}{ - "tags": []string{"角色管理"}, - "summary": "列出角色", - "description": "获取所有安全测试角色", - "operationId": "getRoles", - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "获取成功", - "content": map[string]interface{}{ - "application/json": map[string]interface{}{ - "schema": map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "roles": map[string]interface{}{ - "type": "array", - "description": "角色列表", - "items": map[string]interface{}{ - "$ref": "#/components/schemas/RoleConfig", - }, - }, - }, - }, - }, - }, - }, - "401": map[string]interface{}{ - "description": "未授权", - }, - }, - }, - "post": map[string]interface{}{ - "tags": []string{"角色管理"}, - "summary": "创建角色", - "description": "创建一个新的安全测试角色", - "operationId": "createRole", - "requestBody": map[string]interface{}{ - "required": true, - "content": map[string]interface{}{ - "application/json": map[string]interface{}{ - "schema": map[string]interface{}{ - "$ref": "#/components/schemas/RoleConfig", - }, - }, - }, - }, - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "创建成功", - }, - "400": map[string]interface{}{ - "description": "请求参数错误", - }, - "401": map[string]interface{}{ - "description": "未授权", - }, - }, - }, - }, - "/api/roles/{name}": map[string]interface{}{ - "get": map[string]interface{}{ - "tags": []string{"角色管理"}, - "summary": "获取角色", - "description": "获取指定角色的详细信息", - "operationId": "getRole", - "parameters": []map[string]interface{}{ - { - "name": "name", - "in": "path", - "required": true, - "description": "角色名称", - "schema": map[string]interface{}{ - "type": "string", - }, - }, - }, - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "获取成功", - "content": map[string]interface{}{ - "application/json": map[string]interface{}{ - "schema": map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "role": map[string]interface{}{ - "$ref": "#/components/schemas/RoleConfig", - }, - }, - }, - }, - }, - }, - "404": map[string]interface{}{ - "description": "角色不存在", - }, - "401": map[string]interface{}{ - "description": "未授权", - }, - }, - }, - "put": map[string]interface{}{ - "tags": []string{"角色管理"}, - "summary": "更新角色", - "description": "更新指定角色的配置", - "operationId": "updateRole", - "parameters": []map[string]interface{}{ - { - "name": "name", - "in": "path", - "required": true, - "description": "角色名称", - "schema": map[string]interface{}{ - "type": "string", - }, - }, - }, - "requestBody": map[string]interface{}{ - "required": true, - "content": map[string]interface{}{ - "application/json": map[string]interface{}{ - "schema": map[string]interface{}{ - "$ref": "#/components/schemas/RoleConfig", - }, - }, - }, - }, - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "更新成功", - }, - "400": map[string]interface{}{ - "description": "请求参数错误", - }, - "404": map[string]interface{}{ - "description": "角色不存在", - }, - "401": map[string]interface{}{ - "description": "未授权", - }, - }, - }, - "delete": map[string]interface{}{ - "tags": []string{"角色管理"}, - "summary": "删除角色", - "description": "删除指定角色", - "operationId": "deleteRole", - "parameters": []map[string]interface{}{ - { - "name": "name", - "in": "path", - "required": true, - "description": "角色名称", - "schema": map[string]interface{}{ - "type": "string", - }, - }, - }, - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "删除成功", - }, - "404": map[string]interface{}{ - "description": "角色不存在", - }, - "401": map[string]interface{}{ - "description": "未授权", - }, - }, - }, - }, - "/api/roles/skills/list": map[string]interface{}{ - "get": map[string]interface{}{ - "tags": []string{"角色管理"}, - "summary": "获取可用Skills列表", - "description": "获取所有可用的Skills列表,用于角色配置", - "operationId": "getSkills", - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "获取成功", - "content": map[string]interface{}{ - "application/json": map[string]interface{}{ - "schema": map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "skills": map[string]interface{}{ - "type": "array", - "description": "Skills列表", - "items": map[string]interface{}{ - "type": "string", - }, - }, - }, - }, - }, - }, - }, - "401": map[string]interface{}{ - "description": "未授权", - }, - }, - }, - }, - "/api/skills": map[string]interface{}{ - "get": map[string]interface{}{ - "tags": []string{"Skills管理"}, - "summary": "列出Skills", - "description": "获取所有Skills列表,支持分页和搜索", - "operationId": "getSkills", - "parameters": []map[string]interface{}{ - { - "name": "limit", - "in": "query", - "required": false, - "description": "每页数量", - "schema": map[string]interface{}{ - "type": "integer", - "default": 20, - }, - }, - { - "name": "offset", - "in": "query", - "required": false, - "description": "偏移量", - "schema": map[string]interface{}{ - "type": "integer", - "default": 0, - }, - }, - { - "name": "search", - "in": "query", - "required": false, - "description": "搜索关键词", - "schema": map[string]interface{}{ - "type": "string", - }, - }, - }, - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "获取成功", - "content": map[string]interface{}{ - "application/json": map[string]interface{}{ - "schema": map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "skills": map[string]interface{}{ - "type": "array", - "description": "Skills列表", - "items": map[string]interface{}{ - "$ref": "#/components/schemas/Skill", - }, - }, - "total": map[string]interface{}{ - "type": "integer", - "description": "总数", - }, - }, - }, - }, - }, - }, - "401": map[string]interface{}{ - "description": "未授权", - }, - }, - }, - "post": map[string]interface{}{ - "tags": []string{"Skills管理"}, - "summary": "创建Skill", - "description": "创建一个新的Skill", - "operationId": "createSkill", - "requestBody": map[string]interface{}{ - "required": true, - "content": map[string]interface{}{ - "application/json": map[string]interface{}{ - "schema": map[string]interface{}{ - "$ref": "#/components/schemas/CreateSkillRequest", - }, - }, - }, - }, - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "创建成功", - }, - "400": map[string]interface{}{ - "description": "请求参数错误", - }, - "401": map[string]interface{}{ - "description": "未授权", - }, - }, - }, - }, - "/api/skills/stats": map[string]interface{}{ - "get": map[string]interface{}{ - "tags": []string{"Skills管理"}, - "summary": "获取Skill统计", - "description": "获取Skill调用统计信息", - "operationId": "getSkillStats", - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "获取成功", - "content": map[string]interface{}{ - "application/json": map[string]interface{}{ - "schema": map[string]interface{}{ - "type": "object", - "description": "统计信息", - }, - }, - }, - }, - "401": map[string]interface{}{ - "description": "未授权", - }, - }, - }, - "delete": map[string]interface{}{ - "tags": []string{"Skills管理"}, - "summary": "清空Skill统计", - "description": "清空所有Skill的调用统计", - "operationId": "clearSkillStats", - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "清空成功", - }, - "401": map[string]interface{}{ - "description": "未授权", - }, - }, - }, - }, - "/api/skills/{name}": map[string]interface{}{ - "get": map[string]interface{}{ - "tags": []string{"Skills管理"}, - "summary": "获取Skill", - "description": "获取指定Skill的详细信息", - "operationId": "getSkill", - "parameters": []map[string]interface{}{ - { - "name": "name", - "in": "path", - "required": true, - "description": "Skill名称", - "schema": map[string]interface{}{ - "type": "string", - }, - }, - }, - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "获取成功", - "content": map[string]interface{}{ - "application/json": map[string]interface{}{ - "schema": map[string]interface{}{ - "$ref": "#/components/schemas/Skill", - }, - }, - }, - }, - "404": map[string]interface{}{ - "description": "Skill不存在", - }, - "401": map[string]interface{}{ - "description": "未授权", - }, - }, - }, - "put": map[string]interface{}{ - "tags": []string{"Skills管理"}, - "summary": "更新Skill", - "description": "更新指定Skill的信息", - "operationId": "updateSkill", - "parameters": []map[string]interface{}{ - { - "name": "name", - "in": "path", - "required": true, - "description": "Skill名称", - "schema": map[string]interface{}{ - "type": "string", - }, - }, - }, - "requestBody": map[string]interface{}{ - "required": true, - "content": map[string]interface{}{ - "application/json": map[string]interface{}{ - "schema": map[string]interface{}{ - "$ref": "#/components/schemas/UpdateSkillRequest", - }, - }, - }, - }, - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "更新成功", - }, - "400": map[string]interface{}{ - "description": "请求参数错误", - }, - "404": map[string]interface{}{ - "description": "Skill不存在", - }, - "401": map[string]interface{}{ - "description": "未授权", - }, - }, - }, - "delete": map[string]interface{}{ - "tags": []string{"Skills管理"}, - "summary": "删除Skill", - "description": "删除指定Skill", - "operationId": "deleteSkill", - "parameters": []map[string]interface{}{ - { - "name": "name", - "in": "path", - "required": true, - "description": "Skill名称", - "schema": map[string]interface{}{ - "type": "string", - }, - }, - }, - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "删除成功", - }, - "404": map[string]interface{}{ - "description": "Skill不存在", - }, - "401": map[string]interface{}{ - "description": "未授权", - }, - }, - }, - }, - "/api/skills/{name}/bound-roles": map[string]interface{}{ - "get": map[string]interface{}{ - "tags": []string{"Skills管理"}, - "summary": "获取绑定角色", - "description": "获取使用指定Skill的所有角色", - "operationId": "getSkillBoundRoles", - "parameters": []map[string]interface{}{ - { - "name": "name", - "in": "path", - "required": true, - "description": "Skill名称", - "schema": map[string]interface{}{ - "type": "string", - }, - }, - }, - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "获取成功", - "content": map[string]interface{}{ - "application/json": map[string]interface{}{ - "schema": map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "roles": map[string]interface{}{ - "type": "array", - "description": "角色列表", - "items": map[string]interface{}{ - "type": "string", - }, - }, - }, - }, - }, - }, - }, - "404": map[string]interface{}{ - "description": "Skill不存在", - }, - "401": map[string]interface{}{ - "description": "未授权", - }, - }, - }, - }, - "/api/skills/{name}/stats": map[string]interface{}{ - "delete": map[string]interface{}{ - "tags": []string{"Skills管理"}, - "summary": "清空Skill统计", - "description": "清空指定Skill的调用统计", - "operationId": "clearSkillStatsByName", - "parameters": []map[string]interface{}{ - { - "name": "name", - "in": "path", - "required": true, - "description": "Skill名称", - "schema": map[string]interface{}{ - "type": "string", - }, - }, - }, - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "清空成功", - }, - "404": map[string]interface{}{ - "description": "Skill不存在", - }, - "401": map[string]interface{}{ - "description": "未授权", - }, - }, - }, - }, - "/api/monitor": map[string]interface{}{ - "get": map[string]interface{}{ - "tags": []string{"监控"}, - "summary": "获取监控信息", - "description": "获取工具执行监控信息,支持分页和筛选", - "operationId": "monitor", - "parameters": []map[string]interface{}{ - { - "name": "page", - "in": "query", - "required": false, - "description": "页码", - "schema": map[string]interface{}{ - "type": "integer", - "default": 1, - "minimum": 1, - }, - }, - { - "name": "page_size", - "in": "query", - "required": false, - "description": "每页数量", - "schema": map[string]interface{}{ - "type": "integer", - "default": 20, - "minimum": 1, - "maximum": 100, - }, - }, - { - "name": "status", - "in": "query", - "required": false, - "description": "状态筛选", - "schema": map[string]interface{}{ - "type": "string", - "enum": []string{"success", "failed", "running"}, - }, - }, - { - "name": "tool", - "in": "query", - "required": false, - "description": "工具名称筛选(支持部分匹配)", - "schema": map[string]interface{}{ - "type": "string", - }, - }, - }, - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "获取成功", - "content": map[string]interface{}{ - "application/json": map[string]interface{}{ - "schema": map[string]interface{}{ - "$ref": "#/components/schemas/MonitorResponse", - }, - }, - }, - }, - "401": map[string]interface{}{ - "description": "未授权", - }, - }, - }, - }, - "/api/monitor/execution/{id}": map[string]interface{}{ - "get": map[string]interface{}{ - "tags": []string{"监控"}, - "summary": "获取执行记录", - "description": "获取指定执行记录的详细信息", - "operationId": "getExecution", - "parameters": []map[string]interface{}{ - { - "name": "id", - "in": "path", - "required": true, - "description": "执行ID", - "schema": map[string]interface{}{ - "type": "string", - }, - }, - }, - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "获取成功", - "content": map[string]interface{}{ - "application/json": map[string]interface{}{ - "schema": map[string]interface{}{ - "$ref": "#/components/schemas/ToolExecution", - }, - }, - }, - }, - "404": map[string]interface{}{ - "description": "执行记录不存在", - }, - "401": map[string]interface{}{ - "description": "未授权", - }, - }, - }, - "delete": map[string]interface{}{ - "tags": []string{"监控"}, - "summary": "删除执行记录", - "description": "删除指定的执行记录", - "operationId": "deleteExecution", - "parameters": []map[string]interface{}{ - { - "name": "id", - "in": "path", - "required": true, - "description": "执行ID", - "schema": map[string]interface{}{ - "type": "string", - }, - }, - }, - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "删除成功", - }, - "404": map[string]interface{}{ - "description": "执行记录不存在", - }, - "401": map[string]interface{}{ - "description": "未授权", - }, - }, - }, - }, - "/api/monitor/executions": map[string]interface{}{ - "delete": map[string]interface{}{ - "tags": []string{"监控"}, - "summary": "批量删除执行记录", - "description": "批量删除执行记录", - "operationId": "deleteExecutions", - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "删除成功", - }, - "401": map[string]interface{}{ - "description": "未授权", - }, - }, - }, - }, - "/api/monitor/stats": map[string]interface{}{ - "get": map[string]interface{}{ - "tags": []string{"监控"}, - "summary": "获取统计信息", - "description": "获取工具执行统计信息", - "operationId": "getStats", - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "获取成功", - "content": map[string]interface{}{ - "application/json": map[string]interface{}{ - "schema": map[string]interface{}{ - "type": "object", - "description": "统计信息", - }, - }, - }, - }, - "401": map[string]interface{}{ - "description": "未授权", - }, - }, - }, - }, - "/api/config": map[string]interface{}{ - "get": map[string]interface{}{ - "tags": []string{"配置管理"}, - "summary": "获取配置", - "description": "获取系统配置信息", - "operationId": "getConfig", - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "获取成功", - "content": map[string]interface{}{ - "application/json": map[string]interface{}{ - "schema": map[string]interface{}{ - "$ref": "#/components/schemas/ConfigResponse", - }, - }, - }, - }, - "401": map[string]interface{}{ - "description": "未授权", - }, - }, - }, - "put": map[string]interface{}{ - "tags": []string{"配置管理"}, - "summary": "更新配置", - "description": "更新系统配置", - "operationId": "updateConfig", - "requestBody": map[string]interface{}{ - "required": true, - "content": map[string]interface{}{ - "application/json": map[string]interface{}{ - "schema": map[string]interface{}{ - "$ref": "#/components/schemas/UpdateConfigRequest", - }, - }, - }, - }, - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "更新成功", - }, - "400": map[string]interface{}{ - "description": "请求参数错误", - }, - "401": map[string]interface{}{ - "description": "未授权", - }, - }, - }, - }, - "/api/config/tools": map[string]interface{}{ - "get": map[string]interface{}{ - "tags": []string{"配置管理"}, - "summary": "获取工具配置", - "description": "获取所有工具的配置信息", - "operationId": "getTools", - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "获取成功", - "content": map[string]interface{}{ - "application/json": map[string]interface{}{ - "schema": map[string]interface{}{ - "type": "array", - "description": "工具配置列表", - }, - }, - }, - }, - "401": map[string]interface{}{ - "description": "未授权", - }, - }, - }, - }, - "/api/config/apply": map[string]interface{}{ - "post": map[string]interface{}{ - "tags": []string{"配置管理"}, - "summary": "应用配置", - "description": "应用配置更改", - "operationId": "applyConfig", - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "应用成功", - }, - "401": map[string]interface{}{ - "description": "未授权", - }, - }, - }, - }, - "/api/external-mcp": map[string]interface{}{ - "get": map[string]interface{}{ - "tags": []string{"外部MCP管理"}, - "summary": "列出外部MCP", - "description": "获取所有外部MCP配置和状态", - "operationId": "getExternalMCPs", - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "获取成功", - "content": map[string]interface{}{ - "application/json": map[string]interface{}{ - "schema": map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "servers": map[string]interface{}{ - "type": "object", - "description": "MCP服务器配置", - "additionalProperties": map[string]interface{}{ - "$ref": "#/components/schemas/ExternalMCPResponse", - }, - }, - "stats": map[string]interface{}{ - "type": "object", - "description": "统计信息", - }, - }, - }, - }, - }, - }, - "401": map[string]interface{}{ - "description": "未授权", - }, - }, - }, - }, - "/api/external-mcp/stats": map[string]interface{}{ - "get": map[string]interface{}{ - "tags": []string{"外部MCP管理"}, - "summary": "获取外部MCP统计", - "description": "获取外部MCP统计信息", - "operationId": "getExternalMCPStats", - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "获取成功", - "content": map[string]interface{}{ - "application/json": map[string]interface{}{ - "schema": map[string]interface{}{ - "type": "object", - "description": "统计信息", - }, - }, - }, - }, - "401": map[string]interface{}{ - "description": "未授权", - }, - }, - }, - }, - "/api/external-mcp/{name}": map[string]interface{}{ - "get": map[string]interface{}{ - "tags": []string{"外部MCP管理"}, - "summary": "获取外部MCP", - "description": "获取指定外部MCP的配置和状态", - "operationId": "getExternalMCP", - "parameters": []map[string]interface{}{ - { - "name": "name", - "in": "path", - "required": true, - "description": "MCP名称", - "schema": map[string]interface{}{ - "type": "string", - }, - }, - }, - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "获取成功", - "content": map[string]interface{}{ - "application/json": map[string]interface{}{ - "schema": map[string]interface{}{ - "$ref": "#/components/schemas/ExternalMCPResponse", - }, - }, - }, - }, - "404": map[string]interface{}{ - "description": "MCP不存在", - }, - "401": map[string]interface{}{ - "description": "未授权", - }, - }, - }, - "put": map[string]interface{}{ - "tags": []string{"外部MCP管理"}, - "summary": "添加或更新外部MCP", - "description": "添加新的外部MCP配置或更新现有配置。\n**传输方式**:\n支持两种传输方式:\n**1. stdio(标准输入输出)**:\n```json\n{\n \"config\": {\n \"enabled\": true,\n \"command\": \"node\",\n \"args\": [\"/path/to/mcp-server.js\"],\n \"env\": {}\n }\n}\n```\n**2. sse(Server-Sent Events)**:\n```json\n{\n \"config\": {\n \"enabled\": true,\n \"transport\": \"sse\",\n \"url\": \"http://127.0.0.1:8082/sse\",\n \"timeout\": 30\n }\n}\n```\n**配置参数说明**:\n- `enabled`: 是否启用(boolean,必需)\n- `command`: 命令(stdio模式必需,如:\"node\", \"python\")\n- `args`: 命令参数数组(stdio模式必需)\n- `env`: 环境变量(object,可选)\n- `transport`: 传输方式(\"stdio\" 或 \"sse\",sse模式必需)\n- `url`: SSE端点URL(sse模式必需)\n- `timeout`: 超时时间(秒,可选,默认30)\n- `description`: 描述(可选)", - "operationId": "addOrUpdateExternalMCP", - "parameters": []map[string]interface{}{ - { - "name": "name", - "in": "path", - "required": true, - "description": "MCP名称(唯一标识符)", - "schema": map[string]interface{}{ - "type": "string", - }, - }, - }, - "requestBody": map[string]interface{}{ - "required": true, - "content": map[string]interface{}{ - "application/json": map[string]interface{}{ - "schema": map[string]interface{}{ - "$ref": "#/components/schemas/AddOrUpdateExternalMCPRequest", - }, - "examples": map[string]interface{}{ - "stdio": map[string]interface{}{ - "summary": "stdio模式配置", - "description": "使用标准输入输出方式连接外部MCP服务器", - "value": map[string]interface{}{ - "config": map[string]interface{}{ - "enabled": true, - "command": "node", - "args": []string{"/path/to/mcp-server.js"}, - "env": map[string]interface{}{}, - "timeout": 30, - "description": "Node.js MCP服务器", - }, - }, - }, - "sse": map[string]interface{}{ - "summary": "SSE模式配置", - "description": "使用Server-Sent Events方式连接外部MCP服务器", - "value": map[string]interface{}{ - "config": map[string]interface{}{ - "enabled": true, - "transport": "sse", - "url": "http://127.0.0.1:8082/sse", - "timeout": 30, - "description": "SSE MCP服务器", - }, - }, - }, - }, - }, - }, - }, - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "操作成功", - "content": map[string]interface{}{ - "application/json": map[string]interface{}{ - "schema": map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "message": map[string]interface{}{ - "type": "string", - "example": "外部MCP配置已保存", - }, - }, - }, - }, - }, - }, - "400": map[string]interface{}{ - "description": "请求参数错误(如配置格式不正确、缺少必需字段等)", - "content": map[string]interface{}{ - "application/json": map[string]interface{}{ - "schema": map[string]interface{}{ - "$ref": "#/components/schemas/Error", - }, - "example": map[string]interface{}{ - "error": "stdio模式需要提供command和args参数", - }, - }, - }, - }, - "401": map[string]interface{}{ - "description": "未授权", - }, - }, - }, - "delete": map[string]interface{}{ - "tags": []string{"外部MCP管理"}, - "summary": "删除外部MCP", - "description": "删除指定的外部MCP配置", - "operationId": "deleteExternalMCP", - "parameters": []map[string]interface{}{ - { - "name": "name", - "in": "path", - "required": true, - "description": "MCP名称", - "schema": map[string]interface{}{ - "type": "string", - }, - }, - }, - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "删除成功", - }, - "404": map[string]interface{}{ - "description": "MCP不存在", - }, - "401": map[string]interface{}{ - "description": "未授权", - }, - }, - }, - }, - "/api/external-mcp/{name}/start": map[string]interface{}{ - "post": map[string]interface{}{ - "tags": []string{"外部MCP管理"}, - "summary": "启动外部MCP", - "description": "启动指定的外部MCP服务器", - "operationId": "startExternalMCP", - "parameters": []map[string]interface{}{ - { - "name": "name", - "in": "path", - "required": true, - "description": "MCP名称", - "schema": map[string]interface{}{ - "type": "string", - }, - }, - }, - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "启动成功", - }, - "404": map[string]interface{}{ - "description": "MCP不存在", - }, - "401": map[string]interface{}{ - "description": "未授权", - }, - }, - }, - }, - "/api/external-mcp/{name}/stop": map[string]interface{}{ - "post": map[string]interface{}{ - "tags": []string{"外部MCP管理"}, - "summary": "停止外部MCP", - "description": "停止指定的外部MCP服务器", - "operationId": "stopExternalMCP", - "parameters": []map[string]interface{}{ - { - "name": "name", - "in": "path", - "required": true, - "description": "MCP名称", - "schema": map[string]interface{}{ - "type": "string", - }, - }, - }, - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "停止成功", - }, - "404": map[string]interface{}{ - "description": "MCP不存在", - }, - "401": map[string]interface{}{ - "description": "未授权", - }, - }, - }, - }, - "/api/attack-chain/{conversationId}": map[string]interface{}{ - "get": map[string]interface{}{ - "tags": []string{"攻击链"}, - "summary": "获取攻击链", - "description": "获取指定对话的攻击链可视化数据", - "operationId": "getAttackChain", - "parameters": []map[string]interface{}{ - { - "name": "conversationId", - "in": "path", - "required": true, - "description": "对话ID", - "schema": map[string]interface{}{ - "type": "string", - }, - }, - }, - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "获取成功", - "content": map[string]interface{}{ - "application/json": map[string]interface{}{ - "schema": map[string]interface{}{ - "$ref": "#/components/schemas/AttackChain", - }, - }, - }, - }, - "404": map[string]interface{}{ - "description": "对话不存在", - }, - "401": map[string]interface{}{ - "description": "未授权", - }, - }, - }, - }, - "/api/attack-chain/{conversationId}/regenerate": map[string]interface{}{ - "post": map[string]interface{}{ - "tags": []string{"攻击链"}, - "summary": "重新生成攻击链", - "description": "重新生成指定对话的攻击链可视化数据", - "operationId": "regenerateAttackChain", - "parameters": []map[string]interface{}{ - { - "name": "conversationId", - "in": "path", - "required": true, - "description": "对话ID", - "schema": map[string]interface{}{ - "type": "string", - }, - }, - }, - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "重新生成成功", - "content": map[string]interface{}{ - "application/json": map[string]interface{}{ - "schema": map[string]interface{}{ - "$ref": "#/components/schemas/AttackChain", - }, - }, - }, - }, - "404": map[string]interface{}{ - "description": "对话不存在", - }, - "401": map[string]interface{}{ - "description": "未授权", - }, - }, - }, - }, - "/api/conversations/{id}/pinned": map[string]interface{}{ - "put": map[string]interface{}{ - "tags": []string{"对话管理"}, - "summary": "设置对话置顶", - "description": "设置或取消对话的置顶状态", - "operationId": "updateConversationPinned", - "parameters": []map[string]interface{}{ - { - "name": "id", - "in": "path", - "required": true, - "description": "对话ID", - "schema": map[string]interface{}{ - "type": "string", - }, - }, - }, - "requestBody": map[string]interface{}{ - "required": true, - "content": map[string]interface{}{ - "application/json": map[string]interface{}{ - "schema": map[string]interface{}{ - "type": "object", - "required": []string{"pinned"}, - "properties": map[string]interface{}{ - "pinned": map[string]interface{}{ - "type": "boolean", - "description": "是否置顶", - }, - }, - }, - }, - }, - }, - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "更新成功", - }, - "404": map[string]interface{}{ - "description": "对话不存在", - }, - "401": map[string]interface{}{ - "description": "未授权", - }, - }, - }, - }, - "/api/groups/{id}/pinned": map[string]interface{}{ - "put": map[string]interface{}{ - "tags": []string{"对话分组"}, - "summary": "设置分组置顶", - "description": "设置或取消分组的置顶状态", - "operationId": "updateGroupPinned", - "parameters": []map[string]interface{}{ - { - "name": "id", - "in": "path", - "required": true, - "description": "分组ID", - "schema": map[string]interface{}{ - "type": "string", - }, - }, - }, - "requestBody": map[string]interface{}{ - "required": true, - "content": map[string]interface{}{ - "application/json": map[string]interface{}{ - "schema": map[string]interface{}{ - "type": "object", - "required": []string{"pinned"}, - "properties": map[string]interface{}{ - "pinned": map[string]interface{}{ - "type": "boolean", - "description": "是否置顶", - }, - }, - }, - }, - }, - }, - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "更新成功", - }, - "404": map[string]interface{}{ - "description": "分组不存在", - }, - "401": map[string]interface{}{ - "description": "未授权", - }, - }, - }, - }, - "/api/groups/{id}/conversations/{conversationId}/pinned": map[string]interface{}{ - "put": map[string]interface{}{ - "tags": []string{"对话分组"}, - "summary": "设置分组中对话的置顶", - "description": "设置或取消分组中对话的置顶状态", - "operationId": "updateConversationPinnedInGroup", - "parameters": []map[string]interface{}{ - { - "name": "id", - "in": "path", - "required": true, - "description": "分组ID", - "schema": map[string]interface{}{ - "type": "string", - }, - }, - { - "name": "conversationId", - "in": "path", - "required": true, - "description": "对话ID", - "schema": map[string]interface{}{ - "type": "string", - }, - }, - }, - "requestBody": map[string]interface{}{ - "required": true, - "content": map[string]interface{}{ - "application/json": map[string]interface{}{ - "schema": map[string]interface{}{ - "type": "object", - "required": []string{"pinned"}, - "properties": map[string]interface{}{ - "pinned": map[string]interface{}{ - "type": "boolean", - "description": "是否置顶", - }, - }, - }, - }, - }, - }, - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "更新成功", - }, - "404": map[string]interface{}{ - "description": "对话或分组不存在", - }, - "401": map[string]interface{}{ - "description": "未授权", - }, - }, - }, - }, - "/api/knowledge/categories": map[string]interface{}{ - "get": map[string]interface{}{ - "tags": []string{"知识库"}, - "summary": "获取分类", - "description": "获取知识库的所有分类", - "operationId": "getKnowledgeCategories", - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "获取成功", - "content": map[string]interface{}{ - "application/json": map[string]interface{}{ - "schema": map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "categories": map[string]interface{}{ - "type": "array", - "description": "分类列表", - "items": map[string]interface{}{ - "type": "string", - }, - }, - "enabled": map[string]interface{}{ - "type": "boolean", - "description": "知识库是否启用", - }, - }, - }, - }, - }, - }, - "401": map[string]interface{}{ - "description": "未授权", - }, - }, - }, - }, - "/api/knowledge/items": map[string]interface{}{ - "get": map[string]interface{}{ - "tags": []string{"知识库"}, - "summary": "列出知识项", - "description": "获取知识库中的所有知识项", - "operationId": "getKnowledgeItems", - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "获取成功", - "content": map[string]interface{}{ - "application/json": map[string]interface{}{ - "schema": map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "items": map[string]interface{}{ - "type": "array", - "description": "知识项列表", - }, - "enabled": map[string]interface{}{ - "type": "boolean", - "description": "知识库是否启用", - }, - }, - }, - }, - }, - }, - "401": map[string]interface{}{ - "description": "未授权", - }, - }, - }, - "post": map[string]interface{}{ - "tags": []string{"知识库"}, - "summary": "创建知识项", - "description": "创建新的知识项", - "operationId": "createKnowledgeItem", - "requestBody": map[string]interface{}{ - "required": true, - "content": map[string]interface{}{ - "application/json": map[string]interface{}{ - "schema": map[string]interface{}{ - "type": "object", - "description": "知识项数据", - }, - }, - }, - }, - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "创建成功", - }, - "400": map[string]interface{}{ - "description": "请求参数错误", - }, - "401": map[string]interface{}{ - "description": "未授权", - }, - }, - }, - }, - "/api/knowledge/items/{id}": map[string]interface{}{ - "get": map[string]interface{}{ - "tags": []string{"知识库"}, - "summary": "获取知识项", - "description": "获取指定知识项的详细信息", - "operationId": "getKnowledgeItem", - "parameters": []map[string]interface{}{ - { - "name": "id", - "in": "path", - "required": true, - "description": "知识项ID", - "schema": map[string]interface{}{ - "type": "string", - }, - }, - }, - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "获取成功", - }, - "404": map[string]interface{}{ - "description": "知识项不存在", - }, - "401": map[string]interface{}{ - "description": "未授权", - }, - }, - }, - "put": map[string]interface{}{ - "tags": []string{"知识库"}, - "summary": "更新知识项", - "description": "更新指定知识项", - "operationId": "updateKnowledgeItem", - "parameters": []map[string]interface{}{ - { - "name": "id", - "in": "path", - "required": true, - "description": "知识项ID", - "schema": map[string]interface{}{ - "type": "string", - }, - }, - }, - "requestBody": map[string]interface{}{ - "required": true, - "content": map[string]interface{}{ - "application/json": map[string]interface{}{ - "schema": map[string]interface{}{ - "type": "object", - "description": "知识项数据", - }, - }, - }, - }, - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "更新成功", - }, - "404": map[string]interface{}{ - "description": "知识项不存在", - }, - "401": map[string]interface{}{ - "description": "未授权", - }, - }, - }, - "delete": map[string]interface{}{ - "tags": []string{"知识库"}, - "summary": "删除知识项", - "description": "删除指定知识项", - "operationId": "deleteKnowledgeItem", - "parameters": []map[string]interface{}{ - { - "name": "id", - "in": "path", - "required": true, - "description": "知识项ID", - "schema": map[string]interface{}{ - "type": "string", - }, - }, - }, - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "删除成功", - }, - "404": map[string]interface{}{ - "description": "知识项不存在", - }, - "401": map[string]interface{}{ - "description": "未授权", - }, - }, - }, - }, - "/api/knowledge/index-status": map[string]interface{}{ - "get": map[string]interface{}{ - "tags": []string{"知识库"}, - "summary": "获取索引状态", - "description": "获取知识库索引的构建状态", - "operationId": "getIndexStatus", - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "获取成功", - "content": map[string]interface{}{ - "application/json": map[string]interface{}{ - "schema": map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "enabled": map[string]interface{}{ - "type": "boolean", - "description": "知识库是否启用", - }, - "total_items": map[string]interface{}{ - "type": "integer", - "description": "总知识项数", - }, - "indexed_items": map[string]interface{}{ - "type": "integer", - "description": "已索引知识项数", - }, - "progress_percent": map[string]interface{}{ - "type": "number", - "description": "索引进度百分比", - }, - "is_complete": map[string]interface{}{ - "type": "boolean", - "description": "索引是否完成", - }, - }, - }, - }, - }, - }, - "401": map[string]interface{}{ - "description": "未授权", - }, - }, - }, - }, - "/api/knowledge/index": map[string]interface{}{ - "post": map[string]interface{}{ - "tags": []string{"知识库"}, - "summary": "重建索引", - "description": "重新构建知识库索引", - "operationId": "rebuildIndex", - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "重建索引任务已启动", - }, - "401": map[string]interface{}{ - "description": "未授权", - }, - }, - }, - }, - "/api/knowledge/scan": map[string]interface{}{ - "post": map[string]interface{}{ - "tags": []string{"知识库"}, - "summary": "扫描知识库", - "description": "扫描知识库目录,导入新的知识文件", - "operationId": "scanKnowledgeBase", - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "扫描任务已启动", - }, - "401": map[string]interface{}{ - "description": "未授权", - }, - }, - }, - }, - "/api/knowledge/search": map[string]interface{}{ - "post": map[string]interface{}{ - "tags": []string{"知识库"}, - "summary": "搜索知识库", - "description": "在知识库中搜索相关内容。基于向量检索,按查询与知识片段的语义相似度(余弦)返回最相关结果。\n**搜索说明**:\n- 语义相似度搜索:嵌入向量 + 余弦相似度,可配置相似度阈值与 TopK\n- 可按风险类型等元数据过滤(如:SQL注入、XSS、文件上传等)\n- 建议先调用 `/api/knowledge/categories` 获取可用的风险类型列表\n**使用示例**:\n```json\n{\n \"query\": \"SQL注入漏洞的检测方法\",\n \"riskType\": \"SQL注入\",\n \"topK\": 5,\n \"threshold\": 0.7\n}\n```", - "operationId": "searchKnowledge", - "requestBody": map[string]interface{}{ - "required": true, - "content": map[string]interface{}{ - "application/json": map[string]interface{}{ - "schema": map[string]interface{}{ - "type": "object", - "required": []string{"query"}, - "properties": map[string]interface{}{ - "query": map[string]interface{}{ - "type": "string", - "description": "搜索查询内容,描述你想要了解的安全知识主题(必需)", - "example": "SQL注入漏洞的检测方法", - }, - "riskType": map[string]interface{}{ - "type": "string", - "description": "可选:指定风险类型(如:SQL注入、XSS、文件上传等)。建议先调用 `/api/knowledge/categories` 获取可用的风险类型列表,然后使用正确的风险类型进行精确搜索,这样可以大幅减少检索时间。如果不指定则搜索所有类型。", - "example": "SQL注入", - }, - "topK": map[string]interface{}{ - "type": "integer", - "description": "可选:返回Top-K结果数量,默认5", - "default": 5, - "minimum": 1, - "maximum": 50, - "example": 5, - }, - "threshold": map[string]interface{}{ - "type": "number", - "format": "float", - "description": "可选:相似度阈值(0-1之间),默认0.7。只有相似度大于等于此值的结果才会返回", - "default": 0.7, - "minimum": 0, - "maximum": 1, - "example": 0.7, - }, - }, - }, - "examples": map[string]interface{}{ - "basic": map[string]interface{}{ - "summary": "基础搜索", - "description": "最简单的搜索,只提供查询内容", - "value": map[string]interface{}{ - "query": "SQL注入漏洞的检测方法", - }, - }, - "withRiskType": map[string]interface{}{ - "summary": "按风险类型搜索", - "description": "指定风险类型进行精确搜索", - "value": map[string]interface{}{ - "query": "SQL注入漏洞的检测方法", - "riskType": "SQL注入", - "topK": 5, - "threshold": 0.7, - }, - }, - }, - }, - }, - }, - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "搜索成功", - "content": map[string]interface{}{ - "application/json": map[string]interface{}{ - "schema": map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "results": map[string]interface{}{ - "type": "array", - "description": "搜索结果列表,每个结果包含:item(知识项信息)、chunks(匹配的知识片段)、score(相似度分数)", - "items": map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "item": map[string]interface{}{ - "type": "object", - "description": "知识项信息", - }, - "chunks": map[string]interface{}{ - "type": "array", - "description": "匹配的知识片段列表", - }, - "score": map[string]interface{}{ - "type": "number", - "description": "相似度分数(0-1之间)", - }, - }, - }, - }, - "enabled": map[string]interface{}{ - "type": "boolean", - "description": "知识库是否启用", - }, - }, - }, - "example": map[string]interface{}{ - "results": []map[string]interface{}{ - { - "item": map[string]interface{}{ - "id": "item-1", - "title": "SQL注入漏洞检测", - "category": "SQL注入", - }, - "chunks": []map[string]interface{}{ - { - "text": "SQL注入漏洞的检测方法包括...", - }, - }, - "score": 0.85, - }, - }, - "enabled": true, - }, - }, - }, - }, - "400": map[string]interface{}{ - "description": "请求参数错误(如query为空)", - "content": map[string]interface{}{ - "application/json": map[string]interface{}{ - "schema": map[string]interface{}{ - "$ref": "#/components/schemas/Error", - }, - "example": map[string]interface{}{ - "error": "查询不能为空", - }, - }, - }, - }, - "401": map[string]interface{}{ - "description": "未授权", - }, - "500": map[string]interface{}{ - "description": "服务器内部错误(如知识库未启用或检索失败)", - }, - }, - }, - }, - "/api/knowledge/retrieval-logs": map[string]interface{}{ - "get": map[string]interface{}{ - "tags": []string{"知识库"}, - "summary": "获取检索日志", - "description": "获取知识库检索日志", - "operationId": "getRetrievalLogs", - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "获取成功", - "content": map[string]interface{}{ - "application/json": map[string]interface{}{ - "schema": map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "logs": map[string]interface{}{ - "type": "array", - "description": "检索日志列表", - }, - "enabled": map[string]interface{}{ - "type": "boolean", - "description": "知识库是否启用", - }, - }, - }, - }, - }, - }, - "401": map[string]interface{}{ - "description": "未授权", - }, - }, - }, - }, - "/api/knowledge/retrieval-logs/{id}": map[string]interface{}{ - "delete": map[string]interface{}{ - "tags": []string{"知识库"}, - "summary": "删除检索日志", - "description": "删除指定的检索日志", - "operationId": "deleteRetrievalLog", - "parameters": []map[string]interface{}{ - { - "name": "id", - "in": "path", - "required": true, - "description": "日志ID", - "schema": map[string]interface{}{ - "type": "string", - }, - }, - }, - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "删除成功", - }, - "404": map[string]interface{}{ - "description": "日志不存在", - }, - "401": map[string]interface{}{ - "description": "未授权", - }, - }, - }, - }, - "/api/mcp": map[string]interface{}{ - "post": map[string]interface{}{ - "tags": []string{"MCP"}, - "summary": "MCP端点", - "description": "MCP (Model Context Protocol) 端点,用于处理MCP协议请求。\n**协议说明**:\n本端点遵循 JSON-RPC 2.0 规范,支持以下方法:\n**1. initialize** - 初始化MCP连接\n```json\n{\n \"jsonrpc\": \"2.0\",\n \"id\": \"init-1\",\n \"method\": \"initialize\",\n \"params\": {\n \"protocolVersion\": \"2024-11-05\",\n \"capabilities\": {},\n \"clientInfo\": {\n \"name\": \"MyClient\",\n \"version\": \"1.0.0\"\n }\n }\n}\n```\n**2. tools/list** - 列出所有可用工具\n```json\n{\n \"jsonrpc\": \"2.0\",\n \"id\": \"list-1\",\n \"method\": \"tools/list\",\n \"params\": {}\n}\n```\n**3. tools/call** - 调用工具\n```json\n{\n \"jsonrpc\": \"2.0\",\n \"id\": \"call-1\",\n \"method\": \"tools/call\",\n \"params\": {\n \"name\": \"nmap\",\n \"arguments\": {\n \"target\": \"192.168.1.1\",\n \"ports\": \"80,443\"\n }\n }\n}\n```\n**4. prompts/list** - 列出所有提示词模板\n```json\n{\n \"jsonrpc\": \"2.0\",\n \"id\": \"prompts-list-1\",\n \"method\": \"prompts/list\",\n \"params\": {}\n}\n```\n**5. prompts/get** - 获取提示词模板\n```json\n{\n \"jsonrpc\": \"2.0\",\n \"id\": \"prompt-get-1\",\n \"method\": \"prompts/get\",\n \"params\": {\n \"name\": \"prompt-name\",\n \"arguments\": {}\n }\n}\n```\n**6. resources/list** - 列出所有资源\n```json\n{\n \"jsonrpc\": \"2.0\",\n \"id\": \"resources-list-1\",\n \"method\": \"resources/list\",\n \"params\": {}\n}\n```\n**7. resources/read** - 读取资源内容\n```json\n{\n \"jsonrpc\": \"2.0\",\n \"id\": \"resource-read-1\",\n \"method\": \"resources/read\",\n \"params\": {\n \"uri\": \"resource://example\"\n }\n}\n```\n**错误代码说明**:\n- `-32700`: Parse error - JSON解析错误\n- `-32600`: Invalid Request - 无效请求\n- `-32601`: Method not found - 方法不存在\n- `-32602`: Invalid params - 参数无效\n- `-32603`: Internal error - 内部错误", - "operationId": "mcpEndpoint", - "requestBody": map[string]interface{}{ - "required": true, - "content": map[string]interface{}{ - "application/json": map[string]interface{}{ - "schema": map[string]interface{}{ - "$ref": "#/components/schemas/MCPMessage", - }, - "examples": map[string]interface{}{ - "listTools": map[string]interface{}{ - "summary": "列出所有工具", - "description": "获取系统中所有可用的MCP工具列表", - "value": map[string]interface{}{ - "jsonrpc": "2.0", - "id": "list-tools-1", - "method": "tools/list", - "params": map[string]interface{}{}, - }, - }, - "callTool": map[string]interface{}{ - "summary": "调用工具", - "description": "调用指定的MCP工具", - "value": map[string]interface{}{ - "jsonrpc": "2.0", - "id": "call-tool-1", - "method": "tools/call", - "params": map[string]interface{}{ - "name": "nmap", - "arguments": map[string]interface{}{ - "target": "192.168.1.1", - "ports": "80,443", - }, - }, - }, - }, - "initialize": map[string]interface{}{ - "summary": "初始化连接", - "description": "初始化MCP连接,获取服务器能力", - "value": map[string]interface{}{ - "jsonrpc": "2.0", - "id": "init-1", - "method": "initialize", - "params": map[string]interface{}{ - "protocolVersion": "2024-11-05", - "capabilities": map[string]interface{}{}, - "clientInfo": map[string]interface{}{ - "name": "MyClient", - "version": "1.0.0", - }, - }, - }, - }, - }, - }, - }, - }, - "responses": map[string]interface{}{ - "200": map[string]interface{}{ - "description": "MCP响应(JSON-RPC 2.0格式)", - "content": map[string]interface{}{ - "application/json": map[string]interface{}{ - "schema": map[string]interface{}{ - "$ref": "#/components/schemas/MCPResponse", - }, - "examples": map[string]interface{}{ - "success": map[string]interface{}{ - "summary": "成功响应", - "description": "工具调用成功的响应示例", - "value": map[string]interface{}{ - "jsonrpc": "2.0", - "id": "call-tool-1", - "result": map[string]interface{}{ - "content": []map[string]interface{}{ - { - "type": "text", - "text": "工具执行结果...", - }, - }, - "isError": false, - }, - }, - }, - "error": map[string]interface{}{ - "summary": "错误响应", - "description": "工具调用失败的响应示例", - "value": map[string]interface{}{ - "jsonrpc": "2.0", - "id": "call-tool-1", - "error": map[string]interface{}{ - "code": -32601, - "message": "Tool not found", - "data": "工具 'unknown-tool' 不存在", - }, - }, - }, - }, - }, - }, - }, - "400": map[string]interface{}{ - "description": "请求格式错误(JSON解析失败)", - "content": map[string]interface{}{ - "application/json": map[string]interface{}{ - "schema": map[string]interface{}{ - "$ref": "#/components/schemas/MCPResponse", - }, - "example": map[string]interface{}{ - "id": nil, - "error": map[string]interface{}{ - "code": -32700, - "message": "Parse error", - "data": "unexpected end of JSON input", - }, - "jsonrpc": "2.0", - }, - }, - }, - }, - "401": map[string]interface{}{ - "description": "未授权,需要有效的Token", - }, - "405": map[string]interface{}{ - "description": "方法不允许(仅支持POST请求)", - }, - }, - }, - }, - }, - } - - enrichSpecWithI18nKeys(spec) - c.JSON(http.StatusOK, spec) -} - -// GetConversationResults 获取对话结果(OpenAPI端点) -// 注意:创建对话和获取对话详情直接使用标准的 /api/conversations 端点 -// 这个端点只是为了提供结果聚合功能 -func (h *OpenAPIHandler) GetConversationResults(c *gin.Context) { - conversationID := c.Param("id") - - // 验证对话是否存在 - conv, err := h.db.GetConversation(conversationID) - if err != nil { - h.logger.Error("获取对话失败", zap.Error(err)) - c.JSON(http.StatusNotFound, gin.H{"error": "对话不存在"}) - return - } - - // 获取消息列表 - messages, err := h.db.GetMessages(conversationID) - if err != nil { - h.logger.Error("获取消息失败", zap.Error(err)) - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - - // 获取漏洞列表 - vulnList, err := h.db.ListVulnerabilities(1000, 0, "", conversationID, "", "") - if err != nil { - h.logger.Warn("获取漏洞列表失败", zap.Error(err)) - vulnList = []*database.Vulnerability{} - } - vulnerabilities := make([]database.Vulnerability, len(vulnList)) - for i, v := range vulnList { - vulnerabilities[i] = *v - } - - // 获取执行结果(从MCP执行记录中获取) - executionResults := []map[string]interface{}{} - for _, msg := range messages { - if len(msg.MCPExecutionIDs) > 0 { - for _, execID := range msg.MCPExecutionIDs { - // 尝试从结果存储中获取执行结果 - if h.resultStorage != nil { - result, err := h.resultStorage.GetResult(execID) - if err == nil && result != "" { - // 获取元数据以获取工具名称和创建时间 - metadata, err := h.resultStorage.GetResultMetadata(execID) - toolName := "unknown" - createdAt := time.Now() - if err == nil && metadata != nil { - toolName = metadata.ToolName - createdAt = metadata.CreatedAt - } - executionResults = append(executionResults, map[string]interface{}{ - "id": execID, - "toolName": toolName, - "status": "success", - "result": result, - "createdAt": createdAt.Format(time.RFC3339), - }) - } - } - } - } - } - - response := map[string]interface{}{ - "conversationId": conv.ID, - "messages": messages, - "vulnerabilities": vulnerabilities, - "executionResults": executionResults, - } - - c.JSON(http.StatusOK, response) -} diff --git a/handler/openapi_i18n.go b/handler/openapi_i18n.go deleted file mode 100644 index 3479766e..00000000 --- a/handler/openapi_i18n.go +++ /dev/null @@ -1,139 +0,0 @@ -package handler - -// apiDocI18n 为 OpenAPI 文档提供 x-i18n-* 扩展键,供前端 apiDocs 国际化使用。 -// 前端通过 apiDocs.tags.* / apiDocs.summary.* / apiDocs.response.* 翻译。 - -var apiDocI18nTagToKey = map[string]string{ - "认证": "auth", "对话管理": "conversationManagement", "对话交互": "conversationInteraction", - "批量任务": "batchTasks", "对话分组": "conversationGroups", "漏洞管理": "vulnerabilityManagement", - "角色管理": "roleManagement", "Skills管理": "skillsManagement", "监控": "monitoring", - "配置管理": "configManagement", "外部MCP管理": "externalMCPManagement", "攻击链": "attackChain", - "知识库": "knowledgeBase", "MCP": "mcp", -} - -var apiDocI18nSummaryToKey = map[string]string{ - "用户登录": "login", "用户登出": "logout", "修改密码": "changePassword", "验证Token": "validateToken", - "创建对话": "createConversation", "列出对话": "listConversations", "查看对话详情": "getConversationDetail", - "更新对话": "updateConversation", "删除对话": "deleteConversation", "获取对话结果": "getConversationResult", - "发送消息并获取AI回复(非流式)": "sendMessageNonStream", "发送消息并获取AI回复(流式)": "sendMessageStream", - "取消任务": "cancelTask", "列出运行中的任务": "listRunningTasks", "列出已完成的任务": "listCompletedTasks", - "创建批量任务队列": "createBatchQueue", "列出批量任务队列": "listBatchQueues", "获取批量任务队列": "getBatchQueue", - "删除批量任务队列": "deleteBatchQueue", "启动批量任务队列": "startBatchQueue", "暂停批量任务队列": "pauseBatchQueue", - "添加任务到队列": "addTaskToQueue", "SQL注入扫描": "sqlInjectionScan", "端口扫描": "portScan", - "更新批量任务": "updateBatchTask", "删除批量任务": "deleteBatchTask", - "创建分组": "createGroup", "列出分组": "listGroups", "获取分组": "getGroup", "更新分组": "updateGroup", - "删除分组": "deleteGroup", "获取分组中的对话": "getGroupConversations", "添加对话到分组": "addConversationToGroup", - "从分组移除对话": "removeConversationFromGroup", - "列出漏洞": "listVulnerabilities", "创建漏洞": "createVulnerability", "获取漏洞统计": "getVulnerabilityStats", - "获取漏洞": "getVulnerability", "更新漏洞": "updateVulnerability", "删除漏洞": "deleteVulnerability", - "列出角色": "listRoles", "创建角色": "createRole", "获取角色": "getRole", "更新角色": "updateRole", "删除角色": "deleteRole", - "获取可用Skills列表": "getAvailableSkills", "列出Skills": "listSkills", "创建Skill": "createSkill", - "获取Skill统计": "getSkillStats", "清空Skill统计": "clearSkillStats", "获取Skill": "getSkill", - "更新Skill": "updateSkill", "删除Skill": "deleteSkill", "获取绑定角色": "getBoundRoles", - "获取监控信息": "getMonitorInfo", "获取执行记录": "getExecutionRecords", "删除执行记录": "deleteExecutionRecord", - "批量删除执行记录": "batchDeleteExecutionRecords", "获取统计信息": "getStats", - "获取配置": "getConfig", "更新配置": "updateConfig", "获取工具配置": "getToolConfig", "应用配置": "applyConfig", - "列出外部MCP": "listExternalMCP", "获取外部MCP统计": "getExternalMCPStats", "获取外部MCP": "getExternalMCP", - "添加或更新外部MCP": "addOrUpdateExternalMCP", "stdio模式配置": "stdioModeConfig", "SSE模式配置": "sseModeConfig", - "删除外部MCP": "deleteExternalMCP", "启动外部MCP": "startExternalMCP", "停止外部MCP": "stopExternalMCP", - "获取攻击链": "getAttackChain", "重新生成攻击链": "regenerateAttackChain", - "设置对话置顶": "pinConversation", "设置分组置顶": "pinGroup", "设置分组中对话的置顶": "pinGroupConversation", - "获取分类": "getCategories", "列出知识项": "listKnowledgeItems", "创建知识项": "createKnowledgeItem", - "获取知识项": "getKnowledgeItem", "更新知识项": "updateKnowledgeItem", "删除知识项": "deleteKnowledgeItem", - "获取索引状态": "getIndexStatus", "重建索引": "rebuildIndex", "扫描知识库": "scanKnowledgeBase", - "搜索知识库": "searchKnowledgeBase", "基础搜索": "basicSearch", "按风险类型搜索": "searchByRiskType", - "获取检索日志": "getRetrievalLogs", "删除检索日志": "deleteRetrievalLog", - "MCP端点": "mcpEndpoint", "列出所有工具": "listAllTools", "调用工具": "invokeTool", "初始化连接": "initConnection", - "成功响应": "successResponse", "错误响应": "errorResponse", -} - -var apiDocI18nResponseDescToKey = map[string]string{ - "获取成功": "getSuccess", "未授权": "unauthorized", "未授权,需要有效的Token": "unauthorizedToken", - "创建成功": "createSuccess", "请求参数错误": "badRequest", "对话不存在": "conversationNotFound", - "对话不存在或结果不存在": "conversationOrResultNotFound", "请求参数错误(如task为空)": "badRequestTaskEmpty", - "请求参数错误或分组名称已存在": "badRequestGroupNameExists", "分组不存在": "groupNotFound", - "请求参数错误(如配置格式不正确、缺少必需字段等)": "badRequestConfig", - "请求参数错误(如query为空)": "badRequestQueryEmpty", "方法不允许(仅支持POST请求)": "methodNotAllowed", - "登录成功": "loginSuccess", "密码错误": "invalidPassword", "登出成功": "logoutSuccess", - "密码修改成功": "passwordChanged", "Token有效": "tokenValid", "Token无效或已过期": "tokenInvalid", - "对话创建成功": "conversationCreated", "服务器内部错误": "internalError", "更新成功": "updateSuccess", - "删除成功": "deleteSuccess", "队列不存在": "queueNotFound", "启动成功": "startSuccess", - "暂停成功": "pauseSuccess", "添加成功": "addSuccess", - "任务不存在": "taskNotFound", "对话或分组不存在": "conversationOrGroupNotFound", - "取消请求已提交": "cancelSubmitted", "未找到正在执行的任务": "noRunningTask", - "消息发送成功,返回AI回复": "messageSent", "流式响应(Server-Sent Events)": "streamResponse", -} - -// enrichSpecWithI18nKeys 在 spec 的每个 operation 上写入 x-i18n-tags、x-i18n-summary, -// 在每个 response 上写入 x-i18n-description,供前端按 key 做国际化。 -func enrichSpecWithI18nKeys(spec map[string]interface{}) { - paths, _ := spec["paths"].(map[string]interface{}) - if paths == nil { - return - } - for _, pathItem := range paths { - pm, _ := pathItem.(map[string]interface{}) - if pm == nil { - continue - } - for _, method := range []string{"get", "post", "put", "delete", "patch"} { - opVal, ok := pm[method] - if !ok { - continue - } - op, _ := opVal.(map[string]interface{}) - if op == nil { - continue - } - // x-i18n-tags: 与 tags 一一对应的 i18n 键数组(spec 中 tags 为 []string) - switch tags := op["tags"].(type) { - case []string: - if len(tags) > 0 { - keys := make([]string, 0, len(tags)) - for _, s := range tags { - if k := apiDocI18nTagToKey[s]; k != "" { - keys = append(keys, k) - } else { - keys = append(keys, s) - } - } - op["x-i18n-tags"] = keys - } - case []interface{}: - if len(tags) > 0 { - keys := make([]interface{}, 0, len(tags)) - for _, t := range tags { - if s, ok := t.(string); ok { - if k := apiDocI18nTagToKey[s]; k != "" { - keys = append(keys, k) - } else { - keys = append(keys, s) - } - } - } - if len(keys) > 0 { - op["x-i18n-tags"] = keys - } - } - } - // x-i18n-summary - if summary, _ := op["summary"].(string); summary != "" { - if k := apiDocI18nSummaryToKey[summary]; k != "" { - op["x-i18n-summary"] = k - } - } - // responses -> 每个 status -> x-i18n-description - if respMap, _ := op["responses"].(map[string]interface{}); respMap != nil { - for _, rv := range respMap { - if r, _ := rv.(map[string]interface{}); r != nil { - if desc, _ := r["description"].(string); desc != "" { - if k := apiDocI18nResponseDescToKey[desc]; k != "" { - r["x-i18n-description"] = k - } - } - } - } - } - } - } -} diff --git a/handler/robot.go b/handler/robot.go deleted file mode 100644 index a7b8f3a7..00000000 --- a/handler/robot.go +++ /dev/null @@ -1,907 +0,0 @@ -package handler - -import ( - "bytes" - "context" - "crypto/aes" - "crypto/cipher" - "crypto/rand" - "crypto/sha1" - "encoding/base64" - "encoding/binary" - "encoding/json" - "encoding/xml" - "errors" - "fmt" - "io" - "net/http" - "sort" - "strings" - "sync" - "time" - - "cyberstrike-ai/internal/config" - "cyberstrike-ai/internal/database" - - "github.com/gin-gonic/gin" - "go.uber.org/zap" -) - -const ( - robotCmdHelp = "帮助" - robotCmdList = "列表" - robotCmdListAlt = "对话列表" - robotCmdSwitch = "切换" - robotCmdContinue = "继续" - robotCmdNew = "新对话" - robotCmdClear = "清空" - robotCmdCurrent = "当前" - robotCmdStop = "停止" - robotCmdRoles = "角色" - robotCmdRolesList = "角色列表" - robotCmdSwitchRole = "切换角色" - robotCmdDelete = "删除" - robotCmdVersion = "版本" -) - -// RobotHandler 企业微信/钉钉/飞书等机器人回调处理 -type RobotHandler struct { - config *config.Config - db *database.DB - agentHandler *AgentHandler - logger *zap.Logger - mu sync.RWMutex - sessions map[string]string // key: "platform_userID", value: conversationID - sessionRoles map[string]string // key: "platform_userID", value: roleName(默认"默认") - cancelMu sync.Mutex // 保护 runningCancels - runningCancels map[string]context.CancelFunc // key: "platform_userID", 用于停止命令中断任务 -} - -// NewRobotHandler 创建机器人处理器 -func NewRobotHandler(cfg *config.Config, db *database.DB, agentHandler *AgentHandler, logger *zap.Logger) *RobotHandler { - return &RobotHandler{ - config: cfg, - db: db, - agentHandler: agentHandler, - logger: logger, - sessions: make(map[string]string), - sessionRoles: make(map[string]string), - runningCancels: make(map[string]context.CancelFunc), - } -} - -// sessionKey 生成会话 key -func (h *RobotHandler) sessionKey(platform, userID string) string { - return platform + "_" + userID -} - -// getOrCreateConversation 获取或创建当前会话,title 用于新对话的标题(取用户首条消息前50字) -func (h *RobotHandler) getOrCreateConversation(platform, userID, title string) (convID string, isNew bool) { - h.mu.RLock() - convID = h.sessions[h.sessionKey(platform, userID)] - h.mu.RUnlock() - if convID != "" { - return convID, false - } - t := strings.TrimSpace(title) - if t == "" { - t = "新对话 " + time.Now().Format("01-02 15:04") - } else { - t = safeTruncateString(t, 50) - } - conv, err := h.db.CreateConversation(t) - if err != nil { - h.logger.Warn("创建机器人会话失败", zap.Error(err)) - return "", false - } - convID = conv.ID - h.mu.Lock() - h.sessions[h.sessionKey(platform, userID)] = convID - h.mu.Unlock() - return convID, true -} - -// setConversation 切换当前会话 -func (h *RobotHandler) setConversation(platform, userID, convID string) { - h.mu.Lock() - h.sessions[h.sessionKey(platform, userID)] = convID - h.mu.Unlock() -} - -// getRole 获取当前用户使用的角色,未设置时返回"默认" -func (h *RobotHandler) getRole(platform, userID string) string { - h.mu.RLock() - role := h.sessionRoles[h.sessionKey(platform, userID)] - h.mu.RUnlock() - if role == "" { - return "默认" - } - return role -} - -// setRole 设置当前用户使用的角色 -func (h *RobotHandler) setRole(platform, userID, roleName string) { - h.mu.Lock() - h.sessionRoles[h.sessionKey(platform, userID)] = roleName - h.mu.Unlock() -} - -// clearConversation 清空当前会话(切换到新对话) -func (h *RobotHandler) clearConversation(platform, userID string) (newConvID string) { - title := "新对话 " + time.Now().Format("01-02 15:04") - conv, err := h.db.CreateConversation(title) - if err != nil { - h.logger.Warn("创建新对话失败", zap.Error(err)) - return "" - } - h.setConversation(platform, userID, conv.ID) - return conv.ID -} - -// HandleMessage 处理用户输入,返回回复文本(供各平台 webhook 调用) -func (h *RobotHandler) HandleMessage(platform, userID, text string) (reply string) { - text = strings.TrimSpace(text) - if text == "" { - return "请输入内容或发送「帮助」/ help 查看命令。" - } - - // 先尝试作为命令处理(支持中英文) - if cmdReply, ok := h.handleRobotCommand(platform, userID, text); ok { - return cmdReply - } - - // 普通消息:走 Agent - convID, _ := h.getOrCreateConversation(platform, userID, text) - if convID == "" { - return "无法创建或获取对话,请稍后再试。" - } - // 若对话标题为「新对话 xx:xx」格式(由「新对话」命令创建),将标题更新为首条消息内容,与 Web 端体验一致 - if conv, err := h.db.GetConversation(convID); err == nil && strings.HasPrefix(conv.Title, "新对话 ") { - newTitle := safeTruncateString(text, 50) - if newTitle != "" { - _ = h.db.UpdateConversationTitle(convID, newTitle) - } - } - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) - sk := h.sessionKey(platform, userID) - h.cancelMu.Lock() - h.runningCancels[sk] = cancel - h.cancelMu.Unlock() - defer func() { - cancel() - h.cancelMu.Lock() - delete(h.runningCancels, sk) - h.cancelMu.Unlock() - }() - role := h.getRole(platform, userID) - resp, newConvID, err := h.agentHandler.ProcessMessageForRobot(ctx, convID, text, role) - if err != nil { - h.logger.Warn("机器人 Agent 执行失败", zap.String("platform", platform), zap.String("userID", userID), zap.Error(err)) - if errors.Is(err, context.Canceled) { - return "任务已取消。" - } - return "处理失败: " + err.Error() - } - if newConvID != convID { - h.setConversation(platform, userID, newConvID) - } - return resp -} - -func (h *RobotHandler) cmdHelp() string { - return "**【CyberStrikeAI 机器人命令】**\n\n" + - "- `帮助` `help` — 显示本帮助 | Show this help\n" + - "- `列表` `list` — 列出所有对话标题与 ID | List conversations\n" + - "- `切换 ` `switch ` — 指定对话继续 | Switch to conversation\n" + - "- `新对话` `new` — 开启新对话 | Start new conversation\n" + - "- `清空` `clear` — 清空当前上下文 | Clear context\n" + - "- `当前` `current` — 显示当前对话 ID 与标题 | Show current conversation\n" + - "- `停止` `stop` — 中断当前任务 | Stop running task\n" + - "- `角色` `roles` — 列出所有可用角色 | List roles\n" + - "- `角色 <名>` `role ` — 切换当前角色 | Switch role\n" + - "- `删除 ` `delete ` — 删除指定对话 | Delete conversation\n" + - "- `版本` `version` — 显示当前版本号 | Show version\n\n" + - "---\n" + - "除以上命令外,直接输入内容将发送给 AI 进行渗透测试/安全分析。\n" + - "Otherwise, send any text for AI penetration testing / security analysis." -} - -func (h *RobotHandler) cmdList() string { - convs, err := h.db.ListConversations(50, 0, "") - if err != nil { - return "获取对话列表失败: " + err.Error() - } - if len(convs) == 0 { - return "暂无对话。发送任意内容将自动创建新对话。" - } - var b strings.Builder - b.WriteString("【对话列表】\n") - for i, c := range convs { - if i >= 20 { - b.WriteString("… 仅显示前 20 条\n") - break - } - b.WriteString(fmt.Sprintf("· %s\n ID: %s\n", c.Title, c.ID)) - } - return strings.TrimSuffix(b.String(), "\n") -} - -func (h *RobotHandler) cmdSwitch(platform, userID, convID string) string { - if convID == "" { - return "请指定对话 ID,例如:切换 xxx-xxx-xxx" - } - conv, err := h.db.GetConversation(convID) - if err != nil { - return "对话不存在或 ID 错误。" - } - h.setConversation(platform, userID, conv.ID) - return fmt.Sprintf("已切换到对话:「%s」\nID: %s", conv.Title, conv.ID) -} - -func (h *RobotHandler) cmdNew(platform, userID string) string { - newID := h.clearConversation(platform, userID) - if newID == "" { - return "创建新对话失败,请重试。" - } - return "已开启新对话,可直接发送内容。" -} - -func (h *RobotHandler) cmdClear(platform, userID string) string { - return h.cmdNew(platform, userID) -} - -func (h *RobotHandler) cmdStop(platform, userID string) string { - sk := h.sessionKey(platform, userID) - h.cancelMu.Lock() - cancel, ok := h.runningCancels[sk] - if ok { - delete(h.runningCancels, sk) - cancel() - } - h.cancelMu.Unlock() - if !ok { - return "当前没有正在执行的任务。" - } - return "已停止当前任务。" -} - -func (h *RobotHandler) cmdCurrent(platform, userID string) string { - h.mu.RLock() - convID := h.sessions[h.sessionKey(platform, userID)] - h.mu.RUnlock() - if convID == "" { - return "当前没有进行中的对话。发送任意内容将创建新对话。" - } - conv, err := h.db.GetConversation(convID) - if err != nil { - return "当前对话 ID: " + convID + "(获取标题失败)" - } - role := h.getRole(platform, userID) - return fmt.Sprintf("当前对话:「%s」\nID: %s\n当前角色: %s", conv.Title, conv.ID, role) -} - -func (h *RobotHandler) cmdRoles() string { - if h.config.Roles == nil || len(h.config.Roles) == 0 { - return "暂无可用角色。" - } - names := make([]string, 0, len(h.config.Roles)) - for name, role := range h.config.Roles { - if role.Enabled { - names = append(names, name) - } - } - if len(names) == 0 { - return "暂无可用角色。" - } - sort.Slice(names, func(i, j int) bool { - if names[i] == "默认" { - return true - } - if names[j] == "默认" { - return false - } - return names[i] < names[j] - }) - var b strings.Builder - b.WriteString("【角色列表】\n") - for _, name := range names { - role := h.config.Roles[name] - desc := role.Description - if desc == "" { - desc = "无描述" - } - b.WriteString(fmt.Sprintf("· %s — %s\n", name, desc)) - } - return strings.TrimSuffix(b.String(), "\n") -} - -func (h *RobotHandler) cmdSwitchRole(platform, userID, roleName string) string { - if roleName == "" { - return "请指定角色名称,例如:角色 渗透测试" - } - if h.config.Roles == nil { - return "暂无可用角色。" - } - role, exists := h.config.Roles[roleName] - if !exists { - return fmt.Sprintf("角色「%s」不存在。发送「角色」查看可用角色。", roleName) - } - if !role.Enabled { - return fmt.Sprintf("角色「%s」已禁用。", roleName) - } - h.setRole(platform, userID, roleName) - return fmt.Sprintf("已切换到角色:「%s」\n%s", roleName, role.Description) -} - -func (h *RobotHandler) cmdDelete(platform, userID, convID string) string { - if convID == "" { - return "请指定对话 ID,例如:删除 xxx-xxx-xxx" - } - sk := h.sessionKey(platform, userID) - h.mu.RLock() - currentConvID := h.sessions[sk] - h.mu.RUnlock() - if convID == currentConvID { - // 删除当前对话时,先清空会话绑定 - h.mu.Lock() - delete(h.sessions, sk) - h.mu.Unlock() - } - if err := h.db.DeleteConversation(convID); err != nil { - return "删除失败: " + err.Error() - } - return fmt.Sprintf("已删除对话 ID: %s", convID) -} - -func (h *RobotHandler) cmdVersion() string { - v := h.config.Version - if v == "" { - v = "未知" - } - return "CyberStrikeAI " + v -} - -// handleRobotCommand 处理机器人内置命令;若匹配到命令返回 (回复内容, true),否则返回 ("", false) -func (h *RobotHandler) handleRobotCommand(platform, userID, text string) (string, bool) { - switch { - case text == robotCmdHelp || text == "help" || text == "?" || text == "?": - return h.cmdHelp(), true - case text == robotCmdList || text == robotCmdListAlt || text == "list": - return h.cmdList(), true - case strings.HasPrefix(text, robotCmdSwitch+" ") || strings.HasPrefix(text, robotCmdContinue+" ") || strings.HasPrefix(text, "switch ") || strings.HasPrefix(text, "continue "): - var id string - switch { - case strings.HasPrefix(text, robotCmdSwitch+" "): - id = strings.TrimSpace(text[len(robotCmdSwitch)+1:]) - case strings.HasPrefix(text, robotCmdContinue+" "): - id = strings.TrimSpace(text[len(robotCmdContinue)+1:]) - case strings.HasPrefix(text, "switch "): - id = strings.TrimSpace(text[7:]) - default: - id = strings.TrimSpace(text[9:]) - } - return h.cmdSwitch(platform, userID, id), true - case text == robotCmdNew || text == "new": - return h.cmdNew(platform, userID), true - case text == robotCmdClear || text == "clear": - return h.cmdClear(platform, userID), true - case text == robotCmdCurrent || text == "current": - return h.cmdCurrent(platform, userID), true - case text == robotCmdStop || text == "stop": - return h.cmdStop(platform, userID), true - case text == robotCmdRoles || text == robotCmdRolesList || text == "roles": - return h.cmdRoles(), true - case strings.HasPrefix(text, robotCmdRoles+" ") || strings.HasPrefix(text, robotCmdSwitchRole+" ") || strings.HasPrefix(text, "role "): - var roleName string - switch { - case strings.HasPrefix(text, robotCmdRoles+" "): - roleName = strings.TrimSpace(text[len(robotCmdRoles)+1:]) - case strings.HasPrefix(text, robotCmdSwitchRole+" "): - roleName = strings.TrimSpace(text[len(robotCmdSwitchRole)+1:]) - default: - roleName = strings.TrimSpace(text[5:]) - } - return h.cmdSwitchRole(platform, userID, roleName), true - case strings.HasPrefix(text, robotCmdDelete+" ") || strings.HasPrefix(text, "delete "): - var convID string - if strings.HasPrefix(text, robotCmdDelete+" ") { - convID = strings.TrimSpace(text[len(robotCmdDelete)+1:]) - } else { - convID = strings.TrimSpace(text[7:]) - } - return h.cmdDelete(platform, userID, convID), true - case text == robotCmdVersion || text == "version": - return h.cmdVersion(), true - default: - return "", false - } -} - -// —————— 企业微信 —————— - -// wecomXML 企业微信回调 XML(明文模式下的简化结构;加密模式需先解密再解析) -type wecomXML struct { - ToUserName string `xml:"ToUserName"` - FromUserName string `xml:"FromUserName"` - CreateTime int64 `xml:"CreateTime"` - MsgType string `xml:"MsgType"` - Content string `xml:"Content"` - MsgID string `xml:"MsgId"` - AgentID int64 `xml:"AgentID"` - Encrypt string `xml:"Encrypt"` // 加密模式下消息在此 -} - -// wecomReplyXML 被动回复 XML(仅用于兼容,当前使用手动构造 XML) -type wecomReplyXML struct { - XMLName xml.Name `xml:"xml"` - ToUserName string `xml:"ToUserName"` - FromUserName string `xml:"FromUserName"` - CreateTime int64 `xml:"CreateTime"` - MsgType string `xml:"MsgType"` - Content string `xml:"Content"` -} - -// HandleWecomGET 企业微信 URL 校验(GET) -func (h *RobotHandler) HandleWecomGET(c *gin.Context) { - if !h.config.Robots.Wecom.Enabled { - c.String(http.StatusNotFound, "") - return - } - // Gin 的 Query() 会自动 URL 解码,拿到的就是正确的 base64 字符串 - echostr := c.Query("echostr") - msgSignature := c.Query("msg_signature") - timestamp := c.Query("timestamp") - nonce := c.Query("nonce") - - // 验证签名:将 token、timestamp、nonce、echostr 四个参数排序后拼接计算 SHA1 - signature := h.signWecomRequest(h.config.Robots.Wecom.Token, timestamp, nonce, echostr) - if signature != msgSignature { - h.logger.Warn("企业微信 URL 验证签名失败", zap.String("expected", msgSignature), zap.String("got", signature)) - c.String(http.StatusBadRequest, "invalid signature") - return - } - - if echostr == "" { - c.String(http.StatusBadRequest, "missing echostr") - return - } - - // 如果配置了 EncodingAESKey,说明是加密模式,需要解密 echostr - if h.config.Robots.Wecom.EncodingAESKey != "" { - decrypted, err := wecomDecrypt(h.config.Robots.Wecom.EncodingAESKey, echostr) - if err != nil { - h.logger.Warn("企业微信 echostr 解密失败", zap.Error(err)) - c.String(http.StatusBadRequest, "decrypt failed") - return - } - c.String(http.StatusOK, string(decrypted)) - return - } - - // 明文模式直接返回 echostr - c.String(http.StatusOK, echostr) -} - -// signWecomRequest 生成企业微信请求签名 -// 企业微信签名算法:将 token、timestamp、nonce、echostr 四个值排序后拼接成字符串,再计算 SHA1 -func (h *RobotHandler) signWecomRequest(token, timestamp, nonce, echostr string) string { - strs := []string{token, timestamp, nonce, echostr} - sort.Strings(strs) - s := strings.Join(strs, "") - hash := sha1.Sum([]byte(s)) - return fmt.Sprintf("%x", hash) -} - -// wecomDecrypt 企业微信消息解密(AES-256-CBC,PKCS7,明文格式:16字节随机+4字节长度+消息+corpID) -func wecomDecrypt(encodingAESKey, encryptedB64 string) ([]byte, error) { - key, err := base64.StdEncoding.DecodeString(encodingAESKey + "=") - if err != nil { - return nil, err - } - if len(key) != 32 { - return nil, fmt.Errorf("encoding_aes_key 解码后应为 32 字节") - } - ciphertext, err := base64.StdEncoding.DecodeString(encryptedB64) - if err != nil { - return nil, err - } - block, err := aes.NewCipher(key) - if err != nil { - return nil, err - } - iv := key[:16] - mode := cipher.NewCBCDecrypter(block, iv) - if len(ciphertext)%aes.BlockSize != 0 { - return nil, fmt.Errorf("密文长度不是块大小的倍数") - } - plain := make([]byte, len(ciphertext)) - mode.CryptBlocks(plain, ciphertext) - // 去除 PKCS7 填充 - n := int(plain[len(plain)-1]) - if n < 1 || n > 32 { - return nil, fmt.Errorf("无效的 PKCS7 填充") - } - plain = plain[:len(plain)-n] - // 企业微信格式:16 字节随机 + 4 字节长度(大端) + 消息 + corpID - if len(plain) < 20 { - return nil, fmt.Errorf("明文过短") - } - msgLen := binary.BigEndian.Uint32(plain[16:20]) - if int(20+msgLen) > len(plain) { - return nil, fmt.Errorf("消息长度越界") - } - return plain[20 : 20+msgLen], nil -} - -// wecomEncrypt 企业微信消息加密(AES-256-CBC,PKCS7,明文格式:16字节随机+4字节长度+消息+corpID) -func wecomEncrypt(encodingAESKey, message, corpID string) (string, error) { - key, err := base64.StdEncoding.DecodeString(encodingAESKey + "=") - if err != nil { - return "", err - } - if len(key) != 32 { - return "", fmt.Errorf("encoding_aes_key 解码后应为 32 字节") - } - // 构造明文:16 字节随机 + 4 字节长度 (大端) + 消息 + corpID - random := make([]byte, 16) - if _, err := rand.Read(random); err != nil { - // 降级方案:使用时间戳生成随机数 - for i := range random { - random[i] = byte(time.Now().UnixNano() % 256) - } - } - msgLen := len(message) - msgBytes := []byte(message) - corpBytes := []byte(corpID) - plain := make([]byte, 16+4+msgLen+len(corpBytes)) - copy(plain[:16], random) - binary.BigEndian.PutUint32(plain[16:20], uint32(msgLen)) - copy(plain[20:20+msgLen], msgBytes) - copy(plain[20+msgLen:], corpBytes) - // PKCS7 填充 - padding := aes.BlockSize - len(plain)%aes.BlockSize - pad := bytes.Repeat([]byte{byte(padding)}, padding) - plain = append(plain, pad...) - // AES-256-CBC 加密 - block, err := aes.NewCipher(key) - if err != nil { - return "", err - } - iv := key[:16] - ciphertext := make([]byte, len(plain)) - mode := cipher.NewCBCEncrypter(block, iv) - mode.CryptBlocks(ciphertext, plain) - return base64.StdEncoding.EncodeToString(ciphertext), nil -} - -// HandleWecomPOST 企业微信消息回调(POST),支持明文与加密模式 -func (h *RobotHandler) HandleWecomPOST(c *gin.Context) { - if !h.config.Robots.Wecom.Enabled { - h.logger.Debug("企业微信机器人未启用,跳过请求") - c.String(http.StatusOK, "") - return - } - // 从 URL 获取签名参数(加密模式回复时需要用到) - timestamp := c.Query("timestamp") - nonce := c.Query("nonce") - msgSignature := c.Query("msg_signature") - - // 先读取请求体,后续解析/签名验证都会用到 - bodyRaw, err := io.ReadAll(c.Request.Body) - if err != nil { - h.logger.Warn("企业微信 POST 读取请求体失败", zap.Error(err)) - c.String(http.StatusOK, "") - return - } - h.logger.Debug("企业微信 POST 收到请求", zap.String("body", string(bodyRaw))) - - // 验证请求签名防止伪造。企业微信签名算法同 URL 验证,使用 token、timestamp、nonce、 Encrypt 四个字段 - // 若配置了 Token 则必须校验签名,避免未授权请求触发 Agent(防止平台被接管) - token := h.config.Robots.Wecom.Token - if token != "" { - if msgSignature == "" { - h.logger.Warn("企业微信 POST 缺少签名,已拒绝(需配置 token 并确保回调携带 msg_signature)") - c.String(http.StatusOK, "") - return - } - var tmp wecomXML - if err := xml.Unmarshal(bodyRaw, &tmp); err != nil { - h.logger.Warn("企业微信 POST 签名验证前解析 XML 失败", zap.Error(err)) - c.String(http.StatusOK, "") - return - } - expected := h.signWecomRequest(token, timestamp, nonce, tmp.Encrypt) - if expected != msgSignature { - h.logger.Warn("企业微信 POST 签名验证失败", zap.String("expected", expected), zap.String("got", msgSignature)) - c.String(http.StatusOK, "") - return - } - } - - var body wecomXML - if err := xml.Unmarshal(bodyRaw, &body); err != nil { - h.logger.Warn("企业微信 POST 解析 XML 失败", zap.Error(err)) - c.String(http.StatusOK, "") - return - } - h.logger.Debug("企业微信 XML 解析成功", zap.String("ToUserName", body.ToUserName), zap.String("FromUserName", body.FromUserName), zap.String("MsgType", body.MsgType), zap.String("Content", body.Content), zap.String("Encrypt", body.Encrypt)) - - // 保存企业 ID(用于明文模式回复) - enterpriseID := body.ToUserName - - // 加密模式:先解密再解析内层 XML - if body.Encrypt != "" && h.config.Robots.Wecom.EncodingAESKey != "" { - h.logger.Debug("企业微信进入加密模式解密流程") - decrypted, err := wecomDecrypt(h.config.Robots.Wecom.EncodingAESKey, body.Encrypt) - if err != nil { - h.logger.Warn("企业微信消息解密失败", zap.Error(err)) - c.String(http.StatusOK, "") - return - } - h.logger.Debug("企业微信解密成功", zap.String("decrypted", string(decrypted))) - if err := xml.Unmarshal(decrypted, &body); err != nil { - h.logger.Warn("企业微信解密后 XML 解析失败", zap.Error(err)) - c.String(http.StatusOK, "") - return - } - h.logger.Debug("企业微信内层 XML 解析成功", zap.String("FromUserName", body.FromUserName), zap.String("Content", body.Content)) - } - - userID := body.FromUserName - text := strings.TrimSpace(body.Content) - - // 限制回复内容长度(企业微信限制 2048 字节) - maxReplyLen := 2000 - limitReply := func(s string) string { - if len(s) > maxReplyLen { - return s[:maxReplyLen] + "\n\n(内容过长,已截断)" - } - return s - } - - if body.MsgType != "text" { - h.logger.Debug("企业微信收到非文本消息", zap.String("MsgType", body.MsgType)) - h.sendWecomReply(c, userID, enterpriseID, limitReply("暂仅支持文本消息,请发送文字。"), timestamp, nonce) - return - } - - // 文本消息:先判断是否为内置命令(如 帮助/列表/新对话 等),这类命令处理很快,可以直接走被动回复,避免依赖主动发送 API。 - if cmdReply, ok := h.handleRobotCommand("wecom", userID, text); ok { - h.logger.Debug("企业微信收到命令消息,走被动回复", zap.String("userID", userID), zap.String("text", text)) - h.sendWecomReply(c, userID, enterpriseID, limitReply(cmdReply), timestamp, nonce) - return - } - - h.logger.Debug("企业微信开始处理消息(异步 AI)", zap.String("userID", userID), zap.String("text", text)) - - // 企业微信被动回复有 5 秒超时限制,而 AI 调用通常超过该时长。 - // 这里采用推荐做法:立即返回 success(或空串),然后通过主动发送接口推送完整回复。 - c.String(http.StatusOK, "success") - - // 异步处理消息并通过企业微信主动消息接口发送结果 - go func() { - reply := h.HandleMessage("wecom", userID, text) - reply = limitReply(reply) - h.logger.Debug("企业微信消息处理完成", zap.String("userID", userID), zap.String("reply", reply)) - // 调用企业微信 API 主动发送消息 - h.sendWecomMessageViaAPI(userID, enterpriseID, reply) - }() -} - -// sendWecomReply 发送企业微信回复(加密模式自动加密) -// 参数:toUser=用户 ID, fromUser=企业 ID(明文模式)/CorpID(加密模式), content=回复内容,timestamp/nonce=请求参数 -func (h *RobotHandler) sendWecomReply(c *gin.Context, toUser, fromUser, content, timestamp, nonce string) { - // 加密模式:判断 EncodingAESKey 是否配置 - if h.config.Robots.Wecom.EncodingAESKey != "" { - // 加密模式使用 CorpID 进行加密 - corpID := h.config.Robots.Wecom.CorpID - if corpID == "" { - h.logger.Warn("企业微信加密模式缺少 CorpID 配置") - c.String(http.StatusOK, "") - return - } - - // 构造完整的明文 XML 回复(格式严格按企业微信文档要求) - plainResp := fmt.Sprintf(` - - -%d - - -`, toUser, fromUser, time.Now().Unix(), content) - - encrypted, err := wecomEncrypt(h.config.Robots.Wecom.EncodingAESKey, plainResp, corpID) - if err != nil { - h.logger.Warn("企业微信回复加密失败", zap.Error(err)) - c.String(http.StatusOK, "") - return - } - // 使用请求中的 timestamp/nonce 生成签名(企业微信要求回复时使用与请求相同的 timestamp 和 nonce) - msgSignature := h.signWecomRequest(h.config.Robots.Wecom.Token, timestamp, nonce, encrypted) - - h.logger.Debug("企业微信发送加密回复", - zap.String("Encrypt", encrypted[:50]+"..."), - zap.String("MsgSignature", msgSignature), - zap.String("TimeStamp", timestamp), - zap.String("Nonce", nonce)) - - // 加密模式仅返回 4 个核心字段(企业微信官方要求) - xmlResp := fmt.Sprintf(``, encrypted, msgSignature, timestamp, nonce) - // also log the final response body so we can cross-check with the - // network traffic or developer console - h.logger.Debug("企业微信加密回复包", zap.String("xml", xmlResp)) - // for additional confidence, decrypt the payload ourselves and log it - if dec, err2 := wecomDecrypt(h.config.Robots.Wecom.EncodingAESKey, encrypted); err2 == nil { - h.logger.Debug("企业微信加密回复解密检查", zap.String("plain", string(dec))) - } else { - h.logger.Warn("企业微信加密回复解密检查失败", zap.Error(err2)) - } - - // 使用 c.Writer.Write 直接写入响应,避免 c.String 的转义问题 - c.Writer.WriteHeader(http.StatusOK) - // use text/xml as that's what WeCom examples show - c.Writer.Header().Set("Content-Type", "text/xml; charset=utf-8") - _, _ = c.Writer.Write([]byte(xmlResp)) - h.logger.Debug("企业微信加密回复已发送") - return - } - - // 明文模式 - h.logger.Debug("企业微信发送明文回复", zap.String("ToUserName", toUser), zap.String("FromUserName", fromUser), zap.String("Content", content[:50]+"...")) - - // 手动构造 XML 响应(使用 CDATA 包裹所有字段,并包含 AgentID) - xmlResp := fmt.Sprintf(` - - -%d - - -`, toUser, fromUser, time.Now().Unix(), content) - - // log the exact plaintext response for debugging - h.logger.Debug("企业微信明文回复包", zap.String("xml", xmlResp)) - - // use text/xml as recommended by WeCom docs - c.Header("Content-Type", "text/xml; charset=utf-8") - c.String(http.StatusOK, xmlResp) - h.logger.Debug("企业微信明文回复已发送") -} - -// —————— 测试接口(需登录,用于验证机器人逻辑,无需钉钉/飞书客户端) —————— - -// RobotTestRequest 模拟机器人消息请求 -type RobotTestRequest struct { - Platform string `json:"platform"` // 如 "dingtalk"、"lark"、"wecom" - UserID string `json:"user_id"` - Text string `json:"text"` -} - -// HandleRobotTest 供本地验证:POST JSON { "platform", "user_id", "text" },返回 { "reply": "..." } -func (h *RobotHandler) HandleRobotTest(c *gin.Context) { - var req RobotTestRequest - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": "请求体需为 JSON,包含 platform、user_id、text"}) - return - } - platform := strings.TrimSpace(req.Platform) - if platform == "" { - platform = "test" - } - userID := strings.TrimSpace(req.UserID) - if userID == "" { - userID = "test_user" - } - reply := h.HandleMessage(platform, userID, req.Text) - c.JSON(http.StatusOK, gin.H{"reply": reply}) -} - -// sendWecomMessageViaAPI 通过企业微信 API 主动发送消息(用于异步处理后的结果发送) -func (h *RobotHandler) sendWecomMessageViaAPI(toUser, toParty, content string) { - if !h.config.Robots.Wecom.Enabled { - return - } - - secret := h.config.Robots.Wecom.Secret - corpID := h.config.Robots.Wecom.CorpID - agentID := h.config.Robots.Wecom.AgentID - - if secret == "" || corpID == "" { - h.logger.Warn("企业微信主动 API 缺少 secret 或 corpID 配置") - return - } - - // 第 1 步:获取 access_token - tokenURL := fmt.Sprintf("https://qyapi.weixin.qq.com/cgi-bin/gettoken?corpid=%s&corpsecret=%s", corpID, secret) - resp, err := http.Get(tokenURL) - if err != nil { - h.logger.Warn("企业微信获取 token 失败", zap.Error(err)) - return - } - defer resp.Body.Close() - - var tokenResp struct { - AccessToken string `json:"access_token"` - ErrCode int `json:"errcode"` - ErrMsg string `json:"errmsg"` - } - if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil { - h.logger.Warn("企业微信 token 响应解析失败", zap.Error(err)) - return - } - if tokenResp.ErrCode != 0 { - h.logger.Warn("企业微信 token 获取错误", zap.String("errmsg", tokenResp.ErrMsg), zap.Int("errcode", tokenResp.ErrCode)) - return - } - - // 第 2 步:构造发送消息请求 - msgReq := map[string]interface{}{ - "touser": toUser, - "msgtype": "text", - "agentid": agentID, - "text": map[string]interface{}{ - "content": content, - }, - } - - msgBody, err := json.Marshal(msgReq) - if err != nil { - h.logger.Warn("企业微信消息序列化失败", zap.Error(err)) - return - } - - // 第 3 步:发送消息 - sendURL := fmt.Sprintf("https://qyapi.weixin.qq.com/cgi-bin/message/send?access_token=%s", tokenResp.AccessToken) - msgResp, err := http.Post(sendURL, "application/json", bytes.NewReader(msgBody)) - if err != nil { - h.logger.Warn("企业微信主动发送消息失败", zap.Error(err)) - return - } - defer msgResp.Body.Close() - - var sendResp struct { - ErrCode int `json:"errcode"` - ErrMsg string `json:"errmsg"` - InvalidUser string `json:"invaliduser"` - MsgID string `json:"msgid"` - } - if err := json.NewDecoder(msgResp.Body).Decode(&sendResp); err != nil { - h.logger.Warn("企业微信发送响应解析失败", zap.Error(err)) - return - } - - if sendResp.ErrCode == 0 { - h.logger.Debug("企业微信主动发送消息成功", zap.String("msgid", sendResp.MsgID)) - } else { - h.logger.Warn("企业微信主动发送消息失败", zap.String("errmsg", sendResp.ErrMsg), zap.Int("errcode", sendResp.ErrCode), zap.String("invaliduser", sendResp.InvalidUser)) - } -} - -// —————— 钉钉 —————— - -// HandleDingtalkPOST 钉钉事件回调(流式接入等);当前为占位,返回 200 -func (h *RobotHandler) HandleDingtalkPOST(c *gin.Context) { - if !h.config.Robots.Dingtalk.Enabled { - c.JSON(http.StatusOK, gin.H{}) - return - } - // 钉钉流式/事件回调格式需按官方文档解析并异步回复,此处仅返回 200 - c.JSON(http.StatusOK, gin.H{"message": "ok"}) -} - -// —————— 飞书 —————— - -// HandleLarkPOST 飞书事件回调;当前为占位,返回 200;验证时需返回 challenge -func (h *RobotHandler) HandleLarkPOST(c *gin.Context) { - if !h.config.Robots.Lark.Enabled { - c.JSON(http.StatusOK, gin.H{}) - return - } - var body struct { - Challenge string `json:"challenge"` - } - if err := c.ShouldBindJSON(&body); err == nil && body.Challenge != "" { - c.JSON(http.StatusOK, gin.H{"challenge": body.Challenge}) - return - } - c.JSON(http.StatusOK, gin.H{}) -} diff --git a/handler/role.go b/handler/role.go deleted file mode 100644 index 88c42138..00000000 --- a/handler/role.go +++ /dev/null @@ -1,487 +0,0 @@ -package handler - -import ( - "fmt" - "net/http" - "os" - "path/filepath" - "regexp" - "strings" - - "cyberstrike-ai/internal/config" - - "gopkg.in/yaml.v3" - - "github.com/gin-gonic/gin" - "go.uber.org/zap" -) - -// RoleHandler 角色处理器 -type RoleHandler struct { - config *config.Config - configPath string - logger *zap.Logger - skillsManager SkillsManager // Skills管理器接口(可选) -} - -// SkillsManager Skills管理器接口 -type SkillsManager interface { - ListSkills() ([]string, error) -} - -// NewRoleHandler 创建新的角色处理器 -func NewRoleHandler(cfg *config.Config, configPath string, logger *zap.Logger) *RoleHandler { - return &RoleHandler{ - config: cfg, - configPath: configPath, - logger: logger, - } -} - -// SetSkillsManager 设置Skills管理器 -func (h *RoleHandler) SetSkillsManager(manager SkillsManager) { - h.skillsManager = manager -} - -// GetSkills 获取所有可用的skills列表 -func (h *RoleHandler) GetSkills(c *gin.Context) { - if h.skillsManager == nil { - c.JSON(http.StatusOK, gin.H{ - "skills": []string{}, - }) - return - } - - skills, err := h.skillsManager.ListSkills() - if err != nil { - h.logger.Warn("获取skills列表失败", zap.Error(err)) - c.JSON(http.StatusOK, gin.H{ - "skills": []string{}, - }) - return - } - - c.JSON(http.StatusOK, gin.H{ - "skills": skills, - }) -} - -// GetRoles 获取所有角色 -func (h *RoleHandler) GetRoles(c *gin.Context) { - if h.config.Roles == nil { - h.config.Roles = make(map[string]config.RoleConfig) - } - - roles := make([]config.RoleConfig, 0, len(h.config.Roles)) - for key, role := range h.config.Roles { - // 确保角色的key与name一致 - if role.Name == "" { - role.Name = key - } - roles = append(roles, role) - } - - c.JSON(http.StatusOK, gin.H{ - "roles": roles, - }) -} - -// GetRole 获取单个角色 -func (h *RoleHandler) GetRole(c *gin.Context) { - roleName := c.Param("name") - if roleName == "" { - c.JSON(http.StatusBadRequest, gin.H{"error": "角色名称不能为空"}) - return - } - - if h.config.Roles == nil { - c.JSON(http.StatusNotFound, gin.H{"error": "角色不存在"}) - return - } - - role, exists := h.config.Roles[roleName] - if !exists { - c.JSON(http.StatusNotFound, gin.H{"error": "角色不存在"}) - return - } - - // 确保角色的name与key一致 - if role.Name == "" { - role.Name = roleName - } - - c.JSON(http.StatusOK, gin.H{ - "role": role, - }) -} - -// UpdateRole 更新角色 -func (h *RoleHandler) UpdateRole(c *gin.Context) { - roleName := c.Param("name") - if roleName == "" { - c.JSON(http.StatusBadRequest, gin.H{"error": "角色名称不能为空"}) - return - } - - var req config.RoleConfig - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()}) - return - } - - // 确保角色名称与请求中的name一致 - if req.Name == "" { - req.Name = roleName - } - - // 初始化Roles map - if h.config.Roles == nil { - h.config.Roles = make(map[string]config.RoleConfig) - } - - // 删除所有与角色name相同但key不同的旧角色(避免重复) - // 使用角色name作为key,确保唯一性 - finalKey := req.Name - keysToDelete := make([]string, 0) - for key := range h.config.Roles { - // 如果key与最终的key不同,但name相同,则标记为删除 - if key != finalKey { - role := h.config.Roles[key] - // 确保角色的name字段正确设置 - if role.Name == "" { - role.Name = key - } - if role.Name == req.Name { - keysToDelete = append(keysToDelete, key) - } - } - } - // 删除旧的角色 - for _, key := range keysToDelete { - delete(h.config.Roles, key) - h.logger.Info("删除重复的角色", zap.String("oldKey", key), zap.String("name", req.Name)) - } - - // 如果当前更新的key与最终key不同,也需要删除旧的 - if roleName != finalKey { - delete(h.config.Roles, roleName) - } - - // 如果角色名称改变,需要删除旧文件 - if roleName != finalKey { - configDir := filepath.Dir(h.configPath) - rolesDir := h.config.RolesDir - if rolesDir == "" { - rolesDir = "roles" // 默认目录 - } - - // 如果是相对路径,相对于配置文件所在目录 - if !filepath.IsAbs(rolesDir) { - rolesDir = filepath.Join(configDir, rolesDir) - } - - // 删除旧的角色文件 - oldSafeFileName := sanitizeFileName(roleName) - oldRoleFileYaml := filepath.Join(rolesDir, oldSafeFileName+".yaml") - oldRoleFileYml := filepath.Join(rolesDir, oldSafeFileName+".yml") - - if _, err := os.Stat(oldRoleFileYaml); err == nil { - if err := os.Remove(oldRoleFileYaml); err != nil { - h.logger.Warn("删除旧角色配置文件失败", zap.String("file", oldRoleFileYaml), zap.Error(err)) - } - } - if _, err := os.Stat(oldRoleFileYml); err == nil { - if err := os.Remove(oldRoleFileYml); err != nil { - h.logger.Warn("删除旧角色配置文件失败", zap.String("file", oldRoleFileYml), zap.Error(err)) - } - } - } - - // 使用角色name作为key来保存(确保唯一性) - h.config.Roles[finalKey] = req - - // 保存配置到文件 - if err := h.saveConfig(); err != nil { - h.logger.Error("保存配置失败", zap.Error(err)) - c.JSON(http.StatusInternalServerError, gin.H{"error": "保存配置失败: " + err.Error()}) - return - } - - h.logger.Info("更新角色", zap.String("oldKey", roleName), zap.String("newKey", finalKey), zap.String("name", req.Name)) - c.JSON(http.StatusOK, gin.H{ - "message": "角色已更新", - "role": req, - }) -} - -// CreateRole 创建新角色 -func (h *RoleHandler) CreateRole(c *gin.Context) { - var req config.RoleConfig - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()}) - return - } - - if req.Name == "" { - c.JSON(http.StatusBadRequest, gin.H{"error": "角色名称不能为空"}) - return - } - - // 初始化Roles map - if h.config.Roles == nil { - h.config.Roles = make(map[string]config.RoleConfig) - } - - // 检查角色是否已存在 - if _, exists := h.config.Roles[req.Name]; exists { - c.JSON(http.StatusBadRequest, gin.H{"error": "角色已存在"}) - return - } - - // 创建角色(默认启用) - if !req.Enabled { - req.Enabled = true - } - - h.config.Roles[req.Name] = req - - // 保存配置到文件 - if err := h.saveConfig(); err != nil { - h.logger.Error("保存配置失败", zap.Error(err)) - c.JSON(http.StatusInternalServerError, gin.H{"error": "保存配置失败: " + err.Error()}) - return - } - - h.logger.Info("创建角色", zap.String("roleName", req.Name)) - c.JSON(http.StatusOK, gin.H{ - "message": "角色已创建", - "role": req, - }) -} - -// DeleteRole 删除角色 -func (h *RoleHandler) DeleteRole(c *gin.Context) { - roleName := c.Param("name") - if roleName == "" { - c.JSON(http.StatusBadRequest, gin.H{"error": "角色名称不能为空"}) - return - } - - if h.config.Roles == nil { - c.JSON(http.StatusNotFound, gin.H{"error": "角色不存在"}) - return - } - - if _, exists := h.config.Roles[roleName]; !exists { - c.JSON(http.StatusNotFound, gin.H{"error": "角色不存在"}) - return - } - - // 不允许删除"默认"角色 - if roleName == "默认" { - c.JSON(http.StatusBadRequest, gin.H{"error": "不能删除默认角色"}) - return - } - - delete(h.config.Roles, roleName) - - // 删除对应的角色文件 - configDir := filepath.Dir(h.configPath) - rolesDir := h.config.RolesDir - if rolesDir == "" { - rolesDir = "roles" // 默认目录 - } - - // 如果是相对路径,相对于配置文件所在目录 - if !filepath.IsAbs(rolesDir) { - rolesDir = filepath.Join(configDir, rolesDir) - } - - // 尝试删除角色文件(.yaml 和 .yml) - safeFileName := sanitizeFileName(roleName) - roleFileYaml := filepath.Join(rolesDir, safeFileName+".yaml") - roleFileYml := filepath.Join(rolesDir, safeFileName+".yml") - - // 删除 .yaml 文件(如果存在) - if _, err := os.Stat(roleFileYaml); err == nil { - if err := os.Remove(roleFileYaml); err != nil { - h.logger.Warn("删除角色配置文件失败", zap.String("file", roleFileYaml), zap.Error(err)) - } else { - h.logger.Info("已删除角色配置文件", zap.String("file", roleFileYaml)) - } - } - - // 删除 .yml 文件(如果存在) - if _, err := os.Stat(roleFileYml); err == nil { - if err := os.Remove(roleFileYml); err != nil { - h.logger.Warn("删除角色配置文件失败", zap.String("file", roleFileYml), zap.Error(err)) - } else { - h.logger.Info("已删除角色配置文件", zap.String("file", roleFileYml)) - } - } - - h.logger.Info("删除角色", zap.String("roleName", roleName)) - c.JSON(http.StatusOK, gin.H{ - "message": "角色已删除", - }) -} - -// saveConfig 保存配置到目录中的文件 -func (h *RoleHandler) saveConfig() error { - configDir := filepath.Dir(h.configPath) - rolesDir := h.config.RolesDir - if rolesDir == "" { - rolesDir = "roles" // 默认目录 - } - - // 如果是相对路径,相对于配置文件所在目录 - if !filepath.IsAbs(rolesDir) { - rolesDir = filepath.Join(configDir, rolesDir) - } - - // 确保目录存在 - if err := os.MkdirAll(rolesDir, 0755); err != nil { - return fmt.Errorf("创建角色目录失败: %w", err) - } - - // 保存每个角色到独立的文件 - if h.config.Roles != nil { - for roleName, role := range h.config.Roles { - // 确保角色名称正确设置 - if role.Name == "" { - role.Name = roleName - } - - // 使用角色名称作为文件名(安全化文件名,避免特殊字符) - safeFileName := sanitizeFileName(role.Name) - roleFile := filepath.Join(rolesDir, safeFileName+".yaml") - - // 将角色配置序列化为YAML - roleData, err := yaml.Marshal(&role) - if err != nil { - h.logger.Error("序列化角色配置失败", zap.String("role", roleName), zap.Error(err)) - continue - } - - // 处理icon字段:确保包含\U的icon值被引号包围(YAML需要引号才能正确解析Unicode转义) - roleDataStr := string(roleData) - if role.Icon != "" && strings.HasPrefix(role.Icon, "\\U") { - // 匹配 icon: \UXXXXXXXX 格式(没有引号),排除已经有引号的情况 - // 使用负向前瞻确保后面没有引号,或者直接匹配没有引号的情况 - re := regexp.MustCompile(`(?m)^(icon:\s+)(\\U[0-9A-F]{8})(\s*)$`) - roleDataStr = re.ReplaceAllString(roleDataStr, `${1}"${2}"${3}`) - roleData = []byte(roleDataStr) - } - - // 写入文件 - if err := os.WriteFile(roleFile, roleData, 0644); err != nil { - h.logger.Error("保存角色配置文件失败", zap.String("role", roleName), zap.String("file", roleFile), zap.Error(err)) - continue - } - - h.logger.Info("角色配置已保存到文件", zap.String("role", roleName), zap.String("file", roleFile)) - } - } - - return nil -} - -// sanitizeFileName 将角色名称转换为安全的文件名 -func sanitizeFileName(name string) string { - // 替换可能不安全的字符 - replacer := map[rune]string{ - '/': "_", - '\\': "_", - ':': "_", - '*': "_", - '?': "_", - '"': "_", - '<': "_", - '>': "_", - '|': "_", - ' ': "_", - } - - var result []rune - for _, r := range name { - if replacement, ok := replacer[r]; ok { - result = append(result, []rune(replacement)...) - } else { - result = append(result, r) - } - } - - fileName := string(result) - // 如果文件名为空,使用默认名称 - if fileName == "" { - fileName = "role" - } - - return fileName -} - -// updateRolesConfig 更新角色配置 -func updateRolesConfig(doc *yaml.Node, cfg config.RolesConfig) { - root := doc.Content[0] - rolesNode := ensureMap(root, "roles") - - // 清空现有角色 - if rolesNode.Kind == yaml.MappingNode { - rolesNode.Content = nil - } - - // 添加新角色(使用name作为key,确保唯一性) - if cfg.Roles != nil { - // 先建立一个以name为key的map,去重(保留最后一个) - rolesByName := make(map[string]config.RoleConfig) - for roleKey, role := range cfg.Roles { - // 确保角色的name字段正确设置 - if role.Name == "" { - role.Name = roleKey - } - // 使用name作为最终key,如果有多个key对应相同的name,只保留最后一个 - rolesByName[role.Name] = role - } - - // 将去重后的角色写入YAML - for roleName, role := range rolesByName { - roleNode := ensureMap(rolesNode, roleName) - setStringInMap(roleNode, "name", role.Name) - setStringInMap(roleNode, "description", role.Description) - setStringInMap(roleNode, "user_prompt", role.UserPrompt) - if role.Icon != "" { - setStringInMap(roleNode, "icon", role.Icon) - } - setBoolInMap(roleNode, "enabled", role.Enabled) - - // 添加工具列表(优先使用tools字段) - if len(role.Tools) > 0 { - toolsNode := ensureArray(roleNode, "tools") - toolsNode.Content = nil - for _, toolKey := range role.Tools { - toolNode := &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: toolKey} - toolsNode.Content = append(toolsNode.Content, toolNode) - } - } else if len(role.MCPs) > 0 { - // 向后兼容:如果没有tools但有mcps,保存mcps - mcpsNode := ensureArray(roleNode, "mcps") - mcpsNode.Content = nil - for _, mcpName := range role.MCPs { - mcpNode := &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: mcpName} - mcpsNode.Content = append(mcpsNode.Content, mcpNode) - } - } - } - } -} - -// ensureArray 确保数组中存在指定key的数组节点 -func ensureArray(parent *yaml.Node, key string) *yaml.Node { - _, valueNode := ensureKeyValue(parent, key) - if valueNode.Kind != yaml.SequenceNode { - valueNode.Kind = yaml.SequenceNode - valueNode.Tag = "!!seq" - valueNode.Content = nil - } - return valueNode -} diff --git a/handler/skills.go b/handler/skills.go deleted file mode 100644 index f6577292..00000000 --- a/handler/skills.go +++ /dev/null @@ -1,758 +0,0 @@ -package handler - -import ( - "fmt" - "net/http" - "os" - "path/filepath" - "regexp" - "strings" - - "cyberstrike-ai/internal/config" - "cyberstrike-ai/internal/database" - "cyberstrike-ai/internal/skillpackage" - - "github.com/gin-gonic/gin" - "go.uber.org/zap" - "gopkg.in/yaml.v3" -) - -// SkillsHandler Skills处理器(磁盘 + Eino 规范;运行时由 Eino ADK skill 中间件加载) -type SkillsHandler struct { - config *config.Config - configPath string - logger *zap.Logger - db *database.DB // 数据库连接(遗留统计;MCP list/read 已移除) -} - -// NewSkillsHandler 创建新的Skills处理器 -func NewSkillsHandler(cfg *config.Config, configPath string, logger *zap.Logger) *SkillsHandler { - return &SkillsHandler{ - config: cfg, - configPath: configPath, - logger: logger, - } -} - -func (h *SkillsHandler) skillsRootAbs() string { - skillsDir := h.config.SkillsDir - if skillsDir == "" { - skillsDir = "skills" - } - configDir := filepath.Dir(h.configPath) - if !filepath.IsAbs(skillsDir) { - skillsDir = filepath.Join(configDir, skillsDir) - } - return skillsDir -} - -// SetDB 设置数据库连接(用于获取调用统计) -func (h *SkillsHandler) SetDB(db *database.DB) { - h.db = db -} - -// GetSkills 获取所有skills列表(支持分页和搜索) -func (h *SkillsHandler) GetSkills(c *gin.Context) { - allSummaries, err := skillpackage.ListSkillSummaries(h.skillsRootAbs()) - if err != nil { - h.logger.Error("获取skills列表失败", zap.Error(err)) - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - - searchKeyword := strings.TrimSpace(c.Query("search")) - - allSkillsInfo := make([]map[string]interface{}, 0, len(allSummaries)) - for _, s := range allSummaries { - skillInfo := map[string]interface{}{ - "id": s.ID, - "name": s.Name, - "dir_name": s.DirName, - "description": s.Description, - "version": s.Version, - "path": s.Path, - "tags": s.Tags, - "triggers": s.Triggers, - "script_count": s.ScriptCount, - "file_count": s.FileCount, - "progressive": s.Progressive, - "file_size": s.FileSize, - "mod_time": s.ModTime, - } - allSkillsInfo = append(allSkillsInfo, skillInfo) - } - - filteredSkillsInfo := allSkillsInfo - if searchKeyword != "" { - keywordLower := strings.ToLower(searchKeyword) - filteredSkillsInfo = make([]map[string]interface{}, 0) - for _, skillInfo := range allSkillsInfo { - id := strings.ToLower(fmt.Sprintf("%v", skillInfo["id"])) - name := strings.ToLower(fmt.Sprintf("%v", skillInfo["name"])) - description := strings.ToLower(fmt.Sprintf("%v", skillInfo["description"])) - path := strings.ToLower(fmt.Sprintf("%v", skillInfo["path"])) - version := strings.ToLower(fmt.Sprintf("%v", skillInfo["version"])) - tagsJoined := "" - if tags, ok := skillInfo["tags"].([]string); ok { - tagsJoined = strings.ToLower(strings.Join(tags, " ")) - } - trigJoined := "" - if tr, ok := skillInfo["triggers"].([]string); ok { - trigJoined = strings.ToLower(strings.Join(tr, " ")) - } - if strings.Contains(id, keywordLower) || - strings.Contains(name, keywordLower) || - strings.Contains(description, keywordLower) || - strings.Contains(path, keywordLower) || - strings.Contains(version, keywordLower) || - strings.Contains(tagsJoined, keywordLower) || - strings.Contains(trigJoined, keywordLower) { - filteredSkillsInfo = append(filteredSkillsInfo, skillInfo) - } - } - } - - // 分页参数 - limit := 20 // 默认每页20条 - offset := 0 - if limitStr := c.Query("limit"); limitStr != "" { - if parsed, err := parseInt(limitStr); err == nil && parsed > 0 { - // 允许更大的limit用于搜索场景,但设置一个合理的上限(10000) - if parsed <= 10000 { - limit = parsed - } else { - limit = 10000 - } - } - } - if offsetStr := c.Query("offset"); offsetStr != "" { - if parsed, err := parseInt(offsetStr); err == nil && parsed >= 0 { - offset = parsed - } - } - - // 计算分页范围 - total := len(filteredSkillsInfo) - start := offset - end := offset + limit - if start > total { - start = total - } - if end > total { - end = total - } - - // 获取当前页的skill列表 - var paginatedSkillsInfo []map[string]interface{} - if start < end { - paginatedSkillsInfo = filteredSkillsInfo[start:end] - } else { - paginatedSkillsInfo = []map[string]interface{}{} - } - - c.JSON(http.StatusOK, gin.H{ - "skills": paginatedSkillsInfo, - "total": total, - "limit": limit, - "offset": offset, - }) -} - -// GetSkill 获取单个skill的详细信息 -func (h *SkillsHandler) GetSkill(c *gin.Context) { - skillName := c.Param("name") - if skillName == "" { - c.JSON(http.StatusBadRequest, gin.H{"error": "skill名称不能为空"}) - return - } - - resPath := strings.TrimSpace(c.Query("resource_path")) - if resPath == "" { - resPath = strings.TrimSpace(c.Query("skill_script_path")) - } - if resPath != "" { - content, err := skillpackage.ReadScriptText(h.skillsRootAbs(), skillName, resPath, 0) - if err != nil { - h.logger.Warn("读取skill资源失败", zap.String("skill", skillName), zap.String("path", resPath), zap.Error(err)) - c.JSON(http.StatusNotFound, gin.H{"error": err.Error()}) - return - } - c.JSON(http.StatusOK, gin.H{ - "skill": map[string]interface{}{ - "id": skillName, - }, - "resource": map[string]interface{}{ - "path": resPath, - "content": content, - }, - }) - return - } - - depthStr := strings.ToLower(strings.TrimSpace(c.DefaultQuery("depth", "full"))) - section := strings.TrimSpace(c.Query("section")) - opt := skillpackage.LoadOptions{Section: section} - switch depthStr { - case "summary": - opt.Depth = "summary" - case "full", "": - opt.Depth = "full" - default: - c.JSON(http.StatusBadRequest, gin.H{"error": "depth 仅支持 summary 或 full"}) - return - } - - skill, err := skillpackage.LoadSkill(h.skillsRootAbs(), skillName, opt) - if err != nil { - h.logger.Warn("加载skill失败", zap.String("skill", skillName), zap.Error(err)) - c.JSON(http.StatusNotFound, gin.H{"error": "skill不存在: " + err.Error()}) - return - } - - skillPath := skill.Path - skillFile := filepath.Join(skillPath, "SKILL.md") - - fileInfo, _ := os.Stat(skillFile) - var fileSize int64 - var modTime string - if fileInfo != nil { - fileSize = fileInfo.Size() - modTime = fileInfo.ModTime().Format("2006-01-02 15:04:05") - } - - c.JSON(http.StatusOK, gin.H{ - "skill": map[string]interface{}{ - "id": skill.DirName, - "name": skill.Name, - "description": skill.Description, - "content": skill.Content, - "path": skill.Path, - "version": skill.Version, - "tags": skill.Tags, - "scripts": skill.Scripts, - "sections": skill.Sections, - "package_files": skill.PackageFiles, - "file_size": fileSize, - "mod_time": modTime, - "depth": depthStr, - "section": section, - }, - }) -} - -// ListSkillPackageFiles lists all files in a skill directory (Agent Skills layout). -func (h *SkillsHandler) ListSkillPackageFiles(c *gin.Context) { - skillID := c.Param("name") - files, err := skillpackage.ListPackageFiles(h.skillsRootAbs(), skillID) - if err != nil { - c.JSON(http.StatusNotFound, gin.H{"error": err.Error()}) - return - } - c.JSON(http.StatusOK, gin.H{"files": files}) -} - -// GetSkillPackageFile returns one file by relative path (?path=). -func (h *SkillsHandler) GetSkillPackageFile(c *gin.Context) { - skillID := c.Param("name") - rel := strings.TrimSpace(c.Query("path")) - if rel == "" { - c.JSON(http.StatusBadRequest, gin.H{"error": "query path is required"}) - return - } - b, err := skillpackage.ReadPackageFile(h.skillsRootAbs(), skillID, rel, 0) - if err != nil { - c.JSON(http.StatusNotFound, gin.H{"error": err.Error()}) - return - } - c.JSON(http.StatusOK, gin.H{"path": rel, "content": string(b)}) -} - -// PutSkillPackageFile writes a file inside the skill package. -func (h *SkillsHandler) PutSkillPackageFile(c *gin.Context) { - skillID := c.Param("name") - var req struct { - Path string `json:"path" binding:"required"` - Content string `json:"content"` - } - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()}) - return - } - if req.Path == "SKILL.md" { - if err := skillpackage.ValidateSkillMDPackage([]byte(req.Content), skillID); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - } - if err := skillpackage.WritePackageFile(h.skillsRootAbs(), skillID, req.Path, []byte(req.Content)); err != nil { - h.logger.Error("写入 skill 文件失败", zap.String("skill", skillID), zap.String("path", req.Path), zap.Error(err)) - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - c.JSON(http.StatusOK, gin.H{"message": "saved", "path": req.Path}) -} - -// GetSkillBoundRoles 获取绑定指定skill的角色列表 -func (h *SkillsHandler) GetSkillBoundRoles(c *gin.Context) { - skillName := c.Param("name") - if skillName == "" { - c.JSON(http.StatusBadRequest, gin.H{"error": "skill名称不能为空"}) - return - } - - boundRoles := h.getRolesBoundToSkill(skillName) - c.JSON(http.StatusOK, gin.H{ - "skill": skillName, - "bound_roles": boundRoles, - "bound_count": len(boundRoles), - }) -} - -// getRolesBoundToSkill 获取绑定指定skill的角色列表(不修改配置) -func (h *SkillsHandler) getRolesBoundToSkill(skillName string) []string { - if h.config.Roles == nil { - return []string{} - } - - boundRoles := make([]string, 0) - for roleName, role := range h.config.Roles { - // 确保角色名称正确设置 - if role.Name == "" { - role.Name = roleName - } - - // 检查角色的Skills列表中是否包含该skill - if len(role.Skills) > 0 { - for _, skill := range role.Skills { - if skill == skillName { - boundRoles = append(boundRoles, roleName) - break - } - } - } - } - - return boundRoles -} - -// CreateSkill 创建新 skill(标准 Agent Skills:生成 SKILL.md + YAML front matter) -func (h *SkillsHandler) CreateSkill(c *gin.Context) { - var req struct { - Name string `json:"name" binding:"required"` - Description string `json:"description" binding:"required"` - Content string `json:"content" binding:"required"` - } - - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()}) - return - } - - if !isValidSkillName(req.Name) { - c.JSON(http.StatusBadRequest, gin.H{"error": "skill 目录名须为小写字母、数字、连字符(与 Agent Skills name 一致)"}) - return - } - - manifest := &skillpackage.SkillManifest{ - Name: req.Name, - Description: strings.TrimSpace(req.Description), - } - skillMD, err := skillpackage.BuildSkillMD(manifest, req.Content) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - if err := skillpackage.ValidateSkillMDPackage(skillMD, req.Name); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - - skillDir := filepath.Join(h.skillsRootAbs(), req.Name) - if err := os.MkdirAll(skillDir, 0755); err != nil { - h.logger.Error("创建skill目录失败", zap.String("skill", req.Name), zap.Error(err)) - c.JSON(http.StatusInternalServerError, gin.H{"error": "创建skill目录失败: " + err.Error()}) - return - } - - if _, err := os.Stat(filepath.Join(skillDir, "SKILL.md")); err == nil { - c.JSON(http.StatusBadRequest, gin.H{"error": "skill已存在"}) - return - } - - if err := os.WriteFile(filepath.Join(skillDir, "SKILL.md"), skillMD, 0644); err != nil { - h.logger.Error("创建 SKILL.md 失败", zap.String("skill", req.Name), zap.Error(err)) - c.JSON(http.StatusInternalServerError, gin.H{"error": "创建 SKILL.md 失败: " + err.Error()}) - return - } - - h.logger.Info("创建skill成功", zap.String("skill", req.Name)) - c.JSON(http.StatusOK, gin.H{ - "message": "skill已创建", - "skill": map[string]interface{}{ - "name": req.Name, - "path": skillDir, - }, - }) -} - -// UpdateSkill 更新 SKILL.md(保留 front matter 中除 description 外的字段;可选覆盖 description) -func (h *SkillsHandler) UpdateSkill(c *gin.Context) { - skillName := c.Param("name") - if skillName == "" { - c.JSON(http.StatusBadRequest, gin.H{"error": "skill名称不能为空"}) - return - } - - var req struct { - Description string `json:"description"` - Content string `json:"content" binding:"required"` - } - - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()}) - return - } - - mdPath := filepath.Join(h.skillsRootAbs(), skillName, "SKILL.md") - raw, err := os.ReadFile(mdPath) - if err != nil { - c.JSON(http.StatusNotFound, gin.H{"error": "skill不存在: " + err.Error()}) - return - } - m, _, err := skillpackage.ParseSkillMD(raw) - if err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - if req.Description != "" { - m.Description = strings.TrimSpace(req.Description) - } - skillMD, err := skillpackage.BuildSkillMD(m, req.Content) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - if err := skillpackage.ValidateSkillMDPackage(skillMD, skillName); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - - skillDir := filepath.Join(h.skillsRootAbs(), skillName) - - if err := os.WriteFile(filepath.Join(skillDir, "SKILL.md"), skillMD, 0644); err != nil { - h.logger.Error("更新 SKILL.md 失败", zap.String("skill", skillName), zap.Error(err)) - c.JSON(http.StatusInternalServerError, gin.H{"error": "更新 SKILL.md 失败: " + err.Error()}) - return - } - - h.logger.Info("更新skill成功", zap.String("skill", skillName)) - c.JSON(http.StatusOK, gin.H{ - "message": "skill已更新", - }) -} - -// DeleteSkill 删除skill -func (h *SkillsHandler) DeleteSkill(c *gin.Context) { - skillName := c.Param("name") - if skillName == "" { - c.JSON(http.StatusBadRequest, gin.H{"error": "skill名称不能为空"}) - return - } - - // 检查是否有角色绑定了该skill,如果有则自动移除绑定 - affectedRoles := h.removeSkillFromRoles(skillName) - if len(affectedRoles) > 0 { - h.logger.Info("从角色中移除skill绑定", - zap.String("skill", skillName), - zap.Strings("roles", affectedRoles)) - } - - skillDir := filepath.Join(h.skillsRootAbs(), skillName) - if err := os.RemoveAll(skillDir); err != nil { - h.logger.Error("删除skill失败", zap.String("skill", skillName), zap.Error(err)) - c.JSON(http.StatusInternalServerError, gin.H{"error": "删除skill失败: " + err.Error()}) - return - } - responseMsg := "skill已删除" - if len(affectedRoles) > 0 { - responseMsg = fmt.Sprintf("skill已删除,已自动从 %d 个角色中移除绑定: %s", - len(affectedRoles), strings.Join(affectedRoles, ", ")) - } - - h.logger.Info("删除skill成功", zap.String("skill", skillName)) - c.JSON(http.StatusOK, gin.H{ - "message": responseMsg, - "affected_roles": affectedRoles, - }) -} - -// GetSkillStats 获取skills调用统计信息 -func (h *SkillsHandler) GetSkillStats(c *gin.Context) { - skillList, err := skillpackage.ListSkillDirNames(h.skillsRootAbs()) - if err != nil { - h.logger.Error("获取skills列表失败", zap.Error(err)) - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - - skillsDir := h.skillsRootAbs() - - // 从数据库加载调用统计 - var skillStatsMap map[string]*database.SkillStats - if h.db != nil { - dbStats, err := h.db.LoadSkillStats() - if err != nil { - h.logger.Warn("从数据库加载Skills统计信息失败", zap.Error(err)) - skillStatsMap = make(map[string]*database.SkillStats) - } else { - skillStatsMap = dbStats - } - } else { - skillStatsMap = make(map[string]*database.SkillStats) - } - - // 构建统计信息(包含所有skills,即使没有调用记录) - statsList := make([]map[string]interface{}, 0, len(skillList)) - totalCalls := 0 - totalSuccess := 0 - totalFailed := 0 - - for _, skillName := range skillList { - stat, exists := skillStatsMap[skillName] - if !exists { - stat = &database.SkillStats{ - SkillName: skillName, - TotalCalls: 0, - SuccessCalls: 0, - FailedCalls: 0, - } - } - - totalCalls += stat.TotalCalls - totalSuccess += stat.SuccessCalls - totalFailed += stat.FailedCalls - - lastCallTimeStr := "" - if stat.LastCallTime != nil { - lastCallTimeStr = stat.LastCallTime.Format("2006-01-02 15:04:05") - } - - statsList = append(statsList, map[string]interface{}{ - "skill_name": stat.SkillName, - "total_calls": stat.TotalCalls, - "success_calls": stat.SuccessCalls, - "failed_calls": stat.FailedCalls, - "last_call_time": lastCallTimeStr, - }) - } - - c.JSON(http.StatusOK, gin.H{ - "total_skills": len(skillList), - "total_calls": totalCalls, - "total_success": totalSuccess, - "total_failed": totalFailed, - "skills_dir": skillsDir, - "stats": statsList, - }) -} - -// ClearSkillStats 清空所有Skills统计信息 -func (h *SkillsHandler) ClearSkillStats(c *gin.Context) { - if h.db == nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "数据库连接未配置"}) - return - } - - if err := h.db.ClearSkillStats(); err != nil { - h.logger.Error("清空Skills统计信息失败", zap.Error(err)) - c.JSON(http.StatusInternalServerError, gin.H{"error": "清空统计信息失败: " + err.Error()}) - return - } - - h.logger.Info("已清空所有Skills统计信息") - c.JSON(http.StatusOK, gin.H{ - "message": "已清空所有Skills统计信息", - }) -} - -// ClearSkillStatsByName 清空指定skill的统计信息 -func (h *SkillsHandler) ClearSkillStatsByName(c *gin.Context) { - skillName := c.Param("name") - if skillName == "" { - c.JSON(http.StatusBadRequest, gin.H{"error": "skill名称不能为空"}) - return - } - - if h.db == nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "数据库连接未配置"}) - return - } - - if err := h.db.ClearSkillStatsByName(skillName); err != nil { - h.logger.Error("清空指定skill统计信息失败", zap.String("skill", skillName), zap.Error(err)) - c.JSON(http.StatusInternalServerError, gin.H{"error": "清空统计信息失败: " + err.Error()}) - return - } - - h.logger.Info("已清空指定skill统计信息", zap.String("skill", skillName)) - c.JSON(http.StatusOK, gin.H{ - "message": fmt.Sprintf("已清空skill '%s' 的统计信息", skillName), - }) -} - -// removeSkillFromRoles 从所有角色中移除指定的skill绑定 -// 返回受影响角色名称列表 -func (h *SkillsHandler) removeSkillFromRoles(skillName string) []string { - if h.config.Roles == nil { - return []string{} - } - - affectedRoles := make([]string, 0) - rolesToUpdate := make(map[string]config.RoleConfig) - - // 遍历所有角色,查找并移除skill绑定 - for roleName, role := range h.config.Roles { - // 确保角色名称正确设置 - if role.Name == "" { - role.Name = roleName - } - - // 检查角色的Skills列表中是否包含要删除的skill - if len(role.Skills) > 0 { - updated := false - newSkills := make([]string, 0, len(role.Skills)) - for _, skill := range role.Skills { - if skill != skillName { - newSkills = append(newSkills, skill) - } else { - updated = true - } - } - if updated { - role.Skills = newSkills - rolesToUpdate[roleName] = role - affectedRoles = append(affectedRoles, roleName) - } - } - } - - // 如果有角色需要更新,保存到文件 - if len(rolesToUpdate) > 0 { - // 更新内存中的配置 - for roleName, role := range rolesToUpdate { - h.config.Roles[roleName] = role - } - // 保存更新后的角色配置到文件 - if err := h.saveRolesConfig(); err != nil { - h.logger.Error("保存角色配置失败", zap.Error(err)) - } - } - - return affectedRoles -} - -// saveRolesConfig 保存角色配置到文件(从SkillsHandler调用) -func (h *SkillsHandler) saveRolesConfig() error { - configDir := filepath.Dir(h.configPath) - rolesDir := h.config.RolesDir - if rolesDir == "" { - rolesDir = "roles" // 默认目录 - } - - // 如果是相对路径,相对于配置文件所在目录 - if !filepath.IsAbs(rolesDir) { - rolesDir = filepath.Join(configDir, rolesDir) - } - - // 确保目录存在 - if err := os.MkdirAll(rolesDir, 0755); err != nil { - return fmt.Errorf("创建角色目录失败: %w", err) - } - - // 保存每个角色到独立的文件 - if h.config.Roles != nil { - for roleName, role := range h.config.Roles { - // 确保角色名称正确设置 - if role.Name == "" { - role.Name = roleName - } - - // 使用角色名称作为文件名(安全化文件名,避免特殊字符) - safeFileName := sanitizeRoleFileName(role.Name) - roleFile := filepath.Join(rolesDir, safeFileName+".yaml") - - // 将角色配置序列化为YAML - roleData, err := yaml.Marshal(&role) - if err != nil { - h.logger.Error("序列化角色配置失败", zap.String("role", roleName), zap.Error(err)) - continue - } - - // 处理icon字段:确保包含\U的icon值被引号包围(YAML需要引号才能正确解析Unicode转义) - roleDataStr := string(roleData) - if role.Icon != "" && strings.HasPrefix(role.Icon, "\\U") { - // 匹配 icon: \UXXXXXXXX 格式(没有引号),排除已经有引号的情况 - re := regexp.MustCompile(`(?m)^(icon:\s+)(\\U[0-9A-F]{8})(\s*)$`) - roleDataStr = re.ReplaceAllString(roleDataStr, `${1}"${2}"${3}`) - roleData = []byte(roleDataStr) - } - - // 写入文件 - if err := os.WriteFile(roleFile, roleData, 0644); err != nil { - h.logger.Error("保存角色配置文件失败", zap.String("role", roleName), zap.String("file", roleFile), zap.Error(err)) - continue - } - - h.logger.Info("角色配置已保存到文件", zap.String("role", roleName), zap.String("file", roleFile)) - } - } - - return nil -} - -// sanitizeRoleFileName 将角色名称转换为安全的文件名 -func sanitizeRoleFileName(name string) string { - // 替换可能不安全的字符 - replacer := map[rune]string{ - '/': "_", - '\\': "_", - ':': "_", - '*': "_", - '?': "_", - '"': "_", - '<': "_", - '>': "_", - '|': "_", - ' ': "_", - } - - var result []rune - for _, r := range name { - if replacement, ok := replacer[r]; ok { - result = append(result, []rune(replacement)...) - } else { - result = append(result, r) - } - } - - fileName := string(result) - // 如果文件名为空,使用默认名称 - if fileName == "" { - fileName = "role" - } - - return fileName -} - -// isValidSkillName 验证 skill 目录名(与 Agent Skills 的 name 字段一致:小写、数字、连字符) -func isValidSkillName(name string) bool { - if name == "" || len(name) > 100 { - return false - } - for _, r := range name { - if !((r >= 'a' && r <= 'z') || (r >= '0' && r <= '9') || r == '-') { - return false - } - } - return true -} diff --git a/handler/sse_keepalive.go b/handler/sse_keepalive.go deleted file mode 100644 index ae750ecd..00000000 --- a/handler/sse_keepalive.go +++ /dev/null @@ -1,58 +0,0 @@ -package handler - -import ( - "fmt" - "net/http" - "sync" - "time" - - "github.com/gin-gonic/gin" -) - -// sseInterval is how often we write on long SSE streams. Shorter intervals help NATs and -// some proxies that treat connections as idle; 10s is a reasonable balance with traffic. -const sseKeepaliveInterval = 10 * time.Second - -// sseKeepalive sends periodic SSE traffic so proxies (e.g. nginx proxy_read_timeout), NATs, -// and load balancers do not close long-running streams. Some intermediaries ignore comment-only -// lines, so we send both a comment and a minimal data frame (type heartbeat) per tick. -// -// writeMu must be the same mutex used by sendEvent for this request: concurrent writes to -// http.ResponseWriter break chunked transfer encoding (browser: net::ERR_INVALID_CHUNKED_ENCODING). -func sseKeepalive(c *gin.Context, stop <-chan struct{}, writeMu *sync.Mutex) { - if writeMu == nil { - return - } - ticker := time.NewTicker(sseKeepaliveInterval) - defer ticker.Stop() - for { - select { - case <-stop: - return - case <-c.Request.Context().Done(): - return - case <-ticker.C: - select { - case <-stop: - return - case <-c.Request.Context().Done(): - return - default: - } - writeMu.Lock() - if _, err := fmt.Fprintf(c.Writer, ": keepalive\n\n"); err != nil { - writeMu.Unlock() - return - } - // data: frame so strict proxies still see downstream bytes (comments alone may not reset timers) - if _, err := fmt.Fprintf(c.Writer, `data: {"type":"heartbeat"}`+"\n\n"); err != nil { - writeMu.Unlock() - return - } - if flusher, ok := c.Writer.(http.Flusher); ok { - flusher.Flush() - } - writeMu.Unlock() - } - } -} diff --git a/handler/task_manager.go b/handler/task_manager.go deleted file mode 100644 index 9964ad5c..00000000 --- a/handler/task_manager.go +++ /dev/null @@ -1,276 +0,0 @@ -package handler - -import ( - "context" - "errors" - "sync" - "time" -) - -// ErrTaskCancelled 用户取消任务的错误 -var ErrTaskCancelled = errors.New("agent task cancelled by user") - -// ErrTaskAlreadyRunning 会话已有任务正在执行 -var ErrTaskAlreadyRunning = errors.New("agent task already running for conversation") - -// AgentTask 描述正在运行的Agent任务 -type AgentTask struct { - ConversationID string `json:"conversationId"` - Message string `json:"message,omitempty"` - StartedAt time.Time `json:"startedAt"` - Status string `json:"status"` - CancellingAt time.Time `json:"-"` // 进入 cancelling 状态的时间,用于清理长时间卡住的任务 - - cancel func(error) -} - -// CompletedTask 已完成的任务(用于历史记录) -type CompletedTask struct { - ConversationID string `json:"conversationId"` - Message string `json:"message,omitempty"` - StartedAt time.Time `json:"startedAt"` - CompletedAt time.Time `json:"completedAt"` - Status string `json:"status"` -} - -// AgentTaskManager 管理正在运行的Agent任务 -type AgentTaskManager struct { - mu sync.RWMutex - tasks map[string]*AgentTask - completedTasks []*CompletedTask // 最近完成的任务历史 - maxHistorySize int // 最大历史记录数 - historyRetention time.Duration // 历史记录保留时间 -} - -const ( - // cancellingStuckThreshold 处于「取消中」超过此时长则强制从运行列表移除。正常取消会在当前步骤内返回, - // 超过则视为卡住,尽快释放会话。常见做法多为 30–60s 内释放。 - cancellingStuckThreshold = 45 * time.Second - // cancellingStuckThresholdLegacy 未记录 CancellingAt 时用 StartedAt 判断的兜底时长 - cancellingStuckThresholdLegacy = 2 * time.Minute - cleanupInterval = 15 * time.Second // 与上面阈值配合,最长约 60s 内移除 -) - -// NewAgentTaskManager 创建任务管理器 -func NewAgentTaskManager() *AgentTaskManager { - m := &AgentTaskManager{ - tasks: make(map[string]*AgentTask), - completedTasks: make([]*CompletedTask, 0), - maxHistorySize: 50, // 最多保留50条历史记录 - historyRetention: 24 * time.Hour, // 保留24小时 - } - go m.runStuckCancellingCleanup() - return m -} - -// runStuckCancellingCleanup 定期将长时间处于「取消中」的任务强制结束,避免卡住无法发新消息 -func (m *AgentTaskManager) runStuckCancellingCleanup() { - ticker := time.NewTicker(cleanupInterval) - defer ticker.Stop() - for range ticker.C { - m.cleanupStuckCancelling() - } -} - -func (m *AgentTaskManager) cleanupStuckCancelling() { - m.mu.Lock() - var toFinish []string - now := time.Now() - for id, task := range m.tasks { - if task.Status != "cancelling" { - continue - } - var elapsed time.Duration - if !task.CancellingAt.IsZero() { - elapsed = now.Sub(task.CancellingAt) - if elapsed < cancellingStuckThreshold { - continue - } - } else { - elapsed = now.Sub(task.StartedAt) - if elapsed < cancellingStuckThresholdLegacy { - continue - } - } - toFinish = append(toFinish, id) - } - m.mu.Unlock() - for _, id := range toFinish { - m.FinishTask(id, "cancelled") - } -} - -// StartTask 注册并开始一个新的任务 -func (m *AgentTaskManager) StartTask(conversationID, message string, cancel context.CancelCauseFunc) (*AgentTask, error) { - m.mu.Lock() - defer m.mu.Unlock() - - if _, exists := m.tasks[conversationID]; exists { - return nil, ErrTaskAlreadyRunning - } - - task := &AgentTask{ - ConversationID: conversationID, - Message: message, - StartedAt: time.Now(), - Status: "running", - cancel: func(err error) { - if cancel != nil { - cancel(err) - } - }, - } - - m.tasks[conversationID] = task - return task, nil -} - -// CancelTask 取消指定会话的任务。若任务已在取消中,仍返回 (true, nil) 以便接口幂等、前端不报错。 -func (m *AgentTaskManager) CancelTask(conversationID string, cause error) (bool, error) { - m.mu.Lock() - task, exists := m.tasks[conversationID] - if !exists { - m.mu.Unlock() - return false, nil - } - - // 如果已经处于取消流程,视为成功(幂等),避免前端重复点击报「未找到任务」 - if task.Status == "cancelling" { - m.mu.Unlock() - return true, nil - } - - task.Status = "cancelling" - task.CancellingAt = time.Now() - cancel := task.cancel - m.mu.Unlock() - - if cause == nil { - cause = ErrTaskCancelled - } - if cancel != nil { - cancel(cause) - } - return true, nil -} - -// UpdateTaskStatus 更新任务状态但不删除任务(用于在发送事件前更新状态) -func (m *AgentTaskManager) UpdateTaskStatus(conversationID string, status string) { - m.mu.Lock() - defer m.mu.Unlock() - - task, exists := m.tasks[conversationID] - if !exists { - return - } - - if status != "" { - task.Status = status - } -} - -// FinishTask 完成任务并从管理器中移除 -func (m *AgentTaskManager) FinishTask(conversationID string, finalStatus string) { - m.mu.Lock() - defer m.mu.Unlock() - - task, exists := m.tasks[conversationID] - if !exists { - return - } - - if finalStatus != "" { - task.Status = finalStatus - } - - // 保存到历史记录 - completedTask := &CompletedTask{ - ConversationID: task.ConversationID, - Message: task.Message, - StartedAt: task.StartedAt, - CompletedAt: time.Now(), - Status: finalStatus, - } - - // 添加到历史记录 - m.completedTasks = append(m.completedTasks, completedTask) - - // 清理过期和过多的历史记录 - m.cleanupHistory() - - // 从运行任务中移除 - delete(m.tasks, conversationID) -} - -// cleanupHistory 清理过期的历史记录 -func (m *AgentTaskManager) cleanupHistory() { - now := time.Now() - cutoffTime := now.Add(-m.historyRetention) - - // 过滤掉过期的记录 - validTasks := make([]*CompletedTask, 0, len(m.completedTasks)) - for _, task := range m.completedTasks { - if task.CompletedAt.After(cutoffTime) { - validTasks = append(validTasks, task) - } - } - - // 如果仍然超过最大数量,只保留最新的 - if len(validTasks) > m.maxHistorySize { - // 按完成时间排序,保留最新的 - // 由于是追加的,最新的在最后,所以直接取最后N个 - start := len(validTasks) - m.maxHistorySize - validTasks = validTasks[start:] - } - - m.completedTasks = validTasks -} - -// GetActiveTasks 返回所有正在运行的任务 -func (m *AgentTaskManager) GetActiveTasks() []*AgentTask { - m.mu.RLock() - defer m.mu.RUnlock() - - result := make([]*AgentTask, 0, len(m.tasks)) - for _, task := range m.tasks { - result = append(result, &AgentTask{ - ConversationID: task.ConversationID, - Message: task.Message, - StartedAt: task.StartedAt, - Status: task.Status, - }) - } - return result -} - -// GetCompletedTasks 返回最近完成的任务历史 -func (m *AgentTaskManager) GetCompletedTasks() []*CompletedTask { - m.mu.RLock() - defer m.mu.RUnlock() - - // 清理过期记录(只读锁,不影响其他操作) - // 注意:这里不能直接调用cleanupHistory,因为需要写锁 - // 所以返回时过滤过期记录 - now := time.Now() - cutoffTime := now.Add(-m.historyRetention) - - result := make([]*CompletedTask, 0, len(m.completedTasks)) - for _, task := range m.completedTasks { - if task.CompletedAt.After(cutoffTime) { - result = append(result, task) - } - } - - // 按完成时间倒序排序(最新的在前) - // 由于是追加的,最新的在最后,需要反转 - for i, j := 0, len(result)-1; i < j; i, j = i+1, j-1 { - result[i], result[j] = result[j], result[i] - } - - // 限制返回数量 - if len(result) > m.maxHistorySize { - result = result[:m.maxHistorySize] - } - - return result -} diff --git a/handler/terminal.go b/handler/terminal.go deleted file mode 100644 index a17d361d..00000000 --- a/handler/terminal.go +++ /dev/null @@ -1,257 +0,0 @@ -package handler - -import ( - "bytes" - "context" - "encoding/json" - "net/http" - "os" - "os/exec" - "path/filepath" - "runtime" - "strings" - "time" - - "github.com/gin-gonic/gin" - "go.uber.org/zap" -) - -const ( - terminalMaxCommandLen = 4096 - terminalMaxOutputLen = 256 * 1024 // 256KB - terminalTimeout = 30 * time.Minute -) - -// TerminalHandler 处理系统设置中的终端命令执行 -type TerminalHandler struct { - logger *zap.Logger -} - -// maskTerminalCommand 对可能包含敏感信息的终端命令做脱敏,避免在日志中直接记录密码等内容 -func maskTerminalCommand(cmd string) string { - trimmed := strings.TrimSpace(cmd) - lower := strings.ToLower(trimmed) - if strings.Contains(lower, "sudo") || strings.Contains(lower, "password") { - return "[masked sensitive terminal command]" - } - if len(trimmed) > 256 { - return trimmed[:256] + "..." - } - return trimmed -} - -// NewTerminalHandler 创建终端处理器 -func NewTerminalHandler(logger *zap.Logger) *TerminalHandler { - return &TerminalHandler{logger: logger} -} - -// RunCommandRequest 执行命令请求 -type RunCommandRequest struct { - Command string `json:"command"` - Shell string `json:"shell,omitempty"` - Cwd string `json:"cwd,omitempty"` -} - -// RunCommandResponse 执行命令响应 -type RunCommandResponse struct { - Stdout string `json:"stdout"` - Stderr string `json:"stderr"` - ExitCode int `json:"exit_code"` - Error string `json:"error,omitempty"` -} - -// RunCommand 执行终端命令(需登录) -func (h *TerminalHandler) RunCommand(c *gin.Context) { - var req RunCommandRequest - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": "请求体无效,需要 command 字段"}) - return - } - - cmdStr := strings.TrimSpace(req.Command) - if cmdStr == "" { - c.JSON(http.StatusBadRequest, gin.H{"error": "command 不能为空"}) - return - } - if len(cmdStr) > terminalMaxCommandLen { - c.JSON(http.StatusBadRequest, gin.H{"error": "命令过长"}) - return - } - - shell := req.Shell - if shell == "" { - if runtime.GOOS == "windows" { - shell = "cmd" - } else { - shell = "sh" - } - } - - ctx, cancel := context.WithTimeout(c.Request.Context(), terminalTimeout) - defer cancel() - - var cmd *exec.Cmd - if runtime.GOOS == "windows" { - cmd = exec.CommandContext(ctx, "cmd", "/c", cmdStr) - } else { - cmd = exec.CommandContext(ctx, shell, "-c", cmdStr) - // 无 TTY 时设置 COLUMNS/TERM,使 ping 等工具的 usage 排版与真实终端一致 - cmd.Env = append(os.Environ(), "COLUMNS=256", "LINES=40", "TERM=xterm-256color") - } - - if req.Cwd != "" { - absCwd, err := filepath.Abs(req.Cwd) - if err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": "工作目录无效"}) - return - } - cur, _ := os.Getwd() - curAbs, _ := filepath.Abs(cur) - rel, err := filepath.Rel(curAbs, absCwd) - if err != nil || strings.HasPrefix(rel, "..") || rel == ".." { - c.JSON(http.StatusBadRequest, gin.H{"error": "工作目录必须在当前进程目录下"}) - return - } - cmd.Dir = absCwd - } - - var stdout, stderr bytes.Buffer - cmd.Stdout = &stdout - cmd.Stderr = &stderr - - err := cmd.Run() - stdoutBytes := stdout.Bytes() - stderrBytes := stderr.Bytes() - - // 限制输出长度,防止内存占用过大(复制后截断,避免修改原 buffer) - truncSuffix := []byte("\n...(输出已截断)\n") - if len(stdoutBytes) > terminalMaxOutputLen { - tmp := make([]byte, terminalMaxOutputLen+len(truncSuffix)) - n := copy(tmp, stdoutBytes[:terminalMaxOutputLen]) - copy(tmp[n:], truncSuffix) - stdoutBytes = tmp - } - if len(stderrBytes) > terminalMaxOutputLen { - tmp := make([]byte, terminalMaxOutputLen+len(truncSuffix)) - n := copy(tmp, stderrBytes[:terminalMaxOutputLen]) - copy(tmp[n:], truncSuffix) - stderrBytes = tmp - } - - exitCode := 0 - if err != nil { - if exitErr, ok := err.(*exec.ExitError); ok { - exitCode = exitErr.ExitCode() - } else { - exitCode = -1 - } - if ctx.Err() == context.DeadlineExceeded { - so := strings.ReplaceAll(string(stdoutBytes), "\r\n", "\n") - so = strings.ReplaceAll(so, "\r", "\n") - se := strings.ReplaceAll(string(stderrBytes), "\r\n", "\n") - se = strings.ReplaceAll(se, "\r", "\n") - resp := RunCommandResponse{ - Stdout: so, - Stderr: se, - ExitCode: -1, - Error: "命令执行超时(" + terminalTimeout.String() + ")", - } - c.JSON(http.StatusOK, resp) - return - } - h.logger.Debug("终端命令执行异常", zap.String("command", maskTerminalCommand(cmdStr)), zap.Error(err)) - } - - // 统一为 \n,避免前端因 \r 出现错位/对角线排版 - stdoutStr := strings.ReplaceAll(string(stdoutBytes), "\r\n", "\n") - stdoutStr = strings.ReplaceAll(stdoutStr, "\r", "\n") - stderrStr := strings.ReplaceAll(string(stderrBytes), "\r\n", "\n") - stderrStr = strings.ReplaceAll(stderrStr, "\r", "\n") - - resp := RunCommandResponse{ - Stdout: stdoutStr, - Stderr: stderrStr, - ExitCode: exitCode, - } - if err != nil && exitCode != 0 { - resp.Error = err.Error() - } - c.JSON(http.StatusOK, resp) -} - -// streamEvent SSE 事件 -type streamEvent struct { - T string `json:"t"` // "out" | "err" | "exit" - D string `json:"d,omitempty"` - C int `json:"c"` // exit code(不用 omitempty,否则 0 不序列化导致前端显示 [exit undefined]) -} - -// RunCommandStream 流式执行命令,输出实时推送到前端(SSE) -func (h *TerminalHandler) RunCommandStream(c *gin.Context) { - var req RunCommandRequest - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": "请求体无效,需要 command 字段"}) - return - } - cmdStr := strings.TrimSpace(req.Command) - if cmdStr == "" { - c.JSON(http.StatusBadRequest, gin.H{"error": "command 不能为空"}) - return - } - if len(cmdStr) > terminalMaxCommandLen { - c.JSON(http.StatusBadRequest, gin.H{"error": "命令过长"}) - return - } - shell := req.Shell - if shell == "" { - if runtime.GOOS == "windows" { - shell = "cmd" - } else { - shell = "sh" - } - } - ctx, cancel := context.WithTimeout(c.Request.Context(), terminalTimeout) - defer cancel() - - var cmd *exec.Cmd - if runtime.GOOS == "windows" { - cmd = exec.CommandContext(ctx, "cmd", "/c", cmdStr) - } else { - cmd = exec.CommandContext(ctx, shell, "-c", cmdStr) - cmd.Env = append(os.Environ(), "COLUMNS=256", "LINES=40", "TERM=xterm-256color") - } - if req.Cwd != "" { - absCwd, err := filepath.Abs(req.Cwd) - if err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": "工作目录无效"}) - return - } - cur, _ := os.Getwd() - curAbs, _ := filepath.Abs(cur) - rel, err := filepath.Rel(curAbs, absCwd) - if err != nil || strings.HasPrefix(rel, "..") || rel == ".." { - c.JSON(http.StatusBadRequest, gin.H{"error": "工作目录必须在当前进程目录下"}) - return - } - cmd.Dir = absCwd - } - - c.Header("Content-Type", "text/event-stream") - c.Header("Cache-Control", "no-cache") - c.Header("Connection", "keep-alive") - c.Header("X-Accel-Buffering", "no") - c.Writer.WriteHeader(http.StatusOK) - flusher, ok := c.Writer.(http.Flusher) - if !ok { - cancel() - return - } - - sendEvent := func(ev streamEvent) { - body, _ := json.Marshal(ev) - c.SSEvent("", string(body)) - flusher.Flush() - } - - runCommandStreamImpl(cmd, sendEvent, ctx) -} diff --git a/handler/terminal_stream_unix.go b/handler/terminal_stream_unix.go deleted file mode 100644 index 9b543b6c..00000000 --- a/handler/terminal_stream_unix.go +++ /dev/null @@ -1,46 +0,0 @@ -//go:build !windows - -package handler - -import ( - "bufio" - "context" - "os/exec" - "strings" - - "github.com/creack/pty" -) - -const ptyCols = 256 -const ptyRows = 40 - -// runCommandStreamImpl 在 Unix 下用 PTY 执行,使 ping 等命令按终端宽度排版(isatty 为真) -func runCommandStreamImpl(cmd *exec.Cmd, sendEvent func(streamEvent), ctx context.Context) { - ptmx, err := pty.StartWithSize(cmd, &pty.Winsize{Cols: ptyCols, Rows: ptyRows}) - if err != nil { - sendEvent(streamEvent{T: "exit", C: -1}) - return - } - defer ptmx.Close() - - normalize := func(s string) string { - s = strings.ReplaceAll(s, "\r\n", "\n") - return strings.ReplaceAll(s, "\r", "\n") - } - sc := bufio.NewScanner(ptmx) - for sc.Scan() { - sendEvent(streamEvent{T: "out", D: normalize(sc.Text())}) - } - exitCode := 0 - if err := cmd.Wait(); err != nil { - if exitErr, ok := err.(*exec.ExitError); ok { - exitCode = exitErr.ExitCode() - } else { - exitCode = -1 - } - } - if ctx.Err() == context.DeadlineExceeded { - exitCode = -1 - } - sendEvent(streamEvent{T: "exit", C: exitCode}) -} diff --git a/handler/terminal_stream_windows.go b/handler/terminal_stream_windows.go deleted file mode 100644 index 9f69303c..00000000 --- a/handler/terminal_stream_windows.go +++ /dev/null @@ -1,65 +0,0 @@ -//go:build windows - -package handler - -import ( - "bufio" - "context" - "os/exec" - "strings" - "sync" -) - -// runCommandStreamImpl 在 Windows 下用 stdout/stderr 管道执行 -func runCommandStreamImpl(cmd *exec.Cmd, sendEvent func(streamEvent), ctx context.Context) { - stdoutPipe, err := cmd.StdoutPipe() - if err != nil { - sendEvent(streamEvent{T: "exit", C: -1}) - return - } - stderrPipe, err := cmd.StderrPipe() - if err != nil { - sendEvent(streamEvent{T: "exit", C: -1}) - return - } - if err := cmd.Start(); err != nil { - sendEvent(streamEvent{T: "exit", C: -1}) - return - } - - normalize := func(s string) string { - s = strings.ReplaceAll(s, "\r\n", "\n") - return strings.ReplaceAll(s, "\r", "\n") - } - - var wg sync.WaitGroup - wg.Add(2) - go func() { - defer wg.Done() - sc := bufio.NewScanner(stdoutPipe) - for sc.Scan() { - sendEvent(streamEvent{T: "out", D: normalize(sc.Text())}) - } - }() - go func() { - defer wg.Done() - sc := bufio.NewScanner(stderrPipe) - for sc.Scan() { - sendEvent(streamEvent{T: "err", D: normalize(sc.Text())}) - } - }() - - wg.Wait() - exitCode := 0 - if err := cmd.Wait(); err != nil { - if exitErr, ok := err.(*exec.ExitError); ok { - exitCode = exitErr.ExitCode() - } else { - exitCode = -1 - } - } - if ctx.Err() == context.DeadlineExceeded { - exitCode = -1 - } - sendEvent(streamEvent{T: "exit", C: exitCode}) -} diff --git a/handler/terminal_ws_unix.go b/handler/terminal_ws_unix.go deleted file mode 100644 index eaa5df67..00000000 --- a/handler/terminal_ws_unix.go +++ /dev/null @@ -1,112 +0,0 @@ -//go:build !windows - -package handler - -import ( - "encoding/json" - "net/http" - "os" - "os/exec" - "time" - - "github.com/creack/pty" - "github.com/gin-gonic/gin" - "github.com/gorilla/websocket" -) - -// terminalResize is sent by the frontend when the xterm.js terminal is resized. -type terminalResize struct { - Type string `json:"type"` - Cols uint16 `json:"cols"` - Rows uint16 `json:"rows"` -} - -// wsUpgrader 仅用于系统设置中的终端 WebSocket,会复用已有的登录保护(JWT 中间件在上层路由组) -var wsUpgrader = websocket.Upgrader{ - CheckOrigin: func(r *http.Request) bool { - // 由于已在 Gin 路由层做了认证,这里放宽 Origin,方便在同一域名下通过 HTTPS/WSS 访问 - return true - }, -} - -// RunCommandWS 提供真正交互式 Shell:基于 WebSocket + PTY 的长会话 -// 前端建立 WebSocket 连接后,所有键盘输入都会透传到 Shell,Shell 的输出也会实时写回前端。 -func (h *TerminalHandler) RunCommandWS(c *gin.Context) { - conn, err := wsUpgrader.Upgrade(c.Writer, c.Request, nil) - if err != nil { - return - } - defer conn.Close() - - // 启动交互式 Shell,这里优先使用 bash,找不到则退回 sh - shell := "bash" - if _, err := exec.LookPath(shell); err != nil { - shell = "sh" - } - cmd := exec.Command(shell) - cmd.Env = append(os.Environ(), - "COLUMNS=80", - "LINES=24", - "TERM=xterm-256color", - ) - - // Use 80x24 as a safe default; the frontend will send the actual size immediately after connecting. - ptmx, err := pty.StartWithSize(cmd, &pty.Winsize{Cols: 80, Rows: 24}) - if err != nil { - return - } - defer ptmx.Close() - - // Shell -> WebSocket:将 PTY 输出实时发给前端 - doneChan := make(chan struct{}) - go func() { - buf := make([]byte, 4096) - for { - n, err := ptmx.Read(buf) - if n > 0 { - _ = conn.WriteMessage(websocket.BinaryMessage, buf[:n]) - } - if err != nil { - break - } - } - close(doneChan) - }() - - // WebSocket -> Shell:将前端输入写入 PTY(包括 sudo 密码、Ctrl+C 等) - conn.SetReadLimit(64 * 1024) - _ = conn.SetReadDeadline(time.Now().Add(terminalTimeout)) - conn.SetPongHandler(func(string) error { - _ = conn.SetReadDeadline(time.Now().Add(terminalTimeout)) - return nil - }) - - for { - msgType, data, err := conn.ReadMessage() - if err != nil { - _ = cmd.Process.Kill() - break - } - if msgType != websocket.TextMessage && msgType != websocket.BinaryMessage { - continue - } - if len(data) == 0 { - continue - } - // Check if this is a resize message (JSON with type:"resize") - if msgType == websocket.TextMessage && len(data) > 0 && data[0] == '{' { - var resize terminalResize - if json.Unmarshal(data, &resize) == nil && resize.Type == "resize" && resize.Cols > 0 && resize.Rows > 0 { - _ = pty.Setsize(ptmx, &pty.Winsize{Cols: resize.Cols, Rows: resize.Rows}) - continue - } - } - if _, err := ptmx.Write(data); err != nil { - _ = cmd.Process.Kill() - break - } - } - - <-doneChan -} - diff --git a/handler/vulnerability.go b/handler/vulnerability.go deleted file mode 100644 index 9975efa7..00000000 --- a/handler/vulnerability.go +++ /dev/null @@ -1,263 +0,0 @@ -package handler - -import ( - "net/http" - "strconv" - - "cyberstrike-ai/internal/database" - "github.com/gin-gonic/gin" - "go.uber.org/zap" -) - -// VulnerabilityHandler 漏洞处理器 -type VulnerabilityHandler struct { - db *database.DB - logger *zap.Logger -} - -// NewVulnerabilityHandler 创建新的漏洞处理器 -func NewVulnerabilityHandler(db *database.DB, logger *zap.Logger) *VulnerabilityHandler { - return &VulnerabilityHandler{ - db: db, - logger: logger, - } -} - -// CreateVulnerabilityRequest 创建漏洞请求 -type CreateVulnerabilityRequest struct { - ConversationID string `json:"conversation_id" binding:"required"` - Title string `json:"title" binding:"required"` - Description string `json:"description"` - Severity string `json:"severity" binding:"required"` - Status string `json:"status"` - Type string `json:"type"` - Target string `json:"target"` - Proof string `json:"proof"` - Impact string `json:"impact"` - Recommendation string `json:"recommendation"` -} - -// CreateVulnerability 创建漏洞 -func (h *VulnerabilityHandler) CreateVulnerability(c *gin.Context) { - var req CreateVulnerabilityRequest - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - - vuln := &database.Vulnerability{ - ConversationID: req.ConversationID, - Title: req.Title, - Description: req.Description, - Severity: req.Severity, - Status: req.Status, - Type: req.Type, - Target: req.Target, - Proof: req.Proof, - Impact: req.Impact, - Recommendation: req.Recommendation, - } - - created, err := h.db.CreateVulnerability(vuln) - if err != nil { - h.logger.Error("创建漏洞失败", zap.Error(err)) - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - - c.JSON(http.StatusOK, created) -} - -// GetVulnerability 获取漏洞 -func (h *VulnerabilityHandler) GetVulnerability(c *gin.Context) { - id := c.Param("id") - - vuln, err := h.db.GetVulnerability(id) - if err != nil { - h.logger.Error("获取漏洞失败", zap.Error(err)) - c.JSON(http.StatusNotFound, gin.H{"error": "漏洞不存在"}) - return - } - - c.JSON(http.StatusOK, vuln) -} - -// ListVulnerabilitiesResponse 漏洞列表响应 -type ListVulnerabilitiesResponse struct { - Vulnerabilities []*database.Vulnerability `json:"vulnerabilities"` - Total int `json:"total"` - Page int `json:"page"` - PageSize int `json:"page_size"` - TotalPages int `json:"total_pages"` -} - -// ListVulnerabilities 列出漏洞 -func (h *VulnerabilityHandler) ListVulnerabilities(c *gin.Context) { - limitStr := c.DefaultQuery("limit", "20") - offsetStr := c.DefaultQuery("offset", "0") - pageStr := c.Query("page") - id := c.Query("id") - conversationID := c.Query("conversation_id") - severity := c.Query("severity") - status := c.Query("status") - - limit, _ := strconv.Atoi(limitStr) - offset, _ := strconv.Atoi(offsetStr) - page := 1 - - // 如果提供了page参数,优先使用page计算offset - if pageStr != "" { - if p, err := strconv.Atoi(pageStr); err == nil && p > 0 { - page = p - offset = (page - 1) * limit - } - } - - if limit <= 0 || limit > 100 { - limit = 20 - } - if offset < 0 { - offset = 0 - } - - // 获取总数 - total, err := h.db.CountVulnerabilities(id, conversationID, severity, status) - if err != nil { - h.logger.Error("获取漏洞总数失败", zap.Error(err)) - // 继续执行,使用0作为总数 - total = 0 - } - - // 获取漏洞列表 - vulnerabilities, err := h.db.ListVulnerabilities(limit, offset, id, conversationID, severity, status) - if err != nil { - h.logger.Error("获取漏洞列表失败", zap.Error(err)) - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - - // 计算总页数 - totalPages := (total + limit - 1) / limit - if totalPages == 0 { - totalPages = 1 - } - - // 如果使用offset计算page,需要重新计算 - if pageStr == "" { - page = (offset / limit) + 1 - } - - response := ListVulnerabilitiesResponse{ - Vulnerabilities: vulnerabilities, - Total: total, - Page: page, - PageSize: limit, - TotalPages: totalPages, - } - - c.JSON(http.StatusOK, response) -} - -// UpdateVulnerabilityRequest 更新漏洞请求 -type UpdateVulnerabilityRequest struct { - Title string `json:"title"` - Description string `json:"description"` - Severity string `json:"severity"` - Status string `json:"status"` - Type string `json:"type"` - Target string `json:"target"` - Proof string `json:"proof"` - Impact string `json:"impact"` - Recommendation string `json:"recommendation"` -} - -// UpdateVulnerability 更新漏洞 -func (h *VulnerabilityHandler) UpdateVulnerability(c *gin.Context) { - id := c.Param("id") - - var req UpdateVulnerabilityRequest - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - - // 获取现有漏洞 - existing, err := h.db.GetVulnerability(id) - if err != nil { - c.JSON(http.StatusNotFound, gin.H{"error": "漏洞不存在"}) - return - } - - // 更新字段 - if req.Title != "" { - existing.Title = req.Title - } - if req.Description != "" { - existing.Description = req.Description - } - if req.Severity != "" { - existing.Severity = req.Severity - } - if req.Status != "" { - existing.Status = req.Status - } - if req.Type != "" { - existing.Type = req.Type - } - if req.Target != "" { - existing.Target = req.Target - } - if req.Proof != "" { - existing.Proof = req.Proof - } - if req.Impact != "" { - existing.Impact = req.Impact - } - if req.Recommendation != "" { - existing.Recommendation = req.Recommendation - } - - if err := h.db.UpdateVulnerability(id, existing); err != nil { - h.logger.Error("更新漏洞失败", zap.Error(err)) - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - - // 返回更新后的漏洞 - updated, err := h.db.GetVulnerability(id) - if err != nil { - h.logger.Error("获取更新后的漏洞失败", zap.Error(err)) - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - - c.JSON(http.StatusOK, updated) -} - -// DeleteVulnerability 删除漏洞 -func (h *VulnerabilityHandler) DeleteVulnerability(c *gin.Context) { - id := c.Param("id") - - if err := h.db.DeleteVulnerability(id); err != nil { - h.logger.Error("删除漏洞失败", zap.Error(err)) - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - - c.JSON(http.StatusOK, gin.H{"message": "删除成功"}) -} - -// GetVulnerabilityStats 获取漏洞统计 -func (h *VulnerabilityHandler) GetVulnerabilityStats(c *gin.Context) { - conversationID := c.Query("conversation_id") - - stats, err := h.db.GetVulnerabilityStats(conversationID) - if err != nil { - h.logger.Error("获取漏洞统计失败", zap.Error(err)) - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - - c.JSON(http.StatusOK, stats) -} - diff --git a/handler/webshell.go b/handler/webshell.go deleted file mode 100644 index 06da5d61..00000000 --- a/handler/webshell.go +++ /dev/null @@ -1,706 +0,0 @@ -package handler - -import ( - "bytes" - "database/sql" - "encoding/json" - "io" - "net/http" - "net/url" - "strings" - "time" - - "cyberstrike-ai/internal/database" - - "github.com/gin-gonic/gin" - "github.com/google/uuid" - "go.uber.org/zap" -) - -// WebShellHandler 代理执行 WebShell 命令(类似冰蝎/蚁剑),避免前端跨域并统一构建请求 -type WebShellHandler struct { - logger *zap.Logger - client *http.Client - db *database.DB -} - -// NewWebShellHandler 创建 WebShell 处理器,db 可为 nil(连接配置接口将不可用) -func NewWebShellHandler(logger *zap.Logger, db *database.DB) *WebShellHandler { - return &WebShellHandler{ - logger: logger, - client: &http.Client{ - Timeout: 30 * time.Second, - Transport: &http.Transport{DisableKeepAlives: false}, - }, - db: db, - } -} - -// CreateConnectionRequest 创建连接请求 -type CreateConnectionRequest struct { - URL string `json:"url" binding:"required"` - Password string `json:"password"` - Type string `json:"type"` - Method string `json:"method"` - CmdParam string `json:"cmd_param"` - Remark string `json:"remark"` -} - -// UpdateConnectionRequest 更新连接请求 -type UpdateConnectionRequest struct { - URL string `json:"url" binding:"required"` - Password string `json:"password"` - Type string `json:"type"` - Method string `json:"method"` - CmdParam string `json:"cmd_param"` - Remark string `json:"remark"` -} - -// ListConnections 列出所有 WebShell 连接(GET /api/webshell/connections) -func (h *WebShellHandler) ListConnections(c *gin.Context) { - if h.db == nil { - c.JSON(http.StatusServiceUnavailable, gin.H{"error": "database not available"}) - return - } - list, err := h.db.ListWebshellConnections() - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - if list == nil { - list = []database.WebShellConnection{} - } - c.JSON(http.StatusOK, list) -} - -// CreateConnection 创建 WebShell 连接(POST /api/webshell/connections) -func (h *WebShellHandler) CreateConnection(c *gin.Context) { - if h.db == nil { - c.JSON(http.StatusServiceUnavailable, gin.H{"error": "database not available"}) - return - } - var req CreateConnectionRequest - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - req.URL = strings.TrimSpace(req.URL) - if req.URL == "" { - c.JSON(http.StatusBadRequest, gin.H{"error": "url is required"}) - return - } - if _, err := url.Parse(req.URL); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": "invalid url"}) - return - } - method := strings.ToLower(strings.TrimSpace(req.Method)) - if method != "get" && method != "post" { - method = "post" - } - shellType := strings.ToLower(strings.TrimSpace(req.Type)) - if shellType == "" { - shellType = "php" - } - conn := &database.WebShellConnection{ - ID: "ws_" + strings.ReplaceAll(uuid.New().String(), "-", "")[:12], - URL: req.URL, - Password: strings.TrimSpace(req.Password), - Type: shellType, - Method: method, - CmdParam: strings.TrimSpace(req.CmdParam), - Remark: strings.TrimSpace(req.Remark), - CreatedAt: time.Now(), - } - if err := h.db.CreateWebshellConnection(conn); err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - c.JSON(http.StatusOK, conn) -} - -// UpdateConnection 更新 WebShell 连接(PUT /api/webshell/connections/:id) -func (h *WebShellHandler) UpdateConnection(c *gin.Context) { - if h.db == nil { - c.JSON(http.StatusServiceUnavailable, gin.H{"error": "database not available"}) - return - } - id := strings.TrimSpace(c.Param("id")) - if id == "" { - c.JSON(http.StatusBadRequest, gin.H{"error": "id is required"}) - return - } - var req UpdateConnectionRequest - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - req.URL = strings.TrimSpace(req.URL) - if req.URL == "" { - c.JSON(http.StatusBadRequest, gin.H{"error": "url is required"}) - return - } - if _, err := url.Parse(req.URL); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": "invalid url"}) - return - } - method := strings.ToLower(strings.TrimSpace(req.Method)) - if method != "get" && method != "post" { - method = "post" - } - shellType := strings.ToLower(strings.TrimSpace(req.Type)) - if shellType == "" { - shellType = "php" - } - conn := &database.WebShellConnection{ - ID: id, - URL: req.URL, - Password: strings.TrimSpace(req.Password), - Type: shellType, - Method: method, - CmdParam: strings.TrimSpace(req.CmdParam), - Remark: strings.TrimSpace(req.Remark), - } - if err := h.db.UpdateWebshellConnection(conn); err != nil { - if err == sql.ErrNoRows { - c.JSON(http.StatusNotFound, gin.H{"error": "connection not found"}) - return - } - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - updated, _ := h.db.GetWebshellConnection(id) - if updated != nil { - c.JSON(http.StatusOK, updated) - } else { - c.JSON(http.StatusOK, conn) - } -} - -// DeleteConnection 删除 WebShell 连接(DELETE /api/webshell/connections/:id) -func (h *WebShellHandler) DeleteConnection(c *gin.Context) { - if h.db == nil { - c.JSON(http.StatusServiceUnavailable, gin.H{"error": "database not available"}) - return - } - id := strings.TrimSpace(c.Param("id")) - if id == "" { - c.JSON(http.StatusBadRequest, gin.H{"error": "id is required"}) - return - } - if err := h.db.DeleteWebshellConnection(id); err != nil { - if err == sql.ErrNoRows { - c.JSON(http.StatusNotFound, gin.H{"error": "connection not found"}) - return - } - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - c.JSON(http.StatusOK, gin.H{"ok": true}) -} - -// GetConnectionState 获取 WebShell 连接关联的前端持久化状态(GET /api/webshell/connections/:id/state) -func (h *WebShellHandler) GetConnectionState(c *gin.Context) { - if h.db == nil { - c.JSON(http.StatusServiceUnavailable, gin.H{"error": "database not available"}) - return - } - id := strings.TrimSpace(c.Param("id")) - if id == "" { - c.JSON(http.StatusBadRequest, gin.H{"error": "id is required"}) - return - } - conn, err := h.db.GetWebshellConnection(id) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - if conn == nil { - c.JSON(http.StatusNotFound, gin.H{"error": "connection not found"}) - return - } - stateJSON, err := h.db.GetWebshellConnectionState(id) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - var state interface{} - if err := json.Unmarshal([]byte(stateJSON), &state); err != nil { - state = map[string]interface{}{} - } - c.JSON(http.StatusOK, gin.H{"state": state}) -} - -// SaveConnectionState 保存 WebShell 连接关联的前端持久化状态(PUT /api/webshell/connections/:id/state) -func (h *WebShellHandler) SaveConnectionState(c *gin.Context) { - if h.db == nil { - c.JSON(http.StatusServiceUnavailable, gin.H{"error": "database not available"}) - return - } - id := strings.TrimSpace(c.Param("id")) - if id == "" { - c.JSON(http.StatusBadRequest, gin.H{"error": "id is required"}) - return - } - conn, err := h.db.GetWebshellConnection(id) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - if conn == nil { - c.JSON(http.StatusNotFound, gin.H{"error": "connection not found"}) - return - } - var req struct { - State json.RawMessage `json:"state"` - } - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - raw := req.State - if len(raw) == 0 { - raw = json.RawMessage(`{}`) - } - if len(raw) > 2*1024*1024 { - c.JSON(http.StatusBadRequest, gin.H{"error": "state payload too large (max 2MB)"}) - return - } - var anyJSON interface{} - if err := json.Unmarshal(raw, &anyJSON); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": "state must be valid json"}) - return - } - if err := h.db.UpsertWebshellConnectionState(id, string(raw)); err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - c.JSON(http.StatusOK, gin.H{"ok": true}) -} - -// GetAIHistory 获取指定 WebShell 连接的 AI 助手对话历史(GET /api/webshell/connections/:id/ai-history) -func (h *WebShellHandler) GetAIHistory(c *gin.Context) { - if h.db == nil { - c.JSON(http.StatusServiceUnavailable, gin.H{"error": "database not available"}) - return - } - id := strings.TrimSpace(c.Param("id")) - if id == "" { - c.JSON(http.StatusBadRequest, gin.H{"error": "id is required"}) - return - } - conv, err := h.db.GetConversationByWebshellConnectionID(id) - if err != nil { - h.logger.Warn("获取 WebShell AI 对话失败", zap.String("connectionId", id), zap.Error(err)) - c.JSON(http.StatusOK, gin.H{"conversationId": nil, "messages": []database.Message{}}) - return - } - if conv == nil { - c.JSON(http.StatusOK, gin.H{"conversationId": nil, "messages": []database.Message{}}) - return - } - c.JSON(http.StatusOK, gin.H{"conversationId": conv.ID, "messages": conv.Messages}) -} - -// ListAIConversations 列出该 WebShell 连接下的所有 AI 对话(供侧边栏) -func (h *WebShellHandler) ListAIConversations(c *gin.Context) { - if h.db == nil { - c.JSON(http.StatusServiceUnavailable, gin.H{"error": "database not available"}) - return - } - id := strings.TrimSpace(c.Param("id")) - if id == "" { - c.JSON(http.StatusBadRequest, gin.H{"error": "id is required"}) - return - } - list, err := h.db.ListConversationsByWebshellConnectionID(id) - if err != nil { - h.logger.Warn("列出 WebShell AI 对话失败", zap.String("connectionId", id), zap.Error(err)) - c.JSON(http.StatusOK, []database.WebShellConversationItem{}) - return - } - if list == nil { - list = []database.WebShellConversationItem{} - } - c.JSON(http.StatusOK, list) -} - -// ExecRequest 执行命令请求(前端传入连接信息 + 命令) -type ExecRequest struct { - URL string `json:"url" binding:"required"` - Password string `json:"password"` - Type string `json:"type"` // php, asp, aspx, jsp, custom - Method string `json:"method"` // GET 或 POST,空则默认 POST - CmdParam string `json:"cmd_param"` // 命令参数名,如 cmd/xxx,空则默认 cmd - Command string `json:"command" binding:"required"` -} - -// ExecResponse 执行命令响应 -type ExecResponse struct { - OK bool `json:"ok"` - Output string `json:"output"` - Error string `json:"error,omitempty"` - HTTPCode int `json:"http_code,omitempty"` -} - -// FileOpRequest 文件操作请求 -type FileOpRequest struct { - URL string `json:"url" binding:"required"` - Password string `json:"password"` - Type string `json:"type"` - Method string `json:"method"` // GET 或 POST,空则默认 POST - CmdParam string `json:"cmd_param"` // 命令参数名,如 cmd/xxx,空则默认 cmd - Action string `json:"action" binding:"required"` // list, read, delete, write, mkdir, rename, upload, upload_chunk - Path string `json:"path"` - TargetPath string `json:"target_path"` // rename 时目标路径 - Content string `json:"content"` // write/upload 时使用 - ChunkIndex int `json:"chunk_index"` // upload_chunk 时,0 表示首块 -} - -// FileOpResponse 文件操作响应 -type FileOpResponse struct { - OK bool `json:"ok"` - Output string `json:"output"` - Error string `json:"error,omitempty"` -} - -func (h *WebShellHandler) Exec(c *gin.Context) { - var req ExecRequest - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - req.URL = strings.TrimSpace(req.URL) - req.Command = strings.TrimSpace(req.Command) - if req.URL == "" || req.Command == "" { - c.JSON(http.StatusBadRequest, gin.H{"error": "url and command are required"}) - return - } - - parsed, err := url.Parse(req.URL) - if err != nil || (parsed.Scheme != "http" && parsed.Scheme != "https") { - c.JSON(http.StatusBadRequest, gin.H{"error": "invalid url: only http(s) allowed"}) - return - } - - useGET := strings.ToUpper(strings.TrimSpace(req.Method)) == "GET" - cmdParam := strings.TrimSpace(req.CmdParam) - if cmdParam == "" { - cmdParam = "cmd" - } - var httpReq *http.Request - if useGET { - targetURL := h.buildExecURL(req.URL, req.Type, req.Password, cmdParam, req.Command) - httpReq, err = http.NewRequest(http.MethodGet, targetURL, nil) - } else { - body := h.buildExecBody(req.Type, req.Password, cmdParam, req.Command) - httpReq, err = http.NewRequest(http.MethodPost, req.URL, bytes.NewReader(body)) - httpReq.Header.Set("Content-Type", "application/x-www-form-urlencoded") - } - if err != nil { - h.logger.Warn("webshell exec NewRequest", zap.Error(err)) - c.JSON(http.StatusInternalServerError, ExecResponse{OK: false, Error: err.Error()}) - return - } - httpReq.Header.Set("User-Agent", "Mozilla/5.0 (compatible; CyberStrikeAI-WebShell/1.0)") - - resp, err := h.client.Do(httpReq) - if err != nil { - h.logger.Warn("webshell exec Do", zap.String("url", req.URL), zap.Error(err)) - c.JSON(http.StatusOK, ExecResponse{OK: false, Error: err.Error()}) - return - } - defer resp.Body.Close() - - out, _ := io.ReadAll(resp.Body) - output := string(out) - httpCode := resp.StatusCode - - c.JSON(http.StatusOK, ExecResponse{ - OK: resp.StatusCode == http.StatusOK, - Output: output, - HTTPCode: httpCode, - }) -} - -// buildExecBody 按常见 WebShell 约定构建 POST 体(多数使用 pass + cmd,可配置命令参数名) -func (h *WebShellHandler) buildExecBody(shellType, password, cmdParam, command string) []byte { - form := h.execParams(shellType, password, cmdParam, command) - return []byte(form.Encode()) -} - -// buildExecURL 构建 GET 请求的完整 URL(baseURL + ?pass=xxx&cmd=yyy,cmd 可配置) -func (h *WebShellHandler) buildExecURL(baseURL, shellType, password, cmdParam, command string) string { - form := h.execParams(shellType, password, cmdParam, command) - if parsed, err := url.Parse(baseURL); err == nil { - parsed.RawQuery = form.Encode() - return parsed.String() - } - return baseURL + "?" + form.Encode() -} - -func (h *WebShellHandler) execParams(shellType, password, cmdParam, command string) url.Values { - shellType = strings.ToLower(strings.TrimSpace(shellType)) - if shellType == "" { - shellType = "php" - } - if strings.TrimSpace(cmdParam) == "" { - cmdParam = "cmd" - } - form := url.Values{} - form.Set("pass", password) - form.Set(cmdParam, command) - return form -} - -func (h *WebShellHandler) FileOp(c *gin.Context) { - var req FileOpRequest - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - req.URL = strings.TrimSpace(req.URL) - req.Action = strings.ToLower(strings.TrimSpace(req.Action)) - if req.URL == "" || req.Action == "" { - c.JSON(http.StatusBadRequest, gin.H{"error": "url and action are required"}) - return - } - - parsed, err := url.Parse(req.URL) - if err != nil || (parsed.Scheme != "http" && parsed.Scheme != "https") { - c.JSON(http.StatusBadRequest, gin.H{"error": "invalid url: only http(s) allowed"}) - return - } - - // 通过执行系统命令实现文件操作(与通用一句话兼容) - var command string - shellType := strings.ToLower(strings.TrimSpace(req.Type)) - switch req.Action { - case "list": - path := strings.TrimSpace(req.Path) - if path == "" { - path = "." - } - if shellType == "asp" || shellType == "aspx" { - command = "dir " + h.escapePath(path) - } else { - command = "ls -la " + h.escapePath(path) - } - case "read": - if shellType == "asp" || shellType == "aspx" { - command = "type " + h.escapePath(strings.TrimSpace(req.Path)) - } else { - command = "cat " + h.escapePath(strings.TrimSpace(req.Path)) - } - case "delete": - if shellType == "asp" || shellType == "aspx" { - command = "del " + h.escapePath(strings.TrimSpace(req.Path)) - } else { - command = "rm -f " + h.escapePath(strings.TrimSpace(req.Path)) - } - case "write": - path := h.escapePath(strings.TrimSpace(req.Path)) - command = "echo " + h.escapeForEcho(req.Content) + " > " + path - case "mkdir": - path := strings.TrimSpace(req.Path) - if path == "" { - c.JSON(http.StatusBadRequest, FileOpResponse{OK: false, Error: "path is required for mkdir"}) - return - } - if shellType == "asp" || shellType == "aspx" { - command = "md " + h.escapePath(path) - } else { - command = "mkdir -p " + h.escapePath(path) - } - case "rename": - oldPath := strings.TrimSpace(req.Path) - newPath := strings.TrimSpace(req.TargetPath) - if oldPath == "" || newPath == "" { - c.JSON(http.StatusBadRequest, FileOpResponse{OK: false, Error: "path and target_path are required for rename"}) - return - } - if shellType == "asp" || shellType == "aspx" { - command = "move /y " + h.escapePath(oldPath) + " " + h.escapePath(newPath) - } else { - command = "mv " + h.escapePath(oldPath) + " " + h.escapePath(newPath) - } - case "upload": - path := strings.TrimSpace(req.Path) - if path == "" { - c.JSON(http.StatusBadRequest, FileOpResponse{OK: false, Error: "path is required for upload"}) - return - } - if len(req.Content) > 512*1024 { - c.JSON(http.StatusBadRequest, FileOpResponse{OK: false, Error: "upload content too large (max 512KB base64)"}) - return - } - // base64 仅含 A-Za-z0-9+/=,用单引号包裹安全 - command = "echo " + "'" + req.Content + "'" + " | base64 -d > " + h.escapePath(path) - case "upload_chunk": - path := strings.TrimSpace(req.Path) - if path == "" { - c.JSON(http.StatusBadRequest, FileOpResponse{OK: false, Error: "path is required for upload_chunk"}) - return - } - redir := ">>" - if req.ChunkIndex == 0 { - redir = ">" - } - command = "echo " + "'" + req.Content + "'" + " | base64 -d " + redir + " " + h.escapePath(path) - default: - c.JSON(http.StatusBadRequest, FileOpResponse{OK: false, Error: "unsupported action: " + req.Action}) - return - } - - useGET := strings.ToUpper(strings.TrimSpace(req.Method)) == "GET" - cmdParam := strings.TrimSpace(req.CmdParam) - if cmdParam == "" { - cmdParam = "cmd" - } - var httpReq *http.Request - if useGET { - targetURL := h.buildExecURL(req.URL, req.Type, req.Password, cmdParam, command) - httpReq, err = http.NewRequest(http.MethodGet, targetURL, nil) - } else { - body := h.buildExecBody(req.Type, req.Password, cmdParam, command) - httpReq, err = http.NewRequest(http.MethodPost, req.URL, bytes.NewReader(body)) - httpReq.Header.Set("Content-Type", "application/x-www-form-urlencoded") - } - if err != nil { - c.JSON(http.StatusInternalServerError, FileOpResponse{OK: false, Error: err.Error()}) - return - } - httpReq.Header.Set("User-Agent", "Mozilla/5.0 (compatible; CyberStrikeAI-WebShell/1.0)") - - resp, err := h.client.Do(httpReq) - if err != nil { - c.JSON(http.StatusOK, FileOpResponse{OK: false, Error: err.Error()}) - return - } - defer resp.Body.Close() - - out, _ := io.ReadAll(resp.Body) - output := string(out) - - c.JSON(http.StatusOK, FileOpResponse{ - OK: resp.StatusCode == http.StatusOK, - Output: output, - }) -} - -func (h *WebShellHandler) escapePath(p string) string { - if p == "" { - return "." - } - // 简单转义空格与敏感字符,避免命令注入 - return "'" + strings.ReplaceAll(p, "'", "'\\''") + "'" -} - -func (h *WebShellHandler) escapeForEcho(s string) string { - // 仅用于 write:base64 写入更安全,这里简单用单引号包裹 - return "'" + strings.ReplaceAll(s, "'", "'\"'\"'") + "'" -} - -// ExecWithConnection 在指定 WebShell 连接上执行命令(供 MCP/Agent 等非 HTTP 调用) -func (h *WebShellHandler) ExecWithConnection(conn *database.WebShellConnection, command string) (output string, ok bool, errMsg string) { - if conn == nil { - return "", false, "connection is nil" - } - command = strings.TrimSpace(command) - if command == "" { - return "", false, "command is required" - } - useGET := strings.ToUpper(strings.TrimSpace(conn.Method)) == "GET" - cmdParam := strings.TrimSpace(conn.CmdParam) - if cmdParam == "" { - cmdParam = "cmd" - } - var httpReq *http.Request - var err error - if useGET { - targetURL := h.buildExecURL(conn.URL, conn.Type, conn.Password, cmdParam, command) - httpReq, err = http.NewRequest(http.MethodGet, targetURL, nil) - } else { - body := h.buildExecBody(conn.Type, conn.Password, cmdParam, command) - httpReq, err = http.NewRequest(http.MethodPost, conn.URL, bytes.NewReader(body)) - httpReq.Header.Set("Content-Type", "application/x-www-form-urlencoded") - } - if err != nil { - return "", false, err.Error() - } - httpReq.Header.Set("User-Agent", "Mozilla/5.0 (compatible; CyberStrikeAI-WebShell/1.0)") - resp, err := h.client.Do(httpReq) - if err != nil { - return "", false, err.Error() - } - defer resp.Body.Close() - out, _ := io.ReadAll(resp.Body) - return string(out), resp.StatusCode == http.StatusOK, "" -} - -// FileOpWithConnection 在指定 WebShell 连接上执行文件操作(供 MCP/Agent 调用),支持 list / read / write -func (h *WebShellHandler) FileOpWithConnection(conn *database.WebShellConnection, action, path, content, targetPath string) (output string, ok bool, errMsg string) { - if conn == nil { - return "", false, "connection is nil" - } - action = strings.ToLower(strings.TrimSpace(action)) - shellType := strings.ToLower(strings.TrimSpace(conn.Type)) - if shellType == "" { - shellType = "php" - } - var command string - switch action { - case "list": - if path == "" { - path = "." - } - if shellType == "asp" || shellType == "aspx" { - command = "dir " + h.escapePath(strings.TrimSpace(path)) - } else { - command = "ls -la " + h.escapePath(strings.TrimSpace(path)) - } - case "read": - path = strings.TrimSpace(path) - if path == "" { - return "", false, "path is required for read" - } - if shellType == "asp" || shellType == "aspx" { - command = "type " + h.escapePath(path) - } else { - command = "cat " + h.escapePath(path) - } - case "write": - path = strings.TrimSpace(path) - if path == "" { - return "", false, "path is required for write" - } - command = "echo " + h.escapeForEcho(content) + " > " + h.escapePath(path) - default: - return "", false, "unsupported action: " + action + " (supported: list, read, write)" - } - useGET := strings.ToUpper(strings.TrimSpace(conn.Method)) == "GET" - cmdParam := strings.TrimSpace(conn.CmdParam) - if cmdParam == "" { - cmdParam = "cmd" - } - var httpReq *http.Request - var err error - if useGET { - targetURL := h.buildExecURL(conn.URL, conn.Type, conn.Password, cmdParam, command) - httpReq, err = http.NewRequest(http.MethodGet, targetURL, nil) - } else { - body := h.buildExecBody(conn.Type, conn.Password, cmdParam, command) - httpReq, err = http.NewRequest(http.MethodPost, conn.URL, bytes.NewReader(body)) - httpReq.Header.Set("Content-Type", "application/x-www-form-urlencoded") - } - if err != nil { - return "", false, err.Error() - } - httpReq.Header.Set("User-Agent", "Mozilla/5.0 (compatible; CyberStrikeAI-WebShell/1.0)") - resp, err := h.client.Do(httpReq) - if err != nil { - return "", false, err.Error() - } - defer resp.Body.Close() - out, _ := io.ReadAll(resp.Body) - return string(out), resp.StatusCode == http.StatusOK, "" -}