diff --git a/internal/handler/auth.go b/internal/handler/auth.go index 6af09008..508553c1 100644 --- a/internal/handler/auth.go +++ b/internal/handler/auth.go @@ -103,7 +103,7 @@ func (h *AuthHandler) ChangePassword(c *gin.Context) { } if !h.manager.CheckPassword(oldPassword) { - c.JSON(http.StatusUnauthorized, gin.H{"error": "当前密码不正确"}) + c.JSON(http.StatusBadRequest, gin.H{"error": "当前密码不正确"}) return } diff --git a/internal/knowledge/indexer.go b/internal/knowledge/indexer.go index 0c35a74a..275f787f 100644 --- a/internal/knowledge/indexer.go +++ b/internal/knowledge/indexer.go @@ -250,9 +250,9 @@ func (idx *Indexer) estimateTokens(text string) int { // IndexItem 索引知识项(分块并向量化) func (idx *Indexer) IndexItem(ctx context.Context, itemID string) error { - // 获取知识项 - var content string - err := idx.db.QueryRow("SELECT content FROM knowledge_base_items WHERE id = ?", itemID).Scan(&content) + // 获取知识项(包含category和title,用于向量化) + var content, category, title string + err := idx.db.QueryRow("SELECT content, category, title FROM knowledge_base_items WHERE id = ?", itemID).Scan(&content, &category, &title) if err != nil { return fmt.Errorf("获取知识项失败: %w", err) } @@ -267,13 +267,19 @@ func (idx *Indexer) IndexItem(ctx context.Context, itemID string) error { chunks := idx.ChunkText(content) idx.logger.Info("知识项分块完成", zap.String("itemId", itemID), zap.Int("chunks", len(chunks))) - // 向量化每个块 + // 向量化每个块(包含category和title信息,以便向量检索时能匹配到风险类型) for i, chunk := range chunks { chunkPreview := chunk if len(chunkPreview) > 200 { chunkPreview = chunkPreview[:200] + "..." } - embedding, err := idx.embedder.EmbedText(ctx, chunk) + + // 将category和title信息包含到向量化的文本中 + // 格式:"[风险类型: {category}] [标题: {title}]\n{chunk内容}" + // 这样向量嵌入就会包含风险类型信息,即使SQL过滤失败,向量相似度也能帮助匹配 + textForEmbedding := fmt.Sprintf("[风险类型: %s] [标题: %s]\n%s", category, title, chunk) + + embedding, err := idx.embedder.EmbedText(ctx, textForEmbedding) if err != nil { idx.logger.Warn("向量化失败", zap.String("itemId", itemID), diff --git a/internal/knowledge/retriever.go b/internal/knowledge/retriever.go index ad7245e9..d859fcaf 100644 --- a/internal/knowledge/retriever.go +++ b/internal/knowledge/retriever.go @@ -102,20 +102,32 @@ func (r *Retriever) Search(ctx context.Context, req *SearchRequest) ([]*Retrieva threshold = 0.7 } - // 向量化查询 - queryEmbedding, err := r.embedder.EmbedText(ctx, req.Query) + // 向量化查询(如果提供了risk_type,也包含在查询文本中,以便更好地匹配) + queryText := req.Query + if req.RiskType != "" { + // 将risk_type信息包含到查询中,格式与索引时保持一致 + queryText = fmt.Sprintf("[风险类型: %s] %s", req.RiskType, req.Query) + } + queryEmbedding, err := r.embedder.EmbedText(ctx, queryText) if err != nil { return nil, fmt.Errorf("向量化查询失败: %w", err) } // 查询所有向量(或按风险类型过滤) + // 使用精确匹配(=)以提高性能和准确性 + // 由于系统提供了 list_knowledge_risk_types 工具,用户应该使用准确的category名称 + // 同时,向量嵌入中已包含category信息,即使SQL过滤不完全匹配,向量相似度也能帮助匹配 var rows *sql.Rows if req.RiskType != "" { + // 使用精确匹配(=),性能更好且更准确 + // 使用 COLLATE NOCASE 实现大小写不敏感匹配,提高容错性 + // 注意:如果用户输入的risk_type与category不完全一致,可能匹配不到 + // 建议用户先调用 list_knowledge_risk_types 获取准确的category名称 rows, err = r.db.Query(` SELECT e.id, e.item_id, e.chunk_index, e.chunk_text, e.embedding, i.category, i.title FROM knowledge_embeddings e JOIN knowledge_base_items i ON e.item_id = i.id - WHERE i.category = ? + WHERE i.category = ? COLLATE NOCASE `, req.RiskType) } else { rows, err = r.db.Query(` diff --git a/internal/knowledge/tool.go b/internal/knowledge/tool.go index 075c35cb..b6eb2c52 100644 --- a/internal/knowledge/tool.go +++ b/internal/knowledge/tool.go @@ -19,11 +19,68 @@ func RegisterKnowledgeTool( manager *Manager, logger *zap.Logger, ) { - // manager 和 retriever 在 handler 中直接使用参数 - _ = manager // 保留参数,可能将来用于日志记录等 - tool := mcp.Tool{ + // 注册第一个工具:获取所有可用的风险类型列表 + listRiskTypesTool := mcp.Tool{ + Name: "list_knowledge_risk_types", + Description: "获取知识库中所有可用的风险类型(risk_type)列表。在搜索知识库之前,可以先调用此工具获取可用的风险类型,然后使用正确的风险类型进行精确搜索,这样可以大幅减少检索时间并提高检索准确性。", + ShortDescription: "获取知识库中所有可用的风险类型列表", + InputSchema: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{}, + "required": []string{}, + }, + } + + listRiskTypesHandler := func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) { + categories, err := manager.GetCategories() + if err != nil { + logger.Error("获取风险类型列表失败", zap.Error(err)) + return &mcp.ToolResult{ + Content: []mcp.Content{ + { + Type: "text", + Text: fmt.Sprintf("获取风险类型列表失败: %v", err), + }, + }, + IsError: true, + }, nil + } + + if len(categories) == 0 { + return &mcp.ToolResult{ + Content: []mcp.Content{ + { + Type: "text", + Text: "知识库中暂无风险类型。", + }, + }, + }, nil + } + + var resultText strings.Builder + resultText.WriteString(fmt.Sprintf("知识库中共有 %d 个风险类型:\n\n", len(categories))) + for i, category := range categories { + resultText.WriteString(fmt.Sprintf("%d. %s\n", i+1, category)) + } + resultText.WriteString("\n提示:在调用 search_knowledge_base 工具时,可以使用上述风险类型之一作为 risk_type 参数,以缩小搜索范围并提高检索效率。") + + return &mcp.ToolResult{ + Content: []mcp.Content{ + { + Type: "text", + Text: resultText.String(), + }, + }, + }, nil + } + + mcpServer.RegisterTool(listRiskTypesTool, listRiskTypesHandler) + logger.Info("风险类型列表工具已注册", zap.String("toolName", listRiskTypesTool.Name)) + + // 注册第二个工具:搜索知识库(保持原有功能) + searchTool := mcp.Tool{ Name: "search_knowledge_base", - Description: "在知识库中搜索相关的安全知识。当你需要了解特定漏洞类型、攻击技术、检测方法等安全知识时,可以使用此工具进行检索。工具使用向量检索和混合搜索技术,能够根据查询内容的语义相似度和关键词匹配,自动找到最相关的知识片段。", + Description: "在知识库中搜索相关的安全知识。当你需要了解特定漏洞类型、攻击技术、检测方法等安全知识时,可以使用此工具进行检索。工具使用向量检索和混合搜索技术,能够根据查询内容的语义相似度和关键词匹配,自动找到最相关的知识片段。建议:在搜索前可以先调用 list_knowledge_risk_types 工具获取可用的风险类型,然后使用正确的 risk_type 参数进行精确搜索,这样可以大幅减少检索时间。", ShortDescription: "搜索知识库中的安全知识(支持向量检索和混合搜索)", InputSchema: map[string]interface{}{ "type": "object", @@ -34,14 +91,14 @@ func RegisterKnowledgeTool( }, "risk_type": map[string]interface{}{ "type": "string", - "description": "可选:指定风险类型(如:SQL注入、XSS、文件上传等),如果不指定则搜索所有类型", + "description": "可选:指定风险类型(如:SQL注入、XSS、文件上传等)。建议先调用 list_knowledge_risk_types 工具获取可用的风险类型列表,然后使用正确的风险类型进行精确搜索,这样可以大幅减少检索时间。如果不指定则搜索所有类型。", }, }, "required": []string{"query"}, }, } - handler := func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) { + searchHandler := func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) { query, ok := args["query"].(string) if !ok || query == "" { return &mcp.ToolResult{ @@ -182,8 +239,8 @@ func RegisterKnowledgeTool( }, nil } - mcpServer.RegisterTool(tool, handler) - logger.Info("知识检索工具已注册", zap.String("toolName", tool.Name)) + mcpServer.RegisterTool(searchTool, searchHandler) + logger.Info("知识检索工具已注册", zap.String("toolName", searchTool.Name)) } // contains 检查切片是否包含元素