Add files via upload

This commit is contained in:
公明
2025-11-23 19:46:18 +08:00
committed by GitHub
parent 3304e1996a
commit a8bc32aefb
7 changed files with 138 additions and 45 deletions

View File

@@ -1283,6 +1283,12 @@ func (a *Agent) UpdateConfig(cfg *config.OpenAIConfig) {
a.mu.Lock()
defer a.mu.Unlock()
a.config = cfg
// 同时更新MemoryCompressor的配置如果存在
if a.memoryCompressor != nil {
a.memoryCompressor.UpdateConfig(cfg)
}
a.logger.Info("Agent配置已更新",
zap.String("base_url", cfg.BaseURL),
zap.String("model", cfg.Model),

View File

@@ -138,6 +138,26 @@ func NewMemoryCompressor(cfg MemoryCompressorConfig) (*MemoryCompressor, error)
}, nil
}
// UpdateConfig 更新OpenAI配置用于动态更新模型配置
func (mc *MemoryCompressor) UpdateConfig(cfg *config.OpenAIConfig) {
if cfg == nil {
return
}
// 更新summaryModel字段
if cfg.Model != "" {
mc.summaryModel = cfg.Model
}
// 更新completionClient中的配置如果是OpenAICompletionClient
if openAIClient, ok := mc.completionClient.(*OpenAICompletionClient); ok {
openAIClient.config = cfg
mc.logger.Info("MemoryCompressor配置已更新",
zap.String("model", cfg.Model),
)
}
}
// CompressHistory 根据Token限制压缩历史消息。
func (mc *MemoryCompressor) CompressHistory(ctx context.Context, messages []ChatMessage) ([]ChatMessage, bool, error) {
if len(messages) == 0 {
@@ -238,11 +258,24 @@ func (mc *MemoryCompressor) countTotalTokens(systemMsgs, regularMsgs []ChatMessa
return total
}
// getModelName 获取当前使用的模型名称优先从completionClient获取最新配置
func (mc *MemoryCompressor) getModelName() string {
// 如果completionClient是OpenAICompletionClient从它获取最新的模型名称
if openAIClient, ok := mc.completionClient.(*OpenAICompletionClient); ok {
if openAIClient.config != nil && openAIClient.config.Model != "" {
return openAIClient.config.Model
}
}
// 否则使用保存的summaryModel
return mc.summaryModel
}
func (mc *MemoryCompressor) countTokens(text string) int {
if mc.tokenCounter == nil {
return len(text) / 4
}
count, err := mc.tokenCounter.Count(mc.summaryModel, text)
modelName := mc.getModelName()
count, err := mc.tokenCounter.Count(modelName, text)
if err != nil {
return len(text) / 4
}
@@ -269,7 +302,9 @@ func (mc *MemoryCompressor) summarizeChunk(ctx context.Context, chunk []ChatMess
conversation := strings.Join(formatted, "\n")
prompt := fmt.Sprintf(summaryPromptTemplate, conversation)
summary, err := mc.completionClient.Complete(ctx, mc.summaryModel, prompt, mc.timeout)
// 使用动态获取的模型名称而不是保存的summaryModel
modelName := mc.getModelName()
summary, err := mc.completionClient.Complete(ctx, modelName, prompt, mc.timeout)
if err != nil {
return ChatMessage{}, err
}