mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-04-01 16:50:47 +02:00
Add files via upload
This commit is contained in:
@@ -26,16 +26,21 @@ import (
|
||||
|
||||
// App 应用
|
||||
type App struct {
|
||||
config *config.Config
|
||||
logger *logger.Logger
|
||||
router *gin.Engine
|
||||
mcpServer *mcp.Server
|
||||
externalMCPMgr *mcp.ExternalMCPManager
|
||||
agent *agent.Agent
|
||||
executor *security.Executor
|
||||
db *database.DB
|
||||
knowledgeDB *database.DB // 知识库数据库连接(如果使用独立数据库)
|
||||
auth *security.AuthManager
|
||||
config *config.Config
|
||||
logger *logger.Logger
|
||||
router *gin.Engine
|
||||
mcpServer *mcp.Server
|
||||
externalMCPMgr *mcp.ExternalMCPManager
|
||||
agent *agent.Agent
|
||||
executor *security.Executor
|
||||
db *database.DB
|
||||
knowledgeDB *database.DB // 知识库数据库连接(如果使用独立数据库)
|
||||
auth *security.AuthManager
|
||||
knowledgeManager *knowledge.Manager // 知识库管理器(用于动态初始化)
|
||||
knowledgeRetriever *knowledge.Retriever // 知识库检索器(用于动态初始化)
|
||||
knowledgeIndexer *knowledge.Indexer // 知识库索引器(用于动态初始化)
|
||||
knowledgeHandler *handler.KnowledgeHandler // 知识库处理器(用于动态初始化)
|
||||
agentHandler *handler.AgentHandler // Agent处理器(用于更新知识库管理器)
|
||||
}
|
||||
|
||||
// New 创建新应用
|
||||
@@ -196,12 +201,13 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
|
||||
|
||||
// 扫描知识库并建立索引(异步)
|
||||
go func() {
|
||||
if err := knowledgeManager.ScanKnowledgeBase(); err != nil {
|
||||
itemsToIndex, err := knowledgeManager.ScanKnowledgeBase()
|
||||
if err != nil {
|
||||
log.Logger.Warn("扫描知识库失败", zap.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
// 检查是否已有索引,如果有则跳过自动重建
|
||||
// 检查是否已有索引
|
||||
hasIndex, err := knowledgeIndexer.HasIndex()
|
||||
if err != nil {
|
||||
log.Logger.Warn("检查索引状态失败", zap.Error(err))
|
||||
@@ -209,7 +215,20 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
|
||||
}
|
||||
|
||||
if hasIndex {
|
||||
log.Logger.Info("检测到已有知识库索引,跳过自动重建。如需重建,请手动点击重建索引按钮")
|
||||
// 如果已有索引,只索引新添加或更新的项
|
||||
if len(itemsToIndex) > 0 {
|
||||
log.Logger.Info("检测到已有知识库索引,开始增量索引", zap.Int("count", len(itemsToIndex)))
|
||||
ctx := context.Background()
|
||||
for _, itemID := range itemsToIndex {
|
||||
if err := knowledgeIndexer.IndexItem(ctx, itemID); err != nil {
|
||||
log.Logger.Warn("索引知识项失败", zap.String("itemId", itemID), zap.Error(err))
|
||||
continue
|
||||
}
|
||||
}
|
||||
log.Logger.Info("增量索引完成", zap.Int("totalItems", len(itemsToIndex)))
|
||||
} else {
|
||||
log.Logger.Info("检测到已有知识库索引,没有需要索引的新项或更新项")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -242,6 +261,51 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
|
||||
attackChainHandler := handler.NewAttackChainHandler(db, &cfg.OpenAI, log.Logger)
|
||||
vulnerabilityHandler := handler.NewVulnerabilityHandler(db, log.Logger)
|
||||
configHandler := handler.NewConfigHandler(configPath, cfg, mcpServer, executor, agent, attackChainHandler, externalMCPMgr, log.Logger)
|
||||
externalMCPHandler := handler.NewExternalMCPHandler(externalMCPMgr, cfg, configPath, log.Logger)
|
||||
|
||||
// 创建 App 实例(部分字段稍后填充)
|
||||
app := &App{
|
||||
config: cfg,
|
||||
logger: log,
|
||||
router: router,
|
||||
mcpServer: mcpServer,
|
||||
externalMCPMgr: externalMCPMgr,
|
||||
agent: agent,
|
||||
executor: executor,
|
||||
db: db,
|
||||
knowledgeDB: knowledgeDBConn,
|
||||
auth: authManager,
|
||||
knowledgeManager: knowledgeManager,
|
||||
knowledgeRetriever: knowledgeRetriever,
|
||||
knowledgeIndexer: knowledgeIndexer,
|
||||
knowledgeHandler: knowledgeHandler,
|
||||
agentHandler: agentHandler,
|
||||
}
|
||||
|
||||
// 设置知识库初始化器(用于动态初始化,需要在 App 创建后设置)
|
||||
configHandler.SetKnowledgeInitializer(func() (*handler.KnowledgeHandler, error) {
|
||||
knowledgeHandler, err := initializeKnowledge(cfg, db, knowledgeDBConn, mcpServer, agentHandler, app, log.Logger)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 动态初始化后,设置知识库工具注册器和检索器更新器
|
||||
// 这样后续 ApplyConfig 时就能重新注册工具了
|
||||
if app.knowledgeRetriever != nil && app.knowledgeManager != nil {
|
||||
// 创建闭包,捕获knowledgeRetriever和knowledgeManager的引用
|
||||
registrar := func() error {
|
||||
knowledge.RegisterKnowledgeTool(mcpServer, app.knowledgeRetriever, app.knowledgeManager, log.Logger)
|
||||
return nil
|
||||
}
|
||||
configHandler.SetKnowledgeToolRegistrar(registrar)
|
||||
// 设置检索器更新器,以便在ApplyConfig时更新检索器配置
|
||||
configHandler.SetRetrieverUpdater(app.knowledgeRetriever)
|
||||
log.Logger.Info("动态初始化后已设置知识库工具注册器和检索器更新器")
|
||||
}
|
||||
|
||||
return knowledgeHandler, nil
|
||||
})
|
||||
|
||||
// 如果知识库已启用,设置知识库工具注册器和检索器更新器
|
||||
if cfg.Knowledge.Enabled && knowledgeRetriever != nil && knowledgeManager != nil {
|
||||
// 创建闭包,捕获knowledgeRetriever和knowledgeManager的引用
|
||||
@@ -253,9 +317,8 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
|
||||
// 设置检索器更新器,以便在ApplyConfig时更新检索器配置
|
||||
configHandler.SetRetrieverUpdater(knowledgeRetriever)
|
||||
}
|
||||
externalMCPHandler := handler.NewExternalMCPHandler(externalMCPMgr, cfg, configPath, log.Logger)
|
||||
|
||||
// 设置路由
|
||||
// 设置路由(使用 App 实例以便动态获取 handler)
|
||||
setupRoutes(
|
||||
router,
|
||||
authHandler,
|
||||
@@ -266,24 +329,14 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
|
||||
configHandler,
|
||||
externalMCPHandler,
|
||||
attackChainHandler,
|
||||
knowledgeHandler,
|
||||
app, // 传递 App 实例以便动态获取 knowledgeHandler
|
||||
vulnerabilityHandler,
|
||||
mcpServer,
|
||||
authManager,
|
||||
)
|
||||
|
||||
return &App{
|
||||
config: cfg,
|
||||
logger: log,
|
||||
router: router,
|
||||
mcpServer: mcpServer,
|
||||
externalMCPMgr: externalMCPMgr,
|
||||
agent: agent,
|
||||
executor: executor,
|
||||
db: db,
|
||||
knowledgeDB: knowledgeDBConn,
|
||||
auth: authManager,
|
||||
}, nil
|
||||
return app, nil
|
||||
|
||||
}
|
||||
|
||||
// Run 启动应用
|
||||
@@ -336,7 +389,7 @@ func setupRoutes(
|
||||
configHandler *handler.ConfigHandler,
|
||||
externalMCPHandler *handler.ExternalMCPHandler,
|
||||
attackChainHandler *handler.AttackChainHandler,
|
||||
knowledgeHandler *handler.KnowledgeHandler,
|
||||
app *App, // 传递 App 实例以便动态获取 knowledgeHandler
|
||||
vulnerabilityHandler *handler.VulnerabilityHandler,
|
||||
mcpServer *mcp.Server,
|
||||
authManager *security.AuthManager,
|
||||
@@ -409,20 +462,137 @@ func setupRoutes(
|
||||
protected.GET("/attack-chain/:conversationId", attackChainHandler.GetAttackChain)
|
||||
protected.POST("/attack-chain/:conversationId/regenerate", attackChainHandler.RegenerateAttackChain)
|
||||
|
||||
// 知识库管理(如果启用)
|
||||
if knowledgeHandler != nil {
|
||||
protected.GET("/knowledge/categories", knowledgeHandler.GetCategories)
|
||||
protected.GET("/knowledge/items", knowledgeHandler.GetItems)
|
||||
protected.GET("/knowledge/items/:id", knowledgeHandler.GetItem)
|
||||
protected.POST("/knowledge/items", knowledgeHandler.CreateItem)
|
||||
protected.PUT("/knowledge/items/:id", knowledgeHandler.UpdateItem)
|
||||
protected.DELETE("/knowledge/items/:id", knowledgeHandler.DeleteItem)
|
||||
protected.GET("/knowledge/index-status", knowledgeHandler.GetIndexStatus)
|
||||
protected.POST("/knowledge/index", knowledgeHandler.RebuildIndex)
|
||||
protected.POST("/knowledge/scan", knowledgeHandler.ScanKnowledgeBase)
|
||||
protected.GET("/knowledge/retrieval-logs", knowledgeHandler.GetRetrievalLogs)
|
||||
protected.DELETE("/knowledge/retrieval-logs/:id", knowledgeHandler.DeleteRetrievalLog)
|
||||
protected.POST("/knowledge/search", knowledgeHandler.Search)
|
||||
// 知识库管理(始终注册路由,通过 App 实例动态获取 handler)
|
||||
knowledgeRoutes := protected.Group("/knowledge")
|
||||
{
|
||||
knowledgeRoutes.GET("/categories", func(c *gin.Context) {
|
||||
if app.knowledgeHandler == nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"categories": []string{},
|
||||
"enabled": false,
|
||||
"message": "知识库功能未启用,请前往系统设置启用知识检索功能",
|
||||
})
|
||||
return
|
||||
}
|
||||
app.knowledgeHandler.GetCategories(c)
|
||||
})
|
||||
knowledgeRoutes.GET("/items", func(c *gin.Context) {
|
||||
if app.knowledgeHandler == nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"items": []interface{}{},
|
||||
"enabled": false,
|
||||
"message": "知识库功能未启用,请前往系统设置启用知识检索功能",
|
||||
})
|
||||
return
|
||||
}
|
||||
app.knowledgeHandler.GetItems(c)
|
||||
})
|
||||
knowledgeRoutes.GET("/items/:id", func(c *gin.Context) {
|
||||
if app.knowledgeHandler == nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"enabled": false,
|
||||
"message": "知识库功能未启用,请前往系统设置启用知识检索功能",
|
||||
})
|
||||
return
|
||||
}
|
||||
app.knowledgeHandler.GetItem(c)
|
||||
})
|
||||
knowledgeRoutes.POST("/items", func(c *gin.Context) {
|
||||
if app.knowledgeHandler == nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"enabled": false,
|
||||
"error": "知识库功能未启用,请前往系统设置启用知识检索功能",
|
||||
})
|
||||
return
|
||||
}
|
||||
app.knowledgeHandler.CreateItem(c)
|
||||
})
|
||||
knowledgeRoutes.PUT("/items/:id", func(c *gin.Context) {
|
||||
if app.knowledgeHandler == nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"enabled": false,
|
||||
"error": "知识库功能未启用,请前往系统设置启用知识检索功能",
|
||||
})
|
||||
return
|
||||
}
|
||||
app.knowledgeHandler.UpdateItem(c)
|
||||
})
|
||||
knowledgeRoutes.DELETE("/items/:id", func(c *gin.Context) {
|
||||
if app.knowledgeHandler == nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"enabled": false,
|
||||
"error": "知识库功能未启用,请前往系统设置启用知识检索功能",
|
||||
})
|
||||
return
|
||||
}
|
||||
app.knowledgeHandler.DeleteItem(c)
|
||||
})
|
||||
knowledgeRoutes.GET("/index-status", func(c *gin.Context) {
|
||||
if app.knowledgeHandler == nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"enabled": false,
|
||||
"total_items": 0,
|
||||
"indexed_items": 0,
|
||||
"progress_percent": 0,
|
||||
"is_complete": false,
|
||||
"message": "知识库功能未启用,请前往系统设置启用知识检索功能",
|
||||
})
|
||||
return
|
||||
}
|
||||
app.knowledgeHandler.GetIndexStatus(c)
|
||||
})
|
||||
knowledgeRoutes.POST("/index", func(c *gin.Context) {
|
||||
if app.knowledgeHandler == nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"enabled": false,
|
||||
"error": "知识库功能未启用,请前往系统设置启用知识检索功能",
|
||||
})
|
||||
return
|
||||
}
|
||||
app.knowledgeHandler.RebuildIndex(c)
|
||||
})
|
||||
knowledgeRoutes.POST("/scan", func(c *gin.Context) {
|
||||
if app.knowledgeHandler == nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"enabled": false,
|
||||
"error": "知识库功能未启用,请前往系统设置启用知识检索功能",
|
||||
})
|
||||
return
|
||||
}
|
||||
app.knowledgeHandler.ScanKnowledgeBase(c)
|
||||
})
|
||||
knowledgeRoutes.GET("/retrieval-logs", func(c *gin.Context) {
|
||||
if app.knowledgeHandler == nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"logs": []interface{}{},
|
||||
"enabled": false,
|
||||
"message": "知识库功能未启用,请前往系统设置启用知识检索功能",
|
||||
})
|
||||
return
|
||||
}
|
||||
app.knowledgeHandler.GetRetrievalLogs(c)
|
||||
})
|
||||
knowledgeRoutes.DELETE("/retrieval-logs/:id", func(c *gin.Context) {
|
||||
if app.knowledgeHandler == nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"enabled": false,
|
||||
"error": "知识库功能未启用,请前往系统设置启用知识检索功能",
|
||||
})
|
||||
return
|
||||
}
|
||||
app.knowledgeHandler.DeleteRetrievalLog(c)
|
||||
})
|
||||
knowledgeRoutes.POST("/search", func(c *gin.Context) {
|
||||
if app.knowledgeHandler == nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"results": []interface{}{},
|
||||
"enabled": false,
|
||||
"message": "知识库功能未启用,请前往系统设置启用知识检索功能",
|
||||
})
|
||||
return
|
||||
}
|
||||
app.knowledgeHandler.Search(c)
|
||||
})
|
||||
}
|
||||
|
||||
// 漏洞管理
|
||||
@@ -594,7 +764,7 @@ func registerVulnerabilityTool(mcpServer *mcp.Server, db *database.DB, logger *z
|
||||
Title: title,
|
||||
Description: description,
|
||||
Severity: severity,
|
||||
Status: "open",
|
||||
Status: "open",
|
||||
Type: vulnType,
|
||||
Target: target,
|
||||
Proof: proof,
|
||||
@@ -638,6 +808,136 @@ func registerVulnerabilityTool(mcpServer *mcp.Server, db *database.DB, logger *z
|
||||
logger.Info("漏洞记录工具注册成功")
|
||||
}
|
||||
|
||||
// initializeKnowledge 初始化知识库组件(用于动态初始化)
|
||||
func initializeKnowledge(
|
||||
cfg *config.Config,
|
||||
db *database.DB,
|
||||
knowledgeDBConn *database.DB,
|
||||
mcpServer *mcp.Server,
|
||||
agentHandler *handler.AgentHandler,
|
||||
app *App, // 传递 App 引用以便更新知识库组件
|
||||
logger *zap.Logger,
|
||||
) (*handler.KnowledgeHandler, error) {
|
||||
// 确定知识库数据库路径
|
||||
knowledgeDBPath := cfg.Database.KnowledgeDBPath
|
||||
var knowledgeDB *sql.DB
|
||||
|
||||
if knowledgeDBPath != "" {
|
||||
// 使用独立的知识库数据库
|
||||
// 确保目录存在
|
||||
if err := os.MkdirAll(filepath.Dir(knowledgeDBPath), 0755); err != nil {
|
||||
return nil, fmt.Errorf("创建知识库数据库目录失败: %w", err)
|
||||
}
|
||||
|
||||
var err error
|
||||
knowledgeDBConn, err = database.NewKnowledgeDB(knowledgeDBPath, logger)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("初始化知识库数据库失败: %w", err)
|
||||
}
|
||||
knowledgeDB = knowledgeDBConn.DB
|
||||
logger.Info("使用独立的知识库数据库", zap.String("path", knowledgeDBPath))
|
||||
} else {
|
||||
// 向后兼容:使用会话数据库
|
||||
knowledgeDB = db.DB
|
||||
logger.Info("使用会话数据库存储知识库数据(建议配置knowledge_db_path以分离数据)")
|
||||
}
|
||||
|
||||
// 创建知识库管理器
|
||||
knowledgeManager := knowledge.NewManager(knowledgeDB, cfg.Knowledge.BasePath, logger)
|
||||
|
||||
// 创建嵌入器
|
||||
// 使用OpenAI配置的API Key(如果知识库配置中没有指定)
|
||||
if cfg.Knowledge.Embedding.APIKey == "" {
|
||||
cfg.Knowledge.Embedding.APIKey = cfg.OpenAI.APIKey
|
||||
}
|
||||
if cfg.Knowledge.Embedding.BaseURL == "" {
|
||||
cfg.Knowledge.Embedding.BaseURL = cfg.OpenAI.BaseURL
|
||||
}
|
||||
|
||||
httpClient := &http.Client{
|
||||
Timeout: 30 * time.Minute,
|
||||
}
|
||||
openAIClient := openai.NewClient(&cfg.OpenAI, httpClient, logger)
|
||||
embedder := knowledge.NewEmbedder(&cfg.Knowledge, &cfg.OpenAI, openAIClient, logger)
|
||||
|
||||
// 创建检索器
|
||||
retrievalConfig := &knowledge.RetrievalConfig{
|
||||
TopK: cfg.Knowledge.Retrieval.TopK,
|
||||
SimilarityThreshold: cfg.Knowledge.Retrieval.SimilarityThreshold,
|
||||
HybridWeight: cfg.Knowledge.Retrieval.HybridWeight,
|
||||
}
|
||||
knowledgeRetriever := knowledge.NewRetriever(knowledgeDB, embedder, retrievalConfig, logger)
|
||||
|
||||
// 创建索引器
|
||||
knowledgeIndexer := knowledge.NewIndexer(knowledgeDB, embedder, logger)
|
||||
|
||||
// 注册知识检索工具到MCP服务器
|
||||
knowledge.RegisterKnowledgeTool(mcpServer, knowledgeRetriever, knowledgeManager, logger)
|
||||
|
||||
// 创建知识库API处理器
|
||||
knowledgeHandler := handler.NewKnowledgeHandler(knowledgeManager, knowledgeRetriever, knowledgeIndexer, db, logger)
|
||||
logger.Info("知识库模块初始化完成", zap.Bool("handler_created", knowledgeHandler != nil))
|
||||
|
||||
// 设置知识库管理器到AgentHandler以便记录检索日志
|
||||
agentHandler.SetKnowledgeManager(knowledgeManager)
|
||||
|
||||
// 更新 App 中的知识库组件(如果 App 不为 nil,说明是动态初始化)
|
||||
if app != nil {
|
||||
app.knowledgeManager = knowledgeManager
|
||||
app.knowledgeRetriever = knowledgeRetriever
|
||||
app.knowledgeIndexer = knowledgeIndexer
|
||||
app.knowledgeHandler = knowledgeHandler
|
||||
// 如果使用独立数据库,更新 knowledgeDB
|
||||
if knowledgeDBPath != "" {
|
||||
app.knowledgeDB = knowledgeDBConn
|
||||
}
|
||||
logger.Info("App 中的知识库组件已更新")
|
||||
}
|
||||
|
||||
// 扫描知识库并建立索引(异步)
|
||||
go func() {
|
||||
itemsToIndex, err := knowledgeManager.ScanKnowledgeBase()
|
||||
if err != nil {
|
||||
logger.Warn("扫描知识库失败", zap.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
// 检查是否已有索引
|
||||
hasIndex, err := knowledgeIndexer.HasIndex()
|
||||
if err != nil {
|
||||
logger.Warn("检查索引状态失败", zap.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
if hasIndex {
|
||||
// 如果已有索引,只索引新添加或更新的项
|
||||
if len(itemsToIndex) > 0 {
|
||||
logger.Info("检测到已有知识库索引,开始增量索引", zap.Int("count", len(itemsToIndex)))
|
||||
ctx := context.Background()
|
||||
for _, itemID := range itemsToIndex {
|
||||
if err := knowledgeIndexer.IndexItem(ctx, itemID); err != nil {
|
||||
logger.Warn("索引知识项失败", zap.String("itemId", itemID), zap.Error(err))
|
||||
continue
|
||||
}
|
||||
}
|
||||
logger.Info("增量索引完成", zap.Int("totalItems", len(itemsToIndex)))
|
||||
} else {
|
||||
logger.Info("检测到已有知识库索引,没有需要索引的新项或更新项")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 只有在没有索引时才自动重建
|
||||
logger.Info("未检测到知识库索引,开始自动构建索引")
|
||||
ctx := context.Background()
|
||||
if err := knowledgeIndexer.RebuildIndex(ctx); err != nil {
|
||||
logger.Warn("重建知识库索引失败", zap.Error(err))
|
||||
}
|
||||
}()
|
||||
|
||||
return knowledgeHandler, nil
|
||||
}
|
||||
|
||||
// corsMiddleware CORS中间件
|
||||
func corsMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"cyberstrike-ai/internal/knowledge"
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
"cyberstrike-ai/internal/security"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
"gopkg.in/yaml.v3"
|
||||
@@ -29,19 +30,29 @@ type RetrieverUpdater interface {
|
||||
UpdateConfig(config *knowledge.RetrievalConfig)
|
||||
}
|
||||
|
||||
// KnowledgeInitializer 知识库初始化器接口
|
||||
type KnowledgeInitializer func() (*KnowledgeHandler, error)
|
||||
|
||||
// AppUpdater App更新接口(用于更新App中的知识库组件)
|
||||
type AppUpdater interface {
|
||||
UpdateKnowledgeComponents(handler *KnowledgeHandler, manager interface{}, retriever interface{}, indexer interface{})
|
||||
}
|
||||
|
||||
// ConfigHandler 配置处理器
|
||||
type ConfigHandler struct {
|
||||
configPath string
|
||||
config *config.Config
|
||||
mcpServer *mcp.Server
|
||||
executor *security.Executor
|
||||
agent AgentUpdater // Agent接口,用于更新Agent配置
|
||||
attackChainHandler AttackChainUpdater // 攻击链处理器接口,用于更新配置
|
||||
externalMCPMgr *mcp.ExternalMCPManager // 外部MCP管理器
|
||||
knowledgeToolRegistrar KnowledgeToolRegistrar // 知识库工具注册器(可选)
|
||||
retrieverUpdater RetrieverUpdater // 检索器更新器(可选)
|
||||
logger *zap.Logger
|
||||
mu sync.RWMutex
|
||||
configPath string
|
||||
config *config.Config
|
||||
mcpServer *mcp.Server
|
||||
executor *security.Executor
|
||||
agent AgentUpdater // Agent接口,用于更新Agent配置
|
||||
attackChainHandler AttackChainUpdater // 攻击链处理器接口,用于更新配置
|
||||
externalMCPMgr *mcp.ExternalMCPManager // 外部MCP管理器
|
||||
knowledgeToolRegistrar KnowledgeToolRegistrar // 知识库工具注册器(可选)
|
||||
retrieverUpdater RetrieverUpdater // 检索器更新器(可选)
|
||||
knowledgeInitializer KnowledgeInitializer // 知识库初始化器(可选)
|
||||
appUpdater AppUpdater // App更新器(可选)
|
||||
logger *zap.Logger
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// AttackChainUpdater 攻击链处理器更新接口
|
||||
@@ -83,12 +94,26 @@ func (h *ConfigHandler) SetRetrieverUpdater(updater RetrieverUpdater) {
|
||||
h.retrieverUpdater = updater
|
||||
}
|
||||
|
||||
// SetKnowledgeInitializer 设置知识库初始化器
|
||||
func (h *ConfigHandler) SetKnowledgeInitializer(initializer KnowledgeInitializer) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
h.knowledgeInitializer = initializer
|
||||
}
|
||||
|
||||
// SetAppUpdater 设置App更新器
|
||||
func (h *ConfigHandler) SetAppUpdater(updater AppUpdater) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
h.appUpdater = updater
|
||||
}
|
||||
|
||||
// GetConfigResponse 获取配置响应
|
||||
type GetConfigResponse struct {
|
||||
OpenAI config.OpenAIConfig `json:"openai"`
|
||||
MCP config.MCPConfig `json:"mcp"`
|
||||
Tools []ToolConfigInfo `json:"tools"`
|
||||
Agent config.AgentConfig `json:"agent"`
|
||||
OpenAI config.OpenAIConfig `json:"openai"`
|
||||
MCP config.MCPConfig `json:"mcp"`
|
||||
Tools []ToolConfigInfo `json:"tools"`
|
||||
Agent config.AgentConfig `json:"agent"`
|
||||
Knowledge config.KnowledgeConfig `json:"knowledge"`
|
||||
}
|
||||
|
||||
@@ -127,7 +152,7 @@ func (h *ConfigHandler) GetConfig(c *gin.Context) {
|
||||
tools[len(tools)-1].Description = desc
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// 从MCP服务器获取所有已注册的工具(包括直接注册的工具,如知识检索工具)
|
||||
if h.mcpServer != nil {
|
||||
mcpTools := h.mcpServer.GetAllTools()
|
||||
@@ -287,7 +312,7 @@ func (h *ConfigHandler) GetTools(c *gin.Context) {
|
||||
|
||||
allTools = append(allTools, toolInfo)
|
||||
}
|
||||
|
||||
|
||||
// 从MCP服务器获取所有已注册的工具(包括直接注册的工具,如知识检索工具)
|
||||
if h.mcpServer != nil {
|
||||
mcpTools := h.mcpServer.GetAllTools()
|
||||
@@ -296,7 +321,7 @@ func (h *ConfigHandler) GetTools(c *gin.Context) {
|
||||
if configToolMap[mcpTool.Name] {
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
description := mcpTool.ShortDescription
|
||||
if description == "" {
|
||||
description = mcpTool.Description
|
||||
@@ -304,14 +329,14 @@ func (h *ConfigHandler) GetTools(c *gin.Context) {
|
||||
if len(description) > 100 {
|
||||
description = description[:100] + "..."
|
||||
}
|
||||
|
||||
|
||||
toolInfo := ToolConfigInfo{
|
||||
Name: mcpTool.Name,
|
||||
Description: description,
|
||||
Enabled: true, // 直接注册的工具默认启用
|
||||
IsExternal: false,
|
||||
}
|
||||
|
||||
|
||||
// 如果有关键词,进行搜索过滤
|
||||
if searchTermLower != "" {
|
||||
nameLower := strings.ToLower(toolInfo.Name)
|
||||
@@ -320,7 +345,7 @@ func (h *ConfigHandler) GetTools(c *gin.Context) {
|
||||
continue // 不匹配,跳过
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
allTools = append(allTools, toolInfo)
|
||||
}
|
||||
}
|
||||
@@ -336,7 +361,7 @@ func (h *ConfigHandler) GetTools(c *gin.Context) {
|
||||
} else {
|
||||
// 获取外部MCP配置,用于判断启用状态
|
||||
externalMCPConfigs := h.externalMCPMgr.GetConfigs()
|
||||
|
||||
|
||||
for _, externalTool := range externalTools {
|
||||
// 解析工具名称:mcpName::toolName
|
||||
var mcpName, actualToolName string
|
||||
@@ -434,7 +459,7 @@ type UpdateConfigRequest struct {
|
||||
OpenAI *config.OpenAIConfig `json:"openai,omitempty"`
|
||||
MCP *config.MCPConfig `json:"mcp,omitempty"`
|
||||
Tools []ToolEnableStatus `json:"tools,omitempty"`
|
||||
Agent *config.AgentConfig `json:"agent,omitempty"`
|
||||
Agent *config.AgentConfig `json:"agent,omitempty"`
|
||||
Knowledge *config.KnowledgeConfig `json:"knowledge,omitempty"`
|
||||
}
|
||||
|
||||
@@ -541,12 +566,12 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) {
|
||||
h.logger.Warn("外部MCP配置不存在", zap.String("mcp", mcpName))
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
// 初始化ToolEnabled map
|
||||
if cfg.ToolEnabled == nil {
|
||||
cfg.ToolEnabled = make(map[string]bool)
|
||||
}
|
||||
|
||||
|
||||
// 更新每个工具的启用状态
|
||||
for toolName, enabled := range toolStates {
|
||||
cfg.ToolEnabled[toolName] = enabled
|
||||
@@ -556,7 +581,7 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) {
|
||||
zap.Bool("enabled", enabled),
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
// 检查是否有任何工具启用,如果有则启用MCP
|
||||
hasEnabledTool := false
|
||||
for _, enabled := range cfg.ToolEnabled {
|
||||
@@ -565,21 +590,21 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// 如果MCP之前未启用,但现在有工具启用,则启用MCP
|
||||
// 如果MCP之前已启用,保持启用状态(允许部分工具禁用)
|
||||
if !cfg.ExternalMCPEnable && hasEnabledTool {
|
||||
cfg.ExternalMCPEnable = true
|
||||
h.logger.Info("自动启用外部MCP(因为有工具启用)", zap.String("mcp", mcpName))
|
||||
}
|
||||
|
||||
|
||||
h.config.ExternalMCP.Servers[mcpName] = cfg
|
||||
}
|
||||
|
||||
|
||||
// 同步更新 externalMCPMgr 中的配置,确保 GetConfigs() 返回最新配置
|
||||
// 在循环外部统一更新,避免重复调用
|
||||
h.externalMCPMgr.LoadConfigs(&h.config.ExternalMCP)
|
||||
|
||||
|
||||
// 处理MCP连接状态(异步启动,避免阻塞)
|
||||
for mcpName := range externalMCPToolMap {
|
||||
cfg := h.config.ExternalMCP.Servers[mcpName]
|
||||
@@ -618,18 +643,41 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) {
|
||||
|
||||
// ApplyConfig 应用配置(重新加载并重启相关服务)
|
||||
func (h *ConfigHandler) ApplyConfig(c *gin.Context) {
|
||||
// 先检查是否需要动态初始化知识库(在锁外执行,避免阻塞其他请求)
|
||||
var needInitKnowledge bool
|
||||
var knowledgeInitializer KnowledgeInitializer
|
||||
|
||||
h.mu.RLock()
|
||||
needInitKnowledge = h.config.Knowledge.Enabled && h.knowledgeToolRegistrar == nil && h.knowledgeInitializer != nil
|
||||
if needInitKnowledge {
|
||||
knowledgeInitializer = h.knowledgeInitializer
|
||||
}
|
||||
h.mu.RUnlock()
|
||||
|
||||
// 如果需要动态初始化知识库,在锁外执行(这是耗时操作)
|
||||
if needInitKnowledge {
|
||||
h.logger.Info("检测到知识库从禁用变为启用,开始动态初始化知识库组件")
|
||||
if _, err := knowledgeInitializer(); err != nil {
|
||||
h.logger.Error("动态初始化知识库失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "初始化知识库失败: " + err.Error()})
|
||||
return
|
||||
}
|
||||
h.logger.Info("知识库动态初始化完成,工具已注册")
|
||||
}
|
||||
|
||||
// 现在获取写锁,执行快速的操作
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
// 重新注册工具(根据新的启用状态)
|
||||
h.logger.Info("重新注册工具")
|
||||
|
||||
|
||||
// 清空MCP服务器中的工具
|
||||
h.mcpServer.ClearTools()
|
||||
|
||||
|
||||
// 重新注册安全工具
|
||||
h.executor.RegisterTools(h.mcpServer)
|
||||
|
||||
|
||||
// 如果知识库启用,重新注册知识库工具
|
||||
if h.config.Knowledge.Enabled && h.knowledgeToolRegistrar != nil {
|
||||
h.logger.Info("重新注册知识库工具")
|
||||
@@ -673,7 +721,7 @@ func (h *ConfigHandler) ApplyConfig(c *gin.Context) {
|
||||
)
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "配置已应用",
|
||||
"message": "配置已应用",
|
||||
"tools_count": len(h.config.Security.Tools),
|
||||
})
|
||||
}
|
||||
@@ -847,7 +895,7 @@ func updateKnowledgeConfig(doc *yaml.Node, cfg config.KnowledgeConfig) {
|
||||
knowledgeNode := ensureMap(root, "knowledge")
|
||||
setBoolInMap(knowledgeNode, "enabled", cfg.Enabled)
|
||||
setStringInMap(knowledgeNode, "base_path", cfg.BasePath)
|
||||
|
||||
|
||||
// 更新嵌入配置
|
||||
embeddingNode := ensureMap(knowledgeNode, "embedding")
|
||||
setStringInMap(embeddingNode, "provider", cfg.Embedding.Provider)
|
||||
@@ -858,7 +906,7 @@ func updateKnowledgeConfig(doc *yaml.Node, cfg config.KnowledgeConfig) {
|
||||
if cfg.Embedding.APIKey != "" {
|
||||
setStringInMap(embeddingNode, "api_key", cfg.Embedding.APIKey)
|
||||
}
|
||||
|
||||
|
||||
// 更新检索配置
|
||||
retrievalNode := ensureMap(knowledgeNode, "retrieval")
|
||||
setIntInMap(retrievalNode, "top_k", cfg.Retrieval.TopK)
|
||||
@@ -940,14 +988,14 @@ func findBoolInMap(mapNode *yaml.Node, key string) *bool {
|
||||
if mapNode == nil || mapNode.Kind != yaml.MappingNode {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
for i := 0; i < len(mapNode.Content); i += 2 {
|
||||
if i+1 >= len(mapNode.Content) {
|
||||
break
|
||||
}
|
||||
keyNode := mapNode.Content[i]
|
||||
valueNode := mapNode.Content[i+1]
|
||||
|
||||
|
||||
if keyNode.Kind == yaml.ScalarNode && keyNode.Value == key {
|
||||
if valueNode.Kind == yaml.ScalarNode {
|
||||
if valueNode.Value == "true" {
|
||||
@@ -989,5 +1037,3 @@ func setFloatInMap(mapNode *yaml.Node, key string, value float64) {
|
||||
valueNode.Value = fmt.Sprintf("%g", value)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -170,21 +170,36 @@ func (h *KnowledgeHandler) RebuildIndex(c *gin.Context) {
|
||||
|
||||
// ScanKnowledgeBase 扫描知识库
|
||||
func (h *KnowledgeHandler) ScanKnowledgeBase(c *gin.Context) {
|
||||
if err := h.manager.ScanKnowledgeBase(); err != nil {
|
||||
itemsToIndex, err := h.manager.ScanKnowledgeBase()
|
||||
if err != nil {
|
||||
h.logger.Error("扫描知识库失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 异步重建索引
|
||||
if len(itemsToIndex) == 0 {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "扫描完成,没有需要索引的新项或更新项"})
|
||||
return
|
||||
}
|
||||
|
||||
// 异步索引新添加或更新的项(增量索引)
|
||||
go func() {
|
||||
ctx := context.Background()
|
||||
if err := h.indexer.RebuildIndex(ctx); err != nil {
|
||||
h.logger.Error("重建索引失败", zap.Error(err))
|
||||
h.logger.Info("开始增量索引", zap.Int("count", len(itemsToIndex)))
|
||||
for i, itemID := range itemsToIndex {
|
||||
if err := h.indexer.IndexItem(ctx, itemID); err != nil {
|
||||
h.logger.Warn("索引知识项失败", zap.String("itemId", itemID), zap.Error(err))
|
||||
continue
|
||||
}
|
||||
h.logger.Info("索引进度", zap.Int("current", i+1), zap.Int("total", len(itemsToIndex)))
|
||||
}
|
||||
h.logger.Info("增量索引完成", zap.Int("totalItems", len(itemsToIndex)))
|
||||
}()
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "扫描完成,索引重建已开始"})
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": fmt.Sprintf("扫描完成,开始索引 %d 个新添加或更新的知识项", len(itemsToIndex)),
|
||||
"items_to_index": len(itemsToIndex),
|
||||
})
|
||||
}
|
||||
|
||||
// GetRetrievalLogs 获取检索日志
|
||||
|
||||
@@ -257,7 +257,7 @@ func (idx *Indexer) IndexItem(ctx context.Context, itemID string) error {
|
||||
return fmt.Errorf("获取知识项失败: %w", err)
|
||||
}
|
||||
|
||||
// 删除旧的向量
|
||||
// 删除旧的向量(在 RebuildIndex 中已经统一清空,这里保留是为了单独调用 IndexItem 时的兼容性)
|
||||
_, err = idx.db.Exec("DELETE FROM knowledge_embeddings WHERE item_id = ?", itemID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("删除旧向量失败: %w", err)
|
||||
@@ -338,12 +338,22 @@ func (idx *Indexer) RebuildIndex(ctx context.Context) error {
|
||||
|
||||
idx.logger.Info("开始重建索引", zap.Int("totalItems", len(itemIDs)))
|
||||
|
||||
// 在开始重建前,先清空所有旧的向量,确保进度从0开始
|
||||
// 这样 GetIndexStatus 可以准确反映重建进度
|
||||
_, err = idx.db.Exec("DELETE FROM knowledge_embeddings")
|
||||
if err != nil {
|
||||
idx.logger.Warn("清空旧索引失败", zap.Error(err))
|
||||
// 继续执行,即使清空失败也尝试重建
|
||||
} else {
|
||||
idx.logger.Info("已清空旧索引,开始重建")
|
||||
}
|
||||
|
||||
for i, itemID := range itemIDs {
|
||||
if err := idx.IndexItem(ctx, itemID); err != nil {
|
||||
idx.logger.Warn("索引知识项失败", zap.String("itemId", itemID), zap.Error(err))
|
||||
continue
|
||||
}
|
||||
idx.logger.Debug("索引进度", zap.Int("current", i+1), zap.Int("total", len(itemIDs)))
|
||||
idx.logger.Info("索引进度", zap.Int("current", i+1), zap.Int("total", len(itemIDs)))
|
||||
}
|
||||
|
||||
idx.logger.Info("索引重建完成", zap.Int("totalItems", len(itemIDs)))
|
||||
|
||||
@@ -31,18 +31,21 @@ func NewManager(db *sql.DB, basePath string, logger *zap.Logger) *Manager {
|
||||
}
|
||||
|
||||
// ScanKnowledgeBase 扫描知识库目录,更新数据库
|
||||
func (m *Manager) ScanKnowledgeBase() error {
|
||||
// 返回需要索引的知识项ID列表(新添加的或更新的)
|
||||
func (m *Manager) ScanKnowledgeBase() ([]string, error) {
|
||||
if m.basePath == "" {
|
||||
return fmt.Errorf("知识库路径未配置")
|
||||
return nil, fmt.Errorf("知识库路径未配置")
|
||||
}
|
||||
|
||||
// 确保目录存在
|
||||
if err := os.MkdirAll(m.basePath, 0755); err != nil {
|
||||
return fmt.Errorf("创建知识库目录失败: %w", err)
|
||||
return nil, fmt.Errorf("创建知识库目录失败: %w", err)
|
||||
}
|
||||
|
||||
var itemsToIndex []string
|
||||
|
||||
// 遍历知识库目录
|
||||
return filepath.WalkDir(m.basePath, func(path string, d fs.DirEntry, err error) error {
|
||||
err := filepath.WalkDir(m.basePath, func(path string, d fs.DirEntry, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -77,10 +80,12 @@ func (m *Manager) ScanKnowledgeBase() error {
|
||||
|
||||
// 检查是否已存在
|
||||
var existingID string
|
||||
var existingContent string
|
||||
var existingUpdatedAt time.Time
|
||||
err = m.db.QueryRow(
|
||||
"SELECT id FROM knowledge_base_items WHERE file_path = ?",
|
||||
"SELECT id, content, updated_at FROM knowledge_base_items WHERE file_path = ?",
|
||||
path,
|
||||
).Scan(&existingID)
|
||||
).Scan(&existingID, &existingContent, &existingUpdatedAt)
|
||||
|
||||
if err == sql.ErrNoRows {
|
||||
// 创建新项
|
||||
@@ -94,22 +99,38 @@ func (m *Manager) ScanKnowledgeBase() error {
|
||||
return fmt.Errorf("插入知识项失败: %w", err)
|
||||
}
|
||||
m.logger.Info("添加知识项", zap.String("id", id), zap.String("title", title), zap.String("category", category))
|
||||
// 新添加的项需要索引
|
||||
itemsToIndex = append(itemsToIndex, id)
|
||||
} else if err == nil {
|
||||
// 更新现有项
|
||||
_, err = m.db.Exec(
|
||||
"UPDATE knowledge_base_items SET category = ?, title = ?, content = ?, updated_at = ? WHERE id = ?",
|
||||
category, title, string(content), time.Now(), existingID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("更新知识项失败: %w", err)
|
||||
// 检查内容是否有变化
|
||||
contentChanged := existingContent != string(content)
|
||||
if contentChanged {
|
||||
// 更新现有项
|
||||
_, err = m.db.Exec(
|
||||
"UPDATE knowledge_base_items SET category = ?, title = ?, content = ?, updated_at = ? WHERE id = ?",
|
||||
category, title, string(content), time.Now(), existingID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("更新知识项失败: %w", err)
|
||||
}
|
||||
m.logger.Info("更新知识项", zap.String("id", existingID), zap.String("title", title))
|
||||
// 内容已更新的项需要重新索引
|
||||
itemsToIndex = append(itemsToIndex, existingID)
|
||||
} else {
|
||||
m.logger.Debug("知识项未变化,跳过", zap.String("id", existingID), zap.String("title", title))
|
||||
}
|
||||
m.logger.Debug("更新知识项", zap.String("id", existingID), zap.String("title", title))
|
||||
} else {
|
||||
return fmt.Errorf("查询知识项失败: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return itemsToIndex, nil
|
||||
}
|
||||
|
||||
// GetCategories 获取所有分类(风险类型)
|
||||
@@ -170,7 +191,7 @@ func (m *Manager) GetItems(category string) ([]*KnowledgeItem, error) {
|
||||
time.RFC3339,
|
||||
time.RFC3339Nano,
|
||||
}
|
||||
|
||||
|
||||
// 解析创建时间
|
||||
if createdAt != "" {
|
||||
for _, format := range timeFormats {
|
||||
@@ -181,7 +202,7 @@ func (m *Manager) GetItems(category string) ([]*KnowledgeItem, error) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// 解析更新时间
|
||||
if updatedAt != "" {
|
||||
for _, format := range timeFormats {
|
||||
@@ -192,7 +213,7 @@ func (m *Manager) GetItems(category string) ([]*KnowledgeItem, error) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// 如果更新时间为空,使用创建时间
|
||||
if item.UpdatedAt.IsZero() && !item.CreatedAt.IsZero() {
|
||||
item.UpdatedAt = item.CreatedAt
|
||||
@@ -230,7 +251,7 @@ func (m *Manager) GetItem(id string) (*KnowledgeItem, error) {
|
||||
time.RFC3339,
|
||||
time.RFC3339Nano,
|
||||
}
|
||||
|
||||
|
||||
// 解析创建时间
|
||||
if createdAt != "" {
|
||||
for _, format := range timeFormats {
|
||||
@@ -241,7 +262,7 @@ func (m *Manager) GetItem(id string) (*KnowledgeItem, error) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// 解析更新时间
|
||||
if updatedAt != "" {
|
||||
for _, format := range timeFormats {
|
||||
@@ -252,7 +273,7 @@ func (m *Manager) GetItem(id string) (*KnowledgeItem, error) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// 如果更新时间为空,使用创建时间
|
||||
if item.UpdatedAt.IsZero() && !item.CreatedAt.IsZero() {
|
||||
item.UpdatedAt = item.CreatedAt
|
||||
@@ -418,10 +439,10 @@ func (m *Manager) GetIndexStatus() (map[string]interface{}, error) {
|
||||
isComplete := indexedItems >= totalItems && totalItems > 0
|
||||
|
||||
return map[string]interface{}{
|
||||
"total_items": totalItems,
|
||||
"indexed_items": indexedItems,
|
||||
"total_items": totalItems,
|
||||
"indexed_items": indexedItems,
|
||||
"progress_percent": progressPercent,
|
||||
"is_complete": isComplete,
|
||||
"is_complete": isComplete,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -472,17 +493,17 @@ func (m *Manager) GetRetrievalLogs(conversationID, messageID string, limit int)
|
||||
time.RFC3339,
|
||||
time.RFC3339Nano,
|
||||
}
|
||||
|
||||
|
||||
for _, format := range timeFormats {
|
||||
log.CreatedAt, err = time.Parse(format, createdAt)
|
||||
if err == nil && !log.CreatedAt.IsZero() {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// 如果所有格式都失败,记录警告但继续处理
|
||||
if log.CreatedAt.IsZero() {
|
||||
m.logger.Warn("解析检索日志时间失败",
|
||||
m.logger.Warn("解析检索日志时间失败",
|
||||
zap.String("timeStr", createdAt),
|
||||
zap.Error(err),
|
||||
)
|
||||
@@ -519,4 +540,3 @@ func (m *Manager) DeleteRetrievalLog(id string) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -22,6 +22,32 @@ async function loadKnowledgeCategories() {
|
||||
throw new Error('获取分类失败');
|
||||
}
|
||||
const data = await response.json();
|
||||
|
||||
// 检查知识库功能是否启用
|
||||
if (data.enabled === false) {
|
||||
// 功能未启用,显示友好提示
|
||||
const container = document.getElementById('knowledge-items-list');
|
||||
if (container) {
|
||||
container.innerHTML = `
|
||||
<div class="empty-state" style="text-align: center; padding: 40px 20px;">
|
||||
<div style="font-size: 48px; margin-bottom: 20px;">📚</div>
|
||||
<h3 style="margin-bottom: 10px; color: #666;">知识库功能未启用</h3>
|
||||
<p style="color: #999; margin-bottom: 20px;">${data.message || '请前往系统设置启用知识检索功能'}</p>
|
||||
<button onclick="switchToSettings()" style="
|
||||
background: #007bff;
|
||||
color: white;
|
||||
border: none;
|
||||
padding: 10px 20px;
|
||||
border-radius: 5px;
|
||||
cursor: pointer;
|
||||
font-size: 14px;
|
||||
">前往设置</button>
|
||||
</div>
|
||||
`;
|
||||
}
|
||||
return [];
|
||||
}
|
||||
|
||||
knowledgeCategories = data.categories || [];
|
||||
|
||||
// 更新分类筛选下拉框
|
||||
@@ -43,7 +69,10 @@ async function loadKnowledgeCategories() {
|
||||
return knowledgeCategories;
|
||||
} catch (error) {
|
||||
console.error('加载分类失败:', error);
|
||||
showNotification('加载分类失败: ' + error.message, 'error');
|
||||
// 只在非功能未启用的情况下显示错误
|
||||
if (!error.message.includes('知识库功能未启用')) {
|
||||
showNotification('加载分类失败: ' + error.message, 'error');
|
||||
}
|
||||
return [];
|
||||
}
|
||||
}
|
||||
@@ -70,12 +99,42 @@ async function loadKnowledgeItems(category = '') {
|
||||
throw new Error('获取知识项失败');
|
||||
}
|
||||
const data = await response.json();
|
||||
|
||||
// 检查知识库功能是否启用
|
||||
if (data.enabled === false) {
|
||||
// 功能未启用,显示友好提示(如果还没有显示的话)
|
||||
const container = document.getElementById('knowledge-items-list');
|
||||
if (container && !container.querySelector('.empty-state')) {
|
||||
container.innerHTML = `
|
||||
<div class="empty-state" style="text-align: center; padding: 40px 20px;">
|
||||
<div style="font-size: 48px; margin-bottom: 20px;">📚</div>
|
||||
<h3 style="margin-bottom: 10px; color: #666;">知识库功能未启用</h3>
|
||||
<p style="color: #999; margin-bottom: 20px;">${data.message || '请前往系统设置启用知识检索功能'}</p>
|
||||
<button onclick="switchToSettings()" style="
|
||||
background: #007bff;
|
||||
color: white;
|
||||
border: none;
|
||||
padding: 10px 20px;
|
||||
border-radius: 5px;
|
||||
cursor: pointer;
|
||||
font-size: 14px;
|
||||
">前往设置</button>
|
||||
</div>
|
||||
`;
|
||||
}
|
||||
knowledgeItems = [];
|
||||
return [];
|
||||
}
|
||||
|
||||
knowledgeItems = data.items || [];
|
||||
renderKnowledgeItems(knowledgeItems);
|
||||
return knowledgeItems;
|
||||
} catch (error) {
|
||||
console.error('加载知识项失败:', error);
|
||||
showNotification('加载知识项失败: ' + error.message, 'error');
|
||||
// 只在非功能未启用的情况下显示错误
|
||||
if (!error.message.includes('知识库功能未启用')) {
|
||||
showNotification('加载知识项失败: ' + error.message, 'error');
|
||||
}
|
||||
return [];
|
||||
}
|
||||
}
|
||||
@@ -252,6 +311,17 @@ async function updateIndexProgress() {
|
||||
const progressContainer = document.getElementById('knowledge-index-progress');
|
||||
if (!progressContainer) return;
|
||||
|
||||
// 检查知识库功能是否启用
|
||||
if (status.enabled === false) {
|
||||
// 功能未启用,隐藏进度条
|
||||
progressContainer.style.display = 'none';
|
||||
if (indexProgressInterval) {
|
||||
clearInterval(indexProgressInterval);
|
||||
indexProgressInterval = null;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
const totalItems = status.total_items || 0;
|
||||
const indexedItems = status.indexed_items || 0;
|
||||
const progressPercent = status.progress_percent || 0;
|
||||
@@ -373,16 +443,35 @@ async function refreshKnowledgeBase() {
|
||||
if (!response.ok) {
|
||||
throw new Error('扫描知识库失败');
|
||||
}
|
||||
showNotification('扫描完成,索引重建已开始', 'success');
|
||||
const data = await response.json();
|
||||
// 根据返回的消息显示不同的提示
|
||||
if (data.items_to_index && data.items_to_index > 0) {
|
||||
showNotification(`扫描完成,开始索引 ${data.items_to_index} 个新添加或更新的知识项`, 'success');
|
||||
} else {
|
||||
showNotification(data.message || '扫描完成,没有需要索引的新项或更新项', 'success');
|
||||
}
|
||||
// 重新加载知识项
|
||||
await loadKnowledgeCategories();
|
||||
await loadKnowledgeItems();
|
||||
|
||||
// 开始轮询进度
|
||||
// 停止现有的轮询
|
||||
if (indexProgressInterval) {
|
||||
clearInterval(indexProgressInterval);
|
||||
indexProgressInterval = null;
|
||||
}
|
||||
|
||||
// 如果有需要索引的项,等待一小段时间后立即更新进度
|
||||
if (data.items_to_index && data.items_to_index > 0) {
|
||||
await new Promise(resolve => setTimeout(resolve, 500));
|
||||
updateIndexProgress();
|
||||
// 开始轮询进度(每2秒刷新一次)
|
||||
if (!indexProgressInterval) {
|
||||
indexProgressInterval = setInterval(updateIndexProgress, 2000);
|
||||
}
|
||||
} else {
|
||||
// 没有需要索引的项,也更新一次以显示当前状态
|
||||
updateIndexProgress();
|
||||
}
|
||||
updateIndexProgress(); // 立即更新一次
|
||||
} catch (error) {
|
||||
console.error('刷新知识库失败:', error);
|
||||
showNotification('刷新知识库失败: ' + error.message, 'error');
|
||||
@@ -396,6 +485,31 @@ async function rebuildKnowledgeIndex() {
|
||||
return;
|
||||
}
|
||||
showNotification('正在重建索引...', 'info');
|
||||
|
||||
// 先停止现有的轮询
|
||||
if (indexProgressInterval) {
|
||||
clearInterval(indexProgressInterval);
|
||||
indexProgressInterval = null;
|
||||
}
|
||||
|
||||
// 立即显示"正在重建"状态,因为重建开始时会清空旧索引
|
||||
const progressContainer = document.getElementById('knowledge-index-progress');
|
||||
if (progressContainer) {
|
||||
progressContainer.style.display = 'block';
|
||||
progressContainer.innerHTML = `
|
||||
<div class="knowledge-index-progress">
|
||||
<div class="progress-header">
|
||||
<span class="progress-icon">🔨</span>
|
||||
<span class="progress-text">正在重建索引: 准备中...</span>
|
||||
</div>
|
||||
<div class="progress-bar-container">
|
||||
<div class="progress-bar" style="width: 0%"></div>
|
||||
</div>
|
||||
<div class="progress-hint">索引构建完成后,语义搜索功能将可用</div>
|
||||
</div>
|
||||
`;
|
||||
}
|
||||
|
||||
const response = await apiFetch('/api/knowledge/index', {
|
||||
method: 'POST'
|
||||
});
|
||||
@@ -404,11 +518,16 @@ async function rebuildKnowledgeIndex() {
|
||||
}
|
||||
showNotification('索引重建已开始,将在后台进行', 'success');
|
||||
|
||||
// 开始轮询进度
|
||||
if (indexProgressInterval) {
|
||||
clearInterval(indexProgressInterval);
|
||||
// 等待一小段时间,确保后端已经开始处理并清空了旧索引
|
||||
await new Promise(resolve => setTimeout(resolve, 500));
|
||||
|
||||
// 立即更新一次进度
|
||||
updateIndexProgress();
|
||||
|
||||
// 开始轮询进度(每2秒刷新一次,比默认的3秒更频繁)
|
||||
if (!indexProgressInterval) {
|
||||
indexProgressInterval = setInterval(updateIndexProgress, 2000);
|
||||
}
|
||||
updateIndexProgress(); // 立即更新一次
|
||||
} catch (error) {
|
||||
console.error('重建索引失败:', error);
|
||||
showNotification('重建索引失败: ' + error.message, 'error');
|
||||
@@ -1528,8 +1647,8 @@ function formatTime(timeStr) {
|
||||
|
||||
// 显示通知
|
||||
function showNotification(message, type = 'info') {
|
||||
// 如果存在全局通知系统,使用它
|
||||
if (typeof window.showNotification === 'function') {
|
||||
// 如果存在全局通知系统(且不是当前函数),使用它
|
||||
if (typeof window.showNotification === 'function' && window.showNotification !== showNotification) {
|
||||
window.showNotification(message, type);
|
||||
return;
|
||||
}
|
||||
@@ -1680,6 +1799,39 @@ window.addEventListener('click', function(event) {
|
||||
}
|
||||
});
|
||||
|
||||
// 切换到设置页面(用于功能未启用时的提示)
|
||||
function switchToSettings() {
|
||||
if (typeof switchPage === 'function') {
|
||||
switchPage('settings');
|
||||
// 等待设置页面加载后,切换到知识库配置部分
|
||||
setTimeout(() => {
|
||||
if (typeof switchSettingsSection === 'function') {
|
||||
// 查找知识库配置部分(通常在基本设置中)
|
||||
const knowledgeSection = document.querySelector('[data-section="knowledge"]');
|
||||
if (knowledgeSection) {
|
||||
switchSettingsSection('knowledge');
|
||||
} else {
|
||||
// 如果没有独立的知识库部分,切换到基本设置
|
||||
switchSettingsSection('basic');
|
||||
// 滚动到知识库配置区域
|
||||
setTimeout(() => {
|
||||
const knowledgeEnabledCheckbox = document.getElementById('knowledge-enabled');
|
||||
if (knowledgeEnabledCheckbox) {
|
||||
knowledgeEnabledCheckbox.scrollIntoView({ behavior: 'smooth', block: 'center' });
|
||||
// 高亮显示
|
||||
knowledgeEnabledCheckbox.parentElement.style.transition = 'background-color 0.3s';
|
||||
knowledgeEnabledCheckbox.parentElement.style.backgroundColor = '#e3f2fd';
|
||||
setTimeout(() => {
|
||||
knowledgeEnabledCheckbox.parentElement.style.backgroundColor = '';
|
||||
}, 2000);
|
||||
}
|
||||
}, 300);
|
||||
}
|
||||
}
|
||||
}, 100);
|
||||
}
|
||||
}
|
||||
|
||||
// 自定义下拉组件交互
|
||||
document.addEventListener('DOMContentLoaded', function() {
|
||||
const wrapper = document.getElementById('knowledge-category-filter-wrapper');
|
||||
|
||||
Reference in New Issue
Block a user