Files
CyberStrikeAI/internal/mcp/client.go
2025-11-15 17:50:47 +08:00

475 lines
10 KiB
Go

package mcp
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"os/exec"
"sync"
"time"
"github.com/google/uuid"
"go.uber.org/zap"
)
// ExternalMCPClient 外部MCP客户端接口
type ExternalMCPClient interface {
// Initialize 初始化连接
Initialize(ctx context.Context) error
// ListTools 列出工具
ListTools(ctx context.Context) ([]Tool, error)
// CallTool 调用工具
CallTool(ctx context.Context, name string, args map[string]interface{}) (*ToolResult, error)
// Close 关闭连接
Close() error
// IsConnected 检查是否已连接
IsConnected() bool
// GetStatus 获取状态
GetStatus() string
}
// HTTPMCPClient HTTP模式的MCP客户端
type HTTPMCPClient struct {
url string
timeout time.Duration
client *http.Client
logger *zap.Logger
mu sync.RWMutex
status string // "disconnected", "connecting", "connected", "error"
}
// NewHTTPMCPClient 创建HTTP模式的MCP客户端
func NewHTTPMCPClient(url string, timeout time.Duration, logger *zap.Logger) *HTTPMCPClient {
if timeout <= 0 {
timeout = 30 * time.Second
}
return &HTTPMCPClient{
url: url,
timeout: timeout,
client: &http.Client{
Timeout: timeout,
},
logger: logger,
status: "disconnected",
}
}
func (c *HTTPMCPClient) setStatus(status string) {
c.mu.Lock()
defer c.mu.Unlock()
c.status = status
}
func (c *HTTPMCPClient) GetStatus() string {
c.mu.RLock()
defer c.mu.RUnlock()
return c.status
}
func (c *HTTPMCPClient) IsConnected() bool {
return c.GetStatus() == "connected"
}
func (c *HTTPMCPClient) Initialize(ctx context.Context) error {
c.setStatus("connecting")
req := Message{
ID: MessageID{value: "1"},
Method: "initialize",
Version: "2.0",
}
params := InitializeRequest{
ProtocolVersion: ProtocolVersion,
Capabilities: make(map[string]interface{}),
ClientInfo: ClientInfo{
Name: "CyberStrikeAI",
Version: "1.0.0",
},
}
paramsJSON, _ := json.Marshal(params)
req.Params = paramsJSON
_, err := c.sendRequest(ctx, &req)
if err != nil {
c.setStatus("error")
return fmt.Errorf("初始化失败: %w", err)
}
c.setStatus("connected")
return nil
}
func (c *HTTPMCPClient) ListTools(ctx context.Context) ([]Tool, error) {
req := Message{
ID: MessageID{value: uuid.New().String()},
Method: "tools/list",
Version: "2.0",
}
req.Params = json.RawMessage("{}")
resp, err := c.sendRequest(ctx, &req)
if err != nil {
return nil, fmt.Errorf("获取工具列表失败: %w", err)
}
var listResp ListToolsResponse
if err := json.Unmarshal(resp.Result, &listResp); err != nil {
return nil, fmt.Errorf("解析工具列表失败: %w", err)
}
return listResp.Tools, nil
}
func (c *HTTPMCPClient) CallTool(ctx context.Context, name string, args map[string]interface{}) (*ToolResult, error) {
req := Message{
ID: MessageID{value: uuid.New().String()},
Method: "tools/call",
Version: "2.0",
}
callReq := CallToolRequest{
Name: name,
Arguments: args,
}
paramsJSON, _ := json.Marshal(callReq)
req.Params = paramsJSON
resp, err := c.sendRequest(ctx, &req)
if err != nil {
return nil, fmt.Errorf("调用工具失败: %w", err)
}
var callResp CallToolResponse
if err := json.Unmarshal(resp.Result, &callResp); err != nil {
return nil, fmt.Errorf("解析工具调用结果失败: %w", err)
}
return &ToolResult{
Content: callResp.Content,
IsError: callResp.IsError,
}, nil
}
func (c *HTTPMCPClient) sendRequest(ctx context.Context, msg *Message) (*Message, error) {
body, err := json.Marshal(msg)
if err != nil {
return nil, fmt.Errorf("序列化请求失败: %w", err)
}
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.url, bytes.NewReader(body))
if err != nil {
return nil, fmt.Errorf("创建HTTP请求失败: %w", err)
}
httpReq.Header.Set("Content-Type", "application/json")
resp, err := c.client.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("HTTP请求失败: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
bodyBytes, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("HTTP错误 %d: %s", resp.StatusCode, string(bodyBytes))
}
var mcpResp Message
if err := json.NewDecoder(resp.Body).Decode(&mcpResp); err != nil {
return nil, fmt.Errorf("解析响应失败: %w", err)
}
if mcpResp.Error != nil {
return nil, fmt.Errorf("MCP错误: %s (code: %d)", mcpResp.Error.Message, mcpResp.Error.Code)
}
return &mcpResp, nil
}
func (c *HTTPMCPClient) Close() error {
c.setStatus("disconnected")
return nil
}
// StdioMCPClient stdio模式的MCP客户端
type StdioMCPClient struct {
command string
args []string
timeout time.Duration
cmd *exec.Cmd
stdin io.WriteCloser
stdout io.ReadCloser
decoder *json.Decoder
encoder *json.Encoder
logger *zap.Logger
mu sync.RWMutex
status string
requestID int64
responses map[string]chan *Message
responsesMu sync.Mutex
ctx context.Context
cancel context.CancelFunc
}
// NewStdioMCPClient 创建stdio模式的MCP客户端
func NewStdioMCPClient(command string, args []string, timeout time.Duration, logger *zap.Logger) *StdioMCPClient {
if timeout <= 0 {
timeout = 30 * time.Second
}
ctx, cancel := context.WithCancel(context.Background())
return &StdioMCPClient{
command: command,
args: args,
timeout: timeout,
logger: logger,
status: "disconnected",
responses: make(map[string]chan *Message),
ctx: ctx,
cancel: cancel,
}
}
func (c *StdioMCPClient) setStatus(status string) {
c.mu.Lock()
defer c.mu.Unlock()
c.status = status
}
func (c *StdioMCPClient) GetStatus() string {
c.mu.RLock()
defer c.mu.RUnlock()
return c.status
}
func (c *StdioMCPClient) IsConnected() bool {
return c.GetStatus() == "connected"
}
func (c *StdioMCPClient) Initialize(ctx context.Context) error {
c.setStatus("connecting")
if err := c.startProcess(); err != nil {
c.setStatus("error")
return fmt.Errorf("启动进程失败: %w", err)
}
// 启动响应读取goroutine
go c.readResponses()
// 发送初始化请求
req := Message{
ID: MessageID{value: "1"},
Method: "initialize",
Version: "2.0",
}
params := InitializeRequest{
ProtocolVersion: ProtocolVersion,
Capabilities: make(map[string]interface{}),
ClientInfo: ClientInfo{
Name: "CyberStrikeAI",
Version: "1.0.0",
},
}
paramsJSON, _ := json.Marshal(params)
req.Params = paramsJSON
_, err := c.sendRequest(ctx, &req)
if err != nil {
c.setStatus("error")
c.Close()
return fmt.Errorf("初始化失败: %w", err)
}
c.setStatus("connected")
return nil
}
func (c *StdioMCPClient) startProcess() error {
cmd := exec.CommandContext(c.ctx, c.command, c.args...)
stdin, err := cmd.StdinPipe()
if err != nil {
return err
}
stdout, err := cmd.StdoutPipe()
if err != nil {
stdin.Close()
return err
}
if err := cmd.Start(); err != nil {
stdin.Close()
stdout.Close()
return err
}
c.cmd = cmd
c.stdin = stdin
c.stdout = stdout
c.decoder = json.NewDecoder(stdout)
c.encoder = json.NewEncoder(stdin)
return nil
}
func (c *StdioMCPClient) readResponses() {
defer func() {
if r := recover(); r != nil {
c.logger.Error("读取响应时发生panic", zap.Any("error", r))
}
}()
for {
var msg Message
if err := c.decoder.Decode(&msg); err != nil {
if err == io.EOF {
c.setStatus("disconnected")
break
}
c.logger.Error("读取响应失败", zap.Error(err))
break
}
// 处理响应
id := msg.ID.String()
c.responsesMu.Lock()
if ch, ok := c.responses[id]; ok {
select {
case ch <- &msg:
default:
}
delete(c.responses, id)
}
c.responsesMu.Unlock()
}
}
func (c *StdioMCPClient) sendRequest(ctx context.Context, msg *Message) (*Message, error) {
if c.encoder == nil {
return nil, fmt.Errorf("进程未启动")
}
id := msg.ID.String()
if id == "" {
c.mu.Lock()
c.requestID++
id = fmt.Sprintf("%d", c.requestID)
msg.ID = MessageID{value: id}
c.mu.Unlock()
}
// 创建响应通道
responseCh := make(chan *Message, 1)
c.responsesMu.Lock()
c.responses[id] = responseCh
c.responsesMu.Unlock()
// 发送请求
if err := c.encoder.Encode(msg); err != nil {
c.responsesMu.Lock()
delete(c.responses, id)
c.responsesMu.Unlock()
return nil, fmt.Errorf("发送请求失败: %w", err)
}
// 等待响应
select {
case resp := <-responseCh:
if resp.Error != nil {
return nil, fmt.Errorf("MCP错误: %s (code: %d)", resp.Error.Message, resp.Error.Code)
}
return resp, nil
case <-ctx.Done():
c.responsesMu.Lock()
delete(c.responses, id)
c.responsesMu.Unlock()
return nil, ctx.Err()
case <-time.After(c.timeout):
c.responsesMu.Lock()
delete(c.responses, id)
c.responsesMu.Unlock()
return nil, fmt.Errorf("请求超时")
}
}
func (c *StdioMCPClient) ListTools(ctx context.Context) ([]Tool, error) {
req := Message{
ID: MessageID{value: uuid.New().String()},
Method: "tools/list",
Version: "2.0",
}
req.Params = json.RawMessage("{}")
resp, err := c.sendRequest(ctx, &req)
if err != nil {
return nil, fmt.Errorf("获取工具列表失败: %w", err)
}
var listResp ListToolsResponse
if err := json.Unmarshal(resp.Result, &listResp); err != nil {
return nil, fmt.Errorf("解析工具列表失败: %w", err)
}
return listResp.Tools, nil
}
func (c *StdioMCPClient) CallTool(ctx context.Context, name string, args map[string]interface{}) (*ToolResult, error) {
req := Message{
ID: MessageID{value: uuid.New().String()},
Method: "tools/call",
Version: "2.0",
}
callReq := CallToolRequest{
Name: name,
Arguments: args,
}
paramsJSON, _ := json.Marshal(callReq)
req.Params = paramsJSON
resp, err := c.sendRequest(ctx, &req)
if err != nil {
return nil, fmt.Errorf("调用工具失败: %w", err)
}
var callResp CallToolResponse
if err := json.Unmarshal(resp.Result, &callResp); err != nil {
return nil, fmt.Errorf("解析工具调用结果失败: %w", err)
}
return &ToolResult{
Content: callResp.Content,
IsError: callResp.IsError,
}, nil
}
func (c *StdioMCPClient) Close() error {
c.cancel()
if c.stdin != nil {
c.stdin.Close()
}
if c.stdout != nil {
c.stdout.Close()
}
if c.cmd != nil {
c.cmd.Process.Kill()
c.cmd.Wait()
}
c.setStatus("disconnected")
return nil
}