From 5b3709b9adde30991cf9a00162c50b22bc2322a7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=85=AC=E6=98=8E?= <83812544+Ed1s0nZ@users.noreply.github.com> Date: Mon, 9 Mar 2026 22:37:37 +0800 Subject: [PATCH] Add files via upload --- cmd/server/main.go | 9 +++ internal/app/app.go | 17 ++++- internal/config/config.go | 128 +++++++++++++++++++++++++++++++++++++- 3 files changed, 150 insertions(+), 4 deletions(-) diff --git a/cmd/server/main.go b/cmd/server/main.go index cb4292bd..9a962a3e 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -19,6 +19,15 @@ func main() { return } + // MCP 启用且 auth_header_value 为空时,自动生成随机密钥并写回配置 + if err := config.EnsureMCPAuth(*configPath, cfg); err != nil { + fmt.Printf("MCP 鉴权配置失败: %v\n", err) + return + } + if cfg.MCP.Enabled { + config.PrintMCPConfigJSON(cfg.MCP) + } + // 初始化日志 log := logger.New(cfg.Log.Level, cfg.Log.Output) diff --git a/internal/app/app.go b/internal/app/app.go index 11edb3c1..a09b683b 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -442,6 +442,21 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) { } +// mcpHandlerWithAuth 在鉴权通过后转发到 MCP 处理;若配置了 auth_header 则校验请求头,否则直接放行 +func (a *App) mcpHandlerWithAuth(w http.ResponseWriter, r *http.Request) { + cfg := a.config.MCP + if cfg.AuthHeader != "" { + if r.Header.Get(cfg.AuthHeader) != cfg.AuthHeaderValue { + a.logger.Logger.Debug("MCP 鉴权失败:header 缺失或值不匹配", zap.String("header", cfg.AuthHeader)) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusUnauthorized) + w.Write([]byte(`{"error":"unauthorized"}`)) + return + } + } + a.mcpServer.HandleHTTP(w, r) +} + // Run 启动应用 func (a *App) Run() error { // 启动MCP服务器(如果启用) @@ -451,7 +466,7 @@ func (a *App) Run() error { a.logger.Info("启动MCP服务器", zap.String("address", mcpAddr)) mux := http.NewServeMux() - mux.HandleFunc("/mcp", a.mcpServer.HandleHTTP) + mux.HandleFunc("/mcp", a.mcpHandlerWithAuth) if err := http.ListenAndServe(mcpAddr, mux); err != nil { a.logger.Error("MCP服务器启动失败", zap.Error(err)) diff --git a/internal/config/config.go b/internal/config/config.go index 83b0997f..3db8cb00 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -3,6 +3,8 @@ package config import ( "crypto/rand" "encoding/base64" + "encoding/hex" + "encoding/json" "fmt" "os" "path/filepath" @@ -74,9 +76,11 @@ type LogConfig struct { } type MCPConfig struct { - Enabled bool `yaml:"enabled"` - Host string `yaml:"host"` - Port int `yaml:"port"` + Enabled bool `yaml:"enabled"` + Host string `yaml:"host"` + Port int `yaml:"port"` + AuthHeader string `yaml:"auth_header,omitempty"` // 鉴权 header 名,留空表示不鉴权 + AuthHeaderValue string `yaml:"auth_header_value,omitempty"` // 鉴权 header 值,需与请求中该 header 一致 } type OpenAIConfig struct { @@ -384,6 +388,124 @@ func PrintGeneratedPasswordWarning(password string, persisted bool, persistErr s fmt.Println("----------------------------------------------------------------") } +// generateRandomToken 生成用于 MCP 鉴权的随机字符串(64 位十六进制) +func generateRandomToken() (string, error) { + b := make([]byte, 32) + if _, err := rand.Read(b); err != nil { + return "", err + } + return hex.EncodeToString(b), nil +} + +// persistMCPAuth 将 MCP 的 auth_header / auth_header_value 写回配置文件 +func persistMCPAuth(path string, mcp *MCPConfig) error { + data, err := os.ReadFile(path) + if err != nil { + return err + } + lines := strings.Split(string(data), "\n") + inMcpBlock := false + mcpIndent := -1 + + for i, line := range lines { + trimmed := strings.TrimSpace(line) + if !inMcpBlock { + if strings.HasPrefix(trimmed, "mcp:") { + inMcpBlock = true + mcpIndent = len(line) - len(strings.TrimLeft(line, " ")) + } + continue + } + if trimmed == "" || strings.HasPrefix(trimmed, "#") { + continue + } + leadingSpaces := len(line) - len(strings.TrimLeft(line, " ")) + if leadingSpaces <= mcpIndent { + inMcpBlock = false + mcpIndent = -1 + if strings.HasPrefix(trimmed, "mcp:") { + inMcpBlock = true + mcpIndent = leadingSpaces + } + continue + } + + prefix := line[:leadingSpaces] + rest := strings.TrimSpace(line[leadingSpaces:]) + comment := "" + if idx := strings.Index(line, "#"); idx >= 0 { + comment = strings.TrimRight(line[idx:], " ") + } + withComment := "" + if comment != "" { + if !strings.HasPrefix(comment, " ") { + withComment = " " + } + withComment += comment + } + + if strings.HasPrefix(rest, "auth_header_value:") { + lines[i] = fmt.Sprintf("%sauth_header_value: %q%s", prefix, mcp.AuthHeaderValue, withComment) + } else if strings.HasPrefix(rest, "auth_header:") { + lines[i] = fmt.Sprintf("%sauth_header: %q%s", prefix, mcp.AuthHeader, withComment) + } + } + + return os.WriteFile(path, []byte(strings.Join(lines, "\n")), 0644) +} + +// EnsureMCPAuth 在 MCP 启用且 auth_header_value 为空时,自动生成随机密钥并写回配置 +func EnsureMCPAuth(path string, cfg *Config) error { + if !cfg.MCP.Enabled || strings.TrimSpace(cfg.MCP.AuthHeaderValue) != "" { + return nil + } + token, err := generateRandomToken() + if err != nil { + return fmt.Errorf("生成 MCP 鉴权密钥失败: %w", err) + } + cfg.MCP.AuthHeaderValue = token + if strings.TrimSpace(cfg.MCP.AuthHeader) == "" { + cfg.MCP.AuthHeader = "X-MCP-Token" + } + return persistMCPAuth(path, &cfg.MCP) +} + +// PrintMCPConfigJSON 向终端输出 MCP 配置的 JSON,可直接复制到 Cursor / Claude Code 的 mcp 配置中使用 +func PrintMCPConfigJSON(mcp MCPConfig) { + if !mcp.Enabled { + return + } + hostForURL := strings.TrimSpace(mcp.Host) + if hostForURL == "" || hostForURL == "0.0.0.0" { + hostForURL = "localhost" + } + url := fmt.Sprintf("http://%s:%d/mcp", hostForURL, mcp.Port) + headers := map[string]string{} + if mcp.AuthHeader != "" { + headers[mcp.AuthHeader] = mcp.AuthHeaderValue + } + serverEntry := map[string]interface{}{ + "url": url, + } + if len(headers) > 0 { + serverEntry["headers"] = headers + } + // Claude Code 需要 type: "http" + serverEntry["type"] = "http" + out := map[string]interface{}{ + "mcpServers": map[string]interface{}{ + "cyberstrike-ai": serverEntry, + }, + } + b, _ := json.MarshalIndent(out, "", " ") + fmt.Println("[CyberStrikeAI] MCP 配置(可复制到 Cursor / Claude Code 使用):") + fmt.Println(" Cursor: 放入 ~/.cursor/mcp.json 的 mcpServers,或项目 .cursor/mcp.json") + fmt.Println(" Claude Code: 放入 .mcp.json 或 ~/.claude.json 的 mcpServers") + fmt.Println("----------------------------------------------------------------") + fmt.Println(string(b)) + fmt.Println("----------------------------------------------------------------") +} + // LoadToolsFromDir 从目录加载所有工具配置文件 func LoadToolsFromDir(dir string) ([]ToolConfig, error) { var tools []ToolConfig