Files
CyberStrikeAI/internal/mcp/server.go
2025-11-13 01:26:40 +08:00

912 lines
22 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package mcp
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"strings"
"sync"
"time"
"github.com/google/uuid"
"go.uber.org/zap"
)
// Server MCP服务器
type Server struct {
tools map[string]ToolHandler
toolDefs map[string]Tool // 工具定义
executions map[string]*ToolExecution
stats map[string]*ToolStats
prompts map[string]*Prompt // 提示词模板
resources map[string]*Resource // 资源
mu sync.RWMutex
logger *zap.Logger
}
// ToolHandler 工具处理函数
type ToolHandler func(ctx context.Context, args map[string]interface{}) (*ToolResult, error)
// NewServer 创建新的MCP服务器
func NewServer(logger *zap.Logger) *Server {
s := &Server{
tools: make(map[string]ToolHandler),
toolDefs: make(map[string]Tool),
executions: make(map[string]*ToolExecution),
stats: make(map[string]*ToolStats),
prompts: make(map[string]*Prompt),
resources: make(map[string]*Resource),
logger: logger,
}
// 初始化默认提示词和资源
s.initDefaultPrompts()
s.initDefaultResources()
return s
}
// RegisterTool 注册工具
func (s *Server) RegisterTool(tool Tool, handler ToolHandler) {
s.mu.Lock()
defer s.mu.Unlock()
s.tools[tool.Name] = handler
s.toolDefs[tool.Name] = tool
// 自动为工具创建资源文档
resourceURI := fmt.Sprintf("tool://%s", tool.Name)
s.resources[resourceURI] = &Resource{
URI: resourceURI,
Name: fmt.Sprintf("%s工具文档", tool.Name),
Description: tool.Description,
MimeType: "text/plain",
}
}
// ClearTools 清空所有工具(用于重新加载配置)
func (s *Server) ClearTools() {
s.mu.Lock()
defer s.mu.Unlock()
// 清空工具和工具定义
s.tools = make(map[string]ToolHandler)
s.toolDefs = make(map[string]Tool)
// 清空工具相关的资源(保留其他资源)
newResources := make(map[string]*Resource)
for uri, resource := range s.resources {
// 保留非工具资源
if !strings.HasPrefix(uri, "tool://") {
newResources[uri] = resource
}
}
s.resources = newResources
}
// HandleHTTP 处理HTTP请求
func (s *Server) HandleHTTP(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
body, err := io.ReadAll(r.Body)
if err != nil {
s.sendError(w, nil, -32700, "Parse error", err.Error())
return
}
var msg Message
if err := json.Unmarshal(body, &msg); err != nil {
s.sendError(w, nil, -32700, "Parse error", err.Error())
return
}
// 处理消息
response := s.handleMessage(&msg)
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(response)
}
// handleMessage 处理MCP消息
func (s *Server) handleMessage(msg *Message) *Message {
// 检查是否是通知notification- 通知没有id字段不需要响应
isNotification := msg.ID.Value() == nil || msg.ID.String() == ""
// 如果不是通知且ID为空生成新的UUID
if !isNotification && msg.ID.String() == "" {
msg.ID = MessageID{value: uuid.New().String()}
}
switch msg.Method {
case "initialize":
return s.handleInitialize(msg)
case "tools/list":
return s.handleListTools(msg)
case "tools/call":
return s.handleCallTool(msg)
case "prompts/list":
return s.handleListPrompts(msg)
case "prompts/get":
return s.handleGetPrompt(msg)
case "resources/list":
return s.handleListResources(msg)
case "resources/read":
return s.handleReadResource(msg)
case "sampling/request":
return s.handleSamplingRequest(msg)
case "notifications/initialized":
// 通知类型,不需要响应
s.logger.Debug("收到 initialized 通知")
return nil
case "":
// 空方法名,可能是通知,不返回错误
if isNotification {
s.logger.Debug("收到无方法名的通知消息")
return nil
}
fallthrough
default:
// 如果是通知,不返回错误响应
if isNotification {
s.logger.Debug("收到未知通知", zap.String("method", msg.Method))
return nil
}
// 对于请求,返回方法未找到错误
return &Message{
ID: msg.ID,
Type: MessageTypeError,
Version: "2.0",
Error: &Error{Code: -32601, Message: "Method not found"},
}
}
}
// handleInitialize 处理初始化请求
func (s *Server) handleInitialize(msg *Message) *Message {
var req InitializeRequest
if err := json.Unmarshal(msg.Params, &req); err != nil {
return &Message{
ID: msg.ID,
Type: MessageTypeError,
Version: "2.0",
Error: &Error{Code: -32602, Message: "Invalid params"},
}
}
response := InitializeResponse{
ProtocolVersion: ProtocolVersion,
Capabilities: ServerCapabilities{
Tools: map[string]interface{}{
"listChanged": true,
},
Prompts: map[string]interface{}{
"listChanged": true,
},
Resources: map[string]interface{}{
"subscribe": true,
"listChanged": true,
},
Sampling: map[string]interface{}{},
},
ServerInfo: ServerInfo{
Name: "CyberStrikeAI",
Version: "1.0.0",
},
}
result, _ := json.Marshal(response)
return &Message{
ID: msg.ID,
Type: MessageTypeResponse,
Version: "2.0",
Result: result,
}
}
// handleListTools 处理列出工具请求
func (s *Server) handleListTools(msg *Message) *Message {
s.mu.RLock()
tools := make([]Tool, 0, len(s.toolDefs))
for _, tool := range s.toolDefs {
tools = append(tools, tool)
}
s.mu.RUnlock()
response := ListToolsResponse{Tools: tools}
result, _ := json.Marshal(response)
return &Message{
ID: msg.ID,
Type: MessageTypeResponse,
Version: "2.0",
Result: result,
}
}
// handleCallTool 处理工具调用请求
func (s *Server) handleCallTool(msg *Message) *Message {
var req CallToolRequest
if err := json.Unmarshal(msg.Params, &req); err != nil {
return &Message{
ID: msg.ID,
Type: MessageTypeError,
Version: "2.0",
Error: &Error{Code: -32602, Message: "Invalid params"},
}
}
// 创建执行记录
executionID := uuid.New().String()
execution := &ToolExecution{
ID: executionID,
ToolName: req.Name,
Arguments: req.Arguments,
Status: "running",
StartTime: time.Now(),
}
s.mu.Lock()
s.executions[executionID] = execution
s.mu.Unlock()
// 更新统计
s.updateStats(req.Name, false)
// 执行工具
s.mu.RLock()
handler, exists := s.tools[req.Name]
s.mu.RUnlock()
if !exists {
execution.Status = "failed"
execution.Error = "Tool not found"
now := time.Now()
execution.EndTime = &now
return &Message{
ID: msg.ID,
Type: MessageTypeError,
Version: "2.0",
Error: &Error{Code: -32601, Message: "Tool not found"},
}
}
// 同步执行所有工具,确保错误能正确返回
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()
s.logger.Info("开始执行工具",
zap.String("toolName", req.Name),
zap.Any("arguments", req.Arguments),
)
result, err := handler(ctx, req.Arguments)
s.mu.Lock()
now := time.Now()
execution.EndTime = &now
execution.Duration = now.Sub(execution.StartTime)
if err != nil {
execution.Status = "failed"
execution.Error = err.Error()
s.updateStats(req.Name, true)
s.mu.Unlock()
s.logger.Error("工具执行失败",
zap.String("toolName", req.Name),
zap.Error(err),
)
// 返回错误结果
errorResult, _ := json.Marshal(CallToolResponse{
Content: []Content{
{Type: "text", Text: fmt.Sprintf("工具执行失败: %v", err)},
},
IsError: true,
})
return &Message{
ID: msg.ID,
Type: MessageTypeResponse,
Version: "2.0",
Result: errorResult,
}
}
// 检查result是否为错误
if result != nil && result.IsError {
execution.Status = "failed"
if len(result.Content) > 0 {
execution.Error = result.Content[0].Text
}
s.updateStats(req.Name, true)
} else {
execution.Status = "completed"
execution.Result = result
s.updateStats(req.Name, false)
}
s.mu.Unlock()
// 返回执行结果
if result == nil {
result = &ToolResult{
Content: []Content{
{Type: "text", Text: "工具执行完成,但未返回结果"},
},
}
}
resultJSON, _ := json.Marshal(CallToolResponse{
Content: result.Content,
IsError: result.IsError,
})
s.logger.Info("工具执行完成",
zap.String("toolName", req.Name),
zap.Bool("isError", result.IsError),
)
return &Message{
ID: msg.ID,
Type: MessageTypeResponse,
Version: "2.0",
Result: resultJSON,
}
}
// updateStats 更新统计信息
func (s *Server) updateStats(toolName string, failed bool) {
if s.stats[toolName] == nil {
s.stats[toolName] = &ToolStats{
ToolName: toolName,
}
}
stats := s.stats[toolName]
stats.TotalCalls++
now := time.Now()
stats.LastCallTime = &now
if failed {
stats.FailedCalls++
} else {
stats.SuccessCalls++
}
}
// GetExecution 获取执行记录
func (s *Server) GetExecution(id string) (*ToolExecution, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
exec, exists := s.executions[id]
return exec, exists
}
// GetAllExecutions 获取所有执行记录
func (s *Server) GetAllExecutions() []*ToolExecution {
s.mu.RLock()
defer s.mu.RUnlock()
executions := make([]*ToolExecution, 0, len(s.executions))
for _, exec := range s.executions {
executions = append(executions, exec)
}
return executions
}
// GetStats 获取统计信息
func (s *Server) GetStats() map[string]*ToolStats {
s.mu.RLock()
defer s.mu.RUnlock()
stats := make(map[string]*ToolStats)
for k, v := range s.stats {
stats[k] = v
}
return stats
}
// GetAllTools 获取所有已注册的工具用于Agent动态获取工具列表
func (s *Server) GetAllTools() []Tool {
s.mu.RLock()
defer s.mu.RUnlock()
tools := make([]Tool, 0, len(s.toolDefs))
for _, tool := range s.toolDefs {
tools = append(tools, tool)
}
return tools
}
// CallTool 直接调用工具(用于内部调用)
func (s *Server) CallTool(ctx context.Context, toolName string, args map[string]interface{}) (*ToolResult, string, error) {
s.mu.RLock()
handler, exists := s.tools[toolName]
s.mu.RUnlock()
if !exists {
return nil, "", fmt.Errorf("工具 %s 未找到", toolName)
}
// 创建执行记录
executionID := uuid.New().String()
execution := &ToolExecution{
ID: executionID,
ToolName: toolName,
Arguments: args,
Status: "running",
StartTime: time.Now(),
}
s.mu.Lock()
s.executions[executionID] = execution
s.mu.Unlock()
// 更新统计
s.updateStats(toolName, false)
// 执行工具
result, err := handler(ctx, args)
s.mu.Lock()
now := time.Now()
execution.EndTime = &now
execution.Duration = now.Sub(execution.StartTime)
if err != nil {
execution.Status = "failed"
execution.Error = err.Error()
s.updateStats(toolName, true)
s.mu.Unlock()
return nil, executionID, err
} else {
execution.Status = "completed"
execution.Result = result
s.updateStats(toolName, false)
s.mu.Unlock()
return result, executionID, nil
}
}
// initDefaultPrompts 初始化默认提示词模板
func (s *Server) initDefaultPrompts() {
s.mu.Lock()
defer s.mu.Unlock()
// 网络安全测试提示词
s.prompts["security_scan"] = &Prompt{
Name: "security_scan",
Description: "生成网络安全扫描任务的提示词",
Arguments: []PromptArgument{
{Name: "target", Description: "扫描目标IP地址或域名", Required: true},
{Name: "scan_type", Description: "扫描类型port, vuln, web等", Required: false},
},
}
// 渗透测试提示词
s.prompts["penetration_test"] = &Prompt{
Name: "penetration_test",
Description: "生成渗透测试任务的提示词",
Arguments: []PromptArgument{
{Name: "target", Description: "测试目标", Required: true},
{Name: "scope", Description: "测试范围", Required: false},
},
}
}
// initDefaultResources 初始化默认资源
// 注意:工具资源现在在 RegisterTool 时自动创建,此函数保留用于其他非工具资源
func (s *Server) initDefaultResources() {
// 工具资源已改为在 RegisterTool 时自动创建,无需在此硬编码
}
// handleListPrompts 处理列出提示词请求
func (s *Server) handleListPrompts(msg *Message) *Message {
s.mu.RLock()
prompts := make([]Prompt, 0, len(s.prompts))
for _, prompt := range s.prompts {
prompts = append(prompts, *prompt)
}
s.mu.RUnlock()
response := ListPromptsResponse{
Prompts: prompts,
}
result, _ := json.Marshal(response)
return &Message{
ID: msg.ID,
Type: MessageTypeResponse,
Version: "2.0",
Result: result,
}
}
// handleGetPrompt 处理获取提示词请求
func (s *Server) handleGetPrompt(msg *Message) *Message {
var req GetPromptRequest
if err := json.Unmarshal(msg.Params, &req); err != nil {
return &Message{
ID: msg.ID,
Type: MessageTypeError,
Version: "2.0",
Error: &Error{Code: -32602, Message: "Invalid params"},
}
}
s.mu.RLock()
prompt, exists := s.prompts[req.Name]
s.mu.RUnlock()
if !exists {
return &Message{
ID: msg.ID,
Type: MessageTypeError,
Version: "2.0",
Error: &Error{Code: -32601, Message: "Prompt not found"},
}
}
// 根据提示词名称生成消息
messages := s.generatePromptMessages(prompt, req.Arguments)
response := GetPromptResponse{
Messages: messages,
}
result, _ := json.Marshal(response)
return &Message{
ID: msg.ID,
Type: MessageTypeResponse,
Version: "2.0",
Result: result,
}
}
// generatePromptMessages 生成提示词消息
func (s *Server) generatePromptMessages(prompt *Prompt, args map[string]interface{}) []PromptMessage {
messages := []PromptMessage{}
switch prompt.Name {
case "security_scan":
target, _ := args["target"].(string)
scanType, _ := args["scan_type"].(string)
if scanType == "" {
scanType = "comprehensive"
}
content := fmt.Sprintf(`请对目标 %s 执行%s安全扫描。包括
1. 端口扫描和服务识别
2. 漏洞检测
3. Web应用安全测试
4. 生成详细的安全报告`, target, scanType)
messages = append(messages, PromptMessage{
Role: "user",
Content: content,
})
case "penetration_test":
target, _ := args["target"].(string)
scope, _ := args["scope"].(string)
content := fmt.Sprintf(`请对目标 %s 执行渗透测试。`, target)
if scope != "" {
content += fmt.Sprintf("测试范围:%s", scope)
}
content += "\n请按照OWASP Top 10进行全面的安全测试。"
messages = append(messages, PromptMessage{
Role: "user",
Content: content,
})
default:
messages = append(messages, PromptMessage{
Role: "user",
Content: "请执行安全测试任务",
})
}
return messages
}
// handleListResources 处理列出资源请求
func (s *Server) handleListResources(msg *Message) *Message {
s.mu.RLock()
resources := make([]Resource, 0, len(s.resources))
for _, resource := range s.resources {
resources = append(resources, *resource)
}
s.mu.RUnlock()
response := ListResourcesResponse{
Resources: resources,
}
result, _ := json.Marshal(response)
return &Message{
ID: msg.ID,
Type: MessageTypeResponse,
Version: "2.0",
Result: result,
}
}
// handleReadResource 处理读取资源请求
func (s *Server) handleReadResource(msg *Message) *Message {
var req ReadResourceRequest
if err := json.Unmarshal(msg.Params, &req); err != nil {
return &Message{
ID: msg.ID,
Type: MessageTypeError,
Version: "2.0",
Error: &Error{Code: -32602, Message: "Invalid params"},
}
}
s.mu.RLock()
resource, exists := s.resources[req.URI]
s.mu.RUnlock()
if !exists {
return &Message{
ID: msg.ID,
Type: MessageTypeError,
Version: "2.0",
Error: &Error{Code: -32601, Message: "Resource not found"},
}
}
// 生成资源内容
content := s.generateResourceContent(resource)
response := ReadResourceResponse{
Contents: []ResourceContent{content},
}
result, _ := json.Marshal(response)
return &Message{
ID: msg.ID,
Type: MessageTypeResponse,
Version: "2.0",
Result: result,
}
}
// generateResourceContent 生成资源内容
func (s *Server) generateResourceContent(resource *Resource) ResourceContent {
content := ResourceContent{
URI: resource.URI,
MimeType: resource.MimeType,
}
// 如果是工具资源,生成详细文档
if strings.HasPrefix(resource.URI, "tool://") {
toolName := strings.TrimPrefix(resource.URI, "tool://")
content.Text = s.generateToolDocumentation(toolName, resource)
} else {
// 其他资源使用描述或默认内容
content.Text = resource.Description
}
return content
}
// generateToolDocumentation 生成工具文档
func (s *Server) generateToolDocumentation(toolName string, resource *Resource) string {
// 获取工具定义以获取更详细的信息
s.mu.RLock()
tool, hasTool := s.toolDefs[toolName]
s.mu.RUnlock()
// 为常见工具生成详细文档
switch toolName {
case "nmap":
return `Nmap (Network Mapper) 是一个强大的网络扫描工具。
主要功能:
- 端口扫描:发现目标主机开放的端口
- 服务识别:识别运行在端口上的服务
- 版本检测:检测服务版本信息
- 操作系统检测:识别目标操作系统
常用命令:
- nmap -sT target # TCP连接扫描
- nmap -sV target # 版本检测
- nmap -sC target # 默认脚本扫描
- nmap -p 1-1000 target # 扫描指定端口范围
参数说明:
- target: 目标IP地址或域名必需
- ports: 端口范围,例如: 1-1000可选`
case "sqlmap":
return `SQLMap 是一个自动化的SQL注入检测和利用工具。
主要功能:
- 自动检测SQL注入漏洞
- 数据库指纹识别
- 数据提取
- 文件系统访问
常用命令:
- sqlmap -u "http://target.com/page?id=1" # 检测URL参数
- sqlmap -u "http://target.com" --forms # 检测表单
- sqlmap -u "http://target.com" --dbs # 列出数据库
参数说明:
- url: 目标URL必需`
case "nikto":
return `Nikto 是一个Web服务器扫描工具。
主要功能:
- Web服务器漏洞扫描
- 检测过时的服务器软件
- 检测危险文件和程序
- 检测服务器配置问题
常用命令:
- nikto -h target # 扫描目标主机
- nikto -h target -p 80,443 # 扫描指定端口
参数说明:
- target: 目标URL必需`
case "dirb":
return `Dirb 是一个Web内容扫描器。
主要功能:
- 扫描Web目录和文件
- 发现隐藏的目录和文件
- 支持自定义字典
常用命令:
- dirb url # 扫描目标URL
- dirb url -w wordlist.txt # 使用自定义字典
参数说明:
- target: 目标URL必需`
case "exec":
return `Exec 工具用于执行系统命令。
⚠️ 警告:此工具可以执行任意系统命令,请谨慎使用!
参数说明:
- command: 要执行的系统命令(必需)
- shell: 使用的shell默认为sh可选
- workdir: 工作目录(可选)`
default:
// 对于其他工具,使用工具定义中的描述信息
if hasTool {
doc := fmt.Sprintf("%s\n\n", resource.Description)
if tool.InputSchema != nil {
if props, ok := tool.InputSchema["properties"].(map[string]interface{}); ok {
doc += "参数说明:\n"
for paramName, paramInfo := range props {
if paramMap, ok := paramInfo.(map[string]interface{}); ok {
if desc, ok := paramMap["description"].(string); ok {
doc += fmt.Sprintf("- %s: %s\n", paramName, desc)
}
}
}
}
}
return doc
}
return resource.Description
}
}
// handleSamplingRequest 处理采样请求
func (s *Server) handleSamplingRequest(msg *Message) *Message {
var req SamplingRequest
if err := json.Unmarshal(msg.Params, &req); err != nil {
return &Message{
ID: msg.ID,
Type: MessageTypeError,
Version: "2.0",
Error: &Error{Code: -32602, Message: "Invalid params"},
}
}
// 注意采样功能通常需要连接到实际的LLM服务
// 这里返回一个占位符响应实际实现需要集成LLM API
s.logger.Warn("Sampling request received but not fully implemented",
zap.Any("request", req),
)
response := SamplingResponse{
Content: []SamplingContent{
{
Type: "text",
Text: "采样功能需要配置LLM服务。请使用Agent Loop API进行AI对话。",
},
},
StopReason: "length",
}
result, _ := json.Marshal(response)
return &Message{
ID: msg.ID,
Type: MessageTypeResponse,
Version: "2.0",
Result: result,
}
}
// RegisterPrompt 注册提示词模板
func (s *Server) RegisterPrompt(prompt *Prompt) {
s.mu.Lock()
defer s.mu.Unlock()
s.prompts[prompt.Name] = prompt
}
// RegisterResource 注册资源
func (s *Server) RegisterResource(resource *Resource) {
s.mu.Lock()
defer s.mu.Unlock()
s.resources[resource.URI] = resource
}
// HandleStdio 处理标准输入输出(用于 stdio 传输模式)
// MCP 协议使用换行分隔的 JSON-RPC 消息
func (s *Server) HandleStdio() error {
decoder := json.NewDecoder(os.Stdin)
encoder := json.NewEncoder(os.Stdout)
// 注意不设置缩进MCP 协议期望紧凑的 JSON 格式
for {
var msg Message
if err := decoder.Decode(&msg); err != nil {
if err == io.EOF {
break
}
// 日志输出到 stderr避免干扰 stdout 的 JSON-RPC 通信
s.logger.Error("读取消息失败", zap.Error(err))
// 发送错误响应
errorMsg := Message{
ID: msg.ID,
Type: MessageTypeError,
Version: "2.0",
Error: &Error{Code: -32700, Message: "Parse error", Data: err.Error()},
}
if err := encoder.Encode(errorMsg); err != nil {
return fmt.Errorf("发送错误响应失败: %w", err)
}
continue
}
// 处理消息
response := s.handleMessage(&msg)
// 如果是通知response 为 nil不需要发送响应
if response == nil {
continue
}
// 发送响应
if err := encoder.Encode(response); err != nil {
return fmt.Errorf("发送响应失败: %w", err)
}
}
return nil
}
// sendError 发送错误响应
func (s *Server) sendError(w http.ResponseWriter, id interface{}, code int, message, data string) {
var msgID MessageID
if id != nil {
msgID = MessageID{value: id}
}
response := Message{
ID: msgID,
Type: MessageTypeError,
Version: "2.0",
Error: &Error{Code: code, Message: message, Data: data},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(response)
}