mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-04-01 08:40:42 +02:00
Add files via upload
This commit is contained in:
@@ -1283,6 +1283,12 @@ func (a *Agent) UpdateConfig(cfg *config.OpenAIConfig) {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
a.config = cfg
|
||||
|
||||
// 同时更新MemoryCompressor的配置(如果存在)
|
||||
if a.memoryCompressor != nil {
|
||||
a.memoryCompressor.UpdateConfig(cfg)
|
||||
}
|
||||
|
||||
a.logger.Info("Agent配置已更新",
|
||||
zap.String("base_url", cfg.BaseURL),
|
||||
zap.String("model", cfg.Model),
|
||||
|
||||
@@ -138,6 +138,26 @@ func NewMemoryCompressor(cfg MemoryCompressorConfig) (*MemoryCompressor, error)
|
||||
}, nil
|
||||
}
|
||||
|
||||
// UpdateConfig 更新OpenAI配置(用于动态更新模型配置)
|
||||
func (mc *MemoryCompressor) UpdateConfig(cfg *config.OpenAIConfig) {
|
||||
if cfg == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 更新summaryModel字段
|
||||
if cfg.Model != "" {
|
||||
mc.summaryModel = cfg.Model
|
||||
}
|
||||
|
||||
// 更新completionClient中的配置(如果是OpenAICompletionClient)
|
||||
if openAIClient, ok := mc.completionClient.(*OpenAICompletionClient); ok {
|
||||
openAIClient.config = cfg
|
||||
mc.logger.Info("MemoryCompressor配置已更新",
|
||||
zap.String("model", cfg.Model),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// CompressHistory 根据Token限制压缩历史消息。
|
||||
func (mc *MemoryCompressor) CompressHistory(ctx context.Context, messages []ChatMessage) ([]ChatMessage, bool, error) {
|
||||
if len(messages) == 0 {
|
||||
@@ -238,11 +258,24 @@ func (mc *MemoryCompressor) countTotalTokens(systemMsgs, regularMsgs []ChatMessa
|
||||
return total
|
||||
}
|
||||
|
||||
// getModelName 获取当前使用的模型名称(优先从completionClient获取最新配置)
|
||||
func (mc *MemoryCompressor) getModelName() string {
|
||||
// 如果completionClient是OpenAICompletionClient,从它获取最新的模型名称
|
||||
if openAIClient, ok := mc.completionClient.(*OpenAICompletionClient); ok {
|
||||
if openAIClient.config != nil && openAIClient.config.Model != "" {
|
||||
return openAIClient.config.Model
|
||||
}
|
||||
}
|
||||
// 否则使用保存的summaryModel
|
||||
return mc.summaryModel
|
||||
}
|
||||
|
||||
func (mc *MemoryCompressor) countTokens(text string) int {
|
||||
if mc.tokenCounter == nil {
|
||||
return len(text) / 4
|
||||
}
|
||||
count, err := mc.tokenCounter.Count(mc.summaryModel, text)
|
||||
modelName := mc.getModelName()
|
||||
count, err := mc.tokenCounter.Count(modelName, text)
|
||||
if err != nil {
|
||||
return len(text) / 4
|
||||
}
|
||||
@@ -269,7 +302,9 @@ func (mc *MemoryCompressor) summarizeChunk(ctx context.Context, chunk []ChatMess
|
||||
conversation := strings.Join(formatted, "\n")
|
||||
prompt := fmt.Sprintf(summaryPromptTemplate, conversation)
|
||||
|
||||
summary, err := mc.completionClient.Complete(ctx, mc.summaryModel, prompt, mc.timeout)
|
||||
// 使用动态获取的模型名称,而不是保存的summaryModel
|
||||
modelName := mc.getModelName()
|
||||
summary, err := mc.completionClient.Complete(ctx, modelName, prompt, mc.timeout)
|
||||
if err != nil {
|
||||
return ChatMessage{}, err
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user