From 65957b2013a532cbc33092d06b6d97bf895de6f1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=85=AC=E6=98=8E?= <83812544+Ed1s0nZ@users.noreply.github.com> Date: Sat, 27 Dec 2025 19:42:21 +0800 Subject: [PATCH] Add files via upload --- internal/app/app.go | 388 ++++++++++++++++++++++++++++++---- internal/handler/config.go | 126 +++++++---- internal/handler/knowledge.go | 25 ++- internal/knowledge/indexer.go | 14 +- internal/knowledge/manager.go | 74 ++++--- web/static/js/knowledge.js | 174 ++++++++++++++- 6 files changed, 672 insertions(+), 129 deletions(-) diff --git a/internal/app/app.go b/internal/app/app.go index c6568db6..d05562ab 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -26,16 +26,21 @@ import ( // App 应用 type App struct { - config *config.Config - logger *logger.Logger - router *gin.Engine - mcpServer *mcp.Server - externalMCPMgr *mcp.ExternalMCPManager - agent *agent.Agent - executor *security.Executor - db *database.DB - knowledgeDB *database.DB // 知识库数据库连接(如果使用独立数据库) - auth *security.AuthManager + config *config.Config + logger *logger.Logger + router *gin.Engine + mcpServer *mcp.Server + externalMCPMgr *mcp.ExternalMCPManager + agent *agent.Agent + executor *security.Executor + db *database.DB + knowledgeDB *database.DB // 知识库数据库连接(如果使用独立数据库) + auth *security.AuthManager + knowledgeManager *knowledge.Manager // 知识库管理器(用于动态初始化) + knowledgeRetriever *knowledge.Retriever // 知识库检索器(用于动态初始化) + knowledgeIndexer *knowledge.Indexer // 知识库索引器(用于动态初始化) + knowledgeHandler *handler.KnowledgeHandler // 知识库处理器(用于动态初始化) + agentHandler *handler.AgentHandler // Agent处理器(用于更新知识库管理器) } // New 创建新应用 @@ -196,12 +201,13 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) { // 扫描知识库并建立索引(异步) go func() { - if err := knowledgeManager.ScanKnowledgeBase(); err != nil { + itemsToIndex, err := knowledgeManager.ScanKnowledgeBase() + if err != nil { log.Logger.Warn("扫描知识库失败", zap.Error(err)) return } - // 检查是否已有索引,如果有则跳过自动重建 + // 检查是否已有索引 hasIndex, err := knowledgeIndexer.HasIndex() if err != nil { log.Logger.Warn("检查索引状态失败", zap.Error(err)) @@ -209,7 +215,20 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) { } if hasIndex { - log.Logger.Info("检测到已有知识库索引,跳过自动重建。如需重建,请手动点击重建索引按钮") + // 如果已有索引,只索引新添加或更新的项 + if len(itemsToIndex) > 0 { + log.Logger.Info("检测到已有知识库索引,开始增量索引", zap.Int("count", len(itemsToIndex))) + ctx := context.Background() + for _, itemID := range itemsToIndex { + if err := knowledgeIndexer.IndexItem(ctx, itemID); err != nil { + log.Logger.Warn("索引知识项失败", zap.String("itemId", itemID), zap.Error(err)) + continue + } + } + log.Logger.Info("增量索引完成", zap.Int("totalItems", len(itemsToIndex))) + } else { + log.Logger.Info("检测到已有知识库索引,没有需要索引的新项或更新项") + } return } @@ -242,6 +261,51 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) { attackChainHandler := handler.NewAttackChainHandler(db, &cfg.OpenAI, log.Logger) vulnerabilityHandler := handler.NewVulnerabilityHandler(db, log.Logger) configHandler := handler.NewConfigHandler(configPath, cfg, mcpServer, executor, agent, attackChainHandler, externalMCPMgr, log.Logger) + externalMCPHandler := handler.NewExternalMCPHandler(externalMCPMgr, cfg, configPath, log.Logger) + + // 创建 App 实例(部分字段稍后填充) + app := &App{ + config: cfg, + logger: log, + router: router, + mcpServer: mcpServer, + externalMCPMgr: externalMCPMgr, + agent: agent, + executor: executor, + db: db, + knowledgeDB: knowledgeDBConn, + auth: authManager, + knowledgeManager: knowledgeManager, + knowledgeRetriever: knowledgeRetriever, + knowledgeIndexer: knowledgeIndexer, + knowledgeHandler: knowledgeHandler, + agentHandler: agentHandler, + } + + // 设置知识库初始化器(用于动态初始化,需要在 App 创建后设置) + configHandler.SetKnowledgeInitializer(func() (*handler.KnowledgeHandler, error) { + knowledgeHandler, err := initializeKnowledge(cfg, db, knowledgeDBConn, mcpServer, agentHandler, app, log.Logger) + if err != nil { + return nil, err + } + + // 动态初始化后,设置知识库工具注册器和检索器更新器 + // 这样后续 ApplyConfig 时就能重新注册工具了 + if app.knowledgeRetriever != nil && app.knowledgeManager != nil { + // 创建闭包,捕获knowledgeRetriever和knowledgeManager的引用 + registrar := func() error { + knowledge.RegisterKnowledgeTool(mcpServer, app.knowledgeRetriever, app.knowledgeManager, log.Logger) + return nil + } + configHandler.SetKnowledgeToolRegistrar(registrar) + // 设置检索器更新器,以便在ApplyConfig时更新检索器配置 + configHandler.SetRetrieverUpdater(app.knowledgeRetriever) + log.Logger.Info("动态初始化后已设置知识库工具注册器和检索器更新器") + } + + return knowledgeHandler, nil + }) + // 如果知识库已启用,设置知识库工具注册器和检索器更新器 if cfg.Knowledge.Enabled && knowledgeRetriever != nil && knowledgeManager != nil { // 创建闭包,捕获knowledgeRetriever和knowledgeManager的引用 @@ -253,9 +317,8 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) { // 设置检索器更新器,以便在ApplyConfig时更新检索器配置 configHandler.SetRetrieverUpdater(knowledgeRetriever) } - externalMCPHandler := handler.NewExternalMCPHandler(externalMCPMgr, cfg, configPath, log.Logger) - // 设置路由 + // 设置路由(使用 App 实例以便动态获取 handler) setupRoutes( router, authHandler, @@ -266,24 +329,14 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) { configHandler, externalMCPHandler, attackChainHandler, - knowledgeHandler, + app, // 传递 App 实例以便动态获取 knowledgeHandler vulnerabilityHandler, mcpServer, authManager, ) - return &App{ - config: cfg, - logger: log, - router: router, - mcpServer: mcpServer, - externalMCPMgr: externalMCPMgr, - agent: agent, - executor: executor, - db: db, - knowledgeDB: knowledgeDBConn, - auth: authManager, - }, nil + return app, nil + } // Run 启动应用 @@ -336,7 +389,7 @@ func setupRoutes( configHandler *handler.ConfigHandler, externalMCPHandler *handler.ExternalMCPHandler, attackChainHandler *handler.AttackChainHandler, - knowledgeHandler *handler.KnowledgeHandler, + app *App, // 传递 App 实例以便动态获取 knowledgeHandler vulnerabilityHandler *handler.VulnerabilityHandler, mcpServer *mcp.Server, authManager *security.AuthManager, @@ -409,20 +462,137 @@ func setupRoutes( protected.GET("/attack-chain/:conversationId", attackChainHandler.GetAttackChain) protected.POST("/attack-chain/:conversationId/regenerate", attackChainHandler.RegenerateAttackChain) - // 知识库管理(如果启用) - if knowledgeHandler != nil { - protected.GET("/knowledge/categories", knowledgeHandler.GetCategories) - protected.GET("/knowledge/items", knowledgeHandler.GetItems) - protected.GET("/knowledge/items/:id", knowledgeHandler.GetItem) - protected.POST("/knowledge/items", knowledgeHandler.CreateItem) - protected.PUT("/knowledge/items/:id", knowledgeHandler.UpdateItem) - protected.DELETE("/knowledge/items/:id", knowledgeHandler.DeleteItem) - protected.GET("/knowledge/index-status", knowledgeHandler.GetIndexStatus) - protected.POST("/knowledge/index", knowledgeHandler.RebuildIndex) - protected.POST("/knowledge/scan", knowledgeHandler.ScanKnowledgeBase) - protected.GET("/knowledge/retrieval-logs", knowledgeHandler.GetRetrievalLogs) - protected.DELETE("/knowledge/retrieval-logs/:id", knowledgeHandler.DeleteRetrievalLog) - protected.POST("/knowledge/search", knowledgeHandler.Search) + // 知识库管理(始终注册路由,通过 App 实例动态获取 handler) + knowledgeRoutes := protected.Group("/knowledge") + { + knowledgeRoutes.GET("/categories", func(c *gin.Context) { + if app.knowledgeHandler == nil { + c.JSON(http.StatusOK, gin.H{ + "categories": []string{}, + "enabled": false, + "message": "知识库功能未启用,请前往系统设置启用知识检索功能", + }) + return + } + app.knowledgeHandler.GetCategories(c) + }) + knowledgeRoutes.GET("/items", func(c *gin.Context) { + if app.knowledgeHandler == nil { + c.JSON(http.StatusOK, gin.H{ + "items": []interface{}{}, + "enabled": false, + "message": "知识库功能未启用,请前往系统设置启用知识检索功能", + }) + return + } + app.knowledgeHandler.GetItems(c) + }) + knowledgeRoutes.GET("/items/:id", func(c *gin.Context) { + if app.knowledgeHandler == nil { + c.JSON(http.StatusOK, gin.H{ + "enabled": false, + "message": "知识库功能未启用,请前往系统设置启用知识检索功能", + }) + return + } + app.knowledgeHandler.GetItem(c) + }) + knowledgeRoutes.POST("/items", func(c *gin.Context) { + if app.knowledgeHandler == nil { + c.JSON(http.StatusOK, gin.H{ + "enabled": false, + "error": "知识库功能未启用,请前往系统设置启用知识检索功能", + }) + return + } + app.knowledgeHandler.CreateItem(c) + }) + knowledgeRoutes.PUT("/items/:id", func(c *gin.Context) { + if app.knowledgeHandler == nil { + c.JSON(http.StatusOK, gin.H{ + "enabled": false, + "error": "知识库功能未启用,请前往系统设置启用知识检索功能", + }) + return + } + app.knowledgeHandler.UpdateItem(c) + }) + knowledgeRoutes.DELETE("/items/:id", func(c *gin.Context) { + if app.knowledgeHandler == nil { + c.JSON(http.StatusOK, gin.H{ + "enabled": false, + "error": "知识库功能未启用,请前往系统设置启用知识检索功能", + }) + return + } + app.knowledgeHandler.DeleteItem(c) + }) + knowledgeRoutes.GET("/index-status", func(c *gin.Context) { + if app.knowledgeHandler == nil { + c.JSON(http.StatusOK, gin.H{ + "enabled": false, + "total_items": 0, + "indexed_items": 0, + "progress_percent": 0, + "is_complete": false, + "message": "知识库功能未启用,请前往系统设置启用知识检索功能", + }) + return + } + app.knowledgeHandler.GetIndexStatus(c) + }) + knowledgeRoutes.POST("/index", func(c *gin.Context) { + if app.knowledgeHandler == nil { + c.JSON(http.StatusOK, gin.H{ + "enabled": false, + "error": "知识库功能未启用,请前往系统设置启用知识检索功能", + }) + return + } + app.knowledgeHandler.RebuildIndex(c) + }) + knowledgeRoutes.POST("/scan", func(c *gin.Context) { + if app.knowledgeHandler == nil { + c.JSON(http.StatusOK, gin.H{ + "enabled": false, + "error": "知识库功能未启用,请前往系统设置启用知识检索功能", + }) + return + } + app.knowledgeHandler.ScanKnowledgeBase(c) + }) + knowledgeRoutes.GET("/retrieval-logs", func(c *gin.Context) { + if app.knowledgeHandler == nil { + c.JSON(http.StatusOK, gin.H{ + "logs": []interface{}{}, + "enabled": false, + "message": "知识库功能未启用,请前往系统设置启用知识检索功能", + }) + return + } + app.knowledgeHandler.GetRetrievalLogs(c) + }) + knowledgeRoutes.DELETE("/retrieval-logs/:id", func(c *gin.Context) { + if app.knowledgeHandler == nil { + c.JSON(http.StatusOK, gin.H{ + "enabled": false, + "error": "知识库功能未启用,请前往系统设置启用知识检索功能", + }) + return + } + app.knowledgeHandler.DeleteRetrievalLog(c) + }) + knowledgeRoutes.POST("/search", func(c *gin.Context) { + if app.knowledgeHandler == nil { + c.JSON(http.StatusOK, gin.H{ + "results": []interface{}{}, + "enabled": false, + "message": "知识库功能未启用,请前往系统设置启用知识检索功能", + }) + return + } + app.knowledgeHandler.Search(c) + }) } // 漏洞管理 @@ -594,7 +764,7 @@ func registerVulnerabilityTool(mcpServer *mcp.Server, db *database.DB, logger *z Title: title, Description: description, Severity: severity, - Status: "open", + Status: "open", Type: vulnType, Target: target, Proof: proof, @@ -638,6 +808,136 @@ func registerVulnerabilityTool(mcpServer *mcp.Server, db *database.DB, logger *z logger.Info("漏洞记录工具注册成功") } +// initializeKnowledge 初始化知识库组件(用于动态初始化) +func initializeKnowledge( + cfg *config.Config, + db *database.DB, + knowledgeDBConn *database.DB, + mcpServer *mcp.Server, + agentHandler *handler.AgentHandler, + app *App, // 传递 App 引用以便更新知识库组件 + logger *zap.Logger, +) (*handler.KnowledgeHandler, error) { + // 确定知识库数据库路径 + knowledgeDBPath := cfg.Database.KnowledgeDBPath + var knowledgeDB *sql.DB + + if knowledgeDBPath != "" { + // 使用独立的知识库数据库 + // 确保目录存在 + if err := os.MkdirAll(filepath.Dir(knowledgeDBPath), 0755); err != nil { + return nil, fmt.Errorf("创建知识库数据库目录失败: %w", err) + } + + var err error + knowledgeDBConn, err = database.NewKnowledgeDB(knowledgeDBPath, logger) + if err != nil { + return nil, fmt.Errorf("初始化知识库数据库失败: %w", err) + } + knowledgeDB = knowledgeDBConn.DB + logger.Info("使用独立的知识库数据库", zap.String("path", knowledgeDBPath)) + } else { + // 向后兼容:使用会话数据库 + knowledgeDB = db.DB + logger.Info("使用会话数据库存储知识库数据(建议配置knowledge_db_path以分离数据)") + } + + // 创建知识库管理器 + knowledgeManager := knowledge.NewManager(knowledgeDB, cfg.Knowledge.BasePath, logger) + + // 创建嵌入器 + // 使用OpenAI配置的API Key(如果知识库配置中没有指定) + if cfg.Knowledge.Embedding.APIKey == "" { + cfg.Knowledge.Embedding.APIKey = cfg.OpenAI.APIKey + } + if cfg.Knowledge.Embedding.BaseURL == "" { + cfg.Knowledge.Embedding.BaseURL = cfg.OpenAI.BaseURL + } + + httpClient := &http.Client{ + Timeout: 30 * time.Minute, + } + openAIClient := openai.NewClient(&cfg.OpenAI, httpClient, logger) + embedder := knowledge.NewEmbedder(&cfg.Knowledge, &cfg.OpenAI, openAIClient, logger) + + // 创建检索器 + retrievalConfig := &knowledge.RetrievalConfig{ + TopK: cfg.Knowledge.Retrieval.TopK, + SimilarityThreshold: cfg.Knowledge.Retrieval.SimilarityThreshold, + HybridWeight: cfg.Knowledge.Retrieval.HybridWeight, + } + knowledgeRetriever := knowledge.NewRetriever(knowledgeDB, embedder, retrievalConfig, logger) + + // 创建索引器 + knowledgeIndexer := knowledge.NewIndexer(knowledgeDB, embedder, logger) + + // 注册知识检索工具到MCP服务器 + knowledge.RegisterKnowledgeTool(mcpServer, knowledgeRetriever, knowledgeManager, logger) + + // 创建知识库API处理器 + knowledgeHandler := handler.NewKnowledgeHandler(knowledgeManager, knowledgeRetriever, knowledgeIndexer, db, logger) + logger.Info("知识库模块初始化完成", zap.Bool("handler_created", knowledgeHandler != nil)) + + // 设置知识库管理器到AgentHandler以便记录检索日志 + agentHandler.SetKnowledgeManager(knowledgeManager) + + // 更新 App 中的知识库组件(如果 App 不为 nil,说明是动态初始化) + if app != nil { + app.knowledgeManager = knowledgeManager + app.knowledgeRetriever = knowledgeRetriever + app.knowledgeIndexer = knowledgeIndexer + app.knowledgeHandler = knowledgeHandler + // 如果使用独立数据库,更新 knowledgeDB + if knowledgeDBPath != "" { + app.knowledgeDB = knowledgeDBConn + } + logger.Info("App 中的知识库组件已更新") + } + + // 扫描知识库并建立索引(异步) + go func() { + itemsToIndex, err := knowledgeManager.ScanKnowledgeBase() + if err != nil { + logger.Warn("扫描知识库失败", zap.Error(err)) + return + } + + // 检查是否已有索引 + hasIndex, err := knowledgeIndexer.HasIndex() + if err != nil { + logger.Warn("检查索引状态失败", zap.Error(err)) + return + } + + if hasIndex { + // 如果已有索引,只索引新添加或更新的项 + if len(itemsToIndex) > 0 { + logger.Info("检测到已有知识库索引,开始增量索引", zap.Int("count", len(itemsToIndex))) + ctx := context.Background() + for _, itemID := range itemsToIndex { + if err := knowledgeIndexer.IndexItem(ctx, itemID); err != nil { + logger.Warn("索引知识项失败", zap.String("itemId", itemID), zap.Error(err)) + continue + } + } + logger.Info("增量索引完成", zap.Int("totalItems", len(itemsToIndex))) + } else { + logger.Info("检测到已有知识库索引,没有需要索引的新项或更新项") + } + return + } + + // 只有在没有索引时才自动重建 + logger.Info("未检测到知识库索引,开始自动构建索引") + ctx := context.Background() + if err := knowledgeIndexer.RebuildIndex(ctx); err != nil { + logger.Warn("重建知识库索引失败", zap.Error(err)) + } + }() + + return knowledgeHandler, nil +} + // corsMiddleware CORS中间件 func corsMiddleware() gin.HandlerFunc { return func(c *gin.Context) { diff --git a/internal/handler/config.go b/internal/handler/config.go index 74b64f14..cfea5e92 100644 --- a/internal/handler/config.go +++ b/internal/handler/config.go @@ -16,6 +16,7 @@ import ( "cyberstrike-ai/internal/knowledge" "cyberstrike-ai/internal/mcp" "cyberstrike-ai/internal/security" + "github.com/gin-gonic/gin" "go.uber.org/zap" "gopkg.in/yaml.v3" @@ -29,19 +30,29 @@ type RetrieverUpdater interface { UpdateConfig(config *knowledge.RetrievalConfig) } +// KnowledgeInitializer 知识库初始化器接口 +type KnowledgeInitializer func() (*KnowledgeHandler, error) + +// AppUpdater App更新接口(用于更新App中的知识库组件) +type AppUpdater interface { + UpdateKnowledgeComponents(handler *KnowledgeHandler, manager interface{}, retriever interface{}, indexer interface{}) +} + // ConfigHandler 配置处理器 type ConfigHandler struct { - configPath string - config *config.Config - mcpServer *mcp.Server - executor *security.Executor - agent AgentUpdater // Agent接口,用于更新Agent配置 - attackChainHandler AttackChainUpdater // 攻击链处理器接口,用于更新配置 - externalMCPMgr *mcp.ExternalMCPManager // 外部MCP管理器 - knowledgeToolRegistrar KnowledgeToolRegistrar // 知识库工具注册器(可选) - retrieverUpdater RetrieverUpdater // 检索器更新器(可选) - logger *zap.Logger - mu sync.RWMutex + configPath string + config *config.Config + mcpServer *mcp.Server + executor *security.Executor + agent AgentUpdater // Agent接口,用于更新Agent配置 + attackChainHandler AttackChainUpdater // 攻击链处理器接口,用于更新配置 + externalMCPMgr *mcp.ExternalMCPManager // 外部MCP管理器 + knowledgeToolRegistrar KnowledgeToolRegistrar // 知识库工具注册器(可选) + retrieverUpdater RetrieverUpdater // 检索器更新器(可选) + knowledgeInitializer KnowledgeInitializer // 知识库初始化器(可选) + appUpdater AppUpdater // App更新器(可选) + logger *zap.Logger + mu sync.RWMutex } // AttackChainUpdater 攻击链处理器更新接口 @@ -83,12 +94,26 @@ func (h *ConfigHandler) SetRetrieverUpdater(updater RetrieverUpdater) { h.retrieverUpdater = updater } +// SetKnowledgeInitializer 设置知识库初始化器 +func (h *ConfigHandler) SetKnowledgeInitializer(initializer KnowledgeInitializer) { + h.mu.Lock() + defer h.mu.Unlock() + h.knowledgeInitializer = initializer +} + +// SetAppUpdater 设置App更新器 +func (h *ConfigHandler) SetAppUpdater(updater AppUpdater) { + h.mu.Lock() + defer h.mu.Unlock() + h.appUpdater = updater +} + // GetConfigResponse 获取配置响应 type GetConfigResponse struct { - OpenAI config.OpenAIConfig `json:"openai"` - MCP config.MCPConfig `json:"mcp"` - Tools []ToolConfigInfo `json:"tools"` - Agent config.AgentConfig `json:"agent"` + OpenAI config.OpenAIConfig `json:"openai"` + MCP config.MCPConfig `json:"mcp"` + Tools []ToolConfigInfo `json:"tools"` + Agent config.AgentConfig `json:"agent"` Knowledge config.KnowledgeConfig `json:"knowledge"` } @@ -127,7 +152,7 @@ func (h *ConfigHandler) GetConfig(c *gin.Context) { tools[len(tools)-1].Description = desc } } - + // 从MCP服务器获取所有已注册的工具(包括直接注册的工具,如知识检索工具) if h.mcpServer != nil { mcpTools := h.mcpServer.GetAllTools() @@ -287,7 +312,7 @@ func (h *ConfigHandler) GetTools(c *gin.Context) { allTools = append(allTools, toolInfo) } - + // 从MCP服务器获取所有已注册的工具(包括直接注册的工具,如知识检索工具) if h.mcpServer != nil { mcpTools := h.mcpServer.GetAllTools() @@ -296,7 +321,7 @@ func (h *ConfigHandler) GetTools(c *gin.Context) { if configToolMap[mcpTool.Name] { continue } - + description := mcpTool.ShortDescription if description == "" { description = mcpTool.Description @@ -304,14 +329,14 @@ func (h *ConfigHandler) GetTools(c *gin.Context) { if len(description) > 100 { description = description[:100] + "..." } - + toolInfo := ToolConfigInfo{ Name: mcpTool.Name, Description: description, Enabled: true, // 直接注册的工具默认启用 IsExternal: false, } - + // 如果有关键词,进行搜索过滤 if searchTermLower != "" { nameLower := strings.ToLower(toolInfo.Name) @@ -320,7 +345,7 @@ func (h *ConfigHandler) GetTools(c *gin.Context) { continue // 不匹配,跳过 } } - + allTools = append(allTools, toolInfo) } } @@ -336,7 +361,7 @@ func (h *ConfigHandler) GetTools(c *gin.Context) { } else { // 获取外部MCP配置,用于判断启用状态 externalMCPConfigs := h.externalMCPMgr.GetConfigs() - + for _, externalTool := range externalTools { // 解析工具名称:mcpName::toolName var mcpName, actualToolName string @@ -434,7 +459,7 @@ type UpdateConfigRequest struct { OpenAI *config.OpenAIConfig `json:"openai,omitempty"` MCP *config.MCPConfig `json:"mcp,omitempty"` Tools []ToolEnableStatus `json:"tools,omitempty"` - Agent *config.AgentConfig `json:"agent,omitempty"` + Agent *config.AgentConfig `json:"agent,omitempty"` Knowledge *config.KnowledgeConfig `json:"knowledge,omitempty"` } @@ -541,12 +566,12 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) { h.logger.Warn("外部MCP配置不存在", zap.String("mcp", mcpName)) continue } - + // 初始化ToolEnabled map if cfg.ToolEnabled == nil { cfg.ToolEnabled = make(map[string]bool) } - + // 更新每个工具的启用状态 for toolName, enabled := range toolStates { cfg.ToolEnabled[toolName] = enabled @@ -556,7 +581,7 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) { zap.Bool("enabled", enabled), ) } - + // 检查是否有任何工具启用,如果有则启用MCP hasEnabledTool := false for _, enabled := range cfg.ToolEnabled { @@ -565,21 +590,21 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) { break } } - + // 如果MCP之前未启用,但现在有工具启用,则启用MCP // 如果MCP之前已启用,保持启用状态(允许部分工具禁用) if !cfg.ExternalMCPEnable && hasEnabledTool { cfg.ExternalMCPEnable = true h.logger.Info("自动启用外部MCP(因为有工具启用)", zap.String("mcp", mcpName)) } - + h.config.ExternalMCP.Servers[mcpName] = cfg } - + // 同步更新 externalMCPMgr 中的配置,确保 GetConfigs() 返回最新配置 // 在循环外部统一更新,避免重复调用 h.externalMCPMgr.LoadConfigs(&h.config.ExternalMCP) - + // 处理MCP连接状态(异步启动,避免阻塞) for mcpName := range externalMCPToolMap { cfg := h.config.ExternalMCP.Servers[mcpName] @@ -618,18 +643,41 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) { // ApplyConfig 应用配置(重新加载并重启相关服务) func (h *ConfigHandler) ApplyConfig(c *gin.Context) { + // 先检查是否需要动态初始化知识库(在锁外执行,避免阻塞其他请求) + var needInitKnowledge bool + var knowledgeInitializer KnowledgeInitializer + + h.mu.RLock() + needInitKnowledge = h.config.Knowledge.Enabled && h.knowledgeToolRegistrar == nil && h.knowledgeInitializer != nil + if needInitKnowledge { + knowledgeInitializer = h.knowledgeInitializer + } + h.mu.RUnlock() + + // 如果需要动态初始化知识库,在锁外执行(这是耗时操作) + if needInitKnowledge { + h.logger.Info("检测到知识库从禁用变为启用,开始动态初始化知识库组件") + if _, err := knowledgeInitializer(); err != nil { + h.logger.Error("动态初始化知识库失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": "初始化知识库失败: " + err.Error()}) + return + } + h.logger.Info("知识库动态初始化完成,工具已注册") + } + + // 现在获取写锁,执行快速的操作 h.mu.Lock() defer h.mu.Unlock() // 重新注册工具(根据新的启用状态) h.logger.Info("重新注册工具") - + // 清空MCP服务器中的工具 h.mcpServer.ClearTools() - + // 重新注册安全工具 h.executor.RegisterTools(h.mcpServer) - + // 如果知识库启用,重新注册知识库工具 if h.config.Knowledge.Enabled && h.knowledgeToolRegistrar != nil { h.logger.Info("重新注册知识库工具") @@ -673,7 +721,7 @@ func (h *ConfigHandler) ApplyConfig(c *gin.Context) { ) c.JSON(http.StatusOK, gin.H{ - "message": "配置已应用", + "message": "配置已应用", "tools_count": len(h.config.Security.Tools), }) } @@ -847,7 +895,7 @@ func updateKnowledgeConfig(doc *yaml.Node, cfg config.KnowledgeConfig) { knowledgeNode := ensureMap(root, "knowledge") setBoolInMap(knowledgeNode, "enabled", cfg.Enabled) setStringInMap(knowledgeNode, "base_path", cfg.BasePath) - + // 更新嵌入配置 embeddingNode := ensureMap(knowledgeNode, "embedding") setStringInMap(embeddingNode, "provider", cfg.Embedding.Provider) @@ -858,7 +906,7 @@ func updateKnowledgeConfig(doc *yaml.Node, cfg config.KnowledgeConfig) { if cfg.Embedding.APIKey != "" { setStringInMap(embeddingNode, "api_key", cfg.Embedding.APIKey) } - + // 更新检索配置 retrievalNode := ensureMap(knowledgeNode, "retrieval") setIntInMap(retrievalNode, "top_k", cfg.Retrieval.TopK) @@ -940,14 +988,14 @@ func findBoolInMap(mapNode *yaml.Node, key string) *bool { if mapNode == nil || mapNode.Kind != yaml.MappingNode { return nil } - + for i := 0; i < len(mapNode.Content); i += 2 { if i+1 >= len(mapNode.Content) { break } keyNode := mapNode.Content[i] valueNode := mapNode.Content[i+1] - + if keyNode.Kind == yaml.ScalarNode && keyNode.Value == key { if valueNode.Kind == yaml.ScalarNode { if valueNode.Value == "true" { @@ -989,5 +1037,3 @@ func setFloatInMap(mapNode *yaml.Node, key string, value float64) { valueNode.Value = fmt.Sprintf("%g", value) } } - - diff --git a/internal/handler/knowledge.go b/internal/handler/knowledge.go index 1dd7feb1..dc8b2d9f 100644 --- a/internal/handler/knowledge.go +++ b/internal/handler/knowledge.go @@ -170,21 +170,36 @@ func (h *KnowledgeHandler) RebuildIndex(c *gin.Context) { // ScanKnowledgeBase 扫描知识库 func (h *KnowledgeHandler) ScanKnowledgeBase(c *gin.Context) { - if err := h.manager.ScanKnowledgeBase(); err != nil { + itemsToIndex, err := h.manager.ScanKnowledgeBase() + if err != nil { h.logger.Error("扫描知识库失败", zap.Error(err)) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } - // 异步重建索引 + if len(itemsToIndex) == 0 { + c.JSON(http.StatusOK, gin.H{"message": "扫描完成,没有需要索引的新项或更新项"}) + return + } + + // 异步索引新添加或更新的项(增量索引) go func() { ctx := context.Background() - if err := h.indexer.RebuildIndex(ctx); err != nil { - h.logger.Error("重建索引失败", zap.Error(err)) + h.logger.Info("开始增量索引", zap.Int("count", len(itemsToIndex))) + for i, itemID := range itemsToIndex { + if err := h.indexer.IndexItem(ctx, itemID); err != nil { + h.logger.Warn("索引知识项失败", zap.String("itemId", itemID), zap.Error(err)) + continue + } + h.logger.Info("索引进度", zap.Int("current", i+1), zap.Int("total", len(itemsToIndex))) } + h.logger.Info("增量索引完成", zap.Int("totalItems", len(itemsToIndex))) }() - c.JSON(http.StatusOK, gin.H{"message": "扫描完成,索引重建已开始"}) + c.JSON(http.StatusOK, gin.H{ + "message": fmt.Sprintf("扫描完成,开始索引 %d 个新添加或更新的知识项", len(itemsToIndex)), + "items_to_index": len(itemsToIndex), + }) } // GetRetrievalLogs 获取检索日志 diff --git a/internal/knowledge/indexer.go b/internal/knowledge/indexer.go index 275f787f..a02f5bf7 100644 --- a/internal/knowledge/indexer.go +++ b/internal/knowledge/indexer.go @@ -257,7 +257,7 @@ func (idx *Indexer) IndexItem(ctx context.Context, itemID string) error { return fmt.Errorf("获取知识项失败: %w", err) } - // 删除旧的向量 + // 删除旧的向量(在 RebuildIndex 中已经统一清空,这里保留是为了单独调用 IndexItem 时的兼容性) _, err = idx.db.Exec("DELETE FROM knowledge_embeddings WHERE item_id = ?", itemID) if err != nil { return fmt.Errorf("删除旧向量失败: %w", err) @@ -338,12 +338,22 @@ func (idx *Indexer) RebuildIndex(ctx context.Context) error { idx.logger.Info("开始重建索引", zap.Int("totalItems", len(itemIDs))) + // 在开始重建前,先清空所有旧的向量,确保进度从0开始 + // 这样 GetIndexStatus 可以准确反映重建进度 + _, err = idx.db.Exec("DELETE FROM knowledge_embeddings") + if err != nil { + idx.logger.Warn("清空旧索引失败", zap.Error(err)) + // 继续执行,即使清空失败也尝试重建 + } else { + idx.logger.Info("已清空旧索引,开始重建") + } + for i, itemID := range itemIDs { if err := idx.IndexItem(ctx, itemID); err != nil { idx.logger.Warn("索引知识项失败", zap.String("itemId", itemID), zap.Error(err)) continue } - idx.logger.Debug("索引进度", zap.Int("current", i+1), zap.Int("total", len(itemIDs))) + idx.logger.Info("索引进度", zap.Int("current", i+1), zap.Int("total", len(itemIDs))) } idx.logger.Info("索引重建完成", zap.Int("totalItems", len(itemIDs))) diff --git a/internal/knowledge/manager.go b/internal/knowledge/manager.go index bebfc4eb..4f9dc95a 100644 --- a/internal/knowledge/manager.go +++ b/internal/knowledge/manager.go @@ -31,18 +31,21 @@ func NewManager(db *sql.DB, basePath string, logger *zap.Logger) *Manager { } // ScanKnowledgeBase 扫描知识库目录,更新数据库 -func (m *Manager) ScanKnowledgeBase() error { +// 返回需要索引的知识项ID列表(新添加的或更新的) +func (m *Manager) ScanKnowledgeBase() ([]string, error) { if m.basePath == "" { - return fmt.Errorf("知识库路径未配置") + return nil, fmt.Errorf("知识库路径未配置") } // 确保目录存在 if err := os.MkdirAll(m.basePath, 0755); err != nil { - return fmt.Errorf("创建知识库目录失败: %w", err) + return nil, fmt.Errorf("创建知识库目录失败: %w", err) } + var itemsToIndex []string + // 遍历知识库目录 - return filepath.WalkDir(m.basePath, func(path string, d fs.DirEntry, err error) error { + err := filepath.WalkDir(m.basePath, func(path string, d fs.DirEntry, err error) error { if err != nil { return err } @@ -77,10 +80,12 @@ func (m *Manager) ScanKnowledgeBase() error { // 检查是否已存在 var existingID string + var existingContent string + var existingUpdatedAt time.Time err = m.db.QueryRow( - "SELECT id FROM knowledge_base_items WHERE file_path = ?", + "SELECT id, content, updated_at FROM knowledge_base_items WHERE file_path = ?", path, - ).Scan(&existingID) + ).Scan(&existingID, &existingContent, &existingUpdatedAt) if err == sql.ErrNoRows { // 创建新项 @@ -94,22 +99,38 @@ func (m *Manager) ScanKnowledgeBase() error { return fmt.Errorf("插入知识项失败: %w", err) } m.logger.Info("添加知识项", zap.String("id", id), zap.String("title", title), zap.String("category", category)) + // 新添加的项需要索引 + itemsToIndex = append(itemsToIndex, id) } else if err == nil { - // 更新现有项 - _, err = m.db.Exec( - "UPDATE knowledge_base_items SET category = ?, title = ?, content = ?, updated_at = ? WHERE id = ?", - category, title, string(content), time.Now(), existingID, - ) - if err != nil { - return fmt.Errorf("更新知识项失败: %w", err) + // 检查内容是否有变化 + contentChanged := existingContent != string(content) + if contentChanged { + // 更新现有项 + _, err = m.db.Exec( + "UPDATE knowledge_base_items SET category = ?, title = ?, content = ?, updated_at = ? WHERE id = ?", + category, title, string(content), time.Now(), existingID, + ) + if err != nil { + return fmt.Errorf("更新知识项失败: %w", err) + } + m.logger.Info("更新知识项", zap.String("id", existingID), zap.String("title", title)) + // 内容已更新的项需要重新索引 + itemsToIndex = append(itemsToIndex, existingID) + } else { + m.logger.Debug("知识项未变化,跳过", zap.String("id", existingID), zap.String("title", title)) } - m.logger.Debug("更新知识项", zap.String("id", existingID), zap.String("title", title)) } else { return fmt.Errorf("查询知识项失败: %w", err) } return nil }) + + if err != nil { + return nil, err + } + + return itemsToIndex, nil } // GetCategories 获取所有分类(风险类型) @@ -170,7 +191,7 @@ func (m *Manager) GetItems(category string) ([]*KnowledgeItem, error) { time.RFC3339, time.RFC3339Nano, } - + // 解析创建时间 if createdAt != "" { for _, format := range timeFormats { @@ -181,7 +202,7 @@ func (m *Manager) GetItems(category string) ([]*KnowledgeItem, error) { } } } - + // 解析更新时间 if updatedAt != "" { for _, format := range timeFormats { @@ -192,7 +213,7 @@ func (m *Manager) GetItems(category string) ([]*KnowledgeItem, error) { } } } - + // 如果更新时间为空,使用创建时间 if item.UpdatedAt.IsZero() && !item.CreatedAt.IsZero() { item.UpdatedAt = item.CreatedAt @@ -230,7 +251,7 @@ func (m *Manager) GetItem(id string) (*KnowledgeItem, error) { time.RFC3339, time.RFC3339Nano, } - + // 解析创建时间 if createdAt != "" { for _, format := range timeFormats { @@ -241,7 +262,7 @@ func (m *Manager) GetItem(id string) (*KnowledgeItem, error) { } } } - + // 解析更新时间 if updatedAt != "" { for _, format := range timeFormats { @@ -252,7 +273,7 @@ func (m *Manager) GetItem(id string) (*KnowledgeItem, error) { } } } - + // 如果更新时间为空,使用创建时间 if item.UpdatedAt.IsZero() && !item.CreatedAt.IsZero() { item.UpdatedAt = item.CreatedAt @@ -418,10 +439,10 @@ func (m *Manager) GetIndexStatus() (map[string]interface{}, error) { isComplete := indexedItems >= totalItems && totalItems > 0 return map[string]interface{}{ - "total_items": totalItems, - "indexed_items": indexedItems, + "total_items": totalItems, + "indexed_items": indexedItems, "progress_percent": progressPercent, - "is_complete": isComplete, + "is_complete": isComplete, }, nil } @@ -472,17 +493,17 @@ func (m *Manager) GetRetrievalLogs(conversationID, messageID string, limit int) time.RFC3339, time.RFC3339Nano, } - + for _, format := range timeFormats { log.CreatedAt, err = time.Parse(format, createdAt) if err == nil && !log.CreatedAt.IsZero() { break } } - + // 如果所有格式都失败,记录警告但继续处理 if log.CreatedAt.IsZero() { - m.logger.Warn("解析检索日志时间失败", + m.logger.Warn("解析检索日志时间失败", zap.String("timeStr", createdAt), zap.Error(err), ) @@ -519,4 +540,3 @@ func (m *Manager) DeleteRetrievalLog(id string) error { return nil } - diff --git a/web/static/js/knowledge.js b/web/static/js/knowledge.js index f01be7ee..dc0848fe 100644 --- a/web/static/js/knowledge.js +++ b/web/static/js/knowledge.js @@ -22,6 +22,32 @@ async function loadKnowledgeCategories() { throw new Error('获取分类失败'); } const data = await response.json(); + + // 检查知识库功能是否启用 + if (data.enabled === false) { + // 功能未启用,显示友好提示 + const container = document.getElementById('knowledge-items-list'); + if (container) { + container.innerHTML = ` +
${data.message || '请前往系统设置启用知识检索功能'}
+ +${data.message || '请前往系统设置启用知识检索功能'}
+ +