package app import ( "context" "fmt" "strings" "cyberstrike-ai/internal/agent" "cyberstrike-ai/internal/database" "cyberstrike-ai/internal/mcp" "cyberstrike-ai/internal/mcp/builtin" "go.uber.org/zap" ) func conversationIDFromToolCtx(ctx context.Context) string { if id := agent.ConversationIDFromContext(ctx); id != "" { return id } return mcp.MCPConversationIDFromContext(ctx) } // canAccessVulnerability 校验当前对话是否有权查看该漏洞(默认项目隔离,未绑项目则仅本会话)。 func canAccessVulnerability(vuln *database.Vulnerability, convID, projectID string) bool { if vuln == nil || convID == "" { return false } if projectID != "" { if strings.TrimSpace(vuln.ProjectID) == projectID { return true } // 历史记录:写入时尚未绑定 project_id,但属于同一会话 if strings.TrimSpace(vuln.ProjectID) == "" && vuln.ConversationID == convID { return true } return false } return vuln.ConversationID == convID } func buildVulnerabilityListFilter(db *database.DB, ctx context.Context, args map[string]interface{}) (database.VulnerabilityListFilter, string, error) { convID := conversationIDFromToolCtx(ctx) if convID == "" { return database.VulnerabilityListFilter{}, "", fmt.Errorf("无法确定当前对话,请在对话上下文中使用漏洞查询工具") } projectID := "" if pid, err := db.GetConversationProjectID(convID); err == nil { projectID = strings.TrimSpace(pid) } scope := strings.TrimSpace(strArg(args, "scope")) if scope == "" { if projectID != "" { scope = "project" } else { scope = "conversation" } } filter := database.VulnerabilityListFilter{ Severity: strings.TrimSpace(strArg(args, "severity")), Status: strings.TrimSpace(strArg(args, "status")), } if q := strings.TrimSpace(strArg(args, "q")); q != "" { filter.Search = q } else { filter.Search = strings.TrimSpace(strArg(args, "search")) } var scopeLabel string switch scope { case "project": if projectID == "" { return filter, "", fmt.Errorf("当前对话未绑定项目,无法按项目列出漏洞;请使用 scope=conversation,或先在对话中绑定项目") } filter.ProjectID = projectID scopeLabel = fmt.Sprintf("项目 %s", projectID) case "conversation": filter.ConversationID = convID scopeLabel = fmt.Sprintf("会话 %s", convID) default: return filter, "", fmt.Errorf("scope 仅支持 project 或 conversation,当前值: %s", scope) } return filter, scopeLabel, nil } func formatVulnerabilityListItem(v *database.Vulnerability) string { line := fmt.Sprintf("- id=%s | %s | %s | %s", v.ID, v.Severity, v.Status, v.Title) if v.Type != "" { line += fmt.Sprintf(" | type=%s", v.Type) } if v.Target != "" { line += fmt.Sprintf(" | target=%s", truncateRunes(v.Target, 80)) } return line } func formatVulnerabilityDetail(v *database.Vulnerability) string { var b strings.Builder b.WriteString(fmt.Sprintf("漏洞ID: %s\n", v.ID)) b.WriteString(fmt.Sprintf("标题: %s\n", v.Title)) b.WriteString(fmt.Sprintf("严重程度: %s\n", v.Severity)) b.WriteString(fmt.Sprintf("状态: %s\n", v.Status)) if v.Type != "" { b.WriteString(fmt.Sprintf("类型: %s\n", v.Type)) } if v.Target != "" { b.WriteString(fmt.Sprintf("目标: %s\n", v.Target)) } if v.ProjectID != "" { b.WriteString(fmt.Sprintf("项目ID: %s\n", v.ProjectID)) } b.WriteString(fmt.Sprintf("会话ID: %s\n", v.ConversationID)) if !v.CreatedAt.IsZero() { b.WriteString(fmt.Sprintf("创建时间: %s\n", v.CreatedAt.Format("2006-01-02 15:04:05"))) } if v.Description != "" { b.WriteString("\n--- 描述 ---\n") b.WriteString(v.Description) b.WriteString("\n") } if v.Proof != "" { b.WriteString("\n--- 证明(POC) ---\n") b.WriteString(v.Proof) b.WriteString("\n") } if v.Impact != "" { b.WriteString("\n--- 影响 ---\n") b.WriteString(v.Impact) b.WriteString("\n") } if v.Recommendation != "" { b.WriteString("\n--- 修复建议 ---\n") b.WriteString(v.Recommendation) b.WriteString("\n") } return b.String() } func truncateRunes(s string, max int) string { r := []rune(s) if len(r) <= max { return s } return string(r[:max]) + "…" } // registerVulnerabilityTools 注册漏洞记录与查询 MCP 工具。 func registerVulnerabilityTools(mcpServer *mcp.Server, db *database.DB, logger *zap.Logger) { registerRecordVulnerabilityTool(mcpServer, db, logger) registerListVulnerabilitiesTool(mcpServer, db, logger) registerGetVulnerabilityTool(mcpServer, db, logger) if logger != nil { logger.Info("漏洞 MCP 工具注册成功", zap.Strings("tools", []string{ builtin.ToolRecordVulnerability, builtin.ToolListVulnerabilities, builtin.ToolGetVulnerability, })) } } func registerRecordVulnerabilityTool(mcpServer *mcp.Server, db *database.DB, logger *zap.Logger) { tool := mcp.Tool{ Name: builtin.ToolRecordVulnerability, Description: "记录发现的漏洞详情到漏洞管理系统。当发现有效漏洞时,使用此工具记录漏洞信息,包括标题、描述、严重程度、类型、目标、证明、影响和建议等。记录前可先 list_vulnerabilities 避免重复。", ShortDescription: "记录发现的漏洞详情到漏洞管理系统", InputSchema: map[string]interface{}{ "type": "object", "properties": map[string]interface{}{ "title": map[string]interface{}{ "type": "string", "description": "漏洞标题(必需)", }, "description": map[string]interface{}{ "type": "string", "description": "漏洞详细描述", }, "severity": map[string]interface{}{ "type": "string", "description": "漏洞严重程度:critical(严重)、high(高)、medium(中)、low(低)、info(信息)", "enum": []string{"critical", "high", "medium", "low", "info"}, }, "vulnerability_type": map[string]interface{}{ "type": "string", "description": "漏洞类型,如:SQL注入、XSS、CSRF、命令注入等", }, "target": map[string]interface{}{ "type": "string", "description": "受影响的目标(URL、IP地址、服务等)", }, "proof": map[string]interface{}{ "type": "string", "description": "漏洞证明(POC、截图、请求/响应等)", }, "impact": map[string]interface{}{ "type": "string", "description": "漏洞影响说明", }, "recommendation": map[string]interface{}{ "type": "string", "description": "修复建议", }, }, "required": []string{"title", "severity"}, }, } mcpServer.RegisterTool(tool, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) { conversationID := strings.TrimSpace(strArg(args, "conversation_id")) if conversationID == "" { conversationID = conversationIDFromToolCtx(ctx) } if conversationID == "" { return textResult("错误: conversation_id 未设置。这是系统错误,请重试。", true), nil } title := strings.TrimSpace(strArg(args, "title")) if title == "" { return textResult("错误: title 参数必需且不能为空", true), nil } severity := strings.TrimSpace(strArg(args, "severity")) if severity == "" { return textResult("错误: severity 参数必需且不能为空", true), nil } validSeverities := map[string]bool{ "critical": true, "high": true, "medium": true, "low": true, "info": true, } if !validSeverities[severity] { return textResult(fmt.Sprintf("错误: severity 必须是 critical、high、medium、low 或 info 之一,当前值: %s", severity), true), nil } projectID := "" if pid, perr := db.GetConversationProjectID(conversationID); perr == nil { projectID = strings.TrimSpace(pid) } vuln := &database.Vulnerability{ ConversationID: conversationID, ProjectID: projectID, Title: title, Description: strArg(args, "description"), Severity: severity, Status: "open", Type: strArg(args, "vulnerability_type"), Target: strArg(args, "target"), Proof: strArg(args, "proof"), Impact: strArg(args, "impact"), Recommendation: strArg(args, "recommendation"), } created, err := db.CreateVulnerability(vuln) if err != nil { if logger != nil { logger.Error("记录漏洞失败", zap.Error(err)) } return textResult(fmt.Sprintf("记录漏洞失败: %v", err), true), nil } if logger != nil { logger.Info("漏洞记录成功", zap.String("id", created.ID), zap.String("title", created.Title), zap.String("severity", created.Severity), zap.String("conversation_id", conversationID), ) } return textResult(fmt.Sprintf("漏洞已成功记录!\n\n漏洞ID: %s\n标题: %s\n严重程度: %s\n状态: %s\n\n可使用 get_vulnerability(id) 查看详情,或 list_vulnerabilities 查看列表。", created.ID, created.Title, created.Severity, created.Status), false), nil }) } func registerListVulnerabilitiesTool(mcpServer *mcp.Server, db *database.DB, logger *zap.Logger) { tool := mcp.Tool{ Name: builtin.ToolListVulnerabilities, Description: "列出当前授权范围内的漏洞(摘要)。默认:对话已绑定项目时列出该项目下全部漏洞;未绑项目时仅列出当前会话漏洞。可用 scope=conversation 仅看本会话。记录新漏洞前建议先调用以避免重复。", ShortDescription: "列出漏洞(默认当前项目)", InputSchema: map[string]interface{}{ "type": "object", "properties": map[string]interface{}{ "scope": map[string]interface{}{ "type": "string", "description": "范围:project(默认,需绑定项目)| conversation(仅当前会话)", "enum": []string{"project", "conversation"}, }, "severity": map[string]interface{}{ "type": "string", "description": "按严重程度筛选:critical、high、medium、low、info", "enum": []string{"critical", "high", "medium", "low", "info"}, }, "status": map[string]interface{}{ "type": "string", "description": "按状态筛选:open、confirmed、fixed、false_positive", "enum": []string{"open", "confirmed", "fixed", "false_positive"}, }, "q": map[string]interface{}{ "type": "string", "description": "关键词搜索(标题、描述、类型、目标等)", }, "limit": map[string]interface{}{ "type": "integer", "description": "返回条数上限,默认 30,最大 100", }, "offset": map[string]interface{}{ "type": "integer", "description": "分页偏移,默认 0", }, }, }, } mcpServer.RegisterTool(tool, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) { filter, scopeLabel, err := buildVulnerabilityListFilter(db, ctx, args) if err != nil { return textResult("错误: "+err.Error(), true), nil } limit := intArg(args, "limit", 30) if limit <= 0 || limit > 100 { limit = 30 } offset := intArg(args, "offset", 0) if offset < 0 { offset = 0 } total, err := db.CountVulnerabilities(filter) if err != nil { if logger != nil { logger.Warn("统计漏洞失败", zap.Error(err)) } total = 0 } list, err := db.ListVulnerabilities(limit, offset, filter) if err != nil { return textResult("错误: "+err.Error(), true), nil } var b strings.Builder b.WriteString(fmt.Sprintf("范围: %s\n总计: %d | 本页: %d 条 (limit=%d offset=%d)\n\n", scopeLabel, total, len(list), limit, offset)) if len(list) == 0 { b.WriteString("(暂无漏洞记录)\n") } else { for _, v := range list { b.WriteString(formatVulnerabilityListItem(v)) b.WriteString("\n") } if total > offset+len(list) { b.WriteString(fmt.Sprintf("\n(还有更多,可增大 offset 或使用 q/severity/status 筛选)\n")) } } b.WriteString("\n需要 POC 与完整字段请对具体 id 调用 get_vulnerability。") return textResult(b.String(), false), nil }) } func registerGetVulnerabilityTool(mcpServer *mcp.Server, db *database.DB, logger *zap.Logger) { tool := mcp.Tool{ Name: builtin.ToolGetVulnerability, Description: "按漏洞 ID 获取完整详情(含 POC、影响、修复建议)。仅能访问当前项目或当前会话下的漏洞(与 list_vulnerabilities 授权范围一致)。", ShortDescription: "按 ID 获取漏洞详情", InputSchema: map[string]interface{}{ "type": "object", "properties": map[string]interface{}{ "id": map[string]interface{}{ "type": "string", "description": "漏洞 ID(list_vulnerabilities 返回的 id)", }, }, "required": []string{"id"}, }, } mcpServer.RegisterTool(tool, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) { convID := conversationIDFromToolCtx(ctx) if convID == "" { return textResult("错误: 无法确定当前对话,请在对话上下文中使用本工具", true), nil } id := strings.TrimSpace(strArg(args, "id")) if id == "" { return textResult("错误: id 必填", true), nil } vuln, err := db.GetVulnerability(id) if err != nil { return textResult("错误: 漏洞不存在或查询失败", true), nil } projectID := "" if pid, perr := db.GetConversationProjectID(convID); perr == nil { projectID = strings.TrimSpace(pid) } if !canAccessVulnerability(vuln, convID, projectID) { return textResult("错误: 无权访问该漏洞(仅可查看当前项目或当前会话下的记录)", true), nil } return textResult(formatVulnerabilityDetail(vuln), false), nil }) }