diff --git a/internal/app/app.go b/internal/app/app.go index 119e2017..a566e2f7 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -214,23 +214,53 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) { return } - if hasIndex { - // 如果已有索引,只索引新添加或更新的项 - if len(itemsToIndex) > 0 { - log.Logger.Info("检测到已有知识库索引,开始增量索引", zap.Int("count", len(itemsToIndex))) - ctx := context.Background() - for _, itemID := range itemsToIndex { - if err := knowledgeIndexer.IndexItem(ctx, itemID); err != nil { - log.Logger.Warn("索引知识项失败", zap.String("itemId", itemID), zap.Error(err)) - continue + if hasIndex { + // 如果已有索引,只索引新添加或更新的项 + if len(itemsToIndex) > 0 { + log.Logger.Info("检测到已有知识库索引,开始增量索引", zap.Int("count", len(itemsToIndex))) + ctx := context.Background() + consecutiveFailures := 0 + var firstFailureItemID string + var firstFailureError error + failedCount := 0 + + for _, itemID := range itemsToIndex { + if err := knowledgeIndexer.IndexItem(ctx, itemID); err != nil { + failedCount++ + consecutiveFailures++ + + if consecutiveFailures == 1 { + firstFailureItemID = itemID + firstFailureError = err + log.Logger.Warn("索引知识项失败", zap.String("itemId", itemID), zap.Error(err)) + } + + // 如果连续失败2次,立即停止增量索引 + if consecutiveFailures >= 2 { + log.Logger.Error("连续索引失败次数过多,立即停止增量索引", + zap.Int("consecutiveFailures", consecutiveFailures), + zap.Int("totalItems", len(itemsToIndex)), + zap.String("firstFailureItemId", firstFailureItemID), + zap.Error(firstFailureError), + ) + break + } + continue + } + + // 成功时重置连续失败计数 + if consecutiveFailures > 0 { + consecutiveFailures = 0 + firstFailureItemID = "" + firstFailureError = nil + } } + log.Logger.Info("增量索引完成", zap.Int("totalItems", len(itemsToIndex)), zap.Int("failedCount", failedCount)) + } else { + log.Logger.Info("检测到已有知识库索引,没有需要索引的新项或更新项") } - log.Logger.Info("增量索引完成", zap.Int("totalItems", len(itemsToIndex))) - } else { - log.Logger.Info("检测到已有知识库索引,没有需要索引的新项或更新项") + return } - return - } // 只有在没有索引时才自动重建 log.Logger.Info("未检测到知识库索引,开始自动构建索引") @@ -934,13 +964,43 @@ func initializeKnowledge( if len(itemsToIndex) > 0 { logger.Info("检测到已有知识库索引,开始增量索引", zap.Int("count", len(itemsToIndex))) ctx := context.Background() + consecutiveFailures := 0 + var firstFailureItemID string + var firstFailureError error + failedCount := 0 + for _, itemID := range itemsToIndex { if err := knowledgeIndexer.IndexItem(ctx, itemID); err != nil { - logger.Warn("索引知识项失败", zap.String("itemId", itemID), zap.Error(err)) + failedCount++ + consecutiveFailures++ + + if consecutiveFailures == 1 { + firstFailureItemID = itemID + firstFailureError = err + logger.Warn("索引知识项失败", zap.String("itemId", itemID), zap.Error(err)) + } + + // 如果连续失败2次,立即停止增量索引 + if consecutiveFailures >= 2 { + logger.Error("连续索引失败次数过多,立即停止增量索引", + zap.Int("consecutiveFailures", consecutiveFailures), + zap.Int("totalItems", len(itemsToIndex)), + zap.String("firstFailureItemId", firstFailureItemID), + zap.Error(firstFailureError), + ) + break + } continue } + + // 成功时重置连续失败计数 + if consecutiveFailures > 0 { + consecutiveFailures = 0 + firstFailureItemID = "" + firstFailureError = nil + } } - logger.Info("增量索引完成", zap.Int("totalItems", len(itemsToIndex))) + logger.Info("增量索引完成", zap.Int("totalItems", len(itemsToIndex)), zap.Int("failedCount", failedCount)) } else { logger.Info("检测到已有知识库索引,没有需要索引的新项或更新项") } diff --git a/internal/handler/config.go b/internal/handler/config.go index cd0e2a5b..93233de7 100644 --- a/internal/handler/config.go +++ b/internal/handler/config.go @@ -57,6 +57,7 @@ type ConfigHandler struct { appUpdater AppUpdater // App更新器(可选) logger *zap.Logger mu sync.RWMutex + lastEmbeddingConfig *config.EmbeddingConfig // 上一次的嵌入模型配置(用于检测变更) } // AttackChainUpdater 攻击链处理器更新接口 @@ -72,15 +73,26 @@ type AgentUpdater interface { // NewConfigHandler 创建新的配置处理器 func NewConfigHandler(configPath string, cfg *config.Config, mcpServer *mcp.Server, executor *security.Executor, agent AgentUpdater, attackChainHandler AttackChainUpdater, externalMCPMgr *mcp.ExternalMCPManager, logger *zap.Logger) *ConfigHandler { + // 保存初始的嵌入模型配置(如果知识库已启用) + var lastEmbeddingConfig *config.EmbeddingConfig + if cfg.Knowledge.Enabled { + lastEmbeddingConfig = &config.EmbeddingConfig{ + Provider: cfg.Knowledge.Embedding.Provider, + Model: cfg.Knowledge.Embedding.Model, + BaseURL: cfg.Knowledge.Embedding.BaseURL, + APIKey: cfg.Knowledge.Embedding.APIKey, + } + } return &ConfigHandler{ - configPath: configPath, - config: cfg, - mcpServer: mcpServer, - executor: executor, - agent: agent, - attackChainHandler: attackChainHandler, - externalMCPMgr: externalMCPMgr, - logger: logger, + configPath: configPath, + config: cfg, + mcpServer: mcpServer, + executor: executor, + agent: agent, + attackChainHandler: attackChainHandler, + externalMCPMgr: externalMCPMgr, + logger: logger, + lastEmbeddingConfig: lastEmbeddingConfig, } } @@ -522,6 +534,15 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) { // 更新Knowledge配置 if req.Knowledge != nil { + // 保存旧的嵌入模型配置(用于检测变更) + if h.config.Knowledge.Enabled { + h.lastEmbeddingConfig = &config.EmbeddingConfig{ + Provider: h.config.Knowledge.Embedding.Provider, + Model: h.config.Knowledge.Embedding.Model, + BaseURL: h.config.Knowledge.Embedding.BaseURL, + APIKey: h.config.Knowledge.Embedding.APIKey, + } + } h.config.Knowledge = *req.Knowledge h.logger.Info("更新Knowledge配置", zap.Bool("enabled", h.config.Knowledge.Enabled), @@ -676,10 +697,55 @@ func (h *ConfigHandler) ApplyConfig(c *gin.Context) { h.logger.Info("知识库动态初始化完成,工具已注册") } + // 检查嵌入模型配置是否变更(需要在锁外执行,避免阻塞) + var needReinitKnowledge bool + var reinitKnowledgeInitializer KnowledgeInitializer + h.mu.RLock() + if h.config.Knowledge.Enabled && h.knowledgeInitializer != nil && h.lastEmbeddingConfig != nil { + // 检查嵌入模型配置是否变更 + currentEmbedding := h.config.Knowledge.Embedding + if currentEmbedding.Provider != h.lastEmbeddingConfig.Provider || + currentEmbedding.Model != h.lastEmbeddingConfig.Model || + currentEmbedding.BaseURL != h.lastEmbeddingConfig.BaseURL || + currentEmbedding.APIKey != h.lastEmbeddingConfig.APIKey { + needReinitKnowledge = true + reinitKnowledgeInitializer = h.knowledgeInitializer + h.logger.Info("检测到嵌入模型配置变更,需要重新初始化知识库组件", + zap.String("old_model", h.lastEmbeddingConfig.Model), + zap.String("new_model", currentEmbedding.Model), + zap.String("old_base_url", h.lastEmbeddingConfig.BaseURL), + zap.String("new_base_url", currentEmbedding.BaseURL), + ) + } + } + h.mu.RUnlock() + + // 如果需要重新初始化知识库(嵌入模型配置变更),在锁外执行 + if needReinitKnowledge { + h.logger.Info("开始重新初始化知识库组件(嵌入模型配置已变更)") + if _, err := reinitKnowledgeInitializer(); err != nil { + h.logger.Error("重新初始化知识库失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": "重新初始化知识库失败: " + err.Error()}) + return + } + h.logger.Info("知识库组件重新初始化完成") + } + // 现在获取写锁,执行快速的操作 h.mu.Lock() defer h.mu.Unlock() + // 如果重新初始化了知识库,更新嵌入模型配置记录 + if needReinitKnowledge && h.config.Knowledge.Enabled { + h.lastEmbeddingConfig = &config.EmbeddingConfig{ + Provider: h.config.Knowledge.Embedding.Provider, + Model: h.config.Knowledge.Embedding.Model, + BaseURL: h.config.Knowledge.Embedding.BaseURL, + APIKey: h.config.Knowledge.Embedding.APIKey, + } + h.logger.Info("已更新嵌入模型配置记录") + } + // 重新注册工具(根据新的启用状态) h.logger.Info("重新注册工具") @@ -722,20 +788,30 @@ func (h *ConfigHandler) ApplyConfig(c *gin.Context) { h.logger.Info("AttackChainHandler配置已更新") } - // 更新检索器配置(如果知识库启用) - if h.config.Knowledge.Enabled && h.retrieverUpdater != nil { - retrievalConfig := &knowledge.RetrievalConfig{ - TopK: h.config.Knowledge.Retrieval.TopK, - SimilarityThreshold: h.config.Knowledge.Retrieval.SimilarityThreshold, - HybridWeight: h.config.Knowledge.Retrieval.HybridWeight, + // 更新检索器配置(如果知识库启用) + if h.config.Knowledge.Enabled && h.retrieverUpdater != nil { + retrievalConfig := &knowledge.RetrievalConfig{ + TopK: h.config.Knowledge.Retrieval.TopK, + SimilarityThreshold: h.config.Knowledge.Retrieval.SimilarityThreshold, + HybridWeight: h.config.Knowledge.Retrieval.HybridWeight, + } + h.retrieverUpdater.UpdateConfig(retrievalConfig) + h.logger.Info("检索器配置已更新", + zap.Int("top_k", retrievalConfig.TopK), + zap.Float64("similarity_threshold", retrievalConfig.SimilarityThreshold), + zap.Float64("hybrid_weight", retrievalConfig.HybridWeight), + ) + } + + // 更新嵌入模型配置记录(如果知识库启用) + if h.config.Knowledge.Enabled { + h.lastEmbeddingConfig = &config.EmbeddingConfig{ + Provider: h.config.Knowledge.Embedding.Provider, + Model: h.config.Knowledge.Embedding.Model, + BaseURL: h.config.Knowledge.Embedding.BaseURL, + APIKey: h.config.Knowledge.Embedding.APIKey, + } } - h.retrieverUpdater.UpdateConfig(retrievalConfig) - h.logger.Info("检索器配置已更新", - zap.Int("top_k", retrievalConfig.TopK), - zap.Float64("similarity_threshold", retrievalConfig.SimilarityThreshold), - zap.Float64("hybrid_weight", retrievalConfig.HybridWeight), - ) - } h.logger.Info("配置已应用", zap.Int("tools_count", len(h.config.Security.Tools)), diff --git a/internal/handler/knowledge.go b/internal/handler/knowledge.go index fa6c009b..79addac8 100644 --- a/internal/handler/knowledge.go +++ b/internal/handler/knowledge.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "net/http" + "time" "cyberstrike-ai/internal/database" "cyberstrike-ai/internal/knowledge" @@ -336,14 +337,54 @@ func (h *KnowledgeHandler) ScanKnowledgeBase(c *gin.Context) { go func() { ctx := context.Background() h.logger.Info("开始增量索引", zap.Int("count", len(itemsToIndex))) + failedCount := 0 + consecutiveFailures := 0 + var firstFailureItemID string + var firstFailureError error + for i, itemID := range itemsToIndex { if err := h.indexer.IndexItem(ctx, itemID); err != nil { - h.logger.Warn("索引知识项失败", zap.String("itemId", itemID), zap.Error(err)) + failedCount++ + consecutiveFailures++ + + // 只在第一个失败时记录详细日志 + if consecutiveFailures == 1 { + firstFailureItemID = itemID + firstFailureError = err + h.logger.Warn("索引知识项失败", + zap.String("itemId", itemID), + zap.Int("totalItems", len(itemsToIndex)), + zap.Error(err), + ) + } + + // 如果连续失败2次,立即停止增量索引 + if consecutiveFailures >= 2 { + h.logger.Error("连续索引失败次数过多,立即停止增量索引", + zap.Int("consecutiveFailures", consecutiveFailures), + zap.Int("totalItems", len(itemsToIndex)), + zap.Int("processedItems", i+1), + zap.String("firstFailureItemId", firstFailureItemID), + zap.Error(firstFailureError), + ) + break + } continue } - h.logger.Info("索引进度", zap.Int("current", i+1), zap.Int("total", len(itemsToIndex))) + + // 成功时重置连续失败计数 + if consecutiveFailures > 0 { + consecutiveFailures = 0 + firstFailureItemID = "" + firstFailureError = nil + } + + // 减少进度日志频率 + if (i+1)%10 == 0 || i+1 == len(itemsToIndex) { + h.logger.Info("索引进度", zap.Int("current", i+1), zap.Int("total", len(itemsToIndex)), zap.Int("failed", failedCount)) + } } - h.logger.Info("增量索引完成", zap.Int("totalItems", len(itemsToIndex))) + h.logger.Info("增量索引完成", zap.Int("totalItems", len(itemsToIndex)), zap.Int("failedCount", failedCount)) }() c.JSON(http.StatusOK, gin.H{ @@ -396,6 +437,18 @@ func (h *KnowledgeHandler) GetIndexStatus(c *gin.Context) { return } + // 获取索引器的错误信息 + if h.indexer != nil { + lastError, lastErrorTime := h.indexer.GetLastError() + if lastError != "" { + // 如果错误是最近发生的(5分钟内),则返回错误信息 + if time.Since(lastErrorTime) < 5*time.Minute { + status["last_error"] = lastError + status["last_error_time"] = lastErrorTime.Format(time.RFC3339) + } + } + } + c.JSON(http.StatusOK, status) } diff --git a/internal/knowledge/indexer.go b/internal/knowledge/indexer.go index a02f5bf7..d5a49afc 100644 --- a/internal/knowledge/indexer.go +++ b/internal/knowledge/indexer.go @@ -7,6 +7,8 @@ import ( "fmt" "regexp" "strings" + "sync" + "time" "github.com/google/uuid" "go.uber.org/zap" @@ -19,6 +21,12 @@ type Indexer struct { logger *zap.Logger chunkSize int // 每个块的最大token数(估算) overlap int // 块之间的重叠token数 + + // 错误跟踪 + mu sync.RWMutex + lastError string // 最近一次错误信息 + lastErrorTime time.Time // 最近一次错误时间 + errorCount int // 连续错误计数 } // NewIndexer 创建新的索引器 @@ -267,13 +275,13 @@ func (idx *Indexer) IndexItem(ctx context.Context, itemID string) error { chunks := idx.ChunkText(content) idx.logger.Info("知识项分块完成", zap.String("itemId", itemID), zap.Int("chunks", len(chunks))) + // 跟踪该知识项的错误 + itemErrorCount := 0 + var firstError error + firstErrorChunkIndex := -1 + // 向量化每个块(包含category和title信息,以便向量检索时能匹配到风险类型) for i, chunk := range chunks { - chunkPreview := chunk - if len(chunkPreview) > 200 { - chunkPreview = chunkPreview[:200] + "..." - } - // 将category和title信息包含到向量化的文本中 // 格式:"[风险类型: {category}] [标题: {title}]\n{chunk内容}" // 这样向量嵌入就会包含风险类型信息,即使SQL过滤失败,向量相似度也能帮助匹配 @@ -281,13 +289,43 @@ func (idx *Indexer) IndexItem(ctx context.Context, itemID string) error { embedding, err := idx.embedder.EmbedText(ctx, textForEmbedding) if err != nil { - idx.logger.Warn("向量化失败", - zap.String("itemId", itemID), - zap.Int("chunkIndex", i), - zap.Int("chunkLength", len(chunk)), - zap.String("chunkPreview", chunkPreview), - zap.Error(err), - ) + itemErrorCount++ + if firstError == nil { + firstError = err + firstErrorChunkIndex = i + // 只在第一个块失败时记录详细日志 + chunkPreview := chunk + if len(chunkPreview) > 200 { + chunkPreview = chunkPreview[:200] + "..." + } + idx.logger.Warn("向量化失败", + zap.String("itemId", itemID), + zap.Int("chunkIndex", i), + zap.Int("totalChunks", len(chunks)), + zap.String("chunkPreview", chunkPreview), + zap.Error(err), + ) + + // 更新全局错误跟踪 + errorMsg := fmt.Sprintf("向量化失败 (知识项: %s): %v", itemID, err) + idx.mu.Lock() + idx.lastError = errorMsg + idx.lastErrorTime = time.Now() + idx.mu.Unlock() + } + + // 如果连续失败2个块,立即停止处理该知识项(降低阈值,更快停止) + // 这样可以避免继续浪费API调用,同时也能更快地检测到配置问题 + if itemErrorCount >= 2 { + idx.logger.Error("知识项连续向量化失败,停止处理", + zap.String("itemId", itemID), + zap.Int("totalChunks", len(chunks)), + zap.Int("failedChunks", itemErrorCount), + zap.Int("firstErrorChunkIndex", firstErrorChunkIndex), + zap.Error(firstError), + ) + return fmt.Errorf("知识项连续向量化失败 (%d个块失败): %v", itemErrorCount, firstError) + } continue } @@ -321,6 +359,13 @@ func (idx *Indexer) HasIndex() (bool, error) { // RebuildIndex 重建所有索引 func (idx *Indexer) RebuildIndex(ctx context.Context) error { + // 重置错误跟踪 + idx.mu.Lock() + idx.lastError = "" + idx.lastErrorTime = time.Time{} + idx.errorCount = 0 + idx.mu.Unlock() + rows, err := idx.db.Query("SELECT id FROM knowledge_base_items") if err != nil { return fmt.Errorf("查询知识项失败: %w", err) @@ -348,14 +393,84 @@ func (idx *Indexer) RebuildIndex(ctx context.Context) error { idx.logger.Info("已清空旧索引,开始重建") } + failedCount := 0 + consecutiveFailures := 0 + maxConsecutiveFailures := 2 // 连续失败2次后立即停止(降低阈值,更快停止) + firstFailureItemID := "" + var firstFailureError error + for i, itemID := range itemIDs { if err := idx.IndexItem(ctx, itemID); err != nil { - idx.logger.Warn("索引知识项失败", zap.String("itemId", itemID), zap.Error(err)) + failedCount++ + consecutiveFailures++ + + // 只在第一个失败时记录详细日志 + if consecutiveFailures == 1 { + firstFailureItemID = itemID + firstFailureError = err + idx.logger.Warn("索引知识项失败", + zap.String("itemId", itemID), + zap.Int("totalItems", len(itemIDs)), + zap.Error(err), + ) + } + + // 如果连续失败过多,可能是配置问题,立即停止索引 + if consecutiveFailures >= maxConsecutiveFailures { + errorMsg := fmt.Sprintf("连续 %d 个知识项索引失败,可能存在配置问题(如嵌入模型配置错误、API密钥无效、余额不足等)。第一个失败项: %s, 错误: %v", consecutiveFailures, firstFailureItemID, firstFailureError) + idx.mu.Lock() + idx.lastError = errorMsg + idx.lastErrorTime = time.Now() + idx.mu.Unlock() + + idx.logger.Error("连续索引失败次数过多,立即停止索引", + zap.Int("consecutiveFailures", consecutiveFailures), + zap.Int("totalItems", len(itemIDs)), + zap.Int("processedItems", i+1), + zap.String("firstFailureItemId", firstFailureItemID), + zap.Error(firstFailureError), + ) + return fmt.Errorf("连续索引失败次数过多: %v", firstFailureError) + } + + // 如果失败的知识项过多,记录警告但继续处理(降低阈值到30%) + if failedCount > len(itemIDs)*3/10 && failedCount == len(itemIDs)*3/10+1 { + errorMsg := fmt.Sprintf("索引失败的知识项过多 (%d/%d),可能存在配置问题。第一个失败项: %s, 错误: %v", failedCount, len(itemIDs), firstFailureItemID, firstFailureError) + idx.mu.Lock() + idx.lastError = errorMsg + idx.lastErrorTime = time.Now() + idx.mu.Unlock() + + idx.logger.Error("索引失败的知识项过多,可能存在配置问题", + zap.Int("failedCount", failedCount), + zap.Int("totalItems", len(itemIDs)), + zap.String("firstFailureItemId", firstFailureItemID), + zap.Error(firstFailureError), + ) + } continue } - idx.logger.Info("索引进度", zap.Int("current", i+1), zap.Int("total", len(itemIDs))) + + // 成功时重置连续失败计数和第一个失败信息 + if consecutiveFailures > 0 { + consecutiveFailures = 0 + firstFailureItemID = "" + firstFailureError = nil + } + + // 减少进度日志频率(每10个或每10%记录一次) + if (i+1)%10 == 0 || (len(itemIDs) > 0 && (i+1)*100/len(itemIDs)%10 == 0 && (i+1)*100/len(itemIDs) > 0) { + idx.logger.Info("索引进度", zap.Int("current", i+1), zap.Int("total", len(itemIDs)), zap.Int("failed", failedCount)) + } } - idx.logger.Info("索引重建完成", zap.Int("totalItems", len(itemIDs))) + idx.logger.Info("索引重建完成", zap.Int("totalItems", len(itemIDs)), zap.Int("failedCount", failedCount)) return nil } + +// GetLastError 获取最近一次错误信息 +func (idx *Indexer) GetLastError() (string, time.Time) { + idx.mu.RLock() + defer idx.mu.RUnlock() + return idx.lastError, idx.lastErrorTime +} diff --git a/web/static/js/knowledge.js b/web/static/js/knowledge.js index 26126157..783b7eff 100644 --- a/web/static/js/knowledge.js +++ b/web/static/js/knowledge.js @@ -457,6 +457,7 @@ async function updateIndexProgress() { const indexedItems = status.indexed_items || 0; const progressPercent = status.progress_percent || 0; const isComplete = status.is_complete || false; + const lastError = status.last_error || ''; if (totalItems === 0) { // 没有知识项,隐藏进度条 @@ -471,6 +472,58 @@ async function updateIndexProgress() { // 显示进度条 progressContainer.style.display = 'block'; + // 如果有错误信息,显示错误 + if (lastError) { + progressContainer.innerHTML = ` +