mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-03-31 00:09:29 +02:00
Add files via upload
This commit is contained in:
@@ -103,7 +103,7 @@ func (h *AuthHandler) ChangePassword(c *gin.Context) {
|
||||
}
|
||||
|
||||
if !h.manager.CheckPassword(oldPassword) {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "当前密码不正确"})
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "当前密码不正确"})
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -250,9 +250,9 @@ func (idx *Indexer) estimateTokens(text string) int {
|
||||
|
||||
// IndexItem 索引知识项(分块并向量化)
|
||||
func (idx *Indexer) IndexItem(ctx context.Context, itemID string) error {
|
||||
// 获取知识项
|
||||
var content string
|
||||
err := idx.db.QueryRow("SELECT content FROM knowledge_base_items WHERE id = ?", itemID).Scan(&content)
|
||||
// 获取知识项(包含category和title,用于向量化)
|
||||
var content, category, title string
|
||||
err := idx.db.QueryRow("SELECT content, category, title FROM knowledge_base_items WHERE id = ?", itemID).Scan(&content, &category, &title)
|
||||
if err != nil {
|
||||
return fmt.Errorf("获取知识项失败: %w", err)
|
||||
}
|
||||
@@ -267,13 +267,19 @@ func (idx *Indexer) IndexItem(ctx context.Context, itemID string) error {
|
||||
chunks := idx.ChunkText(content)
|
||||
idx.logger.Info("知识项分块完成", zap.String("itemId", itemID), zap.Int("chunks", len(chunks)))
|
||||
|
||||
// 向量化每个块
|
||||
// 向量化每个块(包含category和title信息,以便向量检索时能匹配到风险类型)
|
||||
for i, chunk := range chunks {
|
||||
chunkPreview := chunk
|
||||
if len(chunkPreview) > 200 {
|
||||
chunkPreview = chunkPreview[:200] + "..."
|
||||
}
|
||||
embedding, err := idx.embedder.EmbedText(ctx, chunk)
|
||||
|
||||
// 将category和title信息包含到向量化的文本中
|
||||
// 格式:"[风险类型: {category}] [标题: {title}]\n{chunk内容}"
|
||||
// 这样向量嵌入就会包含风险类型信息,即使SQL过滤失败,向量相似度也能帮助匹配
|
||||
textForEmbedding := fmt.Sprintf("[风险类型: %s] [标题: %s]\n%s", category, title, chunk)
|
||||
|
||||
embedding, err := idx.embedder.EmbedText(ctx, textForEmbedding)
|
||||
if err != nil {
|
||||
idx.logger.Warn("向量化失败",
|
||||
zap.String("itemId", itemID),
|
||||
|
||||
@@ -102,20 +102,32 @@ func (r *Retriever) Search(ctx context.Context, req *SearchRequest) ([]*Retrieva
|
||||
threshold = 0.7
|
||||
}
|
||||
|
||||
// 向量化查询
|
||||
queryEmbedding, err := r.embedder.EmbedText(ctx, req.Query)
|
||||
// 向量化查询(如果提供了risk_type,也包含在查询文本中,以便更好地匹配)
|
||||
queryText := req.Query
|
||||
if req.RiskType != "" {
|
||||
// 将risk_type信息包含到查询中,格式与索引时保持一致
|
||||
queryText = fmt.Sprintf("[风险类型: %s] %s", req.RiskType, req.Query)
|
||||
}
|
||||
queryEmbedding, err := r.embedder.EmbedText(ctx, queryText)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("向量化查询失败: %w", err)
|
||||
}
|
||||
|
||||
// 查询所有向量(或按风险类型过滤)
|
||||
// 使用精确匹配(=)以提高性能和准确性
|
||||
// 由于系统提供了 list_knowledge_risk_types 工具,用户应该使用准确的category名称
|
||||
// 同时,向量嵌入中已包含category信息,即使SQL过滤不完全匹配,向量相似度也能帮助匹配
|
||||
var rows *sql.Rows
|
||||
if req.RiskType != "" {
|
||||
// 使用精确匹配(=),性能更好且更准确
|
||||
// 使用 COLLATE NOCASE 实现大小写不敏感匹配,提高容错性
|
||||
// 注意:如果用户输入的risk_type与category不完全一致,可能匹配不到
|
||||
// 建议用户先调用 list_knowledge_risk_types 获取准确的category名称
|
||||
rows, err = r.db.Query(`
|
||||
SELECT e.id, e.item_id, e.chunk_index, e.chunk_text, e.embedding, i.category, i.title
|
||||
FROM knowledge_embeddings e
|
||||
JOIN knowledge_base_items i ON e.item_id = i.id
|
||||
WHERE i.category = ?
|
||||
WHERE i.category = ? COLLATE NOCASE
|
||||
`, req.RiskType)
|
||||
} else {
|
||||
rows, err = r.db.Query(`
|
||||
|
||||
@@ -19,11 +19,68 @@ func RegisterKnowledgeTool(
|
||||
manager *Manager,
|
||||
logger *zap.Logger,
|
||||
) {
|
||||
// manager 和 retriever 在 handler 中直接使用参数
|
||||
_ = manager // 保留参数,可能将来用于日志记录等
|
||||
tool := mcp.Tool{
|
||||
// 注册第一个工具:获取所有可用的风险类型列表
|
||||
listRiskTypesTool := mcp.Tool{
|
||||
Name: "list_knowledge_risk_types",
|
||||
Description: "获取知识库中所有可用的风险类型(risk_type)列表。在搜索知识库之前,可以先调用此工具获取可用的风险类型,然后使用正确的风险类型进行精确搜索,这样可以大幅减少检索时间并提高检索准确性。",
|
||||
ShortDescription: "获取知识库中所有可用的风险类型列表",
|
||||
InputSchema: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{},
|
||||
"required": []string{},
|
||||
},
|
||||
}
|
||||
|
||||
listRiskTypesHandler := func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
|
||||
categories, err := manager.GetCategories()
|
||||
if err != nil {
|
||||
logger.Error("获取风险类型列表失败", zap.Error(err))
|
||||
return &mcp.ToolResult{
|
||||
Content: []mcp.Content{
|
||||
{
|
||||
Type: "text",
|
||||
Text: fmt.Sprintf("获取风险类型列表失败: %v", err),
|
||||
},
|
||||
},
|
||||
IsError: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
if len(categories) == 0 {
|
||||
return &mcp.ToolResult{
|
||||
Content: []mcp.Content{
|
||||
{
|
||||
Type: "text",
|
||||
Text: "知识库中暂无风险类型。",
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
var resultText strings.Builder
|
||||
resultText.WriteString(fmt.Sprintf("知识库中共有 %d 个风险类型:\n\n", len(categories)))
|
||||
for i, category := range categories {
|
||||
resultText.WriteString(fmt.Sprintf("%d. %s\n", i+1, category))
|
||||
}
|
||||
resultText.WriteString("\n提示:在调用 search_knowledge_base 工具时,可以使用上述风险类型之一作为 risk_type 参数,以缩小搜索范围并提高检索效率。")
|
||||
|
||||
return &mcp.ToolResult{
|
||||
Content: []mcp.Content{
|
||||
{
|
||||
Type: "text",
|
||||
Text: resultText.String(),
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
mcpServer.RegisterTool(listRiskTypesTool, listRiskTypesHandler)
|
||||
logger.Info("风险类型列表工具已注册", zap.String("toolName", listRiskTypesTool.Name))
|
||||
|
||||
// 注册第二个工具:搜索知识库(保持原有功能)
|
||||
searchTool := mcp.Tool{
|
||||
Name: "search_knowledge_base",
|
||||
Description: "在知识库中搜索相关的安全知识。当你需要了解特定漏洞类型、攻击技术、检测方法等安全知识时,可以使用此工具进行检索。工具使用向量检索和混合搜索技术,能够根据查询内容的语义相似度和关键词匹配,自动找到最相关的知识片段。",
|
||||
Description: "在知识库中搜索相关的安全知识。当你需要了解特定漏洞类型、攻击技术、检测方法等安全知识时,可以使用此工具进行检索。工具使用向量检索和混合搜索技术,能够根据查询内容的语义相似度和关键词匹配,自动找到最相关的知识片段。建议:在搜索前可以先调用 list_knowledge_risk_types 工具获取可用的风险类型,然后使用正确的 risk_type 参数进行精确搜索,这样可以大幅减少检索时间。",
|
||||
ShortDescription: "搜索知识库中的安全知识(支持向量检索和混合搜索)",
|
||||
InputSchema: map[string]interface{}{
|
||||
"type": "object",
|
||||
@@ -34,14 +91,14 @@ func RegisterKnowledgeTool(
|
||||
},
|
||||
"risk_type": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "可选:指定风险类型(如:SQL注入、XSS、文件上传等),如果不指定则搜索所有类型",
|
||||
"description": "可选:指定风险类型(如:SQL注入、XSS、文件上传等)。建议先调用 list_knowledge_risk_types 工具获取可用的风险类型列表,然后使用正确的风险类型进行精确搜索,这样可以大幅减少检索时间。如果不指定则搜索所有类型。",
|
||||
},
|
||||
},
|
||||
"required": []string{"query"},
|
||||
},
|
||||
}
|
||||
|
||||
handler := func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
|
||||
searchHandler := func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
|
||||
query, ok := args["query"].(string)
|
||||
if !ok || query == "" {
|
||||
return &mcp.ToolResult{
|
||||
@@ -182,8 +239,8 @@ func RegisterKnowledgeTool(
|
||||
}, nil
|
||||
}
|
||||
|
||||
mcpServer.RegisterTool(tool, handler)
|
||||
logger.Info("知识检索工具已注册", zap.String("toolName", tool.Name))
|
||||
mcpServer.RegisterTool(searchTool, searchHandler)
|
||||
logger.Info("知识检索工具已注册", zap.String("toolName", searchTool.Name))
|
||||
}
|
||||
|
||||
// contains 检查切片是否包含元素
|
||||
|
||||
Reference in New Issue
Block a user