mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-04-01 00:30:33 +02:00
912 lines
22 KiB
Go
912 lines
22 KiB
Go
package mcp
|
||
|
||
import (
|
||
"context"
|
||
"encoding/json"
|
||
"fmt"
|
||
"io"
|
||
"net/http"
|
||
"os"
|
||
"strings"
|
||
"sync"
|
||
"time"
|
||
|
||
"github.com/google/uuid"
|
||
"go.uber.org/zap"
|
||
)
|
||
|
||
// Server MCP服务器
|
||
type Server struct {
|
||
tools map[string]ToolHandler
|
||
toolDefs map[string]Tool // 工具定义
|
||
executions map[string]*ToolExecution
|
||
stats map[string]*ToolStats
|
||
prompts map[string]*Prompt // 提示词模板
|
||
resources map[string]*Resource // 资源
|
||
mu sync.RWMutex
|
||
logger *zap.Logger
|
||
}
|
||
|
||
// ToolHandler 工具处理函数
|
||
type ToolHandler func(ctx context.Context, args map[string]interface{}) (*ToolResult, error)
|
||
|
||
// NewServer 创建新的MCP服务器
|
||
func NewServer(logger *zap.Logger) *Server {
|
||
s := &Server{
|
||
tools: make(map[string]ToolHandler),
|
||
toolDefs: make(map[string]Tool),
|
||
executions: make(map[string]*ToolExecution),
|
||
stats: make(map[string]*ToolStats),
|
||
prompts: make(map[string]*Prompt),
|
||
resources: make(map[string]*Resource),
|
||
logger: logger,
|
||
}
|
||
|
||
// 初始化默认提示词和资源
|
||
s.initDefaultPrompts()
|
||
s.initDefaultResources()
|
||
|
||
return s
|
||
}
|
||
|
||
// RegisterTool 注册工具
|
||
func (s *Server) RegisterTool(tool Tool, handler ToolHandler) {
|
||
s.mu.Lock()
|
||
defer s.mu.Unlock()
|
||
s.tools[tool.Name] = handler
|
||
s.toolDefs[tool.Name] = tool
|
||
|
||
// 自动为工具创建资源文档
|
||
resourceURI := fmt.Sprintf("tool://%s", tool.Name)
|
||
s.resources[resourceURI] = &Resource{
|
||
URI: resourceURI,
|
||
Name: fmt.Sprintf("%s工具文档", tool.Name),
|
||
Description: tool.Description,
|
||
MimeType: "text/plain",
|
||
}
|
||
}
|
||
|
||
// ClearTools 清空所有工具(用于重新加载配置)
|
||
func (s *Server) ClearTools() {
|
||
s.mu.Lock()
|
||
defer s.mu.Unlock()
|
||
|
||
// 清空工具和工具定义
|
||
s.tools = make(map[string]ToolHandler)
|
||
s.toolDefs = make(map[string]Tool)
|
||
|
||
// 清空工具相关的资源(保留其他资源)
|
||
newResources := make(map[string]*Resource)
|
||
for uri, resource := range s.resources {
|
||
// 保留非工具资源
|
||
if !strings.HasPrefix(uri, "tool://") {
|
||
newResources[uri] = resource
|
||
}
|
||
}
|
||
s.resources = newResources
|
||
}
|
||
|
||
// HandleHTTP 处理HTTP请求
|
||
func (s *Server) HandleHTTP(w http.ResponseWriter, r *http.Request) {
|
||
if r.Method != http.MethodPost {
|
||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||
return
|
||
}
|
||
|
||
body, err := io.ReadAll(r.Body)
|
||
if err != nil {
|
||
s.sendError(w, nil, -32700, "Parse error", err.Error())
|
||
return
|
||
}
|
||
|
||
var msg Message
|
||
if err := json.Unmarshal(body, &msg); err != nil {
|
||
s.sendError(w, nil, -32700, "Parse error", err.Error())
|
||
return
|
||
}
|
||
|
||
// 处理消息
|
||
response := s.handleMessage(&msg)
|
||
|
||
w.Header().Set("Content-Type", "application/json")
|
||
json.NewEncoder(w).Encode(response)
|
||
}
|
||
|
||
// handleMessage 处理MCP消息
|
||
func (s *Server) handleMessage(msg *Message) *Message {
|
||
// 检查是否是通知(notification)- 通知没有id字段,不需要响应
|
||
isNotification := msg.ID.Value() == nil || msg.ID.String() == ""
|
||
|
||
// 如果不是通知且ID为空,生成新的UUID
|
||
if !isNotification && msg.ID.String() == "" {
|
||
msg.ID = MessageID{value: uuid.New().String()}
|
||
}
|
||
|
||
switch msg.Method {
|
||
case "initialize":
|
||
return s.handleInitialize(msg)
|
||
case "tools/list":
|
||
return s.handleListTools(msg)
|
||
case "tools/call":
|
||
return s.handleCallTool(msg)
|
||
case "prompts/list":
|
||
return s.handleListPrompts(msg)
|
||
case "prompts/get":
|
||
return s.handleGetPrompt(msg)
|
||
case "resources/list":
|
||
return s.handleListResources(msg)
|
||
case "resources/read":
|
||
return s.handleReadResource(msg)
|
||
case "sampling/request":
|
||
return s.handleSamplingRequest(msg)
|
||
case "notifications/initialized":
|
||
// 通知类型,不需要响应
|
||
s.logger.Debug("收到 initialized 通知")
|
||
return nil
|
||
case "":
|
||
// 空方法名,可能是通知,不返回错误
|
||
if isNotification {
|
||
s.logger.Debug("收到无方法名的通知消息")
|
||
return nil
|
||
}
|
||
fallthrough
|
||
default:
|
||
// 如果是通知,不返回错误响应
|
||
if isNotification {
|
||
s.logger.Debug("收到未知通知", zap.String("method", msg.Method))
|
||
return nil
|
||
}
|
||
// 对于请求,返回方法未找到错误
|
||
return &Message{
|
||
ID: msg.ID,
|
||
Type: MessageTypeError,
|
||
Version: "2.0",
|
||
Error: &Error{Code: -32601, Message: "Method not found"},
|
||
}
|
||
}
|
||
}
|
||
|
||
// handleInitialize 处理初始化请求
|
||
func (s *Server) handleInitialize(msg *Message) *Message {
|
||
var req InitializeRequest
|
||
if err := json.Unmarshal(msg.Params, &req); err != nil {
|
||
return &Message{
|
||
ID: msg.ID,
|
||
Type: MessageTypeError,
|
||
Version: "2.0",
|
||
Error: &Error{Code: -32602, Message: "Invalid params"},
|
||
}
|
||
}
|
||
|
||
response := InitializeResponse{
|
||
ProtocolVersion: ProtocolVersion,
|
||
Capabilities: ServerCapabilities{
|
||
Tools: map[string]interface{}{
|
||
"listChanged": true,
|
||
},
|
||
Prompts: map[string]interface{}{
|
||
"listChanged": true,
|
||
},
|
||
Resources: map[string]interface{}{
|
||
"subscribe": true,
|
||
"listChanged": true,
|
||
},
|
||
Sampling: map[string]interface{}{},
|
||
},
|
||
ServerInfo: ServerInfo{
|
||
Name: "CyberStrikeAI",
|
||
Version: "1.0.0",
|
||
},
|
||
}
|
||
|
||
result, _ := json.Marshal(response)
|
||
return &Message{
|
||
ID: msg.ID,
|
||
Type: MessageTypeResponse,
|
||
Version: "2.0",
|
||
Result: result,
|
||
}
|
||
}
|
||
|
||
// handleListTools 处理列出工具请求
|
||
func (s *Server) handleListTools(msg *Message) *Message {
|
||
s.mu.RLock()
|
||
tools := make([]Tool, 0, len(s.toolDefs))
|
||
for _, tool := range s.toolDefs {
|
||
tools = append(tools, tool)
|
||
}
|
||
s.mu.RUnlock()
|
||
|
||
response := ListToolsResponse{Tools: tools}
|
||
result, _ := json.Marshal(response)
|
||
return &Message{
|
||
ID: msg.ID,
|
||
Type: MessageTypeResponse,
|
||
Version: "2.0",
|
||
Result: result,
|
||
}
|
||
}
|
||
|
||
// handleCallTool 处理工具调用请求
|
||
func (s *Server) handleCallTool(msg *Message) *Message {
|
||
var req CallToolRequest
|
||
if err := json.Unmarshal(msg.Params, &req); err != nil {
|
||
return &Message{
|
||
ID: msg.ID,
|
||
Type: MessageTypeError,
|
||
Version: "2.0",
|
||
Error: &Error{Code: -32602, Message: "Invalid params"},
|
||
}
|
||
}
|
||
|
||
// 创建执行记录
|
||
executionID := uuid.New().String()
|
||
execution := &ToolExecution{
|
||
ID: executionID,
|
||
ToolName: req.Name,
|
||
Arguments: req.Arguments,
|
||
Status: "running",
|
||
StartTime: time.Now(),
|
||
}
|
||
|
||
s.mu.Lock()
|
||
s.executions[executionID] = execution
|
||
s.mu.Unlock()
|
||
|
||
// 更新统计
|
||
s.updateStats(req.Name, false)
|
||
|
||
// 执行工具
|
||
s.mu.RLock()
|
||
handler, exists := s.tools[req.Name]
|
||
s.mu.RUnlock()
|
||
|
||
if !exists {
|
||
execution.Status = "failed"
|
||
execution.Error = "Tool not found"
|
||
now := time.Now()
|
||
execution.EndTime = &now
|
||
return &Message{
|
||
ID: msg.ID,
|
||
Type: MessageTypeError,
|
||
Version: "2.0",
|
||
Error: &Error{Code: -32601, Message: "Tool not found"},
|
||
}
|
||
}
|
||
|
||
// 同步执行所有工具,确保错误能正确返回
|
||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
||
defer cancel()
|
||
|
||
s.logger.Info("开始执行工具",
|
||
zap.String("toolName", req.Name),
|
||
zap.Any("arguments", req.Arguments),
|
||
)
|
||
|
||
result, err := handler(ctx, req.Arguments)
|
||
|
||
s.mu.Lock()
|
||
now := time.Now()
|
||
execution.EndTime = &now
|
||
execution.Duration = now.Sub(execution.StartTime)
|
||
|
||
if err != nil {
|
||
execution.Status = "failed"
|
||
execution.Error = err.Error()
|
||
s.updateStats(req.Name, true)
|
||
s.mu.Unlock()
|
||
|
||
s.logger.Error("工具执行失败",
|
||
zap.String("toolName", req.Name),
|
||
zap.Error(err),
|
||
)
|
||
|
||
// 返回错误结果
|
||
errorResult, _ := json.Marshal(CallToolResponse{
|
||
Content: []Content{
|
||
{Type: "text", Text: fmt.Sprintf("工具执行失败: %v", err)},
|
||
},
|
||
IsError: true,
|
||
})
|
||
return &Message{
|
||
ID: msg.ID,
|
||
Type: MessageTypeResponse,
|
||
Version: "2.0",
|
||
Result: errorResult,
|
||
}
|
||
}
|
||
|
||
// 检查result是否为错误
|
||
if result != nil && result.IsError {
|
||
execution.Status = "failed"
|
||
if len(result.Content) > 0 {
|
||
execution.Error = result.Content[0].Text
|
||
}
|
||
s.updateStats(req.Name, true)
|
||
} else {
|
||
execution.Status = "completed"
|
||
execution.Result = result
|
||
s.updateStats(req.Name, false)
|
||
}
|
||
s.mu.Unlock()
|
||
|
||
// 返回执行结果
|
||
if result == nil {
|
||
result = &ToolResult{
|
||
Content: []Content{
|
||
{Type: "text", Text: "工具执行完成,但未返回结果"},
|
||
},
|
||
}
|
||
}
|
||
|
||
resultJSON, _ := json.Marshal(CallToolResponse{
|
||
Content: result.Content,
|
||
IsError: result.IsError,
|
||
})
|
||
|
||
s.logger.Info("工具执行完成",
|
||
zap.String("toolName", req.Name),
|
||
zap.Bool("isError", result.IsError),
|
||
)
|
||
|
||
return &Message{
|
||
ID: msg.ID,
|
||
Type: MessageTypeResponse,
|
||
Version: "2.0",
|
||
Result: resultJSON,
|
||
}
|
||
}
|
||
|
||
// updateStats 更新统计信息
|
||
func (s *Server) updateStats(toolName string, failed bool) {
|
||
if s.stats[toolName] == nil {
|
||
s.stats[toolName] = &ToolStats{
|
||
ToolName: toolName,
|
||
}
|
||
}
|
||
|
||
stats := s.stats[toolName]
|
||
stats.TotalCalls++
|
||
now := time.Now()
|
||
stats.LastCallTime = &now
|
||
|
||
if failed {
|
||
stats.FailedCalls++
|
||
} else {
|
||
stats.SuccessCalls++
|
||
}
|
||
}
|
||
|
||
// GetExecution 获取执行记录
|
||
func (s *Server) GetExecution(id string) (*ToolExecution, bool) {
|
||
s.mu.RLock()
|
||
defer s.mu.RUnlock()
|
||
exec, exists := s.executions[id]
|
||
return exec, exists
|
||
}
|
||
|
||
// GetAllExecutions 获取所有执行记录
|
||
func (s *Server) GetAllExecutions() []*ToolExecution {
|
||
s.mu.RLock()
|
||
defer s.mu.RUnlock()
|
||
executions := make([]*ToolExecution, 0, len(s.executions))
|
||
for _, exec := range s.executions {
|
||
executions = append(executions, exec)
|
||
}
|
||
return executions
|
||
}
|
||
|
||
// GetStats 获取统计信息
|
||
func (s *Server) GetStats() map[string]*ToolStats {
|
||
s.mu.RLock()
|
||
defer s.mu.RUnlock()
|
||
stats := make(map[string]*ToolStats)
|
||
for k, v := range s.stats {
|
||
stats[k] = v
|
||
}
|
||
return stats
|
||
}
|
||
|
||
// GetAllTools 获取所有已注册的工具(用于Agent动态获取工具列表)
|
||
func (s *Server) GetAllTools() []Tool {
|
||
s.mu.RLock()
|
||
defer s.mu.RUnlock()
|
||
|
||
tools := make([]Tool, 0, len(s.toolDefs))
|
||
for _, tool := range s.toolDefs {
|
||
tools = append(tools, tool)
|
||
}
|
||
return tools
|
||
}
|
||
|
||
// CallTool 直接调用工具(用于内部调用)
|
||
func (s *Server) CallTool(ctx context.Context, toolName string, args map[string]interface{}) (*ToolResult, string, error) {
|
||
s.mu.RLock()
|
||
handler, exists := s.tools[toolName]
|
||
s.mu.RUnlock()
|
||
|
||
if !exists {
|
||
return nil, "", fmt.Errorf("工具 %s 未找到", toolName)
|
||
}
|
||
|
||
// 创建执行记录
|
||
executionID := uuid.New().String()
|
||
execution := &ToolExecution{
|
||
ID: executionID,
|
||
ToolName: toolName,
|
||
Arguments: args,
|
||
Status: "running",
|
||
StartTime: time.Now(),
|
||
}
|
||
|
||
s.mu.Lock()
|
||
s.executions[executionID] = execution
|
||
s.mu.Unlock()
|
||
|
||
// 更新统计
|
||
s.updateStats(toolName, false)
|
||
|
||
// 执行工具
|
||
result, err := handler(ctx, args)
|
||
|
||
s.mu.Lock()
|
||
now := time.Now()
|
||
execution.EndTime = &now
|
||
execution.Duration = now.Sub(execution.StartTime)
|
||
|
||
if err != nil {
|
||
execution.Status = "failed"
|
||
execution.Error = err.Error()
|
||
s.updateStats(toolName, true)
|
||
s.mu.Unlock()
|
||
return nil, executionID, err
|
||
} else {
|
||
execution.Status = "completed"
|
||
execution.Result = result
|
||
s.updateStats(toolName, false)
|
||
s.mu.Unlock()
|
||
return result, executionID, nil
|
||
}
|
||
}
|
||
|
||
// initDefaultPrompts 初始化默认提示词模板
|
||
func (s *Server) initDefaultPrompts() {
|
||
s.mu.Lock()
|
||
defer s.mu.Unlock()
|
||
|
||
// 网络安全测试提示词
|
||
s.prompts["security_scan"] = &Prompt{
|
||
Name: "security_scan",
|
||
Description: "生成网络安全扫描任务的提示词",
|
||
Arguments: []PromptArgument{
|
||
{Name: "target", Description: "扫描目标(IP地址或域名)", Required: true},
|
||
{Name: "scan_type", Description: "扫描类型(port, vuln, web等)", Required: false},
|
||
},
|
||
}
|
||
|
||
// 渗透测试提示词
|
||
s.prompts["penetration_test"] = &Prompt{
|
||
Name: "penetration_test",
|
||
Description: "生成渗透测试任务的提示词",
|
||
Arguments: []PromptArgument{
|
||
{Name: "target", Description: "测试目标", Required: true},
|
||
{Name: "scope", Description: "测试范围", Required: false},
|
||
},
|
||
}
|
||
}
|
||
|
||
// initDefaultResources 初始化默认资源
|
||
// 注意:工具资源现在在 RegisterTool 时自动创建,此函数保留用于其他非工具资源
|
||
func (s *Server) initDefaultResources() {
|
||
// 工具资源已改为在 RegisterTool 时自动创建,无需在此硬编码
|
||
}
|
||
|
||
// handleListPrompts 处理列出提示词请求
|
||
func (s *Server) handleListPrompts(msg *Message) *Message {
|
||
s.mu.RLock()
|
||
prompts := make([]Prompt, 0, len(s.prompts))
|
||
for _, prompt := range s.prompts {
|
||
prompts = append(prompts, *prompt)
|
||
}
|
||
s.mu.RUnlock()
|
||
|
||
response := ListPromptsResponse{
|
||
Prompts: prompts,
|
||
}
|
||
result, _ := json.Marshal(response)
|
||
return &Message{
|
||
ID: msg.ID,
|
||
Type: MessageTypeResponse,
|
||
Version: "2.0",
|
||
Result: result,
|
||
}
|
||
}
|
||
|
||
// handleGetPrompt 处理获取提示词请求
|
||
func (s *Server) handleGetPrompt(msg *Message) *Message {
|
||
var req GetPromptRequest
|
||
if err := json.Unmarshal(msg.Params, &req); err != nil {
|
||
return &Message{
|
||
ID: msg.ID,
|
||
Type: MessageTypeError,
|
||
Version: "2.0",
|
||
Error: &Error{Code: -32602, Message: "Invalid params"},
|
||
}
|
||
}
|
||
|
||
s.mu.RLock()
|
||
prompt, exists := s.prompts[req.Name]
|
||
s.mu.RUnlock()
|
||
|
||
if !exists {
|
||
return &Message{
|
||
ID: msg.ID,
|
||
Type: MessageTypeError,
|
||
Version: "2.0",
|
||
Error: &Error{Code: -32601, Message: "Prompt not found"},
|
||
}
|
||
}
|
||
|
||
// 根据提示词名称生成消息
|
||
messages := s.generatePromptMessages(prompt, req.Arguments)
|
||
|
||
response := GetPromptResponse{
|
||
Messages: messages,
|
||
}
|
||
result, _ := json.Marshal(response)
|
||
return &Message{
|
||
ID: msg.ID,
|
||
Type: MessageTypeResponse,
|
||
Version: "2.0",
|
||
Result: result,
|
||
}
|
||
}
|
||
|
||
// generatePromptMessages 生成提示词消息
|
||
func (s *Server) generatePromptMessages(prompt *Prompt, args map[string]interface{}) []PromptMessage {
|
||
messages := []PromptMessage{}
|
||
|
||
switch prompt.Name {
|
||
case "security_scan":
|
||
target, _ := args["target"].(string)
|
||
scanType, _ := args["scan_type"].(string)
|
||
if scanType == "" {
|
||
scanType = "comprehensive"
|
||
}
|
||
|
||
content := fmt.Sprintf(`请对目标 %s 执行%s安全扫描。包括:
|
||
1. 端口扫描和服务识别
|
||
2. 漏洞检测
|
||
3. Web应用安全测试
|
||
4. 生成详细的安全报告`, target, scanType)
|
||
|
||
messages = append(messages, PromptMessage{
|
||
Role: "user",
|
||
Content: content,
|
||
})
|
||
|
||
case "penetration_test":
|
||
target, _ := args["target"].(string)
|
||
scope, _ := args["scope"].(string)
|
||
|
||
content := fmt.Sprintf(`请对目标 %s 执行渗透测试。`, target)
|
||
if scope != "" {
|
||
content += fmt.Sprintf("测试范围:%s", scope)
|
||
}
|
||
content += "\n请按照OWASP Top 10进行全面的安全测试。"
|
||
|
||
messages = append(messages, PromptMessage{
|
||
Role: "user",
|
||
Content: content,
|
||
})
|
||
|
||
default:
|
||
messages = append(messages, PromptMessage{
|
||
Role: "user",
|
||
Content: "请执行安全测试任务",
|
||
})
|
||
}
|
||
|
||
return messages
|
||
}
|
||
|
||
// handleListResources 处理列出资源请求
|
||
func (s *Server) handleListResources(msg *Message) *Message {
|
||
s.mu.RLock()
|
||
resources := make([]Resource, 0, len(s.resources))
|
||
for _, resource := range s.resources {
|
||
resources = append(resources, *resource)
|
||
}
|
||
s.mu.RUnlock()
|
||
|
||
response := ListResourcesResponse{
|
||
Resources: resources,
|
||
}
|
||
result, _ := json.Marshal(response)
|
||
return &Message{
|
||
ID: msg.ID,
|
||
Type: MessageTypeResponse,
|
||
Version: "2.0",
|
||
Result: result,
|
||
}
|
||
}
|
||
|
||
// handleReadResource 处理读取资源请求
|
||
func (s *Server) handleReadResource(msg *Message) *Message {
|
||
var req ReadResourceRequest
|
||
if err := json.Unmarshal(msg.Params, &req); err != nil {
|
||
return &Message{
|
||
ID: msg.ID,
|
||
Type: MessageTypeError,
|
||
Version: "2.0",
|
||
Error: &Error{Code: -32602, Message: "Invalid params"},
|
||
}
|
||
}
|
||
|
||
s.mu.RLock()
|
||
resource, exists := s.resources[req.URI]
|
||
s.mu.RUnlock()
|
||
|
||
if !exists {
|
||
return &Message{
|
||
ID: msg.ID,
|
||
Type: MessageTypeError,
|
||
Version: "2.0",
|
||
Error: &Error{Code: -32601, Message: "Resource not found"},
|
||
}
|
||
}
|
||
|
||
// 生成资源内容
|
||
content := s.generateResourceContent(resource)
|
||
|
||
response := ReadResourceResponse{
|
||
Contents: []ResourceContent{content},
|
||
}
|
||
result, _ := json.Marshal(response)
|
||
return &Message{
|
||
ID: msg.ID,
|
||
Type: MessageTypeResponse,
|
||
Version: "2.0",
|
||
Result: result,
|
||
}
|
||
}
|
||
|
||
// generateResourceContent 生成资源内容
|
||
func (s *Server) generateResourceContent(resource *Resource) ResourceContent {
|
||
content := ResourceContent{
|
||
URI: resource.URI,
|
||
MimeType: resource.MimeType,
|
||
}
|
||
|
||
// 如果是工具资源,生成详细文档
|
||
if strings.HasPrefix(resource.URI, "tool://") {
|
||
toolName := strings.TrimPrefix(resource.URI, "tool://")
|
||
content.Text = s.generateToolDocumentation(toolName, resource)
|
||
} else {
|
||
// 其他资源使用描述或默认内容
|
||
content.Text = resource.Description
|
||
}
|
||
|
||
return content
|
||
}
|
||
|
||
// generateToolDocumentation 生成工具文档
|
||
func (s *Server) generateToolDocumentation(toolName string, resource *Resource) string {
|
||
// 获取工具定义以获取更详细的信息
|
||
s.mu.RLock()
|
||
tool, hasTool := s.toolDefs[toolName]
|
||
s.mu.RUnlock()
|
||
|
||
// 为常见工具生成详细文档
|
||
switch toolName {
|
||
case "nmap":
|
||
return `Nmap (Network Mapper) 是一个强大的网络扫描工具。
|
||
|
||
主要功能:
|
||
- 端口扫描:发现目标主机开放的端口
|
||
- 服务识别:识别运行在端口上的服务
|
||
- 版本检测:检测服务版本信息
|
||
- 操作系统检测:识别目标操作系统
|
||
|
||
常用命令:
|
||
- nmap -sT target # TCP连接扫描
|
||
- nmap -sV target # 版本检测
|
||
- nmap -sC target # 默认脚本扫描
|
||
- nmap -p 1-1000 target # 扫描指定端口范围
|
||
|
||
参数说明:
|
||
- target: 目标IP地址或域名(必需)
|
||
- ports: 端口范围,例如: 1-1000(可选)`
|
||
|
||
case "sqlmap":
|
||
return `SQLMap 是一个自动化的SQL注入检测和利用工具。
|
||
|
||
主要功能:
|
||
- 自动检测SQL注入漏洞
|
||
- 数据库指纹识别
|
||
- 数据提取
|
||
- 文件系统访问
|
||
|
||
常用命令:
|
||
- sqlmap -u "http://target.com/page?id=1" # 检测URL参数
|
||
- sqlmap -u "http://target.com" --forms # 检测表单
|
||
- sqlmap -u "http://target.com" --dbs # 列出数据库
|
||
|
||
参数说明:
|
||
- url: 目标URL(必需)`
|
||
|
||
case "nikto":
|
||
return `Nikto 是一个Web服务器扫描工具。
|
||
|
||
主要功能:
|
||
- Web服务器漏洞扫描
|
||
- 检测过时的服务器软件
|
||
- 检测危险文件和程序
|
||
- 检测服务器配置问题
|
||
|
||
常用命令:
|
||
- nikto -h target # 扫描目标主机
|
||
- nikto -h target -p 80,443 # 扫描指定端口
|
||
|
||
参数说明:
|
||
- target: 目标URL(必需)`
|
||
|
||
case "dirb":
|
||
return `Dirb 是一个Web内容扫描器。
|
||
|
||
主要功能:
|
||
- 扫描Web目录和文件
|
||
- 发现隐藏的目录和文件
|
||
- 支持自定义字典
|
||
|
||
常用命令:
|
||
- dirb url # 扫描目标URL
|
||
- dirb url -w wordlist.txt # 使用自定义字典
|
||
|
||
参数说明:
|
||
- target: 目标URL(必需)`
|
||
|
||
case "exec":
|
||
return `Exec 工具用于执行系统命令。
|
||
|
||
⚠️ 警告:此工具可以执行任意系统命令,请谨慎使用!
|
||
|
||
参数说明:
|
||
- command: 要执行的系统命令(必需)
|
||
- shell: 使用的shell,默认为sh(可选)
|
||
- workdir: 工作目录(可选)`
|
||
|
||
default:
|
||
// 对于其他工具,使用工具定义中的描述信息
|
||
if hasTool {
|
||
doc := fmt.Sprintf("%s\n\n", resource.Description)
|
||
if tool.InputSchema != nil {
|
||
if props, ok := tool.InputSchema["properties"].(map[string]interface{}); ok {
|
||
doc += "参数说明:\n"
|
||
for paramName, paramInfo := range props {
|
||
if paramMap, ok := paramInfo.(map[string]interface{}); ok {
|
||
if desc, ok := paramMap["description"].(string); ok {
|
||
doc += fmt.Sprintf("- %s: %s\n", paramName, desc)
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
return doc
|
||
}
|
||
return resource.Description
|
||
}
|
||
}
|
||
|
||
// handleSamplingRequest 处理采样请求
|
||
func (s *Server) handleSamplingRequest(msg *Message) *Message {
|
||
var req SamplingRequest
|
||
if err := json.Unmarshal(msg.Params, &req); err != nil {
|
||
return &Message{
|
||
ID: msg.ID,
|
||
Type: MessageTypeError,
|
||
Version: "2.0",
|
||
Error: &Error{Code: -32602, Message: "Invalid params"},
|
||
}
|
||
}
|
||
|
||
// 注意:采样功能通常需要连接到实际的LLM服务
|
||
// 这里返回一个占位符响应,实际实现需要集成LLM API
|
||
s.logger.Warn("Sampling request received but not fully implemented",
|
||
zap.Any("request", req),
|
||
)
|
||
|
||
response := SamplingResponse{
|
||
Content: []SamplingContent{
|
||
{
|
||
Type: "text",
|
||
Text: "采样功能需要配置LLM服务。请使用Agent Loop API进行AI对话。",
|
||
},
|
||
},
|
||
StopReason: "length",
|
||
}
|
||
result, _ := json.Marshal(response)
|
||
return &Message{
|
||
ID: msg.ID,
|
||
Type: MessageTypeResponse,
|
||
Version: "2.0",
|
||
Result: result,
|
||
}
|
||
}
|
||
|
||
// RegisterPrompt 注册提示词模板
|
||
func (s *Server) RegisterPrompt(prompt *Prompt) {
|
||
s.mu.Lock()
|
||
defer s.mu.Unlock()
|
||
s.prompts[prompt.Name] = prompt
|
||
}
|
||
|
||
// RegisterResource 注册资源
|
||
func (s *Server) RegisterResource(resource *Resource) {
|
||
s.mu.Lock()
|
||
defer s.mu.Unlock()
|
||
s.resources[resource.URI] = resource
|
||
}
|
||
|
||
// HandleStdio 处理标准输入输出(用于 stdio 传输模式)
|
||
// MCP 协议使用换行分隔的 JSON-RPC 消息
|
||
func (s *Server) HandleStdio() error {
|
||
decoder := json.NewDecoder(os.Stdin)
|
||
encoder := json.NewEncoder(os.Stdout)
|
||
// 注意:不设置缩进,MCP 协议期望紧凑的 JSON 格式
|
||
|
||
for {
|
||
var msg Message
|
||
if err := decoder.Decode(&msg); err != nil {
|
||
if err == io.EOF {
|
||
break
|
||
}
|
||
// 日志输出到 stderr,避免干扰 stdout 的 JSON-RPC 通信
|
||
s.logger.Error("读取消息失败", zap.Error(err))
|
||
// 发送错误响应
|
||
errorMsg := Message{
|
||
ID: msg.ID,
|
||
Type: MessageTypeError,
|
||
Version: "2.0",
|
||
Error: &Error{Code: -32700, Message: "Parse error", Data: err.Error()},
|
||
}
|
||
if err := encoder.Encode(errorMsg); err != nil {
|
||
return fmt.Errorf("发送错误响应失败: %w", err)
|
||
}
|
||
continue
|
||
}
|
||
|
||
// 处理消息
|
||
response := s.handleMessage(&msg)
|
||
|
||
// 如果是通知(response 为 nil),不需要发送响应
|
||
if response == nil {
|
||
continue
|
||
}
|
||
|
||
// 发送响应
|
||
if err := encoder.Encode(response); err != nil {
|
||
return fmt.Errorf("发送响应失败: %w", err)
|
||
}
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// sendError 发送错误响应
|
||
func (s *Server) sendError(w http.ResponseWriter, id interface{}, code int, message, data string) {
|
||
var msgID MessageID
|
||
if id != nil {
|
||
msgID = MessageID{value: id}
|
||
}
|
||
response := Message{
|
||
ID: msgID,
|
||
Type: MessageTypeError,
|
||
Version: "2.0",
|
||
Error: &Error{Code: code, Message: message, Data: data},
|
||
}
|
||
w.Header().Set("Content-Type", "application/json")
|
||
json.NewEncoder(w).Encode(response)
|
||
}
|
||
|