Add files via upload

This commit is contained in:
公明
2026-01-09 19:44:59 +08:00
committed by GitHub
parent c3a1d95a92
commit 2c973f8c3b
6 changed files with 63 additions and 31 deletions

View File

@@ -81,6 +81,7 @@ type ExternalMCPServerConfig struct {
// stdio模式配置
Command string `yaml:"command,omitempty" json:"command,omitempty"`
Args []string `yaml:"args,omitempty" json:"args,omitempty"`
Env map[string]string `yaml:"env,omitempty" json:"env,omitempty"` // 环境变量用于stdio模式
// HTTP模式配置
Transport string `yaml:"transport,omitempty" json:"transport,omitempty"` // "http" 或 "stdio"

View File

@@ -47,17 +47,17 @@ type ConfigHandler struct {
config *config.Config
mcpServer *mcp.Server
executor *security.Executor
agent AgentUpdater // Agent接口用于更新Agent配置
attackChainHandler AttackChainUpdater // 攻击链处理器接口,用于更新配置
agent AgentUpdater // Agent接口用于更新Agent配置
attackChainHandler AttackChainUpdater // 攻击链处理器接口,用于更新配置
externalMCPMgr *mcp.ExternalMCPManager // 外部MCP管理器
knowledgeToolRegistrar KnowledgeToolRegistrar // 知识库工具注册器(可选)
knowledgeToolRegistrar KnowledgeToolRegistrar // 知识库工具注册器(可选)
vulnerabilityToolRegistrar VulnerabilityToolRegistrar // 漏洞工具注册器(可选)
retrieverUpdater RetrieverUpdater // 检索器更新器(可选)
knowledgeInitializer KnowledgeInitializer // 知识库初始化器(可选)
appUpdater AppUpdater // App更新器可选
retrieverUpdater RetrieverUpdater // 检索器更新器(可选)
knowledgeInitializer KnowledgeInitializer // 知识库初始化器(可选)
appUpdater AppUpdater // App更新器可选
logger *zap.Logger
mu sync.RWMutex
lastEmbeddingConfig *config.EmbeddingConfig // 上一次的嵌入模型配置(用于检测变更)
lastEmbeddingConfig *config.EmbeddingConfig // 上一次的嵌入模型配置(用于检测变更)
}
// AttackChainUpdater 攻击链处理器更新接口
@@ -790,30 +790,30 @@ func (h *ConfigHandler) ApplyConfig(c *gin.Context) {
h.logger.Info("AttackChainHandler配置已更新")
}
// 更新检索器配置(如果知识库启用)
if h.config.Knowledge.Enabled && h.retrieverUpdater != nil {
retrievalConfig := &knowledge.RetrievalConfig{
TopK: h.config.Knowledge.Retrieval.TopK,
SimilarityThreshold: h.config.Knowledge.Retrieval.SimilarityThreshold,
HybridWeight: h.config.Knowledge.Retrieval.HybridWeight,
}
h.retrieverUpdater.UpdateConfig(retrievalConfig)
h.logger.Info("检索器配置已更新",
zap.Int("top_k", retrievalConfig.TopK),
zap.Float64("similarity_threshold", retrievalConfig.SimilarityThreshold),
zap.Float64("hybrid_weight", retrievalConfig.HybridWeight),
)
// 更新检索器配置(如果知识库启用)
if h.config.Knowledge.Enabled && h.retrieverUpdater != nil {
retrievalConfig := &knowledge.RetrievalConfig{
TopK: h.config.Knowledge.Retrieval.TopK,
SimilarityThreshold: h.config.Knowledge.Retrieval.SimilarityThreshold,
HybridWeight: h.config.Knowledge.Retrieval.HybridWeight,
}
h.retrieverUpdater.UpdateConfig(retrievalConfig)
h.logger.Info("检索器配置已更新",
zap.Int("top_k", retrievalConfig.TopK),
zap.Float64("similarity_threshold", retrievalConfig.SimilarityThreshold),
zap.Float64("hybrid_weight", retrievalConfig.HybridWeight),
)
}
// 更新嵌入模型配置记录(如果知识库启用)
if h.config.Knowledge.Enabled {
h.lastEmbeddingConfig = &config.EmbeddingConfig{
Provider: h.config.Knowledge.Embedding.Provider,
Model: h.config.Knowledge.Embedding.Model,
BaseURL: h.config.Knowledge.Embedding.BaseURL,
APIKey: h.config.Knowledge.Embedding.APIKey,
}
// 更新嵌入模型配置记录(如果知识库启用)
if h.config.Knowledge.Enabled {
h.lastEmbeddingConfig = &config.EmbeddingConfig{
Provider: h.config.Knowledge.Embedding.Provider,
Model: h.config.Knowledge.Embedding.Model,
BaseURL: h.config.Knowledge.Embedding.BaseURL,
APIKey: h.config.Knowledge.Embedding.APIKey,
}
}
h.logger.Info("配置已应用",
zap.Int("tools_count", len(h.config.Security.Tools)),

View File

@@ -446,6 +446,13 @@ func updateExternalMCPConfig(doc *yaml.Node, cfg config.ExternalMCPConfig, origi
if len(serverCfg.Args) > 0 {
setStringArrayInMap(serverNode, "args", serverCfg.Args)
}
// 保存 env 字段(环境变量)
if serverCfg.Env != nil && len(serverCfg.Env) > 0 {
envNode := ensureMap(serverNode, "env")
for envKey, envValue := range serverCfg.Env {
setStringInMap(envNode, envKey, envValue)
}
}
if serverCfg.Transport != "" {
setStringInMap(serverNode, "transport", serverCfg.Transport)
}

View File

@@ -8,6 +8,7 @@ import (
"fmt"
"io"
"net/http"
"os"
"os/exec"
"strings"
"sync"
@@ -246,6 +247,7 @@ func (c *HTTPMCPClient) Close() error {
type StdioMCPClient struct {
command string
args []string
env map[string]string
timeout time.Duration
cmd *exec.Cmd
stdin io.WriteCloser
@@ -263,7 +265,7 @@ type StdioMCPClient struct {
}
// NewStdioMCPClient 创建stdio模式的MCP客户端
func NewStdioMCPClient(command string, args []string, timeout time.Duration, logger *zap.Logger) *StdioMCPClient {
func NewStdioMCPClient(command string, args []string, env map[string]string, timeout time.Duration, logger *zap.Logger) *StdioMCPClient {
if timeout <= 0 {
timeout = 30 * time.Second
}
@@ -271,6 +273,7 @@ func NewStdioMCPClient(command string, args []string, timeout time.Duration, log
return &StdioMCPClient{
command: command,
args: args,
env: env,
timeout: timeout,
logger: logger,
status: "disconnected",
@@ -354,6 +357,27 @@ func (c *StdioMCPClient) Initialize(ctx context.Context) error {
func (c *StdioMCPClient) startProcess() error {
cmd := exec.CommandContext(c.ctx, c.command, c.args...)
// 设置环境变量
if c.env != nil && len(c.env) > 0 {
// 获取当前环境变量
cmd.Env = os.Environ()
// 添加或覆盖配置的环境变量
for key, value := range c.env {
// 检查是否已存在该环境变量
found := false
for i, envVar := range cmd.Env {
if strings.HasPrefix(envVar, key+"=") {
cmd.Env[i] = key + "=" + value
found = true
break
}
}
if !found {
cmd.Env = append(cmd.Env, key+"="+value)
}
}
}
stdin, err := cmd.StdinPipe()
if err != nil {
return err

View File

@@ -738,7 +738,7 @@ func (m *ExternalMCPManager) createClient(serverCfg config.ExternalMCPServerConf
if serverCfg.Command == "" {
return nil
}
return NewStdioMCPClient(serverCfg.Command, serverCfg.Args, timeout, m.logger)
return NewStdioMCPClient(serverCfg.Command, serverCfg.Args, serverCfg.Env, timeout, m.logger)
case "sse":
if serverCfg.URL == "" {
return nil

View File

@@ -180,7 +180,7 @@ func TestStdioMCPClient_Initialize(t *testing.T) {
// 注意这个测试需要一个真实的stdio MCP服务器
// 如果没有服务器,这个测试会失败
logger := zap.NewNop()
client := NewStdioMCPClient("echo", []string{"test"}, 5*time.Second, logger)
client := NewStdioMCPClient("echo", []string{"test"}, nil, 5*time.Second, logger)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()