diff --git a/internal/app/app.go b/internal/app/app.go index 5a2cef85..93f3dd8e 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -111,7 +111,8 @@ func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error executor.RegisterTools(mcpServer) // 注册漏洞记录工具 - registerVulnerabilityTool(mcpServer, db, log.Logger) + registerVulnerabilityTools(mcpServer, db, log.Logger) + registerProjectFactTools(mcpServer, db, cfg, log.Logger) if cfg.Auth.GeneratedPassword != "" { config.PrintGeneratedPasswordWarning(cfg.Auth.GeneratedPassword, cfg.Auth.GeneratedPasswordPersisted, cfg.Auth.GeneratedPasswordPersistErr) @@ -346,6 +347,7 @@ func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error authHandler.SetAudit(auditSvc) attackChainHandler := handler.NewAttackChainHandler(db, &cfg.OpenAI, log.Logger) vulnerabilityHandler := handler.NewVulnerabilityHandler(db, log.Logger) + projectHandler := handler.NewProjectHandler(db, log.Logger) vulnerabilityHandler.SetAudit(auditSvc) webshellHandler := handler.NewWebShellHandler(log.Logger, db) webshellHandler.SetAudit(auditSvc) @@ -414,7 +416,8 @@ func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error // 设置漏洞工具注册器(内置工具,必须设置) vulnerabilityRegistrar := func() error { - registerVulnerabilityTool(mcpServer, db, log.Logger) + registerVulnerabilityTools(mcpServer, db, log.Logger) + registerProjectFactTools(mcpServer, db, cfg, log.Logger) return nil } configHandler.SetVulnerabilityToolRegistrar(vulnerabilityRegistrar) @@ -502,6 +505,7 @@ func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error attackChainHandler, app, // 传递 App 实例以便动态获取 knowledgeHandler vulnerabilityHandler, + projectHandler, webshellHandler, chatUploadsHandler, roleHandler, @@ -747,6 +751,7 @@ func setupRoutes( attackChainHandler *handler.AttackChainHandler, app *App, // 传递 App 实例以便动态获取 knowledgeHandler vulnerabilityHandler *handler.VulnerabilityHandler, + projectHandler *handler.ProjectHandler, webshellHandler *handler.WebShellHandler, chatUploadsHandler *handler.ChatUploadsHandler, roleHandler *handler.RoleHandler, @@ -851,6 +856,7 @@ func setupRoutes( protected.GET("/conversations/:id", conversationHandler.GetConversation) protected.GET("/messages/:id/process-details", conversationHandler.GetMessageProcessDetails) protected.PUT("/conversations/:id", conversationHandler.UpdateConversation) + protected.PUT("/conversations/:id/project", conversationHandler.SetConversationProject) protected.DELETE("/conversations/:id", conversationHandler.DeleteConversation) protected.POST("/conversations/:id/delete-turn", conversationHandler.DeleteConversationTurn) protected.PUT("/conversations/:id/pinned", groupHandler.UpdateConversationPinned) @@ -1067,6 +1073,18 @@ func setupRoutes( protected.PUT("/vulnerabilities/:id", vulnerabilityHandler.UpdateVulnerability) protected.DELETE("/vulnerabilities/:id", vulnerabilityHandler.DeleteVulnerability) + // 项目管理与事实黑板 + protected.GET("/projects", projectHandler.ListProjects) + protected.POST("/projects", projectHandler.CreateProject) + protected.GET("/projects/:id", projectHandler.GetProject) + protected.PUT("/projects/:id", projectHandler.UpdateProject) + protected.DELETE("/projects/:id", projectHandler.DeleteProject) + protected.GET("/projects/:id/facts", projectHandler.ListFacts) + protected.POST("/projects/:id/facts", projectHandler.CreateFact) + protected.PUT("/projects/:id/facts/:factId", projectHandler.UpdateFact) + protected.DELETE("/projects/:id/facts/:factId", projectHandler.DeleteFact) + protected.POST("/projects/:id/facts/deprecate", projectHandler.DeprecateFact) + // WebShell 管理(代理执行 + 连接配置存 SQLite) protected.GET("/webshell/connections", webshellHandler.ListConnections) protected.POST("/webshell/connections", webshellHandler.CreateConnection) @@ -1187,195 +1205,6 @@ func setupRoutes( }) } -// registerVulnerabilityTool 注册漏洞记录工具到MCP服务器 -func registerVulnerabilityTool(mcpServer *mcp.Server, db *database.DB, logger *zap.Logger) { - tool := mcp.Tool{ - Name: builtin.ToolRecordVulnerability, - Description: "记录发现的漏洞详情到漏洞管理系统。当发现有效漏洞时,使用此工具记录漏洞信息,包括标题、描述、严重程度、类型、目标、证明、影响和建议等。", - ShortDescription: "记录发现的漏洞详情到漏洞管理系统", - InputSchema: map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "title": map[string]interface{}{ - "type": "string", - "description": "漏洞标题(必需)", - }, - "description": map[string]interface{}{ - "type": "string", - "description": "漏洞详细描述", - }, - "severity": map[string]interface{}{ - "type": "string", - "description": "漏洞严重程度:critical(严重)、high(高)、medium(中)、low(低)、info(信息)", - "enum": []string{"critical", "high", "medium", "low", "info"}, - }, - "vulnerability_type": map[string]interface{}{ - "type": "string", - "description": "漏洞类型,如:SQL注入、XSS、CSRF、命令注入等", - }, - "target": map[string]interface{}{ - "type": "string", - "description": "受影响的目标(URL、IP地址、服务等)", - }, - "proof": map[string]interface{}{ - "type": "string", - "description": "漏洞证明(POC、截图、请求/响应等)", - }, - "impact": map[string]interface{}{ - "type": "string", - "description": "漏洞影响说明", - }, - "recommendation": map[string]interface{}{ - "type": "string", - "description": "修复建议", - }, - }, - "required": []string{"title", "severity"}, - }, - } - - handler := func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) { - // 从参数中获取conversation_id(由Agent自动添加) - conversationID, _ := args["conversation_id"].(string) - if conversationID == "" { - return &mcp.ToolResult{ - Content: []mcp.Content{ - { - Type: "text", - Text: "错误: conversation_id 未设置。这是系统错误,请重试。", - }, - }, - IsError: true, - }, nil - } - - title, ok := args["title"].(string) - if !ok || title == "" { - return &mcp.ToolResult{ - Content: []mcp.Content{ - { - Type: "text", - Text: "错误: title 参数必需且不能为空", - }, - }, - IsError: true, - }, nil - } - - severity, ok := args["severity"].(string) - if !ok || severity == "" { - return &mcp.ToolResult{ - Content: []mcp.Content{ - { - Type: "text", - Text: "错误: severity 参数必需且不能为空", - }, - }, - IsError: true, - }, nil - } - - // 验证严重程度 - validSeverities := map[string]bool{ - "critical": true, - "high": true, - "medium": true, - "low": true, - "info": true, - } - if !validSeverities[severity] { - return &mcp.ToolResult{ - Content: []mcp.Content{ - { - Type: "text", - Text: fmt.Sprintf("错误: severity 必须是 critical、high、medium、low 或 info 之一,当前值: %s", severity), - }, - }, - IsError: true, - }, nil - } - - // 获取可选参数 - description := "" - if d, ok := args["description"].(string); ok { - description = d - } - - vulnType := "" - if t, ok := args["vulnerability_type"].(string); ok { - vulnType = t - } - - target := "" - if t, ok := args["target"].(string); ok { - target = t - } - - proof := "" - if p, ok := args["proof"].(string); ok { - proof = p - } - - impact := "" - if i, ok := args["impact"].(string); ok { - impact = i - } - - recommendation := "" - if r, ok := args["recommendation"].(string); ok { - recommendation = r - } - - // 创建漏洞记录 - vuln := &database.Vulnerability{ - ConversationID: conversationID, - Title: title, - Description: description, - Severity: severity, - Status: "open", - Type: vulnType, - Target: target, - Proof: proof, - Impact: impact, - Recommendation: recommendation, - } - - created, err := db.CreateVulnerability(vuln) - if err != nil { - logger.Error("记录漏洞失败", zap.Error(err)) - return &mcp.ToolResult{ - Content: []mcp.Content{ - { - Type: "text", - Text: fmt.Sprintf("记录漏洞失败: %v", err), - }, - }, - IsError: true, - }, nil - } - - logger.Info("漏洞记录成功", - zap.String("id", created.ID), - zap.String("title", created.Title), - zap.String("severity", created.Severity), - zap.String("conversation_id", conversationID), - ) - - return &mcp.ToolResult{ - Content: []mcp.Content{ - { - Type: "text", - Text: fmt.Sprintf("漏洞已成功记录!\n\n漏洞ID: %s\n标题: %s\n严重程度: %s\n状态: %s\n\n你可以在漏洞管理页面查看和管理此漏洞。", created.ID, created.Title, created.Severity, created.Status), - }, - }, - IsError: false, - }, nil - } - - mcpServer.RegisterTool(tool, handler) - logger.Info("漏洞记录工具注册成功") -} - // registerWebshellTools 注册 WebShell 相关 MCP 工具,供 AI 助手在指定连接上执行命令与文件操作 func registerWebshellTools(mcpServer *mcp.Server, db *database.DB, webshellHandler *handler.WebShellHandler, logger *zap.Logger) { if db == nil || webshellHandler == nil { diff --git a/internal/app/project_fact_tools.go b/internal/app/project_fact_tools.go new file mode 100644 index 00000000..efef739d --- /dev/null +++ b/internal/app/project_fact_tools.go @@ -0,0 +1,278 @@ +package app + +import ( + "context" + "fmt" + "strings" + + "cyberstrike-ai/internal/agent" + "cyberstrike-ai/internal/config" + "cyberstrike-ai/internal/database" + "cyberstrike-ai/internal/mcp" + "cyberstrike-ai/internal/mcp/builtin" + + "go.uber.org/zap" +) + +func projectIDFromConversation(db *database.DB, ctx context.Context) (string, error) { + convID := agent.ConversationIDFromContext(ctx) + if convID == "" { + return "", fmt.Errorf("无法确定当前对话,请在对话上下文中使用项目事实工具") + } + pid, err := db.GetConversationProjectID(convID) + if err != nil { + return "", err + } + if strings.TrimSpace(pid) == "" { + return "", fmt.Errorf("当前对话未绑定项目,请先在对话中选择项目或创建带项目的对话") + } + return pid, nil +} + +func textResult(msg string, isErr bool) *mcp.ToolResult { + return &mcp.ToolResult{ + Content: []mcp.Content{{Type: "text", Text: msg}}, + IsError: isErr, + } +} + +// registerProjectFactTools 注册项目黑板 MCP 工具。 +func registerProjectFactTools(mcpServer *mcp.Server, db *database.DB, cfg *config.Config, logger *zap.Logger) { + if db == nil || cfg == nil || !cfg.Project.Enabled { + if logger != nil { + logger.Info("项目黑板工具未注册(未启用)") + } + return + } + + upsertTool := mcp.Tool{ + Name: builtin.ToolUpsertProjectFact, + Description: "写入或更新项目黑板事实。用于记录环境认知、目标信息、认证特征等(非正式漏洞条目)。同 fact_key 会覆盖更新。需要当前对话已绑定项目。", + ShortDescription: "写入/更新项目事实", + InputSchema: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "fact_key": map[string]interface{}{ + "type": "string", + "description": "项目内唯一 key,建议格式 category/slug,如 target/primary_domain", + }, + "category": map[string]interface{}{ + "type": "string", + "description": "分类:target、auth、infra、business、note 等", + }, + "summary": map[string]interface{}{ + "type": "string", + "description": "单行摘要(会注入到后续对话索引)", + }, + "body": map[string]interface{}{ + "type": "string", + "description": "完整详情(POC、长文本等,仅 get_project_fact 返回)", + }, + "confidence": map[string]interface{}{ + "type": "string", + "description": "confirmed | tentative | deprecated", + "enum": []string{"confirmed", "tentative", "deprecated"}, + }, + "pinned": map[string]interface{}{ + "type": "boolean", + "description": "是否优先出现在黑板索引", + }, + "related_vulnerability_id": map[string]interface{}{ + "type": "string", + "description": "可选:关联的漏洞记录 ID", + }, + }, + "required": []string{"fact_key", "summary"}, + }, + } + + mcpServer.RegisterTool(upsertTool, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) { + projectID, err := projectIDFromConversation(db, ctx) + if err != nil { + return textResult("错误: "+err.Error(), true), nil + } + factKey, _ := args["fact_key"].(string) + summary, _ := args["summary"].(string) + if strings.TrimSpace(factKey) == "" || strings.TrimSpace(summary) == "" { + return textResult("错误: fact_key 与 summary 必填", true), nil + } + if len([]rune(summary)) > cfg.Project.FactSummaryMaxRunesEffective() { + return textResult(fmt.Sprintf("错误: summary 过长(最多 %d 字)", cfg.Project.FactSummaryMaxRunesEffective()), true), nil + } + f := &database.ProjectFact{ + ProjectID: projectID, + FactKey: factKey, + Category: strArg(args, "category"), + Summary: summary, + Body: strArg(args, "body"), + Confidence: strArg(args, "confidence"), + Pinned: boolArg(args, "pinned"), + RelatedVulnerabilityID: strArg(args, "related_vulnerability_id"), + } + if convID := agent.ConversationIDFromContext(ctx); convID != "" { + f.SourceConversationID = convID + } + created, err := db.UpsertProjectFact(f) + if err != nil { + return textResult("错误: "+err.Error(), true), nil + } + return textResult(fmt.Sprintf("事实已保存。\nfact_key: %s\nid: %s\nconfidence: %s", created.FactKey, created.ID, created.Confidence), false), nil + }) + + getTool := mcp.Tool{ + Name: builtin.ToolGetProjectFact, + Description: "按 fact_key 获取项目事实完整 body 与元数据。摘要不足时必须调用本工具,禁止臆造细节。", + ShortDescription: "按 key 获取事实详情", + InputSchema: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "fact_key": map[string]interface{}{"type": "string", "description": "事实 key"}, + }, + "required": []string{"fact_key"}, + }, + } + mcpServer.RegisterTool(getTool, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) { + projectID, err := projectIDFromConversation(db, ctx) + if err != nil { + return textResult("错误: "+err.Error(), true), nil + } + key := strings.TrimSpace(strArg(args, "fact_key")) + if key == "" { + return textResult("错误: fact_key 必填", true), nil + } + f, err := db.GetProjectFactByKey(projectID, key) + if err != nil { + return textResult("错误: "+err.Error(), true), nil + } + msg := fmt.Sprintf("fact_key: %s\ncategory: %s\nconfidence: %s\nsummary: %s\nupdated_at: %s\n\n--- body ---\n%s", + f.FactKey, f.Category, f.Confidence, f.Summary, f.UpdatedAt.Format("2006-01-02 15:04:05"), f.Body) + return textResult(msg, false), nil + }) + + listTool := mcp.Tool{ + Name: builtin.ToolListProjectFacts, + Description: "列出当前项目的事实(分页)。", + ShortDescription: "列出项目事实", + InputSchema: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "category": map[string]interface{}{"type": "string"}, + "confidence": map[string]interface{}{"type": "string"}, + "limit": map[string]interface{}{"type": "integer"}, + "offset": map[string]interface{}{"type": "integer"}, + }, + }, + } + mcpServer.RegisterTool(listTool, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) { + projectID, err := projectIDFromConversation(db, ctx) + if err != nil { + return textResult("错误: "+err.Error(), true), nil + } + limit := intArg(args, "limit", 50) + offset := intArg(args, "offset", 0) + filter := database.ProjectFactListFilter{ + Category: strArg(args, "category"), + Confidence: strArg(args, "confidence"), + } + list, err := db.ListProjectFacts(projectID, filter, limit, offset) + if err != nil { + return textResult("错误: "+err.Error(), true), nil + } + var b strings.Builder + b.WriteString(fmt.Sprintf("共 %d 条(limit=%d offset=%d):\n", len(list), limit, offset)) + for _, f := range list { + b.WriteString(fmt.Sprintf("- [%s] %s — %s (%s)\n", f.FactKey, f.Category, f.Summary, f.Confidence)) + } + return textResult(b.String(), false), nil + }) + + searchTool := mcp.Tool{ + Name: builtin.ToolSearchProjectFacts, + Description: "按关键词搜索项目事实(summary/body/fact_key)。", + ShortDescription: "搜索项目事实", + InputSchema: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "query": map[string]interface{}{"type": "string"}, + "limit": map[string]interface{}{"type": "integer"}, + "offset": map[string]interface{}{"type": "integer"}, + }, + "required": []string{"query"}, + }, + } + mcpServer.RegisterTool(searchTool, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) { + projectID, err := projectIDFromConversation(db, ctx) + if err != nil { + return textResult("错误: "+err.Error(), true), nil + } + q := strings.TrimSpace(strArg(args, "query")) + if q == "" { + return textResult("错误: query 必填", true), nil + } + list, err := db.ListProjectFacts(projectID, database.ProjectFactListFilter{Search: q}, intArg(args, "limit", 30), intArg(args, "offset", 0)) + if err != nil { + return textResult("错误: "+err.Error(), true), nil + } + var b strings.Builder + b.WriteString(fmt.Sprintf("搜索 \"%s\" 命中 %d 条:\n", q, len(list))) + for _, f := range list { + b.WriteString(fmt.Sprintf("- [%s] %s — %s\n", f.FactKey, f.Category, f.Summary)) + } + return textResult(b.String(), false), nil + }) + + deprecateTool := mcp.Tool{ + Name: builtin.ToolDeprecateProjectFact, + Description: "将事实标记为 deprecated,从黑板索引中排除。", + ShortDescription: "废弃项目事实", + InputSchema: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "fact_key": map[string]interface{}{"type": "string"}, + }, + "required": []string{"fact_key"}, + }, + } + mcpServer.RegisterTool(deprecateTool, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) { + projectID, err := projectIDFromConversation(db, ctx) + if err != nil { + return textResult("错误: "+err.Error(), true), nil + } + key := strings.TrimSpace(strArg(args, "fact_key")) + if err := db.DeprecateProjectFact(projectID, key); err != nil { + return textResult("错误: "+err.Error(), true), nil + } + return textResult("事实已标记为 deprecated: "+key, false), nil + }) + + if logger != nil { + logger.Info("项目黑板 MCP 工具注册成功") + } +} + +func strArg(args map[string]interface{}, key string) string { + if v, ok := args[key].(string); ok { + return v + } + return "" +} + +func boolArg(args map[string]interface{}, key string) bool { + if v, ok := args[key].(bool); ok { + return v + } + return false +} + +func intArg(args map[string]interface{}, key string, def int) int { + switch v := args[key].(type) { + case float64: + return int(v) + case int: + return v + case int64: + return int(v) + default: + return def + } +} diff --git a/internal/app/vulnerability_tools.go b/internal/app/vulnerability_tools.go new file mode 100644 index 00000000..8359208c --- /dev/null +++ b/internal/app/vulnerability_tools.go @@ -0,0 +1,405 @@ +package app + +import ( + "context" + "fmt" + "strings" + + "cyberstrike-ai/internal/agent" + "cyberstrike-ai/internal/database" + "cyberstrike-ai/internal/mcp" + "cyberstrike-ai/internal/mcp/builtin" + + "go.uber.org/zap" +) + +func conversationIDFromToolCtx(ctx context.Context) string { + if id := agent.ConversationIDFromContext(ctx); id != "" { + return id + } + return mcp.MCPConversationIDFromContext(ctx) +} + +// canAccessVulnerability 校验当前对话是否有权查看该漏洞(默认项目隔离,未绑项目则仅本会话)。 +func canAccessVulnerability(vuln *database.Vulnerability, convID, projectID string) bool { + if vuln == nil || convID == "" { + return false + } + if projectID != "" { + if strings.TrimSpace(vuln.ProjectID) == projectID { + return true + } + // 历史记录:写入时尚未绑定 project_id,但属于同一会话 + if strings.TrimSpace(vuln.ProjectID) == "" && vuln.ConversationID == convID { + return true + } + return false + } + return vuln.ConversationID == convID +} + +func buildVulnerabilityListFilter(db *database.DB, ctx context.Context, args map[string]interface{}) (database.VulnerabilityListFilter, string, error) { + convID := conversationIDFromToolCtx(ctx) + if convID == "" { + return database.VulnerabilityListFilter{}, "", fmt.Errorf("无法确定当前对话,请在对话上下文中使用漏洞查询工具") + } + + projectID := "" + if pid, err := db.GetConversationProjectID(convID); err == nil { + projectID = strings.TrimSpace(pid) + } + + scope := strings.TrimSpace(strArg(args, "scope")) + if scope == "" { + if projectID != "" { + scope = "project" + } else { + scope = "conversation" + } + } + + filter := database.VulnerabilityListFilter{ + Severity: strings.TrimSpace(strArg(args, "severity")), + Status: strings.TrimSpace(strArg(args, "status")), + } + if q := strings.TrimSpace(strArg(args, "q")); q != "" { + filter.Search = q + } else { + filter.Search = strings.TrimSpace(strArg(args, "search")) + } + + var scopeLabel string + switch scope { + case "project": + if projectID == "" { + return filter, "", fmt.Errorf("当前对话未绑定项目,无法按项目列出漏洞;请使用 scope=conversation,或先在对话中绑定项目") + } + filter.ProjectID = projectID + scopeLabel = fmt.Sprintf("项目 %s", projectID) + case "conversation": + filter.ConversationID = convID + scopeLabel = fmt.Sprintf("会话 %s", convID) + default: + return filter, "", fmt.Errorf("scope 仅支持 project 或 conversation,当前值: %s", scope) + } + return filter, scopeLabel, nil +} + +func formatVulnerabilityListItem(v *database.Vulnerability) string { + line := fmt.Sprintf("- id=%s | %s | %s | %s", v.ID, v.Severity, v.Status, v.Title) + if v.Type != "" { + line += fmt.Sprintf(" | type=%s", v.Type) + } + if v.Target != "" { + line += fmt.Sprintf(" | target=%s", truncateRunes(v.Target, 80)) + } + return line +} + +func formatVulnerabilityDetail(v *database.Vulnerability) string { + var b strings.Builder + b.WriteString(fmt.Sprintf("漏洞ID: %s\n", v.ID)) + b.WriteString(fmt.Sprintf("标题: %s\n", v.Title)) + b.WriteString(fmt.Sprintf("严重程度: %s\n", v.Severity)) + b.WriteString(fmt.Sprintf("状态: %s\n", v.Status)) + if v.Type != "" { + b.WriteString(fmt.Sprintf("类型: %s\n", v.Type)) + } + if v.Target != "" { + b.WriteString(fmt.Sprintf("目标: %s\n", v.Target)) + } + if v.ProjectID != "" { + b.WriteString(fmt.Sprintf("项目ID: %s\n", v.ProjectID)) + } + b.WriteString(fmt.Sprintf("会话ID: %s\n", v.ConversationID)) + if !v.CreatedAt.IsZero() { + b.WriteString(fmt.Sprintf("创建时间: %s\n", v.CreatedAt.Format("2006-01-02 15:04:05"))) + } + if v.Description != "" { + b.WriteString("\n--- 描述 ---\n") + b.WriteString(v.Description) + b.WriteString("\n") + } + if v.Proof != "" { + b.WriteString("\n--- 证明(POC) ---\n") + b.WriteString(v.Proof) + b.WriteString("\n") + } + if v.Impact != "" { + b.WriteString("\n--- 影响 ---\n") + b.WriteString(v.Impact) + b.WriteString("\n") + } + if v.Recommendation != "" { + b.WriteString("\n--- 修复建议 ---\n") + b.WriteString(v.Recommendation) + b.WriteString("\n") + } + return b.String() +} + +func truncateRunes(s string, max int) string { + r := []rune(s) + if len(r) <= max { + return s + } + return string(r[:max]) + "…" +} + +// registerVulnerabilityTools 注册漏洞记录与查询 MCP 工具。 +func registerVulnerabilityTools(mcpServer *mcp.Server, db *database.DB, logger *zap.Logger) { + registerRecordVulnerabilityTool(mcpServer, db, logger) + registerListVulnerabilitiesTool(mcpServer, db, logger) + registerGetVulnerabilityTool(mcpServer, db, logger) + if logger != nil { + logger.Info("漏洞 MCP 工具注册成功", zap.Strings("tools", []string{ + builtin.ToolRecordVulnerability, + builtin.ToolListVulnerabilities, + builtin.ToolGetVulnerability, + })) + } +} + +func registerRecordVulnerabilityTool(mcpServer *mcp.Server, db *database.DB, logger *zap.Logger) { + tool := mcp.Tool{ + Name: builtin.ToolRecordVulnerability, + Description: "记录发现的漏洞详情到漏洞管理系统。当发现有效漏洞时,使用此工具记录漏洞信息,包括标题、描述、严重程度、类型、目标、证明、影响和建议等。记录前可先 list_vulnerabilities 避免重复。", + ShortDescription: "记录发现的漏洞详情到漏洞管理系统", + InputSchema: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "title": map[string]interface{}{ + "type": "string", + "description": "漏洞标题(必需)", + }, + "description": map[string]interface{}{ + "type": "string", + "description": "漏洞详细描述", + }, + "severity": map[string]interface{}{ + "type": "string", + "description": "漏洞严重程度:critical(严重)、high(高)、medium(中)、low(低)、info(信息)", + "enum": []string{"critical", "high", "medium", "low", "info"}, + }, + "vulnerability_type": map[string]interface{}{ + "type": "string", + "description": "漏洞类型,如:SQL注入、XSS、CSRF、命令注入等", + }, + "target": map[string]interface{}{ + "type": "string", + "description": "受影响的目标(URL、IP地址、服务等)", + }, + "proof": map[string]interface{}{ + "type": "string", + "description": "漏洞证明(POC、截图、请求/响应等)", + }, + "impact": map[string]interface{}{ + "type": "string", + "description": "漏洞影响说明", + }, + "recommendation": map[string]interface{}{ + "type": "string", + "description": "修复建议", + }, + }, + "required": []string{"title", "severity"}, + }, + } + + mcpServer.RegisterTool(tool, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) { + conversationID := strings.TrimSpace(strArg(args, "conversation_id")) + if conversationID == "" { + conversationID = conversationIDFromToolCtx(ctx) + } + if conversationID == "" { + return textResult("错误: conversation_id 未设置。这是系统错误,请重试。", true), nil + } + + title := strings.TrimSpace(strArg(args, "title")) + if title == "" { + return textResult("错误: title 参数必需且不能为空", true), nil + } + + severity := strings.TrimSpace(strArg(args, "severity")) + if severity == "" { + return textResult("错误: severity 参数必需且不能为空", true), nil + } + + validSeverities := map[string]bool{ + "critical": true, "high": true, "medium": true, "low": true, "info": true, + } + if !validSeverities[severity] { + return textResult(fmt.Sprintf("错误: severity 必须是 critical、high、medium、low 或 info 之一,当前值: %s", severity), true), nil + } + + projectID := "" + if pid, perr := db.GetConversationProjectID(conversationID); perr == nil { + projectID = strings.TrimSpace(pid) + } + + vuln := &database.Vulnerability{ + ConversationID: conversationID, + ProjectID: projectID, + Title: title, + Description: strArg(args, "description"), + Severity: severity, + Status: "open", + Type: strArg(args, "vulnerability_type"), + Target: strArg(args, "target"), + Proof: strArg(args, "proof"), + Impact: strArg(args, "impact"), + Recommendation: strArg(args, "recommendation"), + } + + created, err := db.CreateVulnerability(vuln) + if err != nil { + if logger != nil { + logger.Error("记录漏洞失败", zap.Error(err)) + } + return textResult(fmt.Sprintf("记录漏洞失败: %v", err), true), nil + } + + if logger != nil { + logger.Info("漏洞记录成功", + zap.String("id", created.ID), + zap.String("title", created.Title), + zap.String("severity", created.Severity), + zap.String("conversation_id", conversationID), + ) + } + + return textResult(fmt.Sprintf("漏洞已成功记录!\n\n漏洞ID: %s\n标题: %s\n严重程度: %s\n状态: %s\n\n可使用 get_vulnerability(id) 查看详情,或 list_vulnerabilities 查看列表。", + created.ID, created.Title, created.Severity, created.Status), false), nil + }) +} + +func registerListVulnerabilitiesTool(mcpServer *mcp.Server, db *database.DB, logger *zap.Logger) { + tool := mcp.Tool{ + Name: builtin.ToolListVulnerabilities, + Description: "列出当前授权范围内的漏洞(摘要)。默认:对话已绑定项目时列出该项目下全部漏洞;未绑项目时仅列出当前会话漏洞。可用 scope=conversation 仅看本会话。记录新漏洞前建议先调用以避免重复。", + ShortDescription: "列出漏洞(默认当前项目)", + InputSchema: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "scope": map[string]interface{}{ + "type": "string", + "description": "范围:project(默认,需绑定项目)| conversation(仅当前会话)", + "enum": []string{"project", "conversation"}, + }, + "severity": map[string]interface{}{ + "type": "string", + "description": "按严重程度筛选:critical、high、medium、low、info", + "enum": []string{"critical", "high", "medium", "low", "info"}, + }, + "status": map[string]interface{}{ + "type": "string", + "description": "按状态筛选:open、confirmed、fixed、false_positive", + "enum": []string{"open", "confirmed", "fixed", "false_positive"}, + }, + "q": map[string]interface{}{ + "type": "string", + "description": "关键词搜索(标题、描述、类型、目标等)", + }, + "limit": map[string]interface{}{ + "type": "integer", + "description": "返回条数上限,默认 30,最大 100", + }, + "offset": map[string]interface{}{ + "type": "integer", + "description": "分页偏移,默认 0", + }, + }, + }, + } + + mcpServer.RegisterTool(tool, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) { + filter, scopeLabel, err := buildVulnerabilityListFilter(db, ctx, args) + if err != nil { + return textResult("错误: "+err.Error(), true), nil + } + + limit := intArg(args, "limit", 30) + if limit <= 0 || limit > 100 { + limit = 30 + } + offset := intArg(args, "offset", 0) + if offset < 0 { + offset = 0 + } + + total, err := db.CountVulnerabilities(filter) + if err != nil { + if logger != nil { + logger.Warn("统计漏洞失败", zap.Error(err)) + } + total = 0 + } + + list, err := db.ListVulnerabilities(limit, offset, filter) + if err != nil { + return textResult("错误: "+err.Error(), true), nil + } + + var b strings.Builder + b.WriteString(fmt.Sprintf("范围: %s\n总计: %d | 本页: %d 条 (limit=%d offset=%d)\n\n", scopeLabel, total, len(list), limit, offset)) + if len(list) == 0 { + b.WriteString("(暂无漏洞记录)\n") + } else { + for _, v := range list { + b.WriteString(formatVulnerabilityListItem(v)) + b.WriteString("\n") + } + if total > offset+len(list) { + b.WriteString(fmt.Sprintf("\n(还有更多,可增大 offset 或使用 q/severity/status 筛选)\n")) + } + } + b.WriteString("\n需要 POC 与完整字段请对具体 id 调用 get_vulnerability。") + return textResult(b.String(), false), nil + }) +} + +func registerGetVulnerabilityTool(mcpServer *mcp.Server, db *database.DB, logger *zap.Logger) { + tool := mcp.Tool{ + Name: builtin.ToolGetVulnerability, + Description: "按漏洞 ID 获取完整详情(含 POC、影响、修复建议)。仅能访问当前项目或当前会话下的漏洞(与 list_vulnerabilities 授权范围一致)。", + ShortDescription: "按 ID 获取漏洞详情", + InputSchema: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "id": map[string]interface{}{ + "type": "string", + "description": "漏洞 ID(list_vulnerabilities 返回的 id)", + }, + }, + "required": []string{"id"}, + }, + } + + mcpServer.RegisterTool(tool, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) { + convID := conversationIDFromToolCtx(ctx) + if convID == "" { + return textResult("错误: 无法确定当前对话,请在对话上下文中使用本工具", true), nil + } + + id := strings.TrimSpace(strArg(args, "id")) + if id == "" { + return textResult("错误: id 必填", true), nil + } + + vuln, err := db.GetVulnerability(id) + if err != nil { + return textResult("错误: 漏洞不存在或查询失败", true), nil + } + + projectID := "" + if pid, perr := db.GetConversationProjectID(convID); perr == nil { + projectID = strings.TrimSpace(pid) + } + + if !canAccessVulnerability(vuln, convID, projectID) { + return textResult("错误: 无权访问该漏洞(仅可查看当前项目或当前会话下的记录)", true), nil + } + + return textResult(formatVulnerabilityDetail(vuln), false), nil + }) +} diff --git a/internal/multiagent/eino_single_runner.go b/internal/multiagent/eino_single_runner.go index 5cdadcc2..980e118b 100644 --- a/internal/multiagent/eino_single_runner.go +++ b/internal/multiagent/eino_single_runner.go @@ -13,6 +13,7 @@ import ( "cyberstrike-ai/internal/config" "cyberstrike-ai/internal/einomcp" "cyberstrike-ai/internal/openai" + "cyberstrike-ai/internal/project" "cyberstrike-ai/internal/reasoning" einoopenai "github.com/cloudwego/eino-ext/components/model/openai" @@ -38,6 +39,7 @@ func RunEinoSingleChatModelAgent( roleTools []string, progress func(eventType, message string, data interface{}), reasoningClient *reasoning.ClientIntent, + systemPromptExtra string, ) (*RunResult, error) { if appCfg == nil || ag == nil { return nil, fmt.Errorf("eino single: 配置或 Agent 为空") @@ -177,7 +179,8 @@ func RunEinoSingleChatModelAgent( }, EmitInternalEvents: true, } - ins := injectToolNamesOnlyInstruction(ctx, ag.EinoSingleAgentSystemInstruction(), mainTools, singleToolSearchActive) + ins := project.AppendSystemPromptBlock(ag.EinoSingleAgentSystemInstruction(), systemPromptExtra) + ins = injectToolNamesOnlyInstruction(ctx, ins, mainTools, singleToolSearchActive) if logger != nil { names := collectToolNames(ctx, mainTools) mountedNames := collectToolNames(ctx, mainToolsForCfg) diff --git a/internal/multiagent/orchestrator_instruction.go b/internal/multiagent/orchestrator_instruction.go index a1fd01d3..c5bf840e 100644 --- a/internal/multiagent/orchestrator_instruction.go +++ b/internal/multiagent/orchestrator_instruction.go @@ -106,16 +106,16 @@ func DefaultPlanExecuteOrchestratorInstruction() string { 当工具返回错误时,错误信息会包含在工具响应中,请仔细阅读并做出合理的决策。 -## 漏洞记录 +## 项目黑板(事实)与漏洞记录(分离) -发现有效漏洞时,必须使用 ` + builtin.ToolRecordVulnerability + ` 记录:标题、描述、严重程度、类型、目标、证明(POC)、影响、修复建议。 +绑定项目时会自动注入黑板索引(fact_key + 摘要)。**摘要不足必须 ` + builtin.ToolGetProjectFact + `(fact_key) 取 body,禁止臆造。** 环境认知用 ` + builtin.ToolUpsertProjectFact + `(key 如 target/primary_domain);正式漏洞用 ` + builtin.ToolRecordVulnerability + `(记前可先 ` + builtin.ToolListVulnerabilities + ` 防重复,详情用 ` + builtin.ToolGetVulnerability + `);二者可各记一次。误报用 ` + builtin.ToolDeprecateProjectFact + `。漏洞查询默认仅当前项目(未绑项目则仅当前会话)。 -严重程度:critical / high / medium / low / info。证明须含足够证据(请求响应、截图、命令输出等)。记录后可在授权范围内继续测试。 +严重程度:critical / high / medium / low / info。证明须含足够证据。 ## 技能库(Skills)与知识库 - 技能包位于服务器 skills/ 目录(各子目录 SKILL.md,遵循 agentskills.io);知识库用于向量检索片段,Skills 为可执行工作流指令。 -- plan_execute 执行器通过 MCP 使用知识库与漏洞记录等;Skills 的渐进式加载在「多代理 / Eino DeepAgent」等模式中由内置 skill 工具完成(需 multi_agent.eino_skills)。 +- plan_execute 执行器通过 MCP 使用知识库、项目事实与漏洞记录等;Skills 的渐进式加载在「多代理 / Eino DeepAgent」等模式中由内置 skill 工具完成(需 multi_agent.eino_skills)。 - 若需要完整 Skill 工作流而当前会话无 skill 工具,请在计划或对用户说明中建议切换多代理或 Eino 编排会话。 ## 执行器对用户输出(重要) @@ -206,7 +206,7 @@ func DefaultSupervisorOrchestratorInstruction() string { - **委派优先**:可独立封装、需要专项上下文的子目标(枚举、验证、归纳、报告素材)优先 transfer 给匹配子代理,并在委派说明中写清:子目标、约束、期望交付物结构、证据要求。 - **亲自执行**:仅当无合适专家、需全局衔接或子代理结果不足时,由你直接调用工具。 - **汇总**:子代理输出是证据来源;你要对齐矛盾、补全上下文,给出统一结论与可复现验证步骤,避免机械拼接。 -- **漏洞**:有效漏洞应通过 ` + builtin.ToolRecordVulnerability + ` 记录(含 POC 与严重性:critical / high / medium / low / info)。 +- **事实与漏洞**:环境认知用 ` + builtin.ToolUpsertProjectFact + `;正式漏洞用 ` + builtin.ToolRecordVulnerability + `,查询用 ` + builtin.ToolListVulnerabilities + ` / ` + builtin.ToolGetVulnerability + `;索引摘要不足时必须 ` + builtin.ToolGetProjectFact + ` 取详情。 ## transfer 交接与防重复劳动 diff --git a/internal/multiagent/runner.go b/internal/multiagent/runner.go index 9c98f966..6e33fcd3 100644 --- a/internal/multiagent/runner.go +++ b/internal/multiagent/runner.go @@ -17,6 +17,7 @@ import ( "cyberstrike-ai/internal/config" "cyberstrike-ai/internal/einomcp" "cyberstrike-ai/internal/openai" + "cyberstrike-ai/internal/project" "cyberstrike-ai/internal/reasoning" einoopenai "github.com/cloudwego/eino-ext/components/model/openai" @@ -64,6 +65,7 @@ func RunDeepAgent( agentsMarkdownDir string, orchestrationOverride string, reasoningClient *reasoning.ClientIntent, + systemPromptExtra string, ) (*RunResult, error) { if appCfg == nil || ma == nil || ag == nil { return nil, fmt.Errorf("multiagent: 配置或 Agent 为空") @@ -339,6 +341,7 @@ func RunDeepAgent( return nil, err } + orchInstruction = project.AppendSystemPromptBlock(orchInstruction, systemPromptExtra) orchInstruction = injectToolNamesOnlyInstruction(ctx, orchInstruction, mainTools, mainToolSearchActive) if logger != nil { mainNames := collectToolNames(ctx, mainTools) @@ -387,7 +390,8 @@ func RunDeepAgent( // noNestedTaskMiddleware 必须在最外层(最先拦截),防止 skill 或其他中间件内部触发 task 调用绕过检测。 deepHandlers := []adk.ChatModelAgentMiddleware{newNoNestedTaskMiddleware()} - if mw := newTaskContextEnrichMiddleware(userMessage, history, ma.SubAgentUserContextMaxRunes); mw != nil { + taskEnrichExtra := systemPromptExtra + if mw := newTaskContextEnrichMiddleware(userMessage, history, ma.SubAgentUserContextMaxRunes, taskEnrichExtra); mw != nil { deepHandlers = append(deepHandlers, mw) } if len(mainOrchestratorPre) > 0 { diff --git a/internal/multiagent/sub_agent_context.go b/internal/multiagent/sub_agent_context.go index d2ec73cb..b31269c3 100644 --- a/internal/multiagent/sub_agent_context.go +++ b/internal/multiagent/sub_agent_context.go @@ -30,8 +30,15 @@ type taskContextEnrichMiddleware struct { // newTaskContextEnrichMiddleware returns a middleware that enriches task // descriptions with user conversation context. Returns nil if disabled // (maxRunes < 0) or no user messages exist. -func newTaskContextEnrichMiddleware(userMessage string, history []agent.ChatMessage, maxRunes int) adk.ChatModelAgentMiddleware { +func newTaskContextEnrichMiddleware(userMessage string, history []agent.ChatMessage, maxRunes int, projectBlackboard string) adk.ChatModelAgentMiddleware { supplement := buildUserContextSupplement(userMessage, history, maxRunes) + if bb := strings.TrimSpace(projectBlackboard); bb != "" { + if supplement != "" { + supplement += "\n\n## 项目黑板索引\n" + bb + } else { + supplement = "\n\n## 项目黑板索引\n" + bb + } + } if supplement == "" { return nil } diff --git a/internal/multiagent/sub_agent_context_test.go b/internal/multiagent/sub_agent_context_test.go index 72e10762..0ce3c5a5 100644 --- a/internal/multiagent/sub_agent_context_test.go +++ b/internal/multiagent/sub_agent_context_test.go @@ -105,6 +105,7 @@ func TestTaskContextEnrichMiddleware_EnrichesTaskDescription(t *testing.T) { "继续测试", []agent.ChatMessage{{Role: "user", Content: "http://8.163.32.73:8081 pikachu靶场"}}, 0, + "", ) if mw == nil { t.Fatal("expected non-nil middleware") @@ -149,7 +150,7 @@ func TestTaskContextEnrichMiddleware_EnrichesTaskDescription(t *testing.T) { } func TestTaskContextEnrichMiddleware_IgnoresNonTaskTools(t *testing.T) { - mw := newTaskContextEnrichMiddleware("test", nil, 0) + mw := newTaskContextEnrichMiddleware("test", nil, 0, "") if mw == nil { t.Fatal("expected non-nil middleware") } @@ -175,7 +176,7 @@ func TestTaskContextEnrichMiddleware_IgnoresNonTaskTools(t *testing.T) { } func TestTaskContextEnrichMiddleware_NilWhenDisabled(t *testing.T) { - mw := newTaskContextEnrichMiddleware("test", nil, -1) + mw := newTaskContextEnrichMiddleware("test", nil, -1, "") if mw != nil { t.Error("middleware should be nil when disabled") }