mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-03-31 00:09:29 +02:00
1238 lines
31 KiB
Go
1238 lines
31 KiB
Go
package mcp
|
||
|
||
import (
|
||
"bufio"
|
||
"context"
|
||
"encoding/json"
|
||
"fmt"
|
||
"io"
|
||
"net/http"
|
||
"os"
|
||
"sort"
|
||
"strings"
|
||
"sync"
|
||
"time"
|
||
|
||
"github.com/google/uuid"
|
||
"go.uber.org/zap"
|
||
)
|
||
|
||
// MonitorStorage 监控数据存储接口
|
||
type MonitorStorage interface {
|
||
SaveToolExecution(exec *ToolExecution) error
|
||
LoadToolExecutions() ([]*ToolExecution, error)
|
||
GetToolExecution(id string) (*ToolExecution, error)
|
||
SaveToolStats(toolName string, stats *ToolStats) error
|
||
LoadToolStats() (map[string]*ToolStats, error)
|
||
UpdateToolStats(toolName string, totalCalls, successCalls, failedCalls int, lastCallTime *time.Time) error
|
||
}
|
||
|
||
// Server MCP服务器
|
||
type Server struct {
|
||
tools map[string]ToolHandler
|
||
toolDefs map[string]Tool // 工具定义
|
||
executions map[string]*ToolExecution
|
||
stats map[string]*ToolStats
|
||
prompts map[string]*Prompt // 提示词模板
|
||
resources map[string]*Resource // 资源
|
||
storage MonitorStorage // 可选的持久化存储
|
||
mu sync.RWMutex
|
||
logger *zap.Logger
|
||
maxExecutionsInMemory int // 内存中最大执行记录数
|
||
sseClients map[string]*sseClient
|
||
}
|
||
|
||
type sseClient struct {
|
||
id string
|
||
send chan []byte
|
||
}
|
||
|
||
// ToolHandler 工具处理函数
|
||
type ToolHandler func(ctx context.Context, args map[string]interface{}) (*ToolResult, error)
|
||
|
||
// NewServer 创建新的MCP服务器
|
||
func NewServer(logger *zap.Logger) *Server {
|
||
return NewServerWithStorage(logger, nil)
|
||
}
|
||
|
||
// NewServerWithStorage 创建新的MCP服务器(带持久化存储)
|
||
func NewServerWithStorage(logger *zap.Logger, storage MonitorStorage) *Server {
|
||
s := &Server{
|
||
tools: make(map[string]ToolHandler),
|
||
toolDefs: make(map[string]Tool),
|
||
executions: make(map[string]*ToolExecution),
|
||
stats: make(map[string]*ToolStats),
|
||
prompts: make(map[string]*Prompt),
|
||
resources: make(map[string]*Resource),
|
||
storage: storage,
|
||
logger: logger,
|
||
maxExecutionsInMemory: 1000, // 默认最多在内存中保留1000条执行记录
|
||
sseClients: make(map[string]*sseClient),
|
||
}
|
||
|
||
// 初始化默认提示词和资源
|
||
s.initDefaultPrompts()
|
||
s.initDefaultResources()
|
||
|
||
return s
|
||
}
|
||
|
||
// RegisterTool 注册工具
|
||
func (s *Server) RegisterTool(tool Tool, handler ToolHandler) {
|
||
s.mu.Lock()
|
||
defer s.mu.Unlock()
|
||
s.tools[tool.Name] = handler
|
||
s.toolDefs[tool.Name] = tool
|
||
|
||
// 自动为工具创建资源文档
|
||
resourceURI := fmt.Sprintf("tool://%s", tool.Name)
|
||
s.resources[resourceURI] = &Resource{
|
||
URI: resourceURI,
|
||
Name: fmt.Sprintf("%s工具文档", tool.Name),
|
||
Description: tool.Description,
|
||
MimeType: "text/plain",
|
||
}
|
||
}
|
||
|
||
// ClearTools 清空所有工具(用于重新加载配置)
|
||
func (s *Server) ClearTools() {
|
||
s.mu.Lock()
|
||
defer s.mu.Unlock()
|
||
|
||
// 清空工具和工具定义
|
||
s.tools = make(map[string]ToolHandler)
|
||
s.toolDefs = make(map[string]Tool)
|
||
|
||
// 清空工具相关的资源(保留其他资源)
|
||
newResources := make(map[string]*Resource)
|
||
for uri, resource := range s.resources {
|
||
// 保留非工具资源
|
||
if !strings.HasPrefix(uri, "tool://") {
|
||
newResources[uri] = resource
|
||
}
|
||
}
|
||
s.resources = newResources
|
||
}
|
||
|
||
// HandleHTTP 处理HTTP请求
|
||
func (s *Server) HandleHTTP(w http.ResponseWriter, r *http.Request) {
|
||
if r.Method == http.MethodGet && strings.Contains(r.Header.Get("Accept"), "text/event-stream") {
|
||
s.handleSSE(w, r)
|
||
return
|
||
}
|
||
|
||
if r.Method != http.MethodPost {
|
||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||
return
|
||
}
|
||
|
||
// 官方 MCP SSE 规范:带 sessionid 的 POST 表示消息发往该 SSE 会话,响应通过 SSE 流返回
|
||
if sessionID := r.URL.Query().Get("sessionid"); sessionID != "" {
|
||
s.serveSSESessionMessage(w, r, sessionID)
|
||
return
|
||
}
|
||
|
||
// 简单 POST:请求体为 JSON-RPC,响应在 body 中返回
|
||
body, err := io.ReadAll(r.Body)
|
||
if err != nil {
|
||
s.sendError(w, nil, -32700, "Parse error", err.Error())
|
||
return
|
||
}
|
||
|
||
var msg Message
|
||
if err := json.Unmarshal(body, &msg); err != nil {
|
||
s.sendError(w, nil, -32700, "Parse error", err.Error())
|
||
return
|
||
}
|
||
|
||
response := s.handleMessage(&msg)
|
||
w.Header().Set("Content-Type", "application/json")
|
||
json.NewEncoder(w).Encode(response)
|
||
}
|
||
|
||
// serveSSESessionMessage 处理发往 SSE 会话的 POST:读取 JSON-RPC 请求,处理后将响应通过该会话的 SSE 流推送
|
||
func (s *Server) serveSSESessionMessage(w http.ResponseWriter, r *http.Request, sessionID string) {
|
||
s.mu.RLock()
|
||
client, exists := s.sseClients[sessionID]
|
||
s.mu.RUnlock()
|
||
if !exists || client == nil {
|
||
http.Error(w, "session not found", http.StatusNotFound)
|
||
return
|
||
}
|
||
|
||
body, err := io.ReadAll(r.Body)
|
||
if err != nil {
|
||
http.Error(w, "failed to read body", http.StatusBadRequest)
|
||
return
|
||
}
|
||
|
||
var msg Message
|
||
if err := json.Unmarshal(body, &msg); err != nil {
|
||
http.Error(w, "failed to parse body", http.StatusBadRequest)
|
||
return
|
||
}
|
||
|
||
response := s.handleMessage(&msg)
|
||
if response == nil {
|
||
w.WriteHeader(http.StatusAccepted)
|
||
return
|
||
}
|
||
|
||
respBytes, err := json.Marshal(response)
|
||
if err != nil {
|
||
http.Error(w, "failed to encode response", http.StatusInternalServerError)
|
||
return
|
||
}
|
||
|
||
select {
|
||
case client.send <- respBytes:
|
||
w.WriteHeader(http.StatusAccepted)
|
||
default:
|
||
http.Error(w, "session send buffer full", http.StatusServiceUnavailable)
|
||
}
|
||
}
|
||
|
||
// handleSSE 处理 SSE 连接,兼容官方 MCP 2024-11-05 SSE 规范:
|
||
// 1. 首个事件必须为 event: endpoint,data 为客户端 POST 消息的 URL(含 sessionid)
|
||
// 2. 后续事件为 event: message,data 为 JSON-RPC 响应
|
||
func (s *Server) handleSSE(w http.ResponseWriter, r *http.Request) {
|
||
flusher, ok := w.(http.Flusher)
|
||
if !ok {
|
||
http.Error(w, "Streaming unsupported", http.StatusInternalServerError)
|
||
return
|
||
}
|
||
|
||
w.Header().Set("Content-Type", "text/event-stream")
|
||
w.Header().Set("Cache-Control", "no-cache")
|
||
w.Header().Set("Connection", "keep-alive")
|
||
w.Header().Set("X-Accel-Buffering", "no")
|
||
|
||
sessionID := uuid.New().String()
|
||
client := &sseClient{
|
||
id: sessionID,
|
||
send: make(chan []byte, 32),
|
||
}
|
||
|
||
s.addSSEClient(client)
|
||
defer s.removeSSEClient(client.id)
|
||
|
||
// 官方规范:首个事件为 endpoint,data 为消息端点 URL(客户端将向该 URL POST 请求)
|
||
scheme := "http"
|
||
if r.TLS != nil {
|
||
scheme = "https"
|
||
}
|
||
if r.URL.Scheme != "" {
|
||
scheme = r.URL.Scheme
|
||
}
|
||
endpointURL := fmt.Sprintf("%s://%s%s?sessionid=%s", scheme, r.Host, r.URL.Path, sessionID)
|
||
fmt.Fprintf(w, "event: endpoint\ndata: %s\n\n", endpointURL)
|
||
flusher.Flush()
|
||
|
||
ticker := time.NewTicker(15 * time.Second)
|
||
defer ticker.Stop()
|
||
|
||
for {
|
||
select {
|
||
case <-r.Context().Done():
|
||
return
|
||
case msg, ok := <-client.send:
|
||
if !ok {
|
||
return
|
||
}
|
||
fmt.Fprintf(w, "event: message\ndata: %s\n\n", msg)
|
||
flusher.Flush()
|
||
case <-ticker.C:
|
||
fmt.Fprintf(w, ": ping\n\n")
|
||
flusher.Flush()
|
||
}
|
||
}
|
||
}
|
||
|
||
// addSSEClient 注册SSE客户端
|
||
func (s *Server) addSSEClient(client *sseClient) {
|
||
s.mu.Lock()
|
||
defer s.mu.Unlock()
|
||
s.sseClients[client.id] = client
|
||
}
|
||
|
||
// removeSSEClient 移除SSE客户端
|
||
func (s *Server) removeSSEClient(id string) {
|
||
s.mu.Lock()
|
||
defer s.mu.Unlock()
|
||
if client, exists := s.sseClients[id]; exists {
|
||
close(client.send)
|
||
delete(s.sseClients, id)
|
||
}
|
||
}
|
||
|
||
// handleMessage 处理MCP消息
|
||
func (s *Server) handleMessage(msg *Message) *Message {
|
||
// 检查是否是通知(notification)- 通知没有id字段,不需要响应
|
||
isNotification := msg.ID.Value() == nil || msg.ID.String() == ""
|
||
|
||
// 如果不是通知且ID为空,生成新的UUID
|
||
if !isNotification && msg.ID.String() == "" {
|
||
msg.ID = MessageID{value: uuid.New().String()}
|
||
}
|
||
|
||
switch msg.Method {
|
||
case "initialize":
|
||
return s.handleInitialize(msg)
|
||
case "tools/list":
|
||
return s.handleListTools(msg)
|
||
case "tools/call":
|
||
return s.handleCallTool(msg)
|
||
case "prompts/list":
|
||
return s.handleListPrompts(msg)
|
||
case "prompts/get":
|
||
return s.handleGetPrompt(msg)
|
||
case "resources/list":
|
||
return s.handleListResources(msg)
|
||
case "resources/read":
|
||
return s.handleReadResource(msg)
|
||
case "sampling/request":
|
||
return s.handleSamplingRequest(msg)
|
||
case "notifications/initialized":
|
||
// 通知类型,不需要响应
|
||
s.logger.Debug("收到 initialized 通知")
|
||
return nil
|
||
case "":
|
||
// 空方法名,可能是通知,不返回错误
|
||
if isNotification {
|
||
s.logger.Debug("收到无方法名的通知消息")
|
||
return nil
|
||
}
|
||
fallthrough
|
||
default:
|
||
// 如果是通知,不返回错误响应
|
||
if isNotification {
|
||
s.logger.Debug("收到未知通知", zap.String("method", msg.Method))
|
||
return nil
|
||
}
|
||
// 对于请求,返回方法未找到错误
|
||
return &Message{
|
||
ID: msg.ID,
|
||
Type: MessageTypeError,
|
||
Version: "2.0",
|
||
Error: &Error{Code: -32601, Message: "Method not found"},
|
||
}
|
||
}
|
||
}
|
||
|
||
// handleInitialize 处理初始化请求
|
||
func (s *Server) handleInitialize(msg *Message) *Message {
|
||
var req InitializeRequest
|
||
if err := json.Unmarshal(msg.Params, &req); err != nil {
|
||
return &Message{
|
||
ID: msg.ID,
|
||
Type: MessageTypeError,
|
||
Version: "2.0",
|
||
Error: &Error{Code: -32602, Message: "Invalid params"},
|
||
}
|
||
}
|
||
|
||
response := InitializeResponse{
|
||
ProtocolVersion: ProtocolVersion,
|
||
Capabilities: ServerCapabilities{
|
||
Tools: map[string]interface{}{
|
||
"listChanged": true,
|
||
},
|
||
Prompts: map[string]interface{}{
|
||
"listChanged": true,
|
||
},
|
||
Resources: map[string]interface{}{
|
||
"subscribe": true,
|
||
"listChanged": true,
|
||
},
|
||
Sampling: map[string]interface{}{},
|
||
},
|
||
ServerInfo: ServerInfo{
|
||
Name: "CyberStrikeAI",
|
||
Version: "1.0.0",
|
||
},
|
||
}
|
||
|
||
result, _ := json.Marshal(response)
|
||
return &Message{
|
||
ID: msg.ID,
|
||
Type: MessageTypeResponse,
|
||
Version: "2.0",
|
||
Result: result,
|
||
}
|
||
}
|
||
|
||
// handleListTools 处理列出工具请求
|
||
func (s *Server) handleListTools(msg *Message) *Message {
|
||
s.mu.RLock()
|
||
tools := make([]Tool, 0, len(s.toolDefs))
|
||
for _, tool := range s.toolDefs {
|
||
tools = append(tools, tool)
|
||
}
|
||
s.mu.RUnlock()
|
||
s.logger.Debug("tools/list 请求", zap.Int("返回工具数", len(tools)))
|
||
|
||
response := ListToolsResponse{Tools: tools}
|
||
result, _ := json.Marshal(response)
|
||
return &Message{
|
||
ID: msg.ID,
|
||
Type: MessageTypeResponse,
|
||
Version: "2.0",
|
||
Result: result,
|
||
}
|
||
}
|
||
|
||
// handleCallTool 处理工具调用请求
|
||
func (s *Server) handleCallTool(msg *Message) *Message {
|
||
var req CallToolRequest
|
||
if err := json.Unmarshal(msg.Params, &req); err != nil {
|
||
return &Message{
|
||
ID: msg.ID,
|
||
Type: MessageTypeError,
|
||
Version: "2.0",
|
||
Error: &Error{Code: -32602, Message: "Invalid params"},
|
||
}
|
||
}
|
||
|
||
executionID := uuid.New().String()
|
||
execution := &ToolExecution{
|
||
ID: executionID,
|
||
ToolName: req.Name,
|
||
Arguments: req.Arguments,
|
||
Status: "running",
|
||
StartTime: time.Now(),
|
||
}
|
||
|
||
s.mu.Lock()
|
||
s.executions[executionID] = execution
|
||
// 如果内存中的执行记录超过限制,清理最旧的记录
|
||
s.cleanupOldExecutions()
|
||
s.mu.Unlock()
|
||
|
||
if s.storage != nil {
|
||
if err := s.storage.SaveToolExecution(execution); err != nil {
|
||
s.logger.Warn("保存执行记录到数据库失败", zap.Error(err))
|
||
}
|
||
}
|
||
|
||
s.mu.RLock()
|
||
handler, exists := s.tools[req.Name]
|
||
s.mu.RUnlock()
|
||
|
||
if !exists {
|
||
execution.Status = "failed"
|
||
execution.Error = "Tool not found"
|
||
now := time.Now()
|
||
execution.EndTime = &now
|
||
execution.Duration = now.Sub(execution.StartTime)
|
||
|
||
if s.storage != nil {
|
||
if err := s.storage.SaveToolExecution(execution); err != nil {
|
||
s.logger.Warn("保存执行记录到数据库失败", zap.Error(err))
|
||
}
|
||
s.mu.Lock()
|
||
delete(s.executions, executionID)
|
||
s.mu.Unlock()
|
||
}
|
||
|
||
s.updateStats(req.Name, true)
|
||
|
||
return &Message{
|
||
ID: msg.ID,
|
||
Type: MessageTypeError,
|
||
Version: "2.0",
|
||
Error: &Error{Code: -32601, Message: "Tool not found"},
|
||
}
|
||
}
|
||
|
||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
||
defer cancel()
|
||
|
||
s.logger.Info("开始执行工具",
|
||
zap.String("toolName", req.Name),
|
||
zap.Any("arguments", req.Arguments),
|
||
)
|
||
|
||
result, err := handler(ctx, req.Arguments)
|
||
now := time.Now()
|
||
var failed bool
|
||
var finalResult *ToolResult
|
||
|
||
s.mu.Lock()
|
||
execution.EndTime = &now
|
||
execution.Duration = now.Sub(execution.StartTime)
|
||
|
||
if err != nil {
|
||
execution.Status = "failed"
|
||
execution.Error = err.Error()
|
||
failed = true
|
||
} else if result != nil && result.IsError {
|
||
execution.Status = "failed"
|
||
if len(result.Content) > 0 {
|
||
execution.Error = result.Content[0].Text
|
||
} else {
|
||
execution.Error = "工具执行返回错误结果"
|
||
}
|
||
execution.Result = result
|
||
failed = true
|
||
} else {
|
||
execution.Status = "completed"
|
||
if result == nil {
|
||
result = &ToolResult{
|
||
Content: []Content{
|
||
{Type: "text", Text: "工具执行完成,但未返回结果"},
|
||
},
|
||
}
|
||
}
|
||
execution.Result = result
|
||
failed = false
|
||
}
|
||
|
||
finalResult = execution.Result
|
||
s.mu.Unlock()
|
||
|
||
if s.storage != nil {
|
||
if err := s.storage.SaveToolExecution(execution); err != nil {
|
||
s.logger.Warn("保存执行记录到数据库失败", zap.Error(err))
|
||
}
|
||
}
|
||
|
||
s.updateStats(req.Name, failed)
|
||
|
||
if s.storage != nil {
|
||
s.mu.Lock()
|
||
delete(s.executions, executionID)
|
||
s.mu.Unlock()
|
||
}
|
||
|
||
if err != nil {
|
||
s.logger.Error("工具执行失败",
|
||
zap.String("toolName", req.Name),
|
||
zap.Error(err),
|
||
)
|
||
|
||
errorResult, _ := json.Marshal(CallToolResponse{
|
||
Content: []Content{
|
||
{Type: "text", Text: fmt.Sprintf("工具执行失败: %v", err)},
|
||
},
|
||
IsError: true,
|
||
})
|
||
return &Message{
|
||
ID: msg.ID,
|
||
Type: MessageTypeResponse,
|
||
Version: "2.0",
|
||
Result: errorResult,
|
||
}
|
||
}
|
||
|
||
if finalResult != nil && finalResult.IsError {
|
||
s.logger.Warn("工具执行返回错误结果",
|
||
zap.String("toolName", req.Name),
|
||
)
|
||
|
||
errorResult, _ := json.Marshal(CallToolResponse{
|
||
Content: finalResult.Content,
|
||
IsError: true,
|
||
})
|
||
return &Message{
|
||
ID: msg.ID,
|
||
Type: MessageTypeResponse,
|
||
Version: "2.0",
|
||
Result: errorResult,
|
||
}
|
||
}
|
||
|
||
if finalResult == nil {
|
||
finalResult = &ToolResult{
|
||
Content: []Content{
|
||
{Type: "text", Text: "工具执行完成,但未返回结果"},
|
||
},
|
||
}
|
||
}
|
||
|
||
resultJSON, _ := json.Marshal(CallToolResponse{
|
||
Content: finalResult.Content,
|
||
IsError: false,
|
||
})
|
||
|
||
s.logger.Info("工具执行完成",
|
||
zap.String("toolName", req.Name),
|
||
zap.Bool("isError", finalResult.IsError),
|
||
)
|
||
|
||
return &Message{
|
||
ID: msg.ID,
|
||
Type: MessageTypeResponse,
|
||
Version: "2.0",
|
||
Result: resultJSON,
|
||
}
|
||
}
|
||
|
||
// updateStats 更新统计信息
|
||
func (s *Server) updateStats(toolName string, failed bool) {
|
||
now := time.Now()
|
||
if s.storage != nil {
|
||
totalCalls := 1
|
||
successCalls := 0
|
||
failedCalls := 0
|
||
if failed {
|
||
failedCalls = 1
|
||
} else {
|
||
successCalls = 1
|
||
}
|
||
if err := s.storage.UpdateToolStats(toolName, totalCalls, successCalls, failedCalls, &now); err != nil {
|
||
s.logger.Warn("保存统计信息到数据库失败", zap.Error(err))
|
||
}
|
||
return
|
||
}
|
||
|
||
s.mu.Lock()
|
||
defer s.mu.Unlock()
|
||
|
||
if s.stats[toolName] == nil {
|
||
s.stats[toolName] = &ToolStats{
|
||
ToolName: toolName,
|
||
}
|
||
}
|
||
|
||
stats := s.stats[toolName]
|
||
stats.TotalCalls++
|
||
stats.LastCallTime = &now
|
||
|
||
if failed {
|
||
stats.FailedCalls++
|
||
} else {
|
||
stats.SuccessCalls++
|
||
}
|
||
}
|
||
|
||
// GetExecution 获取执行记录(先从内存查找,再从数据库查找)
|
||
func (s *Server) GetExecution(id string) (*ToolExecution, bool) {
|
||
s.mu.RLock()
|
||
exec, exists := s.executions[id]
|
||
s.mu.RUnlock()
|
||
|
||
if exists {
|
||
return exec, true
|
||
}
|
||
|
||
if s.storage != nil {
|
||
exec, err := s.storage.GetToolExecution(id)
|
||
if err == nil {
|
||
return exec, true
|
||
}
|
||
}
|
||
|
||
return nil, false
|
||
}
|
||
|
||
// loadHistoricalData 从数据库加载历史数据
|
||
func (s *Server) loadHistoricalData() {
|
||
if s.storage == nil {
|
||
return
|
||
}
|
||
|
||
// 加载历史执行记录(最近1000条)
|
||
executions, err := s.storage.LoadToolExecutions()
|
||
if err != nil {
|
||
s.logger.Warn("加载历史执行记录失败", zap.Error(err))
|
||
} else {
|
||
s.mu.Lock()
|
||
for _, exec := range executions {
|
||
// 只加载最近 maxExecutionsInMemory 条,避免内存占用过大
|
||
if len(s.executions) < s.maxExecutionsInMemory {
|
||
s.executions[exec.ID] = exec
|
||
} else {
|
||
break
|
||
}
|
||
}
|
||
s.mu.Unlock()
|
||
s.logger.Info("加载历史执行记录", zap.Int("count", len(executions)))
|
||
}
|
||
|
||
// 加载历史统计信息
|
||
stats, err := s.storage.LoadToolStats()
|
||
if err != nil {
|
||
s.logger.Warn("加载历史统计信息失败", zap.Error(err))
|
||
} else {
|
||
s.mu.Lock()
|
||
for k, v := range stats {
|
||
s.stats[k] = v
|
||
}
|
||
s.mu.Unlock()
|
||
s.logger.Info("加载历史统计信息", zap.Int("count", len(stats)))
|
||
}
|
||
}
|
||
|
||
// GetAllExecutions 获取所有执行记录(合并内存和数据库)
|
||
func (s *Server) GetAllExecutions() []*ToolExecution {
|
||
if s.storage != nil {
|
||
dbExecutions, err := s.storage.LoadToolExecutions()
|
||
if err == nil {
|
||
execMap := make(map[string]*ToolExecution)
|
||
for _, exec := range dbExecutions {
|
||
if _, exists := execMap[exec.ID]; !exists {
|
||
execMap[exec.ID] = exec
|
||
}
|
||
}
|
||
|
||
s.mu.RLock()
|
||
for id, exec := range s.executions {
|
||
if _, exists := execMap[id]; !exists {
|
||
execMap[id] = exec
|
||
}
|
||
}
|
||
s.mu.RUnlock()
|
||
|
||
result := make([]*ToolExecution, 0, len(execMap))
|
||
for _, exec := range execMap {
|
||
result = append(result, exec)
|
||
}
|
||
return result
|
||
} else {
|
||
s.logger.Warn("从数据库加载执行记录失败", zap.Error(err))
|
||
}
|
||
}
|
||
|
||
s.mu.RLock()
|
||
defer s.mu.RUnlock()
|
||
|
||
memExecutions := make([]*ToolExecution, 0, len(s.executions))
|
||
for _, exec := range s.executions {
|
||
memExecutions = append(memExecutions, exec)
|
||
}
|
||
return memExecutions
|
||
}
|
||
|
||
// GetStats 获取统计信息(合并内存和数据库)
|
||
func (s *Server) GetStats() map[string]*ToolStats {
|
||
if s.storage != nil {
|
||
dbStats, err := s.storage.LoadToolStats()
|
||
if err == nil {
|
||
return dbStats
|
||
}
|
||
s.logger.Warn("从数据库加载统计信息失败", zap.Error(err))
|
||
}
|
||
|
||
s.mu.RLock()
|
||
defer s.mu.RUnlock()
|
||
|
||
memStats := make(map[string]*ToolStats)
|
||
for k, v := range s.stats {
|
||
statCopy := *v
|
||
memStats[k] = &statCopy
|
||
}
|
||
|
||
return memStats
|
||
}
|
||
|
||
// GetAllTools 获取所有已注册的工具(用于Agent动态获取工具列表)
|
||
func (s *Server) GetAllTools() []Tool {
|
||
s.mu.RLock()
|
||
defer s.mu.RUnlock()
|
||
|
||
tools := make([]Tool, 0, len(s.toolDefs))
|
||
for _, tool := range s.toolDefs {
|
||
tools = append(tools, tool)
|
||
}
|
||
return tools
|
||
}
|
||
|
||
// CallTool 直接调用工具(用于内部调用)
|
||
func (s *Server) CallTool(ctx context.Context, toolName string, args map[string]interface{}) (*ToolResult, string, error) {
|
||
s.mu.RLock()
|
||
handler, exists := s.tools[toolName]
|
||
s.mu.RUnlock()
|
||
|
||
if !exists {
|
||
return nil, "", fmt.Errorf("工具 %s 未找到", toolName)
|
||
}
|
||
|
||
// 创建执行记录
|
||
executionID := uuid.New().String()
|
||
execution := &ToolExecution{
|
||
ID: executionID,
|
||
ToolName: toolName,
|
||
Arguments: args,
|
||
Status: "running",
|
||
StartTime: time.Now(),
|
||
}
|
||
|
||
s.mu.Lock()
|
||
s.executions[executionID] = execution
|
||
// 如果内存中的执行记录超过限制,清理最旧的记录
|
||
s.cleanupOldExecutions()
|
||
s.mu.Unlock()
|
||
|
||
if s.storage != nil {
|
||
if err := s.storage.SaveToolExecution(execution); err != nil {
|
||
s.logger.Warn("保存执行记录到数据库失败", zap.Error(err))
|
||
}
|
||
}
|
||
|
||
result, err := handler(ctx, args)
|
||
|
||
s.mu.Lock()
|
||
now := time.Now()
|
||
execution.EndTime = &now
|
||
execution.Duration = now.Sub(execution.StartTime)
|
||
var failed bool
|
||
var finalResult *ToolResult
|
||
|
||
if err != nil {
|
||
execution.Status = "failed"
|
||
execution.Error = err.Error()
|
||
failed = true
|
||
} else if result != nil && result.IsError {
|
||
execution.Status = "failed"
|
||
if len(result.Content) > 0 {
|
||
execution.Error = result.Content[0].Text
|
||
} else {
|
||
execution.Error = "工具执行返回错误结果"
|
||
}
|
||
execution.Result = result
|
||
failed = true
|
||
finalResult = result
|
||
} else {
|
||
execution.Status = "completed"
|
||
if result == nil {
|
||
result = &ToolResult{
|
||
Content: []Content{
|
||
{Type: "text", Text: "工具执行完成,但未返回结果"},
|
||
},
|
||
}
|
||
}
|
||
execution.Result = result
|
||
finalResult = result
|
||
failed = false
|
||
}
|
||
|
||
if finalResult == nil {
|
||
finalResult = execution.Result
|
||
}
|
||
s.mu.Unlock()
|
||
|
||
if s.storage != nil {
|
||
if err := s.storage.SaveToolExecution(execution); err != nil {
|
||
s.logger.Warn("保存执行记录到数据库失败", zap.Error(err))
|
||
}
|
||
}
|
||
|
||
s.updateStats(toolName, failed)
|
||
|
||
if s.storage != nil {
|
||
s.mu.Lock()
|
||
delete(s.executions, executionID)
|
||
s.mu.Unlock()
|
||
}
|
||
|
||
if err != nil {
|
||
return nil, executionID, err
|
||
}
|
||
|
||
return finalResult, executionID, nil
|
||
}
|
||
|
||
// cleanupOldExecutions 清理旧的执行记录,防止内存无限增长
|
||
func (s *Server) cleanupOldExecutions() {
|
||
if len(s.executions) <= s.maxExecutionsInMemory {
|
||
return
|
||
}
|
||
|
||
// 按开始时间排序,找出最旧的记录
|
||
type execWithTime struct {
|
||
id string
|
||
startTime time.Time
|
||
}
|
||
execs := make([]execWithTime, 0, len(s.executions))
|
||
for id, exec := range s.executions {
|
||
execs = append(execs, execWithTime{
|
||
id: id,
|
||
startTime: exec.StartTime,
|
||
})
|
||
}
|
||
|
||
// 使用 sort 包进行高效排序(最旧的在前)
|
||
sort.Slice(execs, func(i, j int) bool {
|
||
return execs[i].startTime.Before(execs[j].startTime)
|
||
})
|
||
|
||
// 删除最旧的记录,保留 maxExecutionsInMemory 条
|
||
toDelete := len(s.executions) - s.maxExecutionsInMemory
|
||
for i := 0; i < toDelete; i++ {
|
||
delete(s.executions, execs[i].id)
|
||
}
|
||
|
||
s.logger.Debug("清理旧的执行记录",
|
||
zap.Int("before", len(execs)),
|
||
zap.Int("after", len(s.executions)),
|
||
zap.Int("deleted", toDelete),
|
||
)
|
||
}
|
||
|
||
// initDefaultPrompts 初始化默认提示词模板
|
||
func (s *Server) initDefaultPrompts() {
|
||
s.mu.Lock()
|
||
defer s.mu.Unlock()
|
||
|
||
// 网络安全测试提示词
|
||
s.prompts["security_scan"] = &Prompt{
|
||
Name: "security_scan",
|
||
Description: "生成网络安全扫描任务的提示词",
|
||
Arguments: []PromptArgument{
|
||
{Name: "target", Description: "扫描目标(IP地址或域名)", Required: true},
|
||
{Name: "scan_type", Description: "扫描类型(port, vuln, web等)", Required: false},
|
||
},
|
||
}
|
||
|
||
// 渗透测试提示词
|
||
s.prompts["penetration_test"] = &Prompt{
|
||
Name: "penetration_test",
|
||
Description: "生成渗透测试任务的提示词",
|
||
Arguments: []PromptArgument{
|
||
{Name: "target", Description: "测试目标", Required: true},
|
||
{Name: "scope", Description: "测试范围", Required: false},
|
||
},
|
||
}
|
||
}
|
||
|
||
// initDefaultResources 初始化默认资源
|
||
// 注意:工具资源现在在 RegisterTool 时自动创建,此函数保留用于其他非工具资源
|
||
func (s *Server) initDefaultResources() {
|
||
// 工具资源已改为在 RegisterTool 时自动创建,无需在此硬编码
|
||
}
|
||
|
||
// handleListPrompts 处理列出提示词请求
|
||
func (s *Server) handleListPrompts(msg *Message) *Message {
|
||
s.mu.RLock()
|
||
prompts := make([]Prompt, 0, len(s.prompts))
|
||
for _, prompt := range s.prompts {
|
||
prompts = append(prompts, *prompt)
|
||
}
|
||
s.mu.RUnlock()
|
||
|
||
response := ListPromptsResponse{
|
||
Prompts: prompts,
|
||
}
|
||
result, _ := json.Marshal(response)
|
||
return &Message{
|
||
ID: msg.ID,
|
||
Type: MessageTypeResponse,
|
||
Version: "2.0",
|
||
Result: result,
|
||
}
|
||
}
|
||
|
||
// handleGetPrompt 处理获取提示词请求
|
||
func (s *Server) handleGetPrompt(msg *Message) *Message {
|
||
var req GetPromptRequest
|
||
if err := json.Unmarshal(msg.Params, &req); err != nil {
|
||
return &Message{
|
||
ID: msg.ID,
|
||
Type: MessageTypeError,
|
||
Version: "2.0",
|
||
Error: &Error{Code: -32602, Message: "Invalid params"},
|
||
}
|
||
}
|
||
|
||
s.mu.RLock()
|
||
prompt, exists := s.prompts[req.Name]
|
||
s.mu.RUnlock()
|
||
|
||
if !exists {
|
||
return &Message{
|
||
ID: msg.ID,
|
||
Type: MessageTypeError,
|
||
Version: "2.0",
|
||
Error: &Error{Code: -32601, Message: "Prompt not found"},
|
||
}
|
||
}
|
||
|
||
// 根据提示词名称生成消息
|
||
messages := s.generatePromptMessages(prompt, req.Arguments)
|
||
|
||
response := GetPromptResponse{
|
||
Messages: messages,
|
||
}
|
||
result, _ := json.Marshal(response)
|
||
return &Message{
|
||
ID: msg.ID,
|
||
Type: MessageTypeResponse,
|
||
Version: "2.0",
|
||
Result: result,
|
||
}
|
||
}
|
||
|
||
// generatePromptMessages 生成提示词消息
|
||
func (s *Server) generatePromptMessages(prompt *Prompt, args map[string]interface{}) []PromptMessage {
|
||
messages := []PromptMessage{}
|
||
|
||
switch prompt.Name {
|
||
case "security_scan":
|
||
target, _ := args["target"].(string)
|
||
scanType, _ := args["scan_type"].(string)
|
||
if scanType == "" {
|
||
scanType = "comprehensive"
|
||
}
|
||
|
||
content := fmt.Sprintf(`请对目标 %s 执行%s安全扫描。包括:
|
||
1. 端口扫描和服务识别
|
||
2. 漏洞检测
|
||
3. Web应用安全测试
|
||
4. 生成详细的安全报告`, target, scanType)
|
||
|
||
messages = append(messages, PromptMessage{
|
||
Role: "user",
|
||
Content: content,
|
||
})
|
||
|
||
case "penetration_test":
|
||
target, _ := args["target"].(string)
|
||
scope, _ := args["scope"].(string)
|
||
|
||
content := fmt.Sprintf(`请对目标 %s 执行渗透测试。`, target)
|
||
if scope != "" {
|
||
content += fmt.Sprintf("测试范围:%s", scope)
|
||
}
|
||
content += "\n请按照OWASP Top 10进行全面的安全测试。"
|
||
|
||
messages = append(messages, PromptMessage{
|
||
Role: "user",
|
||
Content: content,
|
||
})
|
||
|
||
default:
|
||
messages = append(messages, PromptMessage{
|
||
Role: "user",
|
||
Content: "请执行安全测试任务",
|
||
})
|
||
}
|
||
|
||
return messages
|
||
}
|
||
|
||
// handleListResources 处理列出资源请求
|
||
func (s *Server) handleListResources(msg *Message) *Message {
|
||
s.mu.RLock()
|
||
resources := make([]Resource, 0, len(s.resources))
|
||
for _, resource := range s.resources {
|
||
resources = append(resources, *resource)
|
||
}
|
||
s.mu.RUnlock()
|
||
|
||
response := ListResourcesResponse{
|
||
Resources: resources,
|
||
}
|
||
result, _ := json.Marshal(response)
|
||
return &Message{
|
||
ID: msg.ID,
|
||
Type: MessageTypeResponse,
|
||
Version: "2.0",
|
||
Result: result,
|
||
}
|
||
}
|
||
|
||
// handleReadResource 处理读取资源请求
|
||
func (s *Server) handleReadResource(msg *Message) *Message {
|
||
var req ReadResourceRequest
|
||
if err := json.Unmarshal(msg.Params, &req); err != nil {
|
||
return &Message{
|
||
ID: msg.ID,
|
||
Type: MessageTypeError,
|
||
Version: "2.0",
|
||
Error: &Error{Code: -32602, Message: "Invalid params"},
|
||
}
|
||
}
|
||
|
||
s.mu.RLock()
|
||
resource, exists := s.resources[req.URI]
|
||
s.mu.RUnlock()
|
||
|
||
if !exists {
|
||
return &Message{
|
||
ID: msg.ID,
|
||
Type: MessageTypeError,
|
||
Version: "2.0",
|
||
Error: &Error{Code: -32601, Message: "Resource not found"},
|
||
}
|
||
}
|
||
|
||
// 生成资源内容
|
||
content := s.generateResourceContent(resource)
|
||
|
||
response := ReadResourceResponse{
|
||
Contents: []ResourceContent{content},
|
||
}
|
||
result, _ := json.Marshal(response)
|
||
return &Message{
|
||
ID: msg.ID,
|
||
Type: MessageTypeResponse,
|
||
Version: "2.0",
|
||
Result: result,
|
||
}
|
||
}
|
||
|
||
// generateResourceContent 生成资源内容
|
||
func (s *Server) generateResourceContent(resource *Resource) ResourceContent {
|
||
content := ResourceContent{
|
||
URI: resource.URI,
|
||
MimeType: resource.MimeType,
|
||
}
|
||
|
||
// 如果是工具资源,生成详细文档
|
||
if strings.HasPrefix(resource.URI, "tool://") {
|
||
toolName := strings.TrimPrefix(resource.URI, "tool://")
|
||
content.Text = s.generateToolDocumentation(toolName, resource)
|
||
} else {
|
||
// 其他资源使用描述或默认内容
|
||
content.Text = resource.Description
|
||
}
|
||
|
||
return content
|
||
}
|
||
|
||
// generateToolDocumentation 生成工具文档
|
||
// 注意:硬编码的工具文档已移除,现在只使用工具定义中的信息
|
||
func (s *Server) generateToolDocumentation(toolName string, resource *Resource) string {
|
||
// 获取工具定义以获取更详细的信息
|
||
s.mu.RLock()
|
||
tool, hasTool := s.toolDefs[toolName]
|
||
s.mu.RUnlock()
|
||
|
||
// 使用工具定义中的描述信息
|
||
if hasTool {
|
||
doc := fmt.Sprintf("%s\n\n", resource.Description)
|
||
if tool.InputSchema != nil {
|
||
if props, ok := tool.InputSchema["properties"].(map[string]interface{}); ok {
|
||
doc += "参数说明:\n"
|
||
for paramName, paramInfo := range props {
|
||
if paramMap, ok := paramInfo.(map[string]interface{}); ok {
|
||
if desc, ok := paramMap["description"].(string); ok {
|
||
doc += fmt.Sprintf("- %s: %s\n", paramName, desc)
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
return doc
|
||
}
|
||
return resource.Description
|
||
}
|
||
|
||
// handleSamplingRequest 处理采样请求
|
||
func (s *Server) handleSamplingRequest(msg *Message) *Message {
|
||
var req SamplingRequest
|
||
if err := json.Unmarshal(msg.Params, &req); err != nil {
|
||
return &Message{
|
||
ID: msg.ID,
|
||
Type: MessageTypeError,
|
||
Version: "2.0",
|
||
Error: &Error{Code: -32602, Message: "Invalid params"},
|
||
}
|
||
}
|
||
|
||
// 注意:采样功能通常需要连接到实际的LLM服务
|
||
// 这里返回一个占位符响应,实际实现需要集成LLM API
|
||
s.logger.Warn("Sampling request received but not fully implemented",
|
||
zap.Any("request", req),
|
||
)
|
||
|
||
response := SamplingResponse{
|
||
Content: []SamplingContent{
|
||
{
|
||
Type: "text",
|
||
Text: "采样功能需要配置LLM服务。请使用Agent Loop API进行AI对话。",
|
||
},
|
||
},
|
||
StopReason: "length",
|
||
}
|
||
result, _ := json.Marshal(response)
|
||
return &Message{
|
||
ID: msg.ID,
|
||
Type: MessageTypeResponse,
|
||
Version: "2.0",
|
||
Result: result,
|
||
}
|
||
}
|
||
|
||
// RegisterPrompt 注册提示词模板
|
||
func (s *Server) RegisterPrompt(prompt *Prompt) {
|
||
s.mu.Lock()
|
||
defer s.mu.Unlock()
|
||
s.prompts[prompt.Name] = prompt
|
||
}
|
||
|
||
// RegisterResource 注册资源
|
||
func (s *Server) RegisterResource(resource *Resource) {
|
||
s.mu.Lock()
|
||
defer s.mu.Unlock()
|
||
s.resources[resource.URI] = resource
|
||
}
|
||
|
||
// HandleStdio 处理标准输入输出(用于 stdio 传输模式)
|
||
// MCP 协议使用换行分隔的 JSON-RPC 消息;管道下需每次写入后 Flush,否则客户端会读不到响应
|
||
func (s *Server) HandleStdio() error {
|
||
decoder := json.NewDecoder(os.Stdin)
|
||
stdout := bufio.NewWriter(os.Stdout)
|
||
encoder := json.NewEncoder(stdout)
|
||
// 注意:不设置缩进,MCP 协议期望紧凑的 JSON 格式
|
||
|
||
for {
|
||
var msg Message
|
||
if err := decoder.Decode(&msg); err != nil {
|
||
if err == io.EOF {
|
||
break
|
||
}
|
||
// 日志输出到 stderr,避免干扰 stdout 的 JSON-RPC 通信
|
||
s.logger.Error("读取消息失败", zap.Error(err))
|
||
// 发送错误响应
|
||
errorMsg := Message{
|
||
ID: msg.ID,
|
||
Type: MessageTypeError,
|
||
Version: "2.0",
|
||
Error: &Error{Code: -32700, Message: "Parse error", Data: err.Error()},
|
||
}
|
||
if err := encoder.Encode(errorMsg); err != nil {
|
||
return fmt.Errorf("发送错误响应失败: %w", err)
|
||
}
|
||
if err := stdout.Flush(); err != nil {
|
||
return fmt.Errorf("刷新 stdout 失败: %w", err)
|
||
}
|
||
continue
|
||
}
|
||
|
||
// 处理消息
|
||
response := s.handleMessage(&msg)
|
||
|
||
// 如果是通知(response 为 nil),不需要发送响应
|
||
if response == nil {
|
||
continue
|
||
}
|
||
|
||
// 发送响应
|
||
if err := encoder.Encode(response); err != nil {
|
||
return fmt.Errorf("发送响应失败: %w", err)
|
||
}
|
||
if err := stdout.Flush(); err != nil {
|
||
return fmt.Errorf("刷新 stdout 失败: %w", err)
|
||
}
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// sendError 发送错误响应
|
||
func (s *Server) sendError(w http.ResponseWriter, id interface{}, code int, message, data string) {
|
||
var msgID MessageID
|
||
if id != nil {
|
||
msgID = MessageID{value: id}
|
||
}
|
||
response := Message{
|
||
ID: msgID,
|
||
Type: MessageTypeError,
|
||
Version: "2.0",
|
||
Error: &Error{Code: code, Message: message, Data: data},
|
||
}
|
||
w.Header().Set("Content-Type", "application/json")
|
||
json.NewEncoder(w).Encode(response)
|
||
}
|