mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-03-31 08:19:54 +02:00
Add files via upload
This commit is contained in:
@@ -19,6 +19,7 @@ type AttackChainHandler struct {
|
||||
db *database.DB
|
||||
logger *zap.Logger
|
||||
openAIConfig *config.OpenAIConfig
|
||||
mu sync.RWMutex // 保护 openAIConfig 的并发访问
|
||||
// 用于防止同一对话的并发生成
|
||||
generatingLocks sync.Map // map[string]*sync.Mutex
|
||||
}
|
||||
@@ -32,6 +33,24 @@ func NewAttackChainHandler(db *database.DB, openAIConfig *config.OpenAIConfig, l
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateConfig 更新OpenAI配置
|
||||
func (h *AttackChainHandler) UpdateConfig(cfg *config.OpenAIConfig) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
h.openAIConfig = cfg
|
||||
h.logger.Info("AttackChainHandler配置已更新",
|
||||
zap.String("base_url", cfg.BaseURL),
|
||||
zap.String("model", cfg.Model),
|
||||
)
|
||||
}
|
||||
|
||||
// getOpenAIConfig 获取OpenAI配置(线程安全)
|
||||
func (h *AttackChainHandler) getOpenAIConfig() *config.OpenAIConfig {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
return h.openAIConfig
|
||||
}
|
||||
|
||||
// GetAttackChain 获取攻击链(按需生成)
|
||||
// GET /api/attack-chain/:conversationId
|
||||
func (h *AttackChainHandler) GetAttackChain(c *gin.Context) {
|
||||
@@ -50,7 +69,8 @@ func (h *AttackChainHandler) GetAttackChain(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 先尝试从数据库加载(如果已生成过)
|
||||
builder := attackchain.NewBuilder(h.db, h.openAIConfig, h.logger)
|
||||
openAIConfig := h.getOpenAIConfig()
|
||||
builder := attackchain.NewBuilder(h.db, openAIConfig, h.logger)
|
||||
chain, err := builder.LoadChainFromDatabase(conversationID)
|
||||
if err == nil && len(chain.Nodes) > 0 {
|
||||
// 如果已存在,直接返回
|
||||
@@ -139,7 +159,8 @@ func (h *AttackChainHandler) RegenerateAttackChain(c *gin.Context) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
builder := attackchain.NewBuilder(h.db, h.openAIConfig, h.logger)
|
||||
openAIConfig := h.getOpenAIConfig()
|
||||
builder := attackchain.NewBuilder(h.db, openAIConfig, h.logger)
|
||||
chain, err := builder.BuildChainFromConversation(ctx, conversationID)
|
||||
if err != nil {
|
||||
h.logger.Error("生成攻击链失败", zap.String("conversationId", conversationID), zap.Error(err))
|
||||
|
||||
@@ -22,14 +22,20 @@ import (
|
||||
|
||||
// ConfigHandler 配置处理器
|
||||
type ConfigHandler struct {
|
||||
configPath string
|
||||
config *config.Config
|
||||
mcpServer *mcp.Server
|
||||
executor *security.Executor
|
||||
agent AgentUpdater // Agent接口,用于更新Agent配置
|
||||
externalMCPMgr *mcp.ExternalMCPManager // 外部MCP管理器
|
||||
logger *zap.Logger
|
||||
mu sync.RWMutex
|
||||
configPath string
|
||||
config *config.Config
|
||||
mcpServer *mcp.Server
|
||||
executor *security.Executor
|
||||
agent AgentUpdater // Agent接口,用于更新Agent配置
|
||||
attackChainHandler AttackChainUpdater // 攻击链处理器接口,用于更新配置
|
||||
externalMCPMgr *mcp.ExternalMCPManager // 外部MCP管理器
|
||||
logger *zap.Logger
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// AttackChainUpdater 攻击链处理器更新接口
|
||||
type AttackChainUpdater interface {
|
||||
UpdateConfig(cfg *config.OpenAIConfig)
|
||||
}
|
||||
|
||||
// AgentUpdater Agent更新接口
|
||||
@@ -39,15 +45,16 @@ type AgentUpdater interface {
|
||||
}
|
||||
|
||||
// NewConfigHandler 创建新的配置处理器
|
||||
func NewConfigHandler(configPath string, cfg *config.Config, mcpServer *mcp.Server, executor *security.Executor, agent AgentUpdater, externalMCPMgr *mcp.ExternalMCPManager, logger *zap.Logger) *ConfigHandler {
|
||||
func NewConfigHandler(configPath string, cfg *config.Config, mcpServer *mcp.Server, executor *security.Executor, agent AgentUpdater, attackChainHandler AttackChainUpdater, externalMCPMgr *mcp.ExternalMCPManager, logger *zap.Logger) *ConfigHandler {
|
||||
return &ConfigHandler{
|
||||
configPath: configPath,
|
||||
config: cfg,
|
||||
mcpServer: mcpServer,
|
||||
executor: executor,
|
||||
agent: agent,
|
||||
externalMCPMgr: externalMCPMgr,
|
||||
logger: logger,
|
||||
configPath: configPath,
|
||||
config: cfg,
|
||||
mcpServer: mcpServer,
|
||||
executor: executor,
|
||||
agent: agent,
|
||||
attackChainHandler: attackChainHandler,
|
||||
externalMCPMgr: externalMCPMgr,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -522,6 +529,12 @@ func (h *ConfigHandler) ApplyConfig(c *gin.Context) {
|
||||
h.logger.Info("Agent配置已更新")
|
||||
}
|
||||
|
||||
// 更新AttackChainHandler的OpenAI配置
|
||||
if h.attackChainHandler != nil {
|
||||
h.attackChainHandler.UpdateConfig(&h.config.OpenAI)
|
||||
h.logger.Info("AttackChainHandler配置已更新")
|
||||
}
|
||||
|
||||
h.logger.Info("配置已应用",
|
||||
zap.Int("tools_count", len(h.config.Security.Tools)),
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user