Files
CyberStrikeAI/internal/app/app.go
2026-01-11 02:03:33 +08:00

1048 lines
35 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package app
import (
"context"
"database/sql"
"fmt"
"net/http"
"os"
"path/filepath"
"time"
"cyberstrike-ai/internal/agent"
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/database"
"cyberstrike-ai/internal/handler"
"cyberstrike-ai/internal/knowledge"
"cyberstrike-ai/internal/logger"
"cyberstrike-ai/internal/mcp"
"cyberstrike-ai/internal/mcp/builtin"
"cyberstrike-ai/internal/openai"
"cyberstrike-ai/internal/security"
"cyberstrike-ai/internal/storage"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// App 应用
type App struct {
config *config.Config
logger *logger.Logger
router *gin.Engine
mcpServer *mcp.Server
externalMCPMgr *mcp.ExternalMCPManager
agent *agent.Agent
executor *security.Executor
db *database.DB
knowledgeDB *database.DB // 知识库数据库连接(如果使用独立数据库)
auth *security.AuthManager
knowledgeManager *knowledge.Manager // 知识库管理器(用于动态初始化)
knowledgeRetriever *knowledge.Retriever // 知识库检索器(用于动态初始化)
knowledgeIndexer *knowledge.Indexer // 知识库索引器(用于动态初始化)
knowledgeHandler *handler.KnowledgeHandler // 知识库处理器(用于动态初始化)
agentHandler *handler.AgentHandler // Agent处理器用于更新知识库管理器
}
// New 创建新应用
func New(cfg *config.Config, log *logger.Logger) (*App, error) {
gin.SetMode(gin.ReleaseMode)
router := gin.Default()
// CORS中间件
router.Use(corsMiddleware())
// 认证管理器
authManager, err := security.NewAuthManager(cfg.Auth.Password, cfg.Auth.SessionDurationHours)
if err != nil {
return nil, fmt.Errorf("初始化认证失败: %w", err)
}
// 初始化数据库
dbPath := cfg.Database.Path
if dbPath == "" {
dbPath = "data/conversations.db"
}
// 确保目录存在
if err := os.MkdirAll(filepath.Dir(dbPath), 0755); err != nil {
return nil, fmt.Errorf("创建数据库目录失败: %w", err)
}
db, err := database.NewDB(dbPath, log.Logger)
if err != nil {
return nil, fmt.Errorf("初始化数据库失败: %w", err)
}
// 创建MCP服务器带数据库持久化
mcpServer := mcp.NewServerWithStorage(log.Logger, db)
// 创建安全工具执行器
executor := security.NewExecutor(&cfg.Security, mcpServer, log.Logger)
// 注册工具
executor.RegisterTools(mcpServer)
// 注册漏洞记录工具
registerVulnerabilityTool(mcpServer, db, log.Logger)
if cfg.Auth.GeneratedPassword != "" {
config.PrintGeneratedPasswordWarning(cfg.Auth.GeneratedPassword, cfg.Auth.GeneratedPasswordPersisted, cfg.Auth.GeneratedPasswordPersistErr)
cfg.Auth.GeneratedPassword = ""
cfg.Auth.GeneratedPasswordPersisted = false
cfg.Auth.GeneratedPasswordPersistErr = ""
}
// 创建外部MCP管理器使用与内部MCP服务器相同的存储
externalMCPMgr := mcp.NewExternalMCPManagerWithStorage(log.Logger, db)
if cfg.ExternalMCP.Servers != nil {
externalMCPMgr.LoadConfigs(&cfg.ExternalMCP)
// 启动所有启用的外部MCP客户端
externalMCPMgr.StartAllEnabled()
}
// 初始化结果存储
resultStorageDir := "tmp"
if cfg.Agent.ResultStorageDir != "" {
resultStorageDir = cfg.Agent.ResultStorageDir
}
// 确保存储目录存在
if err := os.MkdirAll(resultStorageDir, 0755); err != nil {
return nil, fmt.Errorf("创建结果存储目录失败: %w", err)
}
// 创建结果存储实例
resultStorage, err := storage.NewFileResultStorage(resultStorageDir, log.Logger)
if err != nil {
return nil, fmt.Errorf("初始化结果存储失败: %w", err)
}
// 创建Agent
maxIterations := cfg.Agent.MaxIterations
if maxIterations <= 0 {
maxIterations = 30 // 默认值
}
agent := agent.NewAgent(&cfg.OpenAI, &cfg.Agent, mcpServer, externalMCPMgr, log.Logger, maxIterations)
// 设置结果存储到Agent
agent.SetResultStorage(resultStorage)
// 设置结果存储到Executor用于查询工具
executor.SetResultStorage(resultStorage)
// 初始化知识库模块(如果启用)
var knowledgeManager *knowledge.Manager
var knowledgeRetriever *knowledge.Retriever
var knowledgeIndexer *knowledge.Indexer
var knowledgeHandler *handler.KnowledgeHandler
var knowledgeDBConn *database.DB
log.Logger.Info("检查知识库配置", zap.Bool("enabled", cfg.Knowledge.Enabled))
if cfg.Knowledge.Enabled {
// 确定知识库数据库路径
knowledgeDBPath := cfg.Database.KnowledgeDBPath
var knowledgeDB *sql.DB
if knowledgeDBPath != "" {
// 使用独立的知识库数据库
// 确保目录存在
if err := os.MkdirAll(filepath.Dir(knowledgeDBPath), 0755); err != nil {
return nil, fmt.Errorf("创建知识库数据库目录失败: %w", err)
}
var err error
knowledgeDBConn, err = database.NewKnowledgeDB(knowledgeDBPath, log.Logger)
if err != nil {
return nil, fmt.Errorf("初始化知识库数据库失败: %w", err)
}
knowledgeDB = knowledgeDBConn.DB
log.Logger.Info("使用独立的知识库数据库", zap.String("path", knowledgeDBPath))
} else {
// 向后兼容:使用会话数据库
knowledgeDB = db.DB
log.Logger.Info("使用会话数据库存储知识库数据建议配置knowledge_db_path以分离数据")
}
// 创建知识库管理器
knowledgeManager = knowledge.NewManager(knowledgeDB, cfg.Knowledge.BasePath, log.Logger)
// 创建嵌入器
// 使用OpenAI配置的API Key如果知识库配置中没有指定
if cfg.Knowledge.Embedding.APIKey == "" {
cfg.Knowledge.Embedding.APIKey = cfg.OpenAI.APIKey
}
if cfg.Knowledge.Embedding.BaseURL == "" {
cfg.Knowledge.Embedding.BaseURL = cfg.OpenAI.BaseURL
}
httpClient := &http.Client{
Timeout: 30 * time.Minute,
}
openAIClient := openai.NewClient(&cfg.OpenAI, httpClient, log.Logger)
embedder := knowledge.NewEmbedder(&cfg.Knowledge, &cfg.OpenAI, openAIClient, log.Logger)
// 创建检索器
retrievalConfig := &knowledge.RetrievalConfig{
TopK: cfg.Knowledge.Retrieval.TopK,
SimilarityThreshold: cfg.Knowledge.Retrieval.SimilarityThreshold,
HybridWeight: cfg.Knowledge.Retrieval.HybridWeight,
}
knowledgeRetriever = knowledge.NewRetriever(knowledgeDB, embedder, retrievalConfig, log.Logger)
// 创建索引器
knowledgeIndexer = knowledge.NewIndexer(knowledgeDB, embedder, log.Logger)
// 注册知识检索工具到MCP服务器
knowledge.RegisterKnowledgeTool(mcpServer, knowledgeRetriever, knowledgeManager, log.Logger)
// 创建知识库API处理器
knowledgeHandler = handler.NewKnowledgeHandler(knowledgeManager, knowledgeRetriever, knowledgeIndexer, db, log.Logger)
log.Logger.Info("知识库模块初始化完成", zap.Bool("handler_created", knowledgeHandler != nil))
// 扫描知识库并建立索引(异步)
go func() {
itemsToIndex, err := knowledgeManager.ScanKnowledgeBase()
if err != nil {
log.Logger.Warn("扫描知识库失败", zap.Error(err))
return
}
// 检查是否已有索引
hasIndex, err := knowledgeIndexer.HasIndex()
if err != nil {
log.Logger.Warn("检查索引状态失败", zap.Error(err))
return
}
if hasIndex {
// 如果已有索引,只索引新添加或更新的项
if len(itemsToIndex) > 0 {
log.Logger.Info("检测到已有知识库索引,开始增量索引", zap.Int("count", len(itemsToIndex)))
ctx := context.Background()
consecutiveFailures := 0
var firstFailureItemID string
var firstFailureError error
failedCount := 0
for _, itemID := range itemsToIndex {
if err := knowledgeIndexer.IndexItem(ctx, itemID); err != nil {
failedCount++
consecutiveFailures++
if consecutiveFailures == 1 {
firstFailureItemID = itemID
firstFailureError = err
log.Logger.Warn("索引知识项失败", zap.String("itemId", itemID), zap.Error(err))
}
// 如果连续失败2次立即停止增量索引
if consecutiveFailures >= 2 {
log.Logger.Error("连续索引失败次数过多,立即停止增量索引",
zap.Int("consecutiveFailures", consecutiveFailures),
zap.Int("totalItems", len(itemsToIndex)),
zap.String("firstFailureItemId", firstFailureItemID),
zap.Error(firstFailureError),
)
break
}
continue
}
// 成功时重置连续失败计数
if consecutiveFailures > 0 {
consecutiveFailures = 0
firstFailureItemID = ""
firstFailureError = nil
}
}
log.Logger.Info("增量索引完成", zap.Int("totalItems", len(itemsToIndex)), zap.Int("failedCount", failedCount))
} else {
log.Logger.Info("检测到已有知识库索引,没有需要索引的新项或更新项")
}
return
}
// 只有在没有索引时才自动重建
log.Logger.Info("未检测到知识库索引,开始自动构建索引")
ctx := context.Background()
if err := knowledgeIndexer.RebuildIndex(ctx); err != nil {
log.Logger.Warn("重建知识库索引失败", zap.Error(err))
}
}()
}
// 获取配置文件路径
configPath := "config.yaml"
if len(os.Args) > 1 {
configPath = os.Args[1]
}
// 创建处理器
agentHandler := handler.NewAgentHandler(agent, db, cfg, log.Logger)
// 如果知识库已启用设置知识库管理器到AgentHandler以便记录检索日志
if knowledgeManager != nil {
agentHandler.SetKnowledgeManager(knowledgeManager)
}
monitorHandler := handler.NewMonitorHandler(mcpServer, executor, db, log.Logger)
monitorHandler.SetExternalMCPManager(externalMCPMgr) // 设置外部MCP管理器以便获取外部MCP执行记录
conversationHandler := handler.NewConversationHandler(db, log.Logger)
groupHandler := handler.NewGroupHandler(db, log.Logger)
authHandler := handler.NewAuthHandler(authManager, cfg, configPath, log.Logger)
attackChainHandler := handler.NewAttackChainHandler(db, &cfg.OpenAI, log.Logger)
vulnerabilityHandler := handler.NewVulnerabilityHandler(db, log.Logger)
configHandler := handler.NewConfigHandler(configPath, cfg, mcpServer, executor, agent, attackChainHandler, externalMCPMgr, log.Logger)
externalMCPHandler := handler.NewExternalMCPHandler(externalMCPMgr, cfg, configPath, log.Logger)
roleHandler := handler.NewRoleHandler(cfg, configPath, log.Logger)
// 创建 App 实例(部分字段稍后填充)
app := &App{
config: cfg,
logger: log,
router: router,
mcpServer: mcpServer,
externalMCPMgr: externalMCPMgr,
agent: agent,
executor: executor,
db: db,
knowledgeDB: knowledgeDBConn,
auth: authManager,
knowledgeManager: knowledgeManager,
knowledgeRetriever: knowledgeRetriever,
knowledgeIndexer: knowledgeIndexer,
knowledgeHandler: knowledgeHandler,
agentHandler: agentHandler,
}
// 设置漏洞工具注册器(内置工具,必须设置)
vulnerabilityRegistrar := func() error {
registerVulnerabilityTool(mcpServer, db, log.Logger)
return nil
}
configHandler.SetVulnerabilityToolRegistrar(vulnerabilityRegistrar)
// 设置知识库初始化器(用于动态初始化,需要在 App 创建后设置)
configHandler.SetKnowledgeInitializer(func() (*handler.KnowledgeHandler, error) {
knowledgeHandler, err := initializeKnowledge(cfg, db, knowledgeDBConn, mcpServer, agentHandler, app, log.Logger)
if err != nil {
return nil, err
}
// 动态初始化后,设置知识库工具注册器和检索器更新器
// 这样后续 ApplyConfig 时就能重新注册工具了
if app.knowledgeRetriever != nil && app.knowledgeManager != nil {
// 创建闭包捕获knowledgeRetriever和knowledgeManager的引用
registrar := func() error {
knowledge.RegisterKnowledgeTool(mcpServer, app.knowledgeRetriever, app.knowledgeManager, log.Logger)
return nil
}
configHandler.SetKnowledgeToolRegistrar(registrar)
// 设置检索器更新器以便在ApplyConfig时更新检索器配置
configHandler.SetRetrieverUpdater(app.knowledgeRetriever)
log.Logger.Info("动态初始化后已设置知识库工具注册器和检索器更新器")
}
return knowledgeHandler, nil
})
// 如果知识库已启用,设置知识库工具注册器和检索器更新器
if cfg.Knowledge.Enabled && knowledgeRetriever != nil && knowledgeManager != nil {
// 创建闭包捕获knowledgeRetriever和knowledgeManager的引用
registrar := func() error {
knowledge.RegisterKnowledgeTool(mcpServer, knowledgeRetriever, knowledgeManager, log.Logger)
return nil
}
configHandler.SetKnowledgeToolRegistrar(registrar)
// 设置检索器更新器以便在ApplyConfig时更新检索器配置
configHandler.SetRetrieverUpdater(knowledgeRetriever)
}
// 设置路由(使用 App 实例以便动态获取 handler
setupRoutes(
router,
authHandler,
agentHandler,
monitorHandler,
conversationHandler,
groupHandler,
configHandler,
externalMCPHandler,
attackChainHandler,
app, // 传递 App 实例以便动态获取 knowledgeHandler
vulnerabilityHandler,
roleHandler,
mcpServer,
authManager,
)
return app, nil
}
// Run 启动应用
func (a *App) Run() error {
// 启动MCP服务器如果启用
if a.config.MCP.Enabled {
go func() {
mcpAddr := fmt.Sprintf("%s:%d", a.config.MCP.Host, a.config.MCP.Port)
a.logger.Info("启动MCP服务器", zap.String("address", mcpAddr))
mux := http.NewServeMux()
mux.HandleFunc("/mcp", a.mcpServer.HandleHTTP)
if err := http.ListenAndServe(mcpAddr, mux); err != nil {
a.logger.Error("MCP服务器启动失败", zap.Error(err))
}
}()
}
// 启动主服务器
addr := fmt.Sprintf("%s:%d", a.config.Server.Host, a.config.Server.Port)
a.logger.Info("启动HTTP服务器", zap.String("address", addr))
return a.router.Run(addr)
}
// Shutdown 关闭应用
func (a *App) Shutdown() {
// 停止所有外部MCP客户端
if a.externalMCPMgr != nil {
a.externalMCPMgr.StopAll()
}
// 关闭知识库数据库连接(如果使用独立数据库)
if a.knowledgeDB != nil {
if err := a.knowledgeDB.Close(); err != nil {
a.logger.Logger.Warn("关闭知识库数据库连接失败", zap.Error(err))
}
}
}
// setupRoutes 设置路由
func setupRoutes(
router *gin.Engine,
authHandler *handler.AuthHandler,
agentHandler *handler.AgentHandler,
monitorHandler *handler.MonitorHandler,
conversationHandler *handler.ConversationHandler,
groupHandler *handler.GroupHandler,
configHandler *handler.ConfigHandler,
externalMCPHandler *handler.ExternalMCPHandler,
attackChainHandler *handler.AttackChainHandler,
app *App, // 传递 App 实例以便动态获取 knowledgeHandler
vulnerabilityHandler *handler.VulnerabilityHandler,
roleHandler *handler.RoleHandler,
mcpServer *mcp.Server,
authManager *security.AuthManager,
) {
// API路由
api := router.Group("/api")
// 认证相关路由
authRoutes := api.Group("/auth")
{
authRoutes.POST("/login", authHandler.Login)
authRoutes.POST("/logout", security.AuthMiddleware(authManager), authHandler.Logout)
authRoutes.POST("/change-password", security.AuthMiddleware(authManager), authHandler.ChangePassword)
authRoutes.GET("/validate", security.AuthMiddleware(authManager), authHandler.Validate)
}
protected := api.Group("")
protected.Use(security.AuthMiddleware(authManager))
{
// Agent Loop
protected.POST("/agent-loop", agentHandler.AgentLoop)
// Agent Loop 流式输出
protected.POST("/agent-loop/stream", agentHandler.AgentLoopStream)
// Agent Loop 取消与任务列表
protected.POST("/agent-loop/cancel", agentHandler.CancelAgentLoop)
protected.GET("/agent-loop/tasks", agentHandler.ListAgentTasks)
protected.GET("/agent-loop/tasks/completed", agentHandler.ListCompletedTasks)
// 批量任务管理
protected.POST("/batch-tasks", agentHandler.CreateBatchQueue)
protected.GET("/batch-tasks", agentHandler.ListBatchQueues)
protected.GET("/batch-tasks/:queueId", agentHandler.GetBatchQueue)
protected.POST("/batch-tasks/:queueId/start", agentHandler.StartBatchQueue)
protected.POST("/batch-tasks/:queueId/pause", agentHandler.PauseBatchQueue)
protected.DELETE("/batch-tasks/:queueId", agentHandler.DeleteBatchQueue)
protected.PUT("/batch-tasks/:queueId/tasks/:taskId", agentHandler.UpdateBatchTask)
protected.POST("/batch-tasks/:queueId/tasks", agentHandler.AddBatchTask)
protected.DELETE("/batch-tasks/:queueId/tasks/:taskId", agentHandler.DeleteBatchTask)
// 对话历史
protected.POST("/conversations", conversationHandler.CreateConversation)
protected.GET("/conversations", conversationHandler.ListConversations)
protected.GET("/conversations/:id", conversationHandler.GetConversation)
protected.PUT("/conversations/:id", conversationHandler.UpdateConversation)
protected.DELETE("/conversations/:id", conversationHandler.DeleteConversation)
protected.PUT("/conversations/:id/pinned", groupHandler.UpdateConversationPinned)
// 对话分组
protected.POST("/groups", groupHandler.CreateGroup)
protected.GET("/groups", groupHandler.ListGroups)
protected.GET("/groups/:id", groupHandler.GetGroup)
protected.PUT("/groups/:id", groupHandler.UpdateGroup)
protected.DELETE("/groups/:id", groupHandler.DeleteGroup)
protected.PUT("/groups/:id/pinned", groupHandler.UpdateGroupPinned)
protected.GET("/groups/:id/conversations", groupHandler.GetGroupConversations)
protected.POST("/groups/conversations", groupHandler.AddConversationToGroup)
protected.DELETE("/groups/:id/conversations/:conversationId", groupHandler.RemoveConversationFromGroup)
protected.PUT("/groups/:id/conversations/:conversationId/pinned", groupHandler.UpdateConversationPinnedInGroup)
// 监控
protected.GET("/monitor", monitorHandler.Monitor)
protected.GET("/monitor/execution/:id", monitorHandler.GetExecution)
protected.DELETE("/monitor/execution/:id", monitorHandler.DeleteExecution)
protected.DELETE("/monitor/executions", monitorHandler.DeleteExecutions)
protected.GET("/monitor/stats", monitorHandler.GetStats)
// 配置管理
protected.GET("/config", configHandler.GetConfig)
protected.GET("/config/tools", configHandler.GetTools)
protected.PUT("/config", configHandler.UpdateConfig)
protected.POST("/config/apply", configHandler.ApplyConfig)
// 外部MCP管理
protected.GET("/external-mcp", externalMCPHandler.GetExternalMCPs)
protected.GET("/external-mcp/stats", externalMCPHandler.GetExternalMCPStats)
protected.GET("/external-mcp/:name", externalMCPHandler.GetExternalMCP)
protected.PUT("/external-mcp/:name", externalMCPHandler.AddOrUpdateExternalMCP)
protected.DELETE("/external-mcp/:name", externalMCPHandler.DeleteExternalMCP)
protected.POST("/external-mcp/:name/start", externalMCPHandler.StartExternalMCP)
protected.POST("/external-mcp/:name/stop", externalMCPHandler.StopExternalMCP)
// 攻击链可视化
protected.GET("/attack-chain/:conversationId", attackChainHandler.GetAttackChain)
protected.POST("/attack-chain/:conversationId/regenerate", attackChainHandler.RegenerateAttackChain)
// 知识库管理(始终注册路由,通过 App 实例动态获取 handler
knowledgeRoutes := protected.Group("/knowledge")
{
knowledgeRoutes.GET("/categories", func(c *gin.Context) {
if app.knowledgeHandler == nil {
c.JSON(http.StatusOK, gin.H{
"categories": []string{},
"enabled": false,
"message": "知识库功能未启用,请前往系统设置启用知识检索功能",
})
return
}
app.knowledgeHandler.GetCategories(c)
})
knowledgeRoutes.GET("/items", func(c *gin.Context) {
if app.knowledgeHandler == nil {
c.JSON(http.StatusOK, gin.H{
"items": []interface{}{},
"enabled": false,
"message": "知识库功能未启用,请前往系统设置启用知识检索功能",
})
return
}
app.knowledgeHandler.GetItems(c)
})
knowledgeRoutes.GET("/items/:id", func(c *gin.Context) {
if app.knowledgeHandler == nil {
c.JSON(http.StatusOK, gin.H{
"enabled": false,
"message": "知识库功能未启用,请前往系统设置启用知识检索功能",
})
return
}
app.knowledgeHandler.GetItem(c)
})
knowledgeRoutes.POST("/items", func(c *gin.Context) {
if app.knowledgeHandler == nil {
c.JSON(http.StatusOK, gin.H{
"enabled": false,
"error": "知识库功能未启用,请前往系统设置启用知识检索功能",
})
return
}
app.knowledgeHandler.CreateItem(c)
})
knowledgeRoutes.PUT("/items/:id", func(c *gin.Context) {
if app.knowledgeHandler == nil {
c.JSON(http.StatusOK, gin.H{
"enabled": false,
"error": "知识库功能未启用,请前往系统设置启用知识检索功能",
})
return
}
app.knowledgeHandler.UpdateItem(c)
})
knowledgeRoutes.DELETE("/items/:id", func(c *gin.Context) {
if app.knowledgeHandler == nil {
c.JSON(http.StatusOK, gin.H{
"enabled": false,
"error": "知识库功能未启用,请前往系统设置启用知识检索功能",
})
return
}
app.knowledgeHandler.DeleteItem(c)
})
knowledgeRoutes.GET("/index-status", func(c *gin.Context) {
if app.knowledgeHandler == nil {
c.JSON(http.StatusOK, gin.H{
"enabled": false,
"total_items": 0,
"indexed_items": 0,
"progress_percent": 0,
"is_complete": false,
"message": "知识库功能未启用,请前往系统设置启用知识检索功能",
})
return
}
app.knowledgeHandler.GetIndexStatus(c)
})
knowledgeRoutes.POST("/index", func(c *gin.Context) {
if app.knowledgeHandler == nil {
c.JSON(http.StatusOK, gin.H{
"enabled": false,
"error": "知识库功能未启用,请前往系统设置启用知识检索功能",
})
return
}
app.knowledgeHandler.RebuildIndex(c)
})
knowledgeRoutes.POST("/scan", func(c *gin.Context) {
if app.knowledgeHandler == nil {
c.JSON(http.StatusOK, gin.H{
"enabled": false,
"error": "知识库功能未启用,请前往系统设置启用知识检索功能",
})
return
}
app.knowledgeHandler.ScanKnowledgeBase(c)
})
knowledgeRoutes.GET("/retrieval-logs", func(c *gin.Context) {
if app.knowledgeHandler == nil {
c.JSON(http.StatusOK, gin.H{
"logs": []interface{}{},
"enabled": false,
"message": "知识库功能未启用,请前往系统设置启用知识检索功能",
})
return
}
app.knowledgeHandler.GetRetrievalLogs(c)
})
knowledgeRoutes.DELETE("/retrieval-logs/:id", func(c *gin.Context) {
if app.knowledgeHandler == nil {
c.JSON(http.StatusOK, gin.H{
"enabled": false,
"error": "知识库功能未启用,请前往系统设置启用知识检索功能",
})
return
}
app.knowledgeHandler.DeleteRetrievalLog(c)
})
knowledgeRoutes.POST("/search", func(c *gin.Context) {
if app.knowledgeHandler == nil {
c.JSON(http.StatusOK, gin.H{
"results": []interface{}{},
"enabled": false,
"message": "知识库功能未启用,请前往系统设置启用知识检索功能",
})
return
}
app.knowledgeHandler.Search(c)
})
}
// 漏洞管理
protected.GET("/vulnerabilities", vulnerabilityHandler.ListVulnerabilities)
protected.GET("/vulnerabilities/stats", vulnerabilityHandler.GetVulnerabilityStats)
protected.GET("/vulnerabilities/:id", vulnerabilityHandler.GetVulnerability)
protected.POST("/vulnerabilities", vulnerabilityHandler.CreateVulnerability)
protected.PUT("/vulnerabilities/:id", vulnerabilityHandler.UpdateVulnerability)
protected.DELETE("/vulnerabilities/:id", vulnerabilityHandler.DeleteVulnerability)
// 角色管理
protected.GET("/roles", roleHandler.GetRoles)
protected.GET("/roles/:name", roleHandler.GetRole)
protected.POST("/roles", roleHandler.CreateRole)
protected.PUT("/roles/:name", roleHandler.UpdateRole)
protected.DELETE("/roles/:name", roleHandler.DeleteRole)
// MCP端点
protected.POST("/mcp", func(c *gin.Context) {
mcpServer.HandleHTTP(c.Writer, c.Request)
})
}
// 静态文件
router.Static("/static", "./web/static")
router.LoadHTMLGlob("web/templates/*")
// 前端页面
router.GET("/", func(c *gin.Context) {
c.HTML(http.StatusOK, "index.html", nil)
})
}
// registerVulnerabilityTool 注册漏洞记录工具到MCP服务器
func registerVulnerabilityTool(mcpServer *mcp.Server, db *database.DB, logger *zap.Logger) {
tool := mcp.Tool{
Name: builtin.ToolRecordVulnerability,
Description: "记录发现的漏洞详情到漏洞管理系统。当发现有效漏洞时,使用此工具记录漏洞信息,包括标题、描述、严重程度、类型、目标、证明、影响和建议等。",
ShortDescription: "记录发现的漏洞详情到漏洞管理系统",
InputSchema: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"title": map[string]interface{}{
"type": "string",
"description": "漏洞标题(必需)",
},
"description": map[string]interface{}{
"type": "string",
"description": "漏洞详细描述",
},
"severity": map[string]interface{}{
"type": "string",
"description": "漏洞严重程度critical严重、high、medium、low、info信息",
"enum": []string{"critical", "high", "medium", "low", "info"},
},
"vulnerability_type": map[string]interface{}{
"type": "string",
"description": "漏洞类型SQL注入、XSS、CSRF、命令注入等",
},
"target": map[string]interface{}{
"type": "string",
"description": "受影响的目标URL、IP地址、服务等",
},
"proof": map[string]interface{}{
"type": "string",
"description": "漏洞证明POC、截图、请求/响应等)",
},
"impact": map[string]interface{}{
"type": "string",
"description": "漏洞影响说明",
},
"recommendation": map[string]interface{}{
"type": "string",
"description": "修复建议",
},
},
"required": []string{"title", "severity"},
},
}
handler := func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
// 从参数中获取conversation_id由Agent自动添加
conversationID, _ := args["conversation_id"].(string)
if conversationID == "" {
return &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: "错误: conversation_id 未设置。这是系统错误,请重试。",
},
},
IsError: true,
}, nil
}
title, ok := args["title"].(string)
if !ok || title == "" {
return &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: "错误: title 参数必需且不能为空",
},
},
IsError: true,
}, nil
}
severity, ok := args["severity"].(string)
if !ok || severity == "" {
return &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: "错误: severity 参数必需且不能为空",
},
},
IsError: true,
}, nil
}
// 验证严重程度
validSeverities := map[string]bool{
"critical": true,
"high": true,
"medium": true,
"low": true,
"info": true,
}
if !validSeverities[severity] {
return &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: fmt.Sprintf("错误: severity 必须是 critical、high、medium、low 或 info 之一,当前值: %s", severity),
},
},
IsError: true,
}, nil
}
// 获取可选参数
description := ""
if d, ok := args["description"].(string); ok {
description = d
}
vulnType := ""
if t, ok := args["vulnerability_type"].(string); ok {
vulnType = t
}
target := ""
if t, ok := args["target"].(string); ok {
target = t
}
proof := ""
if p, ok := args["proof"].(string); ok {
proof = p
}
impact := ""
if i, ok := args["impact"].(string); ok {
impact = i
}
recommendation := ""
if r, ok := args["recommendation"].(string); ok {
recommendation = r
}
// 创建漏洞记录
vuln := &database.Vulnerability{
ConversationID: conversationID,
Title: title,
Description: description,
Severity: severity,
Status: "open",
Type: vulnType,
Target: target,
Proof: proof,
Impact: impact,
Recommendation: recommendation,
}
created, err := db.CreateVulnerability(vuln)
if err != nil {
logger.Error("记录漏洞失败", zap.Error(err))
return &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: fmt.Sprintf("记录漏洞失败: %v", err),
},
},
IsError: true,
}, nil
}
logger.Info("漏洞记录成功",
zap.String("id", created.ID),
zap.String("title", created.Title),
zap.String("severity", created.Severity),
zap.String("conversation_id", conversationID),
)
return &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: fmt.Sprintf("漏洞已成功记录!\n\n漏洞ID: %s\n标题: %s\n严重程度: %s\n状态: %s\n\n你可以在漏洞管理页面查看和管理此漏洞。", created.ID, created.Title, created.Severity, created.Status),
},
},
IsError: false,
}, nil
}
mcpServer.RegisterTool(tool, handler)
logger.Info("漏洞记录工具注册成功")
}
// initializeKnowledge 初始化知识库组件(用于动态初始化)
func initializeKnowledge(
cfg *config.Config,
db *database.DB,
knowledgeDBConn *database.DB,
mcpServer *mcp.Server,
agentHandler *handler.AgentHandler,
app *App, // 传递 App 引用以便更新知识库组件
logger *zap.Logger,
) (*handler.KnowledgeHandler, error) {
// 确定知识库数据库路径
knowledgeDBPath := cfg.Database.KnowledgeDBPath
var knowledgeDB *sql.DB
if knowledgeDBPath != "" {
// 使用独立的知识库数据库
// 确保目录存在
if err := os.MkdirAll(filepath.Dir(knowledgeDBPath), 0755); err != nil {
return nil, fmt.Errorf("创建知识库数据库目录失败: %w", err)
}
var err error
knowledgeDBConn, err = database.NewKnowledgeDB(knowledgeDBPath, logger)
if err != nil {
return nil, fmt.Errorf("初始化知识库数据库失败: %w", err)
}
knowledgeDB = knowledgeDBConn.DB
logger.Info("使用独立的知识库数据库", zap.String("path", knowledgeDBPath))
} else {
// 向后兼容:使用会话数据库
knowledgeDB = db.DB
logger.Info("使用会话数据库存储知识库数据建议配置knowledge_db_path以分离数据")
}
// 创建知识库管理器
knowledgeManager := knowledge.NewManager(knowledgeDB, cfg.Knowledge.BasePath, logger)
// 创建嵌入器
// 使用OpenAI配置的API Key如果知识库配置中没有指定
if cfg.Knowledge.Embedding.APIKey == "" {
cfg.Knowledge.Embedding.APIKey = cfg.OpenAI.APIKey
}
if cfg.Knowledge.Embedding.BaseURL == "" {
cfg.Knowledge.Embedding.BaseURL = cfg.OpenAI.BaseURL
}
httpClient := &http.Client{
Timeout: 30 * time.Minute,
}
openAIClient := openai.NewClient(&cfg.OpenAI, httpClient, logger)
embedder := knowledge.NewEmbedder(&cfg.Knowledge, &cfg.OpenAI, openAIClient, logger)
// 创建检索器
retrievalConfig := &knowledge.RetrievalConfig{
TopK: cfg.Knowledge.Retrieval.TopK,
SimilarityThreshold: cfg.Knowledge.Retrieval.SimilarityThreshold,
HybridWeight: cfg.Knowledge.Retrieval.HybridWeight,
}
knowledgeRetriever := knowledge.NewRetriever(knowledgeDB, embedder, retrievalConfig, logger)
// 创建索引器
knowledgeIndexer := knowledge.NewIndexer(knowledgeDB, embedder, logger)
// 注册知识检索工具到MCP服务器
knowledge.RegisterKnowledgeTool(mcpServer, knowledgeRetriever, knowledgeManager, logger)
// 创建知识库API处理器
knowledgeHandler := handler.NewKnowledgeHandler(knowledgeManager, knowledgeRetriever, knowledgeIndexer, db, logger)
logger.Info("知识库模块初始化完成", zap.Bool("handler_created", knowledgeHandler != nil))
// 设置知识库管理器到AgentHandler以便记录检索日志
agentHandler.SetKnowledgeManager(knowledgeManager)
// 更新 App 中的知识库组件(如果 App 不为 nil说明是动态初始化
if app != nil {
app.knowledgeManager = knowledgeManager
app.knowledgeRetriever = knowledgeRetriever
app.knowledgeIndexer = knowledgeIndexer
app.knowledgeHandler = knowledgeHandler
// 如果使用独立数据库,更新 knowledgeDB
if knowledgeDBPath != "" {
app.knowledgeDB = knowledgeDBConn
}
logger.Info("App 中的知识库组件已更新")
}
// 扫描知识库并建立索引(异步)
go func() {
itemsToIndex, err := knowledgeManager.ScanKnowledgeBase()
if err != nil {
logger.Warn("扫描知识库失败", zap.Error(err))
return
}
// 检查是否已有索引
hasIndex, err := knowledgeIndexer.HasIndex()
if err != nil {
logger.Warn("检查索引状态失败", zap.Error(err))
return
}
if hasIndex {
// 如果已有索引,只索引新添加或更新的项
if len(itemsToIndex) > 0 {
logger.Info("检测到已有知识库索引,开始增量索引", zap.Int("count", len(itemsToIndex)))
ctx := context.Background()
consecutiveFailures := 0
var firstFailureItemID string
var firstFailureError error
failedCount := 0
for _, itemID := range itemsToIndex {
if err := knowledgeIndexer.IndexItem(ctx, itemID); err != nil {
failedCount++
consecutiveFailures++
if consecutiveFailures == 1 {
firstFailureItemID = itemID
firstFailureError = err
logger.Warn("索引知识项失败", zap.String("itemId", itemID), zap.Error(err))
}
// 如果连续失败2次立即停止增量索引
if consecutiveFailures >= 2 {
logger.Error("连续索引失败次数过多,立即停止增量索引",
zap.Int("consecutiveFailures", consecutiveFailures),
zap.Int("totalItems", len(itemsToIndex)),
zap.String("firstFailureItemId", firstFailureItemID),
zap.Error(firstFailureError),
)
break
}
continue
}
// 成功时重置连续失败计数
if consecutiveFailures > 0 {
consecutiveFailures = 0
firstFailureItemID = ""
firstFailureError = nil
}
}
logger.Info("增量索引完成", zap.Int("totalItems", len(itemsToIndex)), zap.Int("failedCount", failedCount))
} else {
logger.Info("检测到已有知识库索引,没有需要索引的新项或更新项")
}
return
}
// 只有在没有索引时才自动重建
logger.Info("未检测到知识库索引,开始自动构建索引")
ctx := context.Background()
if err := knowledgeIndexer.RebuildIndex(ctx); err != nil {
logger.Warn("重建知识库索引失败", zap.Error(err))
}
}()
return knowledgeHandler, nil
}
// corsMiddleware CORS中间件
func corsMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
c.Writer.Header().Set("Access-Control-Allow-Origin", "*")
c.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With")
c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT, DELETE")
if c.Request.Method == "OPTIONS" {
c.AbortWithStatus(204)
return
}
c.Next()
}
}