diff --git a/config.yaml b/config.yaml index 59428b3c..3e1b9cc4 100644 --- a/config.yaml +++ b/config.yaml @@ -7,44 +7,38 @@ # 服务器配置 server: - host: 0.0.0.0 # 监听地址,0.0.0.0 表示监听所有网络接口 - port: 8080 # HTTP 服务端口,可通过浏览器访问 http://localhost:8080 - + host: 0.0.0.0 # 监听地址,0.0.0.0 表示监听所有网络接口 + port: 8080 # HTTP 服务端口,可通过浏览器访问 http://localhost:8080 # 日志配置 log: - level: info # 日志级别: debug(调试), info(信息), warn(警告), error(错误) - output: stdout # 日志输出位置: stdout(标准输出), stderr(标准错误), 或文件路径 - + level: info # 日志级别: debug(调试), info(信息), warn(警告), error(错误) + output: stdout # 日志输出位置: stdout(标准输出), stderr(标准错误), 或文件路径 # MCP 协议配置 # MCP (Model Context Protocol) 用于工具注册和调用 mcp: - enabled: true # 是否启用 MCP 服务器 - host: 0.0.0.0 # MCP 服务器监听地址 - port: 8081 # MCP 服务器端口 - + enabled: true # 是否启用 MCP 服务器 + host: 0.0.0.0 # MCP 服务器监听地址 + port: 8081 # MCP 服务器端口 # AI 模型配置(支持 OpenAI 兼容 API) # 必填项:api_key, base_url, model 必须填写才能正常运行 openai: - api_key: sk-xxx # API 密钥(必填) - base_url: https://api.deepseek.com/v1 # API 基础 URL(必填) - # 支持的 API 服务商: - # - OpenAI: https://api.openai.com/v1 - # - DeepSeek: https://api.deepseek.com/v1 - # - 其他兼容 OpenAI 协议的 API - model: deepseek-chat # 模型名称(必填) - # 常用模型: gpt-4, gpt-3.5-turbo, deepseek-chat, claude-3-opus 等 - + base_url: https://api.deepseek.com/v1 # API 基础 URL(必填) + api_key: sk-xxx # API 密钥(必填) + # 支持的 API 服务商: + # - OpenAI: https://api.openai.com/v1 + # - DeepSeek: https://api.deepseek.com/v1 + # - 其他兼容 OpenAI 协议的 API + model: deepseek-chat # 模型名称(必填) + # 常用模型: gpt-4, gpt-3.5-turbo, deepseek-chat, claude-3-opus 等 # Agent 配置 agent: - max_iterations: 30 # 最大迭代次数,AI 代理最多执行多少轮工具调用 - # 达到最大迭代次数时,AI 会自动总结测试结果 - + max_iterations: 30 # 最大迭代次数,AI 代理最多执行多少轮工具调用 + # 达到最大迭代次数时,AI 会自动总结测试结果 # 数据库配置 database: - path: data/conversations.db # SQLite 数据库文件路径,用于存储对话历史和消息 - + path: data/conversations.db # SQLite 数据库文件路径,用于存储对话历史和消息 # 安全工具配置 security: - tools_dir: tools # 工具配置文件目录(相对于配置文件所在目录) - # 系统会从该目录加载所有 .yaml 格式的工具配置文件 - # 推荐方式:在 tools/ 目录下为每个工具创建独立的配置文件 + tools_dir: tools # 工具配置文件目录(相对于配置文件所在目录) + # 系统会从该目录加载所有 .yaml 格式的工具配置文件 + # 推荐方式:在 tools/ 目录下为每个工具创建独立的配置文件 diff --git a/internal/handler/config.go b/internal/handler/config.go index e24a4e61..0d691006 100644 --- a/internal/handler/config.go +++ b/internal/handler/config.go @@ -1,6 +1,7 @@ package handler import ( + "bytes" "fmt" "net/http" "os" @@ -204,51 +205,27 @@ func (h *ConfigHandler) ApplyConfig(c *gin.Context) { // saveConfig 保存配置到文件 func (h *ConfigHandler) saveConfig() error { - // 读取现有配置文件 + // 读取现有配置文件并创建备份 data, err := os.ReadFile(h.configPath) if err != nil { return fmt.Errorf("读取配置文件失败: %w", err) } - // 解析现有配置 - var existingConfig map[string]interface{} - if err := yaml.Unmarshal(data, &existingConfig); err != nil { + if err := os.WriteFile(h.configPath+".backup", data, 0644); err != nil { + h.logger.Warn("创建配置备份失败", zap.Error(err)) + } + + root, err := loadYAMLDocument(h.configPath) + if err != nil { return fmt.Errorf("解析配置文件失败: %w", err) } - // 更新配置值 - if existingConfig["openai"] == nil { - existingConfig["openai"] = make(map[string]interface{}) - } - openaiMap := existingConfig["openai"].(map[string]interface{}) - if h.config.OpenAI.APIKey != "" { - openaiMap["api_key"] = h.config.OpenAI.APIKey - } - if h.config.OpenAI.BaseURL != "" { - openaiMap["base_url"] = h.config.OpenAI.BaseURL - } - if h.config.OpenAI.Model != "" { - openaiMap["model"] = h.config.OpenAI.Model - } + updateAgentConfig(root, h.config.Agent.MaxIterations) + updateMCPConfig(root, h.config.MCP) + updateOpenAIConfig(root, h.config.OpenAI) - if existingConfig["mcp"] == nil { - existingConfig["mcp"] = make(map[string]interface{}) - } - mcpMap := existingConfig["mcp"].(map[string]interface{}) - mcpMap["enabled"] = h.config.MCP.Enabled - if h.config.MCP.Host != "" { - mcpMap["host"] = h.config.MCP.Host - } - if h.config.MCP.Port > 0 { - mcpMap["port"] = h.config.MCP.Port - } - - if h.config.Agent.MaxIterations > 0 { - if existingConfig["agent"] == nil { - existingConfig["agent"] = make(map[string]interface{}) - } - agentMap := existingConfig["agent"].(map[string]interface{}) - agentMap["max_iterations"] = h.config.Agent.MaxIterations + if err := writeYAMLDocument(h.configPath, root); err != nil { + return fmt.Errorf("保存配置文件失败: %w", err) } // 更新工具配置文件中的enabled状态 @@ -271,31 +248,25 @@ func (h *ConfigHandler) saveConfig() error { } } - // 读取工具配置文件 toolData, err := os.ReadFile(toolFile) if err != nil { h.logger.Warn("读取工具配置文件失败", zap.String("tool", tool.Name), zap.Error(err)) continue } - // 解析工具配置 - var toolConfig map[string]interface{} - if err := yaml.Unmarshal(toolData, &toolConfig); err != nil { - h.logger.Warn("解析工具配置文件失败", zap.String("tool", tool.Name), zap.Error(err)) - continue + if err := os.WriteFile(toolFile+".backup", toolData, 0644); err != nil { + h.logger.Warn("创建工具配置备份失败", zap.String("tool", tool.Name), zap.Error(err)) } - // 更新enabled状态 - toolConfig["enabled"] = tool.Enabled - - // 保存工具配置文件 - updatedData, err := yaml.Marshal(toolConfig) + toolDoc, err := loadYAMLDocument(toolFile) if err != nil { - h.logger.Warn("序列化工具配置失败", zap.String("tool", tool.Name), zap.Error(err)) + h.logger.Warn("解析工具配置失败", zap.String("tool", tool.Name), zap.Error(err)) continue } - if err := os.WriteFile(toolFile, updatedData, 0644); err != nil { + setBoolInMap(toolDoc.Content[0], "enabled", tool.Enabled) + + if err := writeYAMLDocument(toolFile, toolDoc); err != nil { h.logger.Warn("保存工具配置文件失败", zap.String("tool", tool.Name), zap.Error(err)) continue } @@ -304,24 +275,160 @@ func (h *ConfigHandler) saveConfig() error { } } - // 保存主配置文件 - updatedData, err := yaml.Marshal(existingConfig) - if err != nil { - return fmt.Errorf("序列化配置失败: %w", err) - } - - // 创建备份 - backupPath := h.configPath + ".backup" - if err := os.WriteFile(backupPath, data, 0644); err != nil { - h.logger.Warn("创建配置备份失败", zap.Error(err)) - } - - // 保存新配置 - if err := os.WriteFile(h.configPath, updatedData, 0644); err != nil { - return fmt.Errorf("保存配置文件失败: %w", err) - } - h.logger.Info("配置已保存", zap.String("path", h.configPath)) return nil } +func loadYAMLDocument(path string) (*yaml.Node, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, err + } + + if len(bytes.TrimSpace(data)) == 0 { + return newEmptyYAMLDocument(), nil + } + + var doc yaml.Node + if err := yaml.Unmarshal(data, &doc); err != nil { + return nil, err + } + + if doc.Kind != yaml.DocumentNode || len(doc.Content) == 0 { + return newEmptyYAMLDocument(), nil + } + + if doc.Content[0].Kind != yaml.MappingNode { + root := &yaml.Node{Kind: yaml.MappingNode, Tag: "!!map"} + doc.Content = []*yaml.Node{root} + } + + return &doc, nil +} + +func newEmptyYAMLDocument() *yaml.Node { + root := &yaml.Node{ + Kind: yaml.DocumentNode, + Content: []*yaml.Node{{Kind: yaml.MappingNode, Tag: "!!map"}}, + } + return root +} + +func writeYAMLDocument(path string, doc *yaml.Node) error { + var buf bytes.Buffer + encoder := yaml.NewEncoder(&buf) + encoder.SetIndent(2) + if err := encoder.Encode(doc); err != nil { + return err + } + if err := encoder.Close(); err != nil { + return err + } + return os.WriteFile(path, buf.Bytes(), 0644) +} + +func updateAgentConfig(doc *yaml.Node, maxIterations int) { + root := doc.Content[0] + agentNode := ensureMap(root, "agent") + setIntInMap(agentNode, "max_iterations", maxIterations) +} + +func updateMCPConfig(doc *yaml.Node, cfg config.MCPConfig) { + root := doc.Content[0] + mcpNode := ensureMap(root, "mcp") + setBoolInMap(mcpNode, "enabled", cfg.Enabled) + setStringInMap(mcpNode, "host", cfg.Host) + setIntInMap(mcpNode, "port", cfg.Port) +} + +func updateOpenAIConfig(doc *yaml.Node, cfg config.OpenAIConfig) { + root := doc.Content[0] + openaiNode := ensureMap(root, "openai") + setStringInMap(openaiNode, "api_key", cfg.APIKey) + setStringInMap(openaiNode, "base_url", cfg.BaseURL) + setStringInMap(openaiNode, "model", cfg.Model) +} + +func ensureMap(parent *yaml.Node, path ...string) *yaml.Node { + current := parent + for _, key := range path { + value := findMapValue(current, key) + if value == nil { + keyNode := &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: key} + mapNode := &yaml.Node{Kind: yaml.MappingNode, Tag: "!!map"} + current.Content = append(current.Content, keyNode, mapNode) + value = mapNode + } + + if value.Kind != yaml.MappingNode { + value.Kind = yaml.MappingNode + value.Tag = "!!map" + value.Style = 0 + value.Content = nil + } + + current = value + } + + return current +} + +func findMapValue(mapNode *yaml.Node, key string) *yaml.Node { + if mapNode == nil || mapNode.Kind != yaml.MappingNode { + return nil + } + + for i := 0; i < len(mapNode.Content); i += 2 { + if mapNode.Content[i].Value == key { + return mapNode.Content[i+1] + } + } + return nil +} + +func ensureKeyValue(mapNode *yaml.Node, key string) (*yaml.Node, *yaml.Node) { + if mapNode == nil || mapNode.Kind != yaml.MappingNode { + return nil, nil + } + + for i := 0; i < len(mapNode.Content); i += 2 { + if mapNode.Content[i].Value == key { + return mapNode.Content[i], mapNode.Content[i+1] + } + } + + keyNode := &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: key} + valueNode := &yaml.Node{} + mapNode.Content = append(mapNode.Content, keyNode, valueNode) + return keyNode, valueNode +} + +func setStringInMap(mapNode *yaml.Node, key, value string) { + _, valueNode := ensureKeyValue(mapNode, key) + valueNode.Kind = yaml.ScalarNode + valueNode.Tag = "!!str" + valueNode.Style = 0 + valueNode.Value = value +} + +func setIntInMap(mapNode *yaml.Node, key string, value int) { + _, valueNode := ensureKeyValue(mapNode, key) + valueNode.Kind = yaml.ScalarNode + valueNode.Tag = "!!int" + valueNode.Style = 0 + valueNode.Value = fmt.Sprintf("%d", value) +} + +func setBoolInMap(mapNode *yaml.Node, key string, value bool) { + _, valueNode := ensureKeyValue(mapNode, key) + valueNode.Kind = yaml.ScalarNode + valueNode.Tag = "!!bool" + valueNode.Style = 0 + if value { + valueNode.Value = "true" + } else { + valueNode.Value = "false" + } +} + + diff --git a/web/templates/index.html b/web/templates/index.html index 8846cdfa..75ad0bd4 100644 --- a/web/templates/index.html +++ b/web/templates/index.html @@ -68,14 +68,14 @@