mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-03-31 16:20:28 +02:00
396 lines
8.7 KiB
Go
396 lines
8.7 KiB
Go
package main
|
||
|
||
import (
|
||
"encoding/json"
|
||
"fmt"
|
||
"log"
|
||
"net/http"
|
||
"sync"
|
||
"time"
|
||
|
||
"github.com/google/uuid"
|
||
)
|
||
|
||
const ProtocolVersion = "2024-11-05"
|
||
|
||
// Message MCP消息
|
||
type Message struct {
|
||
ID interface{} `json:"id,omitempty"`
|
||
Method string `json:"method,omitempty"`
|
||
Params json.RawMessage `json:"params,omitempty"`
|
||
Result json.RawMessage `json:"result,omitempty"`
|
||
Error *Error `json:"error,omitempty"`
|
||
Version string `json:"jsonrpc,omitempty"`
|
||
}
|
||
|
||
// Error MCP错误
|
||
type Error struct {
|
||
Code int `json:"code"`
|
||
Message string `json:"message"`
|
||
Data interface{} `json:"data,omitempty"`
|
||
}
|
||
|
||
// InitializeRequest 初始化请求
|
||
type InitializeRequest struct {
|
||
ProtocolVersion string `json:"protocolVersion"`
|
||
Capabilities map[string]interface{} `json:"capabilities"`
|
||
ClientInfo ClientInfo `json:"clientInfo"`
|
||
}
|
||
|
||
// ClientInfo 客户端信息
|
||
type ClientInfo struct {
|
||
Name string `json:"name"`
|
||
Version string `json:"version"`
|
||
}
|
||
|
||
// InitializeResponse 初始化响应
|
||
type InitializeResponse struct {
|
||
ProtocolVersion string `json:"protocolVersion"`
|
||
Capabilities ServerCapabilities `json:"capabilities"`
|
||
ServerInfo ServerInfo `json:"serverInfo"`
|
||
}
|
||
|
||
// ServerCapabilities 服务器能力
|
||
type ServerCapabilities struct {
|
||
Tools map[string]interface{} `json:"tools,omitempty"`
|
||
}
|
||
|
||
// ServerInfo 服务器信息
|
||
type ServerInfo struct {
|
||
Name string `json:"name"`
|
||
Version string `json:"version"`
|
||
}
|
||
|
||
// Tool 工具定义
|
||
type Tool struct {
|
||
Name string `json:"name"`
|
||
Description string `json:"description"`
|
||
InputSchema map[string]interface{} `json:"inputSchema"`
|
||
}
|
||
|
||
// ListToolsResponse 列出工具响应
|
||
type ListToolsResponse struct {
|
||
Tools []Tool `json:"tools"`
|
||
}
|
||
|
||
// CallToolRequest 调用工具请求
|
||
type CallToolRequest struct {
|
||
Name string `json:"name"`
|
||
Arguments map[string]interface{} `json:"arguments"`
|
||
}
|
||
|
||
// CallToolResponse 调用工具响应
|
||
type CallToolResponse struct {
|
||
Content []Content `json:"content"`
|
||
IsError bool `json:"isError,omitempty"`
|
||
}
|
||
|
||
// Content 内容
|
||
type Content struct {
|
||
Type string `json:"type"`
|
||
Text string `json:"text"`
|
||
}
|
||
|
||
// SSEServer SSE MCP服务器
|
||
type SSEServer struct {
|
||
sseClients map[string]chan []byte
|
||
mu sync.RWMutex
|
||
}
|
||
|
||
func NewSSEServer() *SSEServer {
|
||
return &SSEServer{
|
||
sseClients: make(map[string]chan []byte),
|
||
}
|
||
}
|
||
|
||
// handleSSE 处理SSE连接
|
||
func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) {
|
||
flusher, ok := w.(http.Flusher)
|
||
if !ok {
|
||
http.Error(w, "Streaming unsupported", http.StatusInternalServerError)
|
||
return
|
||
}
|
||
|
||
w.Header().Set("Content-Type", "text/event-stream")
|
||
w.Header().Set("Cache-Control", "no-cache")
|
||
w.Header().Set("Connection", "keep-alive")
|
||
w.Header().Set("X-Accel-Buffering", "no")
|
||
|
||
clientID := uuid.New().String()
|
||
clientChan := make(chan []byte, 10)
|
||
|
||
s.mu.Lock()
|
||
s.sseClients[clientID] = clientChan
|
||
s.mu.Unlock()
|
||
|
||
defer func() {
|
||
s.mu.Lock()
|
||
delete(s.sseClients, clientID)
|
||
close(clientChan)
|
||
s.mu.Unlock()
|
||
}()
|
||
|
||
// 发送初始ready事件
|
||
fmt.Fprintf(w, "event: message\ndata: {\"type\":\"ready\",\"status\":\"ok\"}\n\n")
|
||
flusher.Flush()
|
||
|
||
log.Printf("SSE客户端连接: %s", clientID)
|
||
|
||
// 心跳
|
||
ticker := time.NewTicker(15 * time.Second)
|
||
defer ticker.Stop()
|
||
|
||
for {
|
||
select {
|
||
case <-r.Context().Done():
|
||
log.Printf("SSE客户端断开: %s", clientID)
|
||
return
|
||
case msg, ok := <-clientChan:
|
||
if !ok {
|
||
return
|
||
}
|
||
fmt.Fprintf(w, "event: message\ndata: %s\n\n", msg)
|
||
flusher.Flush()
|
||
case <-ticker.C:
|
||
// 心跳
|
||
fmt.Fprintf(w, ": ping\n\n")
|
||
flusher.Flush()
|
||
}
|
||
}
|
||
}
|
||
|
||
// handleMessage 处理POST消息
|
||
func (s *SSEServer) handleMessage(w http.ResponseWriter, r *http.Request) {
|
||
if r.Method != http.MethodPost {
|
||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||
return
|
||
}
|
||
|
||
var msg Message
|
||
if err := json.NewDecoder(r.Body).Decode(&msg); err != nil {
|
||
http.Error(w, "Invalid JSON", http.StatusBadRequest)
|
||
return
|
||
}
|
||
|
||
log.Printf("收到请求: method=%s, id=%v", msg.Method, msg.ID)
|
||
|
||
// 处理消息
|
||
response := s.processMessage(&msg)
|
||
|
||
// 如果有SSE客户端,通过SSE推送响应
|
||
if response != nil {
|
||
responseJSON, _ := json.Marshal(response)
|
||
s.mu.RLock()
|
||
// 发送给所有SSE客户端
|
||
for _, ch := range s.sseClients {
|
||
select {
|
||
case ch <- responseJSON:
|
||
default:
|
||
}
|
||
}
|
||
s.mu.RUnlock()
|
||
}
|
||
|
||
// 也直接返回响应(兼容非SSE模式)
|
||
if response != nil {
|
||
w.Header().Set("Content-Type", "application/json")
|
||
json.NewEncoder(w).Encode(response)
|
||
} else {
|
||
w.WriteHeader(http.StatusOK)
|
||
}
|
||
}
|
||
|
||
// processMessage 处理MCP消息
|
||
func (s *SSEServer) processMessage(msg *Message) *Message {
|
||
switch msg.Method {
|
||
case "initialize":
|
||
return s.handleInitialize(msg)
|
||
case "tools/list":
|
||
return s.handleListTools(msg)
|
||
case "tools/call":
|
||
return s.handleCallTool(msg)
|
||
default:
|
||
return &Message{
|
||
ID: msg.ID,
|
||
Version: "2.0",
|
||
Error: &Error{
|
||
Code: -32601,
|
||
Message: "Method not found",
|
||
},
|
||
}
|
||
}
|
||
}
|
||
|
||
// handleInitialize 处理初始化
|
||
func (s *SSEServer) handleInitialize(msg *Message) *Message {
|
||
var req InitializeRequest
|
||
if err := json.Unmarshal(msg.Params, &req); err != nil {
|
||
return &Message{
|
||
ID: msg.ID,
|
||
Version: "2.0",
|
||
Error: &Error{
|
||
Code: -32602,
|
||
Message: "Invalid params",
|
||
},
|
||
}
|
||
}
|
||
|
||
log.Printf("初始化请求: client=%s, version=%s", req.ClientInfo.Name, req.ClientInfo.Version)
|
||
|
||
response := InitializeResponse{
|
||
ProtocolVersion: ProtocolVersion,
|
||
Capabilities: ServerCapabilities{
|
||
Tools: map[string]interface{}{
|
||
"listChanged": true,
|
||
},
|
||
},
|
||
ServerInfo: ServerInfo{
|
||
Name: "Test SSE MCP Server",
|
||
Version: "1.0.0",
|
||
},
|
||
}
|
||
|
||
result, _ := json.Marshal(response)
|
||
return &Message{
|
||
ID: msg.ID,
|
||
Version: "2.0",
|
||
Result: result,
|
||
}
|
||
}
|
||
|
||
// handleListTools 处理列出工具
|
||
func (s *SSEServer) handleListTools(msg *Message) *Message {
|
||
tools := []Tool{
|
||
{
|
||
Name: "test_echo",
|
||
Description: "回显输入的文本,用于测试SSE MCP服务器",
|
||
InputSchema: map[string]interface{}{
|
||
"type": "object",
|
||
"properties": map[string]interface{}{
|
||
"text": map[string]interface{}{
|
||
"type": "string",
|
||
"description": "要回显的文本",
|
||
},
|
||
},
|
||
"required": []string{"text"},
|
||
},
|
||
},
|
||
{
|
||
Name: "test_add",
|
||
Description: "计算两个数字的和,用于测试SSE MCP服务器",
|
||
InputSchema: map[string]interface{}{
|
||
"type": "object",
|
||
"properties": map[string]interface{}{
|
||
"a": map[string]interface{}{
|
||
"type": "number",
|
||
"description": "第一个数字",
|
||
},
|
||
"b": map[string]interface{}{
|
||
"type": "number",
|
||
"description": "第二个数字",
|
||
},
|
||
},
|
||
"required": []string{"a", "b"},
|
||
},
|
||
},
|
||
}
|
||
|
||
response := ListToolsResponse{Tools: tools}
|
||
result, _ := json.Marshal(response)
|
||
return &Message{
|
||
ID: msg.ID,
|
||
Version: "2.0",
|
||
Result: result,
|
||
}
|
||
}
|
||
|
||
// handleCallTool 处理工具调用
|
||
func (s *SSEServer) handleCallTool(msg *Message) *Message {
|
||
var req CallToolRequest
|
||
if err := json.Unmarshal(msg.Params, &req); err != nil {
|
||
return &Message{
|
||
ID: msg.ID,
|
||
Version: "2.0",
|
||
Error: &Error{
|
||
Code: -32602,
|
||
Message: "Invalid params",
|
||
},
|
||
}
|
||
}
|
||
|
||
log.Printf("调用工具: name=%s, args=%v", req.Name, req.Arguments)
|
||
|
||
var content []Content
|
||
|
||
switch req.Name {
|
||
case "test_echo":
|
||
text, _ := req.Arguments["text"].(string)
|
||
content = []Content{
|
||
{
|
||
Type: "text",
|
||
Text: fmt.Sprintf("回显: %s", text),
|
||
},
|
||
}
|
||
case "test_add":
|
||
var a, b float64
|
||
if val, ok := req.Arguments["a"].(float64); ok {
|
||
a = val
|
||
}
|
||
if val, ok := req.Arguments["b"].(float64); ok {
|
||
b = val
|
||
}
|
||
sum := a + b
|
||
content = []Content{
|
||
{
|
||
Type: "text",
|
||
Text: fmt.Sprintf("%.2f + %.2f = %.2f", a, b, sum),
|
||
},
|
||
}
|
||
default:
|
||
return &Message{
|
||
ID: msg.ID,
|
||
Version: "2.0",
|
||
Error: &Error{
|
||
Code: -32601,
|
||
Message: "Tool not found",
|
||
},
|
||
}
|
||
}
|
||
|
||
response := CallToolResponse{
|
||
Content: content,
|
||
IsError: false,
|
||
}
|
||
|
||
result, _ := json.Marshal(response)
|
||
return &Message{
|
||
ID: msg.ID,
|
||
Version: "2.0",
|
||
Result: result,
|
||
}
|
||
}
|
||
|
||
func main() {
|
||
server := NewSSEServer()
|
||
|
||
http.HandleFunc("/sse", server.handleSSE)
|
||
http.HandleFunc("/message", server.handleMessage)
|
||
|
||
port := ":8082"
|
||
log.Printf("SSE MCP测试服务器启动在端口 %s", port)
|
||
log.Printf("SSE端点: http://localhost%s/sse", port)
|
||
log.Printf("消息端点: http://localhost%s/message", port)
|
||
log.Printf("配置示例:")
|
||
log.Printf(`{
|
||
"test-sse-mcp": {
|
||
"transport": "sse",
|
||
"url": "http://127.0.0.1:8082/sse"
|
||
}
|
||
}`)
|
||
|
||
if err := http.ListenAndServe(port, nil); err != nil {
|
||
log.Fatal(err)
|
||
}
|
||
}
|
||
|