Files
2025-11-15 17:50:47 +08:00

146 lines
3.7 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package main
import (
"context"
"fmt"
"os"
"time"
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/logger"
"cyberstrike-ai/internal/mcp"
)
func main() {
if len(os.Args) < 2 {
fmt.Println("Usage: go run cmd/test-external-mcp/main.go <config.yaml>")
os.Exit(1)
}
configPath := os.Args[1]
cfg, err := config.Load(configPath)
if err != nil {
fmt.Printf("Error loading config: %v\n", err)
os.Exit(1)
}
if cfg.ExternalMCP.Servers == nil || len(cfg.ExternalMCP.Servers) == 0 {
fmt.Println("No external MCP servers configured")
os.Exit(0)
}
fmt.Printf("Found %d external MCP server(s)\n\n", len(cfg.ExternalMCP.Servers))
// 创建日志
log := logger.New("info", "stdout")
// 创建外部MCP管理器
manager := mcp.NewExternalMCPManager(log.Logger)
manager.LoadConfigs(&cfg.ExternalMCP)
// 显示配置
fmt.Println("=== 配置信息 ===")
for name, srv := range cfg.ExternalMCP.Servers {
fmt.Printf("\n%s:\n", name)
fmt.Printf(" Transport: %s\n", getTransport(srv))
if srv.Command != "" {
fmt.Printf(" Command: %s\n", srv.Command)
fmt.Printf(" Args: %v\n", srv.Args)
}
if srv.URL != "" {
fmt.Printf(" URL: %s\n", srv.URL)
}
fmt.Printf(" Description: %s\n", srv.Description)
fmt.Printf(" Timeout: %d seconds\n", srv.Timeout)
fmt.Printf(" Enabled: %v\n", srv.Enabled)
fmt.Printf(" Disabled: %v\n", srv.Disabled)
}
// 获取统计信息
fmt.Println("\n=== 统计信息 ===")
stats := manager.GetStats()
fmt.Printf("总数: %d\n", stats["total"])
fmt.Printf("已启用: %d\n", stats["enabled"])
fmt.Printf("已停用: %d\n", stats["disabled"])
fmt.Printf("已连接: %d\n", stats["connected"])
// 测试启动(仅测试启用的)
fmt.Println("\n=== 测试启动 ===")
for name, srv := range cfg.ExternalMCP.Servers {
if srv.Enabled && !srv.Disabled {
fmt.Printf("\n尝试启动 %s...\n", name)
// 注意实际启动可能会失败因为需要真实的MCP服务器
err := manager.StartClient(name)
if err != nil {
fmt.Printf(" 启动失败这是正常的如果没有真实的MCP服务器: %v\n", err)
} else {
fmt.Printf(" 启动成功\n")
// 获取客户端状态
if client, exists := manager.GetClient(name); exists {
fmt.Printf(" 状态: %s\n", client.GetStatus())
fmt.Printf(" 已连接: %v\n", client.IsConnected())
}
}
}
}
// 等待一下
time.Sleep(2 * time.Second)
// 测试获取工具列表
fmt.Println("\n=== 测试获取工具列表 ===")
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
tools, err := manager.GetAllTools(ctx)
if err != nil {
fmt.Printf("获取工具列表失败: %v\n", err)
} else {
fmt.Printf("获取到 %d 个工具\n", len(tools))
for i, tool := range tools {
if i < 5 { // 只显示前5个
fmt.Printf(" - %s: %s\n", tool.Name, tool.Description)
}
}
if len(tools) > 5 {
fmt.Printf(" ... 还有 %d 个工具\n", len(tools)-5)
}
}
// 测试停止
fmt.Println("\n=== 测试停止 ===")
for name := range cfg.ExternalMCP.Servers {
fmt.Printf("\n停止 %s...\n", name)
err := manager.StopClient(name)
if err != nil {
fmt.Printf(" 停止失败: %v\n", err)
} else {
fmt.Printf(" 停止成功\n")
}
}
// 最终统计
fmt.Println("\n=== 最终统计 ===")
stats = manager.GetStats()
fmt.Printf("总数: %d\n", stats["total"])
fmt.Printf("已启用: %d\n", stats["enabled"])
fmt.Printf("已停用: %d\n", stats["disabled"])
fmt.Printf("已连接: %d\n", stats["connected"])
fmt.Println("\n=== 测试完成 ===")
}
func getTransport(srv config.ExternalMCPServerConfig) string {
if srv.Transport != "" {
return srv.Transport
}
if srv.Command != "" {
return "stdio"
}
if srv.URL != "" {
return "http"
}
return "unknown"
}