package handler import ( "context" "fmt" "net/http" "time" "cyberstrike-ai/internal/database" "cyberstrike-ai/internal/knowledge" "github.com/gin-gonic/gin" "go.uber.org/zap" ) // KnowledgeHandler 知识库处理器 type KnowledgeHandler struct { manager *knowledge.Manager retriever *knowledge.Retriever indexer *knowledge.Indexer db *database.DB logger *zap.Logger } // NewKnowledgeHandler 创建新的知识库处理器 func NewKnowledgeHandler( manager *knowledge.Manager, retriever *knowledge.Retriever, indexer *knowledge.Indexer, db *database.DB, logger *zap.Logger, ) *KnowledgeHandler { return &KnowledgeHandler{ manager: manager, retriever: retriever, indexer: indexer, db: db, logger: logger, } } // GetCategories 获取所有分类 func (h *KnowledgeHandler) GetCategories(c *gin.Context) { categories, err := h.manager.GetCategories() if err != nil { h.logger.Error("获取分类失败", zap.Error(err)) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } c.JSON(http.StatusOK, gin.H{"categories": categories}) } // GetItems 获取知识项列表(支持按分类分页和关键字搜索,默认不返回完整内容) func (h *KnowledgeHandler) GetItems(c *gin.Context) { category := c.Query("category") searchKeyword := c.Query("search") // 搜索关键字 // 如果提供了搜索关键字,执行关键字搜索(在所有数据中搜索) if searchKeyword != "" { items, err := h.manager.SearchItemsByKeyword(searchKeyword, category) if err != nil { h.logger.Error("搜索知识项失败", zap.Error(err)) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } // 按分类分组结果 groupedByCategory := make(map[string][]*knowledge.KnowledgeItemSummary) for _, item := range items { cat := item.Category if cat == "" { cat = "未分类" } groupedByCategory[cat] = append(groupedByCategory[cat], item) } // 转换为CategoryWithItems格式 categoriesWithItems := make([]*knowledge.CategoryWithItems, 0, len(groupedByCategory)) for cat, catItems := range groupedByCategory { categoriesWithItems = append(categoriesWithItems, &knowledge.CategoryWithItems{ Category: cat, ItemCount: len(catItems), Items: catItems, }) } // 按分类名称排序 for i := 0; i < len(categoriesWithItems)-1; i++ { for j := i + 1; j < len(categoriesWithItems); j++ { if categoriesWithItems[i].Category > categoriesWithItems[j].Category { categoriesWithItems[i], categoriesWithItems[j] = categoriesWithItems[j], categoriesWithItems[i] } } } c.JSON(http.StatusOK, gin.H{ "categories": categoriesWithItems, "total": len(categoriesWithItems), "search": searchKeyword, "is_search": true, }) return } // 分页模式:categoryPage=true 表示按分类分页,否则按项分页(向后兼容) categoryPageMode := c.Query("categoryPage") != "false" // 默认使用分类分页 // 分页参数 limit := 50 // 默认每页50条(分类分页时为分类数,项分页时为项数) offset := 0 if limitStr := c.Query("limit"); limitStr != "" { if parsed, err := parseInt(limitStr); err == nil && parsed > 0 && parsed <= 500 { limit = parsed } } if offsetStr := c.Query("offset"); offsetStr != "" { if parsed, err := parseInt(offsetStr); err == nil && parsed >= 0 { offset = parsed } } // 如果指定了category参数,且使用分类分页模式,则只返回该分类 if category != "" && categoryPageMode { // 单分类模式:返回该分类的所有知识项(不分页) items, total, err := h.manager.GetItemsSummary(category, 0, 0) if err != nil { h.logger.Error("获取知识项失败", zap.Error(err)) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } // 包装成分类结构 categoriesWithItems := []*knowledge.CategoryWithItems{ { Category: category, ItemCount: total, Items: items, }, } c.JSON(http.StatusOK, gin.H{ "categories": categoriesWithItems, "total": 1, // 只有一个分类 "limit": limit, "offset": offset, }) return } if categoryPageMode { // 按分类分页模式(默认) // limit表示每页分类数,推荐5-10个分类 if limit <= 0 || limit > 100 { limit = 10 // 默认每页10个分类 } categoriesWithItems, totalCategories, err := h.manager.GetCategoriesWithItems(limit, offset) if err != nil { h.logger.Error("获取分类知识项失败", zap.Error(err)) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } c.JSON(http.StatusOK, gin.H{ "categories": categoriesWithItems, "total": totalCategories, "limit": limit, "offset": offset, }) return } // 按项分页模式(向后兼容) // 是否包含完整内容(默认false,只返回摘要) includeContent := c.Query("includeContent") == "true" if includeContent { // 返回完整内容(向后兼容) items, err := h.manager.GetItemsWithOptions(category, limit, offset, true) if err != nil { h.logger.Error("获取知识项失败", zap.Error(err)) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } // 获取总数 total, err := h.manager.GetItemsCount(category) if err != nil { h.logger.Warn("获取知识项总数失败", zap.Error(err)) total = len(items) } c.JSON(http.StatusOK, gin.H{ "items": items, "total": total, "limit": limit, "offset": offset, }) } else { // 返回摘要(不包含完整内容,推荐方式) items, total, err := h.manager.GetItemsSummary(category, limit, offset) if err != nil { h.logger.Error("获取知识项失败", zap.Error(err)) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } c.JSON(http.StatusOK, gin.H{ "items": items, "total": total, "limit": limit, "offset": offset, }) } } // GetItem 获取单个知识项 func (h *KnowledgeHandler) GetItem(c *gin.Context) { id := c.Param("id") item, err := h.manager.GetItem(id) if err != nil { h.logger.Error("获取知识项失败", zap.Error(err)) c.JSON(http.StatusNotFound, gin.H{"error": err.Error()}) return } c.JSON(http.StatusOK, item) } // CreateItem 创建知识项 func (h *KnowledgeHandler) CreateItem(c *gin.Context) { var req struct { Category string `json:"category" binding:"required"` Title string `json:"title" binding:"required"` Content string `json:"content" binding:"required"` } if err := c.ShouldBindJSON(&req); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } item, err := h.manager.CreateItem(req.Category, req.Title, req.Content) if err != nil { h.logger.Error("创建知识项失败", zap.Error(err)) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } // 异步索引 go func() { ctx := context.Background() if err := h.indexer.IndexItem(ctx, item.ID); err != nil { h.logger.Warn("索引知识项失败", zap.String("itemId", item.ID), zap.Error(err)) } }() c.JSON(http.StatusOK, item) } // UpdateItem 更新知识项 func (h *KnowledgeHandler) UpdateItem(c *gin.Context) { id := c.Param("id") var req struct { Category string `json:"category" binding:"required"` Title string `json:"title" binding:"required"` Content string `json:"content" binding:"required"` } if err := c.ShouldBindJSON(&req); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } item, err := h.manager.UpdateItem(id, req.Category, req.Title, req.Content) if err != nil { h.logger.Error("更新知识项失败", zap.Error(err)) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } // 异步重新索引 go func() { ctx := context.Background() if err := h.indexer.IndexItem(ctx, item.ID); err != nil { h.logger.Warn("重新索引知识项失败", zap.String("itemId", item.ID), zap.Error(err)) } }() c.JSON(http.StatusOK, item) } // DeleteItem 删除知识项 func (h *KnowledgeHandler) DeleteItem(c *gin.Context) { id := c.Param("id") if err := h.manager.DeleteItem(id); err != nil { h.logger.Error("删除知识项失败", zap.Error(err)) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } c.JSON(http.StatusOK, gin.H{"message": "删除成功"}) } // RebuildIndex 重建索引 func (h *KnowledgeHandler) RebuildIndex(c *gin.Context) { // 异步重建索引 go func() { ctx := context.Background() if err := h.indexer.RebuildIndex(ctx); err != nil { h.logger.Error("重建索引失败", zap.Error(err)) } }() c.JSON(http.StatusOK, gin.H{"message": "索引重建已开始,将在后台进行"}) } // ScanKnowledgeBase 扫描知识库 func (h *KnowledgeHandler) ScanKnowledgeBase(c *gin.Context) { itemsToIndex, err := h.manager.ScanKnowledgeBase() if err != nil { h.logger.Error("扫描知识库失败", zap.Error(err)) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } if len(itemsToIndex) == 0 { c.JSON(http.StatusOK, gin.H{"message": "扫描完成,没有需要索引的新项或更新项"}) return } // 异步索引新添加或更新的项(增量索引) go func() { ctx := context.Background() h.logger.Info("开始增量索引", zap.Int("count", len(itemsToIndex))) failedCount := 0 consecutiveFailures := 0 var firstFailureItemID string var firstFailureError error for i, itemID := range itemsToIndex { if err := h.indexer.IndexItem(ctx, itemID); err != nil { failedCount++ consecutiveFailures++ // 只在第一个失败时记录详细日志 if consecutiveFailures == 1 { firstFailureItemID = itemID firstFailureError = err h.logger.Warn("索引知识项失败", zap.String("itemId", itemID), zap.Int("totalItems", len(itemsToIndex)), zap.Error(err), ) } // 如果连续失败2次,立即停止增量索引 if consecutiveFailures >= 2 { h.logger.Error("连续索引失败次数过多,立即停止增量索引", zap.Int("consecutiveFailures", consecutiveFailures), zap.Int("totalItems", len(itemsToIndex)), zap.Int("processedItems", i+1), zap.String("firstFailureItemId", firstFailureItemID), zap.Error(firstFailureError), ) break } continue } // 成功时重置连续失败计数 if consecutiveFailures > 0 { consecutiveFailures = 0 firstFailureItemID = "" firstFailureError = nil } // 减少进度日志频率 if (i+1)%10 == 0 || i+1 == len(itemsToIndex) { h.logger.Info("索引进度", zap.Int("current", i+1), zap.Int("total", len(itemsToIndex)), zap.Int("failed", failedCount)) } } h.logger.Info("增量索引完成", zap.Int("totalItems", len(itemsToIndex)), zap.Int("failedCount", failedCount)) }() c.JSON(http.StatusOK, gin.H{ "message": fmt.Sprintf("扫描完成,开始索引 %d 个新添加或更新的知识项", len(itemsToIndex)), "items_to_index": len(itemsToIndex), }) } // GetRetrievalLogs 获取检索日志 func (h *KnowledgeHandler) GetRetrievalLogs(c *gin.Context) { conversationID := c.Query("conversationId") messageID := c.Query("messageId") limit := 50 // 默认50条 if limitStr := c.Query("limit"); limitStr != "" { if parsed, err := parseInt(limitStr); err == nil && parsed > 0 { limit = parsed } } logs, err := h.manager.GetRetrievalLogs(conversationID, messageID, limit) if err != nil { h.logger.Error("获取检索日志失败", zap.Error(err)) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } c.JSON(http.StatusOK, gin.H{"logs": logs}) } // DeleteRetrievalLog 删除检索日志 func (h *KnowledgeHandler) DeleteRetrievalLog(c *gin.Context) { id := c.Param("id") if err := h.manager.DeleteRetrievalLog(id); err != nil { h.logger.Error("删除检索日志失败", zap.Error(err)) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } c.JSON(http.StatusOK, gin.H{"message": "删除成功"}) } // GetIndexStatus 获取索引状态 func (h *KnowledgeHandler) GetIndexStatus(c *gin.Context) { status, err := h.manager.GetIndexStatus() if err != nil { h.logger.Error("获取索引状态失败", zap.Error(err)) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } // 获取索引器的错误信息 if h.indexer != nil { lastError, lastErrorTime := h.indexer.GetLastError() if lastError != "" { // 如果错误是最近发生的(5分钟内),则返回错误信息 if time.Since(lastErrorTime) < 5*time.Minute { status["last_error"] = lastError status["last_error_time"] = lastErrorTime.Format(time.RFC3339) } } } c.JSON(http.StatusOK, status) } // Search 搜索知识库(用于API调用,Agent内部使用Retriever) func (h *KnowledgeHandler) Search(c *gin.Context) { var req knowledge.SearchRequest if err := c.ShouldBindJSON(&req); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } results, err := h.retriever.Search(c.Request.Context(), &req) if err != nil { h.logger.Error("搜索知识库失败", zap.Error(err)) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } c.JSON(http.StatusOK, gin.H{"results": results}) } // GetStats 获取知识库统计信息 func (h *KnowledgeHandler) GetStats(c *gin.Context) { totalCategories, totalItems, err := h.manager.GetStats() if err != nil { h.logger.Error("获取知识库统计信息失败", zap.Error(err)) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } c.JSON(http.StatusOK, gin.H{ "enabled": true, "total_categories": totalCategories, "total_items": totalItems, }) } // 辅助函数:解析整数 func parseInt(s string) (int, error) { var result int _, err := fmt.Sscanf(s, "%d", &result) return result, err }