mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-05-29 10:41:28 +02:00
Add files via upload
This commit is contained in:
@@ -949,7 +949,7 @@ func (a *Agent) getAvailableTools(roleTools []string) []Tool {
|
||||
enabled := false
|
||||
if cfg, exists := externalMCPConfigs[mcpName]; exists {
|
||||
// 首先检查外部MCP是否启用
|
||||
if !cfg.ExternalMCPEnable && !(cfg.Enabled && !cfg.Disabled) {
|
||||
if !cfg.ExternalMCPEnable {
|
||||
enabled = false // MCP未启用,所有工具都禁用
|
||||
} else {
|
||||
// MCP已启用,检查单个工具的启用状态
|
||||
|
||||
+58
-14
@@ -2,6 +2,7 @@ package app
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/subtle"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net/http"
|
||||
@@ -459,7 +460,9 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
|
||||
func (a *App) mcpHandlerWithAuth(w http.ResponseWriter, r *http.Request) {
|
||||
cfg := a.config.MCP
|
||||
if cfg.AuthHeader != "" {
|
||||
if r.Header.Get(cfg.AuthHeader) != cfg.AuthHeaderValue {
|
||||
actual := []byte(r.Header.Get(cfg.AuthHeader))
|
||||
expected := []byte(cfg.AuthHeaderValue)
|
||||
if subtle.ConstantTimeCompare(actual, expected) != 1 {
|
||||
a.logger.Logger.Debug("MCP 鉴权失败:header 缺失或值不匹配", zap.String("header", cfg.AuthHeader))
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
@@ -470,18 +473,25 @@ func (a *App) mcpHandlerWithAuth(w http.ResponseWriter, r *http.Request) {
|
||||
a.mcpServer.HandleHTTP(w, r)
|
||||
}
|
||||
|
||||
// Run 启动应用
|
||||
// Run 启动应用(向后兼容,不支持优雅关闭)
|
||||
func (a *App) Run() error {
|
||||
return a.RunWithContext(context.Background())
|
||||
}
|
||||
|
||||
// RunWithContext 启动应用,支持通过 context 取消来优雅关闭
|
||||
func (a *App) RunWithContext(ctx context.Context) error {
|
||||
// 启动MCP服务器(如果启用)
|
||||
var mcpServer *http.Server
|
||||
if a.config.MCP.Enabled {
|
||||
mcpAddr := fmt.Sprintf("%s:%d", a.config.MCP.Host, a.config.MCP.Port)
|
||||
a.logger.Info("启动MCP服务器", zap.String("address", mcpAddr))
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/mcp", a.mcpHandlerWithAuth)
|
||||
|
||||
mcpServer = &http.Server{Addr: mcpAddr, Handler: mux}
|
||||
go func() {
|
||||
mcpAddr := fmt.Sprintf("%s:%d", a.config.MCP.Host, a.config.MCP.Port)
|
||||
a.logger.Info("启动MCP服务器", zap.String("address", mcpAddr))
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/mcp", a.mcpHandlerWithAuth)
|
||||
|
||||
if err := http.ListenAndServe(mcpAddr, mux); err != nil {
|
||||
if err := mcpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
a.logger.Error("MCP服务器启动失败", zap.Error(err))
|
||||
}
|
||||
}()
|
||||
@@ -491,7 +501,27 @@ func (a *App) Run() error {
|
||||
addr := fmt.Sprintf("%s:%d", a.config.Server.Host, a.config.Server.Port)
|
||||
a.logger.Info("启动HTTP服务器", zap.String("address", addr))
|
||||
|
||||
return a.router.Run(addr)
|
||||
srv := &http.Server{Addr: addr, Handler: a.router}
|
||||
|
||||
// 监听 context 取消,优雅关闭 HTTP 服务器
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
if err := srv.Shutdown(shutdownCtx); err != nil {
|
||||
a.logger.Error("HTTP服务器关闭失败", zap.Error(err))
|
||||
}
|
||||
if mcpServer != nil {
|
||||
if err := mcpServer.Shutdown(shutdownCtx); err != nil {
|
||||
a.logger.Error("MCP服务器关闭失败", zap.Error(err))
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Shutdown 关闭应用
|
||||
@@ -519,6 +549,13 @@ func (a *App) Shutdown() {
|
||||
a.logger.Logger.Warn("关闭知识库数据库连接失败", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
// 关闭主数据库连接
|
||||
if a.db != nil {
|
||||
if err := a.db.Close(); err != nil {
|
||||
a.logger.Logger.Warn("关闭主数据库连接失败", zap.Error(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// startRobotConnections 根据当前配置启动钉钉/飞书长连接(不先关闭已有连接,仅用于首次启动)
|
||||
@@ -593,10 +630,16 @@ func setupRoutes(
|
||||
}
|
||||
|
||||
// 机器人回调(无需登录,供企业微信/钉钉/飞书服务器调用)
|
||||
api.GET("/robot/wecom", robotHandler.HandleWecomGET)
|
||||
api.POST("/robot/wecom", robotHandler.HandleWecomPOST)
|
||||
api.POST("/robot/dingtalk", robotHandler.HandleDingtalkPOST)
|
||||
api.POST("/robot/lark", robotHandler.HandleLarkPOST)
|
||||
// 添加速率限制:每个 IP 每分钟最多 60 次请求,防止滥用
|
||||
robotRL := security.NewRateLimiter(60, 1*time.Minute)
|
||||
robotGroup := api.Group("/robot")
|
||||
robotGroup.Use(security.RateLimitMiddleware(robotRL))
|
||||
{
|
||||
robotGroup.GET("/wecom", robotHandler.HandleWecomGET)
|
||||
robotGroup.POST("/wecom", robotHandler.HandleWecomPOST)
|
||||
robotGroup.POST("/dingtalk", robotHandler.HandleDingtalkPOST)
|
||||
robotGroup.POST("/lark", robotHandler.HandleLarkPOST)
|
||||
}
|
||||
|
||||
protected := api.Group("")
|
||||
protected.Use(security.AuthMiddleware(authManager))
|
||||
@@ -680,6 +723,7 @@ func setupRoutes(
|
||||
// 配置管理
|
||||
protected.GET("/config", configHandler.GetConfig)
|
||||
protected.GET("/config/tools", configHandler.GetTools)
|
||||
protected.GET("/config/tools/:name/schema", configHandler.GetToolSchema)
|
||||
protected.PUT("/config", configHandler.UpdateConfig)
|
||||
protected.POST("/config/apply", configHandler.ApplyConfig)
|
||||
protected.POST("/config/test-openai", configHandler.TestOpenAI)
|
||||
|
||||
@@ -2089,14 +2089,17 @@ func (h *AgentHandler) nextBatchQueueRunAt(cronExpr string, from time.Time) (*ti
|
||||
}
|
||||
|
||||
func (h *AgentHandler) startBatchQueueExecution(queueID string, scheduled bool) (bool, error) {
|
||||
queue, exists := h.batchTaskManager.GetBatchQueue(queueID)
|
||||
if !exists {
|
||||
return false, nil
|
||||
}
|
||||
// 先获取执行互斥门,再读取队列状态,避免基于过时快照做判断
|
||||
if !h.markBatchQueueRunning(queueID) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
queue, exists := h.batchTaskManager.GetBatchQueue(queueID)
|
||||
if !exists {
|
||||
h.unmarkBatchQueueRunning(queueID)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if scheduled {
|
||||
if queue.ScheduleMode != "cron" {
|
||||
h.unmarkBatchQueueRunning(queueID)
|
||||
|
||||
@@ -543,16 +543,23 @@ func (m *BatchTaskManager) UpdateTaskStatus(queueID, taskID, status string, resu
|
||||
|
||||
// UpdateTaskStatusWithConversationID 更新任务状态(包含conversationId)
|
||||
func (m *BatchTaskManager) UpdateTaskStatusWithConversationID(queueID, taskID, status string, result, errorMsg, conversationID string) {
|
||||
var needDBUpdate bool
|
||||
|
||||
// 在锁内只更新内存状态
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
queue, exists := m.queues[queueID]
|
||||
if !exists {
|
||||
m.mu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
// DB 优先:先持久化,成功后再更新内存,避免重启后状态不一致
|
||||
if m.db != nil {
|
||||
if err := m.db.UpdateBatchTaskStatus(queueID, taskID, status, conversationID, result, errorMsg); err != nil {
|
||||
m.logger.Warn("batch task DB status update failed, skipping memory update",
|
||||
zap.String("queueId", queueID), zap.String("taskId", taskID), zap.Error(err))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
for _, task := range queue.Tasks {
|
||||
if task.ID == taskID {
|
||||
task.Status = status
|
||||
@@ -575,30 +582,27 @@ func (m *BatchTaskManager) UpdateTaskStatusWithConversationID(queueID, taskID, s
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
needDBUpdate = m.db != nil
|
||||
m.mu.Unlock()
|
||||
|
||||
// 释放锁后写 DB
|
||||
if needDBUpdate {
|
||||
if err := m.db.UpdateBatchTaskStatus(queueID, taskID, status, conversationID, result, errorMsg); err != nil {
|
||||
m.logger.Warn("batch task DB status update failed", zap.String("queueId", queueID), zap.String("taskId", taskID), zap.Error(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateQueueStatus 更新队列状态
|
||||
func (m *BatchTaskManager) UpdateQueueStatus(queueID, status string) {
|
||||
var needDBUpdate bool
|
||||
|
||||
// 在锁内只更新内存状态
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
queue, exists := m.queues[queueID]
|
||||
if !exists {
|
||||
m.mu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
// DB 优先:先持久化,成功后再更新内存
|
||||
if m.db != nil {
|
||||
if err := m.db.UpdateBatchQueueStatus(queueID, status); err != nil {
|
||||
m.logger.Warn("batch queue DB status update failed, skipping memory update",
|
||||
zap.String("queueId", queueID), zap.Error(err))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
queue.Status = status
|
||||
now := time.Now()
|
||||
if status == BatchQueueStatusRunning && queue.StartedAt == nil {
|
||||
@@ -607,16 +611,6 @@ func (m *BatchTaskManager) UpdateQueueStatus(queueID, status string) {
|
||||
if status == BatchQueueStatusCompleted || status == BatchQueueStatusCancelled {
|
||||
queue.CompletedAt = &now
|
||||
}
|
||||
|
||||
needDBUpdate = m.db != nil
|
||||
m.mu.Unlock()
|
||||
|
||||
// 释放锁后写 DB
|
||||
if needDBUpdate {
|
||||
if err := m.db.UpdateBatchQueueStatus(queueID, status); err != nil {
|
||||
m.logger.Warn("batch queue DB status update failed", zap.String("queueId", queueID), zap.Error(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateQueueSchedule 更新队列调度配置
|
||||
@@ -756,6 +750,16 @@ func (m *BatchTaskManager) ResetQueueForRerun(queueID string) bool {
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
|
||||
// DB 优先:先持久化重置,成功后再更新内存,避免 DB 失败导致内存脏状态
|
||||
if m.db != nil {
|
||||
if err := m.db.ResetBatchQueueForRerun(queueID); err != nil {
|
||||
m.logger.Warn("batch queue DB reset for rerun failed, skipping memory update",
|
||||
zap.String("queueId", queueID), zap.Error(err))
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
queue.Status = BatchQueueStatusPending
|
||||
queue.CurrentIndex = 0
|
||||
queue.StartedAt = nil
|
||||
@@ -771,12 +775,6 @@ func (m *BatchTaskManager) ResetQueueForRerun(queueID string) bool {
|
||||
task.Error = ""
|
||||
task.Result = ""
|
||||
}
|
||||
|
||||
if m.db != nil {
|
||||
if err := m.db.ResetBatchQueueForRerun(queueID); err != nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -870,7 +868,7 @@ func (m *BatchTaskManager) DeleteTask(queueID, taskID string) error {
|
||||
return fmt.Errorf("队列正在执行或未就绪,无法删除任务")
|
||||
}
|
||||
|
||||
// 查找并删除任务
|
||||
// 查找任务
|
||||
taskIndex := -1
|
||||
for i, task := range queue.Tasks {
|
||||
if task.ID == taskID {
|
||||
@@ -886,18 +884,14 @@ func (m *BatchTaskManager) DeleteTask(queueID, taskID string) error {
|
||||
return fmt.Errorf("任务不存在")
|
||||
}
|
||||
|
||||
// 从内存队列中删除
|
||||
queue.Tasks = append(queue.Tasks[:taskIndex], queue.Tasks[taskIndex+1:]...)
|
||||
|
||||
// 同步到数据库
|
||||
// DB 优先:先从数据库删除,成功后再从内存移除
|
||||
if m.db != nil {
|
||||
if err := m.db.DeleteBatchTask(queueID, taskID); err != nil {
|
||||
// 如果数据库删除失败,恢复内存中的任务
|
||||
// 这里需要重新插入,但为了简化,我们只记录错误
|
||||
return fmt.Errorf("删除任务失败: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
queue.Tasks = append(queue.Tasks[:taskIndex], queue.Tasks[taskIndex+1:]...)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -987,9 +981,7 @@ func (m *BatchTaskManager) SetTaskCancel(queueID string, cancel context.CancelFu
|
||||
// PauseQueue 暂停队列
|
||||
func (m *BatchTaskManager) PauseQueue(queueID string) bool {
|
||||
var cancelFunc context.CancelFunc
|
||||
var needDBUpdate bool
|
||||
|
||||
// 在锁内只更新内存状态
|
||||
m.mu.Lock()
|
||||
queue, exists := m.queues[queueID]
|
||||
if !exists {
|
||||
@@ -1002,6 +994,16 @@ func (m *BatchTaskManager) PauseQueue(queueID string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// DB 优先:先持久化,成功后再更新内存
|
||||
if m.db != nil {
|
||||
if err := m.db.UpdateBatchQueueStatus(queueID, BatchQueueStatusPaused); err != nil {
|
||||
m.logger.Warn("batch queue DB pause update failed, skipping memory update",
|
||||
zap.String("queueId", queueID), zap.Error(err))
|
||||
m.mu.Unlock()
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
queue.Status = BatchQueueStatusPaused
|
||||
|
||||
// 取消当前正在执行的任务(通过取消context)
|
||||
@@ -1009,22 +1011,13 @@ func (m *BatchTaskManager) PauseQueue(queueID string) bool {
|
||||
cancelFunc = cancel
|
||||
delete(m.taskCancels, queueID)
|
||||
}
|
||||
|
||||
needDBUpdate = m.db != nil
|
||||
m.mu.Unlock()
|
||||
|
||||
// 释放锁后执行取消回调
|
||||
// 释放锁后执行取消回调(cancel 可能阻塞,不应持锁)
|
||||
if cancelFunc != nil {
|
||||
cancelFunc()
|
||||
}
|
||||
|
||||
// 释放锁后写 DB
|
||||
if needDBUpdate {
|
||||
if err := m.db.UpdateBatchQueueStatus(queueID, BatchQueueStatusPaused); err != nil {
|
||||
m.logger.Warn("batch queue DB pause update failed", zap.String("queueId", queueID), zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -1032,9 +1025,7 @@ func (m *BatchTaskManager) PauseQueue(queueID string) bool {
|
||||
func (m *BatchTaskManager) CancelQueue(queueID string) bool {
|
||||
now := time.Now()
|
||||
var cancelFunc context.CancelFunc
|
||||
var needDBUpdate bool
|
||||
|
||||
// 在锁内只更新内存状态,不做 DB 操作
|
||||
m.mu.Lock()
|
||||
queue, exists := m.queues[queueID]
|
||||
if !exists {
|
||||
@@ -1047,6 +1038,22 @@ func (m *BatchTaskManager) CancelQueue(queueID string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// DB 优先:先持久化,成功后再更新内存
|
||||
if m.db != nil {
|
||||
if err := m.db.CancelPendingBatchTasks(queueID, now); err != nil {
|
||||
m.logger.Warn("batch task DB batch cancel failed, skipping memory update",
|
||||
zap.String("queueId", queueID), zap.Error(err))
|
||||
m.mu.Unlock()
|
||||
return false
|
||||
}
|
||||
if err := m.db.UpdateBatchQueueStatus(queueID, BatchQueueStatusCancelled); err != nil {
|
||||
m.logger.Warn("batch queue DB cancel update failed, skipping memory update",
|
||||
zap.String("queueId", queueID), zap.Error(err))
|
||||
m.mu.Unlock()
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
queue.Status = BatchQueueStatusCancelled
|
||||
queue.CompletedAt = &now
|
||||
|
||||
@@ -1063,25 +1070,13 @@ func (m *BatchTaskManager) CancelQueue(queueID string) bool {
|
||||
cancelFunc = cancel
|
||||
delete(m.taskCancels, queueID)
|
||||
}
|
||||
|
||||
needDBUpdate = m.db != nil
|
||||
m.mu.Unlock()
|
||||
|
||||
// 释放锁后执行取消回调
|
||||
// 释放锁后执行取消回调(cancel 可能阻塞,不应持锁)
|
||||
if cancelFunc != nil {
|
||||
cancelFunc()
|
||||
}
|
||||
|
||||
// 释放锁后批量写 DB(单条 SQL 取消所有 pending 任务)
|
||||
if needDBUpdate {
|
||||
if err := m.db.CancelPendingBatchTasks(queueID, now); err != nil {
|
||||
m.logger.Warn("batch task DB batch cancel failed", zap.String("queueId", queueID), zap.Error(err))
|
||||
}
|
||||
if err := m.db.UpdateBatchQueueStatus(queueID, BatchQueueStatusCancelled); err != nil {
|
||||
m.logger.Warn("batch queue DB cancel update failed", zap.String("queueId", queueID), zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
|
||||
+121
-39
@@ -194,12 +194,13 @@ type GetConfigResponse struct {
|
||||
|
||||
// ToolConfigInfo 工具配置信息
|
||||
type ToolConfigInfo struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Enabled bool `json:"enabled"`
|
||||
IsExternal bool `json:"is_external,omitempty"` // 是否为外部MCP工具
|
||||
ExternalMCP string `json:"external_mcp,omitempty"` // 外部MCP名称(如果是外部工具)
|
||||
RoleEnabled *bool `json:"role_enabled,omitempty"` // 该工具在当前角色中是否启用(nil表示未指定角色或使用所有工具)
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Enabled bool `json:"enabled"`
|
||||
IsExternal bool `json:"is_external,omitempty"` // 是否为外部MCP工具
|
||||
ExternalMCP string `json:"external_mcp,omitempty"` // 外部MCP名称(如果是外部工具)
|
||||
RoleEnabled *bool `json:"role_enabled,omitempty"` // 该工具在当前角色中是否启用(nil表示未指定角色或使用所有工具)
|
||||
InputSchema map[string]interface{} `json:"input_schema,omitempty"` // 工具参数 JSON Schema(用于前端展示详情)
|
||||
}
|
||||
|
||||
// GetConfig 获取当前配置
|
||||
@@ -211,25 +212,25 @@ func (h *ConfigHandler) GetConfig(c *gin.Context) {
|
||||
// 首先从配置文件获取工具
|
||||
configToolMap := make(map[string]bool)
|
||||
tools := make([]ToolConfigInfo, 0, len(h.config.Security.Tools))
|
||||
|
||||
for _, tool := range h.config.Security.Tools {
|
||||
configToolMap[tool.Name] = true
|
||||
tools = append(tools, ToolConfigInfo{
|
||||
info := ToolConfigInfo{
|
||||
Name: tool.Name,
|
||||
Description: h.pickToolDescription(tool.ShortDescription, tool.Description),
|
||||
Enabled: tool.Enabled,
|
||||
IsExternal: false,
|
||||
})
|
||||
}
|
||||
tools = append(tools, info)
|
||||
}
|
||||
|
||||
// 从MCP服务器获取所有已注册的工具(包括直接注册的工具,如知识检索工具)
|
||||
if h.mcpServer != nil {
|
||||
mcpTools := h.mcpServer.GetAllTools()
|
||||
for _, mcpTool := range mcpTools {
|
||||
// 跳过已经在配置文件中的工具(避免重复)
|
||||
if configToolMap[mcpTool.Name] {
|
||||
continue
|
||||
}
|
||||
// 添加直接注册到MCP服务器的工具(如知识检索工具)
|
||||
description := mcpTool.ShortDescription
|
||||
if description == "" {
|
||||
description = mcpTool.Description
|
||||
@@ -240,7 +241,7 @@ func (h *ConfigHandler) GetConfig(c *gin.Context) {
|
||||
tools = append(tools, ToolConfigInfo{
|
||||
Name: mcpTool.Name,
|
||||
Description: description,
|
||||
Enabled: true, // 直接注册的工具默认启用
|
||||
Enabled: true,
|
||||
IsExternal: false,
|
||||
})
|
||||
}
|
||||
@@ -442,7 +443,7 @@ func (h *ConfigHandler) GetTools(c *gin.Context) {
|
||||
toolInfo := ToolConfigInfo{
|
||||
Name: mcpTool.Name,
|
||||
Description: description,
|
||||
Enabled: true, // 直接注册的工具默认启用
|
||||
Enabled: true,
|
||||
IsExternal: false,
|
||||
}
|
||||
|
||||
@@ -1142,32 +1143,7 @@ func (h *ConfigHandler) saveConfig() error {
|
||||
updateRobotsConfig(root, h.config.Robots)
|
||||
updateMultiAgentConfig(root, h.config.MultiAgent)
|
||||
// 更新外部MCP配置(使用external_mcp.go中的函数,同一包中可直接调用)
|
||||
// 读取原始配置以保持向后兼容
|
||||
originalConfigs := make(map[string]map[string]bool)
|
||||
externalMCPNode := findMapValue(root, "external_mcp")
|
||||
if externalMCPNode != nil && externalMCPNode.Kind == yaml.MappingNode {
|
||||
serversNode := findMapValue(externalMCPNode, "servers")
|
||||
if serversNode != nil && serversNode.Kind == yaml.MappingNode {
|
||||
for i := 0; i < len(serversNode.Content); i += 2 {
|
||||
if i+1 >= len(serversNode.Content) {
|
||||
break
|
||||
}
|
||||
nameNode := serversNode.Content[i]
|
||||
serverNode := serversNode.Content[i+1]
|
||||
if nameNode.Kind == yaml.ScalarNode && serverNode.Kind == yaml.MappingNode {
|
||||
serverName := nameNode.Value
|
||||
originalConfigs[serverName] = make(map[string]bool)
|
||||
if enabledVal := findBoolInMap(serverNode, "enabled"); enabledVal != nil {
|
||||
originalConfigs[serverName]["enabled"] = *enabledVal
|
||||
}
|
||||
if disabledVal := findBoolInMap(serverNode, "disabled"); disabledVal != nil {
|
||||
originalConfigs[serverName]["disabled"] = *disabledVal
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
updateExternalMCPConfig(root, h.config.ExternalMCP, originalConfigs)
|
||||
updateExternalMCPConfig(root, h.config.ExternalMCP)
|
||||
|
||||
if err := writeYAMLDocument(h.configPath, root); err != nil {
|
||||
return fmt.Errorf("保存配置文件失败: %w", err)
|
||||
@@ -1585,7 +1561,7 @@ func (h *ConfigHandler) calculateExternalToolEnabled(mcpName, toolName string, c
|
||||
}
|
||||
|
||||
// 首先检查外部MCP是否启用
|
||||
if !cfg.ExternalMCPEnable && !(cfg.Enabled && !cfg.Disabled) {
|
||||
if !cfg.ExternalMCPEnable {
|
||||
return false // MCP未启用,所有工具都禁用
|
||||
}
|
||||
|
||||
@@ -1624,3 +1600,109 @@ func (h *ConfigHandler) pickToolDescription(shortDesc, fullDesc string) string {
|
||||
}
|
||||
return description
|
||||
}
|
||||
|
||||
// GetToolSchema 获取单个工具的 inputSchema(按需加载,避免列表接口返回大量 schema 数据)
|
||||
func (h *ConfigHandler) GetToolSchema(c *gin.Context) {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
|
||||
toolName := c.Param("name")
|
||||
if toolName == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "工具名称不能为空"})
|
||||
return
|
||||
}
|
||||
|
||||
// 检查是否为外部工具(格式:mcpName::toolName)
|
||||
externalMCP := c.Query("external_mcp")
|
||||
if externalMCP != "" {
|
||||
// 外部 MCP 工具
|
||||
if h.externalMCPMgr != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
externalTools, _ := h.externalMCPMgr.GetAllTools(ctx)
|
||||
fullName := externalMCP + "::" + toolName
|
||||
for _, t := range externalTools {
|
||||
if t.Name == fullName {
|
||||
c.JSON(http.StatusOK, gin.H{"input_schema": t.InputSchema})
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "外部工具未找到"})
|
||||
return
|
||||
}
|
||||
|
||||
// 内部工具:从 YAML 配置的 Parameters 构建
|
||||
for _, tool := range h.config.Security.Tools {
|
||||
if tool.Name == toolName {
|
||||
c.JSON(http.StatusOK, gin.H{"input_schema": buildInputSchemaFromParams(tool.Parameters)})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// MCP 注册工具(如知识检索)
|
||||
if h.mcpServer != nil {
|
||||
for _, mt := range h.mcpServer.GetAllTools() {
|
||||
if mt.Name == toolName {
|
||||
c.JSON(http.StatusOK, gin.H{"input_schema": mt.InputSchema})
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "工具未找到"})
|
||||
}
|
||||
|
||||
// buildInputSchemaFromParams 从 YAML 工具的 ParameterConfig 构建 JSON Schema(用于前端展示)。
|
||||
// 不依赖 MCP 服务器注册状态,所有工具(包括未启用的)都能返回参数定义。
|
||||
func buildInputSchemaFromParams(params []config.ParameterConfig) map[string]interface{} {
|
||||
if len(params) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
properties := make(map[string]interface{})
|
||||
required := make([]string, 0)
|
||||
|
||||
for _, p := range params {
|
||||
name := strings.TrimSpace(p.Name)
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
prop := map[string]interface{}{
|
||||
"type": convertParamType(p.Type),
|
||||
"description": p.Description,
|
||||
}
|
||||
if p.Default != nil {
|
||||
prop["default"] = p.Default
|
||||
}
|
||||
if len(p.Options) > 0 {
|
||||
prop["enum"] = p.Options
|
||||
}
|
||||
properties[name] = prop
|
||||
if p.Required {
|
||||
required = append(required, name)
|
||||
}
|
||||
}
|
||||
|
||||
schema := map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
}
|
||||
if len(required) > 0 {
|
||||
schema["required"] = required
|
||||
}
|
||||
return schema
|
||||
}
|
||||
|
||||
func convertParamType(t string) string {
|
||||
switch strings.TrimSpace(strings.ToLower(t)) {
|
||||
case "int", "integer", "number":
|
||||
return "number"
|
||||
case "bool", "boolean":
|
||||
return "boolean"
|
||||
case "array", "list":
|
||||
return "array"
|
||||
default:
|
||||
return "string"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -157,36 +157,19 @@ func (h *ExternalMCPHandler) AddOrUpdateExternalMCP(c *gin.Context) {
|
||||
h.config.ExternalMCP.Servers = make(map[string]config.ExternalMCPServerConfig)
|
||||
}
|
||||
|
||||
// 如果用户提供了 disabled 或 enabled 字段,保留它们以保持向后兼容
|
||||
// 同时将值迁移到 external_mcp_enable
|
||||
cfg := req.Config
|
||||
|
||||
if req.Config.Disabled {
|
||||
// 用户设置了 disabled: true
|
||||
// 官方 disabled 字段 → ExternalMCPEnable 取反
|
||||
if cfg.Disabled {
|
||||
cfg.ExternalMCPEnable = false
|
||||
cfg.Disabled = true
|
||||
cfg.Enabled = false
|
||||
} else if req.Config.Enabled {
|
||||
// 用户设置了 enabled: true
|
||||
} else if !cfg.ExternalMCPEnable {
|
||||
// 用户未显式设置 external_mcp_enable,官方配置默认就是启用的
|
||||
cfg.ExternalMCPEnable = true
|
||||
cfg.Enabled = true
|
||||
cfg.Disabled = false
|
||||
} else if !req.Config.ExternalMCPEnable {
|
||||
// 用户没有设置任何字段,且 external_mcp_enable 为 false
|
||||
// 检查现有配置是否有旧字段
|
||||
if existingCfg, exists := h.config.ExternalMCP.Servers[name]; exists {
|
||||
// 保留现有的旧字段
|
||||
cfg.Enabled = existingCfg.Enabled
|
||||
cfg.Disabled = existingCfg.Disabled
|
||||
}
|
||||
} else {
|
||||
// 用户通过新字段启用了(external_mcp_enable: true),但没有设置旧字段
|
||||
// 为了向后兼容,我们设置 enabled: true
|
||||
// 这样即使原始配置中有 disabled: false,也会被转换为 enabled: true
|
||||
cfg.Enabled = true
|
||||
cfg.Disabled = false
|
||||
}
|
||||
|
||||
// 展开 ${VAR} 环境变量
|
||||
config.ExpandConfigEnv(&cfg)
|
||||
|
||||
h.config.ExternalMCP.Servers[name] = cfg
|
||||
|
||||
// 保存到配置文件
|
||||
@@ -315,32 +298,25 @@ func (h *ExternalMCPHandler) GetExternalMCPStats(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, stats)
|
||||
}
|
||||
|
||||
// validateConfig 验证配置
|
||||
// validateConfig 验证配置(同时支持官方 type 字段和旧版 transport 字段)
|
||||
func (h *ExternalMCPHandler) validateConfig(cfg config.ExternalMCPServerConfig) error {
|
||||
transport := cfg.Transport
|
||||
transport := cfg.GetTransportType()
|
||||
if transport == "" {
|
||||
// 如果没有指定transport,根据是否有command或url判断
|
||||
if cfg.Command != "" {
|
||||
transport = "stdio"
|
||||
} else if cfg.URL != "" {
|
||||
transport = "http"
|
||||
} else {
|
||||
return fmt.Errorf("需要指定command(stdio模式)或url(http/sse模式)")
|
||||
}
|
||||
return fmt.Errorf("需要指定 command(stdio模式)或 url + type(http/sse模式)")
|
||||
}
|
||||
|
||||
switch transport {
|
||||
case "http":
|
||||
if cfg.URL == "" {
|
||||
return fmt.Errorf("HTTP模式需要URL")
|
||||
return fmt.Errorf("HTTP模式需要 url")
|
||||
}
|
||||
case "stdio":
|
||||
if cfg.Command == "" {
|
||||
return fmt.Errorf("stdio模式需要command")
|
||||
return fmt.Errorf("stdio模式需要 command")
|
||||
}
|
||||
case "sse":
|
||||
if cfg.URL == "" {
|
||||
return fmt.Errorf("SSE模式需要URL")
|
||||
return fmt.Errorf("SSE模式需要 url")
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("不支持的传输模式: %s,支持的模式: http, stdio, sse", transport)
|
||||
@@ -351,25 +327,11 @@ func (h *ExternalMCPHandler) validateConfig(cfg config.ExternalMCPServerConfig)
|
||||
|
||||
// isEnabled 检查是否启用
|
||||
func (h *ExternalMCPHandler) isEnabled(cfg config.ExternalMCPServerConfig) bool {
|
||||
// 优先使用 ExternalMCPEnable 字段
|
||||
// 如果没有设置,检查旧的 enabled/disabled 字段(向后兼容)
|
||||
if cfg.ExternalMCPEnable {
|
||||
return true
|
||||
}
|
||||
// 向后兼容:检查旧字段
|
||||
if cfg.Disabled {
|
||||
return false
|
||||
}
|
||||
if cfg.Enabled {
|
||||
return true
|
||||
}
|
||||
// 都没有设置,默认为启用
|
||||
return true
|
||||
return cfg.ExternalMCPEnable
|
||||
}
|
||||
|
||||
// saveConfig 保存配置到文件
|
||||
func (h *ExternalMCPHandler) saveConfig() error {
|
||||
// 读取现有配置文件并创建备份
|
||||
data, err := os.ReadFile(h.configPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("读取配置文件失败: %w", err)
|
||||
@@ -384,37 +346,7 @@ func (h *ExternalMCPHandler) saveConfig() error {
|
||||
return fmt.Errorf("解析配置文件失败: %w", err)
|
||||
}
|
||||
|
||||
// 在更新前,读取原始配置中的 enabled/disabled 字段,以便保持向后兼容
|
||||
originalConfigs := make(map[string]map[string]bool)
|
||||
externalMCPNode := findMapValue(root.Content[0], "external_mcp")
|
||||
if externalMCPNode != nil && externalMCPNode.Kind == yaml.MappingNode {
|
||||
serversNode := findMapValue(externalMCPNode, "servers")
|
||||
if serversNode != nil && serversNode.Kind == yaml.MappingNode {
|
||||
// 遍历现有的服务器配置,保存 enabled/disabled 字段
|
||||
for i := 0; i < len(serversNode.Content); i += 2 {
|
||||
if i+1 >= len(serversNode.Content) {
|
||||
break
|
||||
}
|
||||
nameNode := serversNode.Content[i]
|
||||
serverNode := serversNode.Content[i+1]
|
||||
if nameNode.Kind == yaml.ScalarNode && serverNode.Kind == yaml.MappingNode {
|
||||
serverName := nameNode.Value
|
||||
originalConfigs[serverName] = make(map[string]bool)
|
||||
// 检查是否有 enabled 字段
|
||||
if enabledVal := findBoolInMap(serverNode, "enabled"); enabledVal != nil {
|
||||
originalConfigs[serverName]["enabled"] = *enabledVal
|
||||
}
|
||||
// 检查是否有 disabled 字段
|
||||
if disabledVal := findBoolInMap(serverNode, "disabled"); disabledVal != nil {
|
||||
originalConfigs[serverName]["disabled"] = *disabledVal
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 更新外部MCP配置
|
||||
updateExternalMCPConfig(root, h.config.ExternalMCP, originalConfigs)
|
||||
updateExternalMCPConfig(root, h.config.ExternalMCP)
|
||||
|
||||
if err := writeYAMLDocument(h.configPath, root); err != nil {
|
||||
return fmt.Errorf("保存配置文件失败: %w", err)
|
||||
@@ -425,7 +357,7 @@ func (h *ExternalMCPHandler) saveConfig() error {
|
||||
}
|
||||
|
||||
// updateExternalMCPConfig 更新外部MCP配置
|
||||
func updateExternalMCPConfig(doc *yaml.Node, cfg config.ExternalMCPConfig, originalConfigs map[string]map[string]bool) {
|
||||
func updateExternalMCPConfig(doc *yaml.Node, cfg config.ExternalMCPConfig) {
|
||||
root := doc.Content[0]
|
||||
externalMCPNode := ensureMap(root, "external_mcp")
|
||||
serversNode := ensureMap(externalMCPNode, "servers")
|
||||
@@ -435,32 +367,31 @@ func updateExternalMCPConfig(doc *yaml.Node, cfg config.ExternalMCPConfig, origi
|
||||
|
||||
// 添加新的服务器配置
|
||||
for name, serverCfg := range cfg.Servers {
|
||||
// 添加服务器名称键
|
||||
nameNode := &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: name}
|
||||
serverNode := &yaml.Node{Kind: yaml.MappingNode, Tag: "!!map"}
|
||||
serversNode.Content = append(serversNode.Content, nameNode, serverNode)
|
||||
|
||||
// 设置服务器配置字段
|
||||
// type(官方 MCP 传输类型)
|
||||
effectiveType := serverCfg.GetTransportType()
|
||||
if effectiveType != "" && effectiveType != "stdio" {
|
||||
// stdio 可省略(有 command 时自动推断)
|
||||
setStringInMap(serverNode, "type", effectiveType)
|
||||
}
|
||||
if serverCfg.Command != "" {
|
||||
setStringInMap(serverNode, "command", serverCfg.Command)
|
||||
}
|
||||
if len(serverCfg.Args) > 0 {
|
||||
setStringArrayInMap(serverNode, "args", serverCfg.Args)
|
||||
}
|
||||
// 保存 env 字段(环境变量)
|
||||
if serverCfg.Env != nil && len(serverCfg.Env) > 0 {
|
||||
envNode := ensureMap(serverNode, "env")
|
||||
for envKey, envValue := range serverCfg.Env {
|
||||
setStringInMap(envNode, envKey, envValue)
|
||||
}
|
||||
}
|
||||
if serverCfg.Transport != "" {
|
||||
setStringInMap(serverNode, "transport", serverCfg.Transport)
|
||||
}
|
||||
if serverCfg.URL != "" {
|
||||
setStringInMap(serverNode, "url", serverCfg.URL)
|
||||
}
|
||||
// 保存 headers 字段(HTTP/SSE 请求头)
|
||||
if serverCfg.Headers != nil && len(serverCfg.Headers) > 0 {
|
||||
headersNode := ensureMap(serverNode, "headers")
|
||||
for k, v := range serverCfg.Headers {
|
||||
@@ -473,46 +404,32 @@ func updateExternalMCPConfig(doc *yaml.Node, cfg config.ExternalMCPConfig, origi
|
||||
if serverCfg.Timeout > 0 {
|
||||
setIntInMap(serverNode, "timeout", serverCfg.Timeout)
|
||||
}
|
||||
// 保存 external_mcp_enable 字段(新字段)
|
||||
// 官方标准字段
|
||||
if serverCfg.Disabled {
|
||||
setBoolInMap(serverNode, "disabled", true)
|
||||
}
|
||||
if len(serverCfg.AutoApprove) > 0 {
|
||||
setStringArrayInMap(serverNode, "autoApprove", serverCfg.AutoApprove)
|
||||
}
|
||||
|
||||
// SDK 高级配置
|
||||
if serverCfg.MaxRetries > 0 {
|
||||
setIntInMap(serverNode, "max_retries", serverCfg.MaxRetries)
|
||||
}
|
||||
if serverCfg.TerminateDuration > 0 {
|
||||
setIntInMap(serverNode, "terminate_duration", serverCfg.TerminateDuration)
|
||||
}
|
||||
if serverCfg.KeepAlive > 0 {
|
||||
setIntInMap(serverNode, "keep_alive", serverCfg.KeepAlive)
|
||||
}
|
||||
|
||||
setBoolInMap(serverNode, "external_mcp_enable", serverCfg.ExternalMCPEnable)
|
||||
// 保存 tool_enabled 字段(每个工具的启用状态)
|
||||
if serverCfg.ToolEnabled != nil && len(serverCfg.ToolEnabled) > 0 {
|
||||
toolEnabledNode := ensureMap(serverNode, "tool_enabled")
|
||||
for toolName, enabled := range serverCfg.ToolEnabled {
|
||||
setBoolInMap(toolEnabledNode, toolName, enabled)
|
||||
}
|
||||
}
|
||||
// 保留旧的 enabled/disabled 字段以保持向后兼容
|
||||
originalFields, hasOriginal := originalConfigs[name]
|
||||
|
||||
// 如果原始配置中有 enabled 字段,保留它
|
||||
if hasOriginal {
|
||||
if enabledVal, hasEnabled := originalFields["enabled"]; hasEnabled {
|
||||
setBoolInMap(serverNode, "enabled", enabledVal)
|
||||
}
|
||||
// 如果原始配置中有 disabled 字段,保留它
|
||||
// 注意:由于 omitempty,disabled: false 不会被保存,但 disabled: true 会被保存
|
||||
if disabledVal, hasDisabled := originalFields["disabled"]; hasDisabled {
|
||||
if disabledVal {
|
||||
setBoolInMap(serverNode, "disabled", disabledVal)
|
||||
} else {
|
||||
// 如果原始配置中有 disabled: false,我们保存 enabled: true 来等效表示
|
||||
// 因为 disabled: false 等价于 enabled: true
|
||||
setBoolInMap(serverNode, "enabled", true)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 如果用户在当前请求中明确设置了这些字段,也保存它们
|
||||
if serverCfg.Enabled {
|
||||
setBoolInMap(serverNode, "enabled", serverCfg.Enabled)
|
||||
}
|
||||
if serverCfg.Disabled {
|
||||
setBoolInMap(serverNode, "disabled", serverCfg.Disabled)
|
||||
} else if !hasOriginal && serverCfg.ExternalMCPEnable {
|
||||
// 如果用户通过新字段启用了,且原始配置中没有旧字段,保存 enabled: true 以保持向后兼容
|
||||
setBoolInMap(serverNode, "enabled", true)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -60,13 +60,13 @@ func TestExternalMCPHandler_AddOrUpdateExternalMCP_Stdio(t *testing.T) {
|
||||
router, _, configPath := setupTestRouter()
|
||||
defer cleanupTestConfig(configPath)
|
||||
|
||||
// 测试添加stdio模式的配置
|
||||
// 测试添加stdio模式的配置(官方格式:有 command 时 type 可省略)
|
||||
configJSON := `{
|
||||
"command": "python3",
|
||||
"args": ["/path/to/script.py", "--server", "http://example.com"],
|
||||
"description": "Test stdio MCP",
|
||||
"timeout": 300,
|
||||
"enabled": true
|
||||
"external_mcp_enable": true
|
||||
}`
|
||||
|
||||
var configObj config.ExternalMCPServerConfig
|
||||
@@ -115,20 +115,17 @@ func TestExternalMCPHandler_AddOrUpdateExternalMCP_Stdio(t *testing.T) {
|
||||
if response.Config.Timeout != 300 {
|
||||
t.Errorf("期望timeout为300,实际%d", response.Config.Timeout)
|
||||
}
|
||||
if !response.Config.Enabled {
|
||||
t.Error("期望enabled为true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExternalMCPHandler_AddOrUpdateExternalMCP_HTTP(t *testing.T) {
|
||||
router, _, configPath := setupTestRouter()
|
||||
defer cleanupTestConfig(configPath)
|
||||
|
||||
// 测试添加HTTP模式的配置
|
||||
// 测试添加HTTP模式的配置(使用官方 type 字段)
|
||||
configJSON := `{
|
||||
"transport": "http",
|
||||
"type": "http",
|
||||
"url": "http://127.0.0.1:8081/mcp",
|
||||
"enabled": true
|
||||
"external_mcp_enable": true
|
||||
}`
|
||||
|
||||
var configObj config.ExternalMCPServerConfig
|
||||
@@ -165,15 +162,12 @@ func TestExternalMCPHandler_AddOrUpdateExternalMCP_HTTP(t *testing.T) {
|
||||
t.Fatalf("解析响应失败: %v", err)
|
||||
}
|
||||
|
||||
if response.Config.Transport != "http" {
|
||||
t.Errorf("期望transport为http,实际%s", response.Config.Transport)
|
||||
if response.Config.Type != "http" {
|
||||
t.Errorf("期望type为http,实际%s", response.Config.Type)
|
||||
}
|
||||
if response.Config.URL != "http://127.0.0.1:8081/mcp" {
|
||||
t.Errorf("期望url为'http://127.0.0.1:8081/mcp',实际%s", response.Config.URL)
|
||||
}
|
||||
if !response.Config.Enabled {
|
||||
t.Error("期望enabled为true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExternalMCPHandler_AddOrUpdateExternalMCP_InvalidConfig(t *testing.T) {
|
||||
@@ -187,22 +181,22 @@ func TestExternalMCPHandler_AddOrUpdateExternalMCP_InvalidConfig(t *testing.T) {
|
||||
}{
|
||||
{
|
||||
name: "缺少command和url",
|
||||
configJSON: `{"enabled": true}`,
|
||||
expectedErr: "需要指定command(stdio模式)或url(http/sse模式)",
|
||||
configJSON: `{"external_mcp_enable": true}`,
|
||||
expectedErr: "需要指定 command(stdio模式)或 url + type(http/sse模式)",
|
||||
},
|
||||
{
|
||||
name: "stdio模式缺少command",
|
||||
configJSON: `{"args": ["test"], "enabled": true}`,
|
||||
configJSON: `{"args": ["test"], "external_mcp_enable": true}`,
|
||||
expectedErr: "stdio模式需要command",
|
||||
},
|
||||
{
|
||||
name: "http模式缺少url",
|
||||
configJSON: `{"transport": "http", "enabled": true}`,
|
||||
expectedErr: "HTTP模式需要URL",
|
||||
configJSON: `{"type": "http", "external_mcp_enable": true}`,
|
||||
expectedErr: "HTTP模式需要 url",
|
||||
},
|
||||
{
|
||||
name: "无效的transport",
|
||||
configJSON: `{"transport": "invalid", "enabled": true}`,
|
||||
name: "无效的type",
|
||||
configJSON: `{"type": "invalid", "external_mcp_enable": true}`,
|
||||
expectedErr: "不支持的传输模式",
|
||||
},
|
||||
}
|
||||
@@ -254,7 +248,7 @@ func TestExternalMCPHandler_DeleteExternalMCP(t *testing.T) {
|
||||
// 先添加一个配置
|
||||
configObj := config.ExternalMCPServerConfig{
|
||||
Command: "python3",
|
||||
Enabled: true,
|
||||
ExternalMCPEnable: true,
|
||||
}
|
||||
handler.manager.AddOrUpdateConfig("test-delete", configObj)
|
||||
|
||||
@@ -283,11 +277,11 @@ func TestExternalMCPHandler_GetExternalMCPs(t *testing.T) {
|
||||
// 添加多个配置
|
||||
handler.manager.AddOrUpdateConfig("test1", config.ExternalMCPServerConfig{
|
||||
Command: "python3",
|
||||
Enabled: true,
|
||||
ExternalMCPEnable: true,
|
||||
})
|
||||
handler.manager.AddOrUpdateConfig("test2", config.ExternalMCPServerConfig{
|
||||
URL: "http://127.0.0.1:8081/mcp",
|
||||
Enabled: false,
|
||||
ExternalMCPEnable: false,
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/external-mcp", nil)
|
||||
@@ -326,16 +320,14 @@ func TestExternalMCPHandler_GetExternalMCPStats(t *testing.T) {
|
||||
// 添加配置
|
||||
handler.manager.AddOrUpdateConfig("enabled1", config.ExternalMCPServerConfig{
|
||||
Command: "python3",
|
||||
Enabled: true,
|
||||
ExternalMCPEnable: true,
|
||||
})
|
||||
handler.manager.AddOrUpdateConfig("enabled2", config.ExternalMCPServerConfig{
|
||||
URL: "http://127.0.0.1:8081/mcp",
|
||||
Enabled: true,
|
||||
ExternalMCPEnable: true,
|
||||
})
|
||||
handler.manager.AddOrUpdateConfig("disabled1", config.ExternalMCPServerConfig{
|
||||
Command: "python3",
|
||||
Enabled: false,
|
||||
Disabled: true,
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/external-mcp/stats", nil)
|
||||
@@ -369,8 +361,6 @@ func TestExternalMCPHandler_StartStopExternalMCP(t *testing.T) {
|
||||
// 添加一个禁用的配置
|
||||
handler.manager.AddOrUpdateConfig("test-start-stop", config.ExternalMCPServerConfig{
|
||||
Command: "python3",
|
||||
Enabled: false,
|
||||
Disabled: true,
|
||||
})
|
||||
|
||||
// 测试启动(可能会失败,因为没有真实的服务器)
|
||||
@@ -427,7 +417,7 @@ func TestExternalMCPHandler_AddOrUpdateExternalMCP_EmptyName(t *testing.T) {
|
||||
|
||||
configObj := config.ExternalMCPServerConfig{
|
||||
Command: "python3",
|
||||
Enabled: true,
|
||||
ExternalMCPEnable: true,
|
||||
}
|
||||
|
||||
reqBody := AddOrUpdateExternalMCPRequest{
|
||||
@@ -470,14 +460,14 @@ func TestExternalMCPHandler_UpdateExistingConfig(t *testing.T) {
|
||||
// 先添加配置
|
||||
config1 := config.ExternalMCPServerConfig{
|
||||
Command: "python3",
|
||||
Enabled: true,
|
||||
ExternalMCPEnable: true,
|
||||
}
|
||||
handler.manager.AddOrUpdateConfig("test-update", config1)
|
||||
|
||||
// 更新配置
|
||||
config2 := config.ExternalMCPServerConfig{
|
||||
URL: "http://127.0.0.1:8081/mcp",
|
||||
Enabled: true,
|
||||
ExternalMCPEnable: true,
|
||||
}
|
||||
|
||||
reqBody := AddOrUpdateExternalMCPRequest{
|
||||
|
||||
@@ -411,7 +411,10 @@ func (h *WebShellHandler) Exec(c *gin.Context) {
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
out, _ := io.ReadAll(resp.Body)
|
||||
out, readErr := io.ReadAll(resp.Body)
|
||||
if readErr != nil {
|
||||
h.logger.Warn("webshell exec read body", zap.Error(readErr))
|
||||
}
|
||||
output := string(out)
|
||||
httpCode := resp.StatusCode
|
||||
|
||||
@@ -578,7 +581,10 @@ func (h *WebShellHandler) FileOp(c *gin.Context) {
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
out, _ := io.ReadAll(resp.Body)
|
||||
out, readErr := io.ReadAll(resp.Body)
|
||||
if readErr != nil {
|
||||
h.logger.Warn("webshell fileop read body", zap.Error(readErr))
|
||||
}
|
||||
output := string(out)
|
||||
|
||||
c.JSON(http.StatusOK, FileOpResponse{
|
||||
@@ -633,7 +639,10 @@ func (h *WebShellHandler) ExecWithConnection(conn *database.WebShellConnection,
|
||||
return "", false, err.Error()
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
out, _ := io.ReadAll(resp.Body)
|
||||
out, readErr := io.ReadAll(resp.Body)
|
||||
if readErr != nil {
|
||||
h.logger.Warn("webshell ExecWithConnection read body", zap.Error(readErr))
|
||||
}
|
||||
return string(out), resp.StatusCode == http.StatusOK, ""
|
||||
}
|
||||
|
||||
@@ -701,6 +710,9 @@ func (h *WebShellHandler) FileOpWithConnection(conn *database.WebShellConnection
|
||||
return "", false, err.Error()
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
out, _ := io.ReadAll(resp.Body)
|
||||
out, readErr := io.ReadAll(resp.Body)
|
||||
if readErr != nil {
|
||||
h.logger.Warn("webshell FileOpWithConnection read body", zap.Error(readErr))
|
||||
}
|
||||
return string(out), resp.StatusCode == http.StatusOK, ""
|
||||
}
|
||||
|
||||
@@ -487,7 +487,10 @@ func (c *Client) claudeChatCompletionStream(ctx context.Context, payload interfa
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
respBody, readErr := io.ReadAll(resp.Body)
|
||||
if readErr != nil {
|
||||
return "", fmt.Errorf("claude bridge: read error response: %w", readErr)
|
||||
}
|
||||
return "", &APIError{
|
||||
StatusCode: resp.StatusCode,
|
||||
Body: string(respBody),
|
||||
@@ -588,7 +591,10 @@ func (c *Client) claudeChatCompletionStreamWithToolCalls(
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
respBody, readErr := io.ReadAll(resp.Body)
|
||||
if readErr != nil {
|
||||
return "", nil, "", fmt.Errorf("claude bridge: read error response: %w", readErr)
|
||||
}
|
||||
return "", nil, "", &APIError{
|
||||
StatusCode: resp.StatusCode,
|
||||
Body: string(respBody),
|
||||
@@ -824,7 +830,11 @@ func (rt *claudeRoundTripper) RoundTrip(req *http.Request) (*http.Response, erro
|
||||
|
||||
// 非 200:尝试把 Claude 错误格式转成 OpenAI 错误格式,便于 Eino 解析
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
bodyBytes, readErr := io.ReadAll(resp.Body)
|
||||
if readErr != nil {
|
||||
resp.Body.Close()
|
||||
return nil, fmt.Errorf("claude bridge: read error response: %w", readErr)
|
||||
}
|
||||
resp.Body.Close()
|
||||
converted := rt.tryConvertClaudeErrorToOpenAI(bodyBytes)
|
||||
return &http.Response{
|
||||
@@ -838,7 +848,11 @@ func (rt *claudeRoundTripper) RoundTrip(req *http.Request) (*http.Response, erro
|
||||
|
||||
// 非流式:一次性转换响应体
|
||||
if !claudeReq.Stream {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
respBody, readErr := io.ReadAll(resp.Body)
|
||||
if readErr != nil {
|
||||
resp.Body.Close()
|
||||
return nil, fmt.Errorf("claude bridge: read response: %w", readErr)
|
||||
}
|
||||
resp.Body.Close()
|
||||
oaiJSON, err := claudeToOpenAIResponseJSON(respBody)
|
||||
if err != nil {
|
||||
|
||||
@@ -189,7 +189,10 @@ func (c *Client) ChatCompletionStream(ctx context.Context, payload interface{},
|
||||
|
||||
// 非200:读完 body 返回
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
respBody, readErr := io.ReadAll(resp.Body)
|
||||
if readErr != nil {
|
||||
c.logger.Warn("failed to read OpenAI error response body", zap.Error(readErr))
|
||||
}
|
||||
return "", &APIError{
|
||||
StatusCode: resp.StatusCode,
|
||||
Body: string(respBody),
|
||||
@@ -329,7 +332,10 @@ func (c *Client) ChatCompletionStreamWithToolCalls(
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
respBody, readErr := io.ReadAll(resp.Body)
|
||||
if readErr != nil {
|
||||
c.logger.Warn("failed to read OpenAI error response body", zap.Error(readErr))
|
||||
}
|
||||
return "", nil, "", &APIError{
|
||||
StatusCode: resp.StatusCode,
|
||||
Body: string(respBody),
|
||||
|
||||
@@ -0,0 +1,81 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// rateLimitEntry 记录某个 IP 的请求窗口信息
|
||||
type rateLimitEntry struct {
|
||||
count int
|
||||
windowAt time.Time
|
||||
}
|
||||
|
||||
// RateLimiter 基于 IP 的滑动窗口速率限制器
|
||||
type RateLimiter struct {
|
||||
mu sync.Mutex
|
||||
entries map[string]*rateLimitEntry
|
||||
limit int // 窗口内允许的最大请求数
|
||||
window time.Duration // 窗口时长
|
||||
}
|
||||
|
||||
// NewRateLimiter 创建速率限制器
|
||||
func NewRateLimiter(limit int, window time.Duration) *RateLimiter {
|
||||
rl := &RateLimiter{
|
||||
entries: make(map[string]*rateLimitEntry),
|
||||
limit: limit,
|
||||
window: window,
|
||||
}
|
||||
// 后台定期清理过期条目,防止内存泄漏
|
||||
go rl.cleanup()
|
||||
return rl
|
||||
}
|
||||
|
||||
// cleanup 每分钟清理一次过期条目
|
||||
func (rl *RateLimiter) cleanup() {
|
||||
ticker := time.NewTicker(1 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
for range ticker.C {
|
||||
rl.mu.Lock()
|
||||
now := time.Now()
|
||||
for ip, entry := range rl.entries {
|
||||
if now.Sub(entry.windowAt) > rl.window {
|
||||
delete(rl.entries, ip)
|
||||
}
|
||||
}
|
||||
rl.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// allow 检查指定 IP 是否允许通过
|
||||
func (rl *RateLimiter) allow(ip string) bool {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
entry, ok := rl.entries[ip]
|
||||
if !ok || now.Sub(entry.windowAt) > rl.window {
|
||||
rl.entries[ip] = &rateLimitEntry{count: 1, windowAt: now}
|
||||
return true
|
||||
}
|
||||
|
||||
entry.count++
|
||||
return entry.count <= rl.limit
|
||||
}
|
||||
|
||||
// RateLimitMiddleware 返回 Gin 中间件,对超限请求返回 429
|
||||
func RateLimitMiddleware(rl *RateLimiter) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
ip := c.ClientIP()
|
||||
if !rl.allow(ip) {
|
||||
c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{
|
||||
"error": "rate limit exceeded, please try again later",
|
||||
})
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user