mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-04-01 00:30:33 +02:00
262 lines
6.8 KiB
Go
262 lines
6.8 KiB
Go
package handler
|
||
|
||
import (
|
||
"context"
|
||
"fmt"
|
||
"net/http"
|
||
|
||
"cyberstrike-ai/internal/database"
|
||
"cyberstrike-ai/internal/knowledge"
|
||
|
||
"github.com/gin-gonic/gin"
|
||
"go.uber.org/zap"
|
||
)
|
||
|
||
// KnowledgeHandler 知识库处理器
|
||
type KnowledgeHandler struct {
|
||
manager *knowledge.Manager
|
||
retriever *knowledge.Retriever
|
||
indexer *knowledge.Indexer
|
||
db *database.DB
|
||
logger *zap.Logger
|
||
}
|
||
|
||
// NewKnowledgeHandler 创建新的知识库处理器
|
||
func NewKnowledgeHandler(
|
||
manager *knowledge.Manager,
|
||
retriever *knowledge.Retriever,
|
||
indexer *knowledge.Indexer,
|
||
db *database.DB,
|
||
logger *zap.Logger,
|
||
) *KnowledgeHandler {
|
||
return &KnowledgeHandler{
|
||
manager: manager,
|
||
retriever: retriever,
|
||
indexer: indexer,
|
||
db: db,
|
||
logger: logger,
|
||
}
|
||
}
|
||
|
||
// GetCategories 获取所有分类
|
||
func (h *KnowledgeHandler) GetCategories(c *gin.Context) {
|
||
categories, err := h.manager.GetCategories()
|
||
if err != nil {
|
||
h.logger.Error("获取分类失败", zap.Error(err))
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||
return
|
||
}
|
||
|
||
c.JSON(http.StatusOK, gin.H{"categories": categories})
|
||
}
|
||
|
||
// GetItems 获取知识项列表
|
||
func (h *KnowledgeHandler) GetItems(c *gin.Context) {
|
||
category := c.Query("category")
|
||
|
||
items, err := h.manager.GetItems(category)
|
||
if err != nil {
|
||
h.logger.Error("获取知识项失败", zap.Error(err))
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||
return
|
||
}
|
||
|
||
c.JSON(http.StatusOK, gin.H{"items": items})
|
||
}
|
||
|
||
// GetItem 获取单个知识项
|
||
func (h *KnowledgeHandler) GetItem(c *gin.Context) {
|
||
id := c.Param("id")
|
||
|
||
item, err := h.manager.GetItem(id)
|
||
if err != nil {
|
||
h.logger.Error("获取知识项失败", zap.Error(err))
|
||
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
|
||
return
|
||
}
|
||
|
||
c.JSON(http.StatusOK, item)
|
||
}
|
||
|
||
// CreateItem 创建知识项
|
||
func (h *KnowledgeHandler) CreateItem(c *gin.Context) {
|
||
var req struct {
|
||
Category string `json:"category" binding:"required"`
|
||
Title string `json:"title" binding:"required"`
|
||
Content string `json:"content" binding:"required"`
|
||
}
|
||
|
||
if err := c.ShouldBindJSON(&req); err != nil {
|
||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||
return
|
||
}
|
||
|
||
item, err := h.manager.CreateItem(req.Category, req.Title, req.Content)
|
||
if err != nil {
|
||
h.logger.Error("创建知识项失败", zap.Error(err))
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||
return
|
||
}
|
||
|
||
// 异步索引
|
||
go func() {
|
||
ctx := context.Background()
|
||
if err := h.indexer.IndexItem(ctx, item.ID); err != nil {
|
||
h.logger.Warn("索引知识项失败", zap.String("itemId", item.ID), zap.Error(err))
|
||
}
|
||
}()
|
||
|
||
c.JSON(http.StatusOK, item)
|
||
}
|
||
|
||
// UpdateItem 更新知识项
|
||
func (h *KnowledgeHandler) UpdateItem(c *gin.Context) {
|
||
id := c.Param("id")
|
||
|
||
var req struct {
|
||
Category string `json:"category" binding:"required"`
|
||
Title string `json:"title" binding:"required"`
|
||
Content string `json:"content" binding:"required"`
|
||
}
|
||
|
||
if err := c.ShouldBindJSON(&req); err != nil {
|
||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||
return
|
||
}
|
||
|
||
item, err := h.manager.UpdateItem(id, req.Category, req.Title, req.Content)
|
||
if err != nil {
|
||
h.logger.Error("更新知识项失败", zap.Error(err))
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||
return
|
||
}
|
||
|
||
// 异步重新索引
|
||
go func() {
|
||
ctx := context.Background()
|
||
if err := h.indexer.IndexItem(ctx, item.ID); err != nil {
|
||
h.logger.Warn("重新索引知识项失败", zap.String("itemId", item.ID), zap.Error(err))
|
||
}
|
||
}()
|
||
|
||
c.JSON(http.StatusOK, item)
|
||
}
|
||
|
||
// DeleteItem 删除知识项
|
||
func (h *KnowledgeHandler) DeleteItem(c *gin.Context) {
|
||
id := c.Param("id")
|
||
|
||
if err := h.manager.DeleteItem(id); err != nil {
|
||
h.logger.Error("删除知识项失败", zap.Error(err))
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||
return
|
||
}
|
||
|
||
c.JSON(http.StatusOK, gin.H{"message": "删除成功"})
|
||
}
|
||
|
||
// RebuildIndex 重建索引
|
||
func (h *KnowledgeHandler) RebuildIndex(c *gin.Context) {
|
||
// 异步重建索引
|
||
go func() {
|
||
ctx := context.Background()
|
||
if err := h.indexer.RebuildIndex(ctx); err != nil {
|
||
h.logger.Error("重建索引失败", zap.Error(err))
|
||
}
|
||
}()
|
||
|
||
c.JSON(http.StatusOK, gin.H{"message": "索引重建已开始,将在后台进行"})
|
||
}
|
||
|
||
// ScanKnowledgeBase 扫描知识库
|
||
func (h *KnowledgeHandler) ScanKnowledgeBase(c *gin.Context) {
|
||
if err := h.manager.ScanKnowledgeBase(); err != nil {
|
||
h.logger.Error("扫描知识库失败", zap.Error(err))
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||
return
|
||
}
|
||
|
||
// 异步重建索引
|
||
go func() {
|
||
ctx := context.Background()
|
||
if err := h.indexer.RebuildIndex(ctx); err != nil {
|
||
h.logger.Error("重建索引失败", zap.Error(err))
|
||
}
|
||
}()
|
||
|
||
c.JSON(http.StatusOK, gin.H{"message": "扫描完成,索引重建已开始"})
|
||
}
|
||
|
||
// GetRetrievalLogs 获取检索日志
|
||
func (h *KnowledgeHandler) GetRetrievalLogs(c *gin.Context) {
|
||
conversationID := c.Query("conversationId")
|
||
messageID := c.Query("messageId")
|
||
limit := 50 // 默认50条
|
||
|
||
if limitStr := c.Query("limit"); limitStr != "" {
|
||
if parsed, err := parseInt(limitStr); err == nil && parsed > 0 {
|
||
limit = parsed
|
||
}
|
||
}
|
||
|
||
logs, err := h.manager.GetRetrievalLogs(conversationID, messageID, limit)
|
||
if err != nil {
|
||
h.logger.Error("获取检索日志失败", zap.Error(err))
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||
return
|
||
}
|
||
|
||
c.JSON(http.StatusOK, gin.H{"logs": logs})
|
||
}
|
||
|
||
// DeleteRetrievalLog 删除检索日志
|
||
func (h *KnowledgeHandler) DeleteRetrievalLog(c *gin.Context) {
|
||
id := c.Param("id")
|
||
|
||
if err := h.manager.DeleteRetrievalLog(id); err != nil {
|
||
h.logger.Error("删除检索日志失败", zap.Error(err))
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||
return
|
||
}
|
||
|
||
c.JSON(http.StatusOK, gin.H{"message": "删除成功"})
|
||
}
|
||
|
||
// GetIndexStatus 获取索引状态
|
||
func (h *KnowledgeHandler) GetIndexStatus(c *gin.Context) {
|
||
status, err := h.manager.GetIndexStatus()
|
||
if err != nil {
|
||
h.logger.Error("获取索引状态失败", zap.Error(err))
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||
return
|
||
}
|
||
|
||
c.JSON(http.StatusOK, status)
|
||
}
|
||
|
||
// Search 搜索知识库(用于API调用,Agent内部使用Retriever)
|
||
func (h *KnowledgeHandler) Search(c *gin.Context) {
|
||
var req knowledge.SearchRequest
|
||
if err := c.ShouldBindJSON(&req); err != nil {
|
||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||
return
|
||
}
|
||
|
||
results, err := h.retriever.Search(c.Request.Context(), &req)
|
||
if err != nil {
|
||
h.logger.Error("搜索知识库失败", zap.Error(err))
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||
return
|
||
}
|
||
|
||
c.JSON(http.StatusOK, gin.H{"results": results})
|
||
}
|
||
|
||
// 辅助函数:解析整数
|
||
func parseInt(s string) (int, error) {
|
||
var result int
|
||
_, err := fmt.Sscanf(s, "%d", &result)
|
||
return result, err
|
||
}
|
||
|