mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-03-31 16:20:28 +02:00
495 lines
14 KiB
Go
495 lines
14 KiB
Go
package handler
|
||
|
||
import (
|
||
"context"
|
||
"fmt"
|
||
"net/http"
|
||
"time"
|
||
|
||
"cyberstrike-ai/internal/database"
|
||
"cyberstrike-ai/internal/knowledge"
|
||
|
||
"github.com/gin-gonic/gin"
|
||
"go.uber.org/zap"
|
||
)
|
||
|
||
// KnowledgeHandler 知识库处理器
|
||
type KnowledgeHandler struct {
|
||
manager *knowledge.Manager
|
||
retriever *knowledge.Retriever
|
||
indexer *knowledge.Indexer
|
||
db *database.DB
|
||
logger *zap.Logger
|
||
}
|
||
|
||
// NewKnowledgeHandler 创建新的知识库处理器
|
||
func NewKnowledgeHandler(
|
||
manager *knowledge.Manager,
|
||
retriever *knowledge.Retriever,
|
||
indexer *knowledge.Indexer,
|
||
db *database.DB,
|
||
logger *zap.Logger,
|
||
) *KnowledgeHandler {
|
||
return &KnowledgeHandler{
|
||
manager: manager,
|
||
retriever: retriever,
|
||
indexer: indexer,
|
||
db: db,
|
||
logger: logger,
|
||
}
|
||
}
|
||
|
||
// GetCategories 获取所有分类
|
||
func (h *KnowledgeHandler) GetCategories(c *gin.Context) {
|
||
categories, err := h.manager.GetCategories()
|
||
if err != nil {
|
||
h.logger.Error("获取分类失败", zap.Error(err))
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||
return
|
||
}
|
||
|
||
c.JSON(http.StatusOK, gin.H{"categories": categories})
|
||
}
|
||
|
||
// GetItems 获取知识项列表(支持按分类分页和关键字搜索,默认不返回完整内容)
|
||
func (h *KnowledgeHandler) GetItems(c *gin.Context) {
|
||
category := c.Query("category")
|
||
searchKeyword := c.Query("search") // 搜索关键字
|
||
|
||
// 如果提供了搜索关键字,执行关键字搜索(在所有数据中搜索)
|
||
if searchKeyword != "" {
|
||
items, err := h.manager.SearchItemsByKeyword(searchKeyword, category)
|
||
if err != nil {
|
||
h.logger.Error("搜索知识项失败", zap.Error(err))
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||
return
|
||
}
|
||
|
||
// 按分类分组结果
|
||
groupedByCategory := make(map[string][]*knowledge.KnowledgeItemSummary)
|
||
for _, item := range items {
|
||
cat := item.Category
|
||
if cat == "" {
|
||
cat = "未分类"
|
||
}
|
||
groupedByCategory[cat] = append(groupedByCategory[cat], item)
|
||
}
|
||
|
||
// 转换为CategoryWithItems格式
|
||
categoriesWithItems := make([]*knowledge.CategoryWithItems, 0, len(groupedByCategory))
|
||
for cat, catItems := range groupedByCategory {
|
||
categoriesWithItems = append(categoriesWithItems, &knowledge.CategoryWithItems{
|
||
Category: cat,
|
||
ItemCount: len(catItems),
|
||
Items: catItems,
|
||
})
|
||
}
|
||
|
||
// 按分类名称排序
|
||
for i := 0; i < len(categoriesWithItems)-1; i++ {
|
||
for j := i + 1; j < len(categoriesWithItems); j++ {
|
||
if categoriesWithItems[i].Category > categoriesWithItems[j].Category {
|
||
categoriesWithItems[i], categoriesWithItems[j] = categoriesWithItems[j], categoriesWithItems[i]
|
||
}
|
||
}
|
||
}
|
||
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"categories": categoriesWithItems,
|
||
"total": len(categoriesWithItems),
|
||
"search": searchKeyword,
|
||
"is_search": true,
|
||
})
|
||
return
|
||
}
|
||
|
||
// 分页模式:categoryPage=true 表示按分类分页,否则按项分页(向后兼容)
|
||
categoryPageMode := c.Query("categoryPage") != "false" // 默认使用分类分页
|
||
|
||
// 分页参数
|
||
limit := 50 // 默认每页50条(分类分页时为分类数,项分页时为项数)
|
||
offset := 0
|
||
if limitStr := c.Query("limit"); limitStr != "" {
|
||
if parsed, err := parseInt(limitStr); err == nil && parsed > 0 && parsed <= 500 {
|
||
limit = parsed
|
||
}
|
||
}
|
||
if offsetStr := c.Query("offset"); offsetStr != "" {
|
||
if parsed, err := parseInt(offsetStr); err == nil && parsed >= 0 {
|
||
offset = parsed
|
||
}
|
||
}
|
||
|
||
// 如果指定了category参数,且使用分类分页模式,则只返回该分类
|
||
if category != "" && categoryPageMode {
|
||
// 单分类模式:返回该分类的所有知识项(不分页)
|
||
items, total, err := h.manager.GetItemsSummary(category, 0, 0)
|
||
if err != nil {
|
||
h.logger.Error("获取知识项失败", zap.Error(err))
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||
return
|
||
}
|
||
|
||
// 包装成分类结构
|
||
categoriesWithItems := []*knowledge.CategoryWithItems{
|
||
{
|
||
Category: category,
|
||
ItemCount: total,
|
||
Items: items,
|
||
},
|
||
}
|
||
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"categories": categoriesWithItems,
|
||
"total": 1, // 只有一个分类
|
||
"limit": limit,
|
||
"offset": offset,
|
||
})
|
||
return
|
||
}
|
||
|
||
if categoryPageMode {
|
||
// 按分类分页模式(默认)
|
||
// limit表示每页分类数,推荐5-10个分类
|
||
if limit <= 0 || limit > 100 {
|
||
limit = 10 // 默认每页10个分类
|
||
}
|
||
|
||
categoriesWithItems, totalCategories, err := h.manager.GetCategoriesWithItems(limit, offset)
|
||
if err != nil {
|
||
h.logger.Error("获取分类知识项失败", zap.Error(err))
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||
return
|
||
}
|
||
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"categories": categoriesWithItems,
|
||
"total": totalCategories,
|
||
"limit": limit,
|
||
"offset": offset,
|
||
})
|
||
return
|
||
}
|
||
|
||
// 按项分页模式(向后兼容)
|
||
// 是否包含完整内容(默认false,只返回摘要)
|
||
includeContent := c.Query("includeContent") == "true"
|
||
|
||
if includeContent {
|
||
// 返回完整内容(向后兼容)
|
||
items, err := h.manager.GetItemsWithOptions(category, limit, offset, true)
|
||
if err != nil {
|
||
h.logger.Error("获取知识项失败", zap.Error(err))
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||
return
|
||
}
|
||
|
||
// 获取总数
|
||
total, err := h.manager.GetItemsCount(category)
|
||
if err != nil {
|
||
h.logger.Warn("获取知识项总数失败", zap.Error(err))
|
||
total = len(items)
|
||
}
|
||
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"items": items,
|
||
"total": total,
|
||
"limit": limit,
|
||
"offset": offset,
|
||
})
|
||
} else {
|
||
// 返回摘要(不包含完整内容,推荐方式)
|
||
items, total, err := h.manager.GetItemsSummary(category, limit, offset)
|
||
if err != nil {
|
||
h.logger.Error("获取知识项失败", zap.Error(err))
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||
return
|
||
}
|
||
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"items": items,
|
||
"total": total,
|
||
"limit": limit,
|
||
"offset": offset,
|
||
})
|
||
}
|
||
}
|
||
|
||
// GetItem 获取单个知识项
|
||
func (h *KnowledgeHandler) GetItem(c *gin.Context) {
|
||
id := c.Param("id")
|
||
|
||
item, err := h.manager.GetItem(id)
|
||
if err != nil {
|
||
h.logger.Error("获取知识项失败", zap.Error(err))
|
||
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
|
||
return
|
||
}
|
||
|
||
c.JSON(http.StatusOK, item)
|
||
}
|
||
|
||
// CreateItem 创建知识项
|
||
func (h *KnowledgeHandler) CreateItem(c *gin.Context) {
|
||
var req struct {
|
||
Category string `json:"category" binding:"required"`
|
||
Title string `json:"title" binding:"required"`
|
||
Content string `json:"content" binding:"required"`
|
||
}
|
||
|
||
if err := c.ShouldBindJSON(&req); err != nil {
|
||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||
return
|
||
}
|
||
|
||
item, err := h.manager.CreateItem(req.Category, req.Title, req.Content)
|
||
if err != nil {
|
||
h.logger.Error("创建知识项失败", zap.Error(err))
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||
return
|
||
}
|
||
|
||
// 异步索引
|
||
go func() {
|
||
ctx := context.Background()
|
||
if err := h.indexer.IndexItem(ctx, item.ID); err != nil {
|
||
h.logger.Warn("索引知识项失败", zap.String("itemId", item.ID), zap.Error(err))
|
||
}
|
||
}()
|
||
|
||
c.JSON(http.StatusOK, item)
|
||
}
|
||
|
||
// UpdateItem 更新知识项
|
||
func (h *KnowledgeHandler) UpdateItem(c *gin.Context) {
|
||
id := c.Param("id")
|
||
|
||
var req struct {
|
||
Category string `json:"category" binding:"required"`
|
||
Title string `json:"title" binding:"required"`
|
||
Content string `json:"content" binding:"required"`
|
||
}
|
||
|
||
if err := c.ShouldBindJSON(&req); err != nil {
|
||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||
return
|
||
}
|
||
|
||
item, err := h.manager.UpdateItem(id, req.Category, req.Title, req.Content)
|
||
if err != nil {
|
||
h.logger.Error("更新知识项失败", zap.Error(err))
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||
return
|
||
}
|
||
|
||
// 异步重新索引
|
||
go func() {
|
||
ctx := context.Background()
|
||
if err := h.indexer.IndexItem(ctx, item.ID); err != nil {
|
||
h.logger.Warn("重新索引知识项失败", zap.String("itemId", item.ID), zap.Error(err))
|
||
}
|
||
}()
|
||
|
||
c.JSON(http.StatusOK, item)
|
||
}
|
||
|
||
// DeleteItem 删除知识项
|
||
func (h *KnowledgeHandler) DeleteItem(c *gin.Context) {
|
||
id := c.Param("id")
|
||
|
||
if err := h.manager.DeleteItem(id); err != nil {
|
||
h.logger.Error("删除知识项失败", zap.Error(err))
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||
return
|
||
}
|
||
|
||
c.JSON(http.StatusOK, gin.H{"message": "删除成功"})
|
||
}
|
||
|
||
// RebuildIndex 重建索引
|
||
func (h *KnowledgeHandler) RebuildIndex(c *gin.Context) {
|
||
// 异步重建索引
|
||
go func() {
|
||
ctx := context.Background()
|
||
if err := h.indexer.RebuildIndex(ctx); err != nil {
|
||
h.logger.Error("重建索引失败", zap.Error(err))
|
||
}
|
||
}()
|
||
|
||
c.JSON(http.StatusOK, gin.H{"message": "索引重建已开始,将在后台进行"})
|
||
}
|
||
|
||
// ScanKnowledgeBase 扫描知识库
|
||
func (h *KnowledgeHandler) ScanKnowledgeBase(c *gin.Context) {
|
||
itemsToIndex, err := h.manager.ScanKnowledgeBase()
|
||
if err != nil {
|
||
h.logger.Error("扫描知识库失败", zap.Error(err))
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||
return
|
||
}
|
||
|
||
if len(itemsToIndex) == 0 {
|
||
c.JSON(http.StatusOK, gin.H{"message": "扫描完成,没有需要索引的新项或更新项"})
|
||
return
|
||
}
|
||
|
||
// 异步索引新添加或更新的项(增量索引)
|
||
go func() {
|
||
ctx := context.Background()
|
||
h.logger.Info("开始增量索引", zap.Int("count", len(itemsToIndex)))
|
||
failedCount := 0
|
||
consecutiveFailures := 0
|
||
var firstFailureItemID string
|
||
var firstFailureError error
|
||
|
||
for i, itemID := range itemsToIndex {
|
||
if err := h.indexer.IndexItem(ctx, itemID); err != nil {
|
||
failedCount++
|
||
consecutiveFailures++
|
||
|
||
// 只在第一个失败时记录详细日志
|
||
if consecutiveFailures == 1 {
|
||
firstFailureItemID = itemID
|
||
firstFailureError = err
|
||
h.logger.Warn("索引知识项失败",
|
||
zap.String("itemId", itemID),
|
||
zap.Int("totalItems", len(itemsToIndex)),
|
||
zap.Error(err),
|
||
)
|
||
}
|
||
|
||
// 如果连续失败2次,立即停止增量索引
|
||
if consecutiveFailures >= 2 {
|
||
h.logger.Error("连续索引失败次数过多,立即停止增量索引",
|
||
zap.Int("consecutiveFailures", consecutiveFailures),
|
||
zap.Int("totalItems", len(itemsToIndex)),
|
||
zap.Int("processedItems", i+1),
|
||
zap.String("firstFailureItemId", firstFailureItemID),
|
||
zap.Error(firstFailureError),
|
||
)
|
||
break
|
||
}
|
||
continue
|
||
}
|
||
|
||
// 成功时重置连续失败计数
|
||
if consecutiveFailures > 0 {
|
||
consecutiveFailures = 0
|
||
firstFailureItemID = ""
|
||
firstFailureError = nil
|
||
}
|
||
|
||
// 减少进度日志频率
|
||
if (i+1)%10 == 0 || i+1 == len(itemsToIndex) {
|
||
h.logger.Info("索引进度", zap.Int("current", i+1), zap.Int("total", len(itemsToIndex)), zap.Int("failed", failedCount))
|
||
}
|
||
}
|
||
h.logger.Info("增量索引完成", zap.Int("totalItems", len(itemsToIndex)), zap.Int("failedCount", failedCount))
|
||
}()
|
||
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"message": fmt.Sprintf("扫描完成,开始索引 %d 个新添加或更新的知识项", len(itemsToIndex)),
|
||
"items_to_index": len(itemsToIndex),
|
||
})
|
||
}
|
||
|
||
// GetRetrievalLogs 获取检索日志
|
||
func (h *KnowledgeHandler) GetRetrievalLogs(c *gin.Context) {
|
||
conversationID := c.Query("conversationId")
|
||
messageID := c.Query("messageId")
|
||
limit := 50 // 默认50条
|
||
|
||
if limitStr := c.Query("limit"); limitStr != "" {
|
||
if parsed, err := parseInt(limitStr); err == nil && parsed > 0 {
|
||
limit = parsed
|
||
}
|
||
}
|
||
|
||
logs, err := h.manager.GetRetrievalLogs(conversationID, messageID, limit)
|
||
if err != nil {
|
||
h.logger.Error("获取检索日志失败", zap.Error(err))
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||
return
|
||
}
|
||
|
||
c.JSON(http.StatusOK, gin.H{"logs": logs})
|
||
}
|
||
|
||
// DeleteRetrievalLog 删除检索日志
|
||
func (h *KnowledgeHandler) DeleteRetrievalLog(c *gin.Context) {
|
||
id := c.Param("id")
|
||
|
||
if err := h.manager.DeleteRetrievalLog(id); err != nil {
|
||
h.logger.Error("删除检索日志失败", zap.Error(err))
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||
return
|
||
}
|
||
|
||
c.JSON(http.StatusOK, gin.H{"message": "删除成功"})
|
||
}
|
||
|
||
// GetIndexStatus 获取索引状态
|
||
func (h *KnowledgeHandler) GetIndexStatus(c *gin.Context) {
|
||
status, err := h.manager.GetIndexStatus()
|
||
if err != nil {
|
||
h.logger.Error("获取索引状态失败", zap.Error(err))
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||
return
|
||
}
|
||
|
||
// 获取索引器的错误信息
|
||
if h.indexer != nil {
|
||
lastError, lastErrorTime := h.indexer.GetLastError()
|
||
if lastError != "" {
|
||
// 如果错误是最近发生的(5分钟内),则返回错误信息
|
||
if time.Since(lastErrorTime) < 5*time.Minute {
|
||
status["last_error"] = lastError
|
||
status["last_error_time"] = lastErrorTime.Format(time.RFC3339)
|
||
}
|
||
}
|
||
}
|
||
|
||
c.JSON(http.StatusOK, status)
|
||
}
|
||
|
||
// Search 搜索知识库(用于API调用,Agent内部使用Retriever)
|
||
func (h *KnowledgeHandler) Search(c *gin.Context) {
|
||
var req knowledge.SearchRequest
|
||
if err := c.ShouldBindJSON(&req); err != nil {
|
||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||
return
|
||
}
|
||
|
||
results, err := h.retriever.Search(c.Request.Context(), &req)
|
||
if err != nil {
|
||
h.logger.Error("搜索知识库失败", zap.Error(err))
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||
return
|
||
}
|
||
|
||
c.JSON(http.StatusOK, gin.H{"results": results})
|
||
}
|
||
|
||
// GetStats 获取知识库统计信息
|
||
func (h *KnowledgeHandler) GetStats(c *gin.Context) {
|
||
totalCategories, totalItems, err := h.manager.GetStats()
|
||
if err != nil {
|
||
h.logger.Error("获取知识库统计信息失败", zap.Error(err))
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||
return
|
||
}
|
||
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"enabled": true,
|
||
"total_categories": totalCategories,
|
||
"total_items": totalItems,
|
||
})
|
||
}
|
||
|
||
// 辅助函数:解析整数
|
||
func parseInt(s string) (int, error) {
|
||
var result int
|
||
_, err := fmt.Sscanf(s, "%d", &result)
|
||
return result, err
|
||
}
|