mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-05-27 01:32:26 +02:00
406 lines
14 KiB
Go
406 lines
14 KiB
Go
package app
|
||
|
||
import (
|
||
"context"
|
||
"fmt"
|
||
"strings"
|
||
|
||
"cyberstrike-ai/internal/agent"
|
||
"cyberstrike-ai/internal/database"
|
||
"cyberstrike-ai/internal/mcp"
|
||
"cyberstrike-ai/internal/mcp/builtin"
|
||
|
||
"go.uber.org/zap"
|
||
)
|
||
|
||
func conversationIDFromToolCtx(ctx context.Context) string {
|
||
if id := agent.ConversationIDFromContext(ctx); id != "" {
|
||
return id
|
||
}
|
||
return mcp.MCPConversationIDFromContext(ctx)
|
||
}
|
||
|
||
// canAccessVulnerability 校验当前对话是否有权查看该漏洞(默认项目隔离,未绑项目则仅本会话)。
|
||
func canAccessVulnerability(vuln *database.Vulnerability, convID, projectID string) bool {
|
||
if vuln == nil || convID == "" {
|
||
return false
|
||
}
|
||
if projectID != "" {
|
||
if strings.TrimSpace(vuln.ProjectID) == projectID {
|
||
return true
|
||
}
|
||
// 历史记录:写入时尚未绑定 project_id,但属于同一会话
|
||
if strings.TrimSpace(vuln.ProjectID) == "" && vuln.ConversationID == convID {
|
||
return true
|
||
}
|
||
return false
|
||
}
|
||
return vuln.ConversationID == convID
|
||
}
|
||
|
||
func buildVulnerabilityListFilter(db *database.DB, ctx context.Context, args map[string]interface{}) (database.VulnerabilityListFilter, string, error) {
|
||
convID := conversationIDFromToolCtx(ctx)
|
||
if convID == "" {
|
||
return database.VulnerabilityListFilter{}, "", fmt.Errorf("无法确定当前对话,请在对话上下文中使用漏洞查询工具")
|
||
}
|
||
|
||
projectID := ""
|
||
if pid, err := db.GetConversationProjectID(convID); err == nil {
|
||
projectID = strings.TrimSpace(pid)
|
||
}
|
||
|
||
scope := strings.TrimSpace(strArg(args, "scope"))
|
||
if scope == "" {
|
||
if projectID != "" {
|
||
scope = "project"
|
||
} else {
|
||
scope = "conversation"
|
||
}
|
||
}
|
||
|
||
filter := database.VulnerabilityListFilter{
|
||
Severity: strings.TrimSpace(strArg(args, "severity")),
|
||
Status: strings.TrimSpace(strArg(args, "status")),
|
||
}
|
||
if q := strings.TrimSpace(strArg(args, "q")); q != "" {
|
||
filter.Search = q
|
||
} else {
|
||
filter.Search = strings.TrimSpace(strArg(args, "search"))
|
||
}
|
||
|
||
var scopeLabel string
|
||
switch scope {
|
||
case "project":
|
||
if projectID == "" {
|
||
return filter, "", fmt.Errorf("当前对话未绑定项目,无法按项目列出漏洞;请使用 scope=conversation,或先在对话中绑定项目")
|
||
}
|
||
filter.ProjectID = projectID
|
||
scopeLabel = fmt.Sprintf("项目 %s", projectID)
|
||
case "conversation":
|
||
filter.ConversationID = convID
|
||
scopeLabel = fmt.Sprintf("会话 %s", convID)
|
||
default:
|
||
return filter, "", fmt.Errorf("scope 仅支持 project 或 conversation,当前值: %s", scope)
|
||
}
|
||
return filter, scopeLabel, nil
|
||
}
|
||
|
||
func formatVulnerabilityListItem(v *database.Vulnerability) string {
|
||
line := fmt.Sprintf("- id=%s | %s | %s | %s", v.ID, v.Severity, v.Status, v.Title)
|
||
if v.Type != "" {
|
||
line += fmt.Sprintf(" | type=%s", v.Type)
|
||
}
|
||
if v.Target != "" {
|
||
line += fmt.Sprintf(" | target=%s", truncateRunes(v.Target, 80))
|
||
}
|
||
return line
|
||
}
|
||
|
||
func formatVulnerabilityDetail(v *database.Vulnerability) string {
|
||
var b strings.Builder
|
||
b.WriteString(fmt.Sprintf("漏洞ID: %s\n", v.ID))
|
||
b.WriteString(fmt.Sprintf("标题: %s\n", v.Title))
|
||
b.WriteString(fmt.Sprintf("严重程度: %s\n", v.Severity))
|
||
b.WriteString(fmt.Sprintf("状态: %s\n", v.Status))
|
||
if v.Type != "" {
|
||
b.WriteString(fmt.Sprintf("类型: %s\n", v.Type))
|
||
}
|
||
if v.Target != "" {
|
||
b.WriteString(fmt.Sprintf("目标: %s\n", v.Target))
|
||
}
|
||
if v.ProjectID != "" {
|
||
b.WriteString(fmt.Sprintf("项目ID: %s\n", v.ProjectID))
|
||
}
|
||
b.WriteString(fmt.Sprintf("会话ID: %s\n", v.ConversationID))
|
||
if !v.CreatedAt.IsZero() {
|
||
b.WriteString(fmt.Sprintf("创建时间: %s\n", v.CreatedAt.Format("2006-01-02 15:04:05")))
|
||
}
|
||
if v.Description != "" {
|
||
b.WriteString("\n--- 描述 ---\n")
|
||
b.WriteString(v.Description)
|
||
b.WriteString("\n")
|
||
}
|
||
if v.Proof != "" {
|
||
b.WriteString("\n--- 证明(POC) ---\n")
|
||
b.WriteString(v.Proof)
|
||
b.WriteString("\n")
|
||
}
|
||
if v.Impact != "" {
|
||
b.WriteString("\n--- 影响 ---\n")
|
||
b.WriteString(v.Impact)
|
||
b.WriteString("\n")
|
||
}
|
||
if v.Recommendation != "" {
|
||
b.WriteString("\n--- 修复建议 ---\n")
|
||
b.WriteString(v.Recommendation)
|
||
b.WriteString("\n")
|
||
}
|
||
return b.String()
|
||
}
|
||
|
||
func truncateRunes(s string, max int) string {
|
||
r := []rune(s)
|
||
if len(r) <= max {
|
||
return s
|
||
}
|
||
return string(r[:max]) + "…"
|
||
}
|
||
|
||
// registerVulnerabilityTools 注册漏洞记录与查询 MCP 工具。
|
||
func registerVulnerabilityTools(mcpServer *mcp.Server, db *database.DB, logger *zap.Logger) {
|
||
registerRecordVulnerabilityTool(mcpServer, db, logger)
|
||
registerListVulnerabilitiesTool(mcpServer, db, logger)
|
||
registerGetVulnerabilityTool(mcpServer, db, logger)
|
||
if logger != nil {
|
||
logger.Info("漏洞 MCP 工具注册成功", zap.Strings("tools", []string{
|
||
builtin.ToolRecordVulnerability,
|
||
builtin.ToolListVulnerabilities,
|
||
builtin.ToolGetVulnerability,
|
||
}))
|
||
}
|
||
}
|
||
|
||
func registerRecordVulnerabilityTool(mcpServer *mcp.Server, db *database.DB, logger *zap.Logger) {
|
||
tool := mcp.Tool{
|
||
Name: builtin.ToolRecordVulnerability,
|
||
Description: "记录发现的漏洞详情到漏洞管理系统。当发现有效漏洞时,使用此工具记录漏洞信息,包括标题、描述、严重程度、类型、目标、证明、影响和建议等。记录前可先 list_vulnerabilities 避免重复。",
|
||
ShortDescription: "记录发现的漏洞详情到漏洞管理系统",
|
||
InputSchema: map[string]interface{}{
|
||
"type": "object",
|
||
"properties": map[string]interface{}{
|
||
"title": map[string]interface{}{
|
||
"type": "string",
|
||
"description": "漏洞标题(必需)",
|
||
},
|
||
"description": map[string]interface{}{
|
||
"type": "string",
|
||
"description": "漏洞详细描述",
|
||
},
|
||
"severity": map[string]interface{}{
|
||
"type": "string",
|
||
"description": "漏洞严重程度:critical(严重)、high(高)、medium(中)、low(低)、info(信息)",
|
||
"enum": []string{"critical", "high", "medium", "low", "info"},
|
||
},
|
||
"vulnerability_type": map[string]interface{}{
|
||
"type": "string",
|
||
"description": "漏洞类型,如:SQL注入、XSS、CSRF、命令注入等",
|
||
},
|
||
"target": map[string]interface{}{
|
||
"type": "string",
|
||
"description": "受影响的目标(URL、IP地址、服务等)",
|
||
},
|
||
"proof": map[string]interface{}{
|
||
"type": "string",
|
||
"description": "漏洞证明(POC、截图、请求/响应等)",
|
||
},
|
||
"impact": map[string]interface{}{
|
||
"type": "string",
|
||
"description": "漏洞影响说明",
|
||
},
|
||
"recommendation": map[string]interface{}{
|
||
"type": "string",
|
||
"description": "修复建议",
|
||
},
|
||
},
|
||
"required": []string{"title", "severity"},
|
||
},
|
||
}
|
||
|
||
mcpServer.RegisterTool(tool, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
|
||
conversationID := strings.TrimSpace(strArg(args, "conversation_id"))
|
||
if conversationID == "" {
|
||
conversationID = conversationIDFromToolCtx(ctx)
|
||
}
|
||
if conversationID == "" {
|
||
return textResult("错误: conversation_id 未设置。这是系统错误,请重试。", true), nil
|
||
}
|
||
|
||
title := strings.TrimSpace(strArg(args, "title"))
|
||
if title == "" {
|
||
return textResult("错误: title 参数必需且不能为空", true), nil
|
||
}
|
||
|
||
severity := strings.TrimSpace(strArg(args, "severity"))
|
||
if severity == "" {
|
||
return textResult("错误: severity 参数必需且不能为空", true), nil
|
||
}
|
||
|
||
validSeverities := map[string]bool{
|
||
"critical": true, "high": true, "medium": true, "low": true, "info": true,
|
||
}
|
||
if !validSeverities[severity] {
|
||
return textResult(fmt.Sprintf("错误: severity 必须是 critical、high、medium、low 或 info 之一,当前值: %s", severity), true), nil
|
||
}
|
||
|
||
projectID := ""
|
||
if pid, perr := db.GetConversationProjectID(conversationID); perr == nil {
|
||
projectID = strings.TrimSpace(pid)
|
||
}
|
||
|
||
vuln := &database.Vulnerability{
|
||
ConversationID: conversationID,
|
||
ProjectID: projectID,
|
||
Title: title,
|
||
Description: strArg(args, "description"),
|
||
Severity: severity,
|
||
Status: "open",
|
||
Type: strArg(args, "vulnerability_type"),
|
||
Target: strArg(args, "target"),
|
||
Proof: strArg(args, "proof"),
|
||
Impact: strArg(args, "impact"),
|
||
Recommendation: strArg(args, "recommendation"),
|
||
}
|
||
|
||
created, err := db.CreateVulnerability(vuln)
|
||
if err != nil {
|
||
if logger != nil {
|
||
logger.Error("记录漏洞失败", zap.Error(err))
|
||
}
|
||
return textResult(fmt.Sprintf("记录漏洞失败: %v", err), true), nil
|
||
}
|
||
|
||
if logger != nil {
|
||
logger.Info("漏洞记录成功",
|
||
zap.String("id", created.ID),
|
||
zap.String("title", created.Title),
|
||
zap.String("severity", created.Severity),
|
||
zap.String("conversation_id", conversationID),
|
||
)
|
||
}
|
||
|
||
return textResult(fmt.Sprintf("漏洞已成功记录!\n\n漏洞ID: %s\n标题: %s\n严重程度: %s\n状态: %s\n\n可使用 get_vulnerability(id) 查看详情,或 list_vulnerabilities 查看列表。",
|
||
created.ID, created.Title, created.Severity, created.Status), false), nil
|
||
})
|
||
}
|
||
|
||
func registerListVulnerabilitiesTool(mcpServer *mcp.Server, db *database.DB, logger *zap.Logger) {
|
||
tool := mcp.Tool{
|
||
Name: builtin.ToolListVulnerabilities,
|
||
Description: "列出当前授权范围内的漏洞(摘要)。默认:对话已绑定项目时列出该项目下全部漏洞;未绑项目时仅列出当前会话漏洞。可用 scope=conversation 仅看本会话。记录新漏洞前建议先调用以避免重复。",
|
||
ShortDescription: "列出漏洞(默认当前项目)",
|
||
InputSchema: map[string]interface{}{
|
||
"type": "object",
|
||
"properties": map[string]interface{}{
|
||
"scope": map[string]interface{}{
|
||
"type": "string",
|
||
"description": "范围:project(默认,需绑定项目)| conversation(仅当前会话)",
|
||
"enum": []string{"project", "conversation"},
|
||
},
|
||
"severity": map[string]interface{}{
|
||
"type": "string",
|
||
"description": "按严重程度筛选:critical、high、medium、low、info",
|
||
"enum": []string{"critical", "high", "medium", "low", "info"},
|
||
},
|
||
"status": map[string]interface{}{
|
||
"type": "string",
|
||
"description": "按状态筛选:open、confirmed、fixed、false_positive",
|
||
"enum": []string{"open", "confirmed", "fixed", "false_positive"},
|
||
},
|
||
"q": map[string]interface{}{
|
||
"type": "string",
|
||
"description": "关键词搜索(标题、描述、类型、目标等)",
|
||
},
|
||
"limit": map[string]interface{}{
|
||
"type": "integer",
|
||
"description": "返回条数上限,默认 30,最大 100",
|
||
},
|
||
"offset": map[string]interface{}{
|
||
"type": "integer",
|
||
"description": "分页偏移,默认 0",
|
||
},
|
||
},
|
||
},
|
||
}
|
||
|
||
mcpServer.RegisterTool(tool, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
|
||
filter, scopeLabel, err := buildVulnerabilityListFilter(db, ctx, args)
|
||
if err != nil {
|
||
return textResult("错误: "+err.Error(), true), nil
|
||
}
|
||
|
||
limit := intArg(args, "limit", 30)
|
||
if limit <= 0 || limit > 100 {
|
||
limit = 30
|
||
}
|
||
offset := intArg(args, "offset", 0)
|
||
if offset < 0 {
|
||
offset = 0
|
||
}
|
||
|
||
total, err := db.CountVulnerabilities(filter)
|
||
if err != nil {
|
||
if logger != nil {
|
||
logger.Warn("统计漏洞失败", zap.Error(err))
|
||
}
|
||
total = 0
|
||
}
|
||
|
||
list, err := db.ListVulnerabilities(limit, offset, filter)
|
||
if err != nil {
|
||
return textResult("错误: "+err.Error(), true), nil
|
||
}
|
||
|
||
var b strings.Builder
|
||
b.WriteString(fmt.Sprintf("范围: %s\n总计: %d | 本页: %d 条 (limit=%d offset=%d)\n\n", scopeLabel, total, len(list), limit, offset))
|
||
if len(list) == 0 {
|
||
b.WriteString("(暂无漏洞记录)\n")
|
||
} else {
|
||
for _, v := range list {
|
||
b.WriteString(formatVulnerabilityListItem(v))
|
||
b.WriteString("\n")
|
||
}
|
||
if total > offset+len(list) {
|
||
b.WriteString(fmt.Sprintf("\n(还有更多,可增大 offset 或使用 q/severity/status 筛选)\n"))
|
||
}
|
||
}
|
||
b.WriteString("\n需要 POC 与完整字段请对具体 id 调用 get_vulnerability。")
|
||
return textResult(b.String(), false), nil
|
||
})
|
||
}
|
||
|
||
func registerGetVulnerabilityTool(mcpServer *mcp.Server, db *database.DB, logger *zap.Logger) {
|
||
tool := mcp.Tool{
|
||
Name: builtin.ToolGetVulnerability,
|
||
Description: "按漏洞 ID 获取完整详情(含 POC、影响、修复建议)。仅能访问当前项目或当前会话下的漏洞(与 list_vulnerabilities 授权范围一致)。",
|
||
ShortDescription: "按 ID 获取漏洞详情",
|
||
InputSchema: map[string]interface{}{
|
||
"type": "object",
|
||
"properties": map[string]interface{}{
|
||
"id": map[string]interface{}{
|
||
"type": "string",
|
||
"description": "漏洞 ID(list_vulnerabilities 返回的 id)",
|
||
},
|
||
},
|
||
"required": []string{"id"},
|
||
},
|
||
}
|
||
|
||
mcpServer.RegisterTool(tool, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
|
||
convID := conversationIDFromToolCtx(ctx)
|
||
if convID == "" {
|
||
return textResult("错误: 无法确定当前对话,请在对话上下文中使用本工具", true), nil
|
||
}
|
||
|
||
id := strings.TrimSpace(strArg(args, "id"))
|
||
if id == "" {
|
||
return textResult("错误: id 必填", true), nil
|
||
}
|
||
|
||
vuln, err := db.GetVulnerability(id)
|
||
if err != nil {
|
||
return textResult("错误: 漏洞不存在或查询失败", true), nil
|
||
}
|
||
|
||
projectID := ""
|
||
if pid, perr := db.GetConversationProjectID(convID); perr == nil {
|
||
projectID = strings.TrimSpace(pid)
|
||
}
|
||
|
||
if !canAccessVulnerability(vuln, convID, projectID) {
|
||
return textResult("错误: 无权访问该漏洞(仅可查看当前项目或当前会话下的记录)", true), nil
|
||
}
|
||
|
||
return textResult(formatVulnerabilityDetail(vuln), false), nil
|
||
})
|
||
}
|