Add files via upload

This commit is contained in:
公明
2026-01-11 02:03:33 +08:00
committed by GitHub
parent 4ca1aa9aa8
commit 3aee7022c4
30 changed files with 3759 additions and 86 deletions
+8 -8
View File
@@ -161,14 +161,14 @@ func (r *Retriever) Search(ctx context.Context, req *SearchRequest) ([]*Retrieva
// 查询所有向量(或按风险类型过滤)
// 使用精确匹配(=)以提高性能和准确性
// 由于系统提供了 list_knowledge_risk_types 工具,用户应该使用准确的category名称
// 同时,向量嵌入中已包含category信息,即使SQL过滤不完全匹配,向量相似度也能帮助匹配
var rows *sql.Rows
if req.RiskType != "" {
// 使用精确匹配(=),性能更好且更准确
// 使用 COLLATE NOCASE 实现大小写不敏感匹配,提高容错性
// 注意:如果用户输入的risk_type与category不完全一致,可能匹配不到
// 建议用户先调用 list_knowledge_risk_types 获取准确的category名称
// 由于系统提供了内置工具来获取风险类型列表,用户应该使用准确的category名称
// 同时,向量嵌入中已包含category信息,即使SQL过滤不完全匹配,向量相似度也能帮助匹配
var rows *sql.Rows
if req.RiskType != "" {
// 使用精确匹配(=),性能更好且更准确
// 使用 COLLATE NOCASE 实现大小写不敏感匹配,提高容错性
// 注意:如果用户输入的risk_type与category不完全一致,可能匹配不到
// 建议用户先调用相应的内置工具获取准确的category名称
rows, err = r.db.Query(`
SELECT e.id, e.item_id, e.chunk_index, e.chunk_text, e.embedding, i.category, i.title
FROM knowledge_embeddings e
+12 -11
View File
@@ -8,6 +8,7 @@ import (
"strings"
"cyberstrike-ai/internal/mcp"
"cyberstrike-ai/internal/mcp/builtin"
"go.uber.org/zap"
)
@@ -21,7 +22,7 @@ func RegisterKnowledgeTool(
) {
// 注册第一个工具:获取所有可用的风险类型列表
listRiskTypesTool := mcp.Tool{
Name: "list_knowledge_risk_types",
Name: builtin.ToolListKnowledgeRiskTypes,
Description: "获取知识库中所有可用的风险类型(risk_type)列表。在搜索知识库之前,可以先调用此工具获取可用的风险类型,然后使用正确的风险类型进行精确搜索,这样可以大幅减少检索时间并提高检索准确性。",
ShortDescription: "获取知识库中所有可用的风险类型列表",
InputSchema: map[string]interface{}{
@@ -62,7 +63,7 @@ func RegisterKnowledgeTool(
for i, category := range categories {
resultText.WriteString(fmt.Sprintf("%d. %s\n", i+1, category))
}
resultText.WriteString("\n提示:在调用 search_knowledge_base 工具时,可以使用上述风险类型之一作为 risk_type 参数,以缩小搜索范围并提高检索效率。")
resultText.WriteString("\n提示:在调用 " + builtin.ToolSearchKnowledgeBase + " 工具时,可以使用上述风险类型之一作为 risk_type 参数,以缩小搜索范围并提高检索效率。")
return &mcp.ToolResult{
Content: []mcp.Content{
@@ -79,8 +80,8 @@ func RegisterKnowledgeTool(
// 注册第二个工具:搜索知识库(保持原有功能)
searchTool := mcp.Tool{
Name: "search_knowledge_base",
Description: "在知识库中搜索相关的安全知识。当你需要了解特定漏洞类型、攻击技术、检测方法等安全知识时,可以使用此工具进行检索。工具使用向量检索和混合搜索技术,能够根据查询内容的语义相似度和关键词匹配,自动找到最相关的知识片段。建议:在搜索前可以先调用 list_knowledge_risk_types 工具获取可用的风险类型,然后使用正确的 risk_type 参数进行精确搜索,这样可以大幅减少检索时间。",
Name: builtin.ToolSearchKnowledgeBase,
Description: "在知识库中搜索相关的安全知识。当你需要了解特定漏洞类型、攻击技术、检测方法等安全知识时,可以使用此工具进行检索。工具使用向量检索和混合搜索技术,能够根据查询内容的语义相似度和关键词匹配,自动找到最相关的知识片段。建议:在搜索前可以先调用 " + builtin.ToolListKnowledgeRiskTypes + " 工具获取可用的风险类型,然后使用正确的 risk_type 参数进行精确搜索,这样可以大幅减少检索时间。",
ShortDescription: "搜索知识库中的安全知识(支持向量检索和混合搜索)",
InputSchema: map[string]interface{}{
"type": "object",
@@ -91,7 +92,7 @@ func RegisterKnowledgeTool(
},
"risk_type": map[string]interface{}{
"type": "string",
"description": "可选:指定风险类型(如:SQL注入、XSS、文件上传等)。建议先调用 list_knowledge_risk_types 工具获取可用的风险类型列表,然后使用正确的风险类型进行精确搜索,这样可以大幅减少检索时间。如果不指定则搜索所有类型。",
"description": "可选:指定风险类型(如:SQL注入、XSS、文件上传等)。建议先调用 " + builtin.ToolListKnowledgeRiskTypes + " 工具获取可用的风险类型列表,然后使用正确的风险类型进行精确搜索,这样可以大幅减少检索时间。如果不指定则搜索所有类型。",
},
},
"required": []string{"query"},
@@ -165,9 +166,9 @@ func RegisterKnowledgeTool(
// 按文档分组结果,以便更好地展示上下文
// 使用有序的slice来保持文档顺序(按最高混合分数)
type itemGroup struct {
itemID string
results []*RetrievalResult
maxScore float64 // 该文档的最高混合分数
itemID string
results []*RetrievalResult
maxScore float64 // 该文档的最高混合分数
}
itemGroups := make([]*itemGroup, 0)
itemMap := make(map[string]*itemGroup)
@@ -177,8 +178,8 @@ func RegisterKnowledgeTool(
group, exists := itemMap[itemID]
if !exists {
group = &itemGroup{
itemID: itemID,
results: make([]*RetrievalResult, 0),
itemID: itemID,
results: make([]*RetrievalResult, 0),
maxScore: result.Score,
}
itemMap[itemID] = group
@@ -219,7 +220,7 @@ func RegisterKnowledgeTool(
})
// 显示主结果(混合分数最高的,同时显示相似度和混合分数)
resultText.WriteString(fmt.Sprintf("--- 结果 %d (相似度: %.2f%%, 混合分数: %.2f%%) ---\n",
resultText.WriteString(fmt.Sprintf("--- 结果 %d (相似度: %.2f%%, 混合分数: %.2f%%) ---\n",
resultIndex, mainResult.Similarity*100, mainResult.Score*100))
resultText.WriteString(fmt.Sprintf("来源: [%s] %s (ID: %s)\n", mainResult.Item.Category, mainResult.Item.Title, mainResult.Item.ID))