diff --git a/internal/config/config.go b/internal/config/config.go index 1db1a626..85aaee3b 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -81,6 +81,7 @@ type ExternalMCPServerConfig struct { // stdio模式配置 Command string `yaml:"command,omitempty" json:"command,omitempty"` Args []string `yaml:"args,omitempty" json:"args,omitempty"` + Env map[string]string `yaml:"env,omitempty" json:"env,omitempty"` // 环境变量(用于stdio模式) // HTTP模式配置 Transport string `yaml:"transport,omitempty" json:"transport,omitempty"` // "http" 或 "stdio" diff --git a/internal/handler/config.go b/internal/handler/config.go index 0176c360..a0965721 100644 --- a/internal/handler/config.go +++ b/internal/handler/config.go @@ -47,17 +47,17 @@ type ConfigHandler struct { config *config.Config mcpServer *mcp.Server executor *security.Executor - agent AgentUpdater // Agent接口,用于更新Agent配置 - attackChainHandler AttackChainUpdater // 攻击链处理器接口,用于更新配置 + agent AgentUpdater // Agent接口,用于更新Agent配置 + attackChainHandler AttackChainUpdater // 攻击链处理器接口,用于更新配置 externalMCPMgr *mcp.ExternalMCPManager // 外部MCP管理器 - knowledgeToolRegistrar KnowledgeToolRegistrar // 知识库工具注册器(可选) + knowledgeToolRegistrar KnowledgeToolRegistrar // 知识库工具注册器(可选) vulnerabilityToolRegistrar VulnerabilityToolRegistrar // 漏洞工具注册器(可选) - retrieverUpdater RetrieverUpdater // 检索器更新器(可选) - knowledgeInitializer KnowledgeInitializer // 知识库初始化器(可选) - appUpdater AppUpdater // App更新器(可选) + retrieverUpdater RetrieverUpdater // 检索器更新器(可选) + knowledgeInitializer KnowledgeInitializer // 知识库初始化器(可选) + appUpdater AppUpdater // App更新器(可选) logger *zap.Logger mu sync.RWMutex - lastEmbeddingConfig *config.EmbeddingConfig // 上一次的嵌入模型配置(用于检测变更) + lastEmbeddingConfig *config.EmbeddingConfig // 上一次的嵌入模型配置(用于检测变更) } // AttackChainUpdater 攻击链处理器更新接口 @@ -790,30 +790,30 @@ func (h *ConfigHandler) ApplyConfig(c *gin.Context) { h.logger.Info("AttackChainHandler配置已更新") } - // 更新检索器配置(如果知识库启用) - if h.config.Knowledge.Enabled && h.retrieverUpdater != nil { - retrievalConfig := &knowledge.RetrievalConfig{ - TopK: h.config.Knowledge.Retrieval.TopK, - SimilarityThreshold: h.config.Knowledge.Retrieval.SimilarityThreshold, - HybridWeight: h.config.Knowledge.Retrieval.HybridWeight, - } - h.retrieverUpdater.UpdateConfig(retrievalConfig) - h.logger.Info("检索器配置已更新", - zap.Int("top_k", retrievalConfig.TopK), - zap.Float64("similarity_threshold", retrievalConfig.SimilarityThreshold), - zap.Float64("hybrid_weight", retrievalConfig.HybridWeight), - ) + // 更新检索器配置(如果知识库启用) + if h.config.Knowledge.Enabled && h.retrieverUpdater != nil { + retrievalConfig := &knowledge.RetrievalConfig{ + TopK: h.config.Knowledge.Retrieval.TopK, + SimilarityThreshold: h.config.Knowledge.Retrieval.SimilarityThreshold, + HybridWeight: h.config.Knowledge.Retrieval.HybridWeight, } + h.retrieverUpdater.UpdateConfig(retrievalConfig) + h.logger.Info("检索器配置已更新", + zap.Int("top_k", retrievalConfig.TopK), + zap.Float64("similarity_threshold", retrievalConfig.SimilarityThreshold), + zap.Float64("hybrid_weight", retrievalConfig.HybridWeight), + ) + } - // 更新嵌入模型配置记录(如果知识库启用) - if h.config.Knowledge.Enabled { - h.lastEmbeddingConfig = &config.EmbeddingConfig{ - Provider: h.config.Knowledge.Embedding.Provider, - Model: h.config.Knowledge.Embedding.Model, - BaseURL: h.config.Knowledge.Embedding.BaseURL, - APIKey: h.config.Knowledge.Embedding.APIKey, - } + // 更新嵌入模型配置记录(如果知识库启用) + if h.config.Knowledge.Enabled { + h.lastEmbeddingConfig = &config.EmbeddingConfig{ + Provider: h.config.Knowledge.Embedding.Provider, + Model: h.config.Knowledge.Embedding.Model, + BaseURL: h.config.Knowledge.Embedding.BaseURL, + APIKey: h.config.Knowledge.Embedding.APIKey, } + } h.logger.Info("配置已应用", zap.Int("tools_count", len(h.config.Security.Tools)), diff --git a/internal/handler/external_mcp.go b/internal/handler/external_mcp.go index 207566c7..1fc9e8b3 100644 --- a/internal/handler/external_mcp.go +++ b/internal/handler/external_mcp.go @@ -446,6 +446,13 @@ func updateExternalMCPConfig(doc *yaml.Node, cfg config.ExternalMCPConfig, origi if len(serverCfg.Args) > 0 { setStringArrayInMap(serverNode, "args", serverCfg.Args) } + // 保存 env 字段(环境变量) + if serverCfg.Env != nil && len(serverCfg.Env) > 0 { + envNode := ensureMap(serverNode, "env") + for envKey, envValue := range serverCfg.Env { + setStringInMap(envNode, envKey, envValue) + } + } if serverCfg.Transport != "" { setStringInMap(serverNode, "transport", serverCfg.Transport) } diff --git a/internal/mcp/client.go b/internal/mcp/client.go index 36514ae4..d88ca4c7 100644 --- a/internal/mcp/client.go +++ b/internal/mcp/client.go @@ -8,6 +8,7 @@ import ( "fmt" "io" "net/http" + "os" "os/exec" "strings" "sync" @@ -246,6 +247,7 @@ func (c *HTTPMCPClient) Close() error { type StdioMCPClient struct { command string args []string + env map[string]string timeout time.Duration cmd *exec.Cmd stdin io.WriteCloser @@ -263,7 +265,7 @@ type StdioMCPClient struct { } // NewStdioMCPClient 创建stdio模式的MCP客户端 -func NewStdioMCPClient(command string, args []string, timeout time.Duration, logger *zap.Logger) *StdioMCPClient { +func NewStdioMCPClient(command string, args []string, env map[string]string, timeout time.Duration, logger *zap.Logger) *StdioMCPClient { if timeout <= 0 { timeout = 30 * time.Second } @@ -271,6 +273,7 @@ func NewStdioMCPClient(command string, args []string, timeout time.Duration, log return &StdioMCPClient{ command: command, args: args, + env: env, timeout: timeout, logger: logger, status: "disconnected", @@ -354,6 +357,27 @@ func (c *StdioMCPClient) Initialize(ctx context.Context) error { func (c *StdioMCPClient) startProcess() error { cmd := exec.CommandContext(c.ctx, c.command, c.args...) + // 设置环境变量 + if c.env != nil && len(c.env) > 0 { + // 获取当前环境变量 + cmd.Env = os.Environ() + // 添加或覆盖配置的环境变量 + for key, value := range c.env { + // 检查是否已存在该环境变量 + found := false + for i, envVar := range cmd.Env { + if strings.HasPrefix(envVar, key+"=") { + cmd.Env[i] = key + "=" + value + found = true + break + } + } + if !found { + cmd.Env = append(cmd.Env, key+"="+value) + } + } + } + stdin, err := cmd.StdinPipe() if err != nil { return err diff --git a/internal/mcp/external_manager.go b/internal/mcp/external_manager.go index 8ef1335d..4eb3410b 100644 --- a/internal/mcp/external_manager.go +++ b/internal/mcp/external_manager.go @@ -738,7 +738,7 @@ func (m *ExternalMCPManager) createClient(serverCfg config.ExternalMCPServerConf if serverCfg.Command == "" { return nil } - return NewStdioMCPClient(serverCfg.Command, serverCfg.Args, timeout, m.logger) + return NewStdioMCPClient(serverCfg.Command, serverCfg.Args, serverCfg.Env, timeout, m.logger) case "sse": if serverCfg.URL == "" { return nil diff --git a/internal/mcp/external_manager_test.go b/internal/mcp/external_manager_test.go index 069af5c6..ffd0c8a2 100644 --- a/internal/mcp/external_manager_test.go +++ b/internal/mcp/external_manager_test.go @@ -180,7 +180,7 @@ func TestStdioMCPClient_Initialize(t *testing.T) { // 注意:这个测试需要一个真实的stdio MCP服务器 // 如果没有服务器,这个测试会失败 logger := zap.NewNop() - client := NewStdioMCPClient("echo", []string{"test"}, 5*time.Second, logger) + client := NewStdioMCPClient("echo", []string{"test"}, nil, 5*time.Second, logger) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel()