Add files via upload

This commit is contained in:
公明
2025-12-24 01:47:57 +08:00
committed by GitHub
parent 6e832601d8
commit b8e58d9e44
4 changed files with 92 additions and 17 deletions

View File

@@ -103,7 +103,7 @@ func (h *AuthHandler) ChangePassword(c *gin.Context) {
}
if !h.manager.CheckPassword(oldPassword) {
c.JSON(http.StatusUnauthorized, gin.H{"error": "当前密码不正确"})
c.JSON(http.StatusBadRequest, gin.H{"error": "当前密码不正确"})
return
}

View File

@@ -250,9 +250,9 @@ func (idx *Indexer) estimateTokens(text string) int {
// IndexItem 索引知识项(分块并向量化)
func (idx *Indexer) IndexItem(ctx context.Context, itemID string) error {
// 获取知识项
var content string
err := idx.db.QueryRow("SELECT content FROM knowledge_base_items WHERE id = ?", itemID).Scan(&content)
// 获取知识项包含category和title用于向量化
var content, category, title string
err := idx.db.QueryRow("SELECT content, category, title FROM knowledge_base_items WHERE id = ?", itemID).Scan(&content, &category, &title)
if err != nil {
return fmt.Errorf("获取知识项失败: %w", err)
}
@@ -267,13 +267,19 @@ func (idx *Indexer) IndexItem(ctx context.Context, itemID string) error {
chunks := idx.ChunkText(content)
idx.logger.Info("知识项分块完成", zap.String("itemId", itemID), zap.Int("chunks", len(chunks)))
// 向量化每个块
// 向量化每个块包含category和title信息以便向量检索时能匹配到风险类型
for i, chunk := range chunks {
chunkPreview := chunk
if len(chunkPreview) > 200 {
chunkPreview = chunkPreview[:200] + "..."
}
embedding, err := idx.embedder.EmbedText(ctx, chunk)
// 将category和title信息包含到向量化的文本中
// 格式:"[风险类型: {category}] [标题: {title}]\n{chunk内容}"
// 这样向量嵌入就会包含风险类型信息即使SQL过滤失败向量相似度也能帮助匹配
textForEmbedding := fmt.Sprintf("[风险类型: %s] [标题: %s]\n%s", category, title, chunk)
embedding, err := idx.embedder.EmbedText(ctx, textForEmbedding)
if err != nil {
idx.logger.Warn("向量化失败",
zap.String("itemId", itemID),

View File

@@ -102,20 +102,32 @@ func (r *Retriever) Search(ctx context.Context, req *SearchRequest) ([]*Retrieva
threshold = 0.7
}
// 向量化查询
queryEmbedding, err := r.embedder.EmbedText(ctx, req.Query)
// 向量化查询如果提供了risk_type也包含在查询文本中以便更好地匹配
queryText := req.Query
if req.RiskType != "" {
// 将risk_type信息包含到查询中格式与索引时保持一致
queryText = fmt.Sprintf("[风险类型: %s] %s", req.RiskType, req.Query)
}
queryEmbedding, err := r.embedder.EmbedText(ctx, queryText)
if err != nil {
return nil, fmt.Errorf("向量化查询失败: %w", err)
}
// 查询所有向量(或按风险类型过滤)
// 使用精确匹配(=)以提高性能和准确性
// 由于系统提供了 list_knowledge_risk_types 工具用户应该使用准确的category名称
// 同时向量嵌入中已包含category信息即使SQL过滤不完全匹配向量相似度也能帮助匹配
var rows *sql.Rows
if req.RiskType != "" {
// 使用精确匹配(=),性能更好且更准确
// 使用 COLLATE NOCASE 实现大小写不敏感匹配,提高容错性
// 注意如果用户输入的risk_type与category不完全一致可能匹配不到
// 建议用户先调用 list_knowledge_risk_types 获取准确的category名称
rows, err = r.db.Query(`
SELECT e.id, e.item_id, e.chunk_index, e.chunk_text, e.embedding, i.category, i.title
FROM knowledge_embeddings e
JOIN knowledge_base_items i ON e.item_id = i.id
WHERE i.category = ?
WHERE i.category = ? COLLATE NOCASE
`, req.RiskType)
} else {
rows, err = r.db.Query(`

View File

@@ -19,11 +19,68 @@ func RegisterKnowledgeTool(
manager *Manager,
logger *zap.Logger,
) {
// manager 和 retriever 在 handler 中直接使用参数
_ = manager // 保留参数,可能将来用于日志记录等
tool := mcp.Tool{
// 注册第一个工具:获取所有可用的风险类型列表
listRiskTypesTool := mcp.Tool{
Name: "list_knowledge_risk_types",
Description: "获取知识库中所有可用的风险类型risk_type列表。在搜索知识库之前可以先调用此工具获取可用的风险类型然后使用正确的风险类型进行精确搜索这样可以大幅减少检索时间并提高检索准确性。",
ShortDescription: "获取知识库中所有可用的风险类型列表",
InputSchema: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{},
"required": []string{},
},
}
listRiskTypesHandler := func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
categories, err := manager.GetCategories()
if err != nil {
logger.Error("获取风险类型列表失败", zap.Error(err))
return &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: fmt.Sprintf("获取风险类型列表失败: %v", err),
},
},
IsError: true,
}, nil
}
if len(categories) == 0 {
return &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: "知识库中暂无风险类型。",
},
},
}, nil
}
var resultText strings.Builder
resultText.WriteString(fmt.Sprintf("知识库中共有 %d 个风险类型:\n\n", len(categories)))
for i, category := range categories {
resultText.WriteString(fmt.Sprintf("%d. %s\n", i+1, category))
}
resultText.WriteString("\n提示在调用 search_knowledge_base 工具时,可以使用上述风险类型之一作为 risk_type 参数,以缩小搜索范围并提高检索效率。")
return &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: resultText.String(),
},
},
}, nil
}
mcpServer.RegisterTool(listRiskTypesTool, listRiskTypesHandler)
logger.Info("风险类型列表工具已注册", zap.String("toolName", listRiskTypesTool.Name))
// 注册第二个工具:搜索知识库(保持原有功能)
searchTool := mcp.Tool{
Name: "search_knowledge_base",
Description: "在知识库中搜索相关的安全知识。当你需要了解特定漏洞类型、攻击技术、检测方法等安全知识时,可以使用此工具进行检索。工具使用向量检索和混合搜索技术,能够根据查询内容的语义相似度和关键词匹配,自动找到最相关的知识片段。",
Description: "在知识库中搜索相关的安全知识。当你需要了解特定漏洞类型、攻击技术、检测方法等安全知识时,可以使用此工具进行检索。工具使用向量检索和混合搜索技术,能够根据查询内容的语义相似度和关键词匹配,自动找到最相关的知识片段。建议:在搜索前可以先调用 list_knowledge_risk_types 工具获取可用的风险类型,然后使用正确的 risk_type 参数进行精确搜索,这样可以大幅减少检索时间。",
ShortDescription: "搜索知识库中的安全知识(支持向量检索和混合搜索)",
InputSchema: map[string]interface{}{
"type": "object",
@@ -34,14 +91,14 @@ func RegisterKnowledgeTool(
},
"risk_type": map[string]interface{}{
"type": "string",
"description": "可选指定风险类型SQL注入、XSS、文件上传等如果不指定则搜索所有类型",
"description": "可选指定风险类型SQL注入、XSS、文件上传等。建议先调用 list_knowledge_risk_types 工具获取可用的风险类型列表,然后使用正确的风险类型进行精确搜索,这样可以大幅减少检索时间。如果不指定则搜索所有类型",
},
},
"required": []string{"query"},
},
}
handler := func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
searchHandler := func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
query, ok := args["query"].(string)
if !ok || query == "" {
return &mcp.ToolResult{
@@ -182,8 +239,8 @@ func RegisterKnowledgeTool(
}, nil
}
mcpServer.RegisterTool(tool, handler)
logger.Info("知识检索工具已注册", zap.String("toolName", tool.Name))
mcpServer.RegisterTool(searchTool, searchHandler)
logger.Info("知识检索工具已注册", zap.String("toolName", searchTool.Name))
}
// contains 检查切片是否包含元素