mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-03-31 16:20:28 +02:00
435 lines
11 KiB
Go
435 lines
11 KiB
Go
package handler
|
||
|
||
import (
|
||
"bytes"
|
||
"fmt"
|
||
"net/http"
|
||
"os"
|
||
"path/filepath"
|
||
"sync"
|
||
|
||
"cyberstrike-ai/internal/config"
|
||
"cyberstrike-ai/internal/mcp"
|
||
"cyberstrike-ai/internal/security"
|
||
"github.com/gin-gonic/gin"
|
||
"go.uber.org/zap"
|
||
"gopkg.in/yaml.v3"
|
||
)
|
||
|
||
// ConfigHandler 配置处理器
|
||
type ConfigHandler struct {
|
||
configPath string
|
||
config *config.Config
|
||
mcpServer *mcp.Server
|
||
executor *security.Executor
|
||
agent AgentUpdater // Agent接口,用于更新Agent配置
|
||
logger *zap.Logger
|
||
mu sync.RWMutex
|
||
}
|
||
|
||
// AgentUpdater Agent更新接口
|
||
type AgentUpdater interface {
|
||
UpdateConfig(cfg *config.OpenAIConfig)
|
||
UpdateMaxIterations(maxIterations int)
|
||
}
|
||
|
||
// NewConfigHandler 创建新的配置处理器
|
||
func NewConfigHandler(configPath string, cfg *config.Config, mcpServer *mcp.Server, executor *security.Executor, agent AgentUpdater, logger *zap.Logger) *ConfigHandler {
|
||
return &ConfigHandler{
|
||
configPath: configPath,
|
||
config: cfg,
|
||
mcpServer: mcpServer,
|
||
executor: executor,
|
||
agent: agent,
|
||
logger: logger,
|
||
}
|
||
}
|
||
|
||
// GetConfigResponse 获取配置响应
|
||
type GetConfigResponse struct {
|
||
OpenAI config.OpenAIConfig `json:"openai"`
|
||
MCP config.MCPConfig `json:"mcp"`
|
||
Tools []ToolConfigInfo `json:"tools"`
|
||
Agent config.AgentConfig `json:"agent"`
|
||
}
|
||
|
||
// ToolConfigInfo 工具配置信息
|
||
type ToolConfigInfo struct {
|
||
Name string `json:"name"`
|
||
Description string `json:"description"`
|
||
Enabled bool `json:"enabled"`
|
||
}
|
||
|
||
// GetConfig 获取当前配置
|
||
func (h *ConfigHandler) GetConfig(c *gin.Context) {
|
||
h.mu.RLock()
|
||
defer h.mu.RUnlock()
|
||
|
||
// 获取工具列表
|
||
tools := make([]ToolConfigInfo, 0, len(h.config.Security.Tools))
|
||
for _, tool := range h.config.Security.Tools {
|
||
tools = append(tools, ToolConfigInfo{
|
||
Name: tool.Name,
|
||
Description: tool.ShortDescription,
|
||
Enabled: tool.Enabled,
|
||
})
|
||
// 如果没有简短描述,使用详细描述的前100个字符
|
||
if tools[len(tools)-1].Description == "" {
|
||
desc := tool.Description
|
||
if len(desc) > 100 {
|
||
desc = desc[:100] + "..."
|
||
}
|
||
tools[len(tools)-1].Description = desc
|
||
}
|
||
}
|
||
|
||
c.JSON(http.StatusOK, GetConfigResponse{
|
||
OpenAI: h.config.OpenAI,
|
||
MCP: h.config.MCP,
|
||
Tools: tools,
|
||
Agent: h.config.Agent,
|
||
})
|
||
}
|
||
|
||
// UpdateConfigRequest 更新配置请求
|
||
type UpdateConfigRequest struct {
|
||
OpenAI *config.OpenAIConfig `json:"openai,omitempty"`
|
||
MCP *config.MCPConfig `json:"mcp,omitempty"`
|
||
Tools []ToolEnableStatus `json:"tools,omitempty"`
|
||
Agent *config.AgentConfig `json:"agent,omitempty"`
|
||
}
|
||
|
||
// ToolEnableStatus 工具启用状态
|
||
type ToolEnableStatus struct {
|
||
Name string `json:"name"`
|
||
Enabled bool `json:"enabled"`
|
||
}
|
||
|
||
// UpdateConfig 更新配置
|
||
func (h *ConfigHandler) UpdateConfig(c *gin.Context) {
|
||
var req UpdateConfigRequest
|
||
if err := c.ShouldBindJSON(&req); err != nil {
|
||
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()})
|
||
return
|
||
}
|
||
|
||
h.mu.Lock()
|
||
defer h.mu.Unlock()
|
||
|
||
// 更新OpenAI配置
|
||
if req.OpenAI != nil {
|
||
h.config.OpenAI = *req.OpenAI
|
||
h.logger.Info("更新OpenAI配置",
|
||
zap.String("base_url", h.config.OpenAI.BaseURL),
|
||
zap.String("model", h.config.OpenAI.Model),
|
||
)
|
||
}
|
||
|
||
// 更新MCP配置
|
||
if req.MCP != nil {
|
||
h.config.MCP = *req.MCP
|
||
h.logger.Info("更新MCP配置",
|
||
zap.Bool("enabled", h.config.MCP.Enabled),
|
||
zap.String("host", h.config.MCP.Host),
|
||
zap.Int("port", h.config.MCP.Port),
|
||
)
|
||
}
|
||
|
||
// 更新Agent配置
|
||
if req.Agent != nil {
|
||
h.config.Agent = *req.Agent
|
||
h.logger.Info("更新Agent配置",
|
||
zap.Int("max_iterations", h.config.Agent.MaxIterations),
|
||
)
|
||
}
|
||
|
||
// 更新工具启用状态
|
||
if req.Tools != nil {
|
||
toolMap := make(map[string]bool)
|
||
for _, toolStatus := range req.Tools {
|
||
toolMap[toolStatus.Name] = toolStatus.Enabled
|
||
}
|
||
|
||
// 更新配置中的工具状态
|
||
for i := range h.config.Security.Tools {
|
||
if enabled, ok := toolMap[h.config.Security.Tools[i].Name]; ok {
|
||
h.config.Security.Tools[i].Enabled = enabled
|
||
h.logger.Info("更新工具启用状态",
|
||
zap.String("tool", h.config.Security.Tools[i].Name),
|
||
zap.Bool("enabled", enabled),
|
||
)
|
||
}
|
||
}
|
||
}
|
||
|
||
// 保存配置到文件
|
||
if err := h.saveConfig(); err != nil {
|
||
h.logger.Error("保存配置失败", zap.Error(err))
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": "保存配置失败: " + err.Error()})
|
||
return
|
||
}
|
||
|
||
c.JSON(http.StatusOK, gin.H{"message": "配置已更新"})
|
||
}
|
||
|
||
// ApplyConfig 应用配置(重新加载并重启相关服务)
|
||
func (h *ConfigHandler) ApplyConfig(c *gin.Context) {
|
||
h.mu.Lock()
|
||
defer h.mu.Unlock()
|
||
|
||
// 重新注册工具(根据新的启用状态)
|
||
h.logger.Info("重新注册工具")
|
||
|
||
// 清空MCP服务器中的工具
|
||
h.mcpServer.ClearTools()
|
||
|
||
// 重新注册工具
|
||
h.executor.RegisterTools(h.mcpServer)
|
||
|
||
// 更新Agent的OpenAI配置
|
||
if h.agent != nil {
|
||
h.agent.UpdateConfig(&h.config.OpenAI)
|
||
h.agent.UpdateMaxIterations(h.config.Agent.MaxIterations)
|
||
h.logger.Info("Agent配置已更新")
|
||
}
|
||
|
||
h.logger.Info("配置已应用",
|
||
zap.Int("tools_count", len(h.config.Security.Tools)),
|
||
)
|
||
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"message": "配置已应用",
|
||
"tools_count": len(h.config.Security.Tools),
|
||
})
|
||
}
|
||
|
||
// saveConfig 保存配置到文件
|
||
func (h *ConfigHandler) saveConfig() error {
|
||
// 读取现有配置文件并创建备份
|
||
data, err := os.ReadFile(h.configPath)
|
||
if err != nil {
|
||
return fmt.Errorf("读取配置文件失败: %w", err)
|
||
}
|
||
|
||
if err := os.WriteFile(h.configPath+".backup", data, 0644); err != nil {
|
||
h.logger.Warn("创建配置备份失败", zap.Error(err))
|
||
}
|
||
|
||
root, err := loadYAMLDocument(h.configPath)
|
||
if err != nil {
|
||
return fmt.Errorf("解析配置文件失败: %w", err)
|
||
}
|
||
|
||
updateAgentConfig(root, h.config.Agent.MaxIterations)
|
||
updateMCPConfig(root, h.config.MCP)
|
||
updateOpenAIConfig(root, h.config.OpenAI)
|
||
|
||
if err := writeYAMLDocument(h.configPath, root); err != nil {
|
||
return fmt.Errorf("保存配置文件失败: %w", err)
|
||
}
|
||
|
||
// 更新工具配置文件中的enabled状态
|
||
if h.config.Security.ToolsDir != "" {
|
||
configDir := filepath.Dir(h.configPath)
|
||
toolsDir := h.config.Security.ToolsDir
|
||
if !filepath.IsAbs(toolsDir) {
|
||
toolsDir = filepath.Join(configDir, toolsDir)
|
||
}
|
||
|
||
for _, tool := range h.config.Security.Tools {
|
||
toolFile := filepath.Join(toolsDir, tool.Name+".yaml")
|
||
// 检查文件是否存在
|
||
if _, err := os.Stat(toolFile); os.IsNotExist(err) {
|
||
// 尝试.yml扩展名
|
||
toolFile = filepath.Join(toolsDir, tool.Name+".yml")
|
||
if _, err := os.Stat(toolFile); os.IsNotExist(err) {
|
||
h.logger.Warn("工具配置文件不存在", zap.String("tool", tool.Name))
|
||
continue
|
||
}
|
||
}
|
||
|
||
toolData, err := os.ReadFile(toolFile)
|
||
if err != nil {
|
||
h.logger.Warn("读取工具配置文件失败", zap.String("tool", tool.Name), zap.Error(err))
|
||
continue
|
||
}
|
||
|
||
if err := os.WriteFile(toolFile+".backup", toolData, 0644); err != nil {
|
||
h.logger.Warn("创建工具配置备份失败", zap.String("tool", tool.Name), zap.Error(err))
|
||
}
|
||
|
||
toolDoc, err := loadYAMLDocument(toolFile)
|
||
if err != nil {
|
||
h.logger.Warn("解析工具配置失败", zap.String("tool", tool.Name), zap.Error(err))
|
||
continue
|
||
}
|
||
|
||
setBoolInMap(toolDoc.Content[0], "enabled", tool.Enabled)
|
||
|
||
if err := writeYAMLDocument(toolFile, toolDoc); err != nil {
|
||
h.logger.Warn("保存工具配置文件失败", zap.String("tool", tool.Name), zap.Error(err))
|
||
continue
|
||
}
|
||
|
||
h.logger.Info("更新工具配置", zap.String("tool", tool.Name), zap.Bool("enabled", tool.Enabled))
|
||
}
|
||
}
|
||
|
||
h.logger.Info("配置已保存", zap.String("path", h.configPath))
|
||
return nil
|
||
}
|
||
|
||
func loadYAMLDocument(path string) (*yaml.Node, error) {
|
||
data, err := os.ReadFile(path)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
if len(bytes.TrimSpace(data)) == 0 {
|
||
return newEmptyYAMLDocument(), nil
|
||
}
|
||
|
||
var doc yaml.Node
|
||
if err := yaml.Unmarshal(data, &doc); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
if doc.Kind != yaml.DocumentNode || len(doc.Content) == 0 {
|
||
return newEmptyYAMLDocument(), nil
|
||
}
|
||
|
||
if doc.Content[0].Kind != yaml.MappingNode {
|
||
root := &yaml.Node{Kind: yaml.MappingNode, Tag: "!!map"}
|
||
doc.Content = []*yaml.Node{root}
|
||
}
|
||
|
||
return &doc, nil
|
||
}
|
||
|
||
func newEmptyYAMLDocument() *yaml.Node {
|
||
root := &yaml.Node{
|
||
Kind: yaml.DocumentNode,
|
||
Content: []*yaml.Node{{Kind: yaml.MappingNode, Tag: "!!map"}},
|
||
}
|
||
return root
|
||
}
|
||
|
||
func writeYAMLDocument(path string, doc *yaml.Node) error {
|
||
var buf bytes.Buffer
|
||
encoder := yaml.NewEncoder(&buf)
|
||
encoder.SetIndent(2)
|
||
if err := encoder.Encode(doc); err != nil {
|
||
return err
|
||
}
|
||
if err := encoder.Close(); err != nil {
|
||
return err
|
||
}
|
||
return os.WriteFile(path, buf.Bytes(), 0644)
|
||
}
|
||
|
||
func updateAgentConfig(doc *yaml.Node, maxIterations int) {
|
||
root := doc.Content[0]
|
||
agentNode := ensureMap(root, "agent")
|
||
setIntInMap(agentNode, "max_iterations", maxIterations)
|
||
}
|
||
|
||
func updateMCPConfig(doc *yaml.Node, cfg config.MCPConfig) {
|
||
root := doc.Content[0]
|
||
mcpNode := ensureMap(root, "mcp")
|
||
setBoolInMap(mcpNode, "enabled", cfg.Enabled)
|
||
setStringInMap(mcpNode, "host", cfg.Host)
|
||
setIntInMap(mcpNode, "port", cfg.Port)
|
||
}
|
||
|
||
func updateOpenAIConfig(doc *yaml.Node, cfg config.OpenAIConfig) {
|
||
root := doc.Content[0]
|
||
openaiNode := ensureMap(root, "openai")
|
||
setStringInMap(openaiNode, "api_key", cfg.APIKey)
|
||
setStringInMap(openaiNode, "base_url", cfg.BaseURL)
|
||
setStringInMap(openaiNode, "model", cfg.Model)
|
||
}
|
||
|
||
func ensureMap(parent *yaml.Node, path ...string) *yaml.Node {
|
||
current := parent
|
||
for _, key := range path {
|
||
value := findMapValue(current, key)
|
||
if value == nil {
|
||
keyNode := &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: key}
|
||
mapNode := &yaml.Node{Kind: yaml.MappingNode, Tag: "!!map"}
|
||
current.Content = append(current.Content, keyNode, mapNode)
|
||
value = mapNode
|
||
}
|
||
|
||
if value.Kind != yaml.MappingNode {
|
||
value.Kind = yaml.MappingNode
|
||
value.Tag = "!!map"
|
||
value.Style = 0
|
||
value.Content = nil
|
||
}
|
||
|
||
current = value
|
||
}
|
||
|
||
return current
|
||
}
|
||
|
||
func findMapValue(mapNode *yaml.Node, key string) *yaml.Node {
|
||
if mapNode == nil || mapNode.Kind != yaml.MappingNode {
|
||
return nil
|
||
}
|
||
|
||
for i := 0; i < len(mapNode.Content); i += 2 {
|
||
if mapNode.Content[i].Value == key {
|
||
return mapNode.Content[i+1]
|
||
}
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func ensureKeyValue(mapNode *yaml.Node, key string) (*yaml.Node, *yaml.Node) {
|
||
if mapNode == nil || mapNode.Kind != yaml.MappingNode {
|
||
return nil, nil
|
||
}
|
||
|
||
for i := 0; i < len(mapNode.Content); i += 2 {
|
||
if mapNode.Content[i].Value == key {
|
||
return mapNode.Content[i], mapNode.Content[i+1]
|
||
}
|
||
}
|
||
|
||
keyNode := &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: key}
|
||
valueNode := &yaml.Node{}
|
||
mapNode.Content = append(mapNode.Content, keyNode, valueNode)
|
||
return keyNode, valueNode
|
||
}
|
||
|
||
func setStringInMap(mapNode *yaml.Node, key, value string) {
|
||
_, valueNode := ensureKeyValue(mapNode, key)
|
||
valueNode.Kind = yaml.ScalarNode
|
||
valueNode.Tag = "!!str"
|
||
valueNode.Style = 0
|
||
valueNode.Value = value
|
||
}
|
||
|
||
func setIntInMap(mapNode *yaml.Node, key string, value int) {
|
||
_, valueNode := ensureKeyValue(mapNode, key)
|
||
valueNode.Kind = yaml.ScalarNode
|
||
valueNode.Tag = "!!int"
|
||
valueNode.Style = 0
|
||
valueNode.Value = fmt.Sprintf("%d", value)
|
||
}
|
||
|
||
func setBoolInMap(mapNode *yaml.Node, key string, value bool) {
|
||
_, valueNode := ensureKeyValue(mapNode, key)
|
||
valueNode.Kind = yaml.ScalarNode
|
||
valueNode.Tag = "!!bool"
|
||
valueNode.Style = 0
|
||
if value {
|
||
valueNode.Value = "true"
|
||
} else {
|
||
valueNode.Value = "false"
|
||
}
|
||
}
|
||
|
||
|