Add files via upload

This commit is contained in:
公明
2025-12-27 19:42:21 +08:00
committed by GitHub
parent cb45b9e540
commit 65957b2013
6 changed files with 672 additions and 129 deletions

View File

@@ -26,16 +26,21 @@ import (
// App 应用
type App struct {
config *config.Config
logger *logger.Logger
router *gin.Engine
mcpServer *mcp.Server
externalMCPMgr *mcp.ExternalMCPManager
agent *agent.Agent
executor *security.Executor
db *database.DB
knowledgeDB *database.DB // 知识库数据库连接(如果使用独立数据库)
auth *security.AuthManager
config *config.Config
logger *logger.Logger
router *gin.Engine
mcpServer *mcp.Server
externalMCPMgr *mcp.ExternalMCPManager
agent *agent.Agent
executor *security.Executor
db *database.DB
knowledgeDB *database.DB // 知识库数据库连接(如果使用独立数据库)
auth *security.AuthManager
knowledgeManager *knowledge.Manager // 知识库管理器(用于动态初始化)
knowledgeRetriever *knowledge.Retriever // 知识库检索器(用于动态初始化)
knowledgeIndexer *knowledge.Indexer // 知识库索引器(用于动态初始化)
knowledgeHandler *handler.KnowledgeHandler // 知识库处理器(用于动态初始化)
agentHandler *handler.AgentHandler // Agent处理器用于更新知识库管理器
}
// New 创建新应用
@@ -196,12 +201,13 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
// 扫描知识库并建立索引(异步)
go func() {
if err := knowledgeManager.ScanKnowledgeBase(); err != nil {
itemsToIndex, err := knowledgeManager.ScanKnowledgeBase()
if err != nil {
log.Logger.Warn("扫描知识库失败", zap.Error(err))
return
}
// 检查是否已有索引,如果有则跳过自动重建
// 检查是否已有索引
hasIndex, err := knowledgeIndexer.HasIndex()
if err != nil {
log.Logger.Warn("检查索引状态失败", zap.Error(err))
@@ -209,7 +215,20 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
}
if hasIndex {
log.Logger.Info("检测到已有知识库索引,跳过自动重建。如需重建,请手动点击重建索引按钮")
// 如果已有索引,只索引新添加或更新的项
if len(itemsToIndex) > 0 {
log.Logger.Info("检测到已有知识库索引,开始增量索引", zap.Int("count", len(itemsToIndex)))
ctx := context.Background()
for _, itemID := range itemsToIndex {
if err := knowledgeIndexer.IndexItem(ctx, itemID); err != nil {
log.Logger.Warn("索引知识项失败", zap.String("itemId", itemID), zap.Error(err))
continue
}
}
log.Logger.Info("增量索引完成", zap.Int("totalItems", len(itemsToIndex)))
} else {
log.Logger.Info("检测到已有知识库索引,没有需要索引的新项或更新项")
}
return
}
@@ -242,6 +261,51 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
attackChainHandler := handler.NewAttackChainHandler(db, &cfg.OpenAI, log.Logger)
vulnerabilityHandler := handler.NewVulnerabilityHandler(db, log.Logger)
configHandler := handler.NewConfigHandler(configPath, cfg, mcpServer, executor, agent, attackChainHandler, externalMCPMgr, log.Logger)
externalMCPHandler := handler.NewExternalMCPHandler(externalMCPMgr, cfg, configPath, log.Logger)
// 创建 App 实例(部分字段稍后填充)
app := &App{
config: cfg,
logger: log,
router: router,
mcpServer: mcpServer,
externalMCPMgr: externalMCPMgr,
agent: agent,
executor: executor,
db: db,
knowledgeDB: knowledgeDBConn,
auth: authManager,
knowledgeManager: knowledgeManager,
knowledgeRetriever: knowledgeRetriever,
knowledgeIndexer: knowledgeIndexer,
knowledgeHandler: knowledgeHandler,
agentHandler: agentHandler,
}
// 设置知识库初始化器(用于动态初始化,需要在 App 创建后设置)
configHandler.SetKnowledgeInitializer(func() (*handler.KnowledgeHandler, error) {
knowledgeHandler, err := initializeKnowledge(cfg, db, knowledgeDBConn, mcpServer, agentHandler, app, log.Logger)
if err != nil {
return nil, err
}
// 动态初始化后,设置知识库工具注册器和检索器更新器
// 这样后续 ApplyConfig 时就能重新注册工具了
if app.knowledgeRetriever != nil && app.knowledgeManager != nil {
// 创建闭包捕获knowledgeRetriever和knowledgeManager的引用
registrar := func() error {
knowledge.RegisterKnowledgeTool(mcpServer, app.knowledgeRetriever, app.knowledgeManager, log.Logger)
return nil
}
configHandler.SetKnowledgeToolRegistrar(registrar)
// 设置检索器更新器以便在ApplyConfig时更新检索器配置
configHandler.SetRetrieverUpdater(app.knowledgeRetriever)
log.Logger.Info("动态初始化后已设置知识库工具注册器和检索器更新器")
}
return knowledgeHandler, nil
})
// 如果知识库已启用,设置知识库工具注册器和检索器更新器
if cfg.Knowledge.Enabled && knowledgeRetriever != nil && knowledgeManager != nil {
// 创建闭包捕获knowledgeRetriever和knowledgeManager的引用
@@ -253,9 +317,8 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
// 设置检索器更新器以便在ApplyConfig时更新检索器配置
configHandler.SetRetrieverUpdater(knowledgeRetriever)
}
externalMCPHandler := handler.NewExternalMCPHandler(externalMCPMgr, cfg, configPath, log.Logger)
// 设置路由
// 设置路由(使用 App 实例以便动态获取 handler
setupRoutes(
router,
authHandler,
@@ -266,24 +329,14 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
configHandler,
externalMCPHandler,
attackChainHandler,
knowledgeHandler,
app, // 传递 App 实例以便动态获取 knowledgeHandler
vulnerabilityHandler,
mcpServer,
authManager,
)
return &App{
config: cfg,
logger: log,
router: router,
mcpServer: mcpServer,
externalMCPMgr: externalMCPMgr,
agent: agent,
executor: executor,
db: db,
knowledgeDB: knowledgeDBConn,
auth: authManager,
}, nil
return app, nil
}
// Run 启动应用
@@ -336,7 +389,7 @@ func setupRoutes(
configHandler *handler.ConfigHandler,
externalMCPHandler *handler.ExternalMCPHandler,
attackChainHandler *handler.AttackChainHandler,
knowledgeHandler *handler.KnowledgeHandler,
app *App, // 传递 App 实例以便动态获取 knowledgeHandler
vulnerabilityHandler *handler.VulnerabilityHandler,
mcpServer *mcp.Server,
authManager *security.AuthManager,
@@ -409,20 +462,137 @@ func setupRoutes(
protected.GET("/attack-chain/:conversationId", attackChainHandler.GetAttackChain)
protected.POST("/attack-chain/:conversationId/regenerate", attackChainHandler.RegenerateAttackChain)
// 知识库管理(如果启用
if knowledgeHandler != nil {
protected.GET("/knowledge/categories", knowledgeHandler.GetCategories)
protected.GET("/knowledge/items", knowledgeHandler.GetItems)
protected.GET("/knowledge/items/:id", knowledgeHandler.GetItem)
protected.POST("/knowledge/items", knowledgeHandler.CreateItem)
protected.PUT("/knowledge/items/:id", knowledgeHandler.UpdateItem)
protected.DELETE("/knowledge/items/:id", knowledgeHandler.DeleteItem)
protected.GET("/knowledge/index-status", knowledgeHandler.GetIndexStatus)
protected.POST("/knowledge/index", knowledgeHandler.RebuildIndex)
protected.POST("/knowledge/scan", knowledgeHandler.ScanKnowledgeBase)
protected.GET("/knowledge/retrieval-logs", knowledgeHandler.GetRetrievalLogs)
protected.DELETE("/knowledge/retrieval-logs/:id", knowledgeHandler.DeleteRetrievalLog)
protected.POST("/knowledge/search", knowledgeHandler.Search)
// 知识库管理(始终注册路由,通过 App 实例动态获取 handler
knowledgeRoutes := protected.Group("/knowledge")
{
knowledgeRoutes.GET("/categories", func(c *gin.Context) {
if app.knowledgeHandler == nil {
c.JSON(http.StatusOK, gin.H{
"categories": []string{},
"enabled": false,
"message": "知识库功能未启用,请前往系统设置启用知识检索功能",
})
return
}
app.knowledgeHandler.GetCategories(c)
})
knowledgeRoutes.GET("/items", func(c *gin.Context) {
if app.knowledgeHandler == nil {
c.JSON(http.StatusOK, gin.H{
"items": []interface{}{},
"enabled": false,
"message": "知识库功能未启用,请前往系统设置启用知识检索功能",
})
return
}
app.knowledgeHandler.GetItems(c)
})
knowledgeRoutes.GET("/items/:id", func(c *gin.Context) {
if app.knowledgeHandler == nil {
c.JSON(http.StatusOK, gin.H{
"enabled": false,
"message": "知识库功能未启用,请前往系统设置启用知识检索功能",
})
return
}
app.knowledgeHandler.GetItem(c)
})
knowledgeRoutes.POST("/items", func(c *gin.Context) {
if app.knowledgeHandler == nil {
c.JSON(http.StatusOK, gin.H{
"enabled": false,
"error": "知识库功能未启用,请前往系统设置启用知识检索功能",
})
return
}
app.knowledgeHandler.CreateItem(c)
})
knowledgeRoutes.PUT("/items/:id", func(c *gin.Context) {
if app.knowledgeHandler == nil {
c.JSON(http.StatusOK, gin.H{
"enabled": false,
"error": "知识库功能未启用,请前往系统设置启用知识检索功能",
})
return
}
app.knowledgeHandler.UpdateItem(c)
})
knowledgeRoutes.DELETE("/items/:id", func(c *gin.Context) {
if app.knowledgeHandler == nil {
c.JSON(http.StatusOK, gin.H{
"enabled": false,
"error": "知识库功能未启用,请前往系统设置启用知识检索功能",
})
return
}
app.knowledgeHandler.DeleteItem(c)
})
knowledgeRoutes.GET("/index-status", func(c *gin.Context) {
if app.knowledgeHandler == nil {
c.JSON(http.StatusOK, gin.H{
"enabled": false,
"total_items": 0,
"indexed_items": 0,
"progress_percent": 0,
"is_complete": false,
"message": "知识库功能未启用,请前往系统设置启用知识检索功能",
})
return
}
app.knowledgeHandler.GetIndexStatus(c)
})
knowledgeRoutes.POST("/index", func(c *gin.Context) {
if app.knowledgeHandler == nil {
c.JSON(http.StatusOK, gin.H{
"enabled": false,
"error": "知识库功能未启用,请前往系统设置启用知识检索功能",
})
return
}
app.knowledgeHandler.RebuildIndex(c)
})
knowledgeRoutes.POST("/scan", func(c *gin.Context) {
if app.knowledgeHandler == nil {
c.JSON(http.StatusOK, gin.H{
"enabled": false,
"error": "知识库功能未启用,请前往系统设置启用知识检索功能",
})
return
}
app.knowledgeHandler.ScanKnowledgeBase(c)
})
knowledgeRoutes.GET("/retrieval-logs", func(c *gin.Context) {
if app.knowledgeHandler == nil {
c.JSON(http.StatusOK, gin.H{
"logs": []interface{}{},
"enabled": false,
"message": "知识库功能未启用,请前往系统设置启用知识检索功能",
})
return
}
app.knowledgeHandler.GetRetrievalLogs(c)
})
knowledgeRoutes.DELETE("/retrieval-logs/:id", func(c *gin.Context) {
if app.knowledgeHandler == nil {
c.JSON(http.StatusOK, gin.H{
"enabled": false,
"error": "知识库功能未启用,请前往系统设置启用知识检索功能",
})
return
}
app.knowledgeHandler.DeleteRetrievalLog(c)
})
knowledgeRoutes.POST("/search", func(c *gin.Context) {
if app.knowledgeHandler == nil {
c.JSON(http.StatusOK, gin.H{
"results": []interface{}{},
"enabled": false,
"message": "知识库功能未启用,请前往系统设置启用知识检索功能",
})
return
}
app.knowledgeHandler.Search(c)
})
}
// 漏洞管理
@@ -594,7 +764,7 @@ func registerVulnerabilityTool(mcpServer *mcp.Server, db *database.DB, logger *z
Title: title,
Description: description,
Severity: severity,
Status: "open",
Status: "open",
Type: vulnType,
Target: target,
Proof: proof,
@@ -638,6 +808,136 @@ func registerVulnerabilityTool(mcpServer *mcp.Server, db *database.DB, logger *z
logger.Info("漏洞记录工具注册成功")
}
// initializeKnowledge 初始化知识库组件(用于动态初始化)
func initializeKnowledge(
cfg *config.Config,
db *database.DB,
knowledgeDBConn *database.DB,
mcpServer *mcp.Server,
agentHandler *handler.AgentHandler,
app *App, // 传递 App 引用以便更新知识库组件
logger *zap.Logger,
) (*handler.KnowledgeHandler, error) {
// 确定知识库数据库路径
knowledgeDBPath := cfg.Database.KnowledgeDBPath
var knowledgeDB *sql.DB
if knowledgeDBPath != "" {
// 使用独立的知识库数据库
// 确保目录存在
if err := os.MkdirAll(filepath.Dir(knowledgeDBPath), 0755); err != nil {
return nil, fmt.Errorf("创建知识库数据库目录失败: %w", err)
}
var err error
knowledgeDBConn, err = database.NewKnowledgeDB(knowledgeDBPath, logger)
if err != nil {
return nil, fmt.Errorf("初始化知识库数据库失败: %w", err)
}
knowledgeDB = knowledgeDBConn.DB
logger.Info("使用独立的知识库数据库", zap.String("path", knowledgeDBPath))
} else {
// 向后兼容:使用会话数据库
knowledgeDB = db.DB
logger.Info("使用会话数据库存储知识库数据建议配置knowledge_db_path以分离数据")
}
// 创建知识库管理器
knowledgeManager := knowledge.NewManager(knowledgeDB, cfg.Knowledge.BasePath, logger)
// 创建嵌入器
// 使用OpenAI配置的API Key如果知识库配置中没有指定
if cfg.Knowledge.Embedding.APIKey == "" {
cfg.Knowledge.Embedding.APIKey = cfg.OpenAI.APIKey
}
if cfg.Knowledge.Embedding.BaseURL == "" {
cfg.Knowledge.Embedding.BaseURL = cfg.OpenAI.BaseURL
}
httpClient := &http.Client{
Timeout: 30 * time.Minute,
}
openAIClient := openai.NewClient(&cfg.OpenAI, httpClient, logger)
embedder := knowledge.NewEmbedder(&cfg.Knowledge, &cfg.OpenAI, openAIClient, logger)
// 创建检索器
retrievalConfig := &knowledge.RetrievalConfig{
TopK: cfg.Knowledge.Retrieval.TopK,
SimilarityThreshold: cfg.Knowledge.Retrieval.SimilarityThreshold,
HybridWeight: cfg.Knowledge.Retrieval.HybridWeight,
}
knowledgeRetriever := knowledge.NewRetriever(knowledgeDB, embedder, retrievalConfig, logger)
// 创建索引器
knowledgeIndexer := knowledge.NewIndexer(knowledgeDB, embedder, logger)
// 注册知识检索工具到MCP服务器
knowledge.RegisterKnowledgeTool(mcpServer, knowledgeRetriever, knowledgeManager, logger)
// 创建知识库API处理器
knowledgeHandler := handler.NewKnowledgeHandler(knowledgeManager, knowledgeRetriever, knowledgeIndexer, db, logger)
logger.Info("知识库模块初始化完成", zap.Bool("handler_created", knowledgeHandler != nil))
// 设置知识库管理器到AgentHandler以便记录检索日志
agentHandler.SetKnowledgeManager(knowledgeManager)
// 更新 App 中的知识库组件(如果 App 不为 nil说明是动态初始化
if app != nil {
app.knowledgeManager = knowledgeManager
app.knowledgeRetriever = knowledgeRetriever
app.knowledgeIndexer = knowledgeIndexer
app.knowledgeHandler = knowledgeHandler
// 如果使用独立数据库,更新 knowledgeDB
if knowledgeDBPath != "" {
app.knowledgeDB = knowledgeDBConn
}
logger.Info("App 中的知识库组件已更新")
}
// 扫描知识库并建立索引(异步)
go func() {
itemsToIndex, err := knowledgeManager.ScanKnowledgeBase()
if err != nil {
logger.Warn("扫描知识库失败", zap.Error(err))
return
}
// 检查是否已有索引
hasIndex, err := knowledgeIndexer.HasIndex()
if err != nil {
logger.Warn("检查索引状态失败", zap.Error(err))
return
}
if hasIndex {
// 如果已有索引,只索引新添加或更新的项
if len(itemsToIndex) > 0 {
logger.Info("检测到已有知识库索引,开始增量索引", zap.Int("count", len(itemsToIndex)))
ctx := context.Background()
for _, itemID := range itemsToIndex {
if err := knowledgeIndexer.IndexItem(ctx, itemID); err != nil {
logger.Warn("索引知识项失败", zap.String("itemId", itemID), zap.Error(err))
continue
}
}
logger.Info("增量索引完成", zap.Int("totalItems", len(itemsToIndex)))
} else {
logger.Info("检测到已有知识库索引,没有需要索引的新项或更新项")
}
return
}
// 只有在没有索引时才自动重建
logger.Info("未检测到知识库索引,开始自动构建索引")
ctx := context.Background()
if err := knowledgeIndexer.RebuildIndex(ctx); err != nil {
logger.Warn("重建知识库索引失败", zap.Error(err))
}
}()
return knowledgeHandler, nil
}
// corsMiddleware CORS中间件
func corsMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {

View File

@@ -16,6 +16,7 @@ import (
"cyberstrike-ai/internal/knowledge"
"cyberstrike-ai/internal/mcp"
"cyberstrike-ai/internal/security"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
"gopkg.in/yaml.v3"
@@ -29,19 +30,29 @@ type RetrieverUpdater interface {
UpdateConfig(config *knowledge.RetrievalConfig)
}
// KnowledgeInitializer 知识库初始化器接口
type KnowledgeInitializer func() (*KnowledgeHandler, error)
// AppUpdater App更新接口用于更新App中的知识库组件
type AppUpdater interface {
UpdateKnowledgeComponents(handler *KnowledgeHandler, manager interface{}, retriever interface{}, indexer interface{})
}
// ConfigHandler 配置处理器
type ConfigHandler struct {
configPath string
config *config.Config
mcpServer *mcp.Server
executor *security.Executor
agent AgentUpdater // Agent接口用于更新Agent配置
attackChainHandler AttackChainUpdater // 攻击链处理器接口,用于更新配置
externalMCPMgr *mcp.ExternalMCPManager // 外部MCP管理器
knowledgeToolRegistrar KnowledgeToolRegistrar // 知识库工具注册器(可选)
retrieverUpdater RetrieverUpdater // 检索器更新器(可选)
logger *zap.Logger
mu sync.RWMutex
configPath string
config *config.Config
mcpServer *mcp.Server
executor *security.Executor
agent AgentUpdater // Agent接口用于更新Agent配置
attackChainHandler AttackChainUpdater // 攻击链处理器接口,用于更新配置
externalMCPMgr *mcp.ExternalMCPManager // 外部MCP管理器
knowledgeToolRegistrar KnowledgeToolRegistrar // 知识库工具注册器(可选)
retrieverUpdater RetrieverUpdater // 检索器更新器(可选)
knowledgeInitializer KnowledgeInitializer // 知识库初始化器(可选)
appUpdater AppUpdater // App更新器可选
logger *zap.Logger
mu sync.RWMutex
}
// AttackChainUpdater 攻击链处理器更新接口
@@ -83,12 +94,26 @@ func (h *ConfigHandler) SetRetrieverUpdater(updater RetrieverUpdater) {
h.retrieverUpdater = updater
}
// SetKnowledgeInitializer 设置知识库初始化器
func (h *ConfigHandler) SetKnowledgeInitializer(initializer KnowledgeInitializer) {
h.mu.Lock()
defer h.mu.Unlock()
h.knowledgeInitializer = initializer
}
// SetAppUpdater 设置App更新器
func (h *ConfigHandler) SetAppUpdater(updater AppUpdater) {
h.mu.Lock()
defer h.mu.Unlock()
h.appUpdater = updater
}
// GetConfigResponse 获取配置响应
type GetConfigResponse struct {
OpenAI config.OpenAIConfig `json:"openai"`
MCP config.MCPConfig `json:"mcp"`
Tools []ToolConfigInfo `json:"tools"`
Agent config.AgentConfig `json:"agent"`
OpenAI config.OpenAIConfig `json:"openai"`
MCP config.MCPConfig `json:"mcp"`
Tools []ToolConfigInfo `json:"tools"`
Agent config.AgentConfig `json:"agent"`
Knowledge config.KnowledgeConfig `json:"knowledge"`
}
@@ -127,7 +152,7 @@ func (h *ConfigHandler) GetConfig(c *gin.Context) {
tools[len(tools)-1].Description = desc
}
}
// 从MCP服务器获取所有已注册的工具包括直接注册的工具如知识检索工具
if h.mcpServer != nil {
mcpTools := h.mcpServer.GetAllTools()
@@ -287,7 +312,7 @@ func (h *ConfigHandler) GetTools(c *gin.Context) {
allTools = append(allTools, toolInfo)
}
// 从MCP服务器获取所有已注册的工具包括直接注册的工具如知识检索工具
if h.mcpServer != nil {
mcpTools := h.mcpServer.GetAllTools()
@@ -296,7 +321,7 @@ func (h *ConfigHandler) GetTools(c *gin.Context) {
if configToolMap[mcpTool.Name] {
continue
}
description := mcpTool.ShortDescription
if description == "" {
description = mcpTool.Description
@@ -304,14 +329,14 @@ func (h *ConfigHandler) GetTools(c *gin.Context) {
if len(description) > 100 {
description = description[:100] + "..."
}
toolInfo := ToolConfigInfo{
Name: mcpTool.Name,
Description: description,
Enabled: true, // 直接注册的工具默认启用
IsExternal: false,
}
// 如果有关键词,进行搜索过滤
if searchTermLower != "" {
nameLower := strings.ToLower(toolInfo.Name)
@@ -320,7 +345,7 @@ func (h *ConfigHandler) GetTools(c *gin.Context) {
continue // 不匹配,跳过
}
}
allTools = append(allTools, toolInfo)
}
}
@@ -336,7 +361,7 @@ func (h *ConfigHandler) GetTools(c *gin.Context) {
} else {
// 获取外部MCP配置用于判断启用状态
externalMCPConfigs := h.externalMCPMgr.GetConfigs()
for _, externalTool := range externalTools {
// 解析工具名称mcpName::toolName
var mcpName, actualToolName string
@@ -434,7 +459,7 @@ type UpdateConfigRequest struct {
OpenAI *config.OpenAIConfig `json:"openai,omitempty"`
MCP *config.MCPConfig `json:"mcp,omitempty"`
Tools []ToolEnableStatus `json:"tools,omitempty"`
Agent *config.AgentConfig `json:"agent,omitempty"`
Agent *config.AgentConfig `json:"agent,omitempty"`
Knowledge *config.KnowledgeConfig `json:"knowledge,omitempty"`
}
@@ -541,12 +566,12 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) {
h.logger.Warn("外部MCP配置不存在", zap.String("mcp", mcpName))
continue
}
// 初始化ToolEnabled map
if cfg.ToolEnabled == nil {
cfg.ToolEnabled = make(map[string]bool)
}
// 更新每个工具的启用状态
for toolName, enabled := range toolStates {
cfg.ToolEnabled[toolName] = enabled
@@ -556,7 +581,7 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) {
zap.Bool("enabled", enabled),
)
}
// 检查是否有任何工具启用如果有则启用MCP
hasEnabledTool := false
for _, enabled := range cfg.ToolEnabled {
@@ -565,21 +590,21 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) {
break
}
}
// 如果MCP之前未启用但现在有工具启用则启用MCP
// 如果MCP之前已启用保持启用状态允许部分工具禁用
if !cfg.ExternalMCPEnable && hasEnabledTool {
cfg.ExternalMCPEnable = true
h.logger.Info("自动启用外部MCP因为有工具启用", zap.String("mcp", mcpName))
}
h.config.ExternalMCP.Servers[mcpName] = cfg
}
// 同步更新 externalMCPMgr 中的配置,确保 GetConfigs() 返回最新配置
// 在循环外部统一更新,避免重复调用
h.externalMCPMgr.LoadConfigs(&h.config.ExternalMCP)
// 处理MCP连接状态异步启动避免阻塞
for mcpName := range externalMCPToolMap {
cfg := h.config.ExternalMCP.Servers[mcpName]
@@ -618,18 +643,41 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) {
// ApplyConfig 应用配置(重新加载并重启相关服务)
func (h *ConfigHandler) ApplyConfig(c *gin.Context) {
// 先检查是否需要动态初始化知识库(在锁外执行,避免阻塞其他请求)
var needInitKnowledge bool
var knowledgeInitializer KnowledgeInitializer
h.mu.RLock()
needInitKnowledge = h.config.Knowledge.Enabled && h.knowledgeToolRegistrar == nil && h.knowledgeInitializer != nil
if needInitKnowledge {
knowledgeInitializer = h.knowledgeInitializer
}
h.mu.RUnlock()
// 如果需要动态初始化知识库,在锁外执行(这是耗时操作)
if needInitKnowledge {
h.logger.Info("检测到知识库从禁用变为启用,开始动态初始化知识库组件")
if _, err := knowledgeInitializer(); err != nil {
h.logger.Error("动态初始化知识库失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "初始化知识库失败: " + err.Error()})
return
}
h.logger.Info("知识库动态初始化完成,工具已注册")
}
// 现在获取写锁,执行快速的操作
h.mu.Lock()
defer h.mu.Unlock()
// 重新注册工具(根据新的启用状态)
h.logger.Info("重新注册工具")
// 清空MCP服务器中的工具
h.mcpServer.ClearTools()
// 重新注册安全工具
h.executor.RegisterTools(h.mcpServer)
// 如果知识库启用,重新注册知识库工具
if h.config.Knowledge.Enabled && h.knowledgeToolRegistrar != nil {
h.logger.Info("重新注册知识库工具")
@@ -673,7 +721,7 @@ func (h *ConfigHandler) ApplyConfig(c *gin.Context) {
)
c.JSON(http.StatusOK, gin.H{
"message": "配置已应用",
"message": "配置已应用",
"tools_count": len(h.config.Security.Tools),
})
}
@@ -847,7 +895,7 @@ func updateKnowledgeConfig(doc *yaml.Node, cfg config.KnowledgeConfig) {
knowledgeNode := ensureMap(root, "knowledge")
setBoolInMap(knowledgeNode, "enabled", cfg.Enabled)
setStringInMap(knowledgeNode, "base_path", cfg.BasePath)
// 更新嵌入配置
embeddingNode := ensureMap(knowledgeNode, "embedding")
setStringInMap(embeddingNode, "provider", cfg.Embedding.Provider)
@@ -858,7 +906,7 @@ func updateKnowledgeConfig(doc *yaml.Node, cfg config.KnowledgeConfig) {
if cfg.Embedding.APIKey != "" {
setStringInMap(embeddingNode, "api_key", cfg.Embedding.APIKey)
}
// 更新检索配置
retrievalNode := ensureMap(knowledgeNode, "retrieval")
setIntInMap(retrievalNode, "top_k", cfg.Retrieval.TopK)
@@ -940,14 +988,14 @@ func findBoolInMap(mapNode *yaml.Node, key string) *bool {
if mapNode == nil || mapNode.Kind != yaml.MappingNode {
return nil
}
for i := 0; i < len(mapNode.Content); i += 2 {
if i+1 >= len(mapNode.Content) {
break
}
keyNode := mapNode.Content[i]
valueNode := mapNode.Content[i+1]
if keyNode.Kind == yaml.ScalarNode && keyNode.Value == key {
if valueNode.Kind == yaml.ScalarNode {
if valueNode.Value == "true" {
@@ -989,5 +1037,3 @@ func setFloatInMap(mapNode *yaml.Node, key string, value float64) {
valueNode.Value = fmt.Sprintf("%g", value)
}
}

View File

@@ -170,21 +170,36 @@ func (h *KnowledgeHandler) RebuildIndex(c *gin.Context) {
// ScanKnowledgeBase 扫描知识库
func (h *KnowledgeHandler) ScanKnowledgeBase(c *gin.Context) {
if err := h.manager.ScanKnowledgeBase(); err != nil {
itemsToIndex, err := h.manager.ScanKnowledgeBase()
if err != nil {
h.logger.Error("扫描知识库失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// 异步重建索引
if len(itemsToIndex) == 0 {
c.JSON(http.StatusOK, gin.H{"message": "扫描完成,没有需要索引的新项或更新项"})
return
}
// 异步索引新添加或更新的项(增量索引)
go func() {
ctx := context.Background()
if err := h.indexer.RebuildIndex(ctx); err != nil {
h.logger.Error("重建索引失败", zap.Error(err))
h.logger.Info("开始增量索引", zap.Int("count", len(itemsToIndex)))
for i, itemID := range itemsToIndex {
if err := h.indexer.IndexItem(ctx, itemID); err != nil {
h.logger.Warn("索引知识项失败", zap.String("itemId", itemID), zap.Error(err))
continue
}
h.logger.Info("索引进度", zap.Int("current", i+1), zap.Int("total", len(itemsToIndex)))
}
h.logger.Info("增量索引完成", zap.Int("totalItems", len(itemsToIndex)))
}()
c.JSON(http.StatusOK, gin.H{"message": "扫描完成,索引重建已开始"})
c.JSON(http.StatusOK, gin.H{
"message": fmt.Sprintf("扫描完成,开始索引 %d 个新添加或更新的知识项", len(itemsToIndex)),
"items_to_index": len(itemsToIndex),
})
}
// GetRetrievalLogs 获取检索日志

View File

@@ -257,7 +257,7 @@ func (idx *Indexer) IndexItem(ctx context.Context, itemID string) error {
return fmt.Errorf("获取知识项失败: %w", err)
}
// 删除旧的向量
// 删除旧的向量(在 RebuildIndex 中已经统一清空,这里保留是为了单独调用 IndexItem 时的兼容性)
_, err = idx.db.Exec("DELETE FROM knowledge_embeddings WHERE item_id = ?", itemID)
if err != nil {
return fmt.Errorf("删除旧向量失败: %w", err)
@@ -338,12 +338,22 @@ func (idx *Indexer) RebuildIndex(ctx context.Context) error {
idx.logger.Info("开始重建索引", zap.Int("totalItems", len(itemIDs)))
// 在开始重建前先清空所有旧的向量确保进度从0开始
// 这样 GetIndexStatus 可以准确反映重建进度
_, err = idx.db.Exec("DELETE FROM knowledge_embeddings")
if err != nil {
idx.logger.Warn("清空旧索引失败", zap.Error(err))
// 继续执行,即使清空失败也尝试重建
} else {
idx.logger.Info("已清空旧索引,开始重建")
}
for i, itemID := range itemIDs {
if err := idx.IndexItem(ctx, itemID); err != nil {
idx.logger.Warn("索引知识项失败", zap.String("itemId", itemID), zap.Error(err))
continue
}
idx.logger.Debug("索引进度", zap.Int("current", i+1), zap.Int("total", len(itemIDs)))
idx.logger.Info("索引进度", zap.Int("current", i+1), zap.Int("total", len(itemIDs)))
}
idx.logger.Info("索引重建完成", zap.Int("totalItems", len(itemIDs)))

View File

@@ -31,18 +31,21 @@ func NewManager(db *sql.DB, basePath string, logger *zap.Logger) *Manager {
}
// ScanKnowledgeBase 扫描知识库目录,更新数据库
func (m *Manager) ScanKnowledgeBase() error {
// 返回需要索引的知识项ID列表新添加的或更新的
func (m *Manager) ScanKnowledgeBase() ([]string, error) {
if m.basePath == "" {
return fmt.Errorf("知识库路径未配置")
return nil, fmt.Errorf("知识库路径未配置")
}
// 确保目录存在
if err := os.MkdirAll(m.basePath, 0755); err != nil {
return fmt.Errorf("创建知识库目录失败: %w", err)
return nil, fmt.Errorf("创建知识库目录失败: %w", err)
}
var itemsToIndex []string
// 遍历知识库目录
return filepath.WalkDir(m.basePath, func(path string, d fs.DirEntry, err error) error {
err := filepath.WalkDir(m.basePath, func(path string, d fs.DirEntry, err error) error {
if err != nil {
return err
}
@@ -77,10 +80,12 @@ func (m *Manager) ScanKnowledgeBase() error {
// 检查是否已存在
var existingID string
var existingContent string
var existingUpdatedAt time.Time
err = m.db.QueryRow(
"SELECT id FROM knowledge_base_items WHERE file_path = ?",
"SELECT id, content, updated_at FROM knowledge_base_items WHERE file_path = ?",
path,
).Scan(&existingID)
).Scan(&existingID, &existingContent, &existingUpdatedAt)
if err == sql.ErrNoRows {
// 创建新项
@@ -94,22 +99,38 @@ func (m *Manager) ScanKnowledgeBase() error {
return fmt.Errorf("插入知识项失败: %w", err)
}
m.logger.Info("添加知识项", zap.String("id", id), zap.String("title", title), zap.String("category", category))
// 新添加的项需要索引
itemsToIndex = append(itemsToIndex, id)
} else if err == nil {
// 更新现有项
_, err = m.db.Exec(
"UPDATE knowledge_base_items SET category = ?, title = ?, content = ?, updated_at = ? WHERE id = ?",
category, title, string(content), time.Now(), existingID,
)
if err != nil {
return fmt.Errorf("更新知识项失败: %w", err)
// 检查内容是否有变化
contentChanged := existingContent != string(content)
if contentChanged {
// 更新现有项
_, err = m.db.Exec(
"UPDATE knowledge_base_items SET category = ?, title = ?, content = ?, updated_at = ? WHERE id = ?",
category, title, string(content), time.Now(), existingID,
)
if err != nil {
return fmt.Errorf("更新知识项失败: %w", err)
}
m.logger.Info("更新知识项", zap.String("id", existingID), zap.String("title", title))
// 内容已更新的项需要重新索引
itemsToIndex = append(itemsToIndex, existingID)
} else {
m.logger.Debug("知识项未变化,跳过", zap.String("id", existingID), zap.String("title", title))
}
m.logger.Debug("更新知识项", zap.String("id", existingID), zap.String("title", title))
} else {
return fmt.Errorf("查询知识项失败: %w", err)
}
return nil
})
if err != nil {
return nil, err
}
return itemsToIndex, nil
}
// GetCategories 获取所有分类(风险类型)
@@ -170,7 +191,7 @@ func (m *Manager) GetItems(category string) ([]*KnowledgeItem, error) {
time.RFC3339,
time.RFC3339Nano,
}
// 解析创建时间
if createdAt != "" {
for _, format := range timeFormats {
@@ -181,7 +202,7 @@ func (m *Manager) GetItems(category string) ([]*KnowledgeItem, error) {
}
}
}
// 解析更新时间
if updatedAt != "" {
for _, format := range timeFormats {
@@ -192,7 +213,7 @@ func (m *Manager) GetItems(category string) ([]*KnowledgeItem, error) {
}
}
}
// 如果更新时间为空,使用创建时间
if item.UpdatedAt.IsZero() && !item.CreatedAt.IsZero() {
item.UpdatedAt = item.CreatedAt
@@ -230,7 +251,7 @@ func (m *Manager) GetItem(id string) (*KnowledgeItem, error) {
time.RFC3339,
time.RFC3339Nano,
}
// 解析创建时间
if createdAt != "" {
for _, format := range timeFormats {
@@ -241,7 +262,7 @@ func (m *Manager) GetItem(id string) (*KnowledgeItem, error) {
}
}
}
// 解析更新时间
if updatedAt != "" {
for _, format := range timeFormats {
@@ -252,7 +273,7 @@ func (m *Manager) GetItem(id string) (*KnowledgeItem, error) {
}
}
}
// 如果更新时间为空,使用创建时间
if item.UpdatedAt.IsZero() && !item.CreatedAt.IsZero() {
item.UpdatedAt = item.CreatedAt
@@ -418,10 +439,10 @@ func (m *Manager) GetIndexStatus() (map[string]interface{}, error) {
isComplete := indexedItems >= totalItems && totalItems > 0
return map[string]interface{}{
"total_items": totalItems,
"indexed_items": indexedItems,
"total_items": totalItems,
"indexed_items": indexedItems,
"progress_percent": progressPercent,
"is_complete": isComplete,
"is_complete": isComplete,
}, nil
}
@@ -472,17 +493,17 @@ func (m *Manager) GetRetrievalLogs(conversationID, messageID string, limit int)
time.RFC3339,
time.RFC3339Nano,
}
for _, format := range timeFormats {
log.CreatedAt, err = time.Parse(format, createdAt)
if err == nil && !log.CreatedAt.IsZero() {
break
}
}
// 如果所有格式都失败,记录警告但继续处理
if log.CreatedAt.IsZero() {
m.logger.Warn("解析检索日志时间失败",
m.logger.Warn("解析检索日志时间失败",
zap.String("timeStr", createdAt),
zap.Error(err),
)
@@ -519,4 +540,3 @@ func (m *Manager) DeleteRetrievalLog(id string) error {
return nil
}

View File

@@ -22,6 +22,32 @@ async function loadKnowledgeCategories() {
throw new Error('获取分类失败');
}
const data = await response.json();
// 检查知识库功能是否启用
if (data.enabled === false) {
// 功能未启用,显示友好提示
const container = document.getElementById('knowledge-items-list');
if (container) {
container.innerHTML = `
<div class="empty-state" style="text-align: center; padding: 40px 20px;">
<div style="font-size: 48px; margin-bottom: 20px;">📚</div>
<h3 style="margin-bottom: 10px; color: #666;">知识库功能未启用</h3>
<p style="color: #999; margin-bottom: 20px;">${data.message || '请前往系统设置启用知识检索功能'}</p>
<button onclick="switchToSettings()" style="
background: #007bff;
color: white;
border: none;
padding: 10px 20px;
border-radius: 5px;
cursor: pointer;
font-size: 14px;
">前往设置</button>
</div>
`;
}
return [];
}
knowledgeCategories = data.categories || [];
// 更新分类筛选下拉框
@@ -43,7 +69,10 @@ async function loadKnowledgeCategories() {
return knowledgeCategories;
} catch (error) {
console.error('加载分类失败:', error);
showNotification('加载分类失败: ' + error.message, 'error');
// 只在非功能未启用的情况下显示错误
if (!error.message.includes('知识库功能未启用')) {
showNotification('加载分类失败: ' + error.message, 'error');
}
return [];
}
}
@@ -70,12 +99,42 @@ async function loadKnowledgeItems(category = '') {
throw new Error('获取知识项失败');
}
const data = await response.json();
// 检查知识库功能是否启用
if (data.enabled === false) {
// 功能未启用,显示友好提示(如果还没有显示的话)
const container = document.getElementById('knowledge-items-list');
if (container && !container.querySelector('.empty-state')) {
container.innerHTML = `
<div class="empty-state" style="text-align: center; padding: 40px 20px;">
<div style="font-size: 48px; margin-bottom: 20px;">📚</div>
<h3 style="margin-bottom: 10px; color: #666;">知识库功能未启用</h3>
<p style="color: #999; margin-bottom: 20px;">${data.message || '请前往系统设置启用知识检索功能'}</p>
<button onclick="switchToSettings()" style="
background: #007bff;
color: white;
border: none;
padding: 10px 20px;
border-radius: 5px;
cursor: pointer;
font-size: 14px;
">前往设置</button>
</div>
`;
}
knowledgeItems = [];
return [];
}
knowledgeItems = data.items || [];
renderKnowledgeItems(knowledgeItems);
return knowledgeItems;
} catch (error) {
console.error('加载知识项失败:', error);
showNotification('加载知识项失败: ' + error.message, 'error');
// 只在非功能未启用的情况下显示错误
if (!error.message.includes('知识库功能未启用')) {
showNotification('加载知识项失败: ' + error.message, 'error');
}
return [];
}
}
@@ -252,6 +311,17 @@ async function updateIndexProgress() {
const progressContainer = document.getElementById('knowledge-index-progress');
if (!progressContainer) return;
// 检查知识库功能是否启用
if (status.enabled === false) {
// 功能未启用,隐藏进度条
progressContainer.style.display = 'none';
if (indexProgressInterval) {
clearInterval(indexProgressInterval);
indexProgressInterval = null;
}
return;
}
const totalItems = status.total_items || 0;
const indexedItems = status.indexed_items || 0;
const progressPercent = status.progress_percent || 0;
@@ -373,16 +443,35 @@ async function refreshKnowledgeBase() {
if (!response.ok) {
throw new Error('扫描知识库失败');
}
showNotification('扫描完成,索引重建已开始', 'success');
const data = await response.json();
// 根据返回的消息显示不同的提示
if (data.items_to_index && data.items_to_index > 0) {
showNotification(`扫描完成,开始索引 ${data.items_to_index} 个新添加或更新的知识项`, 'success');
} else {
showNotification(data.message || '扫描完成,没有需要索引的新项或更新项', 'success');
}
// 重新加载知识项
await loadKnowledgeCategories();
await loadKnowledgeItems();
// 开始轮询进度
// 停止现有的轮询
if (indexProgressInterval) {
clearInterval(indexProgressInterval);
indexProgressInterval = null;
}
// 如果有需要索引的项,等待一小段时间后立即更新进度
if (data.items_to_index && data.items_to_index > 0) {
await new Promise(resolve => setTimeout(resolve, 500));
updateIndexProgress();
// 开始轮询进度每2秒刷新一次
if (!indexProgressInterval) {
indexProgressInterval = setInterval(updateIndexProgress, 2000);
}
} else {
// 没有需要索引的项,也更新一次以显示当前状态
updateIndexProgress();
}
updateIndexProgress(); // 立即更新一次
} catch (error) {
console.error('刷新知识库失败:', error);
showNotification('刷新知识库失败: ' + error.message, 'error');
@@ -396,6 +485,31 @@ async function rebuildKnowledgeIndex() {
return;
}
showNotification('正在重建索引...', 'info');
// 先停止现有的轮询
if (indexProgressInterval) {
clearInterval(indexProgressInterval);
indexProgressInterval = null;
}
// 立即显示"正在重建"状态,因为重建开始时会清空旧索引
const progressContainer = document.getElementById('knowledge-index-progress');
if (progressContainer) {
progressContainer.style.display = 'block';
progressContainer.innerHTML = `
<div class="knowledge-index-progress">
<div class="progress-header">
<span class="progress-icon">🔨</span>
<span class="progress-text">正在重建索引: 准备中...</span>
</div>
<div class="progress-bar-container">
<div class="progress-bar" style="width: 0%"></div>
</div>
<div class="progress-hint">索引构建完成后,语义搜索功能将可用</div>
</div>
`;
}
const response = await apiFetch('/api/knowledge/index', {
method: 'POST'
});
@@ -404,11 +518,16 @@ async function rebuildKnowledgeIndex() {
}
showNotification('索引重建已开始,将在后台进行', 'success');
// 开始轮询进度
if (indexProgressInterval) {
clearInterval(indexProgressInterval);
// 等待一小段时间,确保后端已经开始处理并清空了旧索引
await new Promise(resolve => setTimeout(resolve, 500));
// 立即更新一次进度
updateIndexProgress();
// 开始轮询进度每2秒刷新一次比默认的3秒更频繁
if (!indexProgressInterval) {
indexProgressInterval = setInterval(updateIndexProgress, 2000);
}
updateIndexProgress(); // 立即更新一次
} catch (error) {
console.error('重建索引失败:', error);
showNotification('重建索引失败: ' + error.message, 'error');
@@ -1528,8 +1647,8 @@ function formatTime(timeStr) {
// 显示通知
function showNotification(message, type = 'info') {
// 如果存在全局通知系统,使用它
if (typeof window.showNotification === 'function') {
// 如果存在全局通知系统(且不是当前函数),使用它
if (typeof window.showNotification === 'function' && window.showNotification !== showNotification) {
window.showNotification(message, type);
return;
}
@@ -1680,6 +1799,39 @@ window.addEventListener('click', function(event) {
}
});
// 切换到设置页面(用于功能未启用时的提示)
function switchToSettings() {
if (typeof switchPage === 'function') {
switchPage('settings');
// 等待设置页面加载后,切换到知识库配置部分
setTimeout(() => {
if (typeof switchSettingsSection === 'function') {
// 查找知识库配置部分(通常在基本设置中)
const knowledgeSection = document.querySelector('[data-section="knowledge"]');
if (knowledgeSection) {
switchSettingsSection('knowledge');
} else {
// 如果没有独立的知识库部分,切换到基本设置
switchSettingsSection('basic');
// 滚动到知识库配置区域
setTimeout(() => {
const knowledgeEnabledCheckbox = document.getElementById('knowledge-enabled');
if (knowledgeEnabledCheckbox) {
knowledgeEnabledCheckbox.scrollIntoView({ behavior: 'smooth', block: 'center' });
// 高亮显示
knowledgeEnabledCheckbox.parentElement.style.transition = 'background-color 0.3s';
knowledgeEnabledCheckbox.parentElement.style.backgroundColor = '#e3f2fd';
setTimeout(() => {
knowledgeEnabledCheckbox.parentElement.style.backgroundColor = '';
}, 2000);
}
}, 300);
}
}
}, 100);
}
}
// 自定义下拉组件交互
document.addEventListener('DOMContentLoaded', function() {
const wrapper = document.getElementById('knowledge-category-filter-wrapper');