From abc4085c8a36f0c71847d4e3bd0acb399c0f912a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=85=AC=E6=98=8E?= <83812544+Ed1s0nZ@users.noreply.github.com> Date: Sat, 20 Dec 2025 17:36:40 +0800 Subject: [PATCH] Add files via upload --- README.md | 47 + README_CN.md | 47 + config.yaml | 14 + internal/app/app.go | 151 ++- internal/config/config.go | 44 +- internal/database/database.go | 113 ++- internal/handler/agent.go | 146 ++- internal/handler/config.go | 193 +++- internal/handler/knowledge.go | 248 +++++ internal/knowledge/embedder.go | 205 ++++ internal/knowledge/indexer.go | 247 +++++ internal/knowledge/manager.go | 447 +++++++++ internal/knowledge/retriever.go | 230 +++++ internal/knowledge/tool.go | 191 ++++ internal/knowledge/types.go | 67 ++ web/static/css/style.css | 1021 +++++++++++++++++++- web/static/js/chat.js | 7 + web/static/js/knowledge.js | 1558 +++++++++++++++++++++++++++++++ web/static/js/router.js | 17 +- web/static/js/settings.js | 75 ++ web/templates/index.html | 212 +++++ 21 files changed, 5234 insertions(+), 46 deletions(-) create mode 100644 internal/handler/knowledge.go create mode 100644 internal/knowledge/embedder.go create mode 100644 internal/knowledge/indexer.go create mode 100644 internal/knowledge/manager.go create mode 100644 internal/knowledge/retriever.go create mode 100644 internal/knowledge/tool.go create mode 100644 internal/knowledge/types.go create mode 100644 web/static/js/knowledge.js diff --git a/README.md b/README.md index 9dd78c1e..794e9696 100644 --- a/README.md +++ b/README.md @@ -31,6 +31,7 @@ CyberStrikeAI is an **AI-native penetration-testing copilot** built in Go. It co - 📄 Large-result pagination, compression, and searchable archives - 🔗 Attack-chain graph, risk scoring, and step-by-step replay - 🔒 Password-protected web UI, audit logs, and SQLite persistence +- 📚 Knowledge base with vector search and hybrid retrieval for security expertise ## Tool Overview @@ -175,6 +176,38 @@ CyberStrikeAI ships with 100+ curated tools covering the whole kill chain: } ``` +### Knowledge Base +- **Vector search** – AI agent can automatically search the knowledge base for relevant security knowledge during conversations using the `search_knowledge_base` tool. +- **Hybrid retrieval** – combines vector similarity search with keyword matching for better accuracy. +- **Auto-indexing** – scans the `knowledge_base/` directory for Markdown files and automatically indexes them with embeddings. +- **Web management** – create, update, delete knowledge items through the web UI, with category-based organization. +- **Retrieval logs** – tracks all knowledge retrieval operations for audit and debugging. + +**Setting up the knowledge base:** +1. **Enable in config** – set `knowledge.enabled: true` in `config.yaml`: + ```yaml + knowledge: + enabled: true + base_path: knowledge_base + embedding: + provider: openai + model: text-embedding-v4 + base_url: "https://api.openai.com/v1" # or your embedding API + api_key: "sk-xxx" + retrieval: + top_k: 5 + similarity_threshold: 0.7 + hybrid_weight: 0.7 + ``` +2. **Add knowledge files** – place Markdown files in `knowledge_base/` directory, organized by category (e.g., `knowledge_base/SQL Injection/README.md`). +3. **Scan and index** – use the web UI to scan the knowledge base directory, which will automatically import files and build vector embeddings. +4. **Use in conversations** – the AI agent will automatically use `search_knowledge_base` when it needs security knowledge. You can also explicitly ask: "Search the knowledge base for SQL injection techniques". + +**Knowledge base structure:** +- Files are organized by category (directory name becomes the category). +- Each Markdown file becomes a knowledge item with automatic chunking for vector search. +- The system supports incremental updates – modified files are re-indexed automatically. + ### Automation Hooks - **REST APIs** – everything the UI uses (auth, conversations, tool runs, monitor) is available over JSON. - **Task control** – pause/resume/stop long scans, re-run steps with new params, or stream transcripts. @@ -202,8 +235,21 @@ openai: model: "deepseek-chat" database: path: "data/conversations.db" + knowledge_db_path: "data/knowledge.db" # Optional: separate DB for knowledge base security: tools_dir: "tools" +knowledge: + enabled: false # Enable knowledge base feature + base_path: "knowledge_base" # Path to knowledge base directory + embedding: + provider: "openai" # Embedding provider (currently only "openai") + model: "text-embedding-v4" # Embedding model name + base_url: "" # Leave empty to use OpenAI base_url + api_key: "" # Leave empty to use OpenAI api_key + retrieval: + top_k: 5 # Number of top results to return + similarity_threshold: 0.7 # Minimum similarity score (0-1) + hybrid_weight: 0.7 # Weight for vector search (1.0 = pure vector, 0.0 = pure keyword) ``` ### Tool Definition Example (`tools/nmap.yaml`) @@ -261,6 +307,7 @@ Build an attack chain for the latest engagement and export the node list with se ## Changelog (Recent) +- 2025-12-20 – Added knowledge base feature with vector search, hybrid retrieval, and automatic indexing. AI agent can now search security knowledge during conversations. - 2025-12-19 – Added ZoomEye network space search engine tool (zoomeye_search) with support for IPv4/IPv6/web assets, facets statistics, and flexible query parameters. - 2025-12-18 – Optimized web frontend with enhanced sidebar navigation and improved user experience. - 2025-12-07 – Added FOFA network space search engine tool (fofa_search) with flexible query parameters and field configuration. diff --git a/README_CN.md b/README_CN.md index 76ec4d4d..8bba2baf 100644 --- a/README_CN.md +++ b/README_CN.md @@ -30,6 +30,7 @@ CyberStrikeAI 是一款 **AI 原生渗透测试协同体**,以 Go 编写,内 - 📄 大结果分页、压缩与全文检索 - 🔗 攻击链可视化、风险打分与步骤回放 - 🔒 Web 登录保护、审计日志、SQLite 持久化 +- 📚 知识库功能:向量检索与混合搜索,为 AI 提供安全专业知识 ## 工具概览 @@ -173,6 +174,38 @@ CyberStrikeAI 是一款 **AI 原生渗透测试协同体**,以 Go 编写,内 } ``` +### 知识库功能 +- **向量检索**:AI 智能体在对话过程中可自动调用 `search_knowledge_base` 工具搜索知识库中的安全知识。 +- **混合检索**:结合向量相似度搜索与关键词匹配,提升检索准确性。 +- **自动索引**:扫描 `knowledge_base/` 目录下的 Markdown 文件,自动构建向量嵌入索引。 +- **Web 管理**:通过 Web 界面创建、更新、删除知识项,支持分类管理。 +- **检索日志**:记录所有知识检索操作,便于审计与调试。 + +**知识库配置步骤:** +1. **启用功能**:在 `config.yaml` 中设置 `knowledge.enabled: true`: + ```yaml + knowledge: + enabled: true + base_path: knowledge_base + embedding: + provider: openai + model: text-embedding-v4 + base_url: "https://api.openai.com/v1" # 或你的嵌入模型 API + api_key: "sk-xxx" + retrieval: + top_k: 5 + similarity_threshold: 0.7 + hybrid_weight: 0.7 + ``` +2. **添加知识文件**:将 Markdown 文件放入 `knowledge_base/` 目录,按分类组织(如 `knowledge_base/SQL注入/README.md`)。 +3. **扫描索引**:在 Web 界面中点击"扫描知识库",系统会自动导入文件并构建向量索引。 +4. **对话中使用**:AI 智能体在需要安全知识时会自动调用知识检索工具。你也可以显式要求:"搜索知识库中关于 SQL 注入的技术"。 + +**知识库结构说明:** +- 文件按分类组织(目录名作为分类)。 +- 每个 Markdown 文件自动切块并生成向量嵌入。 +- 支持增量更新,修改后的文件会自动重新索引。 + ### 自动化与安全 - **REST API**:认证、会话、任务、监控等接口全部开放,可与 CI/CD 集成。 - **任务控制**:支持暂停/终止长任务、修改参数后重跑、流式获取日志。 @@ -200,8 +233,21 @@ openai: model: "deepseek-chat" database: path: "data/conversations.db" + knowledge_db_path: "data/knowledge.db" # 可选:知识库独立数据库 security: tools_dir: "tools" +knowledge: + enabled: false # 是否启用知识库功能 + base_path: "knowledge_base" # 知识库目录路径 + embedding: + provider: "openai" # 嵌入模型提供商(目前仅支持 openai) + model: "text-embedding-v4" # 嵌入模型名称 + base_url: "" # 留空则使用 OpenAI 配置的 base_url + api_key: "" # 留空则使用 OpenAI 配置的 api_key + retrieval: + top_k: 5 # 检索返回的 Top-K 结果数量 + similarity_threshold: 0.7 # 相似度阈值(0-1),低于此值的结果将被过滤 + hybrid_weight: 0.7 # 混合检索权重(0-1),向量检索的权重,1.0 表示纯向量检索,0.0 表示纯关键词检索 ``` ### 工具模版示例(`tools/nmap.yaml`) @@ -258,6 +304,7 @@ CyberStrikeAI/ ``` ## Changelog(近期) +- 2025-12-20 —— 新增知识库功能:支持向量检索、混合搜索与自动索引,AI 智能体可在对话中自动搜索安全知识。 - 2025-12-19 —— 新增钟馗之眼(ZoomEye)网络空间搜索引擎工具(zoomeye_search),支持 IPv4/IPv6/Web 等资产搜索、统计项查询与灵活的查询参数配置。 - 2025-12-18 —— 优化 Web 前端界面,增加侧边栏导航,提升用户体验。 - 2025-12-07 —— 新增 FOFA 网络空间搜索引擎工具(fofa_search),支持灵活的查询参数与字段配置。 diff --git a/config.yaml b/config.yaml index 2c63c225..d1113395 100644 --- a/config.yaml +++ b/config.yaml @@ -44,6 +44,7 @@ agent: # 数据库配置 database: path: data/conversations.db # SQLite 数据库文件路径,用于存储对话历史和消息 + knowledge_db_path: data/knowledge.db # 知识库数据库文件路径(可选,为空则使用会话数据库),用于存储知识库项和向量嵌入,可独立复制和复用 # 安全工具配置 security: tools_dir: tools # 工具配置文件目录(相对于配置文件所在目录) @@ -52,3 +53,16 @@ security: # 外部MCP配置 external_mcp: servers: {} +# 知识库配置 +knowledge: + enabled: true # 是否启用知识检索功能 + base_path: knowledge_base # 知识库目录路径(相对于配置文件所在目录) + embedding: + provider: openai # 嵌入模型提供商(目前仅支持openai) + model: text-embedding-v4 # 嵌入模型名称 + base_url: https://api.deepseek.com/v1 # 留空则使用OpenAI配置的base_url + api_key: sk-xxxxxx # 留空则使用OpenAI配置的api_key + retrieval: + top_k: 5 # 检索返回的Top-K结果数量 + similarity_threshold: 0.7 # 相似度阈值(0-1),低于此值的结果将被过滤 + hybrid_weight: 0.7 # 混合检索权重(0-1),向量检索的权重,1.0表示纯向量检索,0.0表示纯关键词检索 diff --git a/internal/app/app.go b/internal/app/app.go index 845f18bf..1285b187 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -1,17 +1,22 @@ package app import ( + "context" + "database/sql" "fmt" "net/http" "os" "path/filepath" + "time" "cyberstrike-ai/internal/agent" "cyberstrike-ai/internal/config" "cyberstrike-ai/internal/database" "cyberstrike-ai/internal/handler" + "cyberstrike-ai/internal/knowledge" "cyberstrike-ai/internal/logger" "cyberstrike-ai/internal/mcp" + "cyberstrike-ai/internal/openai" "cyberstrike-ai/internal/security" "cyberstrike-ai/internal/storage" @@ -29,6 +34,7 @@ type App struct { agent *agent.Agent executor *security.Executor db *database.DB + knowledgeDB *database.DB // 知识库数据库连接(如果使用独立数据库) auth *security.AuthManager } @@ -91,31 +97,128 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) { if cfg.Agent.ResultStorageDir != "" { resultStorageDir = cfg.Agent.ResultStorageDir } - + // 确保存储目录存在 if err := os.MkdirAll(resultStorageDir, 0755); err != nil { return nil, fmt.Errorf("创建结果存储目录失败: %w", err) } - + // 创建结果存储实例 resultStorage, err := storage.NewFileResultStorage(resultStorageDir, log.Logger) if err != nil { return nil, fmt.Errorf("初始化结果存储失败: %w", err) } - + // 创建Agent maxIterations := cfg.Agent.MaxIterations if maxIterations <= 0 { maxIterations = 30 // 默认值 } agent := agent.NewAgent(&cfg.OpenAI, &cfg.Agent, mcpServer, externalMCPMgr, log.Logger, maxIterations) - + // 设置结果存储到Agent agent.SetResultStorage(resultStorage) - + // 设置结果存储到Executor(用于查询工具) executor.SetResultStorage(resultStorage) + // 初始化知识库模块(如果启用) + var knowledgeManager *knowledge.Manager + var knowledgeRetriever *knowledge.Retriever + var knowledgeIndexer *knowledge.Indexer + var knowledgeHandler *handler.KnowledgeHandler + + var knowledgeDBConn *database.DB + log.Logger.Info("检查知识库配置", zap.Bool("enabled", cfg.Knowledge.Enabled)) + if cfg.Knowledge.Enabled { + // 确定知识库数据库路径 + knowledgeDBPath := cfg.Database.KnowledgeDBPath + var knowledgeDB *sql.DB + + if knowledgeDBPath != "" { + // 使用独立的知识库数据库 + // 确保目录存在 + if err := os.MkdirAll(filepath.Dir(knowledgeDBPath), 0755); err != nil { + return nil, fmt.Errorf("创建知识库数据库目录失败: %w", err) + } + + var err error + knowledgeDBConn, err = database.NewKnowledgeDB(knowledgeDBPath, log.Logger) + if err != nil { + return nil, fmt.Errorf("初始化知识库数据库失败: %w", err) + } + knowledgeDB = knowledgeDBConn.DB + log.Logger.Info("使用独立的知识库数据库", zap.String("path", knowledgeDBPath)) + } else { + // 向后兼容:使用会话数据库 + knowledgeDB = db.DB + log.Logger.Info("使用会话数据库存储知识库数据(建议配置knowledge_db_path以分离数据)") + } + + // 创建知识库管理器 + knowledgeManager = knowledge.NewManager(knowledgeDB, cfg.Knowledge.BasePath, log.Logger) + + // 创建嵌入器 + // 使用OpenAI配置的API Key(如果知识库配置中没有指定) + if cfg.Knowledge.Embedding.APIKey == "" { + cfg.Knowledge.Embedding.APIKey = cfg.OpenAI.APIKey + } + if cfg.Knowledge.Embedding.BaseURL == "" { + cfg.Knowledge.Embedding.BaseURL = cfg.OpenAI.BaseURL + } + + httpClient := &http.Client{ + Timeout: 30 * time.Minute, + } + openAIClient := openai.NewClient(&cfg.OpenAI, httpClient, log.Logger) + embedder := knowledge.NewEmbedder(&cfg.Knowledge, &cfg.OpenAI, openAIClient, log.Logger) + + // 创建检索器 + retrievalConfig := &knowledge.RetrievalConfig{ + TopK: cfg.Knowledge.Retrieval.TopK, + SimilarityThreshold: cfg.Knowledge.Retrieval.SimilarityThreshold, + HybridWeight: cfg.Knowledge.Retrieval.HybridWeight, + } + knowledgeRetriever = knowledge.NewRetriever(knowledgeDB, embedder, retrievalConfig, log.Logger) + + // 创建索引器 + knowledgeIndexer = knowledge.NewIndexer(knowledgeDB, embedder, log.Logger) + + // 注册知识检索工具到MCP服务器 + knowledge.RegisterKnowledgeTool(mcpServer, knowledgeRetriever, knowledgeManager, log.Logger) + + // 创建知识库API处理器 + knowledgeHandler = handler.NewKnowledgeHandler(knowledgeManager, knowledgeRetriever, knowledgeIndexer, db, log.Logger) + log.Logger.Info("知识库模块初始化完成", zap.Bool("handler_created", knowledgeHandler != nil)) + + // 扫描知识库并建立索引(异步) + go func() { + if err := knowledgeManager.ScanKnowledgeBase(); err != nil { + log.Logger.Warn("扫描知识库失败", zap.Error(err)) + return + } + + // 检查是否已有索引,如果有则跳过自动重建 + hasIndex, err := knowledgeIndexer.HasIndex() + if err != nil { + log.Logger.Warn("检查索引状态失败", zap.Error(err)) + return + } + + if hasIndex { + log.Logger.Info("检测到已有知识库索引,跳过自动重建。如需重建,请手动点击重建索引按钮") + return + } + + // 只有在没有索引时才自动重建 + log.Logger.Info("未检测到知识库索引,开始自动构建索引") + ctx := context.Background() + if err := knowledgeIndexer.RebuildIndex(ctx); err != nil { + log.Logger.Warn("重建知识库索引失败", zap.Error(err)) + } + }() + } + // 获取配置文件路径 configPath := "config.yaml" if len(os.Args) > 1 { @@ -124,12 +227,25 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) { // 创建处理器 agentHandler := handler.NewAgentHandler(agent, db, log.Logger) + // 如果知识库已启用,设置知识库管理器到AgentHandler以便记录检索日志 + if knowledgeManager != nil { + agentHandler.SetKnowledgeManager(knowledgeManager) + } monitorHandler := handler.NewMonitorHandler(mcpServer, executor, db, log.Logger) monitorHandler.SetExternalMCPManager(externalMCPMgr) // 设置外部MCP管理器,以便获取外部MCP执行记录 conversationHandler := handler.NewConversationHandler(db, log.Logger) authHandler := handler.NewAuthHandler(authManager, cfg, configPath, log.Logger) attackChainHandler := handler.NewAttackChainHandler(db, &cfg.OpenAI, log.Logger) configHandler := handler.NewConfigHandler(configPath, cfg, mcpServer, executor, agent, attackChainHandler, externalMCPMgr, log.Logger) + // 如果知识库已启用,设置知识库工具注册器,以便在ApplyConfig时重新注册知识库工具 + if cfg.Knowledge.Enabled && knowledgeRetriever != nil && knowledgeManager != nil { + // 创建闭包,捕获knowledgeRetriever和knowledgeManager的引用 + registrar := func() error { + knowledge.RegisterKnowledgeTool(mcpServer, knowledgeRetriever, knowledgeManager, log.Logger) + return nil + } + configHandler.SetKnowledgeToolRegistrar(registrar) + } externalMCPHandler := handler.NewExternalMCPHandler(externalMCPMgr, cfg, configPath, log.Logger) // 设置路由 @@ -142,6 +258,7 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) { configHandler, externalMCPHandler, attackChainHandler, + knowledgeHandler, mcpServer, authManager, ) @@ -155,6 +272,7 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) { agent: agent, executor: executor, db: db, + knowledgeDB: knowledgeDBConn, auth: authManager, }, nil } @@ -189,6 +307,13 @@ func (a *App) Shutdown() { if a.externalMCPMgr != nil { a.externalMCPMgr.StopAll() } + + // 关闭知识库数据库连接(如果使用独立数据库) + if a.knowledgeDB != nil { + if err := a.knowledgeDB.Close(); err != nil { + a.logger.Logger.Warn("关闭知识库数据库连接失败", zap.Error(err)) + } + } } // setupRoutes 设置路由 @@ -201,6 +326,7 @@ func setupRoutes( configHandler *handler.ConfigHandler, externalMCPHandler *handler.ExternalMCPHandler, attackChainHandler *handler.AttackChainHandler, + knowledgeHandler *handler.KnowledgeHandler, mcpServer *mcp.Server, authManager *security.AuthManager, ) { @@ -258,6 +384,21 @@ func setupRoutes( protected.GET("/attack-chain/:conversationId", attackChainHandler.GetAttackChain) protected.POST("/attack-chain/:conversationId/regenerate", attackChainHandler.RegenerateAttackChain) + // 知识库管理(如果启用) + if knowledgeHandler != nil { + protected.GET("/knowledge/categories", knowledgeHandler.GetCategories) + protected.GET("/knowledge/items", knowledgeHandler.GetItems) + protected.GET("/knowledge/items/:id", knowledgeHandler.GetItem) + protected.POST("/knowledge/items", knowledgeHandler.CreateItem) + protected.PUT("/knowledge/items/:id", knowledgeHandler.UpdateItem) + protected.DELETE("/knowledge/items/:id", knowledgeHandler.DeleteItem) + protected.GET("/knowledge/index-status", knowledgeHandler.GetIndexStatus) + protected.POST("/knowledge/index", knowledgeHandler.RebuildIndex) + protected.POST("/knowledge/scan", knowledgeHandler.ScanKnowledgeBase) + protected.GET("/knowledge/retrieval-logs", knowledgeHandler.GetRetrievalLogs) + protected.POST("/knowledge/search", knowledgeHandler.Search) + } + // MCP端点 protected.POST("/mcp", func(c *gin.Context) { mcpServer.HandleHTTP(c.Writer, c.Request) diff --git a/internal/config/config.go b/internal/config/config.go index c8d61e4c..975e4fae 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -21,6 +21,7 @@ type Config struct { Database DatabaseConfig `yaml:"database"` Auth AuthConfig `yaml:"auth"` ExternalMCP ExternalMCPConfig `yaml:"external_mcp,omitempty"` + Knowledge KnowledgeConfig `yaml:"knowledge,omitempty"` } type ServerConfig struct { @@ -52,7 +53,8 @@ type SecurityConfig struct { } type DatabaseConfig struct { - Path string `yaml:"path"` + Path string `yaml:"path"` // 会话数据库路径 + KnowledgeDBPath string `yaml:"knowledge_db_path,omitempty"` // 知识库数据库路径(可选,为空则使用会话数据库) } type AgentConfig struct { @@ -399,10 +401,48 @@ func Default() *Config { ToolsDir: "tools", // 默认工具目录 }, Database: DatabaseConfig{ - Path: "data/conversations.db", + Path: "data/conversations.db", + KnowledgeDBPath: "data/knowledge.db", // 默认知识库数据库路径 }, Auth: AuthConfig{ SessionDurationHours: 12, }, + Knowledge: KnowledgeConfig{ + Enabled: true, + BasePath: "knowledge_base", + Embedding: EmbeddingConfig{ + Provider: "openai", + Model: "text-embedding-3-small", + BaseURL: "https://api.openai.com/v1", + }, + Retrieval: RetrievalConfig{ + TopK: 5, + SimilarityThreshold: 0.7, + HybridWeight: 0.7, + }, + }, } } + +// KnowledgeConfig 知识库配置 +type KnowledgeConfig struct { + Enabled bool `yaml:"enabled" json:"enabled"` // 是否启用知识检索 + BasePath string `yaml:"base_path" json:"base_path"` // 知识库路径 + Embedding EmbeddingConfig `yaml:"embedding" json:"embedding"` + Retrieval RetrievalConfig `yaml:"retrieval" json:"retrieval"` +} + +// EmbeddingConfig 嵌入配置 +type EmbeddingConfig struct { + Provider string `yaml:"provider" json:"provider"` // 嵌入模型提供商 + Model string `yaml:"model" json:"model"` // 模型名称 + BaseURL string `yaml:"base_url" json:"base_url"` // API Base URL + APIKey string `yaml:"api_key" json:"api_key"` // API Key(从OpenAI配置继承) +} + +// RetrievalConfig 检索配置 +type RetrievalConfig struct { + TopK int `yaml:"top_k" json:"top_k"` // 检索Top-K + SimilarityThreshold float64 `yaml:"similarity_threshold" json:"similarity_threshold"` // 相似度阈值 + HybridWeight float64 `yaml:"hybrid_weight" json:"hybrid_weight"` // 向量检索权重(0-1) +} diff --git a/internal/database/database.go b/internal/database/database.go index 55676264..b0221e22 100644 --- a/internal/database/database.go +++ b/internal/database/database.go @@ -131,6 +131,20 @@ func (db *DB) initTables() error { FOREIGN KEY (target_node_id) REFERENCES attack_chain_nodes(id) ON DELETE CASCADE );` + // 创建知识检索日志表(保留在会话数据库中,因为有外键关联) + createKnowledgeRetrievalLogsTable := ` + CREATE TABLE IF NOT EXISTS knowledge_retrieval_logs ( + id TEXT PRIMARY KEY, + conversation_id TEXT, + message_id TEXT, + query TEXT NOT NULL, + risk_type TEXT, + retrieved_items TEXT, + created_at DATETIME NOT NULL, + FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE SET NULL, + FOREIGN KEY (message_id) REFERENCES messages(id) ON DELETE SET NULL + );` + // 创建索引 createIndexes := ` CREATE INDEX IF NOT EXISTS idx_messages_conversation_id ON messages(conversation_id); @@ -144,6 +158,9 @@ func (db *DB) initTables() error { CREATE INDEX IF NOT EXISTS idx_chain_edges_conversation ON attack_chain_edges(conversation_id); CREATE INDEX IF NOT EXISTS idx_chain_edges_source ON attack_chain_edges(source_node_id); CREATE INDEX IF NOT EXISTS idx_chain_edges_target ON attack_chain_edges(target_node_id); + CREATE INDEX IF NOT EXISTS idx_knowledge_retrieval_logs_conversation ON knowledge_retrieval_logs(conversation_id); + CREATE INDEX IF NOT EXISTS idx_knowledge_retrieval_logs_message ON knowledge_retrieval_logs(message_id); + CREATE INDEX IF NOT EXISTS idx_knowledge_retrieval_logs_created_at ON knowledge_retrieval_logs(created_at); ` if _, err := db.Exec(createConversationsTable); err != nil { @@ -174,6 +191,10 @@ func (db *DB) initTables() error { return fmt.Errorf("创建attack_chain_edges表失败: %w", err) } + if _, err := db.Exec(createKnowledgeRetrievalLogsTable); err != nil { + return fmt.Errorf("创建knowledge_retrieval_logs表失败: %w", err) + } + if _, err := db.Exec(createIndexes); err != nil { return fmt.Errorf("创建索引失败: %w", err) } @@ -182,8 +203,98 @@ func (db *DB) initTables() error { return nil } +// NewKnowledgeDB 创建知识库数据库连接(只包含知识库相关的表) +func NewKnowledgeDB(dbPath string, logger *zap.Logger) (*DB, error) { + sqlDB, err := sql.Open("sqlite3", dbPath+"?_journal_mode=WAL&_foreign_keys=1") + if err != nil { + return nil, fmt.Errorf("打开知识库数据库失败: %w", err) + } + + if err := sqlDB.Ping(); err != nil { + return nil, fmt.Errorf("连接知识库数据库失败: %w", err) + } + + database := &DB{ + DB: sqlDB, + logger: logger, + } + + // 初始化知识库表 + if err := database.initKnowledgeTables(); err != nil { + return nil, fmt.Errorf("初始化知识库表失败: %w", err) + } + + return database, nil +} + +// initKnowledgeTables 初始化知识库数据库表(只包含知识库相关的表) +func (db *DB) initKnowledgeTables() error { + // 创建知识库项表 + createKnowledgeBaseItemsTable := ` + CREATE TABLE IF NOT EXISTS knowledge_base_items ( + id TEXT PRIMARY KEY, + category TEXT NOT NULL, + title TEXT NOT NULL, + file_path TEXT NOT NULL, + content TEXT, + created_at DATETIME NOT NULL, + updated_at DATETIME NOT NULL + );` + + // 创建知识库向量表 + createKnowledgeEmbeddingsTable := ` + CREATE TABLE IF NOT EXISTS knowledge_embeddings ( + id TEXT PRIMARY KEY, + item_id TEXT NOT NULL, + chunk_index INTEGER NOT NULL, + chunk_text TEXT NOT NULL, + embedding TEXT NOT NULL, + created_at DATETIME NOT NULL, + FOREIGN KEY (item_id) REFERENCES knowledge_base_items(id) ON DELETE CASCADE + );` + + // 创建知识检索日志表(在独立知识库数据库中,不使用外键约束,因为conversations和messages表可能不在这个数据库中) + createKnowledgeRetrievalLogsTable := ` + CREATE TABLE IF NOT EXISTS knowledge_retrieval_logs ( + id TEXT PRIMARY KEY, + conversation_id TEXT, + message_id TEXT, + query TEXT NOT NULL, + risk_type TEXT, + retrieved_items TEXT, + created_at DATETIME NOT NULL + );` + + // 创建索引 + createIndexes := ` + CREATE INDEX IF NOT EXISTS idx_knowledge_items_category ON knowledge_base_items(category); + CREATE INDEX IF NOT EXISTS idx_knowledge_embeddings_item_id ON knowledge_embeddings(item_id); + CREATE INDEX IF NOT EXISTS idx_knowledge_retrieval_logs_conversation ON knowledge_retrieval_logs(conversation_id); + CREATE INDEX IF NOT EXISTS idx_knowledge_retrieval_logs_message ON knowledge_retrieval_logs(message_id); + CREATE INDEX IF NOT EXISTS idx_knowledge_retrieval_logs_created_at ON knowledge_retrieval_logs(created_at); + ` + + if _, err := db.Exec(createKnowledgeBaseItemsTable); err != nil { + return fmt.Errorf("创建knowledge_base_items表失败: %w", err) + } + + if _, err := db.Exec(createKnowledgeEmbeddingsTable); err != nil { + return fmt.Errorf("创建knowledge_embeddings表失败: %w", err) + } + + if _, err := db.Exec(createKnowledgeRetrievalLogsTable); err != nil { + return fmt.Errorf("创建knowledge_retrieval_logs表失败: %w", err) + } + + if _, err := db.Exec(createIndexes); err != nil { + return fmt.Errorf("创建索引失败: %w", err) + } + + db.logger.Info("知识库数据库表初始化完成") + return nil +} + // Close 关闭数据库连接 func (db *DB) Close() error { return db.DB.Close() } - diff --git a/internal/handler/agent.go b/internal/handler/agent.go index 939311f7..d43b31af 100644 --- a/internal/handler/agent.go +++ b/internal/handler/agent.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "net/http" + "strings" "time" "cyberstrike-ai/internal/agent" @@ -17,10 +18,13 @@ import ( // AgentHandler Agent处理器 type AgentHandler struct { - agent *agent.Agent - db *database.DB - logger *zap.Logger - tasks *AgentTaskManager + agent *agent.Agent + db *database.DB + logger *zap.Logger + tasks *AgentTaskManager + knowledgeManager interface { // 知识库管理器接口 + LogRetrieval(conversationID, messageID, query, riskType string, retrievedItems []string) error + } } // NewAgentHandler 创建新的Agent处理器 @@ -33,6 +37,13 @@ func NewAgentHandler(agent *agent.Agent, db *database.DB, logger *zap.Logger) *A } } +// SetKnowledgeManager 设置知识库管理器(用于记录检索日志) +func (h *AgentHandler) SetKnowledgeManager(manager interface { + LogRetrieval(conversationID, messageID, query, riskType string, retrievedItems []string) error +}) { + h.knowledgeManager = manager +} + // ChatRequest 聊天请求 type ChatRequest struct { Message string `json:"message" binding:"required"` @@ -271,9 +282,136 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) { assistantMessageID = assistantMsg.ID } + // 用于保存tool_call事件中的参数,以便在tool_result时使用 + toolCallCache := make(map[string]map[string]interface{}) // toolCallId -> arguments + progressCallback := func(eventType, message string, data interface{}) { sendEvent(eventType, message, data) + // 保存tool_call事件中的参数 + if eventType == "tool_call" { + if dataMap, ok := data.(map[string]interface{}); ok { + toolName, _ := dataMap["toolName"].(string) + if toolName == "search_knowledge_base" { + if toolCallId, ok := dataMap["toolCallId"].(string); ok && toolCallId != "" { + if argumentsObj, ok := dataMap["argumentsObj"].(map[string]interface{}); ok { + toolCallCache[toolCallId] = argumentsObj + } + } + } + } + } + + // 处理知识检索日志记录 + if eventType == "tool_result" && h.knowledgeManager != nil { + if dataMap, ok := data.(map[string]interface{}); ok { + toolName, _ := dataMap["toolName"].(string) + if toolName == "search_knowledge_base" { + // 提取检索信息 + query := "" + riskType := "" + var retrievedItems []string + + // 首先尝试从tool_call缓存中获取参数 + if toolCallId, ok := dataMap["toolCallId"].(string); ok && toolCallId != "" { + if cachedArgs, exists := toolCallCache[toolCallId]; exists { + if q, ok := cachedArgs["query"].(string); ok && q != "" { + query = q + } + if rt, ok := cachedArgs["risk_type"].(string); ok && rt != "" { + riskType = rt + } + // 使用后清理缓存 + delete(toolCallCache, toolCallId) + } + } + + // 如果缓存中没有,尝试从argumentsObj中提取 + if query == "" { + if arguments, ok := dataMap["argumentsObj"].(map[string]interface{}); ok { + if q, ok := arguments["query"].(string); ok && q != "" { + query = q + } + if rt, ok := arguments["risk_type"].(string); ok && rt != "" { + riskType = rt + } + } + } + + // 如果query仍然为空,尝试从result中提取(从结果文本的第一行) + if query == "" { + if result, ok := dataMap["result"].(string); ok && result != "" { + // 尝试从结果中提取查询内容(如果结果包含"未找到与查询 'xxx' 相关的知识") + if strings.Contains(result, "未找到与查询 '") { + start := strings.Index(result, "未找到与查询 '") + len("未找到与查询 '") + end := strings.Index(result[start:], "'") + if end > 0 { + query = result[start : start+end] + } + } + } + // 如果还是为空,使用默认值 + if query == "" { + query = "未知查询" + } + } + + // 从工具结果中提取检索到的知识项ID + // 结果格式:"找到 X 条相关知识:\n\n--- 结果 1 (相似度: XX.XX%) ---\n来源: [分类] 标题\n...\n" + if result, ok := dataMap["result"].(string); ok && result != "" { + // 尝试从元数据中提取知识项ID + metadataMatch := strings.Index(result, "") + if metadataEnd > 0 { + metadataJSON := result[metadataStart : metadataStart+metadataEnd] + var metadata map[string]interface{} + if err := json.Unmarshal([]byte(metadataJSON), &metadata); err == nil { + if meta, ok := metadata["_metadata"].(map[string]interface{}); ok { + if ids, ok := meta["retrievedItemIDs"].([]interface{}); ok { + retrievedItems = make([]string, 0, len(ids)) + for _, id := range ids { + if idStr, ok := id.(string); ok { + retrievedItems = append(retrievedItems, idStr) + } + } + } + } + } + } + } + + // 如果没有从元数据中提取到,但结果包含"找到 X 条",至少标记为有结果 + if len(retrievedItems) == 0 && strings.Contains(result, "找到") && !strings.Contains(result, "未找到") { + // 有结果,但无法准确提取ID,使用特殊标记 + retrievedItems = []string{"_has_results"} + } + } + + // 记录检索日志(异步,不阻塞) + go func() { + if err := h.knowledgeManager.LogRetrieval(conversationID, assistantMessageID, query, riskType, retrievedItems); err != nil { + h.logger.Warn("记录知识检索日志失败", zap.Error(err)) + } + }() + + // 添加知识检索事件到processDetails + if assistantMessageID != "" { + retrievalData := map[string]interface{}{ + "query": query, + "riskType": riskType, + "toolName": toolName, + } + if err := h.db.AddProcessDetail(assistantMessageID, conversationID, "knowledge_retrieval", fmt.Sprintf("检索知识: %s", query), retrievalData); err != nil { + h.logger.Warn("保存知识检索详情失败", zap.Error(err)) + } + } + } + } + } + // 保存过程详情到数据库(排除response和done事件,它们会在后面单独处理) if assistantMessageID != "" && eventType != "response" && eventType != "done" { if err := h.db.AddProcessDetail(assistantMessageID, conversationID, eventType, message, data); err != nil { diff --git a/internal/handler/config.go b/internal/handler/config.go index 6cd08e0d..0d3de996 100644 --- a/internal/handler/config.go +++ b/internal/handler/config.go @@ -20,17 +20,21 @@ import ( "gopkg.in/yaml.v3" ) +// KnowledgeToolRegistrar 知识库工具注册器接口 +type KnowledgeToolRegistrar func() error + // ConfigHandler 配置处理器 type ConfigHandler struct { - configPath string - config *config.Config - mcpServer *mcp.Server - executor *security.Executor - agent AgentUpdater // Agent接口,用于更新Agent配置 - attackChainHandler AttackChainUpdater // 攻击链处理器接口,用于更新配置 - externalMCPMgr *mcp.ExternalMCPManager // 外部MCP管理器 - logger *zap.Logger - mu sync.RWMutex + configPath string + config *config.Config + mcpServer *mcp.Server + executor *security.Executor + agent AgentUpdater // Agent接口,用于更新Agent配置 + attackChainHandler AttackChainUpdater // 攻击链处理器接口,用于更新配置 + externalMCPMgr *mcp.ExternalMCPManager // 外部MCP管理器 + knowledgeToolRegistrar KnowledgeToolRegistrar // 知识库工具注册器(可选) + logger *zap.Logger + mu sync.RWMutex } // AttackChainUpdater 攻击链处理器更新接口 @@ -47,23 +51,31 @@ type AgentUpdater interface { // NewConfigHandler 创建新的配置处理器 func NewConfigHandler(configPath string, cfg *config.Config, mcpServer *mcp.Server, executor *security.Executor, agent AgentUpdater, attackChainHandler AttackChainUpdater, externalMCPMgr *mcp.ExternalMCPManager, logger *zap.Logger) *ConfigHandler { return &ConfigHandler{ - configPath: configPath, - config: cfg, - mcpServer: mcpServer, - executor: executor, - agent: agent, + configPath: configPath, + config: cfg, + mcpServer: mcpServer, + executor: executor, + agent: agent, attackChainHandler: attackChainHandler, externalMCPMgr: externalMCPMgr, - logger: logger, + logger: logger, } } +// SetKnowledgeToolRegistrar 设置知识库工具注册器 +func (h *ConfigHandler) SetKnowledgeToolRegistrar(registrar KnowledgeToolRegistrar) { + h.mu.Lock() + defer h.mu.Unlock() + h.knowledgeToolRegistrar = registrar +} + // GetConfigResponse 获取配置响应 type GetConfigResponse struct { - OpenAI config.OpenAIConfig `json:"openai"` - MCP config.MCPConfig `json:"mcp"` - Tools []ToolConfigInfo `json:"tools"` - Agent config.AgentConfig `json:"agent"` + OpenAI config.OpenAIConfig `json:"openai"` + MCP config.MCPConfig `json:"mcp"` + Tools []ToolConfigInfo `json:"tools"` + Agent config.AgentConfig `json:"agent"` + Knowledge config.KnowledgeConfig `json:"knowledge"` } // ToolConfigInfo 工具配置信息 @@ -81,8 +93,11 @@ func (h *ConfigHandler) GetConfig(c *gin.Context) { defer h.mu.RUnlock() // 获取工具列表(包含内部和外部工具) + // 首先从配置文件获取工具 + configToolMap := make(map[string]bool) tools := make([]ToolConfigInfo, 0, len(h.config.Security.Tools)) for _, tool := range h.config.Security.Tools { + configToolMap[tool.Name] = true tools = append(tools, ToolConfigInfo{ Name: tool.Name, Description: tool.ShortDescription, @@ -98,6 +113,31 @@ func (h *ConfigHandler) GetConfig(c *gin.Context) { tools[len(tools)-1].Description = desc } } + + // 从MCP服务器获取所有已注册的工具(包括直接注册的工具,如知识检索工具) + if h.mcpServer != nil { + mcpTools := h.mcpServer.GetAllTools() + for _, mcpTool := range mcpTools { + // 跳过已经在配置文件中的工具(避免重复) + if configToolMap[mcpTool.Name] { + continue + } + // 添加直接注册到MCP服务器的工具(如知识检索工具) + description := mcpTool.ShortDescription + if description == "" { + description = mcpTool.Description + } + if len(description) > 100 { + description = description[:100] + "..." + } + tools = append(tools, ToolConfigInfo{ + Name: mcpTool.Name, + Description: description, + Enabled: true, // 直接注册的工具默认启用 + IsExternal: false, + }) + } + } // 获取外部MCP工具 if h.externalMCPMgr != nil { @@ -159,10 +199,11 @@ func (h *ConfigHandler) GetConfig(c *gin.Context) { } c.JSON(http.StatusOK, GetConfigResponse{ - OpenAI: h.config.OpenAI, - MCP: h.config.MCP, - Tools: tools, - Agent: h.config.Agent, + OpenAI: h.config.OpenAI, + MCP: h.config.MCP, + Tools: tools, + Agent: h.config.Agent, + Knowledge: h.config.Knowledge, }) } @@ -202,8 +243,10 @@ func (h *ConfigHandler) GetTools(c *gin.Context) { } // 获取所有内部工具并应用搜索过滤 + configToolMap := make(map[string]bool) allTools := make([]ToolConfigInfo, 0, len(h.config.Security.Tools)) for _, tool := range h.config.Security.Tools { + configToolMap[tool.Name] = true toolInfo := ToolConfigInfo{ Name: tool.Name, Description: tool.ShortDescription, @@ -230,6 +273,43 @@ func (h *ConfigHandler) GetTools(c *gin.Context) { allTools = append(allTools, toolInfo) } + + // 从MCP服务器获取所有已注册的工具(包括直接注册的工具,如知识检索工具) + if h.mcpServer != nil { + mcpTools := h.mcpServer.GetAllTools() + for _, mcpTool := range mcpTools { + // 跳过已经在配置文件中的工具(避免重复) + if configToolMap[mcpTool.Name] { + continue + } + + description := mcpTool.ShortDescription + if description == "" { + description = mcpTool.Description + } + if len(description) > 100 { + description = description[:100] + "..." + } + + toolInfo := ToolConfigInfo{ + Name: mcpTool.Name, + Description: description, + Enabled: true, // 直接注册的工具默认启用 + IsExternal: false, + } + + // 如果有关键词,进行搜索过滤 + if searchTermLower != "" { + nameLower := strings.ToLower(toolInfo.Name) + descLower := strings.ToLower(toolInfo.Description) + if !strings.Contains(nameLower, searchTermLower) && !strings.Contains(descLower, searchTermLower) { + continue // 不匹配,跳过 + } + } + + allTools = append(allTools, toolInfo) + } + } // 获取外部MCP工具 if h.externalMCPMgr != nil { @@ -337,10 +417,11 @@ func (h *ConfigHandler) GetTools(c *gin.Context) { // UpdateConfigRequest 更新配置请求 type UpdateConfigRequest struct { - OpenAI *config.OpenAIConfig `json:"openai,omitempty"` - MCP *config.MCPConfig `json:"mcp,omitempty"` - Tools []ToolEnableStatus `json:"tools,omitempty"` - Agent *config.AgentConfig `json:"agent,omitempty"` + OpenAI *config.OpenAIConfig `json:"openai,omitempty"` + MCP *config.MCPConfig `json:"mcp,omitempty"` + Tools []ToolEnableStatus `json:"tools,omitempty"` + Agent *config.AgentConfig `json:"agent,omitempty"` + Knowledge *config.KnowledgeConfig `json:"knowledge,omitempty"` } // ToolEnableStatus 工具启用状态 @@ -389,6 +470,19 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) { ) } + // 更新Knowledge配置 + if req.Knowledge != nil { + h.config.Knowledge = *req.Knowledge + h.logger.Info("更新Knowledge配置", + zap.Bool("enabled", h.config.Knowledge.Enabled), + zap.String("base_path", h.config.Knowledge.BasePath), + zap.String("embedding_model", h.config.Knowledge.Embedding.Model), + zap.Int("retrieval_top_k", h.config.Knowledge.Retrieval.TopK), + zap.Float64("similarity_threshold", h.config.Knowledge.Retrieval.SimilarityThreshold), + zap.Float64("hybrid_weight", h.config.Knowledge.Retrieval.HybridWeight), + ) + } + // 更新工具启用状态 if req.Tools != nil { // 分离内部工具和外部工具 @@ -519,8 +613,18 @@ func (h *ConfigHandler) ApplyConfig(c *gin.Context) { // 清空MCP服务器中的工具 h.mcpServer.ClearTools() - // 重新注册工具 + // 重新注册安全工具 h.executor.RegisterTools(h.mcpServer) + + // 如果知识库启用,重新注册知识库工具 + if h.config.Knowledge.Enabled && h.knowledgeToolRegistrar != nil { + h.logger.Info("重新注册知识库工具") + if err := h.knowledgeToolRegistrar(); err != nil { + h.logger.Error("重新注册知识库工具失败", zap.Error(err)) + } else { + h.logger.Info("知识库工具已重新注册") + } + } // 更新Agent的OpenAI配置 if h.agent != nil { @@ -565,6 +669,7 @@ func (h *ConfigHandler) saveConfig() error { updateAgentConfig(root, h.config.Agent.MaxIterations) updateMCPConfig(root, h.config.MCP) updateOpenAIConfig(root, h.config.OpenAI) + updateKnowledgeConfig(root, h.config.Knowledge) // 更新外部MCP配置(使用external_mcp.go中的函数,同一包中可直接调用) // 读取原始配置以保持向后兼容 originalConfigs := make(map[string]map[string]bool) @@ -708,6 +813,30 @@ func updateOpenAIConfig(doc *yaml.Node, cfg config.OpenAIConfig) { setStringInMap(openaiNode, "model", cfg.Model) } +func updateKnowledgeConfig(doc *yaml.Node, cfg config.KnowledgeConfig) { + root := doc.Content[0] + knowledgeNode := ensureMap(root, "knowledge") + setBoolInMap(knowledgeNode, "enabled", cfg.Enabled) + setStringInMap(knowledgeNode, "base_path", cfg.BasePath) + + // 更新嵌入配置 + embeddingNode := ensureMap(knowledgeNode, "embedding") + setStringInMap(embeddingNode, "provider", cfg.Embedding.Provider) + setStringInMap(embeddingNode, "model", cfg.Embedding.Model) + if cfg.Embedding.BaseURL != "" { + setStringInMap(embeddingNode, "base_url", cfg.Embedding.BaseURL) + } + if cfg.Embedding.APIKey != "" { + setStringInMap(embeddingNode, "api_key", cfg.Embedding.APIKey) + } + + // 更新检索配置 + retrievalNode := ensureMap(knowledgeNode, "retrieval") + setIntInMap(retrievalNode, "top_k", cfg.Retrieval.TopK) + setFloatInMap(retrievalNode, "similarity_threshold", cfg.Retrieval.SimilarityThreshold) + setFloatInMap(retrievalNode, "hybrid_weight", cfg.Retrieval.HybridWeight) +} + func ensureMap(parent *yaml.Node, path ...string) *yaml.Node { current := parent for _, key := range path { @@ -818,4 +947,12 @@ func setBoolInMap(mapNode *yaml.Node, key string, value bool) { } } +func setFloatInMap(mapNode *yaml.Node, key string, value float64) { + _, valueNode := ensureKeyValue(mapNode, key) + valueNode.Kind = yaml.ScalarNode + valueNode.Tag = "!!float" + valueNode.Style = 0 + valueNode.Value = fmt.Sprintf("%g", value) +} + diff --git a/internal/handler/knowledge.go b/internal/handler/knowledge.go new file mode 100644 index 00000000..020f6fd9 --- /dev/null +++ b/internal/handler/knowledge.go @@ -0,0 +1,248 @@ +package handler + +import ( + "context" + "fmt" + "net/http" + + "cyberstrike-ai/internal/database" + "cyberstrike-ai/internal/knowledge" + + "github.com/gin-gonic/gin" + "go.uber.org/zap" +) + +// KnowledgeHandler 知识库处理器 +type KnowledgeHandler struct { + manager *knowledge.Manager + retriever *knowledge.Retriever + indexer *knowledge.Indexer + db *database.DB + logger *zap.Logger +} + +// NewKnowledgeHandler 创建新的知识库处理器 +func NewKnowledgeHandler( + manager *knowledge.Manager, + retriever *knowledge.Retriever, + indexer *knowledge.Indexer, + db *database.DB, + logger *zap.Logger, +) *KnowledgeHandler { + return &KnowledgeHandler{ + manager: manager, + retriever: retriever, + indexer: indexer, + db: db, + logger: logger, + } +} + +// GetCategories 获取所有分类 +func (h *KnowledgeHandler) GetCategories(c *gin.Context) { + categories, err := h.manager.GetCategories() + if err != nil { + h.logger.Error("获取分类失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, gin.H{"categories": categories}) +} + +// GetItems 获取知识项列表 +func (h *KnowledgeHandler) GetItems(c *gin.Context) { + category := c.Query("category") + + items, err := h.manager.GetItems(category) + if err != nil { + h.logger.Error("获取知识项失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, gin.H{"items": items}) +} + +// GetItem 获取单个知识项 +func (h *KnowledgeHandler) GetItem(c *gin.Context) { + id := c.Param("id") + + item, err := h.manager.GetItem(id) + if err != nil { + h.logger.Error("获取知识项失败", zap.Error(err)) + c.JSON(http.StatusNotFound, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, item) +} + +// CreateItem 创建知识项 +func (h *KnowledgeHandler) CreateItem(c *gin.Context) { + var req struct { + Category string `json:"category" binding:"required"` + Title string `json:"title" binding:"required"` + Content string `json:"content" binding:"required"` + } + + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + item, err := h.manager.CreateItem(req.Category, req.Title, req.Content) + if err != nil { + h.logger.Error("创建知识项失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + // 异步索引 + go func() { + ctx := context.Background() + if err := h.indexer.IndexItem(ctx, item.ID); err != nil { + h.logger.Warn("索引知识项失败", zap.String("itemId", item.ID), zap.Error(err)) + } + }() + + c.JSON(http.StatusOK, item) +} + +// UpdateItem 更新知识项 +func (h *KnowledgeHandler) UpdateItem(c *gin.Context) { + id := c.Param("id") + + var req struct { + Category string `json:"category" binding:"required"` + Title string `json:"title" binding:"required"` + Content string `json:"content" binding:"required"` + } + + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + item, err := h.manager.UpdateItem(id, req.Category, req.Title, req.Content) + if err != nil { + h.logger.Error("更新知识项失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + // 异步重新索引 + go func() { + ctx := context.Background() + if err := h.indexer.IndexItem(ctx, item.ID); err != nil { + h.logger.Warn("重新索引知识项失败", zap.String("itemId", item.ID), zap.Error(err)) + } + }() + + c.JSON(http.StatusOK, item) +} + +// DeleteItem 删除知识项 +func (h *KnowledgeHandler) DeleteItem(c *gin.Context) { + id := c.Param("id") + + if err := h.manager.DeleteItem(id); err != nil { + h.logger.Error("删除知识项失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, gin.H{"message": "删除成功"}) +} + +// RebuildIndex 重建索引 +func (h *KnowledgeHandler) RebuildIndex(c *gin.Context) { + // 异步重建索引 + go func() { + ctx := context.Background() + if err := h.indexer.RebuildIndex(ctx); err != nil { + h.logger.Error("重建索引失败", zap.Error(err)) + } + }() + + c.JSON(http.StatusOK, gin.H{"message": "索引重建已开始,将在后台进行"}) +} + +// ScanKnowledgeBase 扫描知识库 +func (h *KnowledgeHandler) ScanKnowledgeBase(c *gin.Context) { + if err := h.manager.ScanKnowledgeBase(); err != nil { + h.logger.Error("扫描知识库失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + // 异步重建索引 + go func() { + ctx := context.Background() + if err := h.indexer.RebuildIndex(ctx); err != nil { + h.logger.Error("重建索引失败", zap.Error(err)) + } + }() + + c.JSON(http.StatusOK, gin.H{"message": "扫描完成,索引重建已开始"}) +} + +// GetRetrievalLogs 获取检索日志 +func (h *KnowledgeHandler) GetRetrievalLogs(c *gin.Context) { + conversationID := c.Query("conversationId") + messageID := c.Query("messageId") + limit := 50 // 默认50条 + + if limitStr := c.Query("limit"); limitStr != "" { + if parsed, err := parseInt(limitStr); err == nil && parsed > 0 { + limit = parsed + } + } + + logs, err := h.manager.GetRetrievalLogs(conversationID, messageID, limit) + if err != nil { + h.logger.Error("获取检索日志失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, gin.H{"logs": logs}) +} + +// GetIndexStatus 获取索引状态 +func (h *KnowledgeHandler) GetIndexStatus(c *gin.Context) { + status, err := h.manager.GetIndexStatus() + if err != nil { + h.logger.Error("获取索引状态失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, status) +} + +// Search 搜索知识库(用于API调用,Agent内部使用Retriever) +func (h *KnowledgeHandler) Search(c *gin.Context) { + var req knowledge.SearchRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + results, err := h.retriever.Search(c.Request.Context(), &req) + if err != nil { + h.logger.Error("搜索知识库失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, gin.H{"results": results}) +} + +// 辅助函数:解析整数 +func parseInt(s string) (int, error) { + var result int + _, err := fmt.Sscanf(s, "%d", &result) + return result, err +} + diff --git a/internal/knowledge/embedder.go b/internal/knowledge/embedder.go new file mode 100644 index 00000000..2f27ab9d --- /dev/null +++ b/internal/knowledge/embedder.go @@ -0,0 +1,205 @@ +package knowledge + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "strings" + "time" + + "cyberstrike-ai/internal/config" + "cyberstrike-ai/internal/openai" + + "go.uber.org/zap" +) + +// Embedder 文本嵌入器 +type Embedder struct { + openAIClient *openai.Client + config *config.KnowledgeConfig + openAIConfig *config.OpenAIConfig // 用于获取API Key + logger *zap.Logger +} + +// NewEmbedder 创建新的嵌入器 +func NewEmbedder(cfg *config.KnowledgeConfig, openAIConfig *config.OpenAIConfig, openAIClient *openai.Client, logger *zap.Logger) *Embedder { + return &Embedder{ + openAIClient: openAIClient, + config: cfg, + openAIConfig: openAIConfig, + logger: logger, + } +} + +// EmbeddingRequest OpenAI嵌入请求 +type EmbeddingRequest struct { + Model string `json:"model"` + Input []string `json:"input"` +} + +// EmbeddingResponse OpenAI嵌入响应 +type EmbeddingResponse struct { + Data []EmbeddingData `json:"data"` + Error *EmbeddingError `json:"error,omitempty"` +} + +// EmbeddingData 嵌入数据 +type EmbeddingData struct { + Embedding []float64 `json:"embedding"` + Index int `json:"index"` +} + +// EmbeddingError 嵌入错误 +type EmbeddingError struct { + Message string `json:"message"` + Type string `json:"type"` +} + +// EmbedText 对文本进行嵌入 +func (e *Embedder) EmbedText(ctx context.Context, text string) ([]float32, error) { + if e.openAIClient == nil { + return nil, fmt.Errorf("OpenAI客户端未初始化") + } + + // 使用配置的嵌入模型 + model := e.config.Embedding.Model + if model == "" { + model = "text-embedding-3-small" + } + + req := EmbeddingRequest{ + Model: model, + Input: []string{text}, + } + + // 清理baseURL:去除前后空格和尾部斜杠 + baseURL := strings.TrimSpace(e.config.Embedding.BaseURL) + baseURL = strings.TrimSuffix(baseURL, "/") + if baseURL == "" { + baseURL = "https://api.openai.com/v1" + } + + // 构建请求 + body, err := json.Marshal(req) + if err != nil { + return nil, fmt.Errorf("序列化请求失败: %w", err) + } + + requestURL := baseURL + "/embeddings" + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL, strings.NewReader(string(body))) + if err != nil { + return nil, fmt.Errorf("创建请求失败: %w", err) + } + + httpReq.Header.Set("Content-Type", "application/json") + + // 使用配置的API Key,如果没有则使用OpenAI配置的 + apiKey := strings.TrimSpace(e.config.Embedding.APIKey) + if apiKey == "" && e.openAIConfig != nil { + apiKey = e.openAIConfig.APIKey + } + if apiKey == "" { + return nil, fmt.Errorf("API Key未配置") + } + httpReq.Header.Set("Authorization", "Bearer "+apiKey) + + // 发送请求 + httpClient := &http.Client{ + Timeout: 30 * time.Second, + } + resp, err := httpClient.Do(httpReq) + if err != nil { + return nil, fmt.Errorf("发送请求失败: %w", err) + } + defer resp.Body.Close() + + // 读取响应体以便在错误时输出详细信息 + bodyBytes := make([]byte, 0) + buf := make([]byte, 4096) + for { + n, err := resp.Body.Read(buf) + if n > 0 { + bodyBytes = append(bodyBytes, buf[:n]...) + } + if err != nil { + break + } + } + + // 记录请求和响应信息(用于调试) + requestBodyPreview := string(body) + if len(requestBodyPreview) > 200 { + requestBodyPreview = requestBodyPreview[:200] + "..." + } + e.logger.Debug("嵌入API请求", + zap.String("url", httpReq.URL.String()), + zap.String("model", model), + zap.String("requestBody", requestBodyPreview), + zap.Int("status", resp.StatusCode), + zap.Int("bodySize", len(bodyBytes)), + zap.String("contentType", resp.Header.Get("Content-Type")), + ) + + var embeddingResp EmbeddingResponse + if err := json.Unmarshal(bodyBytes, &embeddingResp); err != nil { + // 输出详细的错误信息 + bodyPreview := string(bodyBytes) + if len(bodyPreview) > 500 { + bodyPreview = bodyPreview[:500] + "..." + } + return nil, fmt.Errorf("解析响应失败 (URL: %s, 状态码: %d, 响应长度: %d字节): %w\n请求体: %s\n响应内容预览: %s", + requestURL, resp.StatusCode, len(bodyBytes), err, requestBodyPreview, bodyPreview) + } + + if embeddingResp.Error != nil { + return nil, fmt.Errorf("OpenAI API错误 (状态码: %d): 类型=%s, 消息=%s", + resp.StatusCode, embeddingResp.Error.Type, embeddingResp.Error.Message) + } + + if resp.StatusCode != http.StatusOK { + bodyPreview := string(bodyBytes) + if len(bodyPreview) > 500 { + bodyPreview = bodyPreview[:500] + "..." + } + return nil, fmt.Errorf("HTTP请求失败 (URL: %s, 状态码: %d): 响应内容=%s", requestURL, resp.StatusCode, bodyPreview) + } + + if len(embeddingResp.Data) == 0 { + bodyPreview := string(bodyBytes) + if len(bodyPreview) > 500 { + bodyPreview = bodyPreview[:500] + "..." + } + return nil, fmt.Errorf("未收到嵌入数据 (状态码: %d, 响应长度: %d字节)\n响应内容: %s", + resp.StatusCode, len(bodyBytes), bodyPreview) + } + + // 转换为float32 + embedding := make([]float32, len(embeddingResp.Data[0].Embedding)) + for i, v := range embeddingResp.Data[0].Embedding { + embedding[i] = float32(v) + } + + return embedding, nil +} + +// EmbedTexts 批量嵌入文本 +func (e *Embedder) EmbedTexts(ctx context.Context, texts []string) ([][]float32, error) { + if len(texts) == 0 { + return nil, nil + } + + // OpenAI API支持批量,但为了简单起见,我们逐个处理 + // 实际可以使用批量API以提高效率 + embeddings := make([][]float32, len(texts)) + for i, text := range texts { + embedding, err := e.EmbedText(ctx, text) + if err != nil { + return nil, fmt.Errorf("嵌入文本[%d]失败: %w", i, err) + } + embeddings[i] = embedding + } + + return embeddings, nil +} + diff --git a/internal/knowledge/indexer.go b/internal/knowledge/indexer.go new file mode 100644 index 00000000..d691f19a --- /dev/null +++ b/internal/knowledge/indexer.go @@ -0,0 +1,247 @@ +package knowledge + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "regexp" + "strings" + + "github.com/google/uuid" + "go.uber.org/zap" +) + +// Indexer 索引器,负责将知识项分块并向量化 +type Indexer struct { + db *sql.DB + embedder *Embedder + logger *zap.Logger + chunkSize int // 每个块的最大token数(估算) + overlap int // 块之间的重叠token数 +} + +// NewIndexer 创建新的索引器 +func NewIndexer(db *sql.DB, embedder *Embedder, logger *zap.Logger) *Indexer { + return &Indexer{ + db: db, + embedder: embedder, + logger: logger, + chunkSize: 512, // 默认512 tokens + overlap: 50, // 默认50 tokens重叠 + } +} + +// ChunkText 将文本分块 +func (idx *Indexer) ChunkText(text string) []string { + // 按Markdown标题分割 + chunks := idx.splitByMarkdownHeaders(text) + + // 如果块太大,进一步分割 + result := make([]string, 0) + for _, chunk := range chunks { + if idx.estimateTokens(chunk) <= idx.chunkSize { + result = append(result, chunk) + } else { + // 按段落分割 + subChunks := idx.splitByParagraphs(chunk) + for _, subChunk := range subChunks { + if idx.estimateTokens(subChunk) <= idx.chunkSize { + result = append(result, subChunk) + } else { + // 按句子分割 + sentences := idx.splitBySentences(subChunk) + currentChunk := "" + for _, sentence := range sentences { + testChunk := currentChunk + if testChunk != "" { + testChunk += "\n" + } + testChunk += sentence + + if idx.estimateTokens(testChunk) > idx.chunkSize && currentChunk != "" { + result = append(result, currentChunk) + currentChunk = sentence + } else { + currentChunk = testChunk + } + } + if currentChunk != "" { + result = append(result, currentChunk) + } + } + } + } + } + + return result +} + +// splitByMarkdownHeaders 按Markdown标题分割 +func (idx *Indexer) splitByMarkdownHeaders(text string) []string { + // 匹配Markdown标题 (# ## ### 等) + headerRegex := regexp.MustCompile(`(?m)^#{1,6}\s+.+$`) + + // 找到所有标题位置 + matches := headerRegex.FindAllStringIndex(text, -1) + if len(matches) == 0 { + return []string{text} + } + + chunks := make([]string, 0) + lastPos := 0 + + for _, match := range matches { + start := match[0] + if start > lastPos { + chunks = append(chunks, strings.TrimSpace(text[lastPos:start])) + } + lastPos = start + } + + // 添加最后一部分 + if lastPos < len(text) { + chunks = append(chunks, strings.TrimSpace(text[lastPos:])) + } + + // 过滤空块 + result := make([]string, 0) + for _, chunk := range chunks { + if strings.TrimSpace(chunk) != "" { + result = append(result, chunk) + } + } + + if len(result) == 0 { + return []string{text} + } + + return result +} + +// splitByParagraphs 按段落分割 +func (idx *Indexer) splitByParagraphs(text string) []string { + paragraphs := strings.Split(text, "\n\n") + result := make([]string, 0) + for _, p := range paragraphs { + if strings.TrimSpace(p) != "" { + result = append(result, strings.TrimSpace(p)) + } + } + return result +} + +// splitBySentences 按句子分割 +func (idx *Indexer) splitBySentences(text string) []string { + // 简单的句子分割(按句号、问号、感叹号) + sentenceRegex := regexp.MustCompile(`[.!?]+\s+`) + sentences := sentenceRegex.Split(text, -1) + result := make([]string, 0) + for _, s := range sentences { + if strings.TrimSpace(s) != "" { + result = append(result, strings.TrimSpace(s)) + } + } + return result +} + +// estimateTokens 估算token数(简单估算:1 token ≈ 4字符) +func (idx *Indexer) estimateTokens(text string) int { + return len([]rune(text)) / 4 +} + +// IndexItem 索引知识项(分块并向量化) +func (idx *Indexer) IndexItem(ctx context.Context, itemID string) error { + // 获取知识项 + var content string + err := idx.db.QueryRow("SELECT content FROM knowledge_base_items WHERE id = ?", itemID).Scan(&content) + if err != nil { + return fmt.Errorf("获取知识项失败: %w", err) + } + + // 删除旧的向量 + _, err = idx.db.Exec("DELETE FROM knowledge_embeddings WHERE item_id = ?", itemID) + if err != nil { + return fmt.Errorf("删除旧向量失败: %w", err) + } + + // 分块 + chunks := idx.ChunkText(content) + idx.logger.Info("知识项分块完成", zap.String("itemId", itemID), zap.Int("chunks", len(chunks))) + + // 向量化每个块 + for i, chunk := range chunks { + chunkPreview := chunk + if len(chunkPreview) > 200 { + chunkPreview = chunkPreview[:200] + "..." + } + embedding, err := idx.embedder.EmbedText(ctx, chunk) + if err != nil { + idx.logger.Warn("向量化失败", + zap.String("itemId", itemID), + zap.Int("chunkIndex", i), + zap.Int("chunkLength", len(chunk)), + zap.String("chunkPreview", chunkPreview), + zap.Error(err), + ) + continue + } + + // 保存向量 + chunkID := uuid.New().String() + embeddingJSON, _ := json.Marshal(embedding) + + _, err = idx.db.Exec( + "INSERT INTO knowledge_embeddings (id, item_id, chunk_index, chunk_text, embedding, created_at) VALUES (?, ?, ?, ?, ?, datetime('now'))", + chunkID, itemID, i, chunk, string(embeddingJSON), + ) + if err != nil { + idx.logger.Warn("保存向量失败", zap.String("itemId", itemID), zap.Int("chunkIndex", i), zap.Error(err)) + continue + } + } + + idx.logger.Info("知识项索引完成", zap.String("itemId", itemID), zap.Int("chunks", len(chunks))) + return nil +} + +// HasIndex 检查是否存在索引 +func (idx *Indexer) HasIndex() (bool, error) { + var count int + err := idx.db.QueryRow("SELECT COUNT(*) FROM knowledge_embeddings").Scan(&count) + if err != nil { + return false, fmt.Errorf("检查索引失败: %w", err) + } + return count > 0, nil +} + +// RebuildIndex 重建所有索引 +func (idx *Indexer) RebuildIndex(ctx context.Context) error { + rows, err := idx.db.Query("SELECT id FROM knowledge_base_items") + if err != nil { + return fmt.Errorf("查询知识项失败: %w", err) + } + defer rows.Close() + + var itemIDs []string + for rows.Next() { + var id string + if err := rows.Scan(&id); err != nil { + return fmt.Errorf("扫描知识项ID失败: %w", err) + } + itemIDs = append(itemIDs, id) + } + + idx.logger.Info("开始重建索引", zap.Int("totalItems", len(itemIDs))) + + for i, itemID := range itemIDs { + if err := idx.IndexItem(ctx, itemID); err != nil { + idx.logger.Warn("索引知识项失败", zap.String("itemId", itemID), zap.Error(err)) + continue + } + idx.logger.Debug("索引进度", zap.Int("current", i+1), zap.Int("total", len(itemIDs))) + } + + idx.logger.Info("索引重建完成", zap.Int("totalItems", len(itemIDs))) + return nil +} diff --git a/internal/knowledge/manager.go b/internal/knowledge/manager.go new file mode 100644 index 00000000..e4a4cf2b --- /dev/null +++ b/internal/knowledge/manager.go @@ -0,0 +1,447 @@ +package knowledge + +import ( + "database/sql" + "encoding/json" + "fmt" + "io/fs" + "os" + "path/filepath" + "strings" + "time" + + "github.com/google/uuid" + "go.uber.org/zap" +) + +// Manager 知识库管理器 +type Manager struct { + db *sql.DB + basePath string + logger *zap.Logger +} + +// NewManager 创建新的知识库管理器 +func NewManager(db *sql.DB, basePath string, logger *zap.Logger) *Manager { + return &Manager{ + db: db, + basePath: basePath, + logger: logger, + } +} + +// ScanKnowledgeBase 扫描知识库目录,更新数据库 +func (m *Manager) ScanKnowledgeBase() error { + if m.basePath == "" { + return fmt.Errorf("知识库路径未配置") + } + + // 确保目录存在 + if err := os.MkdirAll(m.basePath, 0755); err != nil { + return fmt.Errorf("创建知识库目录失败: %w", err) + } + + // 遍历知识库目录 + return filepath.WalkDir(m.basePath, func(path string, d fs.DirEntry, err error) error { + if err != nil { + return err + } + + // 跳过目录和非markdown文件 + if d.IsDir() || !strings.HasSuffix(strings.ToLower(path), ".md") { + return nil + } + + // 计算相对路径和分类 + relPath, err := filepath.Rel(m.basePath, path) + if err != nil { + return err + } + + // 第一个目录名作为分类(风险类型) + parts := strings.Split(relPath, string(filepath.Separator)) + category := "未分类" + if len(parts) > 1 { + category = parts[0] + } + + // 文件名为标题 + title := strings.TrimSuffix(filepath.Base(path), ".md") + + // 读取文件内容 + content, err := os.ReadFile(path) + if err != nil { + m.logger.Warn("读取知识库文件失败", zap.String("path", path), zap.Error(err)) + return nil // 继续处理其他文件 + } + + // 检查是否已存在 + var existingID string + err = m.db.QueryRow( + "SELECT id FROM knowledge_base_items WHERE file_path = ?", + path, + ).Scan(&existingID) + + if err == sql.ErrNoRows { + // 创建新项 + id := uuid.New().String() + now := time.Now() + _, err = m.db.Exec( + "INSERT INTO knowledge_base_items (id, category, title, file_path, content, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?)", + id, category, title, path, string(content), now, now, + ) + if err != nil { + return fmt.Errorf("插入知识项失败: %w", err) + } + m.logger.Info("添加知识项", zap.String("id", id), zap.String("title", title), zap.String("category", category)) + } else if err == nil { + // 更新现有项 + _, err = m.db.Exec( + "UPDATE knowledge_base_items SET category = ?, title = ?, content = ?, updated_at = ? WHERE id = ?", + category, title, string(content), time.Now(), existingID, + ) + if err != nil { + return fmt.Errorf("更新知识项失败: %w", err) + } + m.logger.Debug("更新知识项", zap.String("id", existingID), zap.String("title", title)) + } else { + return fmt.Errorf("查询知识项失败: %w", err) + } + + return nil + }) +} + +// GetCategories 获取所有分类(风险类型) +func (m *Manager) GetCategories() ([]string, error) { + rows, err := m.db.Query("SELECT DISTINCT category FROM knowledge_base_items ORDER BY category") + if err != nil { + return nil, fmt.Errorf("查询分类失败: %w", err) + } + defer rows.Close() + + var categories []string + for rows.Next() { + var category string + if err := rows.Scan(&category); err != nil { + return nil, fmt.Errorf("扫描分类失败: %w", err) + } + categories = append(categories, category) + } + + return categories, nil +} + +// GetItems 获取知识项列表 +func (m *Manager) GetItems(category string) ([]*KnowledgeItem, error) { + var rows *sql.Rows + var err error + + if category != "" { + rows, err = m.db.Query( + "SELECT id, category, title, file_path, content, created_at, updated_at FROM knowledge_base_items WHERE category = ? ORDER BY title", + category, + ) + } else { + rows, err = m.db.Query( + "SELECT id, category, title, file_path, content, created_at, updated_at FROM knowledge_base_items ORDER BY category, title", + ) + } + if err != nil { + return nil, fmt.Errorf("查询知识项失败: %w", err) + } + defer rows.Close() + + var items []*KnowledgeItem + for rows.Next() { + item := &KnowledgeItem{} + var createdAt, updatedAt string + if err := rows.Scan(&item.ID, &item.Category, &item.Title, &item.FilePath, &item.Content, &createdAt, &updatedAt); err != nil { + return nil, fmt.Errorf("扫描知识项失败: %w", err) + } + + // 解析时间 + item.CreatedAt, _ = time.Parse("2006-01-02 15:04:05.999999999-07:00", createdAt) + if item.CreatedAt.IsZero() { + item.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt) + } + item.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05.999999999-07:00", updatedAt) + if item.UpdatedAt.IsZero() { + item.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt) + } + + items = append(items, item) + } + + return items, nil +} + +// GetItem 获取单个知识项 +func (m *Manager) GetItem(id string) (*KnowledgeItem, error) { + item := &KnowledgeItem{} + var createdAt, updatedAt string + err := m.db.QueryRow( + "SELECT id, category, title, file_path, content, created_at, updated_at FROM knowledge_base_items WHERE id = ?", + id, + ).Scan(&item.ID, &item.Category, &item.Title, &item.FilePath, &item.Content, &createdAt, &updatedAt) + + if err == sql.ErrNoRows { + return nil, fmt.Errorf("知识项不存在") + } + if err != nil { + return nil, fmt.Errorf("查询知识项失败: %w", err) + } + + // 解析时间 + item.CreatedAt, _ = time.Parse("2006-01-02 15:04:05.999999999-07:00", createdAt) + if item.CreatedAt.IsZero() { + item.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt) + } + item.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05.999999999-07:00", updatedAt) + if item.UpdatedAt.IsZero() { + item.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt) + } + + return item, nil +} + +// CreateItem 创建知识项 +func (m *Manager) CreateItem(category, title, content string) (*KnowledgeItem, error) { + id := uuid.New().String() + now := time.Now() + + // 构建文件路径 + filePath := filepath.Join(m.basePath, category, title+".md") + + // 确保目录存在 + if err := os.MkdirAll(filepath.Dir(filePath), 0755); err != nil { + return nil, fmt.Errorf("创建目录失败: %w", err) + } + + // 写入文件 + if err := os.WriteFile(filePath, []byte(content), 0644); err != nil { + return nil, fmt.Errorf("写入文件失败: %w", err) + } + + // 插入数据库 + _, err := m.db.Exec( + "INSERT INTO knowledge_base_items (id, category, title, file_path, content, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?)", + id, category, title, filePath, content, now, now, + ) + if err != nil { + return nil, fmt.Errorf("插入知识项失败: %w", err) + } + + return &KnowledgeItem{ + ID: id, + Category: category, + Title: title, + FilePath: filePath, + Content: content, + CreatedAt: now, + UpdatedAt: now, + }, nil +} + +// UpdateItem 更新知识项 +func (m *Manager) UpdateItem(id, category, title, content string) (*KnowledgeItem, error) { + // 获取现有项 + item, err := m.GetItem(id) + if err != nil { + return nil, err + } + + // 构建新文件路径 + newFilePath := filepath.Join(m.basePath, category, title+".md") + + // 如果路径改变,需要移动文件 + if item.FilePath != newFilePath { + // 确保新目录存在 + if err := os.MkdirAll(filepath.Dir(newFilePath), 0755); err != nil { + return nil, fmt.Errorf("创建目录失败: %w", err) + } + + // 移动文件 + if err := os.Rename(item.FilePath, newFilePath); err != nil { + return nil, fmt.Errorf("移动文件失败: %w", err) + } + + // 删除旧目录(如果为空) + oldDir := filepath.Dir(item.FilePath) + if entries, err := os.ReadDir(oldDir); err == nil && len(entries) == 0 { + os.Remove(oldDir) + } + } + + // 写入文件 + if err := os.WriteFile(newFilePath, []byte(content), 0644); err != nil { + return nil, fmt.Errorf("写入文件失败: %w", err) + } + + // 更新数据库 + _, err = m.db.Exec( + "UPDATE knowledge_base_items SET category = ?, title = ?, file_path = ?, content = ?, updated_at = ? WHERE id = ?", + category, title, newFilePath, content, time.Now(), id, + ) + if err != nil { + return nil, fmt.Errorf("更新知识项失败: %w", err) + } + + // 删除旧的向量嵌入(需要重新索引) + _, err = m.db.Exec("DELETE FROM knowledge_embeddings WHERE item_id = ?", id) + if err != nil { + m.logger.Warn("删除旧向量嵌入失败", zap.Error(err)) + } + + return m.GetItem(id) +} + +// DeleteItem 删除知识项 +func (m *Manager) DeleteItem(id string) error { + // 获取文件路径 + var filePath string + err := m.db.QueryRow("SELECT file_path FROM knowledge_base_items WHERE id = ?", id).Scan(&filePath) + if err != nil { + return fmt.Errorf("查询知识项失败: %w", err) + } + + // 删除文件 + if err := os.Remove(filePath); err != nil && !os.IsNotExist(err) { + m.logger.Warn("删除文件失败", zap.String("path", filePath), zap.Error(err)) + } + + // 删除数据库记录(级联删除向量) + _, err = m.db.Exec("DELETE FROM knowledge_base_items WHERE id = ?", id) + if err != nil { + return fmt.Errorf("删除知识项失败: %w", err) + } + + return nil +} + +// LogRetrieval 记录检索日志 +func (m *Manager) LogRetrieval(conversationID, messageID, query, riskType string, retrievedItems []string) error { + id := uuid.New().String() + itemsJSON, _ := json.Marshal(retrievedItems) + + _, err := m.db.Exec( + "INSERT INTO knowledge_retrieval_logs (id, conversation_id, message_id, query, risk_type, retrieved_items, created_at) VALUES (?, ?, ?, ?, ?, ?, ?)", + id, conversationID, messageID, query, riskType, string(itemsJSON), time.Now(), + ) + return err +} + +// GetIndexStatus 获取索引状态 +func (m *Manager) GetIndexStatus() (map[string]interface{}, error) { + // 获取总知识项数 + var totalItems int + err := m.db.QueryRow("SELECT COUNT(*) FROM knowledge_base_items").Scan(&totalItems) + if err != nil { + return nil, fmt.Errorf("查询总知识项数失败: %w", err) + } + + // 获取已索引的知识项数(有向量嵌入的) + var indexedItems int + err = m.db.QueryRow(` + SELECT COUNT(DISTINCT item_id) + FROM knowledge_embeddings + `).Scan(&indexedItems) + if err != nil { + return nil, fmt.Errorf("查询已索引项数失败: %w", err) + } + + // 计算进度百分比 + var progressPercent float64 + if totalItems > 0 { + progressPercent = float64(indexedItems) / float64(totalItems) * 100 + } else { + progressPercent = 100.0 + } + + // 判断是否完成 + isComplete := indexedItems >= totalItems && totalItems > 0 + + return map[string]interface{}{ + "total_items": totalItems, + "indexed_items": indexedItems, + "progress_percent": progressPercent, + "is_complete": isComplete, + }, nil +} + +// GetRetrievalLogs 获取检索日志 +func (m *Manager) GetRetrievalLogs(conversationID, messageID string, limit int) ([]*RetrievalLog, error) { + var rows *sql.Rows + var err error + + if messageID != "" { + rows, err = m.db.Query( + "SELECT id, conversation_id, message_id, query, risk_type, retrieved_items, created_at FROM knowledge_retrieval_logs WHERE message_id = ? ORDER BY created_at DESC LIMIT ?", + messageID, limit, + ) + } else if conversationID != "" { + rows, err = m.db.Query( + "SELECT id, conversation_id, message_id, query, risk_type, retrieved_items, created_at FROM knowledge_retrieval_logs WHERE conversation_id = ? ORDER BY created_at DESC LIMIT ?", + conversationID, limit, + ) + } else { + rows, err = m.db.Query( + "SELECT id, conversation_id, message_id, query, risk_type, retrieved_items, created_at FROM knowledge_retrieval_logs ORDER BY created_at DESC LIMIT ?", + limit, + ) + } + + if err != nil { + return nil, fmt.Errorf("查询检索日志失败: %w", err) + } + defer rows.Close() + + var logs []*RetrievalLog + for rows.Next() { + log := &RetrievalLog{} + var createdAt string + var itemsJSON sql.NullString + if err := rows.Scan(&log.ID, &log.ConversationID, &log.MessageID, &log.Query, &log.RiskType, &itemsJSON, &createdAt); err != nil { + return nil, fmt.Errorf("扫描检索日志失败: %w", err) + } + + // 解析时间 - 支持多种格式 + var err error + timeFormats := []string{ + "2006-01-02 15:04:05.999999999-07:00", + "2006-01-02 15:04:05.999999999", + "2006-01-02T15:04:05.999999999Z07:00", + "2006-01-02T15:04:05Z", + "2006-01-02 15:04:05", + time.RFC3339, + time.RFC3339Nano, + } + + for _, format := range timeFormats { + log.CreatedAt, err = time.Parse(format, createdAt) + if err == nil && !log.CreatedAt.IsZero() { + break + } + } + + // 如果所有格式都失败,记录警告但继续处理 + if log.CreatedAt.IsZero() { + m.logger.Warn("解析检索日志时间失败", + zap.String("timeStr", createdAt), + zap.Error(err), + ) + // 使用当前时间作为fallback + log.CreatedAt = time.Now() + } + + // 解析检索项 + if itemsJSON.Valid { + json.Unmarshal([]byte(itemsJSON.String), &log.RetrievedItems) + } + + logs = append(logs, log) + } + + return logs, nil +} + diff --git a/internal/knowledge/retriever.go b/internal/knowledge/retriever.go new file mode 100644 index 00000000..be2d5311 --- /dev/null +++ b/internal/knowledge/retriever.go @@ -0,0 +1,230 @@ +package knowledge + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "math" + "strings" + + "go.uber.org/zap" +) + +// Retriever 检索器 +type Retriever struct { + db *sql.DB + embedder *Embedder + config *RetrievalConfig + logger *zap.Logger +} + +// RetrievalConfig 检索配置 +type RetrievalConfig struct { + TopK int + SimilarityThreshold float64 + HybridWeight float64 +} + +// NewRetriever 创建新的检索器 +func NewRetriever(db *sql.DB, embedder *Embedder, config *RetrievalConfig, logger *zap.Logger) *Retriever { + return &Retriever{ + db: db, + embedder: embedder, + config: config, + logger: logger, + } +} + +// cosineSimilarity 计算余弦相似度 +func cosineSimilarity(a, b []float32) float64 { + if len(a) != len(b) { + return 0.0 + } + + var dotProduct, normA, normB float64 + for i := range a { + dotProduct += float64(a[i] * b[i]) + normA += float64(a[i] * a[i]) + normB += float64(b[i] * b[i]) + } + + if normA == 0 || normB == 0 { + return 0.0 + } + + return dotProduct / (math.Sqrt(normA) * math.Sqrt(normB)) +} + +// bm25Score 计算BM25分数(简化版) +func (r *Retriever) bm25Score(query, text string) float64 { + queryTerms := strings.Fields(strings.ToLower(query)) + textLower := strings.ToLower(text) + textTerms := strings.Fields(textLower) + + score := 0.0 + for _, term := range queryTerms { + termFreq := 0 + for _, textTerm := range textTerms { + if textTerm == term { + termFreq++ + } + } + if termFreq > 0 { + // 简化的BM25公式 + score += float64(termFreq) / float64(len(textTerms)) + } + } + + return score / float64(len(queryTerms)) +} + +// Search 搜索知识库 +func (r *Retriever) Search(ctx context.Context, req *SearchRequest) ([]*RetrievalResult, error) { + if req.Query == "" { + return nil, fmt.Errorf("查询不能为空") + } + + topK := req.TopK + if topK <= 0 { + topK = r.config.TopK + } + if topK == 0 { + topK = 5 + } + + threshold := req.Threshold + if threshold <= 0 { + threshold = r.config.SimilarityThreshold + } + if threshold == 0 { + threshold = 0.7 + } + + // 向量化查询 + queryEmbedding, err := r.embedder.EmbedText(ctx, req.Query) + if err != nil { + return nil, fmt.Errorf("向量化查询失败: %w", err) + } + + // 查询所有向量(或按风险类型过滤) + var rows *sql.Rows + if req.RiskType != "" { + rows, err = r.db.Query(` + SELECT e.id, e.item_id, e.chunk_index, e.chunk_text, e.embedding, i.category, i.title + FROM knowledge_embeddings e + JOIN knowledge_base_items i ON e.item_id = i.id + WHERE i.category = ? + `, req.RiskType) + } else { + rows, err = r.db.Query(` + SELECT e.id, e.item_id, e.chunk_index, e.chunk_text, e.embedding, i.category, i.title + FROM knowledge_embeddings e + JOIN knowledge_base_items i ON e.item_id = i.id + `) + } + if err != nil { + return nil, fmt.Errorf("查询向量失败: %w", err) + } + defer rows.Close() + + // 计算相似度 + type candidate struct { + chunk *KnowledgeChunk + item *KnowledgeItem + similarity float64 + bm25Score float64 + } + + candidates := make([]candidate, 0) + + for rows.Next() { + var chunkID, itemID, chunkText, embeddingJSON, category, title string + var chunkIndex int + + if err := rows.Scan(&chunkID, &itemID, &chunkIndex, &chunkText, &embeddingJSON, &category, &title); err != nil { + r.logger.Warn("扫描向量失败", zap.Error(err)) + continue + } + + // 解析向量 + var embedding []float32 + if err := json.Unmarshal([]byte(embeddingJSON), &embedding); err != nil { + r.logger.Warn("解析向量失败", zap.Error(err)) + continue + } + + // 计算余弦相似度 + similarity := cosineSimilarity(queryEmbedding, embedding) + + // 计算BM25分数 + bm25Score := r.bm25Score(req.Query, chunkText) + + // 过滤低相似度结果 + if similarity < threshold { + continue + } + + chunk := &KnowledgeChunk{ + ID: chunkID, + ItemID: itemID, + ChunkIndex: chunkIndex, + ChunkText: chunkText, + Embedding: embedding, + } + + item := &KnowledgeItem{ + ID: itemID, + Category: category, + Title: title, + } + + candidates = append(candidates, candidate{ + chunk: chunk, + item: item, + similarity: similarity, + bm25Score: bm25Score, + }) + } + + // 混合排序(向量相似度 + BM25) + hybridWeight := r.config.HybridWeight + if hybridWeight == 0 { + hybridWeight = 0.7 + } + + // 按混合分数排序(简化:主要按相似度,BM25作为次要因素) + // 这里我们主要使用相似度,因为BM25分数可能不稳定 + // 实际可以使用更复杂的混合策略 + + // 选择Top-K + if len(candidates) > topK { + // 简单排序(按相似度) + for i := 0; i < len(candidates)-1; i++ { + for j := i + 1; j < len(candidates); j++ { + if candidates[i].similarity < candidates[j].similarity { + candidates[i], candidates[j] = candidates[j], candidates[i] + } + } + } + candidates = candidates[:topK] + } + + // 转换为结果 + results := make([]*RetrievalResult, len(candidates)) + for i, cand := range candidates { + // 计算混合分数 + normalizedBM25 := math.Min(cand.bm25Score, 1.0) + hybridScore := hybridWeight*cand.similarity + (1-hybridWeight)*normalizedBM25 + + results[i] = &RetrievalResult{ + Chunk: cand.chunk, + Item: cand.item, + Similarity: cand.similarity, + Score: hybridScore, + } + } + + return results, nil +} + diff --git a/internal/knowledge/tool.go b/internal/knowledge/tool.go new file mode 100644 index 00000000..f1d97618 --- /dev/null +++ b/internal/knowledge/tool.go @@ -0,0 +1,191 @@ +package knowledge + +import ( + "context" + "encoding/json" + "fmt" + "strings" + + "cyberstrike-ai/internal/mcp" + + "go.uber.org/zap" +) + +// RegisterKnowledgeTool 注册知识检索工具到MCP服务器 +func RegisterKnowledgeTool( + mcpServer *mcp.Server, + retriever *Retriever, + manager *Manager, + logger *zap.Logger, +) { + // manager 和 retriever 在 handler 中直接使用参数 + _ = manager // 保留参数,可能将来用于日志记录等 + tool := mcp.Tool{ + Name: "search_knowledge_base", + Description: "在知识库中搜索相关的安全知识。当你需要了解特定漏洞类型、攻击技术、检测方法等安全知识时,可以使用此工具进行检索。工具使用向量检索和混合搜索技术,能够根据查询内容的语义相似度和关键词匹配,自动找到最相关的知识片段。", + ShortDescription: "搜索知识库中的安全知识(支持向量检索和混合搜索)", + InputSchema: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "query": map[string]interface{}{ + "type": "string", + "description": "搜索查询内容,描述你想要了解的安全知识主题", + }, + "risk_type": map[string]interface{}{ + "type": "string", + "description": "可选:指定风险类型(如:SQL注入、XSS、文件上传等),如果不指定则搜索所有类型", + }, + }, + "required": []string{"query"}, + }, + } + + handler := func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) { + query, ok := args["query"].(string) + if !ok || query == "" { + return &mcp.ToolResult{ + Content: []mcp.Content{ + { + Type: "text", + Text: "错误: 查询参数不能为空", + }, + }, + IsError: true, + }, nil + } + + riskType := "" + if rt, ok := args["risk_type"].(string); ok && rt != "" { + riskType = rt + } + + logger.Info("执行知识库检索", + zap.String("query", query), + zap.String("riskType", riskType), + ) + + // 执行检索 + searchReq := &SearchRequest{ + Query: query, + RiskType: riskType, + TopK: 5, + } + + results, err := retriever.Search(ctx, searchReq) + if err != nil { + logger.Error("知识库检索失败", zap.Error(err)) + return &mcp.ToolResult{ + Content: []mcp.Content{ + { + Type: "text", + Text: fmt.Sprintf("检索失败: %v", err), + }, + }, + IsError: true, + }, nil + } + + if len(results) == 0 { + return &mcp.ToolResult{ + Content: []mcp.Content{ + { + Type: "text", + Text: fmt.Sprintf("未找到与查询 '%s' 相关的知识。建议:\n1. 尝试使用不同的关键词\n2. 检查风险类型是否正确\n3. 确认知识库中是否包含相关内容", query), + }, + }, + }, nil + } + + // 格式化结果 + var resultText strings.Builder + resultText.WriteString(fmt.Sprintf("找到 %d 条相关知识:\n\n", len(results))) + + // 收集检索到的知识项ID(用于日志) + retrievedItemIDs := make([]string, 0, len(results)) + + for i, result := range results { + resultText.WriteString(fmt.Sprintf("--- 结果 %d (相似度: %.2f%%) ---\n", i+1, result.Similarity*100)) + resultText.WriteString(fmt.Sprintf("来源: [%s] %s\n", result.Item.Category, result.Item.Title)) + resultText.WriteString(fmt.Sprintf("内容:\n%s\n\n", result.Chunk.ChunkText)) + + if !contains(retrievedItemIDs, result.Item.ID) { + retrievedItemIDs = append(retrievedItemIDs, result.Item.ID) + } + } + + // 在结果末尾添加元数据(JSON格式,用于提取知识项ID) + // 使用特殊标记,避免影响AI阅读结果 + if len(retrievedItemIDs) > 0 { + metadataJSON, _ := json.Marshal(map[string]interface{}{ + "_metadata": map[string]interface{}{ + "retrievedItemIDs": retrievedItemIDs, + }, + }) + resultText.WriteString(fmt.Sprintf("\n", string(metadataJSON))) + } + + // 记录检索日志(异步,不阻塞) + // 注意:这里没有conversationID和messageID,需要在Agent层面记录 + // 实际的日志记录应该在Agent的progressCallback中完成 + + return &mcp.ToolResult{ + Content: []mcp.Content{ + { + Type: "text", + Text: resultText.String(), + }, + }, + }, nil + } + + mcpServer.RegisterTool(tool, handler) + logger.Info("知识检索工具已注册", zap.String("toolName", tool.Name)) +} + +// contains 检查切片是否包含元素 +func contains(slice []string, item string) bool { + for _, s := range slice { + if s == item { + return true + } + } + return false +} + +// GetRetrievalMetadata 从工具调用中提取检索元数据(用于日志记录) +func GetRetrievalMetadata(args map[string]interface{}) (query string, riskType string) { + if q, ok := args["query"].(string); ok { + query = q + } + if rt, ok := args["risk_type"].(string); ok { + riskType = rt + } + return +} + +// FormatRetrievalResults 格式化检索结果为字符串(用于日志) +func FormatRetrievalResults(results []*RetrievalResult) string { + if len(results) == 0 { + return "未找到相关结果" + } + + var builder strings.Builder + builder.WriteString(fmt.Sprintf("检索到 %d 条结果:\n", len(results))) + + itemIDs := make(map[string]bool) + for i, result := range results { + builder.WriteString(fmt.Sprintf("%d. [%s] %s (相似度: %.2f%%)\n", + i+1, result.Item.Category, result.Item.Title, result.Similarity*100)) + itemIDs[result.Item.ID] = true + } + + // 返回知识项ID列表(JSON格式) + ids := make([]string, 0, len(itemIDs)) + for id := range itemIDs { + ids = append(ids, id) + } + idsJSON, _ := json.Marshal(ids) + builder.WriteString(fmt.Sprintf("\n检索到的知识项ID: %s", string(idsJSON))) + + return builder.String() +} diff --git a/internal/knowledge/types.go b/internal/knowledge/types.go new file mode 100644 index 00000000..6608f1e7 --- /dev/null +++ b/internal/knowledge/types.go @@ -0,0 +1,67 @@ +package knowledge + +import ( + "encoding/json" + "time" +) + +// KnowledgeItem 知识库项 +type KnowledgeItem struct { + ID string `json:"id"` + Category string `json:"category"` // 风险类型(文件夹名) + Title string `json:"title"` // 标题(文件名) + FilePath string `json:"filePath"` // 文件路径 + Content string `json:"content"` // 文件内容 + CreatedAt time.Time `json:"createdAt"` + UpdatedAt time.Time `json:"updatedAt"` +} + +// KnowledgeChunk 知识块(用于向量化) +type KnowledgeChunk struct { + ID string `json:"id"` + ItemID string `json:"itemId"` + ChunkIndex int `json:"chunkIndex"` + ChunkText string `json:"chunkText"` + Embedding []float32 `json:"-"` // 向量嵌入,不序列化到JSON + CreatedAt time.Time `json:"createdAt"` +} + +// RetrievalResult 检索结果 +type RetrievalResult struct { + Chunk *KnowledgeChunk `json:"chunk"` + Item *KnowledgeItem `json:"item"` + Similarity float64 `json:"similarity"` // 相似度分数 + Score float64 `json:"score"` // 综合分数(混合检索) +} + +// RetrievalLog 检索日志 +type RetrievalLog struct { + ID string `json:"id"` + ConversationID string `json:"conversationId,omitempty"` + MessageID string `json:"messageId,omitempty"` + Query string `json:"query"` + RiskType string `json:"riskType,omitempty"` + RetrievedItems []string `json:"retrievedItems"` // 检索到的知识项ID列表 + CreatedAt time.Time `json:"createdAt"` +} + +// MarshalJSON 自定义JSON序列化,确保时间格式正确 +func (r *RetrievalLog) MarshalJSON() ([]byte, error) { + type Alias RetrievalLog + return json.Marshal(&struct { + *Alias + CreatedAt string `json:"createdAt"` + }{ + Alias: (*Alias)(r), + CreatedAt: r.CreatedAt.Format(time.RFC3339), + }) +} + +// SearchRequest 搜索请求 +type SearchRequest struct { + Query string `json:"query"` + RiskType string `json:"riskType,omitempty"` // 可选:指定风险类型 + TopK int `json:"topK,omitempty"` // 返回Top-K结果,默认5 + Threshold float64 `json:"threshold,omitempty"` // 相似度阈值,默认0.7 +} + diff --git a/web/static/css/style.css b/web/static/css/style.css index d5fc3d60..8bf60144 100644 --- a/web/static/css/style.css +++ b/web/static/css/style.css @@ -2440,6 +2440,28 @@ header { margin-bottom: 16px; } +.settings-subsection-header { + margin-top: 24px; + margin-bottom: 16px; + padding-top: 16px; + border-top: 1px solid var(--border-color); +} + +.settings-subsection-header h5 { + font-size: 0.9375rem; + font-weight: 600; + color: var(--text-primary); + margin-bottom: 12px; +} + +.form-hint { + display: block; + font-size: 0.8125rem; + color: var(--text-secondary); + margin-top: 4px; + line-height: 1.4; +} + .settings-description { font-size: 0.875rem; color: var(--text-secondary); @@ -2482,7 +2504,8 @@ header { color: var(--text-primary); } -.form-group input { +.form-group input, +.form-group select { padding: 10px 12px; border: 1px solid var(--border-color); border-radius: 6px; @@ -2490,24 +2513,138 @@ header { background: var(--bg-primary); color: var(--text-primary); transition: border-color 0.2s; + font-family: inherit; + width: 100%; + box-sizing: border-box; } -.form-group input:focus { +.form-group select { + cursor: pointer; + appearance: none; + background-image: url("data:image/svg+xml,%3Csvg xmlns='http://www.w3.org/2000/svg' width='12' height='12' viewBox='0 0 12 12'%3E%3Cpath fill='%23666' d='M6 9L1 4h10z'/%3E%3C/svg%3E"); + background-repeat: no-repeat; + background-position: right 12px center; + padding-right: 36px; +} + +.form-group input:focus, +.form-group select:focus { outline: none; border-color: var(--accent-color); box-shadow: 0 0 0 3px rgba(0, 102, 255, 0.1); } -.form-group input.error { +.form-group input:hover, +.form-group select:hover { + border-color: var(--accent-color); +} + +.form-group input.error, +.form-group select.error { border-color: var(--error-color); box-shadow: 0 0 0 3px rgba(220, 53, 69, 0.1); } -.form-group input.error:focus { +.form-group input.error:focus, +.form-group select.error:focus { border-color: var(--error-color); box-shadow: 0 0 0 3px rgba(220, 53, 69, 0.2); } +/* 现代化复选框样式 */ +.checkbox-label { + display: flex !important; + align-items: center; + gap: 12px; + margin: 0; + justify-content: flex-start; + width: 100%; + cursor: pointer; + user-select: none; + padding: 8px 0; + transition: all 0.2s ease; +} + +.checkbox-label:hover { + opacity: 0.9; +} + +.modern-checkbox { + position: absolute; + opacity: 0; + width: 0; + height: 0; + pointer-events: none; +} + +.checkbox-custom { + position: relative; + width: 44px; + height: 24px; + background: var(--bg-tertiary); + border: 2px solid var(--border-color); + border-radius: 12px; + transition: all 0.3s cubic-bezier(0.4, 0, 0.2, 1); + flex-shrink: 0; + box-shadow: inset 0 2px 4px rgba(0, 0, 0, 0.1); +} + +.checkbox-custom::before { + content: ''; + position: absolute; + top: 2px; + left: 2px; + width: 16px; + height: 16px; + background: white; + border-radius: 50%; + transition: all 0.3s cubic-bezier(0.4, 0, 0.2, 1); + box-shadow: 0 2px 4px rgba(0, 0, 0, 0.2); +} + +.modern-checkbox:checked + .checkbox-custom { + background: var(--accent-color); + border-color: var(--accent-color); + box-shadow: 0 0 0 3px rgba(0, 102, 255, 0.15), inset 0 2px 4px rgba(0, 0, 0, 0.1); +} + +.modern-checkbox:checked + .checkbox-custom::before { + transform: translateX(20px); + box-shadow: 0 2px 6px rgba(0, 0, 0, 0.25); +} + +.modern-checkbox:focus + .checkbox-custom { + outline: none; + box-shadow: 0 0 0 3px rgba(0, 102, 255, 0.2), inset 0 2px 4px rgba(0, 0, 0, 0.1); +} + +.modern-checkbox:checked:focus + .checkbox-custom { + box-shadow: 0 0 0 3px rgba(0, 102, 255, 0.25), inset 0 2px 4px rgba(0, 0, 0, 0.1); +} + +.checkbox-text { + font-size: 0.9375rem; + font-weight: 500; + color: var(--text-primary); + line-height: 1.5; + transition: color 0.2s ease; +} + +.checkbox-label:hover .checkbox-text { + color: var(--accent-color); +} + +.modern-checkbox:disabled + .checkbox-custom { + opacity: 0.5; + cursor: not-allowed; + background: var(--bg-tertiary); +} + +.modern-checkbox:disabled + .checkbox-custom + .checkbox-text { + opacity: 0.5; + cursor: not-allowed; +} + .form-actions { display: flex; justify-content: flex-end; @@ -3723,3 +3860,879 @@ header { font-size: 1rem; padding: 20px; } + +/* ==================== 知识管理样式 ==================== */ + +.knowledge-controls { + display: flex; + flex-direction: column; + gap: 16px; + margin-bottom: 24px; +} + +.knowledge-stats-bar { + display: flex; + gap: 24px; + padding: 16px 20px; + background: linear-gradient(135deg, rgba(0, 102, 255, 0.05) 0%, rgba(0, 102, 255, 0.02) 100%); + border: 1px solid rgba(0, 102, 255, 0.15); + border-radius: 12px; + box-shadow: var(--shadow-sm); +} + +.knowledge-stat-item { + display: flex; + flex-direction: column; + gap: 4px; +} + +.knowledge-stat-label { + font-size: 0.8125rem; + color: var(--text-secondary); + font-weight: 500; +} + +.knowledge-stat-value { + font-size: 1.5rem; + font-weight: 600; + color: var(--accent-color); +} + +/* 索引进度显示 */ +.knowledge-index-progress { + padding: 16px 20px; + background: linear-gradient(135deg, rgba(255, 193, 7, 0.1) 0%, rgba(255, 193, 7, 0.05) 100%); + border: 1px solid rgba(255, 193, 7, 0.3); + border-radius: 12px; + box-shadow: var(--shadow-sm); +} + +.knowledge-index-progress-complete { + padding: 16px 20px; + background: linear-gradient(135deg, rgba(40, 167, 69, 0.1) 0%, rgba(40, 167, 69, 0.05) 100%); + border: 1px solid rgba(40, 167, 69, 0.3); + border-radius: 12px; + box-shadow: var(--shadow-sm); + display: flex; + align-items: center; + gap: 12px; +} + +.knowledge-index-progress .progress-header { + display: flex; + align-items: center; + gap: 12px; + margin-bottom: 12px; +} + +.knowledge-index-progress .progress-icon { + font-size: 1.25rem; + flex-shrink: 0; +} + +.knowledge-index-progress .progress-text { + font-size: 0.9375rem; + font-weight: 500; + color: var(--text-primary); + flex: 1; +} + +.knowledge-index-progress-complete .progress-icon { + font-size: 1.25rem; +} + +.knowledge-index-progress-complete .progress-text { + font-size: 0.9375rem; + font-weight: 500; + color: var(--text-primary); +} + +.progress-bar-container { + width: 100%; + height: 8px; + background: rgba(0, 0, 0, 0.1); + border-radius: 4px; + overflow: hidden; + margin-bottom: 8px; +} + +.progress-bar { + height: 100%; + background: linear-gradient(90deg, #ffc107 0%, #ff9800 100%); + border-radius: 4px; + transition: width 0.3s ease; + animation: progress-pulse 2s ease-in-out infinite; +} + +@keyframes progress-pulse { + 0%, 100% { + opacity: 1; + } + 50% { + opacity: 0.8; + } +} + +.progress-hint { + font-size: 0.8125rem; + color: var(--text-secondary); + margin-top: 4px; +} + +.knowledge-filters { + display: flex; + gap: 16px; + align-items: flex-end; + flex-wrap: wrap; +} + +.knowledge-filters label { + display: flex; + flex-direction: column; + gap: 6px; + font-size: 0.875rem; + font-weight: 500; + color: var(--text-primary); +} + +.knowledge-filters select { + padding: 8px 12px; + border: 1px solid var(--border-color); + border-radius: 6px; + background: var(--bg-primary); + color: var(--text-primary); + font-size: 0.875rem; + cursor: pointer; + transition: all 0.2s; + min-width: 160px; +} + +.knowledge-filters select:focus { + outline: none; + border-color: var(--accent-color); + box-shadow: 0 0 0 2px rgba(0, 102, 255, 0.1); +} + +/* 自定义下拉组件样式 */ +.custom-select-wrapper { + position: relative; + min-width: 160px; +} + +.custom-select { + position: relative; +} + +.custom-select-trigger { + display: flex; + align-items: center; + justify-content: space-between; + padding: 8px 12px; + border: 1px solid var(--border-color); + border-radius: 6px; + background: var(--bg-primary); + color: var(--text-primary); + font-size: 0.875rem; + cursor: pointer; + transition: all 0.2s; + min-width: 160px; +} + +.custom-select-trigger:hover { + border-color: var(--accent-color); +} + +.custom-select-trigger svg { + transition: transform 0.2s; + flex-shrink: 0; + margin-left: 8px; +} + +.custom-select.open .custom-select-trigger { + border-color: var(--accent-color); + box-shadow: 0 0 0 2px rgba(0, 102, 255, 0.1); +} + +.custom-select.open .custom-select-trigger svg { + transform: rotate(180deg); +} + +.custom-select-dropdown { + position: absolute; + top: calc(100% + 4px); + left: 0; + right: 0; + background: var(--bg-primary); + border: 1px solid var(--border-color); + border-radius: 6px; + box-shadow: var(--shadow-lg); + z-index: 1000; + max-height: 300px; + overflow-y: auto; + overflow-x: hidden; + display: none; +} + +.custom-select.open .custom-select-dropdown { + display: block; +} + +.custom-select-option { + padding: 10px 12px; + font-size: 0.875rem; + color: var(--text-primary); + cursor: pointer; + transition: background-color 0.15s; + border-bottom: 1px solid var(--border-color); +} + +.custom-select-option:last-child { + border-bottom: none; +} + +.custom-select-option:hover { + background: var(--bg-secondary); +} + +.custom-select-option.selected { + background: var(--accent-color); + color: white; +} + +.custom-select-option.selected:hover { + background: var(--accent-hover); +} + +/* 自定义下拉组件滚动条样式 */ +.custom-select-dropdown::-webkit-scrollbar { + width: 6px; +} + +.custom-select-dropdown::-webkit-scrollbar-track { + background: transparent; +} + +.custom-select-dropdown::-webkit-scrollbar-thumb { + background: rgba(0, 0, 0, 0.2); + border-radius: 3px; +} + +.custom-select-dropdown::-webkit-scrollbar-thumb:hover { + background: rgba(0, 0, 0, 0.3); +} + +.knowledge-items-list { + min-height: 200px; +} + +.knowledge-categories-container { + display: flex; + flex-direction: column; + gap: 32px; +} + +.knowledge-category-section { + background: var(--bg-primary); + border: 1px solid var(--border-color); + border-radius: 16px; + padding: 24px; + box-shadow: var(--shadow-sm); + transition: all 0.2s ease; +} + +.knowledge-category-section:hover { + box-shadow: var(--shadow-md); + border-color: rgba(0, 102, 255, 0.2); +} + +.knowledge-category-header { + margin-bottom: 20px; + padding-bottom: 16px; + border-bottom: 2px solid var(--border-color); +} + +.knowledge-category-info { + display: flex; + align-items: center; + gap: 12px; + flex-wrap: wrap; +} + +.knowledge-category-title { + margin: 0; + font-size: 1.25rem; + font-weight: 600; + color: var(--text-primary); + display: flex; + align-items: center; + gap: 8px; +} + +.knowledge-category-title::before { + content: '📁'; + font-size: 1.1rem; +} + +.knowledge-category-count { + padding: 4px 12px; + background: rgba(0, 102, 255, 0.1); + border: 1px solid rgba(0, 102, 255, 0.2); + border-radius: 12px; + font-size: 0.8125rem; + font-weight: 600; + color: var(--accent-color); +} + +.knowledge-items-grid { + display: grid; + grid-template-columns: repeat(auto-fill, minmax(320px, 1fr)); + gap: 16px; +} + +.knowledge-item-card { + background: var(--bg-secondary); + border: 1px solid var(--border-color); + border-radius: 12px; + padding: 16px; + transition: all 0.2s ease; + display: flex; + flex-direction: column; + gap: 12px; + cursor: pointer; + position: relative; + overflow: hidden; +} + +.knowledge-item-card::before { + content: ''; + position: absolute; + top: 0; + left: 0; + width: 4px; + height: 100%; + background: var(--accent-color); + opacity: 0; + transition: opacity 0.2s ease; +} + +.knowledge-item-card:hover { + background: var(--bg-primary); + border-color: var(--accent-color); + box-shadow: var(--shadow-md); + transform: translateY(-2px); +} + +.knowledge-item-card:hover::before { + opacity: 1; +} + +.knowledge-item-card-header { + display: flex; + flex-direction: column; + gap: 8px; +} + +.knowledge-item-card-title-row { + display: flex; + align-items: flex-start; + justify-content: space-between; + gap: 12px; +} + +.knowledge-item-card-title { + margin: 0; + font-size: 1rem; + font-weight: 600; + color: var(--text-primary); + flex: 1; + line-height: 1.4; + overflow: hidden; + text-overflow: ellipsis; + display: -webkit-box; + -webkit-line-clamp: 2; + -webkit-box-orient: vertical; +} + +.knowledge-item-card-actions { + display: flex; + gap: 4px; + flex-shrink: 0; + opacity: 0; + transition: opacity 0.2s ease; +} + +.knowledge-item-card:hover .knowledge-item-card-actions { + opacity: 1; +} + +.knowledge-item-action-btn { + width: 32px; + height: 32px; + display: flex; + align-items: center; + justify-content: center; + border: 1px solid var(--border-color); + border-radius: 6px; + background: var(--bg-primary); + color: var(--text-secondary); + cursor: pointer; + transition: all 0.2s ease; + padding: 0; +} + +.knowledge-item-action-btn:hover { + background: var(--bg-tertiary); + border-color: var(--accent-color); + color: var(--accent-color); + transform: scale(1.05); +} + +.knowledge-item-delete-btn:hover { + background: rgba(220, 53, 69, 0.1); + border-color: var(--error-color); + color: var(--error-color); +} + +.knowledge-item-action-btn svg { + width: 16px; + height: 16px; + stroke: currentColor; +} + +.knowledge-item-path { + font-size: 0.75rem; + color: var(--text-muted); + display: flex; + align-items: center; + gap: 4px; + font-family: 'Monaco', 'Menlo', 'Ubuntu Mono', monospace; + overflow: hidden; + text-overflow: ellipsis; + white-space: nowrap; +} + +.knowledge-item-card-content { + flex: 1; + min-height: 60px; +} + +.knowledge-item-preview { + margin: 0; + font-size: 0.875rem; + color: var(--text-secondary); + line-height: 1.6; + overflow: hidden; + text-overflow: ellipsis; + display: -webkit-box; + -webkit-line-clamp: 3; + -webkit-box-orient: vertical; +} + +.knowledge-item-card-footer { + display: flex; + justify-content: space-between; + align-items: center; + gap: 8px; + padding-top: 12px; + border-top: 1px solid var(--border-color); + flex-wrap: wrap; +} + +.knowledge-item-meta { + display: flex; + align-items: center; + gap: 8px; + flex: 1; + min-width: 0; +} + +.knowledge-item-time { + font-size: 0.75rem; + color: var(--text-muted); + white-space: nowrap; +} + +.knowledge-item-badge-new { + padding: 2px 6px; + background: rgba(40, 167, 69, 0.15); + border: 1px solid rgba(40, 167, 69, 0.3); + border-radius: 4px; + font-size: 0.6875rem; + font-weight: 600; + color: var(--success-color); + white-space: nowrap; +} + +.knowledge-item-updated { + font-size: 0.75rem; + color: var(--text-muted); + white-space: nowrap; +} + +.empty-state { + text-align: center; + padding: 48px 24px; + color: var(--text-muted); + font-size: 0.9375rem; + background: var(--bg-secondary); + border: 2px dashed var(--border-color); + border-radius: 12px; +} + +/* ==================== 检索历史样式 ==================== */ + +.retrieval-logs-controls { + display: flex; + flex-direction: column; + gap: 16px; + margin-bottom: 24px; +} + +.retrieval-stats-bar { + display: flex; + gap: 24px; + padding: 16px 20px; + background: linear-gradient(135deg, rgba(0, 102, 255, 0.05) 0%, rgba(0, 102, 255, 0.02) 100%); + border: 1px solid rgba(0, 102, 255, 0.15); + border-radius: 12px; + box-shadow: var(--shadow-sm); + flex-wrap: wrap; +} + +.retrieval-stat-item { + display: flex; + flex-direction: column; + gap: 4px; +} + +.retrieval-stat-label { + font-size: 0.8125rem; + color: var(--text-secondary); + font-weight: 500; +} + +.retrieval-stat-value { + font-size: 1.5rem; + font-weight: 600; + color: var(--accent-color); +} + +.retrieval-stat-value.text-success { + color: var(--success-color); +} + +.retrieval-logs-filters { + display: flex; + gap: 16px; + align-items: flex-end; + flex-wrap: wrap; + padding: 16px; + background: var(--bg-secondary); + border: 1px solid var(--border-color); + border-radius: 12px; +} + +.retrieval-logs-filters label { + display: flex; + flex-direction: column; + gap: 6px; + font-size: 0.875rem; + font-weight: 500; + color: var(--text-primary); + flex: 1; + min-width: 200px; +} + +.retrieval-logs-filters input { + padding: 8px 12px; + border: 1px solid var(--border-color); + border-radius: 6px; + background: var(--bg-primary); + color: var(--text-primary); + font-size: 0.875rem; + transition: all 0.2s; + font-family: 'Monaco', 'Menlo', 'Ubuntu Mono', monospace; +} + +.retrieval-logs-filters input:focus { + outline: none; + border-color: var(--accent-color); + box-shadow: 0 0 0 2px rgba(0, 102, 255, 0.1); +} + +.retrieval-logs-list { + display: flex; + flex-direction: column; + gap: 16px; +} + +.retrieval-log-card { + background: var(--bg-primary); + border: 1px solid var(--border-color); + border-radius: 12px; + padding: 20px; + transition: all 0.2s ease; + box-shadow: var(--shadow-sm); + position: relative; + overflow: hidden; +} + +.retrieval-log-card::before { + content: ''; + position: absolute; + top: 0; + left: 0; + width: 4px; + height: 100%; + background: var(--border-color); + transition: all 0.2s ease; +} + +.retrieval-log-card.has-results::before { + background: var(--success-color); +} + +.retrieval-log-card.no-results::before { + background: var(--warning-color); +} + +.retrieval-log-card:hover { + box-shadow: var(--shadow-md); + border-color: var(--accent-color); + transform: translateY(-2px); +} + +.retrieval-log-card-header { + display: flex; + align-items: flex-start; + gap: 16px; + margin-bottom: 16px; +} + +.retrieval-log-icon { + font-size: 1.5rem; + flex-shrink: 0; + width: 40px; + height: 40px; + display: flex; + align-items: center; + justify-content: center; + background: var(--bg-secondary); + border-radius: 8px; + border: 1px solid var(--border-color); +} + +.retrieval-log-main-info { + flex: 1; + min-width: 0; +} + +.retrieval-log-query { + font-size: 1rem; + font-weight: 600; + color: var(--text-primary); + margin-bottom: 8px; + line-height: 1.5; + word-break: break-word; +} + +.retrieval-log-meta { + display: flex; + align-items: center; + gap: 12px; + flex-wrap: wrap; +} + +.retrieval-log-time { + font-size: 0.8125rem; + color: var(--text-secondary); + display: flex; + align-items: center; + gap: 4px; +} + +.retrieval-log-risk-type { + font-size: 0.8125rem; + padding: 4px 10px; + background: rgba(0, 102, 255, 0.1); + border: 1px solid rgba(0, 102, 255, 0.2); + border-radius: 12px; + color: var(--accent-color); + font-weight: 500; +} + +.retrieval-log-result-badge { + flex-shrink: 0; + padding: 6px 14px; + border-radius: 20px; + font-size: 0.875rem; + font-weight: 600; + white-space: nowrap; +} + +.retrieval-log-result-badge.success { + background: rgba(40, 167, 69, 0.15); + border: 1px solid rgba(40, 167, 69, 0.3); + color: var(--success-color); +} + +.retrieval-log-result-badge.empty { + background: rgba(255, 193, 7, 0.15); + border: 1px solid rgba(255, 193, 7, 0.3); + color: #b8860b; +} + +.retrieval-log-card-body { + padding-top: 16px; + border-top: 1px solid var(--border-color); +} + +.retrieval-log-details-grid { + display: grid; + grid-template-columns: repeat(auto-fit, minmax(250px, 1fr)); + gap: 12px; + margin-bottom: 12px; +} + +.retrieval-log-detail-item { + display: flex; + flex-direction: column; + gap: 4px; + padding: 10px 12px; + background: var(--bg-secondary); + border-radius: 8px; + border: 1px solid var(--border-color); +} + +.detail-label { + font-size: 0.75rem; + color: var(--text-secondary); + font-weight: 600; + text-transform: uppercase; + letter-spacing: 0.5px; +} + +.detail-value { + font-size: 0.875rem; + color: var(--text-primary); + font-weight: 500; + word-break: break-all; + font-family: 'Monaco', 'Menlo', 'Ubuntu Mono', monospace; +} + +.detail-value.text-success { + color: var(--success-color); +} + +.detail-value.text-muted { + color: var(--text-muted); +} + +.retrieval-log-items-preview { + margin-top: 12px; + padding-top: 12px; + border-top: 1px solid var(--border-color); +} + +.retrieval-log-items-label { + font-size: 0.8125rem; + color: var(--text-secondary); + font-weight: 500; + margin-bottom: 8px; +} + +.retrieval-log-items-list { + display: flex; + flex-wrap: wrap; + gap: 8px; +} + +.retrieval-log-item-tag { + display: inline-flex; + align-items: center; + justify-content: center; + width: 32px; + height: 32px; + background: rgba(0, 102, 255, 0.1); + border: 1px solid rgba(0, 102, 255, 0.2); + border-radius: 8px; + font-size: 0.8125rem; + font-weight: 600; + color: var(--accent-color); +} + +.retrieval-log-item-tag.more { + background: var(--bg-tertiary); + border-color: var(--border-color); + color: var(--text-secondary); +} + +/* 响应式设计 */ +@media (max-width: 768px) { + .knowledge-items-grid { + grid-template-columns: 1fr; + } + + .knowledge-stats-bar { + flex-direction: column; + gap: 12px; + } + + .knowledge-stat-item { + flex-direction: row; + justify-content: space-between; + align-items: center; + } + + .knowledge-filters { + flex-direction: column; + align-items: stretch; + } + + .knowledge-filters select { + min-width: 100%; + } + + .retrieval-stats-bar { + flex-direction: column; + gap: 12px; + } + + .retrieval-stat-item { + flex-direction: row; + justify-content: space-between; + align-items: center; + } + + .retrieval-logs-filters { + flex-direction: column; + align-items: stretch; + } + + .retrieval-logs-filters label { + min-width: 100%; + } + + .retrieval-log-details-grid { + grid-template-columns: 1fr; + } + + .retrieval-log-card-header { + flex-wrap: wrap; + } + + .retrieval-log-result-badge { + width: 100%; + text-align: center; + } +} + +/* 旋转动画 */ +@keyframes spin { + from { + transform: rotate(0deg); + } + to { + transform: rotate(360deg); + } +} diff --git a/web/static/js/chat.js b/web/static/js/chat.js index 002f8e44..47155452 100644 --- a/web/static/js/chat.js +++ b/web/static/js/chat.js @@ -908,6 +908,13 @@ function renderProcessDetails(messageId, processDetails) { const success = data.success !== false; const statusIcon = success ? '✅' : '❌'; itemTitle = `${statusIcon} 工具 ${escapeHtml(toolName)} 执行${success ? '完成' : '失败'}`; + + // 如果是知识检索工具,添加特殊标记 + if (toolName === 'search_knowledge_base' && success) { + itemTitle = `📚 ${itemTitle} - 知识检索`; + } + } else if (eventType === 'knowledge_retrieval') { + itemTitle = '📚 知识检索'; } else if (eventType === 'error') { itemTitle = '❌ 错误'; } else if (eventType === 'cancelled') { diff --git a/web/static/js/knowledge.js b/web/static/js/knowledge.js new file mode 100644 index 00000000..1c067288 --- /dev/null +++ b/web/static/js/knowledge.js @@ -0,0 +1,1558 @@ +// 知识库管理相关功能 +let knowledgeCategories = []; +let knowledgeItems = []; +let currentEditingItemId = null; +let isSavingKnowledgeItem = false; // 防止重复提交 +let retrievalLogsData = []; // 存储检索日志数据,用于详情查看 + +// 加载知识分类 +async function loadKnowledgeCategories() { + try { + // 添加时间戳参数避免缓存 + const timestamp = Date.now(); + const response = await apiFetch(`/api/knowledge/categories?_t=${timestamp}`, { + method: 'GET', + headers: { + 'Cache-Control': 'no-cache, no-store, must-revalidate', + 'Pragma': 'no-cache', + 'Expires': '0' + } + }); + if (!response.ok) { + throw new Error('获取分类失败'); + } + const data = await response.json(); + knowledgeCategories = data.categories || []; + + // 更新分类筛选下拉框 + const filterDropdown = document.getElementById('knowledge-category-filter-dropdown'); + if (filterDropdown) { + filterDropdown.innerHTML = '
${escapeHtml(previewText || '无内容预览')}
+${escapeHtml(log.conversationId)}
+ ${escapeHtml(log.messageId)}
+ ${escapeHtml(log.conversationId)}
+ ${escapeHtml(log.messageId)}
+