mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-04-21 18:26:38 +02:00
275 lines
8.5 KiB
Go
275 lines
8.5 KiB
Go
package knowledge
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"strings"
|
|
|
|
"cyberstrike-ai/internal/mcp"
|
|
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
// RegisterKnowledgeTool 注册知识检索工具到MCP服务器
|
|
func RegisterKnowledgeTool(
|
|
mcpServer *mcp.Server,
|
|
retriever *Retriever,
|
|
manager *Manager,
|
|
logger *zap.Logger,
|
|
) {
|
|
// manager 和 retriever 在 handler 中直接使用参数
|
|
_ = manager // 保留参数,可能将来用于日志记录等
|
|
tool := mcp.Tool{
|
|
Name: "search_knowledge_base",
|
|
Description: "在知识库中搜索相关的安全知识。当你需要了解特定漏洞类型、攻击技术、检测方法等安全知识时,可以使用此工具进行检索。工具使用向量检索和混合搜索技术,能够根据查询内容的语义相似度和关键词匹配,自动找到最相关的知识片段。",
|
|
ShortDescription: "搜索知识库中的安全知识(支持向量检索和混合搜索)",
|
|
InputSchema: map[string]interface{}{
|
|
"type": "object",
|
|
"properties": map[string]interface{}{
|
|
"query": map[string]interface{}{
|
|
"type": "string",
|
|
"description": "搜索查询内容,描述你想要了解的安全知识主题",
|
|
},
|
|
"risk_type": map[string]interface{}{
|
|
"type": "string",
|
|
"description": "可选:指定风险类型(如:SQL注入、XSS、文件上传等),如果不指定则搜索所有类型",
|
|
},
|
|
},
|
|
"required": []string{"query"},
|
|
},
|
|
}
|
|
|
|
handler := func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
|
|
query, ok := args["query"].(string)
|
|
if !ok || query == "" {
|
|
return &mcp.ToolResult{
|
|
Content: []mcp.Content{
|
|
{
|
|
Type: "text",
|
|
Text: "错误: 查询参数不能为空",
|
|
},
|
|
},
|
|
IsError: true,
|
|
}, nil
|
|
}
|
|
|
|
riskType := ""
|
|
if rt, ok := args["risk_type"].(string); ok && rt != "" {
|
|
riskType = rt
|
|
}
|
|
|
|
logger.Info("执行知识库检索",
|
|
zap.String("query", query),
|
|
zap.String("riskType", riskType),
|
|
)
|
|
|
|
// 执行检索
|
|
searchReq := &SearchRequest{
|
|
Query: query,
|
|
RiskType: riskType,
|
|
TopK: 5,
|
|
}
|
|
|
|
results, err := retriever.Search(ctx, searchReq)
|
|
if err != nil {
|
|
logger.Error("知识库检索失败", zap.Error(err))
|
|
return &mcp.ToolResult{
|
|
Content: []mcp.Content{
|
|
{
|
|
Type: "text",
|
|
Text: fmt.Sprintf("检索失败: %v", err),
|
|
},
|
|
},
|
|
IsError: true,
|
|
}, nil
|
|
}
|
|
|
|
if len(results) == 0 {
|
|
return &mcp.ToolResult{
|
|
Content: []mcp.Content{
|
|
{
|
|
Type: "text",
|
|
Text: fmt.Sprintf("未找到与查询 '%s' 相关的知识。建议:\n1. 尝试使用不同的关键词\n2. 检查风险类型是否正确\n3. 确认知识库中是否包含相关内容", query),
|
|
},
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
// 格式化结果
|
|
var resultText strings.Builder
|
|
resultText.WriteString(fmt.Sprintf("找到 %d 条相关知识:\n\n", len(results)))
|
|
|
|
// 收集检索到的知识项ID(用于日志)
|
|
retrievedItemIDs := make([]string, 0, len(results))
|
|
|
|
for i, result := range results {
|
|
resultText.WriteString(fmt.Sprintf("--- 结果 %d (相似度: %.2f%%) ---\n", i+1, result.Similarity*100))
|
|
resultText.WriteString(fmt.Sprintf("来源: [%s] %s (ID: %s)\n", result.Item.Category, result.Item.Title, result.Item.ID))
|
|
resultText.WriteString(fmt.Sprintf("内容片段:\n%s\n\n", result.Chunk.ChunkText))
|
|
|
|
if !contains(retrievedItemIDs, result.Item.ID) {
|
|
retrievedItemIDs = append(retrievedItemIDs, result.Item.ID)
|
|
}
|
|
}
|
|
|
|
// 在结果末尾添加元数据(JSON格式,用于提取知识项ID)
|
|
// 使用特殊标记,避免影响AI阅读结果
|
|
if len(retrievedItemIDs) > 0 {
|
|
metadataJSON, _ := json.Marshal(map[string]interface{}{
|
|
"_metadata": map[string]interface{}{
|
|
"retrievedItemIDs": retrievedItemIDs,
|
|
},
|
|
})
|
|
resultText.WriteString(fmt.Sprintf("\n<!-- METADATA: %s -->", string(metadataJSON)))
|
|
}
|
|
|
|
// 记录检索日志(异步,不阻塞)
|
|
// 注意:这里没有conversationID和messageID,需要在Agent层面记录
|
|
// 实际的日志记录应该在Agent的progressCallback中完成
|
|
|
|
return &mcp.ToolResult{
|
|
Content: []mcp.Content{
|
|
{
|
|
Type: "text",
|
|
Text: resultText.String(),
|
|
},
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
mcpServer.RegisterTool(tool, handler)
|
|
logger.Info("知识检索工具已注册", zap.String("toolName", tool.Name))
|
|
|
|
// 注册读取完整知识项的工具
|
|
RegisterReadKnowledgeItemTool(mcpServer, manager, logger)
|
|
}
|
|
|
|
// RegisterReadKnowledgeItemTool 注册读取完整知识项工具到MCP服务器
|
|
func RegisterReadKnowledgeItemTool(
|
|
mcpServer *mcp.Server,
|
|
manager *Manager,
|
|
logger *zap.Logger,
|
|
) {
|
|
tool := mcp.Tool{
|
|
Name: "read_knowledge_item",
|
|
Description: "根据知识项ID读取完整的知识文档内容。**重要:此工具应谨慎使用,只在检索到的片段信息明显不足时才调用。** 使用场景:1) 检索片段缺少关键上下文导致无法理解;2) 需要查看文档的完整结构或流程;3) 片段信息不完整,必须查看完整文档才能回答用户问题。**不要**仅为了获取更多信息而盲目读取完整文档,因为检索工具已经返回了最相关的片段。传入知识项ID(从search_knowledge_base的检索结果中获取)即可获取该知识项的完整内容(包括标题、分类、完整文档内容等)。",
|
|
ShortDescription: "读取完整知识项文档(仅在片段信息不足时使用)",
|
|
InputSchema: map[string]interface{}{
|
|
"type": "object",
|
|
"properties": map[string]interface{}{
|
|
"item_id": map[string]interface{}{
|
|
"type": "string",
|
|
"description": "知识项ID(可以从 search_knowledge_base 的检索结果中获取)",
|
|
},
|
|
},
|
|
"required": []string{"item_id"},
|
|
},
|
|
}
|
|
|
|
handler := func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
|
|
itemID, ok := args["item_id"].(string)
|
|
if !ok || itemID == "" {
|
|
return &mcp.ToolResult{
|
|
Content: []mcp.Content{
|
|
{
|
|
Type: "text",
|
|
Text: "错误: item_id 参数不能为空",
|
|
},
|
|
},
|
|
IsError: true,
|
|
}, nil
|
|
}
|
|
|
|
logger.Info("读取知识项", zap.String("itemId", itemID))
|
|
|
|
// 获取完整知识项
|
|
item, err := manager.GetItem(itemID)
|
|
if err != nil {
|
|
logger.Error("读取知识项失败", zap.String("itemId", itemID), zap.Error(err))
|
|
return &mcp.ToolResult{
|
|
Content: []mcp.Content{
|
|
{
|
|
Type: "text",
|
|
Text: fmt.Sprintf("读取知识项失败: %v", err),
|
|
},
|
|
},
|
|
IsError: true,
|
|
}, nil
|
|
}
|
|
|
|
// 格式化结果
|
|
var resultText strings.Builder
|
|
resultText.WriteString("=== 完整知识项内容 ===\n\n")
|
|
resultText.WriteString(fmt.Sprintf("ID: %s\n", item.ID))
|
|
resultText.WriteString(fmt.Sprintf("分类: %s\n", item.Category))
|
|
resultText.WriteString(fmt.Sprintf("标题: %s\n", item.Title))
|
|
if item.FilePath != "" {
|
|
resultText.WriteString(fmt.Sprintf("文件路径: %s\n", item.FilePath))
|
|
}
|
|
resultText.WriteString("\n--- 完整内容 ---\n\n")
|
|
resultText.WriteString(item.Content)
|
|
resultText.WriteString("\n\n")
|
|
|
|
return &mcp.ToolResult{
|
|
Content: []mcp.Content{
|
|
{
|
|
Type: "text",
|
|
Text: resultText.String(),
|
|
},
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
mcpServer.RegisterTool(tool, handler)
|
|
logger.Info("读取知识项工具已注册", zap.String("toolName", tool.Name))
|
|
}
|
|
|
|
// contains 检查切片是否包含元素
|
|
func contains(slice []string, item string) bool {
|
|
for _, s := range slice {
|
|
if s == item {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
// GetRetrievalMetadata 从工具调用中提取检索元数据(用于日志记录)
|
|
func GetRetrievalMetadata(args map[string]interface{}) (query string, riskType string) {
|
|
if q, ok := args["query"].(string); ok {
|
|
query = q
|
|
}
|
|
if rt, ok := args["risk_type"].(string); ok {
|
|
riskType = rt
|
|
}
|
|
return
|
|
}
|
|
|
|
// FormatRetrievalResults 格式化检索结果为字符串(用于日志)
|
|
func FormatRetrievalResults(results []*RetrievalResult) string {
|
|
if len(results) == 0 {
|
|
return "未找到相关结果"
|
|
}
|
|
|
|
var builder strings.Builder
|
|
builder.WriteString(fmt.Sprintf("检索到 %d 条结果:\n", len(results)))
|
|
|
|
itemIDs := make(map[string]bool)
|
|
for i, result := range results {
|
|
builder.WriteString(fmt.Sprintf("%d. [%s] %s (相似度: %.2f%%)\n",
|
|
i+1, result.Item.Category, result.Item.Title, result.Similarity*100))
|
|
itemIDs[result.Item.ID] = true
|
|
}
|
|
|
|
// 返回知识项ID列表(JSON格式)
|
|
ids := make([]string, 0, len(itemIDs))
|
|
for id := range itemIDs {
|
|
ids = append(ids, id)
|
|
}
|
|
idsJSON, _ := json.Marshal(ids)
|
|
builder.WriteString(fmt.Sprintf("\n检索到的知识项ID: %s", string(idsJSON)))
|
|
|
|
return builder.String()
|
|
}
|