mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-03-31 08:19:54 +02:00
Add files via upload
This commit is contained in:
@@ -81,6 +81,7 @@ type ExternalMCPServerConfig struct {
|
||||
// stdio模式配置
|
||||
Command string `yaml:"command,omitempty" json:"command,omitempty"`
|
||||
Args []string `yaml:"args,omitempty" json:"args,omitempty"`
|
||||
Env map[string]string `yaml:"env,omitempty" json:"env,omitempty"` // 环境变量(用于stdio模式)
|
||||
|
||||
// HTTP模式配置
|
||||
Transport string `yaml:"transport,omitempty" json:"transport,omitempty"` // "http" 或 "stdio"
|
||||
|
||||
@@ -47,17 +47,17 @@ type ConfigHandler struct {
|
||||
config *config.Config
|
||||
mcpServer *mcp.Server
|
||||
executor *security.Executor
|
||||
agent AgentUpdater // Agent接口,用于更新Agent配置
|
||||
attackChainHandler AttackChainUpdater // 攻击链处理器接口,用于更新配置
|
||||
agent AgentUpdater // Agent接口,用于更新Agent配置
|
||||
attackChainHandler AttackChainUpdater // 攻击链处理器接口,用于更新配置
|
||||
externalMCPMgr *mcp.ExternalMCPManager // 外部MCP管理器
|
||||
knowledgeToolRegistrar KnowledgeToolRegistrar // 知识库工具注册器(可选)
|
||||
knowledgeToolRegistrar KnowledgeToolRegistrar // 知识库工具注册器(可选)
|
||||
vulnerabilityToolRegistrar VulnerabilityToolRegistrar // 漏洞工具注册器(可选)
|
||||
retrieverUpdater RetrieverUpdater // 检索器更新器(可选)
|
||||
knowledgeInitializer KnowledgeInitializer // 知识库初始化器(可选)
|
||||
appUpdater AppUpdater // App更新器(可选)
|
||||
retrieverUpdater RetrieverUpdater // 检索器更新器(可选)
|
||||
knowledgeInitializer KnowledgeInitializer // 知识库初始化器(可选)
|
||||
appUpdater AppUpdater // App更新器(可选)
|
||||
logger *zap.Logger
|
||||
mu sync.RWMutex
|
||||
lastEmbeddingConfig *config.EmbeddingConfig // 上一次的嵌入模型配置(用于检测变更)
|
||||
lastEmbeddingConfig *config.EmbeddingConfig // 上一次的嵌入模型配置(用于检测变更)
|
||||
}
|
||||
|
||||
// AttackChainUpdater 攻击链处理器更新接口
|
||||
@@ -790,30 +790,30 @@ func (h *ConfigHandler) ApplyConfig(c *gin.Context) {
|
||||
h.logger.Info("AttackChainHandler配置已更新")
|
||||
}
|
||||
|
||||
// 更新检索器配置(如果知识库启用)
|
||||
if h.config.Knowledge.Enabled && h.retrieverUpdater != nil {
|
||||
retrievalConfig := &knowledge.RetrievalConfig{
|
||||
TopK: h.config.Knowledge.Retrieval.TopK,
|
||||
SimilarityThreshold: h.config.Knowledge.Retrieval.SimilarityThreshold,
|
||||
HybridWeight: h.config.Knowledge.Retrieval.HybridWeight,
|
||||
}
|
||||
h.retrieverUpdater.UpdateConfig(retrievalConfig)
|
||||
h.logger.Info("检索器配置已更新",
|
||||
zap.Int("top_k", retrievalConfig.TopK),
|
||||
zap.Float64("similarity_threshold", retrievalConfig.SimilarityThreshold),
|
||||
zap.Float64("hybrid_weight", retrievalConfig.HybridWeight),
|
||||
)
|
||||
// 更新检索器配置(如果知识库启用)
|
||||
if h.config.Knowledge.Enabled && h.retrieverUpdater != nil {
|
||||
retrievalConfig := &knowledge.RetrievalConfig{
|
||||
TopK: h.config.Knowledge.Retrieval.TopK,
|
||||
SimilarityThreshold: h.config.Knowledge.Retrieval.SimilarityThreshold,
|
||||
HybridWeight: h.config.Knowledge.Retrieval.HybridWeight,
|
||||
}
|
||||
h.retrieverUpdater.UpdateConfig(retrievalConfig)
|
||||
h.logger.Info("检索器配置已更新",
|
||||
zap.Int("top_k", retrievalConfig.TopK),
|
||||
zap.Float64("similarity_threshold", retrievalConfig.SimilarityThreshold),
|
||||
zap.Float64("hybrid_weight", retrievalConfig.HybridWeight),
|
||||
)
|
||||
}
|
||||
|
||||
// 更新嵌入模型配置记录(如果知识库启用)
|
||||
if h.config.Knowledge.Enabled {
|
||||
h.lastEmbeddingConfig = &config.EmbeddingConfig{
|
||||
Provider: h.config.Knowledge.Embedding.Provider,
|
||||
Model: h.config.Knowledge.Embedding.Model,
|
||||
BaseURL: h.config.Knowledge.Embedding.BaseURL,
|
||||
APIKey: h.config.Knowledge.Embedding.APIKey,
|
||||
}
|
||||
// 更新嵌入模型配置记录(如果知识库启用)
|
||||
if h.config.Knowledge.Enabled {
|
||||
h.lastEmbeddingConfig = &config.EmbeddingConfig{
|
||||
Provider: h.config.Knowledge.Embedding.Provider,
|
||||
Model: h.config.Knowledge.Embedding.Model,
|
||||
BaseURL: h.config.Knowledge.Embedding.BaseURL,
|
||||
APIKey: h.config.Knowledge.Embedding.APIKey,
|
||||
}
|
||||
}
|
||||
|
||||
h.logger.Info("配置已应用",
|
||||
zap.Int("tools_count", len(h.config.Security.Tools)),
|
||||
|
||||
@@ -446,6 +446,13 @@ func updateExternalMCPConfig(doc *yaml.Node, cfg config.ExternalMCPConfig, origi
|
||||
if len(serverCfg.Args) > 0 {
|
||||
setStringArrayInMap(serverNode, "args", serverCfg.Args)
|
||||
}
|
||||
// 保存 env 字段(环境变量)
|
||||
if serverCfg.Env != nil && len(serverCfg.Env) > 0 {
|
||||
envNode := ensureMap(serverNode, "env")
|
||||
for envKey, envValue := range serverCfg.Env {
|
||||
setStringInMap(envNode, envKey, envValue)
|
||||
}
|
||||
}
|
||||
if serverCfg.Transport != "" {
|
||||
setStringInMap(serverNode, "transport", serverCfg.Transport)
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -246,6 +247,7 @@ func (c *HTTPMCPClient) Close() error {
|
||||
type StdioMCPClient struct {
|
||||
command string
|
||||
args []string
|
||||
env map[string]string
|
||||
timeout time.Duration
|
||||
cmd *exec.Cmd
|
||||
stdin io.WriteCloser
|
||||
@@ -263,7 +265,7 @@ type StdioMCPClient struct {
|
||||
}
|
||||
|
||||
// NewStdioMCPClient 创建stdio模式的MCP客户端
|
||||
func NewStdioMCPClient(command string, args []string, timeout time.Duration, logger *zap.Logger) *StdioMCPClient {
|
||||
func NewStdioMCPClient(command string, args []string, env map[string]string, timeout time.Duration, logger *zap.Logger) *StdioMCPClient {
|
||||
if timeout <= 0 {
|
||||
timeout = 30 * time.Second
|
||||
}
|
||||
@@ -271,6 +273,7 @@ func NewStdioMCPClient(command string, args []string, timeout time.Duration, log
|
||||
return &StdioMCPClient{
|
||||
command: command,
|
||||
args: args,
|
||||
env: env,
|
||||
timeout: timeout,
|
||||
logger: logger,
|
||||
status: "disconnected",
|
||||
@@ -354,6 +357,27 @@ func (c *StdioMCPClient) Initialize(ctx context.Context) error {
|
||||
func (c *StdioMCPClient) startProcess() error {
|
||||
cmd := exec.CommandContext(c.ctx, c.command, c.args...)
|
||||
|
||||
// 设置环境变量
|
||||
if c.env != nil && len(c.env) > 0 {
|
||||
// 获取当前环境变量
|
||||
cmd.Env = os.Environ()
|
||||
// 添加或覆盖配置的环境变量
|
||||
for key, value := range c.env {
|
||||
// 检查是否已存在该环境变量
|
||||
found := false
|
||||
for i, envVar := range cmd.Env {
|
||||
if strings.HasPrefix(envVar, key+"=") {
|
||||
cmd.Env[i] = key + "=" + value
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
cmd.Env = append(cmd.Env, key+"="+value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
stdin, err := cmd.StdinPipe()
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
@@ -738,7 +738,7 @@ func (m *ExternalMCPManager) createClient(serverCfg config.ExternalMCPServerConf
|
||||
if serverCfg.Command == "" {
|
||||
return nil
|
||||
}
|
||||
return NewStdioMCPClient(serverCfg.Command, serverCfg.Args, timeout, m.logger)
|
||||
return NewStdioMCPClient(serverCfg.Command, serverCfg.Args, serverCfg.Env, timeout, m.logger)
|
||||
case "sse":
|
||||
if serverCfg.URL == "" {
|
||||
return nil
|
||||
|
||||
@@ -180,7 +180,7 @@ func TestStdioMCPClient_Initialize(t *testing.T) {
|
||||
// 注意:这个测试需要一个真实的stdio MCP服务器
|
||||
// 如果没有服务器,这个测试会失败
|
||||
logger := zap.NewNop()
|
||||
client := NewStdioMCPClient("echo", []string{"test"}, 5*time.Second, logger)
|
||||
client := NewStdioMCPClient("echo", []string{"test"}, nil, 5*time.Second, logger)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
Reference in New Issue
Block a user