Add files via upload

This commit is contained in:
公明
2025-12-27 03:57:01 +08:00
committed by GitHub
parent 3e0867d459
commit 604e31d247
5 changed files with 196 additions and 47 deletions

View File

@@ -13,6 +13,7 @@ import (
"time"
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/knowledge"
"cyberstrike-ai/internal/mcp"
"cyberstrike-ai/internal/security"
"github.com/gin-gonic/gin"
@@ -23,6 +24,11 @@ import (
// KnowledgeToolRegistrar 知识库工具注册器接口
type KnowledgeToolRegistrar func() error
// RetrieverUpdater 检索器更新接口
type RetrieverUpdater interface {
UpdateConfig(config *knowledge.RetrievalConfig)
}
// ConfigHandler 配置处理器
type ConfigHandler struct {
configPath string
@@ -33,6 +39,7 @@ type ConfigHandler struct {
attackChainHandler AttackChainUpdater // 攻击链处理器接口,用于更新配置
externalMCPMgr *mcp.ExternalMCPManager // 外部MCP管理器
knowledgeToolRegistrar KnowledgeToolRegistrar // 知识库工具注册器(可选)
retrieverUpdater RetrieverUpdater // 检索器更新器(可选)
logger *zap.Logger
mu sync.RWMutex
}
@@ -69,6 +76,13 @@ func (h *ConfigHandler) SetKnowledgeToolRegistrar(registrar KnowledgeToolRegistr
h.knowledgeToolRegistrar = registrar
}
// SetRetrieverUpdater 设置检索器更新器
func (h *ConfigHandler) SetRetrieverUpdater(updater RetrieverUpdater) {
h.mu.Lock()
defer h.mu.Unlock()
h.retrieverUpdater = updater
}
// GetConfigResponse 获取配置响应
type GetConfigResponse struct {
OpenAI config.OpenAIConfig `json:"openai"`
@@ -639,6 +653,21 @@ func (h *ConfigHandler) ApplyConfig(c *gin.Context) {
h.logger.Info("AttackChainHandler配置已更新")
}
// 更新检索器配置(如果知识库启用)
if h.config.Knowledge.Enabled && h.retrieverUpdater != nil {
retrievalConfig := &knowledge.RetrievalConfig{
TopK: h.config.Knowledge.Retrieval.TopK,
SimilarityThreshold: h.config.Knowledge.Retrieval.SimilarityThreshold,
HybridWeight: h.config.Knowledge.Retrieval.HybridWeight,
}
h.retrieverUpdater.UpdateConfig(retrievalConfig)
h.logger.Info("检索器配置已更新",
zap.Int("top_k", retrievalConfig.TopK),
zap.Float64("similarity_threshold", retrievalConfig.SimilarityThreshold),
zap.Float64("hybrid_weight", retrievalConfig.HybridWeight),
)
}
h.logger.Info("配置已应用",
zap.Int("tools_count", len(h.config.Security.Tools)),
)
@@ -952,7 +981,13 @@ func setFloatInMap(mapNode *yaml.Node, key string, value float64) {
valueNode.Kind = yaml.ScalarNode
valueNode.Tag = "!!float"
valueNode.Style = 0
valueNode.Value = fmt.Sprintf("%g", value)
// 对于0.0到1.0之间的值如hybrid_weight使用%.1f确保0.0被明确序列化为"0.0"
// 对于其他值,使用%g自动选择最合适的格式
if value >= 0.0 && value <= 1.0 {
valueNode.Value = fmt.Sprintf("%.1f", value)
} else {
valueNode.Value = fmt.Sprintf("%g", value)
}
}