Files
CyberStrikeAI/internal/handler/attackchain.go
2025-11-23 19:46:18 +08:00

174 lines
5.6 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package handler
import (
"context"
"net/http"
"sync"
"time"
"cyberstrike-ai/internal/attackchain"
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/database"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// AttackChainHandler 攻击链处理器
type AttackChainHandler struct {
db *database.DB
logger *zap.Logger
openAIConfig *config.OpenAIConfig
mu sync.RWMutex // 保护 openAIConfig 的并发访问
// 用于防止同一对话的并发生成
generatingLocks sync.Map // map[string]*sync.Mutex
}
// NewAttackChainHandler 创建新的攻击链处理器
func NewAttackChainHandler(db *database.DB, openAIConfig *config.OpenAIConfig, logger *zap.Logger) *AttackChainHandler {
return &AttackChainHandler{
db: db,
logger: logger,
openAIConfig: openAIConfig,
}
}
// UpdateConfig 更新OpenAI配置
func (h *AttackChainHandler) UpdateConfig(cfg *config.OpenAIConfig) {
h.mu.Lock()
defer h.mu.Unlock()
h.openAIConfig = cfg
h.logger.Info("AttackChainHandler配置已更新",
zap.String("base_url", cfg.BaseURL),
zap.String("model", cfg.Model),
)
}
// getOpenAIConfig 获取OpenAI配置线程安全
func (h *AttackChainHandler) getOpenAIConfig() *config.OpenAIConfig {
h.mu.RLock()
defer h.mu.RUnlock()
return h.openAIConfig
}
// GetAttackChain 获取攻击链(按需生成)
// GET /api/attack-chain/:conversationId
func (h *AttackChainHandler) GetAttackChain(c *gin.Context) {
conversationID := c.Param("conversationId")
if conversationID == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "conversationId is required"})
return
}
// 检查对话是否存在
_, err := h.db.GetConversation(conversationID)
if err != nil {
h.logger.Warn("对话不存在", zap.String("conversationId", conversationID), zap.Error(err))
c.JSON(http.StatusNotFound, gin.H{"error": "对话不存在"})
return
}
// 先尝试从数据库加载(如果已生成过)
openAIConfig := h.getOpenAIConfig()
builder := attackchain.NewBuilder(h.db, openAIConfig, h.logger)
chain, err := builder.LoadChainFromDatabase(conversationID)
if err == nil && len(chain.Nodes) > 0 {
// 如果已存在,直接返回
h.logger.Info("返回已存在的攻击链", zap.String("conversationId", conversationID))
c.JSON(http.StatusOK, chain)
return
}
// 如果不存在,则生成新的攻击链(按需生成)
// 使用锁机制防止同一对话的并发生成
lockInterface, _ := h.generatingLocks.LoadOrStore(conversationID, &sync.Mutex{})
lock := lockInterface.(*sync.Mutex)
// 尝试获取锁,如果正在生成则返回错误
acquired := lock.TryLock()
if !acquired {
h.logger.Info("攻击链正在生成中,请稍后再试", zap.String("conversationId", conversationID))
c.JSON(http.StatusConflict, gin.H{"error": "攻击链正在生成中,请稍后再试"})
return
}
defer lock.Unlock()
// 再次检查是否已生成(可能在等待锁的过程中已经生成完成)
chain, err = builder.LoadChainFromDatabase(conversationID)
if err == nil && len(chain.Nodes) > 0 {
h.logger.Info("返回已存在的攻击链(在锁等待期间已生成)", zap.String("conversationId", conversationID))
c.JSON(http.StatusOK, chain)
return
}
h.logger.Info("开始生成攻击链", zap.String("conversationId", conversationID))
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()
chain, err = builder.BuildChainFromConversation(ctx, conversationID)
if err != nil {
h.logger.Error("生成攻击链失败", zap.String("conversationId", conversationID), zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "生成攻击链失败: " + err.Error()})
return
}
// 生成完成后,从锁映射中删除(可选,保留也可以用于防止短时间内重复生成)
// h.generatingLocks.Delete(conversationID)
c.JSON(http.StatusOK, chain)
}
// RegenerateAttackChain 重新生成攻击链
// POST /api/attack-chain/:conversationId/regenerate
func (h *AttackChainHandler) RegenerateAttackChain(c *gin.Context) {
conversationID := c.Param("conversationId")
if conversationID == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "conversationId is required"})
return
}
// 检查对话是否存在
_, err := h.db.GetConversation(conversationID)
if err != nil {
h.logger.Warn("对话不存在", zap.String("conversationId", conversationID), zap.Error(err))
c.JSON(http.StatusNotFound, gin.H{"error": "对话不存在"})
return
}
// 删除旧的攻击链
if err := h.db.DeleteAttackChain(conversationID); err != nil {
h.logger.Warn("删除旧攻击链失败", zap.Error(err))
}
// 使用锁机制防止并发生成
lockInterface, _ := h.generatingLocks.LoadOrStore(conversationID, &sync.Mutex{})
lock := lockInterface.(*sync.Mutex)
acquired := lock.TryLock()
if !acquired {
h.logger.Info("攻击链正在生成中,请稍后再试", zap.String("conversationId", conversationID))
c.JSON(http.StatusConflict, gin.H{"error": "攻击链正在生成中,请稍后再试"})
return
}
defer lock.Unlock()
// 生成新的攻击链
h.logger.Info("重新生成攻击链", zap.String("conversationId", conversationID))
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()
openAIConfig := h.getOpenAIConfig()
builder := attackchain.NewBuilder(h.db, openAIConfig, h.logger)
chain, err := builder.BuildChainFromConversation(ctx, conversationID)
if err != nil {
h.logger.Error("生成攻击链失败", zap.String("conversationId", conversationID), zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "生成攻击链失败: " + err.Error()})
return
}
c.JSON(http.StatusOK, chain)
}