diff --git a/internal/agent/agent.go b/internal/agent/agent.go index bc858106..428d0caf 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -1283,6 +1283,12 @@ func (a *Agent) UpdateConfig(cfg *config.OpenAIConfig) { a.mu.Lock() defer a.mu.Unlock() a.config = cfg + + // 同时更新MemoryCompressor的配置(如果存在) + if a.memoryCompressor != nil { + a.memoryCompressor.UpdateConfig(cfg) + } + a.logger.Info("Agent配置已更新", zap.String("base_url", cfg.BaseURL), zap.String("model", cfg.Model), diff --git a/internal/agent/memory_compressor.go b/internal/agent/memory_compressor.go index ece7497d..14aec88b 100644 --- a/internal/agent/memory_compressor.go +++ b/internal/agent/memory_compressor.go @@ -138,6 +138,26 @@ func NewMemoryCompressor(cfg MemoryCompressorConfig) (*MemoryCompressor, error) }, nil } +// UpdateConfig 更新OpenAI配置(用于动态更新模型配置) +func (mc *MemoryCompressor) UpdateConfig(cfg *config.OpenAIConfig) { + if cfg == nil { + return + } + + // 更新summaryModel字段 + if cfg.Model != "" { + mc.summaryModel = cfg.Model + } + + // 更新completionClient中的配置(如果是OpenAICompletionClient) + if openAIClient, ok := mc.completionClient.(*OpenAICompletionClient); ok { + openAIClient.config = cfg + mc.logger.Info("MemoryCompressor配置已更新", + zap.String("model", cfg.Model), + ) + } +} + // CompressHistory 根据Token限制压缩历史消息。 func (mc *MemoryCompressor) CompressHistory(ctx context.Context, messages []ChatMessage) ([]ChatMessage, bool, error) { if len(messages) == 0 { @@ -238,11 +258,24 @@ func (mc *MemoryCompressor) countTotalTokens(systemMsgs, regularMsgs []ChatMessa return total } +// getModelName 获取当前使用的模型名称(优先从completionClient获取最新配置) +func (mc *MemoryCompressor) getModelName() string { + // 如果completionClient是OpenAICompletionClient,从它获取最新的模型名称 + if openAIClient, ok := mc.completionClient.(*OpenAICompletionClient); ok { + if openAIClient.config != nil && openAIClient.config.Model != "" { + return openAIClient.config.Model + } + } + // 否则使用保存的summaryModel + return mc.summaryModel +} + func (mc *MemoryCompressor) countTokens(text string) int { if mc.tokenCounter == nil { return len(text) / 4 } - count, err := mc.tokenCounter.Count(mc.summaryModel, text) + modelName := mc.getModelName() + count, err := mc.tokenCounter.Count(modelName, text) if err != nil { return len(text) / 4 } @@ -269,7 +302,9 @@ func (mc *MemoryCompressor) summarizeChunk(ctx context.Context, chunk []ChatMess conversation := strings.Join(formatted, "\n") prompt := fmt.Sprintf(summaryPromptTemplate, conversation) - summary, err := mc.completionClient.Complete(ctx, mc.summaryModel, prompt, mc.timeout) + // 使用动态获取的模型名称,而不是保存的summaryModel + modelName := mc.getModelName() + summary, err := mc.completionClient.Complete(ctx, modelName, prompt, mc.timeout) if err != nil { return ChatMessage{}, err } diff --git a/internal/app/app.go b/internal/app/app.go index b7963094..845f18bf 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -128,9 +128,9 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) { monitorHandler.SetExternalMCPManager(externalMCPMgr) // 设置外部MCP管理器,以便获取外部MCP执行记录 conversationHandler := handler.NewConversationHandler(db, log.Logger) authHandler := handler.NewAuthHandler(authManager, cfg, configPath, log.Logger) - configHandler := handler.NewConfigHandler(configPath, cfg, mcpServer, executor, agent, externalMCPMgr, log.Logger) - externalMCPHandler := handler.NewExternalMCPHandler(externalMCPMgr, cfg, configPath, log.Logger) attackChainHandler := handler.NewAttackChainHandler(db, &cfg.OpenAI, log.Logger) + configHandler := handler.NewConfigHandler(configPath, cfg, mcpServer, executor, agent, attackChainHandler, externalMCPMgr, log.Logger) + externalMCPHandler := handler.NewExternalMCPHandler(externalMCPMgr, cfg, configPath, log.Logger) // 设置路由 setupRoutes( diff --git a/internal/database/conversation.go b/internal/database/conversation.go index e21c6c27..e15785a0 100644 --- a/internal/database/conversation.go +++ b/internal/database/conversation.go @@ -74,15 +74,15 @@ func (db *DB) GetConversation(id string) (*Conversation, error) { conv.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05", createdAt) } if err1 != nil { - conv.CreatedAt, err1 = time.Parse(time.RFC3339, createdAt) + conv.CreatedAt, _ = time.Parse(time.RFC3339, createdAt) } - + conv.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05.999999999-07:00", updatedAt) if err2 != nil { conv.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05", updatedAt) } if err2 != nil { - conv.UpdatedAt, err2 = time.Parse(time.RFC3339, updatedAt) + conv.UpdatedAt, _ = time.Parse(time.RFC3339, updatedAt) } // 加载消息 @@ -155,17 +155,17 @@ func (db *DB) ListConversations(limit, offset int) ([]*Conversation, error) { conv.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05", createdAt) } if err1 != nil { - conv.CreatedAt, err1 = time.Parse(time.RFC3339, createdAt) + conv.CreatedAt, _ = time.Parse(time.RFC3339, createdAt) } - + conv.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05.999999999-07:00", updatedAt) if err2 != nil { conv.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05", updatedAt) } if err2 != nil { - conv.UpdatedAt, err2 = time.Parse(time.RFC3339, updatedAt) + conv.UpdatedAt, _ = time.Parse(time.RFC3339, updatedAt) } - + conversations = append(conversations, &conv) } @@ -208,7 +208,7 @@ func (db *DB) DeleteConversation(id string) error { // AddMessage 添加消息 func (db *DB) AddMessage(conversationID, role, content string, mcpExecutionIDs []string) (*Message, error) { id := uuid.New().String() - + var mcpIDsJSON string if len(mcpExecutionIDs) > 0 { jsonData, err := json.Marshal(mcpExecutionIDs) @@ -272,7 +272,7 @@ func (db *DB) GetMessages(conversationID string) ([]Message, error) { msg.CreatedAt, err = time.Parse("2006-01-02 15:04:05", createdAt) } if err != nil { - msg.CreatedAt, err = time.Parse(time.RFC3339, createdAt) + msg.CreatedAt, _ = time.Parse(time.RFC3339, createdAt) } // 解析MCP执行ID @@ -290,19 +290,19 @@ func (db *DB) GetMessages(conversationID string) ([]Message, error) { // ProcessDetail 过程详情事件 type ProcessDetail struct { - ID string `json:"id"` - MessageID string `json:"messageId"` - ConversationID string `json:"conversationId"` - EventType string `json:"eventType"` // iteration, thinking, tool_calls_detected, tool_call, tool_result, progress, error - Message string `json:"message"` - Data string `json:"data"` // JSON格式的数据 - CreatedAt time.Time `json:"createdAt"` + ID string `json:"id"` + MessageID string `json:"messageId"` + ConversationID string `json:"conversationId"` + EventType string `json:"eventType"` // iteration, thinking, tool_calls_detected, tool_call, tool_result, progress, error + Message string `json:"message"` + Data string `json:"data"` // JSON格式的数据 + CreatedAt time.Time `json:"createdAt"` } // AddProcessDetail 添加过程详情事件 func (db *DB) AddProcessDetail(messageID, conversationID, eventType, message string, data interface{}) error { id := uuid.New().String() - + var dataJSON string if data != nil { jsonData, err := json.Marshal(data) @@ -351,7 +351,7 @@ func (db *DB) GetProcessDetails(messageID string) ([]ProcessDetail, error) { detail.CreatedAt, err = time.Parse("2006-01-02 15:04:05", createdAt) } if err != nil { - detail.CreatedAt, err = time.Parse(time.RFC3339, createdAt) + detail.CreatedAt, _ = time.Parse(time.RFC3339, createdAt) } details = append(details, detail) @@ -387,7 +387,7 @@ func (db *DB) GetProcessDetailsByConversation(conversationID string) (map[string detail.CreatedAt, err = time.Parse("2006-01-02 15:04:05", createdAt) } if err != nil { - detail.CreatedAt, err = time.Parse(time.RFC3339, createdAt) + detail.CreatedAt, _ = time.Parse(time.RFC3339, createdAt) } detailsMap[detail.MessageID] = append(detailsMap[detail.MessageID], detail) @@ -395,4 +395,3 @@ func (db *DB) GetProcessDetailsByConversation(conversationID string) (map[string return detailsMap, nil } - diff --git a/internal/handler/attackchain.go b/internal/handler/attackchain.go index e018c004..2b78b9bf 100644 --- a/internal/handler/attackchain.go +++ b/internal/handler/attackchain.go @@ -19,6 +19,7 @@ type AttackChainHandler struct { db *database.DB logger *zap.Logger openAIConfig *config.OpenAIConfig + mu sync.RWMutex // 保护 openAIConfig 的并发访问 // 用于防止同一对话的并发生成 generatingLocks sync.Map // map[string]*sync.Mutex } @@ -32,6 +33,24 @@ func NewAttackChainHandler(db *database.DB, openAIConfig *config.OpenAIConfig, l } } +// UpdateConfig 更新OpenAI配置 +func (h *AttackChainHandler) UpdateConfig(cfg *config.OpenAIConfig) { + h.mu.Lock() + defer h.mu.Unlock() + h.openAIConfig = cfg + h.logger.Info("AttackChainHandler配置已更新", + zap.String("base_url", cfg.BaseURL), + zap.String("model", cfg.Model), + ) +} + +// getOpenAIConfig 获取OpenAI配置(线程安全) +func (h *AttackChainHandler) getOpenAIConfig() *config.OpenAIConfig { + h.mu.RLock() + defer h.mu.RUnlock() + return h.openAIConfig +} + // GetAttackChain 获取攻击链(按需生成) // GET /api/attack-chain/:conversationId func (h *AttackChainHandler) GetAttackChain(c *gin.Context) { @@ -50,7 +69,8 @@ func (h *AttackChainHandler) GetAttackChain(c *gin.Context) { } // 先尝试从数据库加载(如果已生成过) - builder := attackchain.NewBuilder(h.db, h.openAIConfig, h.logger) + openAIConfig := h.getOpenAIConfig() + builder := attackchain.NewBuilder(h.db, openAIConfig, h.logger) chain, err := builder.LoadChainFromDatabase(conversationID) if err == nil && len(chain.Nodes) > 0 { // 如果已存在,直接返回 @@ -139,7 +159,8 @@ func (h *AttackChainHandler) RegenerateAttackChain(c *gin.Context) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) defer cancel() - builder := attackchain.NewBuilder(h.db, h.openAIConfig, h.logger) + openAIConfig := h.getOpenAIConfig() + builder := attackchain.NewBuilder(h.db, openAIConfig, h.logger) chain, err := builder.BuildChainFromConversation(ctx, conversationID) if err != nil { h.logger.Error("生成攻击链失败", zap.String("conversationId", conversationID), zap.Error(err)) diff --git a/internal/handler/config.go b/internal/handler/config.go index 94e01d1d..6cd08e0d 100644 --- a/internal/handler/config.go +++ b/internal/handler/config.go @@ -22,14 +22,20 @@ import ( // ConfigHandler 配置处理器 type ConfigHandler struct { - configPath string - config *config.Config - mcpServer *mcp.Server - executor *security.Executor - agent AgentUpdater // Agent接口,用于更新Agent配置 - externalMCPMgr *mcp.ExternalMCPManager // 外部MCP管理器 - logger *zap.Logger - mu sync.RWMutex + configPath string + config *config.Config + mcpServer *mcp.Server + executor *security.Executor + agent AgentUpdater // Agent接口,用于更新Agent配置 + attackChainHandler AttackChainUpdater // 攻击链处理器接口,用于更新配置 + externalMCPMgr *mcp.ExternalMCPManager // 外部MCP管理器 + logger *zap.Logger + mu sync.RWMutex +} + +// AttackChainUpdater 攻击链处理器更新接口 +type AttackChainUpdater interface { + UpdateConfig(cfg *config.OpenAIConfig) } // AgentUpdater Agent更新接口 @@ -39,15 +45,16 @@ type AgentUpdater interface { } // NewConfigHandler 创建新的配置处理器 -func NewConfigHandler(configPath string, cfg *config.Config, mcpServer *mcp.Server, executor *security.Executor, agent AgentUpdater, externalMCPMgr *mcp.ExternalMCPManager, logger *zap.Logger) *ConfigHandler { +func NewConfigHandler(configPath string, cfg *config.Config, mcpServer *mcp.Server, executor *security.Executor, agent AgentUpdater, attackChainHandler AttackChainUpdater, externalMCPMgr *mcp.ExternalMCPManager, logger *zap.Logger) *ConfigHandler { return &ConfigHandler{ - configPath: configPath, - config: cfg, - mcpServer: mcpServer, - executor: executor, - agent: agent, - externalMCPMgr: externalMCPMgr, - logger: logger, + configPath: configPath, + config: cfg, + mcpServer: mcpServer, + executor: executor, + agent: agent, + attackChainHandler: attackChainHandler, + externalMCPMgr: externalMCPMgr, + logger: logger, } } @@ -522,6 +529,12 @@ func (h *ConfigHandler) ApplyConfig(c *gin.Context) { h.logger.Info("Agent配置已更新") } + // 更新AttackChainHandler的OpenAI配置 + if h.attackChainHandler != nil { + h.attackChainHandler.UpdateConfig(&h.config.OpenAI) + h.logger.Info("AttackChainHandler配置已更新") + } + h.logger.Info("配置已应用", zap.Int("tools_count", len(h.config.Security.Tools)), ) diff --git a/web/static/js/chat.js b/web/static/js/chat.js index ba722f04..8b5db254 100644 --- a/web/static/js/chat.js +++ b/web/static/js/chat.js @@ -506,7 +506,7 @@ function initializeChatUI() { let messageCounter = 0; // 添加消息 -function addMessage(role, content, mcpExecutionIds = null, progressId = null) { +function addMessage(role, content, mcpExecutionIds = null, progressId = null, createdAt = null) { const messagesDiv = document.getElementById('chat-messages'); const messageDiv = document.createElement('div'); messageCounter++; @@ -582,7 +582,25 @@ function addMessage(role, content, mcpExecutionIds = null, progressId = null) { // 添加时间戳 const timeDiv = document.createElement('div'); timeDiv.className = 'message-time'; - timeDiv.textContent = new Date().toLocaleTimeString('zh-CN', { hour: '2-digit', minute: '2-digit' }); + // 如果有传入的创建时间,使用它;否则使用当前时间 + let messageTime; + if (createdAt) { + // 处理字符串或Date对象 + if (typeof createdAt === 'string') { + messageTime = new Date(createdAt); + } else if (createdAt instanceof Date) { + messageTime = createdAt; + } else { + messageTime = new Date(createdAt); + } + // 如果解析失败,使用当前时间 + if (isNaN(messageTime.getTime())) { + messageTime = new Date(); + } + } else { + messageTime = new Date(); + } + timeDiv.textContent = messageTime.toLocaleTimeString('zh-CN', { hour: '2-digit', minute: '2-digit' }); contentWrapper.appendChild(timeDiv); // 如果有MCP执行ID或进度ID,添加查看详情区域(统一使用"渗透测试详情"样式) @@ -1088,7 +1106,8 @@ async function loadConversation(conversationId) { } } - const messageId = addMessage(msg.role, displayContent, msg.mcpExecutionIds || []); + // 传递消息的创建时间 + const messageId = addMessage(msg.role, displayContent, msg.mcpExecutionIds || [], null, msg.createdAt); // 如果有过程详情,显示它们 if (msg.processDetails && msg.processDetails.length > 0 && msg.role === 'assistant') { // 延迟一下,确保消息已经渲染