Files
CyberStrikeAI/internal/handler/config.go
2025-11-13 21:56:50 +08:00

435 lines
11 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package handler
import (
"bytes"
"fmt"
"net/http"
"os"
"path/filepath"
"sync"
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/mcp"
"cyberstrike-ai/internal/security"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
"gopkg.in/yaml.v3"
)
// ConfigHandler 配置处理器
type ConfigHandler struct {
configPath string
config *config.Config
mcpServer *mcp.Server
executor *security.Executor
agent AgentUpdater // Agent接口用于更新Agent配置
logger *zap.Logger
mu sync.RWMutex
}
// AgentUpdater Agent更新接口
type AgentUpdater interface {
UpdateConfig(cfg *config.OpenAIConfig)
UpdateMaxIterations(maxIterations int)
}
// NewConfigHandler 创建新的配置处理器
func NewConfigHandler(configPath string, cfg *config.Config, mcpServer *mcp.Server, executor *security.Executor, agent AgentUpdater, logger *zap.Logger) *ConfigHandler {
return &ConfigHandler{
configPath: configPath,
config: cfg,
mcpServer: mcpServer,
executor: executor,
agent: agent,
logger: logger,
}
}
// GetConfigResponse 获取配置响应
type GetConfigResponse struct {
OpenAI config.OpenAIConfig `json:"openai"`
MCP config.MCPConfig `json:"mcp"`
Tools []ToolConfigInfo `json:"tools"`
Agent config.AgentConfig `json:"agent"`
}
// ToolConfigInfo 工具配置信息
type ToolConfigInfo struct {
Name string `json:"name"`
Description string `json:"description"`
Enabled bool `json:"enabled"`
}
// GetConfig 获取当前配置
func (h *ConfigHandler) GetConfig(c *gin.Context) {
h.mu.RLock()
defer h.mu.RUnlock()
// 获取工具列表
tools := make([]ToolConfigInfo, 0, len(h.config.Security.Tools))
for _, tool := range h.config.Security.Tools {
tools = append(tools, ToolConfigInfo{
Name: tool.Name,
Description: tool.ShortDescription,
Enabled: tool.Enabled,
})
// 如果没有简短描述使用详细描述的前100个字符
if tools[len(tools)-1].Description == "" {
desc := tool.Description
if len(desc) > 100 {
desc = desc[:100] + "..."
}
tools[len(tools)-1].Description = desc
}
}
c.JSON(http.StatusOK, GetConfigResponse{
OpenAI: h.config.OpenAI,
MCP: h.config.MCP,
Tools: tools,
Agent: h.config.Agent,
})
}
// UpdateConfigRequest 更新配置请求
type UpdateConfigRequest struct {
OpenAI *config.OpenAIConfig `json:"openai,omitempty"`
MCP *config.MCPConfig `json:"mcp,omitempty"`
Tools []ToolEnableStatus `json:"tools,omitempty"`
Agent *config.AgentConfig `json:"agent,omitempty"`
}
// ToolEnableStatus 工具启用状态
type ToolEnableStatus struct {
Name string `json:"name"`
Enabled bool `json:"enabled"`
}
// UpdateConfig 更新配置
func (h *ConfigHandler) UpdateConfig(c *gin.Context) {
var req UpdateConfigRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()})
return
}
h.mu.Lock()
defer h.mu.Unlock()
// 更新OpenAI配置
if req.OpenAI != nil {
h.config.OpenAI = *req.OpenAI
h.logger.Info("更新OpenAI配置",
zap.String("base_url", h.config.OpenAI.BaseURL),
zap.String("model", h.config.OpenAI.Model),
)
}
// 更新MCP配置
if req.MCP != nil {
h.config.MCP = *req.MCP
h.logger.Info("更新MCP配置",
zap.Bool("enabled", h.config.MCP.Enabled),
zap.String("host", h.config.MCP.Host),
zap.Int("port", h.config.MCP.Port),
)
}
// 更新Agent配置
if req.Agent != nil {
h.config.Agent = *req.Agent
h.logger.Info("更新Agent配置",
zap.Int("max_iterations", h.config.Agent.MaxIterations),
)
}
// 更新工具启用状态
if req.Tools != nil {
toolMap := make(map[string]bool)
for _, toolStatus := range req.Tools {
toolMap[toolStatus.Name] = toolStatus.Enabled
}
// 更新配置中的工具状态
for i := range h.config.Security.Tools {
if enabled, ok := toolMap[h.config.Security.Tools[i].Name]; ok {
h.config.Security.Tools[i].Enabled = enabled
h.logger.Info("更新工具启用状态",
zap.String("tool", h.config.Security.Tools[i].Name),
zap.Bool("enabled", enabled),
)
}
}
}
// 保存配置到文件
if err := h.saveConfig(); err != nil {
h.logger.Error("保存配置失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "保存配置失败: " + err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"message": "配置已更新"})
}
// ApplyConfig 应用配置(重新加载并重启相关服务)
func (h *ConfigHandler) ApplyConfig(c *gin.Context) {
h.mu.Lock()
defer h.mu.Unlock()
// 重新注册工具(根据新的启用状态)
h.logger.Info("重新注册工具")
// 清空MCP服务器中的工具
h.mcpServer.ClearTools()
// 重新注册工具
h.executor.RegisterTools(h.mcpServer)
// 更新Agent的OpenAI配置
if h.agent != nil {
h.agent.UpdateConfig(&h.config.OpenAI)
h.agent.UpdateMaxIterations(h.config.Agent.MaxIterations)
h.logger.Info("Agent配置已更新")
}
h.logger.Info("配置已应用",
zap.Int("tools_count", len(h.config.Security.Tools)),
)
c.JSON(http.StatusOK, gin.H{
"message": "配置已应用",
"tools_count": len(h.config.Security.Tools),
})
}
// saveConfig 保存配置到文件
func (h *ConfigHandler) saveConfig() error {
// 读取现有配置文件并创建备份
data, err := os.ReadFile(h.configPath)
if err != nil {
return fmt.Errorf("读取配置文件失败: %w", err)
}
if err := os.WriteFile(h.configPath+".backup", data, 0644); err != nil {
h.logger.Warn("创建配置备份失败", zap.Error(err))
}
root, err := loadYAMLDocument(h.configPath)
if err != nil {
return fmt.Errorf("解析配置文件失败: %w", err)
}
updateAgentConfig(root, h.config.Agent.MaxIterations)
updateMCPConfig(root, h.config.MCP)
updateOpenAIConfig(root, h.config.OpenAI)
if err := writeYAMLDocument(h.configPath, root); err != nil {
return fmt.Errorf("保存配置文件失败: %w", err)
}
// 更新工具配置文件中的enabled状态
if h.config.Security.ToolsDir != "" {
configDir := filepath.Dir(h.configPath)
toolsDir := h.config.Security.ToolsDir
if !filepath.IsAbs(toolsDir) {
toolsDir = filepath.Join(configDir, toolsDir)
}
for _, tool := range h.config.Security.Tools {
toolFile := filepath.Join(toolsDir, tool.Name+".yaml")
// 检查文件是否存在
if _, err := os.Stat(toolFile); os.IsNotExist(err) {
// 尝试.yml扩展名
toolFile = filepath.Join(toolsDir, tool.Name+".yml")
if _, err := os.Stat(toolFile); os.IsNotExist(err) {
h.logger.Warn("工具配置文件不存在", zap.String("tool", tool.Name))
continue
}
}
toolData, err := os.ReadFile(toolFile)
if err != nil {
h.logger.Warn("读取工具配置文件失败", zap.String("tool", tool.Name), zap.Error(err))
continue
}
if err := os.WriteFile(toolFile+".backup", toolData, 0644); err != nil {
h.logger.Warn("创建工具配置备份失败", zap.String("tool", tool.Name), zap.Error(err))
}
toolDoc, err := loadYAMLDocument(toolFile)
if err != nil {
h.logger.Warn("解析工具配置失败", zap.String("tool", tool.Name), zap.Error(err))
continue
}
setBoolInMap(toolDoc.Content[0], "enabled", tool.Enabled)
if err := writeYAMLDocument(toolFile, toolDoc); err != nil {
h.logger.Warn("保存工具配置文件失败", zap.String("tool", tool.Name), zap.Error(err))
continue
}
h.logger.Info("更新工具配置", zap.String("tool", tool.Name), zap.Bool("enabled", tool.Enabled))
}
}
h.logger.Info("配置已保存", zap.String("path", h.configPath))
return nil
}
func loadYAMLDocument(path string) (*yaml.Node, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, err
}
if len(bytes.TrimSpace(data)) == 0 {
return newEmptyYAMLDocument(), nil
}
var doc yaml.Node
if err := yaml.Unmarshal(data, &doc); err != nil {
return nil, err
}
if doc.Kind != yaml.DocumentNode || len(doc.Content) == 0 {
return newEmptyYAMLDocument(), nil
}
if doc.Content[0].Kind != yaml.MappingNode {
root := &yaml.Node{Kind: yaml.MappingNode, Tag: "!!map"}
doc.Content = []*yaml.Node{root}
}
return &doc, nil
}
func newEmptyYAMLDocument() *yaml.Node {
root := &yaml.Node{
Kind: yaml.DocumentNode,
Content: []*yaml.Node{{Kind: yaml.MappingNode, Tag: "!!map"}},
}
return root
}
func writeYAMLDocument(path string, doc *yaml.Node) error {
var buf bytes.Buffer
encoder := yaml.NewEncoder(&buf)
encoder.SetIndent(2)
if err := encoder.Encode(doc); err != nil {
return err
}
if err := encoder.Close(); err != nil {
return err
}
return os.WriteFile(path, buf.Bytes(), 0644)
}
func updateAgentConfig(doc *yaml.Node, maxIterations int) {
root := doc.Content[0]
agentNode := ensureMap(root, "agent")
setIntInMap(agentNode, "max_iterations", maxIterations)
}
func updateMCPConfig(doc *yaml.Node, cfg config.MCPConfig) {
root := doc.Content[0]
mcpNode := ensureMap(root, "mcp")
setBoolInMap(mcpNode, "enabled", cfg.Enabled)
setStringInMap(mcpNode, "host", cfg.Host)
setIntInMap(mcpNode, "port", cfg.Port)
}
func updateOpenAIConfig(doc *yaml.Node, cfg config.OpenAIConfig) {
root := doc.Content[0]
openaiNode := ensureMap(root, "openai")
setStringInMap(openaiNode, "api_key", cfg.APIKey)
setStringInMap(openaiNode, "base_url", cfg.BaseURL)
setStringInMap(openaiNode, "model", cfg.Model)
}
func ensureMap(parent *yaml.Node, path ...string) *yaml.Node {
current := parent
for _, key := range path {
value := findMapValue(current, key)
if value == nil {
keyNode := &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: key}
mapNode := &yaml.Node{Kind: yaml.MappingNode, Tag: "!!map"}
current.Content = append(current.Content, keyNode, mapNode)
value = mapNode
}
if value.Kind != yaml.MappingNode {
value.Kind = yaml.MappingNode
value.Tag = "!!map"
value.Style = 0
value.Content = nil
}
current = value
}
return current
}
func findMapValue(mapNode *yaml.Node, key string) *yaml.Node {
if mapNode == nil || mapNode.Kind != yaml.MappingNode {
return nil
}
for i := 0; i < len(mapNode.Content); i += 2 {
if mapNode.Content[i].Value == key {
return mapNode.Content[i+1]
}
}
return nil
}
func ensureKeyValue(mapNode *yaml.Node, key string) (*yaml.Node, *yaml.Node) {
if mapNode == nil || mapNode.Kind != yaml.MappingNode {
return nil, nil
}
for i := 0; i < len(mapNode.Content); i += 2 {
if mapNode.Content[i].Value == key {
return mapNode.Content[i], mapNode.Content[i+1]
}
}
keyNode := &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: key}
valueNode := &yaml.Node{}
mapNode.Content = append(mapNode.Content, keyNode, valueNode)
return keyNode, valueNode
}
func setStringInMap(mapNode *yaml.Node, key, value string) {
_, valueNode := ensureKeyValue(mapNode, key)
valueNode.Kind = yaml.ScalarNode
valueNode.Tag = "!!str"
valueNode.Style = 0
valueNode.Value = value
}
func setIntInMap(mapNode *yaml.Node, key string, value int) {
_, valueNode := ensureKeyValue(mapNode, key)
valueNode.Kind = yaml.ScalarNode
valueNode.Tag = "!!int"
valueNode.Style = 0
valueNode.Value = fmt.Sprintf("%d", value)
}
func setBoolInMap(mapNode *yaml.Node, key string, value bool) {
_, valueNode := ensureKeyValue(mapNode, key)
valueNode.Kind = yaml.ScalarNode
valueNode.Tag = "!!bool"
valueNode.Style = 0
if value {
valueNode.Value = "true"
} else {
valueNode.Value = "false"
}
}