Add files via upload

This commit is contained in:
公明
2025-12-27 19:42:21 +08:00
committed by GitHub
parent cb45b9e540
commit 65957b2013
6 changed files with 672 additions and 129 deletions
+86 -40
View File
@@ -16,6 +16,7 @@ import (
"cyberstrike-ai/internal/knowledge"
"cyberstrike-ai/internal/mcp"
"cyberstrike-ai/internal/security"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
"gopkg.in/yaml.v3"
@@ -29,19 +30,29 @@ type RetrieverUpdater interface {
UpdateConfig(config *knowledge.RetrievalConfig)
}
// KnowledgeInitializer 知识库初始化器接口
type KnowledgeInitializer func() (*KnowledgeHandler, error)
// AppUpdater App更新接口(用于更新App中的知识库组件)
type AppUpdater interface {
UpdateKnowledgeComponents(handler *KnowledgeHandler, manager interface{}, retriever interface{}, indexer interface{})
}
// ConfigHandler 配置处理器
type ConfigHandler struct {
configPath string
config *config.Config
mcpServer *mcp.Server
executor *security.Executor
agent AgentUpdater // Agent接口,用于更新Agent配置
attackChainHandler AttackChainUpdater // 攻击链处理器接口,用于更新配置
externalMCPMgr *mcp.ExternalMCPManager // 外部MCP管理器
knowledgeToolRegistrar KnowledgeToolRegistrar // 知识库工具注册器(可选)
retrieverUpdater RetrieverUpdater // 检索器更新器(可选)
logger *zap.Logger
mu sync.RWMutex
configPath string
config *config.Config
mcpServer *mcp.Server
executor *security.Executor
agent AgentUpdater // Agent接口,用于更新Agent配置
attackChainHandler AttackChainUpdater // 攻击链处理器接口,用于更新配置
externalMCPMgr *mcp.ExternalMCPManager // 外部MCP管理器
knowledgeToolRegistrar KnowledgeToolRegistrar // 知识库工具注册器(可选)
retrieverUpdater RetrieverUpdater // 检索器更新器(可选)
knowledgeInitializer KnowledgeInitializer // 知识库初始化器(可选)
appUpdater AppUpdater // App更新器(可选)
logger *zap.Logger
mu sync.RWMutex
}
// AttackChainUpdater 攻击链处理器更新接口
@@ -83,12 +94,26 @@ func (h *ConfigHandler) SetRetrieverUpdater(updater RetrieverUpdater) {
h.retrieverUpdater = updater
}
// SetKnowledgeInitializer 设置知识库初始化器
func (h *ConfigHandler) SetKnowledgeInitializer(initializer KnowledgeInitializer) {
h.mu.Lock()
defer h.mu.Unlock()
h.knowledgeInitializer = initializer
}
// SetAppUpdater 设置App更新器
func (h *ConfigHandler) SetAppUpdater(updater AppUpdater) {
h.mu.Lock()
defer h.mu.Unlock()
h.appUpdater = updater
}
// GetConfigResponse 获取配置响应
type GetConfigResponse struct {
OpenAI config.OpenAIConfig `json:"openai"`
MCP config.MCPConfig `json:"mcp"`
Tools []ToolConfigInfo `json:"tools"`
Agent config.AgentConfig `json:"agent"`
OpenAI config.OpenAIConfig `json:"openai"`
MCP config.MCPConfig `json:"mcp"`
Tools []ToolConfigInfo `json:"tools"`
Agent config.AgentConfig `json:"agent"`
Knowledge config.KnowledgeConfig `json:"knowledge"`
}
@@ -127,7 +152,7 @@ func (h *ConfigHandler) GetConfig(c *gin.Context) {
tools[len(tools)-1].Description = desc
}
}
// 从MCP服务器获取所有已注册的工具(包括直接注册的工具,如知识检索工具)
if h.mcpServer != nil {
mcpTools := h.mcpServer.GetAllTools()
@@ -287,7 +312,7 @@ func (h *ConfigHandler) GetTools(c *gin.Context) {
allTools = append(allTools, toolInfo)
}
// 从MCP服务器获取所有已注册的工具(包括直接注册的工具,如知识检索工具)
if h.mcpServer != nil {
mcpTools := h.mcpServer.GetAllTools()
@@ -296,7 +321,7 @@ func (h *ConfigHandler) GetTools(c *gin.Context) {
if configToolMap[mcpTool.Name] {
continue
}
description := mcpTool.ShortDescription
if description == "" {
description = mcpTool.Description
@@ -304,14 +329,14 @@ func (h *ConfigHandler) GetTools(c *gin.Context) {
if len(description) > 100 {
description = description[:100] + "..."
}
toolInfo := ToolConfigInfo{
Name: mcpTool.Name,
Description: description,
Enabled: true, // 直接注册的工具默认启用
IsExternal: false,
}
// 如果有关键词,进行搜索过滤
if searchTermLower != "" {
nameLower := strings.ToLower(toolInfo.Name)
@@ -320,7 +345,7 @@ func (h *ConfigHandler) GetTools(c *gin.Context) {
continue // 不匹配,跳过
}
}
allTools = append(allTools, toolInfo)
}
}
@@ -336,7 +361,7 @@ func (h *ConfigHandler) GetTools(c *gin.Context) {
} else {
// 获取外部MCP配置,用于判断启用状态
externalMCPConfigs := h.externalMCPMgr.GetConfigs()
for _, externalTool := range externalTools {
// 解析工具名称:mcpName::toolName
var mcpName, actualToolName string
@@ -434,7 +459,7 @@ type UpdateConfigRequest struct {
OpenAI *config.OpenAIConfig `json:"openai,omitempty"`
MCP *config.MCPConfig `json:"mcp,omitempty"`
Tools []ToolEnableStatus `json:"tools,omitempty"`
Agent *config.AgentConfig `json:"agent,omitempty"`
Agent *config.AgentConfig `json:"agent,omitempty"`
Knowledge *config.KnowledgeConfig `json:"knowledge,omitempty"`
}
@@ -541,12 +566,12 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) {
h.logger.Warn("外部MCP配置不存在", zap.String("mcp", mcpName))
continue
}
// 初始化ToolEnabled map
if cfg.ToolEnabled == nil {
cfg.ToolEnabled = make(map[string]bool)
}
// 更新每个工具的启用状态
for toolName, enabled := range toolStates {
cfg.ToolEnabled[toolName] = enabled
@@ -556,7 +581,7 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) {
zap.Bool("enabled", enabled),
)
}
// 检查是否有任何工具启用,如果有则启用MCP
hasEnabledTool := false
for _, enabled := range cfg.ToolEnabled {
@@ -565,21 +590,21 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) {
break
}
}
// 如果MCP之前未启用,但现在有工具启用,则启用MCP
// 如果MCP之前已启用,保持启用状态(允许部分工具禁用)
if !cfg.ExternalMCPEnable && hasEnabledTool {
cfg.ExternalMCPEnable = true
h.logger.Info("自动启用外部MCP(因为有工具启用)", zap.String("mcp", mcpName))
}
h.config.ExternalMCP.Servers[mcpName] = cfg
}
// 同步更新 externalMCPMgr 中的配置,确保 GetConfigs() 返回最新配置
// 在循环外部统一更新,避免重复调用
h.externalMCPMgr.LoadConfigs(&h.config.ExternalMCP)
// 处理MCP连接状态(异步启动,避免阻塞)
for mcpName := range externalMCPToolMap {
cfg := h.config.ExternalMCP.Servers[mcpName]
@@ -618,18 +643,41 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) {
// ApplyConfig 应用配置(重新加载并重启相关服务)
func (h *ConfigHandler) ApplyConfig(c *gin.Context) {
// 先检查是否需要动态初始化知识库(在锁外执行,避免阻塞其他请求)
var needInitKnowledge bool
var knowledgeInitializer KnowledgeInitializer
h.mu.RLock()
needInitKnowledge = h.config.Knowledge.Enabled && h.knowledgeToolRegistrar == nil && h.knowledgeInitializer != nil
if needInitKnowledge {
knowledgeInitializer = h.knowledgeInitializer
}
h.mu.RUnlock()
// 如果需要动态初始化知识库,在锁外执行(这是耗时操作)
if needInitKnowledge {
h.logger.Info("检测到知识库从禁用变为启用,开始动态初始化知识库组件")
if _, err := knowledgeInitializer(); err != nil {
h.logger.Error("动态初始化知识库失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "初始化知识库失败: " + err.Error()})
return
}
h.logger.Info("知识库动态初始化完成,工具已注册")
}
// 现在获取写锁,执行快速的操作
h.mu.Lock()
defer h.mu.Unlock()
// 重新注册工具(根据新的启用状态)
h.logger.Info("重新注册工具")
// 清空MCP服务器中的工具
h.mcpServer.ClearTools()
// 重新注册安全工具
h.executor.RegisterTools(h.mcpServer)
// 如果知识库启用,重新注册知识库工具
if h.config.Knowledge.Enabled && h.knowledgeToolRegistrar != nil {
h.logger.Info("重新注册知识库工具")
@@ -673,7 +721,7 @@ func (h *ConfigHandler) ApplyConfig(c *gin.Context) {
)
c.JSON(http.StatusOK, gin.H{
"message": "配置已应用",
"message": "配置已应用",
"tools_count": len(h.config.Security.Tools),
})
}
@@ -847,7 +895,7 @@ func updateKnowledgeConfig(doc *yaml.Node, cfg config.KnowledgeConfig) {
knowledgeNode := ensureMap(root, "knowledge")
setBoolInMap(knowledgeNode, "enabled", cfg.Enabled)
setStringInMap(knowledgeNode, "base_path", cfg.BasePath)
// 更新嵌入配置
embeddingNode := ensureMap(knowledgeNode, "embedding")
setStringInMap(embeddingNode, "provider", cfg.Embedding.Provider)
@@ -858,7 +906,7 @@ func updateKnowledgeConfig(doc *yaml.Node, cfg config.KnowledgeConfig) {
if cfg.Embedding.APIKey != "" {
setStringInMap(embeddingNode, "api_key", cfg.Embedding.APIKey)
}
// 更新检索配置
retrievalNode := ensureMap(knowledgeNode, "retrieval")
setIntInMap(retrievalNode, "top_k", cfg.Retrieval.TopK)
@@ -940,14 +988,14 @@ func findBoolInMap(mapNode *yaml.Node, key string) *bool {
if mapNode == nil || mapNode.Kind != yaml.MappingNode {
return nil
}
for i := 0; i < len(mapNode.Content); i += 2 {
if i+1 >= len(mapNode.Content) {
break
}
keyNode := mapNode.Content[i]
valueNode := mapNode.Content[i+1]
if keyNode.Kind == yaml.ScalarNode && keyNode.Value == key {
if valueNode.Kind == yaml.ScalarNode {
if valueNode.Value == "true" {
@@ -989,5 +1037,3 @@ func setFloatInMap(mapNode *yaml.Node, key string, value float64) {
valueNode.Value = fmt.Sprintf("%g", value)
}
}
+20 -5
View File
@@ -170,21 +170,36 @@ func (h *KnowledgeHandler) RebuildIndex(c *gin.Context) {
// ScanKnowledgeBase 扫描知识库
func (h *KnowledgeHandler) ScanKnowledgeBase(c *gin.Context) {
if err := h.manager.ScanKnowledgeBase(); err != nil {
itemsToIndex, err := h.manager.ScanKnowledgeBase()
if err != nil {
h.logger.Error("扫描知识库失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// 异步重建索引
if len(itemsToIndex) == 0 {
c.JSON(http.StatusOK, gin.H{"message": "扫描完成,没有需要索引的新项或更新项"})
return
}
// 异步索引新添加或更新的项(增量索引)
go func() {
ctx := context.Background()
if err := h.indexer.RebuildIndex(ctx); err != nil {
h.logger.Error("重建索引失败", zap.Error(err))
h.logger.Info("开始增量索引", zap.Int("count", len(itemsToIndex)))
for i, itemID := range itemsToIndex {
if err := h.indexer.IndexItem(ctx, itemID); err != nil {
h.logger.Warn("索引知识项失败", zap.String("itemId", itemID), zap.Error(err))
continue
}
h.logger.Info("索引进度", zap.Int("current", i+1), zap.Int("total", len(itemsToIndex)))
}
h.logger.Info("增量索引完成", zap.Int("totalItems", len(itemsToIndex)))
}()
c.JSON(http.StatusOK, gin.H{"message": "扫描完成,索引重建已开始"})
c.JSON(http.StatusOK, gin.H{
"message": fmt.Sprintf("扫描完成,开始索引 %d 个新添加或更新的知识项", len(itemsToIndex)),
"items_to_index": len(itemsToIndex),
})
}
// GetRetrievalLogs 获取检索日志