Files
CyberStrikeAI/cmd/test-sse-mcp-server/main.go
2026-01-08 22:43:41 +08:00

396 lines
8.7 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package main
import (
"encoding/json"
"fmt"
"log"
"net/http"
"sync"
"time"
"github.com/google/uuid"
)
const ProtocolVersion = "2024-11-05"
// Message MCP消息
type Message struct {
ID interface{} `json:"id,omitempty"`
Method string `json:"method,omitempty"`
Params json.RawMessage `json:"params,omitempty"`
Result json.RawMessage `json:"result,omitempty"`
Error *Error `json:"error,omitempty"`
Version string `json:"jsonrpc,omitempty"`
}
// Error MCP错误
type Error struct {
Code int `json:"code"`
Message string `json:"message"`
Data interface{} `json:"data,omitempty"`
}
// InitializeRequest 初始化请求
type InitializeRequest struct {
ProtocolVersion string `json:"protocolVersion"`
Capabilities map[string]interface{} `json:"capabilities"`
ClientInfo ClientInfo `json:"clientInfo"`
}
// ClientInfo 客户端信息
type ClientInfo struct {
Name string `json:"name"`
Version string `json:"version"`
}
// InitializeResponse 初始化响应
type InitializeResponse struct {
ProtocolVersion string `json:"protocolVersion"`
Capabilities ServerCapabilities `json:"capabilities"`
ServerInfo ServerInfo `json:"serverInfo"`
}
// ServerCapabilities 服务器能力
type ServerCapabilities struct {
Tools map[string]interface{} `json:"tools,omitempty"`
}
// ServerInfo 服务器信息
type ServerInfo struct {
Name string `json:"name"`
Version string `json:"version"`
}
// Tool 工具定义
type Tool struct {
Name string `json:"name"`
Description string `json:"description"`
InputSchema map[string]interface{} `json:"inputSchema"`
}
// ListToolsResponse 列出工具响应
type ListToolsResponse struct {
Tools []Tool `json:"tools"`
}
// CallToolRequest 调用工具请求
type CallToolRequest struct {
Name string `json:"name"`
Arguments map[string]interface{} `json:"arguments"`
}
// CallToolResponse 调用工具响应
type CallToolResponse struct {
Content []Content `json:"content"`
IsError bool `json:"isError,omitempty"`
}
// Content 内容
type Content struct {
Type string `json:"type"`
Text string `json:"text"`
}
// SSEServer SSE MCP服务器
type SSEServer struct {
sseClients map[string]chan []byte
mu sync.RWMutex
}
func NewSSEServer() *SSEServer {
return &SSEServer{
sseClients: make(map[string]chan []byte),
}
}
// handleSSE 处理SSE连接
func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) {
flusher, ok := w.(http.Flusher)
if !ok {
http.Error(w, "Streaming unsupported", http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
w.Header().Set("X-Accel-Buffering", "no")
clientID := uuid.New().String()
clientChan := make(chan []byte, 10)
s.mu.Lock()
s.sseClients[clientID] = clientChan
s.mu.Unlock()
defer func() {
s.mu.Lock()
delete(s.sseClients, clientID)
close(clientChan)
s.mu.Unlock()
}()
// 发送初始ready事件
fmt.Fprintf(w, "event: message\ndata: {\"type\":\"ready\",\"status\":\"ok\"}\n\n")
flusher.Flush()
log.Printf("SSE客户端连接: %s", clientID)
// 心跳
ticker := time.NewTicker(15 * time.Second)
defer ticker.Stop()
for {
select {
case <-r.Context().Done():
log.Printf("SSE客户端断开: %s", clientID)
return
case msg, ok := <-clientChan:
if !ok {
return
}
fmt.Fprintf(w, "event: message\ndata: %s\n\n", msg)
flusher.Flush()
case <-ticker.C:
// 心跳
fmt.Fprintf(w, ": ping\n\n")
flusher.Flush()
}
}
}
// handleMessage 处理POST消息
func (s *SSEServer) handleMessage(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var msg Message
if err := json.NewDecoder(r.Body).Decode(&msg); err != nil {
http.Error(w, "Invalid JSON", http.StatusBadRequest)
return
}
log.Printf("收到请求: method=%s, id=%v", msg.Method, msg.ID)
// 处理消息
response := s.processMessage(&msg)
// 如果有SSE客户端通过SSE推送响应
if response != nil {
responseJSON, _ := json.Marshal(response)
s.mu.RLock()
// 发送给所有SSE客户端
for _, ch := range s.sseClients {
select {
case ch <- responseJSON:
default:
}
}
s.mu.RUnlock()
}
// 也直接返回响应兼容非SSE模式
if response != nil {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(response)
} else {
w.WriteHeader(http.StatusOK)
}
}
// processMessage 处理MCP消息
func (s *SSEServer) processMessage(msg *Message) *Message {
switch msg.Method {
case "initialize":
return s.handleInitialize(msg)
case "tools/list":
return s.handleListTools(msg)
case "tools/call":
return s.handleCallTool(msg)
default:
return &Message{
ID: msg.ID,
Version: "2.0",
Error: &Error{
Code: -32601,
Message: "Method not found",
},
}
}
}
// handleInitialize 处理初始化
func (s *SSEServer) handleInitialize(msg *Message) *Message {
var req InitializeRequest
if err := json.Unmarshal(msg.Params, &req); err != nil {
return &Message{
ID: msg.ID,
Version: "2.0",
Error: &Error{
Code: -32602,
Message: "Invalid params",
},
}
}
log.Printf("初始化请求: client=%s, version=%s", req.ClientInfo.Name, req.ClientInfo.Version)
response := InitializeResponse{
ProtocolVersion: ProtocolVersion,
Capabilities: ServerCapabilities{
Tools: map[string]interface{}{
"listChanged": true,
},
},
ServerInfo: ServerInfo{
Name: "Test SSE MCP Server",
Version: "1.0.0",
},
}
result, _ := json.Marshal(response)
return &Message{
ID: msg.ID,
Version: "2.0",
Result: result,
}
}
// handleListTools 处理列出工具
func (s *SSEServer) handleListTools(msg *Message) *Message {
tools := []Tool{
{
Name: "test_echo",
Description: "回显输入的文本用于测试SSE MCP服务器",
InputSchema: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"text": map[string]interface{}{
"type": "string",
"description": "要回显的文本",
},
},
"required": []string{"text"},
},
},
{
Name: "test_add",
Description: "计算两个数字的和用于测试SSE MCP服务器",
InputSchema: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"a": map[string]interface{}{
"type": "number",
"description": "第一个数字",
},
"b": map[string]interface{}{
"type": "number",
"description": "第二个数字",
},
},
"required": []string{"a", "b"},
},
},
}
response := ListToolsResponse{Tools: tools}
result, _ := json.Marshal(response)
return &Message{
ID: msg.ID,
Version: "2.0",
Result: result,
}
}
// handleCallTool 处理工具调用
func (s *SSEServer) handleCallTool(msg *Message) *Message {
var req CallToolRequest
if err := json.Unmarshal(msg.Params, &req); err != nil {
return &Message{
ID: msg.ID,
Version: "2.0",
Error: &Error{
Code: -32602,
Message: "Invalid params",
},
}
}
log.Printf("调用工具: name=%s, args=%v", req.Name, req.Arguments)
var content []Content
switch req.Name {
case "test_echo":
text, _ := req.Arguments["text"].(string)
content = []Content{
{
Type: "text",
Text: fmt.Sprintf("回显: %s", text),
},
}
case "test_add":
var a, b float64
if val, ok := req.Arguments["a"].(float64); ok {
a = val
}
if val, ok := req.Arguments["b"].(float64); ok {
b = val
}
sum := a + b
content = []Content{
{
Type: "text",
Text: fmt.Sprintf("%.2f + %.2f = %.2f", a, b, sum),
},
}
default:
return &Message{
ID: msg.ID,
Version: "2.0",
Error: &Error{
Code: -32601,
Message: "Tool not found",
},
}
}
response := CallToolResponse{
Content: content,
IsError: false,
}
result, _ := json.Marshal(response)
return &Message{
ID: msg.ID,
Version: "2.0",
Result: result,
}
}
func main() {
server := NewSSEServer()
http.HandleFunc("/sse", server.handleSSE)
http.HandleFunc("/message", server.handleMessage)
port := ":8082"
log.Printf("SSE MCP测试服务器启动在端口 %s", port)
log.Printf("SSE端点: http://localhost%s/sse", port)
log.Printf("消息端点: http://localhost%s/message", port)
log.Printf("配置示例:")
log.Printf(`{
"test-sse-mcp": {
"transport": "sse",
"url": "http://127.0.0.1:8082/sse"
}
}`)
if err := http.ListenAndServe(port, nil); err != nil {
log.Fatal(err)
}
}