Add files via upload

This commit is contained in:
公明
2026-04-21 19:16:09 +08:00
committed by GitHub
parent d037647c21
commit 26116b0822
11 changed files with 436 additions and 292 deletions
+1 -1
View File
@@ -949,7 +949,7 @@ func (a *Agent) getAvailableTools(roleTools []string) []Tool {
enabled := false
if cfg, exists := externalMCPConfigs[mcpName]; exists {
// 首先检查外部MCP是否启用
if !cfg.ExternalMCPEnable && !(cfg.Enabled && !cfg.Disabled) {
if !cfg.ExternalMCPEnable {
enabled = false // MCP未启用,所有工具都禁用
} else {
// MCP已启用,检查单个工具的启用状态
+58 -14
View File
@@ -2,6 +2,7 @@ package app
import (
"context"
"crypto/subtle"
"database/sql"
"fmt"
"net/http"
@@ -459,7 +460,9 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
func (a *App) mcpHandlerWithAuth(w http.ResponseWriter, r *http.Request) {
cfg := a.config.MCP
if cfg.AuthHeader != "" {
if r.Header.Get(cfg.AuthHeader) != cfg.AuthHeaderValue {
actual := []byte(r.Header.Get(cfg.AuthHeader))
expected := []byte(cfg.AuthHeaderValue)
if subtle.ConstantTimeCompare(actual, expected) != 1 {
a.logger.Logger.Debug("MCP 鉴权失败:header 缺失或值不匹配", zap.String("header", cfg.AuthHeader))
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusUnauthorized)
@@ -470,18 +473,25 @@ func (a *App) mcpHandlerWithAuth(w http.ResponseWriter, r *http.Request) {
a.mcpServer.HandleHTTP(w, r)
}
// Run 启动应用
// Run 启动应用(向后兼容,不支持优雅关闭)
func (a *App) Run() error {
return a.RunWithContext(context.Background())
}
// RunWithContext 启动应用,支持通过 context 取消来优雅关闭
func (a *App) RunWithContext(ctx context.Context) error {
// 启动MCP服务器(如果启用)
var mcpServer *http.Server
if a.config.MCP.Enabled {
mcpAddr := fmt.Sprintf("%s:%d", a.config.MCP.Host, a.config.MCP.Port)
a.logger.Info("启动MCP服务器", zap.String("address", mcpAddr))
mux := http.NewServeMux()
mux.HandleFunc("/mcp", a.mcpHandlerWithAuth)
mcpServer = &http.Server{Addr: mcpAddr, Handler: mux}
go func() {
mcpAddr := fmt.Sprintf("%s:%d", a.config.MCP.Host, a.config.MCP.Port)
a.logger.Info("启动MCP服务器", zap.String("address", mcpAddr))
mux := http.NewServeMux()
mux.HandleFunc("/mcp", a.mcpHandlerWithAuth)
if err := http.ListenAndServe(mcpAddr, mux); err != nil {
if err := mcpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
a.logger.Error("MCP服务器启动失败", zap.Error(err))
}
}()
@@ -491,7 +501,27 @@ func (a *App) Run() error {
addr := fmt.Sprintf("%s:%d", a.config.Server.Host, a.config.Server.Port)
a.logger.Info("启动HTTP服务器", zap.String("address", addr))
return a.router.Run(addr)
srv := &http.Server{Addr: addr, Handler: a.router}
// 监听 context 取消,优雅关闭 HTTP 服务器
go func() {
<-ctx.Done()
shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := srv.Shutdown(shutdownCtx); err != nil {
a.logger.Error("HTTP服务器关闭失败", zap.Error(err))
}
if mcpServer != nil {
if err := mcpServer.Shutdown(shutdownCtx); err != nil {
a.logger.Error("MCP服务器关闭失败", zap.Error(err))
}
}
}()
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
return err
}
return nil
}
// Shutdown 关闭应用
@@ -519,6 +549,13 @@ func (a *App) Shutdown() {
a.logger.Logger.Warn("关闭知识库数据库连接失败", zap.Error(err))
}
}
// 关闭主数据库连接
if a.db != nil {
if err := a.db.Close(); err != nil {
a.logger.Logger.Warn("关闭主数据库连接失败", zap.Error(err))
}
}
}
// startRobotConnections 根据当前配置启动钉钉/飞书长连接(不先关闭已有连接,仅用于首次启动)
@@ -593,10 +630,16 @@ func setupRoutes(
}
// 机器人回调(无需登录,供企业微信/钉钉/飞书服务器调用)
api.GET("/robot/wecom", robotHandler.HandleWecomGET)
api.POST("/robot/wecom", robotHandler.HandleWecomPOST)
api.POST("/robot/dingtalk", robotHandler.HandleDingtalkPOST)
api.POST("/robot/lark", robotHandler.HandleLarkPOST)
// 添加速率限制:每个 IP 每分钟最多 60 次请求,防止滥用
robotRL := security.NewRateLimiter(60, 1*time.Minute)
robotGroup := api.Group("/robot")
robotGroup.Use(security.RateLimitMiddleware(robotRL))
{
robotGroup.GET("/wecom", robotHandler.HandleWecomGET)
robotGroup.POST("/wecom", robotHandler.HandleWecomPOST)
robotGroup.POST("/dingtalk", robotHandler.HandleDingtalkPOST)
robotGroup.POST("/lark", robotHandler.HandleLarkPOST)
}
protected := api.Group("")
protected.Use(security.AuthMiddleware(authManager))
@@ -680,6 +723,7 @@ func setupRoutes(
// 配置管理
protected.GET("/config", configHandler.GetConfig)
protected.GET("/config/tools", configHandler.GetTools)
protected.GET("/config/tools/:name/schema", configHandler.GetToolSchema)
protected.PUT("/config", configHandler.UpdateConfig)
protected.POST("/config/apply", configHandler.ApplyConfig)
protected.POST("/config/test-openai", configHandler.TestOpenAI)
+7 -4
View File
@@ -2089,14 +2089,17 @@ func (h *AgentHandler) nextBatchQueueRunAt(cronExpr string, from time.Time) (*ti
}
func (h *AgentHandler) startBatchQueueExecution(queueID string, scheduled bool) (bool, error) {
queue, exists := h.batchTaskManager.GetBatchQueue(queueID)
if !exists {
return false, nil
}
// 先获取执行互斥门,再读取队列状态,避免基于过时快照做判断
if !h.markBatchQueueRunning(queueID) {
return true, nil
}
queue, exists := h.batchTaskManager.GetBatchQueue(queueID)
if !exists {
h.unmarkBatchQueueRunning(queueID)
return false, nil
}
if scheduled {
if queue.ScheduleMode != "cron" {
h.unmarkBatchQueueRunning(queueID)
+63 -68
View File
@@ -543,16 +543,23 @@ func (m *BatchTaskManager) UpdateTaskStatus(queueID, taskID, status string, resu
// UpdateTaskStatusWithConversationID 更新任务状态(包含conversationId
func (m *BatchTaskManager) UpdateTaskStatusWithConversationID(queueID, taskID, status string, result, errorMsg, conversationID string) {
var needDBUpdate bool
// 在锁内只更新内存状态
m.mu.Lock()
defer m.mu.Unlock()
queue, exists := m.queues[queueID]
if !exists {
m.mu.Unlock()
return
}
// DB 优先:先持久化,成功后再更新内存,避免重启后状态不一致
if m.db != nil {
if err := m.db.UpdateBatchTaskStatus(queueID, taskID, status, conversationID, result, errorMsg); err != nil {
m.logger.Warn("batch task DB status update failed, skipping memory update",
zap.String("queueId", queueID), zap.String("taskId", taskID), zap.Error(err))
return
}
}
for _, task := range queue.Tasks {
if task.ID == taskID {
task.Status = status
@@ -575,30 +582,27 @@ func (m *BatchTaskManager) UpdateTaskStatusWithConversationID(queueID, taskID, s
break
}
}
needDBUpdate = m.db != nil
m.mu.Unlock()
// 释放锁后写 DB
if needDBUpdate {
if err := m.db.UpdateBatchTaskStatus(queueID, taskID, status, conversationID, result, errorMsg); err != nil {
m.logger.Warn("batch task DB status update failed", zap.String("queueId", queueID), zap.String("taskId", taskID), zap.Error(err))
}
}
}
// UpdateQueueStatus 更新队列状态
func (m *BatchTaskManager) UpdateQueueStatus(queueID, status string) {
var needDBUpdate bool
// 在锁内只更新内存状态
m.mu.Lock()
defer m.mu.Unlock()
queue, exists := m.queues[queueID]
if !exists {
m.mu.Unlock()
return
}
// DB 优先:先持久化,成功后再更新内存
if m.db != nil {
if err := m.db.UpdateBatchQueueStatus(queueID, status); err != nil {
m.logger.Warn("batch queue DB status update failed, skipping memory update",
zap.String("queueId", queueID), zap.Error(err))
return
}
}
queue.Status = status
now := time.Now()
if status == BatchQueueStatusRunning && queue.StartedAt == nil {
@@ -607,16 +611,6 @@ func (m *BatchTaskManager) UpdateQueueStatus(queueID, status string) {
if status == BatchQueueStatusCompleted || status == BatchQueueStatusCancelled {
queue.CompletedAt = &now
}
needDBUpdate = m.db != nil
m.mu.Unlock()
// 释放锁后写 DB
if needDBUpdate {
if err := m.db.UpdateBatchQueueStatus(queueID, status); err != nil {
m.logger.Warn("batch queue DB status update failed", zap.String("queueId", queueID), zap.Error(err))
}
}
}
// UpdateQueueSchedule 更新队列调度配置
@@ -756,6 +750,16 @@ func (m *BatchTaskManager) ResetQueueForRerun(queueID string) bool {
if !exists {
return false
}
// DB 优先:先持久化重置,成功后再更新内存,避免 DB 失败导致内存脏状态
if m.db != nil {
if err := m.db.ResetBatchQueueForRerun(queueID); err != nil {
m.logger.Warn("batch queue DB reset for rerun failed, skipping memory update",
zap.String("queueId", queueID), zap.Error(err))
return false
}
}
queue.Status = BatchQueueStatusPending
queue.CurrentIndex = 0
queue.StartedAt = nil
@@ -771,12 +775,6 @@ func (m *BatchTaskManager) ResetQueueForRerun(queueID string) bool {
task.Error = ""
task.Result = ""
}
if m.db != nil {
if err := m.db.ResetBatchQueueForRerun(queueID); err != nil {
return false
}
}
return true
}
@@ -870,7 +868,7 @@ func (m *BatchTaskManager) DeleteTask(queueID, taskID string) error {
return fmt.Errorf("队列正在执行或未就绪,无法删除任务")
}
// 查找并删除任务
// 查找任务
taskIndex := -1
for i, task := range queue.Tasks {
if task.ID == taskID {
@@ -886,18 +884,14 @@ func (m *BatchTaskManager) DeleteTask(queueID, taskID string) error {
return fmt.Errorf("任务不存在")
}
// 从内存队列中删
queue.Tasks = append(queue.Tasks[:taskIndex], queue.Tasks[taskIndex+1:]...)
// 同步到数据库
// DB 优先:先从数据库删除,成功后再从内存移
if m.db != nil {
if err := m.db.DeleteBatchTask(queueID, taskID); err != nil {
// 如果数据库删除失败,恢复内存中的任务
// 这里需要重新插入,但为了简化,我们只记录错误
return fmt.Errorf("删除任务失败: %w", err)
}
}
queue.Tasks = append(queue.Tasks[:taskIndex], queue.Tasks[taskIndex+1:]...)
return nil
}
@@ -987,9 +981,7 @@ func (m *BatchTaskManager) SetTaskCancel(queueID string, cancel context.CancelFu
// PauseQueue 暂停队列
func (m *BatchTaskManager) PauseQueue(queueID string) bool {
var cancelFunc context.CancelFunc
var needDBUpdate bool
// 在锁内只更新内存状态
m.mu.Lock()
queue, exists := m.queues[queueID]
if !exists {
@@ -1002,6 +994,16 @@ func (m *BatchTaskManager) PauseQueue(queueID string) bool {
return false
}
// DB 优先:先持久化,成功后再更新内存
if m.db != nil {
if err := m.db.UpdateBatchQueueStatus(queueID, BatchQueueStatusPaused); err != nil {
m.logger.Warn("batch queue DB pause update failed, skipping memory update",
zap.String("queueId", queueID), zap.Error(err))
m.mu.Unlock()
return false
}
}
queue.Status = BatchQueueStatusPaused
// 取消当前正在执行的任务(通过取消context)
@@ -1009,22 +1011,13 @@ func (m *BatchTaskManager) PauseQueue(queueID string) bool {
cancelFunc = cancel
delete(m.taskCancels, queueID)
}
needDBUpdate = m.db != nil
m.mu.Unlock()
// 释放锁后执行取消回调
// 释放锁后执行取消回调cancel 可能阻塞,不应持锁)
if cancelFunc != nil {
cancelFunc()
}
// 释放锁后写 DB
if needDBUpdate {
if err := m.db.UpdateBatchQueueStatus(queueID, BatchQueueStatusPaused); err != nil {
m.logger.Warn("batch queue DB pause update failed", zap.String("queueId", queueID), zap.Error(err))
}
}
return true
}
@@ -1032,9 +1025,7 @@ func (m *BatchTaskManager) PauseQueue(queueID string) bool {
func (m *BatchTaskManager) CancelQueue(queueID string) bool {
now := time.Now()
var cancelFunc context.CancelFunc
var needDBUpdate bool
// 在锁内只更新内存状态,不做 DB 操作
m.mu.Lock()
queue, exists := m.queues[queueID]
if !exists {
@@ -1047,6 +1038,22 @@ func (m *BatchTaskManager) CancelQueue(queueID string) bool {
return false
}
// DB 优先:先持久化,成功后再更新内存
if m.db != nil {
if err := m.db.CancelPendingBatchTasks(queueID, now); err != nil {
m.logger.Warn("batch task DB batch cancel failed, skipping memory update",
zap.String("queueId", queueID), zap.Error(err))
m.mu.Unlock()
return false
}
if err := m.db.UpdateBatchQueueStatus(queueID, BatchQueueStatusCancelled); err != nil {
m.logger.Warn("batch queue DB cancel update failed, skipping memory update",
zap.String("queueId", queueID), zap.Error(err))
m.mu.Unlock()
return false
}
}
queue.Status = BatchQueueStatusCancelled
queue.CompletedAt = &now
@@ -1063,25 +1070,13 @@ func (m *BatchTaskManager) CancelQueue(queueID string) bool {
cancelFunc = cancel
delete(m.taskCancels, queueID)
}
needDBUpdate = m.db != nil
m.mu.Unlock()
// 释放锁后执行取消回调
// 释放锁后执行取消回调cancel 可能阻塞,不应持锁)
if cancelFunc != nil {
cancelFunc()
}
// 释放锁后批量写 DB(单条 SQL 取消所有 pending 任务)
if needDBUpdate {
if err := m.db.CancelPendingBatchTasks(queueID, now); err != nil {
m.logger.Warn("batch task DB batch cancel failed", zap.String("queueId", queueID), zap.Error(err))
}
if err := m.db.UpdateBatchQueueStatus(queueID, BatchQueueStatusCancelled); err != nil {
m.logger.Warn("batch queue DB cancel update failed", zap.String("queueId", queueID), zap.Error(err))
}
}
return true
}
+121 -39
View File
@@ -194,12 +194,13 @@ type GetConfigResponse struct {
// ToolConfigInfo 工具配置信息
type ToolConfigInfo struct {
Name string `json:"name"`
Description string `json:"description"`
Enabled bool `json:"enabled"`
IsExternal bool `json:"is_external,omitempty"` // 是否为外部MCP工具
ExternalMCP string `json:"external_mcp,omitempty"` // 外部MCP名称(如果是外部工具)
RoleEnabled *bool `json:"role_enabled,omitempty"` // 该工具在当前角色中是否启用(nil表示未指定角色或使用所有工具)
Name string `json:"name"`
Description string `json:"description"`
Enabled bool `json:"enabled"`
IsExternal bool `json:"is_external,omitempty"` // 是否为外部MCP工具
ExternalMCP string `json:"external_mcp,omitempty"` // 外部MCP名称(如果是外部工具)
RoleEnabled *bool `json:"role_enabled,omitempty"` // 该工具在当前角色中是否启用(nil表示未指定角色或使用所有工具)
InputSchema map[string]interface{} `json:"input_schema,omitempty"` // 工具参数 JSON Schema(用于前端展示详情)
}
// GetConfig 获取当前配置
@@ -211,25 +212,25 @@ func (h *ConfigHandler) GetConfig(c *gin.Context) {
// 首先从配置文件获取工具
configToolMap := make(map[string]bool)
tools := make([]ToolConfigInfo, 0, len(h.config.Security.Tools))
for _, tool := range h.config.Security.Tools {
configToolMap[tool.Name] = true
tools = append(tools, ToolConfigInfo{
info := ToolConfigInfo{
Name: tool.Name,
Description: h.pickToolDescription(tool.ShortDescription, tool.Description),
Enabled: tool.Enabled,
IsExternal: false,
})
}
tools = append(tools, info)
}
// 从MCP服务器获取所有已注册的工具(包括直接注册的工具,如知识检索工具)
if h.mcpServer != nil {
mcpTools := h.mcpServer.GetAllTools()
for _, mcpTool := range mcpTools {
// 跳过已经在配置文件中的工具(避免重复)
if configToolMap[mcpTool.Name] {
continue
}
// 添加直接注册到MCP服务器的工具(如知识检索工具)
description := mcpTool.ShortDescription
if description == "" {
description = mcpTool.Description
@@ -240,7 +241,7 @@ func (h *ConfigHandler) GetConfig(c *gin.Context) {
tools = append(tools, ToolConfigInfo{
Name: mcpTool.Name,
Description: description,
Enabled: true, // 直接注册的工具默认启用
Enabled: true,
IsExternal: false,
})
}
@@ -442,7 +443,7 @@ func (h *ConfigHandler) GetTools(c *gin.Context) {
toolInfo := ToolConfigInfo{
Name: mcpTool.Name,
Description: description,
Enabled: true, // 直接注册的工具默认启用
Enabled: true,
IsExternal: false,
}
@@ -1142,32 +1143,7 @@ func (h *ConfigHandler) saveConfig() error {
updateRobotsConfig(root, h.config.Robots)
updateMultiAgentConfig(root, h.config.MultiAgent)
// 更新外部MCP配置(使用external_mcp.go中的函数,同一包中可直接调用)
// 读取原始配置以保持向后兼容
originalConfigs := make(map[string]map[string]bool)
externalMCPNode := findMapValue(root, "external_mcp")
if externalMCPNode != nil && externalMCPNode.Kind == yaml.MappingNode {
serversNode := findMapValue(externalMCPNode, "servers")
if serversNode != nil && serversNode.Kind == yaml.MappingNode {
for i := 0; i < len(serversNode.Content); i += 2 {
if i+1 >= len(serversNode.Content) {
break
}
nameNode := serversNode.Content[i]
serverNode := serversNode.Content[i+1]
if nameNode.Kind == yaml.ScalarNode && serverNode.Kind == yaml.MappingNode {
serverName := nameNode.Value
originalConfigs[serverName] = make(map[string]bool)
if enabledVal := findBoolInMap(serverNode, "enabled"); enabledVal != nil {
originalConfigs[serverName]["enabled"] = *enabledVal
}
if disabledVal := findBoolInMap(serverNode, "disabled"); disabledVal != nil {
originalConfigs[serverName]["disabled"] = *disabledVal
}
}
}
}
}
updateExternalMCPConfig(root, h.config.ExternalMCP, originalConfigs)
updateExternalMCPConfig(root, h.config.ExternalMCP)
if err := writeYAMLDocument(h.configPath, root); err != nil {
return fmt.Errorf("保存配置文件失败: %w", err)
@@ -1585,7 +1561,7 @@ func (h *ConfigHandler) calculateExternalToolEnabled(mcpName, toolName string, c
}
// 首先检查外部MCP是否启用
if !cfg.ExternalMCPEnable && !(cfg.Enabled && !cfg.Disabled) {
if !cfg.ExternalMCPEnable {
return false // MCP未启用,所有工具都禁用
}
@@ -1624,3 +1600,109 @@ func (h *ConfigHandler) pickToolDescription(shortDesc, fullDesc string) string {
}
return description
}
// GetToolSchema 获取单个工具的 inputSchema(按需加载,避免列表接口返回大量 schema 数据)
func (h *ConfigHandler) GetToolSchema(c *gin.Context) {
h.mu.RLock()
defer h.mu.RUnlock()
toolName := c.Param("name")
if toolName == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "工具名称不能为空"})
return
}
// 检查是否为外部工具(格式:mcpName::toolName
externalMCP := c.Query("external_mcp")
if externalMCP != "" {
// 外部 MCP 工具
if h.externalMCPMgr != nil {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
externalTools, _ := h.externalMCPMgr.GetAllTools(ctx)
fullName := externalMCP + "::" + toolName
for _, t := range externalTools {
if t.Name == fullName {
c.JSON(http.StatusOK, gin.H{"input_schema": t.InputSchema})
return
}
}
}
c.JSON(http.StatusNotFound, gin.H{"error": "外部工具未找到"})
return
}
// 内部工具:从 YAML 配置的 Parameters 构建
for _, tool := range h.config.Security.Tools {
if tool.Name == toolName {
c.JSON(http.StatusOK, gin.H{"input_schema": buildInputSchemaFromParams(tool.Parameters)})
return
}
}
// MCP 注册工具(如知识检索)
if h.mcpServer != nil {
for _, mt := range h.mcpServer.GetAllTools() {
if mt.Name == toolName {
c.JSON(http.StatusOK, gin.H{"input_schema": mt.InputSchema})
return
}
}
}
c.JSON(http.StatusNotFound, gin.H{"error": "工具未找到"})
}
// buildInputSchemaFromParams 从 YAML 工具的 ParameterConfig 构建 JSON Schema(用于前端展示)。
// 不依赖 MCP 服务器注册状态,所有工具(包括未启用的)都能返回参数定义。
func buildInputSchemaFromParams(params []config.ParameterConfig) map[string]interface{} {
if len(params) == 0 {
return nil
}
properties := make(map[string]interface{})
required := make([]string, 0)
for _, p := range params {
name := strings.TrimSpace(p.Name)
if name == "" {
continue
}
prop := map[string]interface{}{
"type": convertParamType(p.Type),
"description": p.Description,
}
if p.Default != nil {
prop["default"] = p.Default
}
if len(p.Options) > 0 {
prop["enum"] = p.Options
}
properties[name] = prop
if p.Required {
required = append(required, name)
}
}
schema := map[string]interface{}{
"type": "object",
"properties": properties,
}
if len(required) > 0 {
schema["required"] = required
}
return schema
}
func convertParamType(t string) string {
switch strings.TrimSpace(strings.ToLower(t)) {
case "int", "integer", "number":
return "number"
case "bool", "boolean":
return "boolean"
case "array", "list":
return "array"
default:
return "string"
}
}
+41 -124
View File
@@ -157,36 +157,19 @@ func (h *ExternalMCPHandler) AddOrUpdateExternalMCP(c *gin.Context) {
h.config.ExternalMCP.Servers = make(map[string]config.ExternalMCPServerConfig)
}
// 如果用户提供了 disabled 或 enabled 字段,保留它们以保持向后兼容
// 同时将值迁移到 external_mcp_enable
cfg := req.Config
if req.Config.Disabled {
// 用户设置了 disabled: true
// 官方 disabled 字段 → ExternalMCPEnable 取反
if cfg.Disabled {
cfg.ExternalMCPEnable = false
cfg.Disabled = true
cfg.Enabled = false
} else if req.Config.Enabled {
// 用户设置了 enabled: true
} else if !cfg.ExternalMCPEnable {
// 用户未显式设置 external_mcp_enable,官方配置默认就是启用的
cfg.ExternalMCPEnable = true
cfg.Enabled = true
cfg.Disabled = false
} else if !req.Config.ExternalMCPEnable {
// 用户没有设置任何字段,且 external_mcp_enable 为 false
// 检查现有配置是否有旧字段
if existingCfg, exists := h.config.ExternalMCP.Servers[name]; exists {
// 保留现有的旧字段
cfg.Enabled = existingCfg.Enabled
cfg.Disabled = existingCfg.Disabled
}
} else {
// 用户通过新字段启用了(external_mcp_enable: true),但没有设置旧字段
// 为了向后兼容,我们设置 enabled: true
// 这样即使原始配置中有 disabled: false,也会被转换为 enabled: true
cfg.Enabled = true
cfg.Disabled = false
}
// 展开 ${VAR} 环境变量
config.ExpandConfigEnv(&cfg)
h.config.ExternalMCP.Servers[name] = cfg
// 保存到配置文件
@@ -315,32 +298,25 @@ func (h *ExternalMCPHandler) GetExternalMCPStats(c *gin.Context) {
c.JSON(http.StatusOK, stats)
}
// validateConfig 验证配置
// validateConfig 验证配置(同时支持官方 type 字段和旧版 transport 字段)
func (h *ExternalMCPHandler) validateConfig(cfg config.ExternalMCPServerConfig) error {
transport := cfg.Transport
transport := cfg.GetTransportType()
if transport == "" {
// 如果没有指定transport,根据是否有command或url判断
if cfg.Command != "" {
transport = "stdio"
} else if cfg.URL != "" {
transport = "http"
} else {
return fmt.Errorf("需要指定commandstdio模式)或urlhttp/sse模式)")
}
return fmt.Errorf("需要指定 commandstdio模式)或 url + typehttp/sse模式)")
}
switch transport {
case "http":
if cfg.URL == "" {
return fmt.Errorf("HTTP模式需要URL")
return fmt.Errorf("HTTP模式需要 url")
}
case "stdio":
if cfg.Command == "" {
return fmt.Errorf("stdio模式需要command")
return fmt.Errorf("stdio模式需要 command")
}
case "sse":
if cfg.URL == "" {
return fmt.Errorf("SSE模式需要URL")
return fmt.Errorf("SSE模式需要 url")
}
default:
return fmt.Errorf("不支持的传输模式: %s,支持的模式: http, stdio, sse", transport)
@@ -351,25 +327,11 @@ func (h *ExternalMCPHandler) validateConfig(cfg config.ExternalMCPServerConfig)
// isEnabled 检查是否启用
func (h *ExternalMCPHandler) isEnabled(cfg config.ExternalMCPServerConfig) bool {
// 优先使用 ExternalMCPEnable 字段
// 如果没有设置,检查旧的 enabled/disabled 字段(向后兼容)
if cfg.ExternalMCPEnable {
return true
}
// 向后兼容:检查旧字段
if cfg.Disabled {
return false
}
if cfg.Enabled {
return true
}
// 都没有设置,默认为启用
return true
return cfg.ExternalMCPEnable
}
// saveConfig 保存配置到文件
func (h *ExternalMCPHandler) saveConfig() error {
// 读取现有配置文件并创建备份
data, err := os.ReadFile(h.configPath)
if err != nil {
return fmt.Errorf("读取配置文件失败: %w", err)
@@ -384,37 +346,7 @@ func (h *ExternalMCPHandler) saveConfig() error {
return fmt.Errorf("解析配置文件失败: %w", err)
}
// 在更新前,读取原始配置中的 enabled/disabled 字段,以便保持向后兼容
originalConfigs := make(map[string]map[string]bool)
externalMCPNode := findMapValue(root.Content[0], "external_mcp")
if externalMCPNode != nil && externalMCPNode.Kind == yaml.MappingNode {
serversNode := findMapValue(externalMCPNode, "servers")
if serversNode != nil && serversNode.Kind == yaml.MappingNode {
// 遍历现有的服务器配置,保存 enabled/disabled 字段
for i := 0; i < len(serversNode.Content); i += 2 {
if i+1 >= len(serversNode.Content) {
break
}
nameNode := serversNode.Content[i]
serverNode := serversNode.Content[i+1]
if nameNode.Kind == yaml.ScalarNode && serverNode.Kind == yaml.MappingNode {
serverName := nameNode.Value
originalConfigs[serverName] = make(map[string]bool)
// 检查是否有 enabled 字段
if enabledVal := findBoolInMap(serverNode, "enabled"); enabledVal != nil {
originalConfigs[serverName]["enabled"] = *enabledVal
}
// 检查是否有 disabled 字段
if disabledVal := findBoolInMap(serverNode, "disabled"); disabledVal != nil {
originalConfigs[serverName]["disabled"] = *disabledVal
}
}
}
}
}
// 更新外部MCP配置
updateExternalMCPConfig(root, h.config.ExternalMCP, originalConfigs)
updateExternalMCPConfig(root, h.config.ExternalMCP)
if err := writeYAMLDocument(h.configPath, root); err != nil {
return fmt.Errorf("保存配置文件失败: %w", err)
@@ -425,7 +357,7 @@ func (h *ExternalMCPHandler) saveConfig() error {
}
// updateExternalMCPConfig 更新外部MCP配置
func updateExternalMCPConfig(doc *yaml.Node, cfg config.ExternalMCPConfig, originalConfigs map[string]map[string]bool) {
func updateExternalMCPConfig(doc *yaml.Node, cfg config.ExternalMCPConfig) {
root := doc.Content[0]
externalMCPNode := ensureMap(root, "external_mcp")
serversNode := ensureMap(externalMCPNode, "servers")
@@ -435,32 +367,31 @@ func updateExternalMCPConfig(doc *yaml.Node, cfg config.ExternalMCPConfig, origi
// 添加新的服务器配置
for name, serverCfg := range cfg.Servers {
// 添加服务器名称键
nameNode := &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: name}
serverNode := &yaml.Node{Kind: yaml.MappingNode, Tag: "!!map"}
serversNode.Content = append(serversNode.Content, nameNode, serverNode)
// 设置服务器配置字段
// type(官方 MCP 传输类型)
effectiveType := serverCfg.GetTransportType()
if effectiveType != "" && effectiveType != "stdio" {
// stdio 可省略(有 command 时自动推断)
setStringInMap(serverNode, "type", effectiveType)
}
if serverCfg.Command != "" {
setStringInMap(serverNode, "command", serverCfg.Command)
}
if len(serverCfg.Args) > 0 {
setStringArrayInMap(serverNode, "args", serverCfg.Args)
}
// 保存 env 字段(环境变量)
if serverCfg.Env != nil && len(serverCfg.Env) > 0 {
envNode := ensureMap(serverNode, "env")
for envKey, envValue := range serverCfg.Env {
setStringInMap(envNode, envKey, envValue)
}
}
if serverCfg.Transport != "" {
setStringInMap(serverNode, "transport", serverCfg.Transport)
}
if serverCfg.URL != "" {
setStringInMap(serverNode, "url", serverCfg.URL)
}
// 保存 headers 字段(HTTP/SSE 请求头)
if serverCfg.Headers != nil && len(serverCfg.Headers) > 0 {
headersNode := ensureMap(serverNode, "headers")
for k, v := range serverCfg.Headers {
@@ -473,46 +404,32 @@ func updateExternalMCPConfig(doc *yaml.Node, cfg config.ExternalMCPConfig, origi
if serverCfg.Timeout > 0 {
setIntInMap(serverNode, "timeout", serverCfg.Timeout)
}
// 保存 external_mcp_enable 字段(新字段
// 官方标准字段
if serverCfg.Disabled {
setBoolInMap(serverNode, "disabled", true)
}
if len(serverCfg.AutoApprove) > 0 {
setStringArrayInMap(serverNode, "autoApprove", serverCfg.AutoApprove)
}
// SDK 高级配置
if serverCfg.MaxRetries > 0 {
setIntInMap(serverNode, "max_retries", serverCfg.MaxRetries)
}
if serverCfg.TerminateDuration > 0 {
setIntInMap(serverNode, "terminate_duration", serverCfg.TerminateDuration)
}
if serverCfg.KeepAlive > 0 {
setIntInMap(serverNode, "keep_alive", serverCfg.KeepAlive)
}
setBoolInMap(serverNode, "external_mcp_enable", serverCfg.ExternalMCPEnable)
// 保存 tool_enabled 字段(每个工具的启用状态)
if serverCfg.ToolEnabled != nil && len(serverCfg.ToolEnabled) > 0 {
toolEnabledNode := ensureMap(serverNode, "tool_enabled")
for toolName, enabled := range serverCfg.ToolEnabled {
setBoolInMap(toolEnabledNode, toolName, enabled)
}
}
// 保留旧的 enabled/disabled 字段以保持向后兼容
originalFields, hasOriginal := originalConfigs[name]
// 如果原始配置中有 enabled 字段,保留它
if hasOriginal {
if enabledVal, hasEnabled := originalFields["enabled"]; hasEnabled {
setBoolInMap(serverNode, "enabled", enabledVal)
}
// 如果原始配置中有 disabled 字段,保留它
// 注意:由于 omitemptydisabled: false 不会被保存,但 disabled: true 会被保存
if disabledVal, hasDisabled := originalFields["disabled"]; hasDisabled {
if disabledVal {
setBoolInMap(serverNode, "disabled", disabledVal)
} else {
// 如果原始配置中有 disabled: false,我们保存 enabled: true 来等效表示
// 因为 disabled: false 等价于 enabled: true
setBoolInMap(serverNode, "enabled", true)
}
}
}
// 如果用户在当前请求中明确设置了这些字段,也保存它们
if serverCfg.Enabled {
setBoolInMap(serverNode, "enabled", serverCfg.Enabled)
}
if serverCfg.Disabled {
setBoolInMap(serverNode, "disabled", serverCfg.Disabled)
} else if !hasOriginal && serverCfg.ExternalMCPEnable {
// 如果用户通过新字段启用了,且原始配置中没有旧字段,保存 enabled: true 以保持向后兼容
setBoolInMap(serverNode, "enabled", true)
}
}
}
+22 -32
View File
@@ -60,13 +60,13 @@ func TestExternalMCPHandler_AddOrUpdateExternalMCP_Stdio(t *testing.T) {
router, _, configPath := setupTestRouter()
defer cleanupTestConfig(configPath)
// 测试添加stdio模式的配置
// 测试添加stdio模式的配置(官方格式:有 command 时 type 可省略)
configJSON := `{
"command": "python3",
"args": ["/path/to/script.py", "--server", "http://example.com"],
"description": "Test stdio MCP",
"timeout": 300,
"enabled": true
"external_mcp_enable": true
}`
var configObj config.ExternalMCPServerConfig
@@ -115,20 +115,17 @@ func TestExternalMCPHandler_AddOrUpdateExternalMCP_Stdio(t *testing.T) {
if response.Config.Timeout != 300 {
t.Errorf("期望timeout为300,实际%d", response.Config.Timeout)
}
if !response.Config.Enabled {
t.Error("期望enabled为true")
}
}
func TestExternalMCPHandler_AddOrUpdateExternalMCP_HTTP(t *testing.T) {
router, _, configPath := setupTestRouter()
defer cleanupTestConfig(configPath)
// 测试添加HTTP模式的配置
// 测试添加HTTP模式的配置(使用官方 type 字段)
configJSON := `{
"transport": "http",
"type": "http",
"url": "http://127.0.0.1:8081/mcp",
"enabled": true
"external_mcp_enable": true
}`
var configObj config.ExternalMCPServerConfig
@@ -165,15 +162,12 @@ func TestExternalMCPHandler_AddOrUpdateExternalMCP_HTTP(t *testing.T) {
t.Fatalf("解析响应失败: %v", err)
}
if response.Config.Transport != "http" {
t.Errorf("期望transport为http,实际%s", response.Config.Transport)
if response.Config.Type != "http" {
t.Errorf("期望type为http,实际%s", response.Config.Type)
}
if response.Config.URL != "http://127.0.0.1:8081/mcp" {
t.Errorf("期望url为'http://127.0.0.1:8081/mcp',实际%s", response.Config.URL)
}
if !response.Config.Enabled {
t.Error("期望enabled为true")
}
}
func TestExternalMCPHandler_AddOrUpdateExternalMCP_InvalidConfig(t *testing.T) {
@@ -187,22 +181,22 @@ func TestExternalMCPHandler_AddOrUpdateExternalMCP_InvalidConfig(t *testing.T) {
}{
{
name: "缺少command和url",
configJSON: `{"enabled": true}`,
expectedErr: "需要指定commandstdio模式)或urlhttp/sse模式)",
configJSON: `{"external_mcp_enable": true}`,
expectedErr: "需要指定 commandstdio模式)或 url + typehttp/sse模式)",
},
{
name: "stdio模式缺少command",
configJSON: `{"args": ["test"], "enabled": true}`,
configJSON: `{"args": ["test"], "external_mcp_enable": true}`,
expectedErr: "stdio模式需要command",
},
{
name: "http模式缺少url",
configJSON: `{"transport": "http", "enabled": true}`,
expectedErr: "HTTP模式需要URL",
configJSON: `{"type": "http", "external_mcp_enable": true}`,
expectedErr: "HTTP模式需要 url",
},
{
name: "无效的transport",
configJSON: `{"transport": "invalid", "enabled": true}`,
name: "无效的type",
configJSON: `{"type": "invalid", "external_mcp_enable": true}`,
expectedErr: "不支持的传输模式",
},
}
@@ -254,7 +248,7 @@ func TestExternalMCPHandler_DeleteExternalMCP(t *testing.T) {
// 先添加一个配置
configObj := config.ExternalMCPServerConfig{
Command: "python3",
Enabled: true,
ExternalMCPEnable: true,
}
handler.manager.AddOrUpdateConfig("test-delete", configObj)
@@ -283,11 +277,11 @@ func TestExternalMCPHandler_GetExternalMCPs(t *testing.T) {
// 添加多个配置
handler.manager.AddOrUpdateConfig("test1", config.ExternalMCPServerConfig{
Command: "python3",
Enabled: true,
ExternalMCPEnable: true,
})
handler.manager.AddOrUpdateConfig("test2", config.ExternalMCPServerConfig{
URL: "http://127.0.0.1:8081/mcp",
Enabled: false,
ExternalMCPEnable: false,
})
req := httptest.NewRequest("GET", "/api/external-mcp", nil)
@@ -326,16 +320,14 @@ func TestExternalMCPHandler_GetExternalMCPStats(t *testing.T) {
// 添加配置
handler.manager.AddOrUpdateConfig("enabled1", config.ExternalMCPServerConfig{
Command: "python3",
Enabled: true,
ExternalMCPEnable: true,
})
handler.manager.AddOrUpdateConfig("enabled2", config.ExternalMCPServerConfig{
URL: "http://127.0.0.1:8081/mcp",
Enabled: true,
ExternalMCPEnable: true,
})
handler.manager.AddOrUpdateConfig("disabled1", config.ExternalMCPServerConfig{
Command: "python3",
Enabled: false,
Disabled: true,
})
req := httptest.NewRequest("GET", "/api/external-mcp/stats", nil)
@@ -369,8 +361,6 @@ func TestExternalMCPHandler_StartStopExternalMCP(t *testing.T) {
// 添加一个禁用的配置
handler.manager.AddOrUpdateConfig("test-start-stop", config.ExternalMCPServerConfig{
Command: "python3",
Enabled: false,
Disabled: true,
})
// 测试启动(可能会失败,因为没有真实的服务器)
@@ -427,7 +417,7 @@ func TestExternalMCPHandler_AddOrUpdateExternalMCP_EmptyName(t *testing.T) {
configObj := config.ExternalMCPServerConfig{
Command: "python3",
Enabled: true,
ExternalMCPEnable: true,
}
reqBody := AddOrUpdateExternalMCPRequest{
@@ -470,14 +460,14 @@ func TestExternalMCPHandler_UpdateExistingConfig(t *testing.T) {
// 先添加配置
config1 := config.ExternalMCPServerConfig{
Command: "python3",
Enabled: true,
ExternalMCPEnable: true,
}
handler.manager.AddOrUpdateConfig("test-update", config1)
// 更新配置
config2 := config.ExternalMCPServerConfig{
URL: "http://127.0.0.1:8081/mcp",
Enabled: true,
ExternalMCPEnable: true,
}
reqBody := AddOrUpdateExternalMCPRequest{
+16 -4
View File
@@ -411,7 +411,10 @@ func (h *WebShellHandler) Exec(c *gin.Context) {
}
defer resp.Body.Close()
out, _ := io.ReadAll(resp.Body)
out, readErr := io.ReadAll(resp.Body)
if readErr != nil {
h.logger.Warn("webshell exec read body", zap.Error(readErr))
}
output := string(out)
httpCode := resp.StatusCode
@@ -578,7 +581,10 @@ func (h *WebShellHandler) FileOp(c *gin.Context) {
}
defer resp.Body.Close()
out, _ := io.ReadAll(resp.Body)
out, readErr := io.ReadAll(resp.Body)
if readErr != nil {
h.logger.Warn("webshell fileop read body", zap.Error(readErr))
}
output := string(out)
c.JSON(http.StatusOK, FileOpResponse{
@@ -633,7 +639,10 @@ func (h *WebShellHandler) ExecWithConnection(conn *database.WebShellConnection,
return "", false, err.Error()
}
defer resp.Body.Close()
out, _ := io.ReadAll(resp.Body)
out, readErr := io.ReadAll(resp.Body)
if readErr != nil {
h.logger.Warn("webshell ExecWithConnection read body", zap.Error(readErr))
}
return string(out), resp.StatusCode == http.StatusOK, ""
}
@@ -701,6 +710,9 @@ func (h *WebShellHandler) FileOpWithConnection(conn *database.WebShellConnection
return "", false, err.Error()
}
defer resp.Body.Close()
out, _ := io.ReadAll(resp.Body)
out, readErr := io.ReadAll(resp.Body)
if readErr != nil {
h.logger.Warn("webshell FileOpWithConnection read body", zap.Error(readErr))
}
return string(out), resp.StatusCode == http.StatusOK, ""
}
+18 -4
View File
@@ -487,7 +487,10 @@ func (c *Client) claudeChatCompletionStream(ctx context.Context, payload interfa
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
respBody, _ := io.ReadAll(resp.Body)
respBody, readErr := io.ReadAll(resp.Body)
if readErr != nil {
return "", fmt.Errorf("claude bridge: read error response: %w", readErr)
}
return "", &APIError{
StatusCode: resp.StatusCode,
Body: string(respBody),
@@ -588,7 +591,10 @@ func (c *Client) claudeChatCompletionStreamWithToolCalls(
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
respBody, _ := io.ReadAll(resp.Body)
respBody, readErr := io.ReadAll(resp.Body)
if readErr != nil {
return "", nil, "", fmt.Errorf("claude bridge: read error response: %w", readErr)
}
return "", nil, "", &APIError{
StatusCode: resp.StatusCode,
Body: string(respBody),
@@ -824,7 +830,11 @@ func (rt *claudeRoundTripper) RoundTrip(req *http.Request) (*http.Response, erro
// 非 200:尝试把 Claude 错误格式转成 OpenAI 错误格式,便于 Eino 解析
if resp.StatusCode != http.StatusOK {
bodyBytes, _ := io.ReadAll(resp.Body)
bodyBytes, readErr := io.ReadAll(resp.Body)
if readErr != nil {
resp.Body.Close()
return nil, fmt.Errorf("claude bridge: read error response: %w", readErr)
}
resp.Body.Close()
converted := rt.tryConvertClaudeErrorToOpenAI(bodyBytes)
return &http.Response{
@@ -838,7 +848,11 @@ func (rt *claudeRoundTripper) RoundTrip(req *http.Request) (*http.Response, erro
// 非流式:一次性转换响应体
if !claudeReq.Stream {
respBody, _ := io.ReadAll(resp.Body)
respBody, readErr := io.ReadAll(resp.Body)
if readErr != nil {
resp.Body.Close()
return nil, fmt.Errorf("claude bridge: read response: %w", readErr)
}
resp.Body.Close()
oaiJSON, err := claudeToOpenAIResponseJSON(respBody)
if err != nil {
+8 -2
View File
@@ -189,7 +189,10 @@ func (c *Client) ChatCompletionStream(ctx context.Context, payload interface{},
// 非200:读完 body 返回
if resp.StatusCode != http.StatusOK {
respBody, _ := io.ReadAll(resp.Body)
respBody, readErr := io.ReadAll(resp.Body)
if readErr != nil {
c.logger.Warn("failed to read OpenAI error response body", zap.Error(readErr))
}
return "", &APIError{
StatusCode: resp.StatusCode,
Body: string(respBody),
@@ -329,7 +332,10 @@ func (c *Client) ChatCompletionStreamWithToolCalls(
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
respBody, _ := io.ReadAll(resp.Body)
respBody, readErr := io.ReadAll(resp.Body)
if readErr != nil {
c.logger.Warn("failed to read OpenAI error response body", zap.Error(readErr))
}
return "", nil, "", &APIError{
StatusCode: resp.StatusCode,
Body: string(respBody),
+81
View File
@@ -0,0 +1,81 @@
package security
import (
"net/http"
"sync"
"time"
"github.com/gin-gonic/gin"
)
// rateLimitEntry 记录某个 IP 的请求窗口信息
type rateLimitEntry struct {
count int
windowAt time.Time
}
// RateLimiter 基于 IP 的滑动窗口速率限制器
type RateLimiter struct {
mu sync.Mutex
entries map[string]*rateLimitEntry
limit int // 窗口内允许的最大请求数
window time.Duration // 窗口时长
}
// NewRateLimiter 创建速率限制器
func NewRateLimiter(limit int, window time.Duration) *RateLimiter {
rl := &RateLimiter{
entries: make(map[string]*rateLimitEntry),
limit: limit,
window: window,
}
// 后台定期清理过期条目,防止内存泄漏
go rl.cleanup()
return rl
}
// cleanup 每分钟清理一次过期条目
func (rl *RateLimiter) cleanup() {
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
for range ticker.C {
rl.mu.Lock()
now := time.Now()
for ip, entry := range rl.entries {
if now.Sub(entry.windowAt) > rl.window {
delete(rl.entries, ip)
}
}
rl.mu.Unlock()
}
}
// allow 检查指定 IP 是否允许通过
func (rl *RateLimiter) allow(ip string) bool {
rl.mu.Lock()
defer rl.mu.Unlock()
now := time.Now()
entry, ok := rl.entries[ip]
if !ok || now.Sub(entry.windowAt) > rl.window {
rl.entries[ip] = &rateLimitEntry{count: 1, windowAt: now}
return true
}
entry.count++
return entry.count <= rl.limit
}
// RateLimitMiddleware 返回 Gin 中间件,对超限请求返回 429
func RateLimitMiddleware(rl *RateLimiter) gin.HandlerFunc {
return func(c *gin.Context) {
ip := c.ClientIP()
if !rl.allow(ip) {
c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{
"error": "rate limit exceeded, please try again later",
})
return
}
c.Next()
}
}