mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-04-22 02:36:40 +02:00
396 lines
8.7 KiB
Go
396 lines
8.7 KiB
Go
package main
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"log"
|
|
"net/http"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
)
|
|
|
|
const ProtocolVersion = "2024-11-05"
|
|
|
|
// Message MCP消息
|
|
type Message struct {
|
|
ID interface{} `json:"id,omitempty"`
|
|
Method string `json:"method,omitempty"`
|
|
Params json.RawMessage `json:"params,omitempty"`
|
|
Result json.RawMessage `json:"result,omitempty"`
|
|
Error *Error `json:"error,omitempty"`
|
|
Version string `json:"jsonrpc,omitempty"`
|
|
}
|
|
|
|
// Error MCP错误
|
|
type Error struct {
|
|
Code int `json:"code"`
|
|
Message string `json:"message"`
|
|
Data interface{} `json:"data,omitempty"`
|
|
}
|
|
|
|
// InitializeRequest 初始化请求
|
|
type InitializeRequest struct {
|
|
ProtocolVersion string `json:"protocolVersion"`
|
|
Capabilities map[string]interface{} `json:"capabilities"`
|
|
ClientInfo ClientInfo `json:"clientInfo"`
|
|
}
|
|
|
|
// ClientInfo 客户端信息
|
|
type ClientInfo struct {
|
|
Name string `json:"name"`
|
|
Version string `json:"version"`
|
|
}
|
|
|
|
// InitializeResponse 初始化响应
|
|
type InitializeResponse struct {
|
|
ProtocolVersion string `json:"protocolVersion"`
|
|
Capabilities ServerCapabilities `json:"capabilities"`
|
|
ServerInfo ServerInfo `json:"serverInfo"`
|
|
}
|
|
|
|
// ServerCapabilities 服务器能力
|
|
type ServerCapabilities struct {
|
|
Tools map[string]interface{} `json:"tools,omitempty"`
|
|
}
|
|
|
|
// ServerInfo 服务器信息
|
|
type ServerInfo struct {
|
|
Name string `json:"name"`
|
|
Version string `json:"version"`
|
|
}
|
|
|
|
// Tool 工具定义
|
|
type Tool struct {
|
|
Name string `json:"name"`
|
|
Description string `json:"description"`
|
|
InputSchema map[string]interface{} `json:"inputSchema"`
|
|
}
|
|
|
|
// ListToolsResponse 列出工具响应
|
|
type ListToolsResponse struct {
|
|
Tools []Tool `json:"tools"`
|
|
}
|
|
|
|
// CallToolRequest 调用工具请求
|
|
type CallToolRequest struct {
|
|
Name string `json:"name"`
|
|
Arguments map[string]interface{} `json:"arguments"`
|
|
}
|
|
|
|
// CallToolResponse 调用工具响应
|
|
type CallToolResponse struct {
|
|
Content []Content `json:"content"`
|
|
IsError bool `json:"isError,omitempty"`
|
|
}
|
|
|
|
// Content 内容
|
|
type Content struct {
|
|
Type string `json:"type"`
|
|
Text string `json:"text"`
|
|
}
|
|
|
|
// SSEServer SSE MCP服务器
|
|
type SSEServer struct {
|
|
sseClients map[string]chan []byte
|
|
mu sync.RWMutex
|
|
}
|
|
|
|
func NewSSEServer() *SSEServer {
|
|
return &SSEServer{
|
|
sseClients: make(map[string]chan []byte),
|
|
}
|
|
}
|
|
|
|
// handleSSE 处理SSE连接
|
|
func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) {
|
|
flusher, ok := w.(http.Flusher)
|
|
if !ok {
|
|
http.Error(w, "Streaming unsupported", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "text/event-stream")
|
|
w.Header().Set("Cache-Control", "no-cache")
|
|
w.Header().Set("Connection", "keep-alive")
|
|
w.Header().Set("X-Accel-Buffering", "no")
|
|
|
|
clientID := uuid.New().String()
|
|
clientChan := make(chan []byte, 10)
|
|
|
|
s.mu.Lock()
|
|
s.sseClients[clientID] = clientChan
|
|
s.mu.Unlock()
|
|
|
|
defer func() {
|
|
s.mu.Lock()
|
|
delete(s.sseClients, clientID)
|
|
close(clientChan)
|
|
s.mu.Unlock()
|
|
}()
|
|
|
|
// 发送初始ready事件
|
|
fmt.Fprintf(w, "event: message\ndata: {\"type\":\"ready\",\"status\":\"ok\"}\n\n")
|
|
flusher.Flush()
|
|
|
|
log.Printf("SSE客户端连接: %s", clientID)
|
|
|
|
// 心跳
|
|
ticker := time.NewTicker(15 * time.Second)
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-r.Context().Done():
|
|
log.Printf("SSE客户端断开: %s", clientID)
|
|
return
|
|
case msg, ok := <-clientChan:
|
|
if !ok {
|
|
return
|
|
}
|
|
fmt.Fprintf(w, "event: message\ndata: %s\n\n", msg)
|
|
flusher.Flush()
|
|
case <-ticker.C:
|
|
// 心跳
|
|
fmt.Fprintf(w, ": ping\n\n")
|
|
flusher.Flush()
|
|
}
|
|
}
|
|
}
|
|
|
|
// handleMessage 处理POST消息
|
|
func (s *SSEServer) handleMessage(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodPost {
|
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
|
return
|
|
}
|
|
|
|
var msg Message
|
|
if err := json.NewDecoder(r.Body).Decode(&msg); err != nil {
|
|
http.Error(w, "Invalid JSON", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
log.Printf("收到请求: method=%s, id=%v", msg.Method, msg.ID)
|
|
|
|
// 处理消息
|
|
response := s.processMessage(&msg)
|
|
|
|
// 如果有SSE客户端,通过SSE推送响应
|
|
if response != nil {
|
|
responseJSON, _ := json.Marshal(response)
|
|
s.mu.RLock()
|
|
// 发送给所有SSE客户端
|
|
for _, ch := range s.sseClients {
|
|
select {
|
|
case ch <- responseJSON:
|
|
default:
|
|
}
|
|
}
|
|
s.mu.RUnlock()
|
|
}
|
|
|
|
// 也直接返回响应(兼容非SSE模式)
|
|
if response != nil {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
json.NewEncoder(w).Encode(response)
|
|
} else {
|
|
w.WriteHeader(http.StatusOK)
|
|
}
|
|
}
|
|
|
|
// processMessage 处理MCP消息
|
|
func (s *SSEServer) processMessage(msg *Message) *Message {
|
|
switch msg.Method {
|
|
case "initialize":
|
|
return s.handleInitialize(msg)
|
|
case "tools/list":
|
|
return s.handleListTools(msg)
|
|
case "tools/call":
|
|
return s.handleCallTool(msg)
|
|
default:
|
|
return &Message{
|
|
ID: msg.ID,
|
|
Version: "2.0",
|
|
Error: &Error{
|
|
Code: -32601,
|
|
Message: "Method not found",
|
|
},
|
|
}
|
|
}
|
|
}
|
|
|
|
// handleInitialize 处理初始化
|
|
func (s *SSEServer) handleInitialize(msg *Message) *Message {
|
|
var req InitializeRequest
|
|
if err := json.Unmarshal(msg.Params, &req); err != nil {
|
|
return &Message{
|
|
ID: msg.ID,
|
|
Version: "2.0",
|
|
Error: &Error{
|
|
Code: -32602,
|
|
Message: "Invalid params",
|
|
},
|
|
}
|
|
}
|
|
|
|
log.Printf("初始化请求: client=%s, version=%s", req.ClientInfo.Name, req.ClientInfo.Version)
|
|
|
|
response := InitializeResponse{
|
|
ProtocolVersion: ProtocolVersion,
|
|
Capabilities: ServerCapabilities{
|
|
Tools: map[string]interface{}{
|
|
"listChanged": true,
|
|
},
|
|
},
|
|
ServerInfo: ServerInfo{
|
|
Name: "Test SSE MCP Server",
|
|
Version: "1.0.0",
|
|
},
|
|
}
|
|
|
|
result, _ := json.Marshal(response)
|
|
return &Message{
|
|
ID: msg.ID,
|
|
Version: "2.0",
|
|
Result: result,
|
|
}
|
|
}
|
|
|
|
// handleListTools 处理列出工具
|
|
func (s *SSEServer) handleListTools(msg *Message) *Message {
|
|
tools := []Tool{
|
|
{
|
|
Name: "test_echo",
|
|
Description: "回显输入的文本,用于测试SSE MCP服务器",
|
|
InputSchema: map[string]interface{}{
|
|
"type": "object",
|
|
"properties": map[string]interface{}{
|
|
"text": map[string]interface{}{
|
|
"type": "string",
|
|
"description": "要回显的文本",
|
|
},
|
|
},
|
|
"required": []string{"text"},
|
|
},
|
|
},
|
|
{
|
|
Name: "test_add",
|
|
Description: "计算两个数字的和,用于测试SSE MCP服务器",
|
|
InputSchema: map[string]interface{}{
|
|
"type": "object",
|
|
"properties": map[string]interface{}{
|
|
"a": map[string]interface{}{
|
|
"type": "number",
|
|
"description": "第一个数字",
|
|
},
|
|
"b": map[string]interface{}{
|
|
"type": "number",
|
|
"description": "第二个数字",
|
|
},
|
|
},
|
|
"required": []string{"a", "b"},
|
|
},
|
|
},
|
|
}
|
|
|
|
response := ListToolsResponse{Tools: tools}
|
|
result, _ := json.Marshal(response)
|
|
return &Message{
|
|
ID: msg.ID,
|
|
Version: "2.0",
|
|
Result: result,
|
|
}
|
|
}
|
|
|
|
// handleCallTool 处理工具调用
|
|
func (s *SSEServer) handleCallTool(msg *Message) *Message {
|
|
var req CallToolRequest
|
|
if err := json.Unmarshal(msg.Params, &req); err != nil {
|
|
return &Message{
|
|
ID: msg.ID,
|
|
Version: "2.0",
|
|
Error: &Error{
|
|
Code: -32602,
|
|
Message: "Invalid params",
|
|
},
|
|
}
|
|
}
|
|
|
|
log.Printf("调用工具: name=%s, args=%v", req.Name, req.Arguments)
|
|
|
|
var content []Content
|
|
|
|
switch req.Name {
|
|
case "test_echo":
|
|
text, _ := req.Arguments["text"].(string)
|
|
content = []Content{
|
|
{
|
|
Type: "text",
|
|
Text: fmt.Sprintf("回显: %s", text),
|
|
},
|
|
}
|
|
case "test_add":
|
|
var a, b float64
|
|
if val, ok := req.Arguments["a"].(float64); ok {
|
|
a = val
|
|
}
|
|
if val, ok := req.Arguments["b"].(float64); ok {
|
|
b = val
|
|
}
|
|
sum := a + b
|
|
content = []Content{
|
|
{
|
|
Type: "text",
|
|
Text: fmt.Sprintf("%.2f + %.2f = %.2f", a, b, sum),
|
|
},
|
|
}
|
|
default:
|
|
return &Message{
|
|
ID: msg.ID,
|
|
Version: "2.0",
|
|
Error: &Error{
|
|
Code: -32601,
|
|
Message: "Tool not found",
|
|
},
|
|
}
|
|
}
|
|
|
|
response := CallToolResponse{
|
|
Content: content,
|
|
IsError: false,
|
|
}
|
|
|
|
result, _ := json.Marshal(response)
|
|
return &Message{
|
|
ID: msg.ID,
|
|
Version: "2.0",
|
|
Result: result,
|
|
}
|
|
}
|
|
|
|
func main() {
|
|
server := NewSSEServer()
|
|
|
|
http.HandleFunc("/sse", server.handleSSE)
|
|
http.HandleFunc("/message", server.handleMessage)
|
|
|
|
port := ":8082"
|
|
log.Printf("SSE MCP测试服务器启动在端口 %s", port)
|
|
log.Printf("SSE端点: http://localhost%s/sse", port)
|
|
log.Printf("消息端点: http://localhost%s/message", port)
|
|
log.Printf("配置示例:")
|
|
log.Printf(`{
|
|
"test-sse-mcp": {
|
|
"transport": "sse",
|
|
"url": "http://127.0.0.1:8082/sse"
|
|
}
|
|
}`)
|
|
|
|
if err := http.ListenAndServe(port, nil); err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
}
|
|
|