mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-03-31 00:09:29 +02:00
174 lines
5.6 KiB
Go
174 lines
5.6 KiB
Go
package handler
|
||
|
||
import (
|
||
"context"
|
||
"net/http"
|
||
"sync"
|
||
"time"
|
||
|
||
"cyberstrike-ai/internal/attackchain"
|
||
"cyberstrike-ai/internal/config"
|
||
"cyberstrike-ai/internal/database"
|
||
|
||
"github.com/gin-gonic/gin"
|
||
"go.uber.org/zap"
|
||
)
|
||
|
||
// AttackChainHandler 攻击链处理器
|
||
type AttackChainHandler struct {
|
||
db *database.DB
|
||
logger *zap.Logger
|
||
openAIConfig *config.OpenAIConfig
|
||
mu sync.RWMutex // 保护 openAIConfig 的并发访问
|
||
// 用于防止同一对话的并发生成
|
||
generatingLocks sync.Map // map[string]*sync.Mutex
|
||
}
|
||
|
||
// NewAttackChainHandler 创建新的攻击链处理器
|
||
func NewAttackChainHandler(db *database.DB, openAIConfig *config.OpenAIConfig, logger *zap.Logger) *AttackChainHandler {
|
||
return &AttackChainHandler{
|
||
db: db,
|
||
logger: logger,
|
||
openAIConfig: openAIConfig,
|
||
}
|
||
}
|
||
|
||
// UpdateConfig 更新OpenAI配置
|
||
func (h *AttackChainHandler) UpdateConfig(cfg *config.OpenAIConfig) {
|
||
h.mu.Lock()
|
||
defer h.mu.Unlock()
|
||
h.openAIConfig = cfg
|
||
h.logger.Info("AttackChainHandler配置已更新",
|
||
zap.String("base_url", cfg.BaseURL),
|
||
zap.String("model", cfg.Model),
|
||
)
|
||
}
|
||
|
||
// getOpenAIConfig 获取OpenAI配置(线程安全)
|
||
func (h *AttackChainHandler) getOpenAIConfig() *config.OpenAIConfig {
|
||
h.mu.RLock()
|
||
defer h.mu.RUnlock()
|
||
return h.openAIConfig
|
||
}
|
||
|
||
// GetAttackChain 获取攻击链(按需生成)
|
||
// GET /api/attack-chain/:conversationId
|
||
func (h *AttackChainHandler) GetAttackChain(c *gin.Context) {
|
||
conversationID := c.Param("conversationId")
|
||
if conversationID == "" {
|
||
c.JSON(http.StatusBadRequest, gin.H{"error": "conversationId is required"})
|
||
return
|
||
}
|
||
|
||
// 检查对话是否存在
|
||
_, err := h.db.GetConversation(conversationID)
|
||
if err != nil {
|
||
h.logger.Warn("对话不存在", zap.String("conversationId", conversationID), zap.Error(err))
|
||
c.JSON(http.StatusNotFound, gin.H{"error": "对话不存在"})
|
||
return
|
||
}
|
||
|
||
// 先尝试从数据库加载(如果已生成过)
|
||
openAIConfig := h.getOpenAIConfig()
|
||
builder := attackchain.NewBuilder(h.db, openAIConfig, h.logger)
|
||
chain, err := builder.LoadChainFromDatabase(conversationID)
|
||
if err == nil && len(chain.Nodes) > 0 {
|
||
// 如果已存在,直接返回
|
||
h.logger.Info("返回已存在的攻击链", zap.String("conversationId", conversationID))
|
||
c.JSON(http.StatusOK, chain)
|
||
return
|
||
}
|
||
|
||
// 如果不存在,则生成新的攻击链(按需生成)
|
||
// 使用锁机制防止同一对话的并发生成
|
||
lockInterface, _ := h.generatingLocks.LoadOrStore(conversationID, &sync.Mutex{})
|
||
lock := lockInterface.(*sync.Mutex)
|
||
|
||
// 尝试获取锁,如果正在生成则返回错误
|
||
acquired := lock.TryLock()
|
||
if !acquired {
|
||
h.logger.Info("攻击链正在生成中,请稍后再试", zap.String("conversationId", conversationID))
|
||
c.JSON(http.StatusConflict, gin.H{"error": "攻击链正在生成中,请稍后再试"})
|
||
return
|
||
}
|
||
defer lock.Unlock()
|
||
|
||
// 再次检查是否已生成(可能在等待锁的过程中已经生成完成)
|
||
chain, err = builder.LoadChainFromDatabase(conversationID)
|
||
if err == nil && len(chain.Nodes) > 0 {
|
||
h.logger.Info("返回已存在的攻击链(在锁等待期间已生成)", zap.String("conversationId", conversationID))
|
||
c.JSON(http.StatusOK, chain)
|
||
return
|
||
}
|
||
|
||
h.logger.Info("开始生成攻击链", zap.String("conversationId", conversationID))
|
||
|
||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
||
defer cancel()
|
||
|
||
chain, err = builder.BuildChainFromConversation(ctx, conversationID)
|
||
if err != nil {
|
||
h.logger.Error("生成攻击链失败", zap.String("conversationId", conversationID), zap.Error(err))
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": "生成攻击链失败: " + err.Error()})
|
||
return
|
||
}
|
||
|
||
// 生成完成后,从锁映射中删除(可选,保留也可以用于防止短时间内重复生成)
|
||
// h.generatingLocks.Delete(conversationID)
|
||
|
||
c.JSON(http.StatusOK, chain)
|
||
}
|
||
|
||
// RegenerateAttackChain 重新生成攻击链
|
||
// POST /api/attack-chain/:conversationId/regenerate
|
||
func (h *AttackChainHandler) RegenerateAttackChain(c *gin.Context) {
|
||
conversationID := c.Param("conversationId")
|
||
if conversationID == "" {
|
||
c.JSON(http.StatusBadRequest, gin.H{"error": "conversationId is required"})
|
||
return
|
||
}
|
||
|
||
// 检查对话是否存在
|
||
_, err := h.db.GetConversation(conversationID)
|
||
if err != nil {
|
||
h.logger.Warn("对话不存在", zap.String("conversationId", conversationID), zap.Error(err))
|
||
c.JSON(http.StatusNotFound, gin.H{"error": "对话不存在"})
|
||
return
|
||
}
|
||
|
||
// 删除旧的攻击链
|
||
if err := h.db.DeleteAttackChain(conversationID); err != nil {
|
||
h.logger.Warn("删除旧攻击链失败", zap.Error(err))
|
||
}
|
||
|
||
// 使用锁机制防止并发生成
|
||
lockInterface, _ := h.generatingLocks.LoadOrStore(conversationID, &sync.Mutex{})
|
||
lock := lockInterface.(*sync.Mutex)
|
||
|
||
acquired := lock.TryLock()
|
||
if !acquired {
|
||
h.logger.Info("攻击链正在生成中,请稍后再试", zap.String("conversationId", conversationID))
|
||
c.JSON(http.StatusConflict, gin.H{"error": "攻击链正在生成中,请稍后再试"})
|
||
return
|
||
}
|
||
defer lock.Unlock()
|
||
|
||
// 生成新的攻击链
|
||
h.logger.Info("重新生成攻击链", zap.String("conversationId", conversationID))
|
||
|
||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
||
defer cancel()
|
||
|
||
openAIConfig := h.getOpenAIConfig()
|
||
builder := attackchain.NewBuilder(h.db, openAIConfig, h.logger)
|
||
chain, err := builder.BuildChainFromConversation(ctx, conversationID)
|
||
if err != nil {
|
||
h.logger.Error("生成攻击链失败", zap.String("conversationId", conversationID), zap.Error(err))
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": "生成攻击链失败: " + err.Error()})
|
||
return
|
||
}
|
||
|
||
c.JSON(http.StatusOK, chain)
|
||
}
|
||
|