mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-03-31 16:20:28 +02:00
236 lines
7.0 KiB
Go
236 lines
7.0 KiB
Go
package knowledge
|
||
|
||
import (
|
||
"context"
|
||
"encoding/json"
|
||
"fmt"
|
||
"sort"
|
||
"strings"
|
||
|
||
"cyberstrike-ai/internal/mcp"
|
||
|
||
"go.uber.org/zap"
|
||
)
|
||
|
||
// RegisterKnowledgeTool 注册知识检索工具到MCP服务器
|
||
func RegisterKnowledgeTool(
|
||
mcpServer *mcp.Server,
|
||
retriever *Retriever,
|
||
manager *Manager,
|
||
logger *zap.Logger,
|
||
) {
|
||
// manager 和 retriever 在 handler 中直接使用参数
|
||
_ = manager // 保留参数,可能将来用于日志记录等
|
||
tool := mcp.Tool{
|
||
Name: "search_knowledge_base",
|
||
Description: "在知识库中搜索相关的安全知识。当你需要了解特定漏洞类型、攻击技术、检测方法等安全知识时,可以使用此工具进行检索。工具使用向量检索和混合搜索技术,能够根据查询内容的语义相似度和关键词匹配,自动找到最相关的知识片段。",
|
||
ShortDescription: "搜索知识库中的安全知识(支持向量检索和混合搜索)",
|
||
InputSchema: map[string]interface{}{
|
||
"type": "object",
|
||
"properties": map[string]interface{}{
|
||
"query": map[string]interface{}{
|
||
"type": "string",
|
||
"description": "搜索查询内容,描述你想要了解的安全知识主题",
|
||
},
|
||
"risk_type": map[string]interface{}{
|
||
"type": "string",
|
||
"description": "可选:指定风险类型(如:SQL注入、XSS、文件上传等),如果不指定则搜索所有类型",
|
||
},
|
||
},
|
||
"required": []string{"query"},
|
||
},
|
||
}
|
||
|
||
handler := func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
|
||
query, ok := args["query"].(string)
|
||
if !ok || query == "" {
|
||
return &mcp.ToolResult{
|
||
Content: []mcp.Content{
|
||
{
|
||
Type: "text",
|
||
Text: "错误: 查询参数不能为空",
|
||
},
|
||
},
|
||
IsError: true,
|
||
}, nil
|
||
}
|
||
|
||
riskType := ""
|
||
if rt, ok := args["risk_type"].(string); ok && rt != "" {
|
||
riskType = rt
|
||
}
|
||
|
||
logger.Info("执行知识库检索",
|
||
zap.String("query", query),
|
||
zap.String("riskType", riskType),
|
||
)
|
||
|
||
// 执行检索
|
||
searchReq := &SearchRequest{
|
||
Query: query,
|
||
RiskType: riskType,
|
||
TopK: 5,
|
||
}
|
||
|
||
results, err := retriever.Search(ctx, searchReq)
|
||
if err != nil {
|
||
logger.Error("知识库检索失败", zap.Error(err))
|
||
return &mcp.ToolResult{
|
||
Content: []mcp.Content{
|
||
{
|
||
Type: "text",
|
||
Text: fmt.Sprintf("检索失败: %v", err),
|
||
},
|
||
},
|
||
IsError: true,
|
||
}, nil
|
||
}
|
||
|
||
if len(results) == 0 {
|
||
return &mcp.ToolResult{
|
||
Content: []mcp.Content{
|
||
{
|
||
Type: "text",
|
||
Text: fmt.Sprintf("未找到与查询 '%s' 相关的知识。建议:\n1. 尝试使用不同的关键词\n2. 检查风险类型是否正确\n3. 确认知识库中是否包含相关内容", query),
|
||
},
|
||
},
|
||
}, nil
|
||
}
|
||
|
||
// 格式化结果
|
||
var resultText strings.Builder
|
||
|
||
// 按文档分组结果,以便更好地展示上下文
|
||
resultsByItem := make(map[string][]*RetrievalResult)
|
||
for _, result := range results {
|
||
itemID := result.Item.ID
|
||
resultsByItem[itemID] = append(resultsByItem[itemID], result)
|
||
}
|
||
|
||
// 收集检索到的知识项ID(用于日志)
|
||
retrievedItemIDs := make([]string, 0, len(resultsByItem))
|
||
|
||
resultText.WriteString(fmt.Sprintf("找到 %d 条相关知识(包含上下文扩展):\n\n", len(results)))
|
||
|
||
resultIndex := 1
|
||
for itemID, itemResults := range resultsByItem {
|
||
// 找到相似度最高的作为主结果
|
||
mainResult := itemResults[0]
|
||
maxSimilarity := mainResult.Similarity
|
||
for _, result := range itemResults {
|
||
if result.Similarity > maxSimilarity {
|
||
maxSimilarity = result.Similarity
|
||
mainResult = result
|
||
}
|
||
}
|
||
|
||
// 按chunk_index排序,保证阅读的逻辑顺序(文档的原始顺序)
|
||
sort.Slice(itemResults, func(i, j int) bool {
|
||
return itemResults[i].Chunk.ChunkIndex < itemResults[j].Chunk.ChunkIndex
|
||
})
|
||
|
||
// 显示主结果(相似度最高的)
|
||
resultText.WriteString(fmt.Sprintf("--- 结果 %d (相似度: %.2f%%) ---\n", resultIndex, mainResult.Similarity*100))
|
||
resultText.WriteString(fmt.Sprintf("来源: [%s] %s (ID: %s)\n", mainResult.Item.Category, mainResult.Item.Title, mainResult.Item.ID))
|
||
|
||
// 按逻辑顺序显示所有chunk(包括主结果和扩展的chunk)
|
||
if len(itemResults) == 1 {
|
||
// 只有一个chunk,直接显示
|
||
resultText.WriteString(fmt.Sprintf("内容片段:\n%s\n", mainResult.Chunk.ChunkText))
|
||
} else {
|
||
// 多个chunk,按逻辑顺序显示
|
||
resultText.WriteString("内容片段(按文档顺序):\n")
|
||
for i, result := range itemResults {
|
||
// 标记主结果
|
||
marker := ""
|
||
if result.Chunk.ID == mainResult.Chunk.ID {
|
||
marker = " [主匹配]"
|
||
}
|
||
resultText.WriteString(fmt.Sprintf(" [片段 %d%s]\n%s\n", i+1, marker, result.Chunk.ChunkText))
|
||
}
|
||
}
|
||
resultText.WriteString("\n")
|
||
|
||
if !contains(retrievedItemIDs, itemID) {
|
||
retrievedItemIDs = append(retrievedItemIDs, itemID)
|
||
}
|
||
resultIndex++
|
||
}
|
||
|
||
// 在结果末尾添加元数据(JSON格式,用于提取知识项ID)
|
||
// 使用特殊标记,避免影响AI阅读结果
|
||
if len(retrievedItemIDs) > 0 {
|
||
metadataJSON, _ := json.Marshal(map[string]interface{}{
|
||
"_metadata": map[string]interface{}{
|
||
"retrievedItemIDs": retrievedItemIDs,
|
||
},
|
||
})
|
||
resultText.WriteString(fmt.Sprintf("\n<!-- METADATA: %s -->", string(metadataJSON)))
|
||
}
|
||
|
||
// 记录检索日志(异步,不阻塞)
|
||
// 注意:这里没有conversationID和messageID,需要在Agent层面记录
|
||
// 实际的日志记录应该在Agent的progressCallback中完成
|
||
|
||
return &mcp.ToolResult{
|
||
Content: []mcp.Content{
|
||
{
|
||
Type: "text",
|
||
Text: resultText.String(),
|
||
},
|
||
},
|
||
}, nil
|
||
}
|
||
|
||
mcpServer.RegisterTool(tool, handler)
|
||
logger.Info("知识检索工具已注册", zap.String("toolName", tool.Name))
|
||
}
|
||
|
||
// contains 检查切片是否包含元素
|
||
func contains(slice []string, item string) bool {
|
||
for _, s := range slice {
|
||
if s == item {
|
||
return true
|
||
}
|
||
}
|
||
return false
|
||
}
|
||
|
||
// GetRetrievalMetadata 从工具调用中提取检索元数据(用于日志记录)
|
||
func GetRetrievalMetadata(args map[string]interface{}) (query string, riskType string) {
|
||
if q, ok := args["query"].(string); ok {
|
||
query = q
|
||
}
|
||
if rt, ok := args["risk_type"].(string); ok {
|
||
riskType = rt
|
||
}
|
||
return
|
||
}
|
||
|
||
// FormatRetrievalResults 格式化检索结果为字符串(用于日志)
|
||
func FormatRetrievalResults(results []*RetrievalResult) string {
|
||
if len(results) == 0 {
|
||
return "未找到相关结果"
|
||
}
|
||
|
||
var builder strings.Builder
|
||
builder.WriteString(fmt.Sprintf("检索到 %d 条结果:\n", len(results)))
|
||
|
||
itemIDs := make(map[string]bool)
|
||
for i, result := range results {
|
||
builder.WriteString(fmt.Sprintf("%d. [%s] %s (相似度: %.2f%%)\n",
|
||
i+1, result.Item.Category, result.Item.Title, result.Similarity*100))
|
||
itemIDs[result.Item.ID] = true
|
||
}
|
||
|
||
// 返回知识项ID列表(JSON格式)
|
||
ids := make([]string, 0, len(itemIDs))
|
||
for id := range itemIDs {
|
||
ids = append(ids, id)
|
||
}
|
||
idsJSON, _ := json.Marshal(ids)
|
||
builder.WriteString(fmt.Sprintf("\n检索到的知识项ID: %s", string(idsJSON)))
|
||
|
||
return builder.String()
|
||
}
|