diff --git a/internal/agent/agent.go b/internal/agent/agent.go index d1a59c23..a96c0983 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -949,7 +949,7 @@ func (a *Agent) getAvailableTools(roleTools []string) []Tool { enabled := false if cfg, exists := externalMCPConfigs[mcpName]; exists { // 首先检查外部MCP是否启用 - if !cfg.ExternalMCPEnable && !(cfg.Enabled && !cfg.Disabled) { + if !cfg.ExternalMCPEnable { enabled = false // MCP未启用,所有工具都禁用 } else { // MCP已启用,检查单个工具的启用状态 diff --git a/internal/app/app.go b/internal/app/app.go index cb5c031c..d4e3dfe7 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -2,6 +2,7 @@ package app import ( "context" + "crypto/subtle" "database/sql" "fmt" "net/http" @@ -459,7 +460,9 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) { func (a *App) mcpHandlerWithAuth(w http.ResponseWriter, r *http.Request) { cfg := a.config.MCP if cfg.AuthHeader != "" { - if r.Header.Get(cfg.AuthHeader) != cfg.AuthHeaderValue { + actual := []byte(r.Header.Get(cfg.AuthHeader)) + expected := []byte(cfg.AuthHeaderValue) + if subtle.ConstantTimeCompare(actual, expected) != 1 { a.logger.Logger.Debug("MCP 鉴权失败:header 缺失或值不匹配", zap.String("header", cfg.AuthHeader)) w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusUnauthorized) @@ -470,18 +473,25 @@ func (a *App) mcpHandlerWithAuth(w http.ResponseWriter, r *http.Request) { a.mcpServer.HandleHTTP(w, r) } -// Run 启动应用 +// Run 启动应用(向后兼容,不支持优雅关闭) func (a *App) Run() error { + return a.RunWithContext(context.Background()) +} + +// RunWithContext 启动应用,支持通过 context 取消来优雅关闭 +func (a *App) RunWithContext(ctx context.Context) error { // 启动MCP服务器(如果启用) + var mcpServer *http.Server if a.config.MCP.Enabled { + mcpAddr := fmt.Sprintf("%s:%d", a.config.MCP.Host, a.config.MCP.Port) + a.logger.Info("启动MCP服务器", zap.String("address", mcpAddr)) + + mux := http.NewServeMux() + mux.HandleFunc("/mcp", a.mcpHandlerWithAuth) + + mcpServer = &http.Server{Addr: mcpAddr, Handler: mux} go func() { - mcpAddr := fmt.Sprintf("%s:%d", a.config.MCP.Host, a.config.MCP.Port) - a.logger.Info("启动MCP服务器", zap.String("address", mcpAddr)) - - mux := http.NewServeMux() - mux.HandleFunc("/mcp", a.mcpHandlerWithAuth) - - if err := http.ListenAndServe(mcpAddr, mux); err != nil { + if err := mcpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed { a.logger.Error("MCP服务器启动失败", zap.Error(err)) } }() @@ -491,7 +501,27 @@ func (a *App) Run() error { addr := fmt.Sprintf("%s:%d", a.config.Server.Host, a.config.Server.Port) a.logger.Info("启动HTTP服务器", zap.String("address", addr)) - return a.router.Run(addr) + srv := &http.Server{Addr: addr, Handler: a.router} + + // 监听 context 取消,优雅关闭 HTTP 服务器 + go func() { + <-ctx.Done() + shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + if err := srv.Shutdown(shutdownCtx); err != nil { + a.logger.Error("HTTP服务器关闭失败", zap.Error(err)) + } + if mcpServer != nil { + if err := mcpServer.Shutdown(shutdownCtx); err != nil { + a.logger.Error("MCP服务器关闭失败", zap.Error(err)) + } + } + }() + + if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { + return err + } + return nil } // Shutdown 关闭应用 @@ -519,6 +549,13 @@ func (a *App) Shutdown() { a.logger.Logger.Warn("关闭知识库数据库连接失败", zap.Error(err)) } } + + // 关闭主数据库连接 + if a.db != nil { + if err := a.db.Close(); err != nil { + a.logger.Logger.Warn("关闭主数据库连接失败", zap.Error(err)) + } + } } // startRobotConnections 根据当前配置启动钉钉/飞书长连接(不先关闭已有连接,仅用于首次启动) @@ -593,10 +630,16 @@ func setupRoutes( } // 机器人回调(无需登录,供企业微信/钉钉/飞书服务器调用) - api.GET("/robot/wecom", robotHandler.HandleWecomGET) - api.POST("/robot/wecom", robotHandler.HandleWecomPOST) - api.POST("/robot/dingtalk", robotHandler.HandleDingtalkPOST) - api.POST("/robot/lark", robotHandler.HandleLarkPOST) + // 添加速率限制:每个 IP 每分钟最多 60 次请求,防止滥用 + robotRL := security.NewRateLimiter(60, 1*time.Minute) + robotGroup := api.Group("/robot") + robotGroup.Use(security.RateLimitMiddleware(robotRL)) + { + robotGroup.GET("/wecom", robotHandler.HandleWecomGET) + robotGroup.POST("/wecom", robotHandler.HandleWecomPOST) + robotGroup.POST("/dingtalk", robotHandler.HandleDingtalkPOST) + robotGroup.POST("/lark", robotHandler.HandleLarkPOST) + } protected := api.Group("") protected.Use(security.AuthMiddleware(authManager)) @@ -680,6 +723,7 @@ func setupRoutes( // 配置管理 protected.GET("/config", configHandler.GetConfig) protected.GET("/config/tools", configHandler.GetTools) + protected.GET("/config/tools/:name/schema", configHandler.GetToolSchema) protected.PUT("/config", configHandler.UpdateConfig) protected.POST("/config/apply", configHandler.ApplyConfig) protected.POST("/config/test-openai", configHandler.TestOpenAI) diff --git a/internal/handler/agent.go b/internal/handler/agent.go index 0ee4185b..e52a4f7b 100644 --- a/internal/handler/agent.go +++ b/internal/handler/agent.go @@ -2089,14 +2089,17 @@ func (h *AgentHandler) nextBatchQueueRunAt(cronExpr string, from time.Time) (*ti } func (h *AgentHandler) startBatchQueueExecution(queueID string, scheduled bool) (bool, error) { - queue, exists := h.batchTaskManager.GetBatchQueue(queueID) - if !exists { - return false, nil - } + // 先获取执行互斥门,再读取队列状态,避免基于过时快照做判断 if !h.markBatchQueueRunning(queueID) { return true, nil } + queue, exists := h.batchTaskManager.GetBatchQueue(queueID) + if !exists { + h.unmarkBatchQueueRunning(queueID) + return false, nil + } + if scheduled { if queue.ScheduleMode != "cron" { h.unmarkBatchQueueRunning(queueID) diff --git a/internal/handler/batch_task_manager.go b/internal/handler/batch_task_manager.go index 5bd03cfb..572588b1 100644 --- a/internal/handler/batch_task_manager.go +++ b/internal/handler/batch_task_manager.go @@ -543,16 +543,23 @@ func (m *BatchTaskManager) UpdateTaskStatus(queueID, taskID, status string, resu // UpdateTaskStatusWithConversationID 更新任务状态(包含conversationId) func (m *BatchTaskManager) UpdateTaskStatusWithConversationID(queueID, taskID, status string, result, errorMsg, conversationID string) { - var needDBUpdate bool - - // 在锁内只更新内存状态 m.mu.Lock() + defer m.mu.Unlock() + queue, exists := m.queues[queueID] if !exists { - m.mu.Unlock() return } + // DB 优先:先持久化,成功后再更新内存,避免重启后状态不一致 + if m.db != nil { + if err := m.db.UpdateBatchTaskStatus(queueID, taskID, status, conversationID, result, errorMsg); err != nil { + m.logger.Warn("batch task DB status update failed, skipping memory update", + zap.String("queueId", queueID), zap.String("taskId", taskID), zap.Error(err)) + return + } + } + for _, task := range queue.Tasks { if task.ID == taskID { task.Status = status @@ -575,30 +582,27 @@ func (m *BatchTaskManager) UpdateTaskStatusWithConversationID(queueID, taskID, s break } } - - needDBUpdate = m.db != nil - m.mu.Unlock() - - // 释放锁后写 DB - if needDBUpdate { - if err := m.db.UpdateBatchTaskStatus(queueID, taskID, status, conversationID, result, errorMsg); err != nil { - m.logger.Warn("batch task DB status update failed", zap.String("queueId", queueID), zap.String("taskId", taskID), zap.Error(err)) - } - } } // UpdateQueueStatus 更新队列状态 func (m *BatchTaskManager) UpdateQueueStatus(queueID, status string) { - var needDBUpdate bool - - // 在锁内只更新内存状态 m.mu.Lock() + defer m.mu.Unlock() + queue, exists := m.queues[queueID] if !exists { - m.mu.Unlock() return } + // DB 优先:先持久化,成功后再更新内存 + if m.db != nil { + if err := m.db.UpdateBatchQueueStatus(queueID, status); err != nil { + m.logger.Warn("batch queue DB status update failed, skipping memory update", + zap.String("queueId", queueID), zap.Error(err)) + return + } + } + queue.Status = status now := time.Now() if status == BatchQueueStatusRunning && queue.StartedAt == nil { @@ -607,16 +611,6 @@ func (m *BatchTaskManager) UpdateQueueStatus(queueID, status string) { if status == BatchQueueStatusCompleted || status == BatchQueueStatusCancelled { queue.CompletedAt = &now } - - needDBUpdate = m.db != nil - m.mu.Unlock() - - // 释放锁后写 DB - if needDBUpdate { - if err := m.db.UpdateBatchQueueStatus(queueID, status); err != nil { - m.logger.Warn("batch queue DB status update failed", zap.String("queueId", queueID), zap.Error(err)) - } - } } // UpdateQueueSchedule 更新队列调度配置 @@ -756,6 +750,16 @@ func (m *BatchTaskManager) ResetQueueForRerun(queueID string) bool { if !exists { return false } + + // DB 优先:先持久化重置,成功后再更新内存,避免 DB 失败导致内存脏状态 + if m.db != nil { + if err := m.db.ResetBatchQueueForRerun(queueID); err != nil { + m.logger.Warn("batch queue DB reset for rerun failed, skipping memory update", + zap.String("queueId", queueID), zap.Error(err)) + return false + } + } + queue.Status = BatchQueueStatusPending queue.CurrentIndex = 0 queue.StartedAt = nil @@ -771,12 +775,6 @@ func (m *BatchTaskManager) ResetQueueForRerun(queueID string) bool { task.Error = "" task.Result = "" } - - if m.db != nil { - if err := m.db.ResetBatchQueueForRerun(queueID); err != nil { - return false - } - } return true } @@ -870,7 +868,7 @@ func (m *BatchTaskManager) DeleteTask(queueID, taskID string) error { return fmt.Errorf("队列正在执行或未就绪,无法删除任务") } - // 查找并删除任务 + // 查找任务 taskIndex := -1 for i, task := range queue.Tasks { if task.ID == taskID { @@ -886,18 +884,14 @@ func (m *BatchTaskManager) DeleteTask(queueID, taskID string) error { return fmt.Errorf("任务不存在") } - // 从内存队列中删除 - queue.Tasks = append(queue.Tasks[:taskIndex], queue.Tasks[taskIndex+1:]...) - - // 同步到数据库 + // DB 优先:先从数据库删除,成功后再从内存移除 if m.db != nil { if err := m.db.DeleteBatchTask(queueID, taskID); err != nil { - // 如果数据库删除失败,恢复内存中的任务 - // 这里需要重新插入,但为了简化,我们只记录错误 return fmt.Errorf("删除任务失败: %w", err) } } + queue.Tasks = append(queue.Tasks[:taskIndex], queue.Tasks[taskIndex+1:]...) return nil } @@ -987,9 +981,7 @@ func (m *BatchTaskManager) SetTaskCancel(queueID string, cancel context.CancelFu // PauseQueue 暂停队列 func (m *BatchTaskManager) PauseQueue(queueID string) bool { var cancelFunc context.CancelFunc - var needDBUpdate bool - // 在锁内只更新内存状态 m.mu.Lock() queue, exists := m.queues[queueID] if !exists { @@ -1002,6 +994,16 @@ func (m *BatchTaskManager) PauseQueue(queueID string) bool { return false } + // DB 优先:先持久化,成功后再更新内存 + if m.db != nil { + if err := m.db.UpdateBatchQueueStatus(queueID, BatchQueueStatusPaused); err != nil { + m.logger.Warn("batch queue DB pause update failed, skipping memory update", + zap.String("queueId", queueID), zap.Error(err)) + m.mu.Unlock() + return false + } + } + queue.Status = BatchQueueStatusPaused // 取消当前正在执行的任务(通过取消context) @@ -1009,22 +1011,13 @@ func (m *BatchTaskManager) PauseQueue(queueID string) bool { cancelFunc = cancel delete(m.taskCancels, queueID) } - - needDBUpdate = m.db != nil m.mu.Unlock() - // 释放锁后执行取消回调 + // 释放锁后执行取消回调(cancel 可能阻塞,不应持锁) if cancelFunc != nil { cancelFunc() } - // 释放锁后写 DB - if needDBUpdate { - if err := m.db.UpdateBatchQueueStatus(queueID, BatchQueueStatusPaused); err != nil { - m.logger.Warn("batch queue DB pause update failed", zap.String("queueId", queueID), zap.Error(err)) - } - } - return true } @@ -1032,9 +1025,7 @@ func (m *BatchTaskManager) PauseQueue(queueID string) bool { func (m *BatchTaskManager) CancelQueue(queueID string) bool { now := time.Now() var cancelFunc context.CancelFunc - var needDBUpdate bool - // 在锁内只更新内存状态,不做 DB 操作 m.mu.Lock() queue, exists := m.queues[queueID] if !exists { @@ -1047,6 +1038,22 @@ func (m *BatchTaskManager) CancelQueue(queueID string) bool { return false } + // DB 优先:先持久化,成功后再更新内存 + if m.db != nil { + if err := m.db.CancelPendingBatchTasks(queueID, now); err != nil { + m.logger.Warn("batch task DB batch cancel failed, skipping memory update", + zap.String("queueId", queueID), zap.Error(err)) + m.mu.Unlock() + return false + } + if err := m.db.UpdateBatchQueueStatus(queueID, BatchQueueStatusCancelled); err != nil { + m.logger.Warn("batch queue DB cancel update failed, skipping memory update", + zap.String("queueId", queueID), zap.Error(err)) + m.mu.Unlock() + return false + } + } + queue.Status = BatchQueueStatusCancelled queue.CompletedAt = &now @@ -1063,25 +1070,13 @@ func (m *BatchTaskManager) CancelQueue(queueID string) bool { cancelFunc = cancel delete(m.taskCancels, queueID) } - - needDBUpdate = m.db != nil m.mu.Unlock() - // 释放锁后执行取消回调 + // 释放锁后执行取消回调(cancel 可能阻塞,不应持锁) if cancelFunc != nil { cancelFunc() } - // 释放锁后批量写 DB(单条 SQL 取消所有 pending 任务) - if needDBUpdate { - if err := m.db.CancelPendingBatchTasks(queueID, now); err != nil { - m.logger.Warn("batch task DB batch cancel failed", zap.String("queueId", queueID), zap.Error(err)) - } - if err := m.db.UpdateBatchQueueStatus(queueID, BatchQueueStatusCancelled); err != nil { - m.logger.Warn("batch queue DB cancel update failed", zap.String("queueId", queueID), zap.Error(err)) - } - } - return true } diff --git a/internal/handler/config.go b/internal/handler/config.go index 05542567..adf998a6 100644 --- a/internal/handler/config.go +++ b/internal/handler/config.go @@ -194,12 +194,13 @@ type GetConfigResponse struct { // ToolConfigInfo 工具配置信息 type ToolConfigInfo struct { - Name string `json:"name"` - Description string `json:"description"` - Enabled bool `json:"enabled"` - IsExternal bool `json:"is_external,omitempty"` // 是否为外部MCP工具 - ExternalMCP string `json:"external_mcp,omitempty"` // 外部MCP名称(如果是外部工具) - RoleEnabled *bool `json:"role_enabled,omitempty"` // 该工具在当前角色中是否启用(nil表示未指定角色或使用所有工具) + Name string `json:"name"` + Description string `json:"description"` + Enabled bool `json:"enabled"` + IsExternal bool `json:"is_external,omitempty"` // 是否为外部MCP工具 + ExternalMCP string `json:"external_mcp,omitempty"` // 外部MCP名称(如果是外部工具) + RoleEnabled *bool `json:"role_enabled,omitempty"` // 该工具在当前角色中是否启用(nil表示未指定角色或使用所有工具) + InputSchema map[string]interface{} `json:"input_schema,omitempty"` // 工具参数 JSON Schema(用于前端展示详情) } // GetConfig 获取当前配置 @@ -211,25 +212,25 @@ func (h *ConfigHandler) GetConfig(c *gin.Context) { // 首先从配置文件获取工具 configToolMap := make(map[string]bool) tools := make([]ToolConfigInfo, 0, len(h.config.Security.Tools)) + for _, tool := range h.config.Security.Tools { configToolMap[tool.Name] = true - tools = append(tools, ToolConfigInfo{ + info := ToolConfigInfo{ Name: tool.Name, Description: h.pickToolDescription(tool.ShortDescription, tool.Description), Enabled: tool.Enabled, IsExternal: false, - }) + } + tools = append(tools, info) } // 从MCP服务器获取所有已注册的工具(包括直接注册的工具,如知识检索工具) if h.mcpServer != nil { mcpTools := h.mcpServer.GetAllTools() for _, mcpTool := range mcpTools { - // 跳过已经在配置文件中的工具(避免重复) if configToolMap[mcpTool.Name] { continue } - // 添加直接注册到MCP服务器的工具(如知识检索工具) description := mcpTool.ShortDescription if description == "" { description = mcpTool.Description @@ -240,7 +241,7 @@ func (h *ConfigHandler) GetConfig(c *gin.Context) { tools = append(tools, ToolConfigInfo{ Name: mcpTool.Name, Description: description, - Enabled: true, // 直接注册的工具默认启用 + Enabled: true, IsExternal: false, }) } @@ -442,7 +443,7 @@ func (h *ConfigHandler) GetTools(c *gin.Context) { toolInfo := ToolConfigInfo{ Name: mcpTool.Name, Description: description, - Enabled: true, // 直接注册的工具默认启用 + Enabled: true, IsExternal: false, } @@ -1142,32 +1143,7 @@ func (h *ConfigHandler) saveConfig() error { updateRobotsConfig(root, h.config.Robots) updateMultiAgentConfig(root, h.config.MultiAgent) // 更新外部MCP配置(使用external_mcp.go中的函数,同一包中可直接调用) - // 读取原始配置以保持向后兼容 - originalConfigs := make(map[string]map[string]bool) - externalMCPNode := findMapValue(root, "external_mcp") - if externalMCPNode != nil && externalMCPNode.Kind == yaml.MappingNode { - serversNode := findMapValue(externalMCPNode, "servers") - if serversNode != nil && serversNode.Kind == yaml.MappingNode { - for i := 0; i < len(serversNode.Content); i += 2 { - if i+1 >= len(serversNode.Content) { - break - } - nameNode := serversNode.Content[i] - serverNode := serversNode.Content[i+1] - if nameNode.Kind == yaml.ScalarNode && serverNode.Kind == yaml.MappingNode { - serverName := nameNode.Value - originalConfigs[serverName] = make(map[string]bool) - if enabledVal := findBoolInMap(serverNode, "enabled"); enabledVal != nil { - originalConfigs[serverName]["enabled"] = *enabledVal - } - if disabledVal := findBoolInMap(serverNode, "disabled"); disabledVal != nil { - originalConfigs[serverName]["disabled"] = *disabledVal - } - } - } - } - } - updateExternalMCPConfig(root, h.config.ExternalMCP, originalConfigs) + updateExternalMCPConfig(root, h.config.ExternalMCP) if err := writeYAMLDocument(h.configPath, root); err != nil { return fmt.Errorf("保存配置文件失败: %w", err) @@ -1585,7 +1561,7 @@ func (h *ConfigHandler) calculateExternalToolEnabled(mcpName, toolName string, c } // 首先检查外部MCP是否启用 - if !cfg.ExternalMCPEnable && !(cfg.Enabled && !cfg.Disabled) { + if !cfg.ExternalMCPEnable { return false // MCP未启用,所有工具都禁用 } @@ -1624,3 +1600,109 @@ func (h *ConfigHandler) pickToolDescription(shortDesc, fullDesc string) string { } return description } + +// GetToolSchema 获取单个工具的 inputSchema(按需加载,避免列表接口返回大量 schema 数据) +func (h *ConfigHandler) GetToolSchema(c *gin.Context) { + h.mu.RLock() + defer h.mu.RUnlock() + + toolName := c.Param("name") + if toolName == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "工具名称不能为空"}) + return + } + + // 检查是否为外部工具(格式:mcpName::toolName) + externalMCP := c.Query("external_mcp") + if externalMCP != "" { + // 外部 MCP 工具 + if h.externalMCPMgr != nil { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + externalTools, _ := h.externalMCPMgr.GetAllTools(ctx) + fullName := externalMCP + "::" + toolName + for _, t := range externalTools { + if t.Name == fullName { + c.JSON(http.StatusOK, gin.H{"input_schema": t.InputSchema}) + return + } + } + } + c.JSON(http.StatusNotFound, gin.H{"error": "外部工具未找到"}) + return + } + + // 内部工具:从 YAML 配置的 Parameters 构建 + for _, tool := range h.config.Security.Tools { + if tool.Name == toolName { + c.JSON(http.StatusOK, gin.H{"input_schema": buildInputSchemaFromParams(tool.Parameters)}) + return + } + } + + // MCP 注册工具(如知识检索) + if h.mcpServer != nil { + for _, mt := range h.mcpServer.GetAllTools() { + if mt.Name == toolName { + c.JSON(http.StatusOK, gin.H{"input_schema": mt.InputSchema}) + return + } + } + } + + c.JSON(http.StatusNotFound, gin.H{"error": "工具未找到"}) +} + +// buildInputSchemaFromParams 从 YAML 工具的 ParameterConfig 构建 JSON Schema(用于前端展示)。 +// 不依赖 MCP 服务器注册状态,所有工具(包括未启用的)都能返回参数定义。 +func buildInputSchemaFromParams(params []config.ParameterConfig) map[string]interface{} { + if len(params) == 0 { + return nil + } + + properties := make(map[string]interface{}) + required := make([]string, 0) + + for _, p := range params { + name := strings.TrimSpace(p.Name) + if name == "" { + continue + } + prop := map[string]interface{}{ + "type": convertParamType(p.Type), + "description": p.Description, + } + if p.Default != nil { + prop["default"] = p.Default + } + if len(p.Options) > 0 { + prop["enum"] = p.Options + } + properties[name] = prop + if p.Required { + required = append(required, name) + } + } + + schema := map[string]interface{}{ + "type": "object", + "properties": properties, + } + if len(required) > 0 { + schema["required"] = required + } + return schema +} + +func convertParamType(t string) string { + switch strings.TrimSpace(strings.ToLower(t)) { + case "int", "integer", "number": + return "number" + case "bool", "boolean": + return "boolean" + case "array", "list": + return "array" + default: + return "string" + } +} diff --git a/internal/handler/external_mcp.go b/internal/handler/external_mcp.go index a8b57ae6..e1fcab1e 100644 --- a/internal/handler/external_mcp.go +++ b/internal/handler/external_mcp.go @@ -157,36 +157,19 @@ func (h *ExternalMCPHandler) AddOrUpdateExternalMCP(c *gin.Context) { h.config.ExternalMCP.Servers = make(map[string]config.ExternalMCPServerConfig) } - // 如果用户提供了 disabled 或 enabled 字段,保留它们以保持向后兼容 - // 同时将值迁移到 external_mcp_enable cfg := req.Config - if req.Config.Disabled { - // 用户设置了 disabled: true + // 官方 disabled 字段 → ExternalMCPEnable 取反 + if cfg.Disabled { cfg.ExternalMCPEnable = false - cfg.Disabled = true - cfg.Enabled = false - } else if req.Config.Enabled { - // 用户设置了 enabled: true + } else if !cfg.ExternalMCPEnable { + // 用户未显式设置 external_mcp_enable,官方配置默认就是启用的 cfg.ExternalMCPEnable = true - cfg.Enabled = true - cfg.Disabled = false - } else if !req.Config.ExternalMCPEnable { - // 用户没有设置任何字段,且 external_mcp_enable 为 false - // 检查现有配置是否有旧字段 - if existingCfg, exists := h.config.ExternalMCP.Servers[name]; exists { - // 保留现有的旧字段 - cfg.Enabled = existingCfg.Enabled - cfg.Disabled = existingCfg.Disabled - } - } else { - // 用户通过新字段启用了(external_mcp_enable: true),但没有设置旧字段 - // 为了向后兼容,我们设置 enabled: true - // 这样即使原始配置中有 disabled: false,也会被转换为 enabled: true - cfg.Enabled = true - cfg.Disabled = false } + // 展开 ${VAR} 环境变量 + config.ExpandConfigEnv(&cfg) + h.config.ExternalMCP.Servers[name] = cfg // 保存到配置文件 @@ -315,32 +298,25 @@ func (h *ExternalMCPHandler) GetExternalMCPStats(c *gin.Context) { c.JSON(http.StatusOK, stats) } -// validateConfig 验证配置 +// validateConfig 验证配置(同时支持官方 type 字段和旧版 transport 字段) func (h *ExternalMCPHandler) validateConfig(cfg config.ExternalMCPServerConfig) error { - transport := cfg.Transport + transport := cfg.GetTransportType() if transport == "" { - // 如果没有指定transport,根据是否有command或url判断 - if cfg.Command != "" { - transport = "stdio" - } else if cfg.URL != "" { - transport = "http" - } else { - return fmt.Errorf("需要指定command(stdio模式)或url(http/sse模式)") - } + return fmt.Errorf("需要指定 command(stdio模式)或 url + type(http/sse模式)") } switch transport { case "http": if cfg.URL == "" { - return fmt.Errorf("HTTP模式需要URL") + return fmt.Errorf("HTTP模式需要 url") } case "stdio": if cfg.Command == "" { - return fmt.Errorf("stdio模式需要command") + return fmt.Errorf("stdio模式需要 command") } case "sse": if cfg.URL == "" { - return fmt.Errorf("SSE模式需要URL") + return fmt.Errorf("SSE模式需要 url") } default: return fmt.Errorf("不支持的传输模式: %s,支持的模式: http, stdio, sse", transport) @@ -351,25 +327,11 @@ func (h *ExternalMCPHandler) validateConfig(cfg config.ExternalMCPServerConfig) // isEnabled 检查是否启用 func (h *ExternalMCPHandler) isEnabled(cfg config.ExternalMCPServerConfig) bool { - // 优先使用 ExternalMCPEnable 字段 - // 如果没有设置,检查旧的 enabled/disabled 字段(向后兼容) - if cfg.ExternalMCPEnable { - return true - } - // 向后兼容:检查旧字段 - if cfg.Disabled { - return false - } - if cfg.Enabled { - return true - } - // 都没有设置,默认为启用 - return true + return cfg.ExternalMCPEnable } // saveConfig 保存配置到文件 func (h *ExternalMCPHandler) saveConfig() error { - // 读取现有配置文件并创建备份 data, err := os.ReadFile(h.configPath) if err != nil { return fmt.Errorf("读取配置文件失败: %w", err) @@ -384,37 +346,7 @@ func (h *ExternalMCPHandler) saveConfig() error { return fmt.Errorf("解析配置文件失败: %w", err) } - // 在更新前,读取原始配置中的 enabled/disabled 字段,以便保持向后兼容 - originalConfigs := make(map[string]map[string]bool) - externalMCPNode := findMapValue(root.Content[0], "external_mcp") - if externalMCPNode != nil && externalMCPNode.Kind == yaml.MappingNode { - serversNode := findMapValue(externalMCPNode, "servers") - if serversNode != nil && serversNode.Kind == yaml.MappingNode { - // 遍历现有的服务器配置,保存 enabled/disabled 字段 - for i := 0; i < len(serversNode.Content); i += 2 { - if i+1 >= len(serversNode.Content) { - break - } - nameNode := serversNode.Content[i] - serverNode := serversNode.Content[i+1] - if nameNode.Kind == yaml.ScalarNode && serverNode.Kind == yaml.MappingNode { - serverName := nameNode.Value - originalConfigs[serverName] = make(map[string]bool) - // 检查是否有 enabled 字段 - if enabledVal := findBoolInMap(serverNode, "enabled"); enabledVal != nil { - originalConfigs[serverName]["enabled"] = *enabledVal - } - // 检查是否有 disabled 字段 - if disabledVal := findBoolInMap(serverNode, "disabled"); disabledVal != nil { - originalConfigs[serverName]["disabled"] = *disabledVal - } - } - } - } - } - - // 更新外部MCP配置 - updateExternalMCPConfig(root, h.config.ExternalMCP, originalConfigs) + updateExternalMCPConfig(root, h.config.ExternalMCP) if err := writeYAMLDocument(h.configPath, root); err != nil { return fmt.Errorf("保存配置文件失败: %w", err) @@ -425,7 +357,7 @@ func (h *ExternalMCPHandler) saveConfig() error { } // updateExternalMCPConfig 更新外部MCP配置 -func updateExternalMCPConfig(doc *yaml.Node, cfg config.ExternalMCPConfig, originalConfigs map[string]map[string]bool) { +func updateExternalMCPConfig(doc *yaml.Node, cfg config.ExternalMCPConfig) { root := doc.Content[0] externalMCPNode := ensureMap(root, "external_mcp") serversNode := ensureMap(externalMCPNode, "servers") @@ -435,32 +367,31 @@ func updateExternalMCPConfig(doc *yaml.Node, cfg config.ExternalMCPConfig, origi // 添加新的服务器配置 for name, serverCfg := range cfg.Servers { - // 添加服务器名称键 nameNode := &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: name} serverNode := &yaml.Node{Kind: yaml.MappingNode, Tag: "!!map"} serversNode.Content = append(serversNode.Content, nameNode, serverNode) - // 设置服务器配置字段 + // type(官方 MCP 传输类型) + effectiveType := serverCfg.GetTransportType() + if effectiveType != "" && effectiveType != "stdio" { + // stdio 可省略(有 command 时自动推断) + setStringInMap(serverNode, "type", effectiveType) + } if serverCfg.Command != "" { setStringInMap(serverNode, "command", serverCfg.Command) } if len(serverCfg.Args) > 0 { setStringArrayInMap(serverNode, "args", serverCfg.Args) } - // 保存 env 字段(环境变量) if serverCfg.Env != nil && len(serverCfg.Env) > 0 { envNode := ensureMap(serverNode, "env") for envKey, envValue := range serverCfg.Env { setStringInMap(envNode, envKey, envValue) } } - if serverCfg.Transport != "" { - setStringInMap(serverNode, "transport", serverCfg.Transport) - } if serverCfg.URL != "" { setStringInMap(serverNode, "url", serverCfg.URL) } - // 保存 headers 字段(HTTP/SSE 请求头) if serverCfg.Headers != nil && len(serverCfg.Headers) > 0 { headersNode := ensureMap(serverNode, "headers") for k, v := range serverCfg.Headers { @@ -473,46 +404,32 @@ func updateExternalMCPConfig(doc *yaml.Node, cfg config.ExternalMCPConfig, origi if serverCfg.Timeout > 0 { setIntInMap(serverNode, "timeout", serverCfg.Timeout) } - // 保存 external_mcp_enable 字段(新字段) + // 官方标准字段 + if serverCfg.Disabled { + setBoolInMap(serverNode, "disabled", true) + } + if len(serverCfg.AutoApprove) > 0 { + setStringArrayInMap(serverNode, "autoApprove", serverCfg.AutoApprove) + } + + // SDK 高级配置 + if serverCfg.MaxRetries > 0 { + setIntInMap(serverNode, "max_retries", serverCfg.MaxRetries) + } + if serverCfg.TerminateDuration > 0 { + setIntInMap(serverNode, "terminate_duration", serverCfg.TerminateDuration) + } + if serverCfg.KeepAlive > 0 { + setIntInMap(serverNode, "keep_alive", serverCfg.KeepAlive) + } + setBoolInMap(serverNode, "external_mcp_enable", serverCfg.ExternalMCPEnable) - // 保存 tool_enabled 字段(每个工具的启用状态) if serverCfg.ToolEnabled != nil && len(serverCfg.ToolEnabled) > 0 { toolEnabledNode := ensureMap(serverNode, "tool_enabled") for toolName, enabled := range serverCfg.ToolEnabled { setBoolInMap(toolEnabledNode, toolName, enabled) } } - // 保留旧的 enabled/disabled 字段以保持向后兼容 - originalFields, hasOriginal := originalConfigs[name] - - // 如果原始配置中有 enabled 字段,保留它 - if hasOriginal { - if enabledVal, hasEnabled := originalFields["enabled"]; hasEnabled { - setBoolInMap(serverNode, "enabled", enabledVal) - } - // 如果原始配置中有 disabled 字段,保留它 - // 注意:由于 omitempty,disabled: false 不会被保存,但 disabled: true 会被保存 - if disabledVal, hasDisabled := originalFields["disabled"]; hasDisabled { - if disabledVal { - setBoolInMap(serverNode, "disabled", disabledVal) - } else { - // 如果原始配置中有 disabled: false,我们保存 enabled: true 来等效表示 - // 因为 disabled: false 等价于 enabled: true - setBoolInMap(serverNode, "enabled", true) - } - } - } - - // 如果用户在当前请求中明确设置了这些字段,也保存它们 - if serverCfg.Enabled { - setBoolInMap(serverNode, "enabled", serverCfg.Enabled) - } - if serverCfg.Disabled { - setBoolInMap(serverNode, "disabled", serverCfg.Disabled) - } else if !hasOriginal && serverCfg.ExternalMCPEnable { - // 如果用户通过新字段启用了,且原始配置中没有旧字段,保存 enabled: true 以保持向后兼容 - setBoolInMap(serverNode, "enabled", true) - } } } diff --git a/internal/handler/external_mcp_test.go b/internal/handler/external_mcp_test.go index a663c489..e52eeced 100644 --- a/internal/handler/external_mcp_test.go +++ b/internal/handler/external_mcp_test.go @@ -60,13 +60,13 @@ func TestExternalMCPHandler_AddOrUpdateExternalMCP_Stdio(t *testing.T) { router, _, configPath := setupTestRouter() defer cleanupTestConfig(configPath) - // 测试添加stdio模式的配置 + // 测试添加stdio模式的配置(官方格式:有 command 时 type 可省略) configJSON := `{ "command": "python3", "args": ["/path/to/script.py", "--server", "http://example.com"], "description": "Test stdio MCP", "timeout": 300, - "enabled": true + "external_mcp_enable": true }` var configObj config.ExternalMCPServerConfig @@ -115,20 +115,17 @@ func TestExternalMCPHandler_AddOrUpdateExternalMCP_Stdio(t *testing.T) { if response.Config.Timeout != 300 { t.Errorf("期望timeout为300,实际%d", response.Config.Timeout) } - if !response.Config.Enabled { - t.Error("期望enabled为true") - } } func TestExternalMCPHandler_AddOrUpdateExternalMCP_HTTP(t *testing.T) { router, _, configPath := setupTestRouter() defer cleanupTestConfig(configPath) - // 测试添加HTTP模式的配置 + // 测试添加HTTP模式的配置(使用官方 type 字段) configJSON := `{ - "transport": "http", + "type": "http", "url": "http://127.0.0.1:8081/mcp", - "enabled": true + "external_mcp_enable": true }` var configObj config.ExternalMCPServerConfig @@ -165,15 +162,12 @@ func TestExternalMCPHandler_AddOrUpdateExternalMCP_HTTP(t *testing.T) { t.Fatalf("解析响应失败: %v", err) } - if response.Config.Transport != "http" { - t.Errorf("期望transport为http,实际%s", response.Config.Transport) + if response.Config.Type != "http" { + t.Errorf("期望type为http,实际%s", response.Config.Type) } if response.Config.URL != "http://127.0.0.1:8081/mcp" { t.Errorf("期望url为'http://127.0.0.1:8081/mcp',实际%s", response.Config.URL) } - if !response.Config.Enabled { - t.Error("期望enabled为true") - } } func TestExternalMCPHandler_AddOrUpdateExternalMCP_InvalidConfig(t *testing.T) { @@ -187,22 +181,22 @@ func TestExternalMCPHandler_AddOrUpdateExternalMCP_InvalidConfig(t *testing.T) { }{ { name: "缺少command和url", - configJSON: `{"enabled": true}`, - expectedErr: "需要指定command(stdio模式)或url(http/sse模式)", + configJSON: `{"external_mcp_enable": true}`, + expectedErr: "需要指定 command(stdio模式)或 url + type(http/sse模式)", }, { name: "stdio模式缺少command", - configJSON: `{"args": ["test"], "enabled": true}`, + configJSON: `{"args": ["test"], "external_mcp_enable": true}`, expectedErr: "stdio模式需要command", }, { name: "http模式缺少url", - configJSON: `{"transport": "http", "enabled": true}`, - expectedErr: "HTTP模式需要URL", + configJSON: `{"type": "http", "external_mcp_enable": true}`, + expectedErr: "HTTP模式需要 url", }, { - name: "无效的transport", - configJSON: `{"transport": "invalid", "enabled": true}`, + name: "无效的type", + configJSON: `{"type": "invalid", "external_mcp_enable": true}`, expectedErr: "不支持的传输模式", }, } @@ -254,7 +248,7 @@ func TestExternalMCPHandler_DeleteExternalMCP(t *testing.T) { // 先添加一个配置 configObj := config.ExternalMCPServerConfig{ Command: "python3", - Enabled: true, + ExternalMCPEnable: true, } handler.manager.AddOrUpdateConfig("test-delete", configObj) @@ -283,11 +277,11 @@ func TestExternalMCPHandler_GetExternalMCPs(t *testing.T) { // 添加多个配置 handler.manager.AddOrUpdateConfig("test1", config.ExternalMCPServerConfig{ Command: "python3", - Enabled: true, + ExternalMCPEnable: true, }) handler.manager.AddOrUpdateConfig("test2", config.ExternalMCPServerConfig{ URL: "http://127.0.0.1:8081/mcp", - Enabled: false, + ExternalMCPEnable: false, }) req := httptest.NewRequest("GET", "/api/external-mcp", nil) @@ -326,16 +320,14 @@ func TestExternalMCPHandler_GetExternalMCPStats(t *testing.T) { // 添加配置 handler.manager.AddOrUpdateConfig("enabled1", config.ExternalMCPServerConfig{ Command: "python3", - Enabled: true, + ExternalMCPEnable: true, }) handler.manager.AddOrUpdateConfig("enabled2", config.ExternalMCPServerConfig{ URL: "http://127.0.0.1:8081/mcp", - Enabled: true, + ExternalMCPEnable: true, }) handler.manager.AddOrUpdateConfig("disabled1", config.ExternalMCPServerConfig{ Command: "python3", - Enabled: false, - Disabled: true, }) req := httptest.NewRequest("GET", "/api/external-mcp/stats", nil) @@ -369,8 +361,6 @@ func TestExternalMCPHandler_StartStopExternalMCP(t *testing.T) { // 添加一个禁用的配置 handler.manager.AddOrUpdateConfig("test-start-stop", config.ExternalMCPServerConfig{ Command: "python3", - Enabled: false, - Disabled: true, }) // 测试启动(可能会失败,因为没有真实的服务器) @@ -427,7 +417,7 @@ func TestExternalMCPHandler_AddOrUpdateExternalMCP_EmptyName(t *testing.T) { configObj := config.ExternalMCPServerConfig{ Command: "python3", - Enabled: true, + ExternalMCPEnable: true, } reqBody := AddOrUpdateExternalMCPRequest{ @@ -470,14 +460,14 @@ func TestExternalMCPHandler_UpdateExistingConfig(t *testing.T) { // 先添加配置 config1 := config.ExternalMCPServerConfig{ Command: "python3", - Enabled: true, + ExternalMCPEnable: true, } handler.manager.AddOrUpdateConfig("test-update", config1) // 更新配置 config2 := config.ExternalMCPServerConfig{ URL: "http://127.0.0.1:8081/mcp", - Enabled: true, + ExternalMCPEnable: true, } reqBody := AddOrUpdateExternalMCPRequest{ diff --git a/internal/handler/webshell.go b/internal/handler/webshell.go index 06da5d61..5afa44c6 100644 --- a/internal/handler/webshell.go +++ b/internal/handler/webshell.go @@ -411,7 +411,10 @@ func (h *WebShellHandler) Exec(c *gin.Context) { } defer resp.Body.Close() - out, _ := io.ReadAll(resp.Body) + out, readErr := io.ReadAll(resp.Body) + if readErr != nil { + h.logger.Warn("webshell exec read body", zap.Error(readErr)) + } output := string(out) httpCode := resp.StatusCode @@ -578,7 +581,10 @@ func (h *WebShellHandler) FileOp(c *gin.Context) { } defer resp.Body.Close() - out, _ := io.ReadAll(resp.Body) + out, readErr := io.ReadAll(resp.Body) + if readErr != nil { + h.logger.Warn("webshell fileop read body", zap.Error(readErr)) + } output := string(out) c.JSON(http.StatusOK, FileOpResponse{ @@ -633,7 +639,10 @@ func (h *WebShellHandler) ExecWithConnection(conn *database.WebShellConnection, return "", false, err.Error() } defer resp.Body.Close() - out, _ := io.ReadAll(resp.Body) + out, readErr := io.ReadAll(resp.Body) + if readErr != nil { + h.logger.Warn("webshell ExecWithConnection read body", zap.Error(readErr)) + } return string(out), resp.StatusCode == http.StatusOK, "" } @@ -701,6 +710,9 @@ func (h *WebShellHandler) FileOpWithConnection(conn *database.WebShellConnection return "", false, err.Error() } defer resp.Body.Close() - out, _ := io.ReadAll(resp.Body) + out, readErr := io.ReadAll(resp.Body) + if readErr != nil { + h.logger.Warn("webshell FileOpWithConnection read body", zap.Error(readErr)) + } return string(out), resp.StatusCode == http.StatusOK, "" } diff --git a/internal/openai/claude_bridge.go b/internal/openai/claude_bridge.go index b6e75d51..ca3a608a 100644 --- a/internal/openai/claude_bridge.go +++ b/internal/openai/claude_bridge.go @@ -487,7 +487,10 @@ func (c *Client) claudeChatCompletionStream(ctx context.Context, payload interfa defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - respBody, _ := io.ReadAll(resp.Body) + respBody, readErr := io.ReadAll(resp.Body) + if readErr != nil { + return "", fmt.Errorf("claude bridge: read error response: %w", readErr) + } return "", &APIError{ StatusCode: resp.StatusCode, Body: string(respBody), @@ -588,7 +591,10 @@ func (c *Client) claudeChatCompletionStreamWithToolCalls( defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - respBody, _ := io.ReadAll(resp.Body) + respBody, readErr := io.ReadAll(resp.Body) + if readErr != nil { + return "", nil, "", fmt.Errorf("claude bridge: read error response: %w", readErr) + } return "", nil, "", &APIError{ StatusCode: resp.StatusCode, Body: string(respBody), @@ -824,7 +830,11 @@ func (rt *claudeRoundTripper) RoundTrip(req *http.Request) (*http.Response, erro // 非 200:尝试把 Claude 错误格式转成 OpenAI 错误格式,便于 Eino 解析 if resp.StatusCode != http.StatusOK { - bodyBytes, _ := io.ReadAll(resp.Body) + bodyBytes, readErr := io.ReadAll(resp.Body) + if readErr != nil { + resp.Body.Close() + return nil, fmt.Errorf("claude bridge: read error response: %w", readErr) + } resp.Body.Close() converted := rt.tryConvertClaudeErrorToOpenAI(bodyBytes) return &http.Response{ @@ -838,7 +848,11 @@ func (rt *claudeRoundTripper) RoundTrip(req *http.Request) (*http.Response, erro // 非流式:一次性转换响应体 if !claudeReq.Stream { - respBody, _ := io.ReadAll(resp.Body) + respBody, readErr := io.ReadAll(resp.Body) + if readErr != nil { + resp.Body.Close() + return nil, fmt.Errorf("claude bridge: read response: %w", readErr) + } resp.Body.Close() oaiJSON, err := claudeToOpenAIResponseJSON(respBody) if err != nil { diff --git a/internal/openai/openai.go b/internal/openai/openai.go index 2c675e5f..7d813d1c 100644 --- a/internal/openai/openai.go +++ b/internal/openai/openai.go @@ -189,7 +189,10 @@ func (c *Client) ChatCompletionStream(ctx context.Context, payload interface{}, // 非200:读完 body 返回 if resp.StatusCode != http.StatusOK { - respBody, _ := io.ReadAll(resp.Body) + respBody, readErr := io.ReadAll(resp.Body) + if readErr != nil { + c.logger.Warn("failed to read OpenAI error response body", zap.Error(readErr)) + } return "", &APIError{ StatusCode: resp.StatusCode, Body: string(respBody), @@ -329,7 +332,10 @@ func (c *Client) ChatCompletionStreamWithToolCalls( defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - respBody, _ := io.ReadAll(resp.Body) + respBody, readErr := io.ReadAll(resp.Body) + if readErr != nil { + c.logger.Warn("failed to read OpenAI error response body", zap.Error(readErr)) + } return "", nil, "", &APIError{ StatusCode: resp.StatusCode, Body: string(respBody), diff --git a/internal/security/ratelimit.go b/internal/security/ratelimit.go new file mode 100644 index 00000000..1c959237 --- /dev/null +++ b/internal/security/ratelimit.go @@ -0,0 +1,81 @@ +package security + +import ( + "net/http" + "sync" + "time" + + "github.com/gin-gonic/gin" +) + +// rateLimitEntry 记录某个 IP 的请求窗口信息 +type rateLimitEntry struct { + count int + windowAt time.Time +} + +// RateLimiter 基于 IP 的滑动窗口速率限制器 +type RateLimiter struct { + mu sync.Mutex + entries map[string]*rateLimitEntry + limit int // 窗口内允许的最大请求数 + window time.Duration // 窗口时长 +} + +// NewRateLimiter 创建速率限制器 +func NewRateLimiter(limit int, window time.Duration) *RateLimiter { + rl := &RateLimiter{ + entries: make(map[string]*rateLimitEntry), + limit: limit, + window: window, + } + // 后台定期清理过期条目,防止内存泄漏 + go rl.cleanup() + return rl +} + +// cleanup 每分钟清理一次过期条目 +func (rl *RateLimiter) cleanup() { + ticker := time.NewTicker(1 * time.Minute) + defer ticker.Stop() + for range ticker.C { + rl.mu.Lock() + now := time.Now() + for ip, entry := range rl.entries { + if now.Sub(entry.windowAt) > rl.window { + delete(rl.entries, ip) + } + } + rl.mu.Unlock() + } +} + +// allow 检查指定 IP 是否允许通过 +func (rl *RateLimiter) allow(ip string) bool { + rl.mu.Lock() + defer rl.mu.Unlock() + + now := time.Now() + entry, ok := rl.entries[ip] + if !ok || now.Sub(entry.windowAt) > rl.window { + rl.entries[ip] = &rateLimitEntry{count: 1, windowAt: now} + return true + } + + entry.count++ + return entry.count <= rl.limit +} + +// RateLimitMiddleware 返回 Gin 中间件,对超限请求返回 429 +func RateLimitMiddleware(rl *RateLimiter) gin.HandlerFunc { + return func(c *gin.Context) { + ip := c.ClientIP() + if !rl.allow(ip) { + c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{ + "error": "rate limit exceeded, please try again later", + }) + return + } + c.Next() + } +}