mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-05-27 17:52:28 +02:00
Add files via upload
This commit is contained in:
+86
-40
@@ -16,6 +16,7 @@ import (
|
||||
"cyberstrike-ai/internal/knowledge"
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
"cyberstrike-ai/internal/security"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
"gopkg.in/yaml.v3"
|
||||
@@ -29,19 +30,29 @@ type RetrieverUpdater interface {
|
||||
UpdateConfig(config *knowledge.RetrievalConfig)
|
||||
}
|
||||
|
||||
// KnowledgeInitializer 知识库初始化器接口
|
||||
type KnowledgeInitializer func() (*KnowledgeHandler, error)
|
||||
|
||||
// AppUpdater App更新接口(用于更新App中的知识库组件)
|
||||
type AppUpdater interface {
|
||||
UpdateKnowledgeComponents(handler *KnowledgeHandler, manager interface{}, retriever interface{}, indexer interface{})
|
||||
}
|
||||
|
||||
// ConfigHandler 配置处理器
|
||||
type ConfigHandler struct {
|
||||
configPath string
|
||||
config *config.Config
|
||||
mcpServer *mcp.Server
|
||||
executor *security.Executor
|
||||
agent AgentUpdater // Agent接口,用于更新Agent配置
|
||||
attackChainHandler AttackChainUpdater // 攻击链处理器接口,用于更新配置
|
||||
externalMCPMgr *mcp.ExternalMCPManager // 外部MCP管理器
|
||||
knowledgeToolRegistrar KnowledgeToolRegistrar // 知识库工具注册器(可选)
|
||||
retrieverUpdater RetrieverUpdater // 检索器更新器(可选)
|
||||
logger *zap.Logger
|
||||
mu sync.RWMutex
|
||||
configPath string
|
||||
config *config.Config
|
||||
mcpServer *mcp.Server
|
||||
executor *security.Executor
|
||||
agent AgentUpdater // Agent接口,用于更新Agent配置
|
||||
attackChainHandler AttackChainUpdater // 攻击链处理器接口,用于更新配置
|
||||
externalMCPMgr *mcp.ExternalMCPManager // 外部MCP管理器
|
||||
knowledgeToolRegistrar KnowledgeToolRegistrar // 知识库工具注册器(可选)
|
||||
retrieverUpdater RetrieverUpdater // 检索器更新器(可选)
|
||||
knowledgeInitializer KnowledgeInitializer // 知识库初始化器(可选)
|
||||
appUpdater AppUpdater // App更新器(可选)
|
||||
logger *zap.Logger
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// AttackChainUpdater 攻击链处理器更新接口
|
||||
@@ -83,12 +94,26 @@ func (h *ConfigHandler) SetRetrieverUpdater(updater RetrieverUpdater) {
|
||||
h.retrieverUpdater = updater
|
||||
}
|
||||
|
||||
// SetKnowledgeInitializer 设置知识库初始化器
|
||||
func (h *ConfigHandler) SetKnowledgeInitializer(initializer KnowledgeInitializer) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
h.knowledgeInitializer = initializer
|
||||
}
|
||||
|
||||
// SetAppUpdater 设置App更新器
|
||||
func (h *ConfigHandler) SetAppUpdater(updater AppUpdater) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
h.appUpdater = updater
|
||||
}
|
||||
|
||||
// GetConfigResponse 获取配置响应
|
||||
type GetConfigResponse struct {
|
||||
OpenAI config.OpenAIConfig `json:"openai"`
|
||||
MCP config.MCPConfig `json:"mcp"`
|
||||
Tools []ToolConfigInfo `json:"tools"`
|
||||
Agent config.AgentConfig `json:"agent"`
|
||||
OpenAI config.OpenAIConfig `json:"openai"`
|
||||
MCP config.MCPConfig `json:"mcp"`
|
||||
Tools []ToolConfigInfo `json:"tools"`
|
||||
Agent config.AgentConfig `json:"agent"`
|
||||
Knowledge config.KnowledgeConfig `json:"knowledge"`
|
||||
}
|
||||
|
||||
@@ -127,7 +152,7 @@ func (h *ConfigHandler) GetConfig(c *gin.Context) {
|
||||
tools[len(tools)-1].Description = desc
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// 从MCP服务器获取所有已注册的工具(包括直接注册的工具,如知识检索工具)
|
||||
if h.mcpServer != nil {
|
||||
mcpTools := h.mcpServer.GetAllTools()
|
||||
@@ -287,7 +312,7 @@ func (h *ConfigHandler) GetTools(c *gin.Context) {
|
||||
|
||||
allTools = append(allTools, toolInfo)
|
||||
}
|
||||
|
||||
|
||||
// 从MCP服务器获取所有已注册的工具(包括直接注册的工具,如知识检索工具)
|
||||
if h.mcpServer != nil {
|
||||
mcpTools := h.mcpServer.GetAllTools()
|
||||
@@ -296,7 +321,7 @@ func (h *ConfigHandler) GetTools(c *gin.Context) {
|
||||
if configToolMap[mcpTool.Name] {
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
description := mcpTool.ShortDescription
|
||||
if description == "" {
|
||||
description = mcpTool.Description
|
||||
@@ -304,14 +329,14 @@ func (h *ConfigHandler) GetTools(c *gin.Context) {
|
||||
if len(description) > 100 {
|
||||
description = description[:100] + "..."
|
||||
}
|
||||
|
||||
|
||||
toolInfo := ToolConfigInfo{
|
||||
Name: mcpTool.Name,
|
||||
Description: description,
|
||||
Enabled: true, // 直接注册的工具默认启用
|
||||
IsExternal: false,
|
||||
}
|
||||
|
||||
|
||||
// 如果有关键词,进行搜索过滤
|
||||
if searchTermLower != "" {
|
||||
nameLower := strings.ToLower(toolInfo.Name)
|
||||
@@ -320,7 +345,7 @@ func (h *ConfigHandler) GetTools(c *gin.Context) {
|
||||
continue // 不匹配,跳过
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
allTools = append(allTools, toolInfo)
|
||||
}
|
||||
}
|
||||
@@ -336,7 +361,7 @@ func (h *ConfigHandler) GetTools(c *gin.Context) {
|
||||
} else {
|
||||
// 获取外部MCP配置,用于判断启用状态
|
||||
externalMCPConfigs := h.externalMCPMgr.GetConfigs()
|
||||
|
||||
|
||||
for _, externalTool := range externalTools {
|
||||
// 解析工具名称:mcpName::toolName
|
||||
var mcpName, actualToolName string
|
||||
@@ -434,7 +459,7 @@ type UpdateConfigRequest struct {
|
||||
OpenAI *config.OpenAIConfig `json:"openai,omitempty"`
|
||||
MCP *config.MCPConfig `json:"mcp,omitempty"`
|
||||
Tools []ToolEnableStatus `json:"tools,omitempty"`
|
||||
Agent *config.AgentConfig `json:"agent,omitempty"`
|
||||
Agent *config.AgentConfig `json:"agent,omitempty"`
|
||||
Knowledge *config.KnowledgeConfig `json:"knowledge,omitempty"`
|
||||
}
|
||||
|
||||
@@ -541,12 +566,12 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) {
|
||||
h.logger.Warn("外部MCP配置不存在", zap.String("mcp", mcpName))
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
// 初始化ToolEnabled map
|
||||
if cfg.ToolEnabled == nil {
|
||||
cfg.ToolEnabled = make(map[string]bool)
|
||||
}
|
||||
|
||||
|
||||
// 更新每个工具的启用状态
|
||||
for toolName, enabled := range toolStates {
|
||||
cfg.ToolEnabled[toolName] = enabled
|
||||
@@ -556,7 +581,7 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) {
|
||||
zap.Bool("enabled", enabled),
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
// 检查是否有任何工具启用,如果有则启用MCP
|
||||
hasEnabledTool := false
|
||||
for _, enabled := range cfg.ToolEnabled {
|
||||
@@ -565,21 +590,21 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// 如果MCP之前未启用,但现在有工具启用,则启用MCP
|
||||
// 如果MCP之前已启用,保持启用状态(允许部分工具禁用)
|
||||
if !cfg.ExternalMCPEnable && hasEnabledTool {
|
||||
cfg.ExternalMCPEnable = true
|
||||
h.logger.Info("自动启用外部MCP(因为有工具启用)", zap.String("mcp", mcpName))
|
||||
}
|
||||
|
||||
|
||||
h.config.ExternalMCP.Servers[mcpName] = cfg
|
||||
}
|
||||
|
||||
|
||||
// 同步更新 externalMCPMgr 中的配置,确保 GetConfigs() 返回最新配置
|
||||
// 在循环外部统一更新,避免重复调用
|
||||
h.externalMCPMgr.LoadConfigs(&h.config.ExternalMCP)
|
||||
|
||||
|
||||
// 处理MCP连接状态(异步启动,避免阻塞)
|
||||
for mcpName := range externalMCPToolMap {
|
||||
cfg := h.config.ExternalMCP.Servers[mcpName]
|
||||
@@ -618,18 +643,41 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) {
|
||||
|
||||
// ApplyConfig 应用配置(重新加载并重启相关服务)
|
||||
func (h *ConfigHandler) ApplyConfig(c *gin.Context) {
|
||||
// 先检查是否需要动态初始化知识库(在锁外执行,避免阻塞其他请求)
|
||||
var needInitKnowledge bool
|
||||
var knowledgeInitializer KnowledgeInitializer
|
||||
|
||||
h.mu.RLock()
|
||||
needInitKnowledge = h.config.Knowledge.Enabled && h.knowledgeToolRegistrar == nil && h.knowledgeInitializer != nil
|
||||
if needInitKnowledge {
|
||||
knowledgeInitializer = h.knowledgeInitializer
|
||||
}
|
||||
h.mu.RUnlock()
|
||||
|
||||
// 如果需要动态初始化知识库,在锁外执行(这是耗时操作)
|
||||
if needInitKnowledge {
|
||||
h.logger.Info("检测到知识库从禁用变为启用,开始动态初始化知识库组件")
|
||||
if _, err := knowledgeInitializer(); err != nil {
|
||||
h.logger.Error("动态初始化知识库失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "初始化知识库失败: " + err.Error()})
|
||||
return
|
||||
}
|
||||
h.logger.Info("知识库动态初始化完成,工具已注册")
|
||||
}
|
||||
|
||||
// 现在获取写锁,执行快速的操作
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
// 重新注册工具(根据新的启用状态)
|
||||
h.logger.Info("重新注册工具")
|
||||
|
||||
|
||||
// 清空MCP服务器中的工具
|
||||
h.mcpServer.ClearTools()
|
||||
|
||||
|
||||
// 重新注册安全工具
|
||||
h.executor.RegisterTools(h.mcpServer)
|
||||
|
||||
|
||||
// 如果知识库启用,重新注册知识库工具
|
||||
if h.config.Knowledge.Enabled && h.knowledgeToolRegistrar != nil {
|
||||
h.logger.Info("重新注册知识库工具")
|
||||
@@ -673,7 +721,7 @@ func (h *ConfigHandler) ApplyConfig(c *gin.Context) {
|
||||
)
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "配置已应用",
|
||||
"message": "配置已应用",
|
||||
"tools_count": len(h.config.Security.Tools),
|
||||
})
|
||||
}
|
||||
@@ -847,7 +895,7 @@ func updateKnowledgeConfig(doc *yaml.Node, cfg config.KnowledgeConfig) {
|
||||
knowledgeNode := ensureMap(root, "knowledge")
|
||||
setBoolInMap(knowledgeNode, "enabled", cfg.Enabled)
|
||||
setStringInMap(knowledgeNode, "base_path", cfg.BasePath)
|
||||
|
||||
|
||||
// 更新嵌入配置
|
||||
embeddingNode := ensureMap(knowledgeNode, "embedding")
|
||||
setStringInMap(embeddingNode, "provider", cfg.Embedding.Provider)
|
||||
@@ -858,7 +906,7 @@ func updateKnowledgeConfig(doc *yaml.Node, cfg config.KnowledgeConfig) {
|
||||
if cfg.Embedding.APIKey != "" {
|
||||
setStringInMap(embeddingNode, "api_key", cfg.Embedding.APIKey)
|
||||
}
|
||||
|
||||
|
||||
// 更新检索配置
|
||||
retrievalNode := ensureMap(knowledgeNode, "retrieval")
|
||||
setIntInMap(retrievalNode, "top_k", cfg.Retrieval.TopK)
|
||||
@@ -940,14 +988,14 @@ func findBoolInMap(mapNode *yaml.Node, key string) *bool {
|
||||
if mapNode == nil || mapNode.Kind != yaml.MappingNode {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
for i := 0; i < len(mapNode.Content); i += 2 {
|
||||
if i+1 >= len(mapNode.Content) {
|
||||
break
|
||||
}
|
||||
keyNode := mapNode.Content[i]
|
||||
valueNode := mapNode.Content[i+1]
|
||||
|
||||
|
||||
if keyNode.Kind == yaml.ScalarNode && keyNode.Value == key {
|
||||
if valueNode.Kind == yaml.ScalarNode {
|
||||
if valueNode.Value == "true" {
|
||||
@@ -989,5 +1037,3 @@ func setFloatInMap(mapNode *yaml.Node, key string, value float64) {
|
||||
valueNode.Value = fmt.Sprintf("%g", value)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -170,21 +170,36 @@ func (h *KnowledgeHandler) RebuildIndex(c *gin.Context) {
|
||||
|
||||
// ScanKnowledgeBase 扫描知识库
|
||||
func (h *KnowledgeHandler) ScanKnowledgeBase(c *gin.Context) {
|
||||
if err := h.manager.ScanKnowledgeBase(); err != nil {
|
||||
itemsToIndex, err := h.manager.ScanKnowledgeBase()
|
||||
if err != nil {
|
||||
h.logger.Error("扫描知识库失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 异步重建索引
|
||||
if len(itemsToIndex) == 0 {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "扫描完成,没有需要索引的新项或更新项"})
|
||||
return
|
||||
}
|
||||
|
||||
// 异步索引新添加或更新的项(增量索引)
|
||||
go func() {
|
||||
ctx := context.Background()
|
||||
if err := h.indexer.RebuildIndex(ctx); err != nil {
|
||||
h.logger.Error("重建索引失败", zap.Error(err))
|
||||
h.logger.Info("开始增量索引", zap.Int("count", len(itemsToIndex)))
|
||||
for i, itemID := range itemsToIndex {
|
||||
if err := h.indexer.IndexItem(ctx, itemID); err != nil {
|
||||
h.logger.Warn("索引知识项失败", zap.String("itemId", itemID), zap.Error(err))
|
||||
continue
|
||||
}
|
||||
h.logger.Info("索引进度", zap.Int("current", i+1), zap.Int("total", len(itemsToIndex)))
|
||||
}
|
||||
h.logger.Info("增量索引完成", zap.Int("totalItems", len(itemsToIndex)))
|
||||
}()
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "扫描完成,索引重建已开始"})
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": fmt.Sprintf("扫描完成,开始索引 %d 个新添加或更新的知识项", len(itemsToIndex)),
|
||||
"items_to_index": len(itemsToIndex),
|
||||
})
|
||||
}
|
||||
|
||||
// GetRetrievalLogs 获取检索日志
|
||||
|
||||
Reference in New Issue
Block a user