mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-06-08 07:14:00 +02:00
Add files via upload
This commit is contained in:
@@ -9,6 +9,7 @@ import (
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
@@ -23,6 +24,7 @@ type Agent struct {
|
||||
mcpServer *mcp.Server
|
||||
logger *zap.Logger
|
||||
maxIterations int
|
||||
mu sync.RWMutex // 添加互斥锁以支持并发更新
|
||||
}
|
||||
|
||||
// NewAgent 创建新的Agent
|
||||
@@ -938,6 +940,27 @@ func (a *Agent) executeToolViaMCP(ctx context.Context, toolName string, args map
|
||||
}, nil
|
||||
}
|
||||
|
||||
// UpdateConfig 更新OpenAI配置
|
||||
func (a *Agent) UpdateConfig(cfg *config.OpenAIConfig) {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
a.config = cfg
|
||||
a.logger.Info("Agent配置已更新",
|
||||
zap.String("base_url", cfg.BaseURL),
|
||||
zap.String("model", cfg.Model),
|
||||
)
|
||||
}
|
||||
|
||||
// UpdateMaxIterations 更新最大迭代次数
|
||||
func (a *Agent) UpdateMaxIterations(maxIterations int) {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
if maxIterations > 0 {
|
||||
a.maxIterations = maxIterations
|
||||
a.logger.Info("Agent最大迭代次数已更新", zap.Int("max_iterations", maxIterations))
|
||||
}
|
||||
}
|
||||
|
||||
// formatToolError 格式化工具错误信息,提供更友好的错误描述
|
||||
func (a *Agent) formatToolError(toolName string, args map[string]interface{}, err error) string {
|
||||
errorMsg := fmt.Sprintf(`工具执行失败
|
||||
|
||||
+14
-2
@@ -73,9 +73,16 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
|
||||
agentHandler := handler.NewAgentHandler(agent, db, log.Logger)
|
||||
monitorHandler := handler.NewMonitorHandler(mcpServer, executor, log.Logger)
|
||||
conversationHandler := handler.NewConversationHandler(db, log.Logger)
|
||||
|
||||
// 获取配置文件路径
|
||||
configPath := "config.yaml"
|
||||
if len(os.Args) > 1 {
|
||||
configPath = os.Args[1]
|
||||
}
|
||||
configHandler := handler.NewConfigHandler(configPath, cfg, mcpServer, executor, agent, log.Logger)
|
||||
|
||||
// 设置路由
|
||||
setupRoutes(router, agentHandler, monitorHandler, conversationHandler, mcpServer)
|
||||
setupRoutes(router, agentHandler, monitorHandler, conversationHandler, configHandler, mcpServer)
|
||||
|
||||
return &App{
|
||||
config: cfg,
|
||||
@@ -113,7 +120,7 @@ func (a *App) Run() error {
|
||||
}
|
||||
|
||||
// setupRoutes 设置路由
|
||||
func setupRoutes(router *gin.Engine, agentHandler *handler.AgentHandler, monitorHandler *handler.MonitorHandler, conversationHandler *handler.ConversationHandler, mcpServer *mcp.Server) {
|
||||
func setupRoutes(router *gin.Engine, agentHandler *handler.AgentHandler, monitorHandler *handler.MonitorHandler, conversationHandler *handler.ConversationHandler, configHandler *handler.ConfigHandler, mcpServer *mcp.Server) {
|
||||
// API路由
|
||||
api := router.Group("/api")
|
||||
{
|
||||
@@ -137,6 +144,11 @@ func setupRoutes(router *gin.Engine, agentHandler *handler.AgentHandler, monitor
|
||||
api.GET("/monitor/stats", monitorHandler.GetStats)
|
||||
api.GET("/monitor/vulnerabilities", monitorHandler.GetVulnerabilities)
|
||||
|
||||
// 配置管理
|
||||
api.GET("/config", configHandler.GetConfig)
|
||||
api.PUT("/config", configHandler.UpdateConfig)
|
||||
api.POST("/config/apply", configHandler.ApplyConfig)
|
||||
|
||||
// MCP端点
|
||||
api.POST("/mcp", func(c *gin.Context) {
|
||||
mcpServer.HandleHTTP(c.Writer, c.Request)
|
||||
|
||||
@@ -36,9 +36,9 @@ type MCPConfig struct {
|
||||
}
|
||||
|
||||
type OpenAIConfig struct {
|
||||
APIKey string `yaml:"api_key"`
|
||||
BaseURL string `yaml:"base_url"`
|
||||
Model string `yaml:"model"`
|
||||
APIKey string `yaml:"api_key" json:"api_key"`
|
||||
BaseURL string `yaml:"base_url" json:"base_url"`
|
||||
Model string `yaml:"model" json:"model"`
|
||||
}
|
||||
|
||||
type SecurityConfig struct {
|
||||
@@ -51,7 +51,7 @@ type DatabaseConfig struct {
|
||||
}
|
||||
|
||||
type AgentConfig struct {
|
||||
MaxIterations int `yaml:"max_iterations"`
|
||||
MaxIterations int `yaml:"max_iterations" json:"max_iterations"`
|
||||
}
|
||||
|
||||
type ToolConfig struct {
|
||||
|
||||
@@ -0,0 +1,327 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
"cyberstrike-ai/internal/security"
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// ConfigHandler 配置处理器
|
||||
type ConfigHandler struct {
|
||||
configPath string
|
||||
config *config.Config
|
||||
mcpServer *mcp.Server
|
||||
executor *security.Executor
|
||||
agent AgentUpdater // Agent接口,用于更新Agent配置
|
||||
logger *zap.Logger
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// AgentUpdater Agent更新接口
|
||||
type AgentUpdater interface {
|
||||
UpdateConfig(cfg *config.OpenAIConfig)
|
||||
UpdateMaxIterations(maxIterations int)
|
||||
}
|
||||
|
||||
// NewConfigHandler 创建新的配置处理器
|
||||
func NewConfigHandler(configPath string, cfg *config.Config, mcpServer *mcp.Server, executor *security.Executor, agent AgentUpdater, logger *zap.Logger) *ConfigHandler {
|
||||
return &ConfigHandler{
|
||||
configPath: configPath,
|
||||
config: cfg,
|
||||
mcpServer: mcpServer,
|
||||
executor: executor,
|
||||
agent: agent,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// GetConfigResponse 获取配置响应
|
||||
type GetConfigResponse struct {
|
||||
OpenAI config.OpenAIConfig `json:"openai"`
|
||||
MCP config.MCPConfig `json:"mcp"`
|
||||
Tools []ToolConfigInfo `json:"tools"`
|
||||
Agent config.AgentConfig `json:"agent"`
|
||||
}
|
||||
|
||||
// ToolConfigInfo 工具配置信息
|
||||
type ToolConfigInfo struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Enabled bool `json:"enabled"`
|
||||
}
|
||||
|
||||
// GetConfig 获取当前配置
|
||||
func (h *ConfigHandler) GetConfig(c *gin.Context) {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
|
||||
// 获取工具列表
|
||||
tools := make([]ToolConfigInfo, 0, len(h.config.Security.Tools))
|
||||
for _, tool := range h.config.Security.Tools {
|
||||
tools = append(tools, ToolConfigInfo{
|
||||
Name: tool.Name,
|
||||
Description: tool.ShortDescription,
|
||||
Enabled: tool.Enabled,
|
||||
})
|
||||
// 如果没有简短描述,使用详细描述的前100个字符
|
||||
if tools[len(tools)-1].Description == "" {
|
||||
desc := tool.Description
|
||||
if len(desc) > 100 {
|
||||
desc = desc[:100] + "..."
|
||||
}
|
||||
tools[len(tools)-1].Description = desc
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, GetConfigResponse{
|
||||
OpenAI: h.config.OpenAI,
|
||||
MCP: h.config.MCP,
|
||||
Tools: tools,
|
||||
Agent: h.config.Agent,
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateConfigRequest 更新配置请求
|
||||
type UpdateConfigRequest struct {
|
||||
OpenAI *config.OpenAIConfig `json:"openai,omitempty"`
|
||||
MCP *config.MCPConfig `json:"mcp,omitempty"`
|
||||
Tools []ToolEnableStatus `json:"tools,omitempty"`
|
||||
Agent *config.AgentConfig `json:"agent,omitempty"`
|
||||
}
|
||||
|
||||
// ToolEnableStatus 工具启用状态
|
||||
type ToolEnableStatus struct {
|
||||
Name string `json:"name"`
|
||||
Enabled bool `json:"enabled"`
|
||||
}
|
||||
|
||||
// UpdateConfig 更新配置
|
||||
func (h *ConfigHandler) UpdateConfig(c *gin.Context) {
|
||||
var req UpdateConfigRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
// 更新OpenAI配置
|
||||
if req.OpenAI != nil {
|
||||
h.config.OpenAI = *req.OpenAI
|
||||
h.logger.Info("更新OpenAI配置",
|
||||
zap.String("base_url", h.config.OpenAI.BaseURL),
|
||||
zap.String("model", h.config.OpenAI.Model),
|
||||
)
|
||||
}
|
||||
|
||||
// 更新MCP配置
|
||||
if req.MCP != nil {
|
||||
h.config.MCP = *req.MCP
|
||||
h.logger.Info("更新MCP配置",
|
||||
zap.Bool("enabled", h.config.MCP.Enabled),
|
||||
zap.String("host", h.config.MCP.Host),
|
||||
zap.Int("port", h.config.MCP.Port),
|
||||
)
|
||||
}
|
||||
|
||||
// 更新Agent配置
|
||||
if req.Agent != nil {
|
||||
h.config.Agent = *req.Agent
|
||||
h.logger.Info("更新Agent配置",
|
||||
zap.Int("max_iterations", h.config.Agent.MaxIterations),
|
||||
)
|
||||
}
|
||||
|
||||
// 更新工具启用状态
|
||||
if req.Tools != nil {
|
||||
toolMap := make(map[string]bool)
|
||||
for _, toolStatus := range req.Tools {
|
||||
toolMap[toolStatus.Name] = toolStatus.Enabled
|
||||
}
|
||||
|
||||
// 更新配置中的工具状态
|
||||
for i := range h.config.Security.Tools {
|
||||
if enabled, ok := toolMap[h.config.Security.Tools[i].Name]; ok {
|
||||
h.config.Security.Tools[i].Enabled = enabled
|
||||
h.logger.Info("更新工具启用状态",
|
||||
zap.String("tool", h.config.Security.Tools[i].Name),
|
||||
zap.Bool("enabled", enabled),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 保存配置到文件
|
||||
if err := h.saveConfig(); err != nil {
|
||||
h.logger.Error("保存配置失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "保存配置失败: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "配置已更新"})
|
||||
}
|
||||
|
||||
// ApplyConfig 应用配置(重新加载并重启相关服务)
|
||||
func (h *ConfigHandler) ApplyConfig(c *gin.Context) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
// 重新注册工具(根据新的启用状态)
|
||||
h.logger.Info("重新注册工具")
|
||||
|
||||
// 清空MCP服务器中的工具
|
||||
h.mcpServer.ClearTools()
|
||||
|
||||
// 重新注册工具
|
||||
h.executor.RegisterTools(h.mcpServer)
|
||||
|
||||
// 更新Agent的OpenAI配置
|
||||
if h.agent != nil {
|
||||
h.agent.UpdateConfig(&h.config.OpenAI)
|
||||
h.agent.UpdateMaxIterations(h.config.Agent.MaxIterations)
|
||||
h.logger.Info("Agent配置已更新")
|
||||
}
|
||||
|
||||
h.logger.Info("配置已应用",
|
||||
zap.Int("tools_count", len(h.config.Security.Tools)),
|
||||
)
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "配置已应用",
|
||||
"tools_count": len(h.config.Security.Tools),
|
||||
})
|
||||
}
|
||||
|
||||
// saveConfig 保存配置到文件
|
||||
func (h *ConfigHandler) saveConfig() error {
|
||||
// 读取现有配置文件
|
||||
data, err := os.ReadFile(h.configPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("读取配置文件失败: %w", err)
|
||||
}
|
||||
|
||||
// 解析现有配置
|
||||
var existingConfig map[string]interface{}
|
||||
if err := yaml.Unmarshal(data, &existingConfig); err != nil {
|
||||
return fmt.Errorf("解析配置文件失败: %w", err)
|
||||
}
|
||||
|
||||
// 更新配置值
|
||||
if existingConfig["openai"] == nil {
|
||||
existingConfig["openai"] = make(map[string]interface{})
|
||||
}
|
||||
openaiMap := existingConfig["openai"].(map[string]interface{})
|
||||
if h.config.OpenAI.APIKey != "" {
|
||||
openaiMap["api_key"] = h.config.OpenAI.APIKey
|
||||
}
|
||||
if h.config.OpenAI.BaseURL != "" {
|
||||
openaiMap["base_url"] = h.config.OpenAI.BaseURL
|
||||
}
|
||||
if h.config.OpenAI.Model != "" {
|
||||
openaiMap["model"] = h.config.OpenAI.Model
|
||||
}
|
||||
|
||||
if existingConfig["mcp"] == nil {
|
||||
existingConfig["mcp"] = make(map[string]interface{})
|
||||
}
|
||||
mcpMap := existingConfig["mcp"].(map[string]interface{})
|
||||
mcpMap["enabled"] = h.config.MCP.Enabled
|
||||
if h.config.MCP.Host != "" {
|
||||
mcpMap["host"] = h.config.MCP.Host
|
||||
}
|
||||
if h.config.MCP.Port > 0 {
|
||||
mcpMap["port"] = h.config.MCP.Port
|
||||
}
|
||||
|
||||
if h.config.Agent.MaxIterations > 0 {
|
||||
if existingConfig["agent"] == nil {
|
||||
existingConfig["agent"] = make(map[string]interface{})
|
||||
}
|
||||
agentMap := existingConfig["agent"].(map[string]interface{})
|
||||
agentMap["max_iterations"] = h.config.Agent.MaxIterations
|
||||
}
|
||||
|
||||
// 更新工具配置文件中的enabled状态
|
||||
if h.config.Security.ToolsDir != "" {
|
||||
configDir := filepath.Dir(h.configPath)
|
||||
toolsDir := h.config.Security.ToolsDir
|
||||
if !filepath.IsAbs(toolsDir) {
|
||||
toolsDir = filepath.Join(configDir, toolsDir)
|
||||
}
|
||||
|
||||
for _, tool := range h.config.Security.Tools {
|
||||
toolFile := filepath.Join(toolsDir, tool.Name+".yaml")
|
||||
// 检查文件是否存在
|
||||
if _, err := os.Stat(toolFile); os.IsNotExist(err) {
|
||||
// 尝试.yml扩展名
|
||||
toolFile = filepath.Join(toolsDir, tool.Name+".yml")
|
||||
if _, err := os.Stat(toolFile); os.IsNotExist(err) {
|
||||
h.logger.Warn("工具配置文件不存在", zap.String("tool", tool.Name))
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// 读取工具配置文件
|
||||
toolData, err := os.ReadFile(toolFile)
|
||||
if err != nil {
|
||||
h.logger.Warn("读取工具配置文件失败", zap.String("tool", tool.Name), zap.Error(err))
|
||||
continue
|
||||
}
|
||||
|
||||
// 解析工具配置
|
||||
var toolConfig map[string]interface{}
|
||||
if err := yaml.Unmarshal(toolData, &toolConfig); err != nil {
|
||||
h.logger.Warn("解析工具配置文件失败", zap.String("tool", tool.Name), zap.Error(err))
|
||||
continue
|
||||
}
|
||||
|
||||
// 更新enabled状态
|
||||
toolConfig["enabled"] = tool.Enabled
|
||||
|
||||
// 保存工具配置文件
|
||||
updatedData, err := yaml.Marshal(toolConfig)
|
||||
if err != nil {
|
||||
h.logger.Warn("序列化工具配置失败", zap.String("tool", tool.Name), zap.Error(err))
|
||||
continue
|
||||
}
|
||||
|
||||
if err := os.WriteFile(toolFile, updatedData, 0644); err != nil {
|
||||
h.logger.Warn("保存工具配置文件失败", zap.String("tool", tool.Name), zap.Error(err))
|
||||
continue
|
||||
}
|
||||
|
||||
h.logger.Info("更新工具配置", zap.String("tool", tool.Name), zap.Bool("enabled", tool.Enabled))
|
||||
}
|
||||
}
|
||||
|
||||
// 保存主配置文件
|
||||
updatedData, err := yaml.Marshal(existingConfig)
|
||||
if err != nil {
|
||||
return fmt.Errorf("序列化配置失败: %w", err)
|
||||
}
|
||||
|
||||
// 创建备份
|
||||
backupPath := h.configPath + ".backup"
|
||||
if err := os.WriteFile(backupPath, data, 0644); err != nil {
|
||||
h.logger.Warn("创建配置备份失败", zap.Error(err))
|
||||
}
|
||||
|
||||
// 保存新配置
|
||||
if err := os.WriteFile(h.configPath, updatedData, 0644); err != nil {
|
||||
return fmt.Errorf("保存配置文件失败: %w", err)
|
||||
}
|
||||
|
||||
h.logger.Info("配置已保存", zap.String("path", h.configPath))
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -66,6 +66,26 @@ func (s *Server) RegisterTool(tool Tool, handler ToolHandler) {
|
||||
}
|
||||
}
|
||||
|
||||
// ClearTools 清空所有工具(用于重新加载配置)
|
||||
func (s *Server) ClearTools() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// 清空工具和工具定义
|
||||
s.tools = make(map[string]ToolHandler)
|
||||
s.toolDefs = make(map[string]Tool)
|
||||
|
||||
// 清空工具相关的资源(保留其他资源)
|
||||
newResources := make(map[string]*Resource)
|
||||
for uri, resource := range s.resources {
|
||||
// 保留非工具资源
|
||||
if !strings.HasPrefix(uri, "tool://") {
|
||||
newResources[uri] = resource
|
||||
}
|
||||
}
|
||||
s.resources = newResources
|
||||
}
|
||||
|
||||
// HandleHTTP 处理HTTP请求
|
||||
func (s *Server) HandleHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
|
||||
Reference in New Issue
Block a user