Compare commits

...

17 Commits

Author SHA1 Message Date
公明 8a2177ffab Update version to v1.3.19 in config.yaml 2026-03-08 04:02:59 +08:00
公明 3a7bbfbb88 Delete internal/handler/wecom_test.go 2026-03-08 04:02:05 +08:00
公明 7c01641de9 Add files via upload 2026-03-08 04:01:33 +08:00
公明 1c1086eea4 Merge pull request #53 from 04cb/fix/ensure-user-message-after-compression
Fix Qwen model error by ensuring user message is kept after memory compression
2026-03-07 14:20:37 +08:00
04cb 8f4f40f894 Fix Qwen model error by ensuring user message is kept after memory compression
Qwen models require a user message in the message array, otherwise they return
'No user query found in messages' error. The adjustRecentStartForToolCalls
function now ensures at least one user message is included in recent messages
after compression to prevent this validation error.
2026-03-07 13:31:32 +08:00
公明 7f16ba706a Add files via upload 2026-03-07 13:19:46 +08:00
公明 0b950f95db Add files via upload 2026-03-07 00:17:02 +08:00
公明 d36984a1c1 Add files via upload 2026-03-06 23:21:16 +08:00
公明 da2109a970 Update version number to v1.3.18 2026-03-06 23:18:49 +08:00
公明 1866aa8089 Add files via upload 2026-03-06 22:51:18 +08:00
公明 5af06e539d Update config.yaml 2026-03-06 22:42:19 +08:00
公明 7493e70686 Add files via upload 2026-03-06 22:39:30 +08:00
公明 81f7a601b7 Update config.yaml 2026-03-06 21:06:42 +08:00
公明 27830d1399 Add files via upload 2026-03-06 20:11:22 +08:00
公明 d9a0178f80 Merge pull request #47 from chhs1129/fix-bug-logger-missing-error
Fix: logger shows empty error msg
2026-03-06 10:20:44 +08:00
chhs1129 1dd8cc7f50 Fix: logger shows empty error msg 2026-03-05 09:40:47 -08:00
公明 55045dd4e0 Add files via upload 2026-03-04 00:18:29 +08:00
27 changed files with 5751 additions and 325 deletions
+17 -1
View File
@@ -10,7 +10,7 @@
# ============================================
# 前端显示的版本号(可选,不填则显示默认版本)
version: "v1.3.16"
version: "v1.3.19"
# 服务器配置
server:
@@ -116,6 +116,22 @@ knowledge:
top_k: 5 # 检索返回的Top-K结果数量
similarity_threshold: 0.7 # 相似度阈值(0-1),低于此值的结果将被过滤
hybrid_weight: 0.7 # 混合检索权重(0-1),向量检索的权重,1.0表示纯向量检索,0.0表示纯关键词检索
# ============================================
# 索引配置(用于解决 API 限制问题)
# ============================================
indexing:
# 分块配置
chunk_size: 512 # 每个块的最大 token 数(默认 512),长文本会被分割成多个块
chunk_overlap: 50 # 块之间的重叠 token 数(默认 50),保持上下文连贯性
max_chunks_per_item: 0 # 单个知识项的最大块数量(0 表示不限制),防止单个文件消耗过多 API 配额
# 速率限制配置(解决 429 错误)
max_rpm: 0 # 每分钟最大请求数(默认 0 表示不限制),如 OpenAI 默认 200 RPM
rate_limit_delay_ms: 300 # 请求间隔毫秒数(默认 300),用于避免 API 速率限制,设为 0 不限制
# 建议值:200 次/分钟≈300ms, 100 次/分钟≈600ms
# 重试配置
max_retries: 3 # 最大重试次数(默认 3),遇到速率限制或服务器错误时自动重试
retry_delay_ms: 1000 # 重试间隔毫秒数(默认 1000),每次重试会递增延迟
# ============================================
# 机器人配置(企业微信、钉钉、飞书)
+38 -6
View File
@@ -2,7 +2,7 @@
[English](robot_en.md)
本文档说明如何通过**钉钉**、**飞书**与 CyberStrikeAI 对话(长连接模式),在手机端即可使用,无需在服务器上打开网页。按下面步骤操作可避免常见弯路。
本文档说明如何通过**钉钉**、**飞书**与 **企业微信** CyberStrikeAI 对话(长连接 / 回调模式),在手机端即可使用,无需在服务器上打开网页。按下面步骤操作可避免常见弯路。
---
@@ -19,12 +19,13 @@
---
## 二、支持的平台(长连接)
## 二、支持的平台(长连接 / 回调
| 平台 | 说明 |
|------|------|
| 钉钉 | 使用 Stream 长连接,程序主动连接钉钉接收消息 |
| 飞书 | 使用长连接,程序主动连接飞书接收消息 |
| 平台 | 说明 |
|----------|------|
| 钉钉 | 使用 Stream 长连接,程序主动连接钉钉接收消息 |
| 飞书 | 使用长连接,程序主动连接飞书接收消息 |
| 企业微信 | 使用 HTTP 回调接收消息,被动回包 + 主动调用企业微信发送消息 API |
下面第三节会按平台写清:在开放平台要做什么、要复制哪些字段、填到 CyberStrikeAI 的哪一栏。
@@ -101,6 +102,37 @@
---
### 3.3 企业微信 (WeCom)
> 企业微信目前采用「HTTP 回调 + 主动发送消息 API」的方式工作:
> - 用户发消息 → 企业微信以加密 XML **回调到你的服务器**(本程序的 `/api/robot/wecom`);
> - CyberStrikeAI 解密并调用 AI → 使用企业微信的 `message/send` 接口**主动发消息给用户**。
**配置概览:**
- 在企业微信管理后台创建或选择一个**自建应用**。
- 在该应用的「接收消息」处配置回调 URL、Token、EncodingAESKey。
- 在 CyberStrikeAI 的 `config.yaml` 中填入:
- `robots.wecom.corp_id`:企业 IDCorpID
- `robots.wecom.agent_id`:应用的 AgentId
- `robots.wecom.token`:消息回调使用的 Token
- `robots.wecom.encoding_aes_key`:消息回调使用的 EncodingAESKey
- `robots.wecom.secret`:该应用的 Secret(用于调用企业微信主动发送消息接口)
> **重要:IP 白名单(errcode 60020**
> CyberStrikeAI 使用 `https://qyapi.weixin.qq.com/cgi-bin/message/send` 主动发送 AI 回复。
> 若企业微信日志或本程序日志中出现 `errcode 60020 not allow to access from your ip`
>
> - 说明你的服务器出口 IP **没有加入企业微信的 IP 白名单**;
> - 请在企业微信管理后台中找到该自建应用的**「安全设置 / IP 白名单」**(具体入口可能因版本略有不同),将运行 CyberStrikeAI 的服务器公网 IP(如 `110.xxx.xxx.xxx`)加入白名单;
> - 保存后等待生效,再次发送消息测试。
>
> 如果 IP 未加入白名单,企业微信会拒绝主动发送消息,表现为:
> - 回调接口 `/api/robot/wecom` 能正常收到并处理消息;
> - 但手机端**始终收不到 AI 回复**,日志中有 `not allow to access from your ip` 提示。
---
## 四、机器人命令
在钉钉/飞书中向机器人发送以下**文本命令**(仅支持文本):
+36 -6
View File
@@ -2,7 +2,7 @@
[中文](robot.md)
This document explains how to chat with CyberStrikeAI from **DingTalk** and **Lark (Feishu)** using long-lived connections—no need to open a browser on the server. Following the steps below helps avoid common mistakes.
This document explains how to chat with CyberStrikeAI from **DingTalk**, **Lark (Feishu)**, and **WeCom (Enterprise WeChat)** using long-lived connections or HTTP callbacks—no need to open a browser on the server. Following the steps below helps avoid common mistakes.
---
@@ -19,12 +19,13 @@ Settings are written to the `robots` section of `config.yaml`; you can also edit
---
## 2. Supported platforms (long-lived connection)
## 2. Supported platforms (long-lived / callback)
| Platform | Description |
|----------|-------------|
| DingTalk | Stream long-lived connection; the app connects to DingTalk to receive messages |
| Lark (Feishu) | Long-lived connection; the app connects to Lark to receive messages |
| Platform | Description |
|----------------|-------------|
| DingTalk | Stream long-lived connection; the app connects to DingTalk to receive messages |
| Lark (Feishu) | Long-lived connection; the app connects to Lark to receive messages |
| WeCom (Qiye WX)| HTTP callback to receive messages; CyberStrikeAI replies via WeComs message sending API |
Section 3 below describes, per platform, what to do in the developer console and which fields to copy into CyberStrikeAI.
@@ -100,6 +101,35 @@ If you only have a **custom bot** Webhook URL (`oapi.dingtalk.com/robot/send?acc
---
### 3.3 WeCom (Enterprise WeChat)
> WeCom uses a **“HTTP callback + active message send API”** model:
> - User sends a message → WeCom sends an **encrypted XML callback** to your server (CyberStrikeAIs `/api/robot/wecom`).
> - CyberStrikeAI decrypts it, calls the AI, then uses WeComs `message/send` API to **actively push the reply** to the user.
**Configuration overview:**
- In the WeCom admin console, create or select a **custom app** (自建应用).
- In that apps settings, configure the message **callback URL**, **Token**, and **EncodingAESKey**.
- In CyberStrikeAIs `config.yaml`, fill in:
- `robots.wecom.corp_id`: your CorpID (企业 ID)
- `robots.wecom.agent_id`: the apps AgentId
- `robots.wecom.token`: the Token used for message callbacks
- `robots.wecom.encoding_aes_key`: the EncodingAESKey used for callbacks
- `robots.wecom.secret`: the apps Secret (used when calling WeCom APIs to send messages)
> **Important: IP allowlist (errcode 60020)**
> CyberStrikeAI calls `https://qyapi.weixin.qq.com/cgi-bin/message/send` to actively send AI replies.
> If logs show `errcode 60020 not allow to access from your ip`:
>
> - Your servers outbound IP is **not in WeComs IP allowlist**.
> - In the WeCom admin console, open the custom apps **Security / IP allowlist** settings (name may vary slightly), and add the public IP of the machine running CyberStrikeAI (e.g. `110.xxx.xxx.xxx`).
> - Save and wait for it to take effect, then test again.
>
> If the IP is not whitelisted, WeCom will reject active message sending. You will see that `/api/robot/wecom` receives and processes callbacks, but users **never see AI replies**, and logs contain `not allow to access from your ip`.
---
## 4. Bot commands
Send these **text commands** to the bot in DingTalk or Lark (text only):
+2 -1
View File
@@ -1,6 +1,6 @@
module cyberstrike-ai
go 1.23.0
go 1.24.0
toolchain go1.24.4
@@ -15,6 +15,7 @@ require (
github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1
github.com/pkoukk/tiktoken-go v0.1.8
go.uber.org/zap v1.26.0
golang.org/x/time v0.14.0
gopkg.in/yaml.v3 v3.0.1
)
+2
View File
@@ -129,6 +129,8 @@ golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k=
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI=
golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
+22 -1
View File
@@ -345,8 +345,29 @@ func (mc *MemoryCompressor) adjustRecentStartForToolCalls(msgs []ChatMessage, re
adjusted--
}
// Ensure at least one user message is included in recent messages to avoid Qwen model error
// Qwen models require a user message in the message array, otherwise they return:
// "No user query found in messages"
hasUserMessage := false
for i := adjusted; i < len(msgs); i++ {
if strings.EqualFold(msgs[i].Role, "user") {
hasUserMessage = true
break
}
}
// If no user message in recent messages, adjust backwards to include one
if !hasUserMessage {
for adjusted > 0 {
adjusted--
if strings.EqualFold(msgs[adjusted].Role, "user") {
break
}
}
}
if adjusted != recentStart {
mc.logger.Debug("adjusted recent window to keep tool call context",
mc.logger.Debug("adjusted recent window to keep tool call context and user message",
zap.Int("original_recent_start", recentStart),
zap.Int("adjusted_recent_start", adjusted),
)
+2 -2
View File
@@ -198,7 +198,7 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
knowledgeRetriever = knowledge.NewRetriever(knowledgeDB, embedder, retrievalConfig, log.Logger)
// 创建索引器
knowledgeIndexer = knowledge.NewIndexer(knowledgeDB, embedder, log.Logger)
knowledgeIndexer = knowledge.NewIndexer(knowledgeDB, embedder, log.Logger, &cfg.Knowledge.Indexing)
// 注册知识检索工具到MCP服务器
knowledge.RegisterKnowledgeTool(mcpServer, knowledgeRetriever, knowledgeManager, log.Logger)
@@ -1102,7 +1102,7 @@ func initializeKnowledge(
knowledgeRetriever := knowledge.NewRetriever(knowledgeDB, embedder, retrievalConfig, logger)
// 创建索引器
knowledgeIndexer := knowledge.NewIndexer(knowledgeDB, embedder, logger)
knowledgeIndexer := knowledge.NewIndexer(knowledgeDB, embedder, logger, &cfg.Knowledge.Indexing)
// 注册知识检索工具到MCP服务器
knowledge.RegisterKnowledgeTool(mcpServer, knowledgeRetriever, knowledgeManager, logger)
+30 -1
View File
@@ -582,9 +582,18 @@ func Default() *Config {
},
Retrieval: RetrievalConfig{
TopK: 5,
SimilarityThreshold: 0.7,
SimilarityThreshold: 0.65, // 降低阈值到 0.65,减少漏检
HybridWeight: 0.7,
},
Indexing: IndexingConfig{
ChunkSize: 768, // 增加到 768,更好的上下文保持
ChunkOverlap: 50,
MaxChunksPerItem: 20, // 限制单个知识项最多 20 个块,避免消耗过多配额
MaxRPM: 100, // 默认 100 RPM,避免 429 错误
RateLimitDelayMs: 600, // 600ms 间隔,对应 100 RPM
MaxRetries: 3,
RetryDelayMs: 1000,
},
},
}
}
@@ -595,6 +604,26 @@ type KnowledgeConfig struct {
BasePath string `yaml:"base_path" json:"base_path"` // 知识库路径
Embedding EmbeddingConfig `yaml:"embedding" json:"embedding"`
Retrieval RetrievalConfig `yaml:"retrieval" json:"retrieval"`
Indexing IndexingConfig `yaml:"indexing,omitempty" json:"indexing,omitempty"` // 索引构建配置
}
// IndexingConfig 索引构建配置(用于控制知识库索引构建时的行为)
type IndexingConfig struct {
// 分块配置
ChunkSize int `yaml:"chunk_size,omitempty" json:"chunk_size,omitempty"` // 每个块的最大 token 数(估算),默认 512
ChunkOverlap int `yaml:"chunk_overlap,omitempty" json:"chunk_overlap,omitempty"` // 块之间的重叠 token 数,默认 50
MaxChunksPerItem int `yaml:"max_chunks_per_item,omitempty" json:"max_chunks_per_item,omitempty"` // 单个知识项的最大块数量,0 表示不限制
// 速率限制配置(用于避免 API 速率限制)
RateLimitDelayMs int `yaml:"rate_limit_delay_ms,omitempty" json:"rate_limit_delay_ms,omitempty"` // 请求间隔时间(毫秒),0 表示不使用固定延迟
MaxRPM int `yaml:"max_rpm,omitempty" json:"max_rpm,omitempty"` // 每分钟最大请求数,0 表示不限制
// 重试配置(用于处理临时错误)
MaxRetries int `yaml:"max_retries,omitempty" json:"max_retries,omitempty"` // 最大重试次数,默认 3
RetryDelayMs int `yaml:"retry_delay_ms,omitempty" json:"retry_delay_ms,omitempty"` // 重试间隔(毫秒),默认 1000
// 批处理配置(用于批量嵌入,当前未使用,保留扩展)
BatchSize int `yaml:"batch_size,omitempty" json:"batch_size,omitempty"` // 批量处理大小,0 表示逐个处理
}
// EmbeddingConfig 嵌入配置
+2 -1
View File
@@ -1444,7 +1444,8 @@ func (h *AgentHandler) executeBatchQueue(queueID string) {
// 执行任务(使用包含角色提示词的finalMessage和角色工具列表)
h.logger.Info("执行批量任务", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("message", task.Message), zap.String("role", queue.Role), zap.String("conversationId", conversationID))
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Minute)
// 单个子任务超时时间:从30分钟调整为6小时,适配长时间渗透/扫描任务
ctx, cancel := context.WithTimeout(context.Background(), 6*time.Hour)
// 存储取消函数,以便在取消队列时能够取消当前任务
h.batchTaskManager.SetTaskCancel(queueID, cancel)
// 使用队列配置的角色工具列表(如果为空,表示使用所有工具)
+10
View File
@@ -1062,6 +1062,16 @@ func updateKnowledgeConfig(doc *yaml.Node, cfg config.KnowledgeConfig) {
setIntInMap(retrievalNode, "top_k", cfg.Retrieval.TopK)
setFloatInMap(retrievalNode, "similarity_threshold", cfg.Retrieval.SimilarityThreshold)
setFloatInMap(retrievalNode, "hybrid_weight", cfg.Retrieval.HybridWeight)
// 更新索引配置
indexingNode := ensureMap(knowledgeNode, "indexing")
setIntInMap(indexingNode, "chunk_size", cfg.Indexing.ChunkSize)
setIntInMap(indexingNode, "chunk_overlap", cfg.Indexing.ChunkOverlap)
setIntInMap(indexingNode, "max_chunks_per_item", cfg.Indexing.MaxChunksPerItem)
setIntInMap(indexingNode, "max_rpm", cfg.Indexing.MaxRPM)
setIntInMap(indexingNode, "rate_limit_delay_ms", cfg.Indexing.RateLimitDelayMs)
setIntInMap(indexingNode, "max_retries", cfg.Indexing.MaxRetries)
setIntInMap(indexingNode, "retry_delay_ms", cfg.Indexing.RetryDelayMs)
}
func updateRobotsConfig(doc *yaml.Node, cfg config.RobotsConfig) {
+32 -10
View File
@@ -75,7 +75,7 @@ func (h *KnowledgeHandler) GetItems(c *gin.Context) {
groupedByCategory[cat] = append(groupedByCategory[cat], item)
}
// 转换为CategoryWithItems格式
// 转换为 CategoryWithItems 格式
categoriesWithItems := make([]*knowledge.CategoryWithItems, 0, len(groupedByCategory))
for cat, catItems := range groupedByCategory {
categoriesWithItems = append(categoriesWithItems, &knowledge.CategoryWithItems{
@@ -107,7 +107,7 @@ func (h *KnowledgeHandler) GetItems(c *gin.Context) {
categoryPageMode := c.Query("categoryPage") != "false" // 默认使用分类分页
// 分页参数
limit := 50 // 默认每页50条(分类分页时为分类数,项分页时为项数)
limit := 50 // 默认每页 50 条(分类分页时为分类数,项分页时为项数)
offset := 0
if limitStr := c.Query("limit"); limitStr != "" {
if parsed, err := parseInt(limitStr); err == nil && parsed > 0 && parsed <= 500 {
@@ -120,7 +120,7 @@ func (h *KnowledgeHandler) GetItems(c *gin.Context) {
}
}
// 如果指定了category参数,且使用分类分页模式,则只返回该分类
// 如果指定了 category 参数,且使用分类分页模式,则只返回该分类
if category != "" && categoryPageMode {
// 单分类模式:返回该分类的所有知识项(不分页)
items, total, err := h.manager.GetItemsSummary(category, 0, 0)
@@ -150,9 +150,9 @@ func (h *KnowledgeHandler) GetItems(c *gin.Context) {
if categoryPageMode {
// 按分类分页模式(默认)
// limit表示每页分类数,推荐5-10个分类
// limit 表示每页分类数,推荐 5-10 个分类
if limit <= 0 || limit > 100 {
limit = 10 // 默认每页10个分类
limit = 10 // 默认每页 10 个分类
}
categoriesWithItems, totalCategories, err := h.manager.GetCategoriesWithItems(limit, offset)
@@ -172,7 +172,7 @@ func (h *KnowledgeHandler) GetItems(c *gin.Context) {
}
// 按项分页模式(向后兼容)
// 是否包含完整内容(默认false,只返回摘要)
// 是否包含完整内容(默认 false,只返回摘要)
includeContent := c.Query("includeContent") == "true"
if includeContent {
@@ -358,7 +358,7 @@ func (h *KnowledgeHandler) ScanKnowledgeBase(c *gin.Context) {
)
}
// 如果连续失败2次,立即停止增量索引
// 如果连续失败 2 次,立即停止增量索引
if consecutiveFailures >= 2 {
h.logger.Error("连续索引失败次数过多,立即停止增量索引",
zap.Int("consecutiveFailures", consecutiveFailures),
@@ -397,7 +397,7 @@ func (h *KnowledgeHandler) ScanKnowledgeBase(c *gin.Context) {
func (h *KnowledgeHandler) GetRetrievalLogs(c *gin.Context) {
conversationID := c.Query("conversationId")
messageID := c.Query("messageId")
limit := 50 // 默认50条
limit := 50 // 默认 50
if limitStr := c.Query("limit"); limitStr != "" {
if parsed, err := parseInt(limitStr); err == nil && parsed > 0 {
@@ -441,18 +441,40 @@ func (h *KnowledgeHandler) GetIndexStatus(c *gin.Context) {
if h.indexer != nil {
lastError, lastErrorTime := h.indexer.GetLastError()
if lastError != "" {
// 如果错误是最近发生的(5分钟内),则返回错误信息
// 如果错误是最近发生的(5 分钟内),则返回错误信息
if time.Since(lastErrorTime) < 5*time.Minute {
status["last_error"] = lastError
status["last_error_time"] = lastErrorTime.Format(time.RFC3339)
}
}
// 获取重建索引状态
isRebuilding, totalItems, current, failed, lastItemID, lastChunks, startTime := h.indexer.GetRebuildStatus()
if isRebuilding {
status["is_rebuilding"] = true
status["rebuild_total"] = totalItems
status["rebuild_current"] = current
status["rebuild_failed"] = failed
status["rebuild_start_time"] = startTime.Format(time.RFC3339)
if lastItemID != "" {
status["rebuild_last_item_id"] = lastItemID
}
if lastChunks > 0 {
status["rebuild_last_chunks"] = lastChunks
}
// 重建中时,is_complete 为 false
status["is_complete"] = false
// 计算重建进度百分比
if totalItems > 0 {
status["progress_percent"] = float64(current) / float64(totalItems) * 100
}
}
}
c.JSON(http.StatusOK, status)
}
// Search 搜索知识库(用于API调用,Agent内部使用Retriever
// Search 搜索知识库(用于 API 调用,Agent 内部使用 Retriever
func (h *KnowledgeHandler) Search(c *gin.Context) {
var req knowledge.SearchRequest
if err := c.ShouldBindJSON(&req); err != nil {
+381 -77
View File
@@ -1,11 +1,15 @@
package handler
import (
"bytes"
"context"
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"crypto/sha1"
"encoding/base64"
"encoding/binary"
"encoding/json"
"encoding/xml"
"errors"
"fmt"
@@ -141,56 +145,9 @@ func (h *RobotHandler) HandleMessage(platform, userID, text string) (reply strin
return "请输入内容或发送「帮助」/ help 查看命令。"
}
// 命令分发(支持中英文)
switch {
case text == robotCmdHelp || text == "help" || text == "" || text == "?":
return h.cmdHelp()
case text == robotCmdList || text == robotCmdListAlt || text == "list":
return h.cmdList()
case strings.HasPrefix(text, robotCmdSwitch+" ") || strings.HasPrefix(text, robotCmdContinue+" ") || strings.HasPrefix(text, "switch ") || strings.HasPrefix(text, "continue "):
var id string
switch {
case strings.HasPrefix(text, robotCmdSwitch+" "):
id = strings.TrimSpace(text[len(robotCmdSwitch)+1:])
case strings.HasPrefix(text, robotCmdContinue+" "):
id = strings.TrimSpace(text[len(robotCmdContinue)+1:])
case strings.HasPrefix(text, "switch "):
id = strings.TrimSpace(text[7:])
default:
id = strings.TrimSpace(text[9:])
}
return h.cmdSwitch(platform, userID, id)
case text == robotCmdNew || text == "new":
return h.cmdNew(platform, userID)
case text == robotCmdClear || text == "clear":
return h.cmdClear(platform, userID)
case text == robotCmdCurrent || text == "current":
return h.cmdCurrent(platform, userID)
case text == robotCmdStop || text == "stop":
return h.cmdStop(platform, userID)
case text == robotCmdRoles || text == robotCmdRolesList || text == "roles":
return h.cmdRoles()
case strings.HasPrefix(text, robotCmdRoles+" ") || strings.HasPrefix(text, robotCmdSwitchRole+" ") || strings.HasPrefix(text, "role "):
var roleName string
switch {
case strings.HasPrefix(text, robotCmdRoles+" "):
roleName = strings.TrimSpace(text[len(robotCmdRoles)+1:])
case strings.HasPrefix(text, robotCmdSwitchRole+" "):
roleName = strings.TrimSpace(text[len(robotCmdSwitchRole)+1:])
default:
roleName = strings.TrimSpace(text[5:])
}
return h.cmdSwitchRole(platform, userID, roleName)
case strings.HasPrefix(text, robotCmdDelete+" ") || strings.HasPrefix(text, "delete "):
var convID string
if strings.HasPrefix(text, robotCmdDelete+" ") {
convID = strings.TrimSpace(text[len(robotCmdDelete)+1:])
} else {
convID = strings.TrimSpace(text[7:])
}
return h.cmdDelete(platform, userID, convID)
case text == robotCmdVersion || text == "version":
return h.cmdVersion()
// 先尝试作为命令处理(支持中英文)
if cmdReply, ok := h.handleRobotCommand(platform, userID, text); ok {
return cmdReply
}
// 普通消息:走 Agent
@@ -404,6 +361,62 @@ func (h *RobotHandler) cmdVersion() string {
return "CyberStrikeAI " + v
}
// handleRobotCommand 处理机器人内置命令;若匹配到命令返回 (回复内容, true),否则返回 ("", false)
func (h *RobotHandler) handleRobotCommand(platform, userID, text string) (string, bool) {
switch {
case text == robotCmdHelp || text == "help" || text == "" || text == "?":
return h.cmdHelp(), true
case text == robotCmdList || text == robotCmdListAlt || text == "list":
return h.cmdList(), true
case strings.HasPrefix(text, robotCmdSwitch+" ") || strings.HasPrefix(text, robotCmdContinue+" ") || strings.HasPrefix(text, "switch ") || strings.HasPrefix(text, "continue "):
var id string
switch {
case strings.HasPrefix(text, robotCmdSwitch+" "):
id = strings.TrimSpace(text[len(robotCmdSwitch)+1:])
case strings.HasPrefix(text, robotCmdContinue+" "):
id = strings.TrimSpace(text[len(robotCmdContinue)+1:])
case strings.HasPrefix(text, "switch "):
id = strings.TrimSpace(text[7:])
default:
id = strings.TrimSpace(text[9:])
}
return h.cmdSwitch(platform, userID, id), true
case text == robotCmdNew || text == "new":
return h.cmdNew(platform, userID), true
case text == robotCmdClear || text == "clear":
return h.cmdClear(platform, userID), true
case text == robotCmdCurrent || text == "current":
return h.cmdCurrent(platform, userID), true
case text == robotCmdStop || text == "stop":
return h.cmdStop(platform, userID), true
case text == robotCmdRoles || text == robotCmdRolesList || text == "roles":
return h.cmdRoles(), true
case strings.HasPrefix(text, robotCmdRoles+" ") || strings.HasPrefix(text, robotCmdSwitchRole+" ") || strings.HasPrefix(text, "role "):
var roleName string
switch {
case strings.HasPrefix(text, robotCmdRoles+" "):
roleName = strings.TrimSpace(text[len(robotCmdRoles)+1:])
case strings.HasPrefix(text, robotCmdSwitchRole+" "):
roleName = strings.TrimSpace(text[len(robotCmdSwitchRole)+1:])
default:
roleName = strings.TrimSpace(text[5:])
}
return h.cmdSwitchRole(platform, userID, roleName), true
case strings.HasPrefix(text, robotCmdDelete+" ") || strings.HasPrefix(text, "delete "):
var convID string
if strings.HasPrefix(text, robotCmdDelete+" ") {
convID = strings.TrimSpace(text[len(robotCmdDelete)+1:])
} else {
convID = strings.TrimSpace(text[7:])
}
return h.cmdDelete(platform, userID, convID), true
case text == robotCmdVersion || text == "version":
return h.cmdVersion(), true
default:
return "", false
}
}
// —————— 企业微信 ——————
// wecomXML 企业微信回调 XML(明文模式下的简化结构;加密模式需先解密再解析)
@@ -418,14 +431,14 @@ type wecomXML struct {
Encrypt string `xml:"Encrypt"` // 加密模式下消息在此
}
// wecomReplyXML 被动回复 XML
// wecomReplyXML 被动回复 XML(仅用于兼容,当前使用手动构造 XML)
type wecomReplyXML struct {
XMLName xml.Name `xml:"xml"`
ToUserName string `xml:"ToUserName"`
FromUserName string `xml:"FromUserName"`
CreateTime int64 `xml:"CreateTime"`
MsgType string `xml:"MsgType"`
Content string `xml:"Content"`
FromUserName string `xml:"FromUserName"`
CreateTime int64 `xml:"CreateTime"`
MsgType string `xml:"MsgType"`
Content string `xml:"Content"`
}
// HandleWecomGET 企业微信 URL 校验(GET
@@ -434,15 +447,51 @@ func (h *RobotHandler) HandleWecomGET(c *gin.Context) {
c.String(http.StatusNotFound, "")
return
}
// Gin 的 Query() 会自动 URL 解码,拿到的就是正确的 base64 字符串
echostr := c.Query("echostr")
msgSignature := c.Query("msg_signature")
timestamp := c.Query("timestamp")
nonce := c.Query("nonce")
// 验证签名:将 token、timestamp、nonce、echostr 四个参数排序后拼接计算 SHA1
signature := h.signWecomRequest(h.config.Robots.Wecom.Token, timestamp, nonce, echostr)
if signature != msgSignature {
h.logger.Warn("企业微信 URL 验证签名失败", zap.String("expected", msgSignature), zap.String("got", signature))
c.String(http.StatusBadRequest, "invalid signature")
return
}
if echostr == "" {
c.String(http.StatusBadRequest, "missing echostr")
return
}
// 明文模式时企业微信可能直接传 echostr,先直接返回以通过校验
// 如果配置了 EncodingAESKey,说明是加密模式,需要解密 echostr
if h.config.Robots.Wecom.EncodingAESKey != "" {
decrypted, err := wecomDecrypt(h.config.Robots.Wecom.EncodingAESKey, echostr)
if err != nil {
h.logger.Warn("企业微信 echostr 解密失败", zap.Error(err))
c.String(http.StatusBadRequest, "decrypt failed")
return
}
c.String(http.StatusOK, string(decrypted))
return
}
// 明文模式直接返回 echostr
c.String(http.StatusOK, echostr)
}
// signWecomRequest 生成企业微信请求签名
// 企业微信签名算法:将 token、timestamp、nonce、echostr 四个值排序后拼接成字符串,再计算 SHA1
func (h *RobotHandler) signWecomRequest(token, timestamp, nonce, echostr string) string {
strs := []string{token, timestamp, nonce, echostr}
sort.Strings(strs)
s := strings.Join(strs, "")
hash := sha1.Sum([]byte(s))
return fmt.Sprintf("%x", hash)
}
// wecomDecrypt 企业微信消息解密(AES-256-CBC,PKCS7,明文格式:16字节随机+4字节长度+消息+corpID)
func wecomDecrypt(encodingAESKey, encryptedB64 string) ([]byte, error) {
key, err := base64.StdEncoding.DecodeString(encodingAESKey + "=")
@@ -484,54 +533,228 @@ func wecomDecrypt(encodingAESKey, encryptedB64 string) ([]byte, error) {
return plain[20 : 20+msgLen], nil
}
// wecomEncrypt 企业微信消息加密(AES-256-CBC,PKCS7,明文格式:16字节随机+4字节长度+消息+corpID)
func wecomEncrypt(encodingAESKey, message, corpID string) (string, error) {
key, err := base64.StdEncoding.DecodeString(encodingAESKey + "=")
if err != nil {
return "", err
}
if len(key) != 32 {
return "", fmt.Errorf("encoding_aes_key 解码后应为 32 字节")
}
// 构造明文:16 字节随机 + 4 字节长度 (大端) + 消息 + corpID
random := make([]byte, 16)
if _, err := rand.Read(random); err != nil {
// 降级方案:使用时间戳生成随机数
for i := range random {
random[i] = byte(time.Now().UnixNano() % 256)
}
}
msgLen := len(message)
msgBytes := []byte(message)
corpBytes := []byte(corpID)
plain := make([]byte, 16+4+msgLen+len(corpBytes))
copy(plain[:16], random)
binary.BigEndian.PutUint32(plain[16:20], uint32(msgLen))
copy(plain[20:20+msgLen], msgBytes)
copy(plain[20+msgLen:], corpBytes)
// PKCS7 填充
padding := aes.BlockSize - len(plain)%aes.BlockSize
pad := bytes.Repeat([]byte{byte(padding)}, padding)
plain = append(plain, pad...)
// AES-256-CBC 加密
block, err := aes.NewCipher(key)
if err != nil {
return "", err
}
iv := key[:16]
ciphertext := make([]byte, len(plain))
mode := cipher.NewCBCEncrypter(block, iv)
mode.CryptBlocks(ciphertext, plain)
return base64.StdEncoding.EncodeToString(ciphertext), nil
}
// HandleWecomPOST 企业微信消息回调(POST),支持明文与加密模式
func (h *RobotHandler) HandleWecomPOST(c *gin.Context) {
if !h.config.Robots.Wecom.Enabled {
h.logger.Debug("企业微信机器人未启用,跳过请求")
c.String(http.StatusOK, "")
return
}
bodyRaw, _ := io.ReadAll(c.Request.Body)
// 从 URL 获取签名参数(加密模式回复时需要用到)
timestamp := c.Query("timestamp")
nonce := c.Query("nonce")
msgSignature := c.Query("msg_signature")
// 先读取请求体,后续解析/签名验证都会用到
bodyRaw, err := io.ReadAll(c.Request.Body)
if err != nil {
h.logger.Warn("企业微信 POST 读取请求体失败", zap.Error(err))
c.String(http.StatusOK, "")
return
}
h.logger.Debug("企业微信 POST 收到请求", zap.String("body", string(bodyRaw)))
// 验证请求签名防止伪造。企业微信签名算法同 URL 验证,使用 token、timestamp、nonce、 Encrypt 四个字段
if msgSignature != "" {
var tmp wecomXML
if err := xml.Unmarshal(bodyRaw, &tmp); err == nil {
expected := h.signWecomRequest(h.config.Robots.Wecom.Token, timestamp, nonce, tmp.Encrypt)
if expected != msgSignature {
h.logger.Warn("企业微信 POST 签名验证失败", zap.String("expected", expected), zap.String("got", msgSignature))
c.String(http.StatusOK, "")
return
}
}
}
var body wecomXML
if err := xml.Unmarshal(bodyRaw, &body); err != nil {
h.logger.Debug("企业微信 POST 解析 XML 失败", zap.Error(err))
h.logger.Warn("企业微信 POST 解析 XML 失败", zap.Error(err))
c.String(http.StatusOK, "")
return
}
h.logger.Debug("企业微信 XML 解析成功", zap.String("ToUserName", body.ToUserName), zap.String("FromUserName", body.FromUserName), zap.String("MsgType", body.MsgType), zap.String("Content", body.Content), zap.String("Encrypt", body.Encrypt))
// 保存企业 ID(用于明文模式回复)
enterpriseID := body.ToUserName
// 加密模式:先解密再解析内层 XML
if body.Encrypt != "" && h.config.Robots.Wecom.EncodingAESKey != "" {
h.logger.Debug("企业微信进入加密模式解密流程")
decrypted, err := wecomDecrypt(h.config.Robots.Wecom.EncodingAESKey, body.Encrypt)
if err != nil {
h.logger.Warn("企业微信消息解密失败", zap.Error(err))
c.String(http.StatusOK, "")
return
}
h.logger.Debug("企业微信解密成功", zap.String("decrypted", string(decrypted)))
if err := xml.Unmarshal(decrypted, &body); err != nil {
h.logger.Warn("企业微信解密后 XML 解析失败", zap.Error(err))
c.String(http.StatusOK, "")
return
}
h.logger.Debug("企业微信内层 XML 解析成功", zap.String("FromUserName", body.FromUserName), zap.String("Content", body.Content))
}
if body.MsgType != "text" {
c.XML(http.StatusOK, wecomReplyXML{
ToUserName: body.FromUserName,
FromUserName: body.ToUserName,
CreateTime: time.Now().Unix(),
MsgType: "text",
Content: "暂仅支持文本消息,请发送文字。",
})
return
}
userID := body.FromUserName
text := strings.TrimSpace(body.Content)
reply := h.HandleMessage("wecom", userID, text)
// 加密模式需加密回复(此处简化为明文回复;若企业要求加密需再实现加密
c.XML(http.StatusOK, wecomReplyXML{
ToUserName: body.FromUserName,
FromUserName: body.ToUserName,
CreateTime: time.Now().Unix(),
MsgType: "text",
Content: reply,
})
// 限制回复内容长度(企业微信限制 2048 字节
maxReplyLen := 2000
limitReply := func(s string) string {
if len(s) > maxReplyLen {
return s[:maxReplyLen] + "\n\n(内容过长,已截断)"
}
return s
}
if body.MsgType != "text" {
h.logger.Debug("企业微信收到非文本消息", zap.String("MsgType", body.MsgType))
h.sendWecomReply(c, userID, enterpriseID, limitReply("暂仅支持文本消息,请发送文字。"), timestamp, nonce)
return
}
// 文本消息:先判断是否为内置命令(如 帮助/列表/新对话 等),这类命令处理很快,可以直接走被动回复,避免依赖主动发送 API。
if cmdReply, ok := h.handleRobotCommand("wecom", userID, text); ok {
h.logger.Debug("企业微信收到命令消息,走被动回复", zap.String("userID", userID), zap.String("text", text))
h.sendWecomReply(c, userID, enterpriseID, limitReply(cmdReply), timestamp, nonce)
return
}
h.logger.Debug("企业微信开始处理消息(异步 AI)", zap.String("userID", userID), zap.String("text", text))
// 企业微信被动回复有 5 秒超时限制,而 AI 调用通常超过该时长。
// 这里采用推荐做法:立即返回 success(或空串),然后通过主动发送接口推送完整回复。
c.String(http.StatusOK, "success")
// 异步处理消息并通过企业微信主动消息接口发送结果
go func() {
reply := h.HandleMessage("wecom", userID, text)
reply = limitReply(reply)
h.logger.Debug("企业微信消息处理完成", zap.String("userID", userID), zap.String("reply", reply))
// 调用企业微信 API 主动发送消息
h.sendWecomMessageViaAPI(userID, enterpriseID, reply)
}()
}
// sendWecomReply 发送企业微信回复(加密模式自动加密)
// 参数:toUser=用户 ID, fromUser=企业 ID(明文模式)/CorpID(加密模式), content=回复内容,timestamp/nonce=请求参数
func (h *RobotHandler) sendWecomReply(c *gin.Context, toUser, fromUser, content, timestamp, nonce string) {
// 加密模式:判断 EncodingAESKey 是否配置
if h.config.Robots.Wecom.EncodingAESKey != "" {
// 加密模式使用 CorpID 进行加密
corpID := h.config.Robots.Wecom.CorpID
if corpID == "" {
h.logger.Warn("企业微信加密模式缺少 CorpID 配置")
c.String(http.StatusOK, "")
return
}
// 构造完整的明文 XML 回复(格式严格按企业微信文档要求)
plainResp := fmt.Sprintf(`<xml>
<ToUserName><![CDATA[%s]]></ToUserName>
<FromUserName><![CDATA[%s]]></FromUserName>
<CreateTime>%d</CreateTime>
<MsgType><![CDATA[text]]></MsgType>
<Content><![CDATA[%s]]></Content>
</xml>`, toUser, fromUser, time.Now().Unix(), content)
encrypted, err := wecomEncrypt(h.config.Robots.Wecom.EncodingAESKey, plainResp, corpID)
if err != nil {
h.logger.Warn("企业微信回复加密失败", zap.Error(err))
c.String(http.StatusOK, "")
return
}
// 使用请求中的 timestamp/nonce 生成签名(企业微信要求回复时使用与请求相同的 timestamp 和 nonce
msgSignature := h.signWecomRequest(h.config.Robots.Wecom.Token, timestamp, nonce, encrypted)
h.logger.Debug("企业微信发送加密回复",
zap.String("Encrypt", encrypted[:50]+"..."),
zap.String("MsgSignature", msgSignature),
zap.String("TimeStamp", timestamp),
zap.String("Nonce", nonce))
// 加密模式仅返回 4 个核心字段(企业微信官方要求)
xmlResp := fmt.Sprintf(`<xml><Encrypt><![CDATA[%s]]></Encrypt><MsgSignature><![CDATA[%s]]></MsgSignature><TimeStamp><![CDATA[%s]]></TimeStamp><Nonce><![CDATA[%s]]></Nonce></xml>`, encrypted, msgSignature, timestamp, nonce)
// also log the final response body so we can cross-check with the
// network traffic or developer console
h.logger.Debug("企业微信加密回复包", zap.String("xml", xmlResp))
// for additional confidence, decrypt the payload ourselves and log it
if dec, err2 := wecomDecrypt(h.config.Robots.Wecom.EncodingAESKey, encrypted); err2 == nil {
h.logger.Debug("企业微信加密回复解密检查", zap.String("plain", string(dec)))
} else {
h.logger.Warn("企业微信加密回复解密检查失败", zap.Error(err2))
}
// 使用 c.Writer.Write 直接写入响应,避免 c.String 的转义问题
c.Writer.WriteHeader(http.StatusOK)
// use text/xml as that's what WeCom examples show
c.Writer.Header().Set("Content-Type", "text/xml; charset=utf-8")
_, _ = c.Writer.Write([]byte(xmlResp))
h.logger.Debug("企业微信加密回复已发送")
return
}
// 明文模式
h.logger.Debug("企业微信发送明文回复", zap.String("ToUserName", toUser), zap.String("FromUserName", fromUser), zap.String("Content", content[:50]+"..."))
// 手动构造 XML 响应(使用 CDATA 包裹所有字段,并包含 AgentID)
xmlResp := fmt.Sprintf(`<xml>
<ToUserName><![CDATA[%s]]></ToUserName>
<FromUserName><![CDATA[%s]]></FromUserName>
<CreateTime>%d</CreateTime>
<MsgType><![CDATA[text]]></MsgType>
<Content><![CDATA[%s]]></Content>
</xml>`, toUser, fromUser, time.Now().Unix(), content)
// log the exact plaintext response for debugging
h.logger.Debug("企业微信明文回复包", zap.String("xml", xmlResp))
// use text/xml as recommended by WeCom docs
c.Header("Content-Type", "text/xml; charset=utf-8")
c.String(http.StatusOK, xmlResp)
h.logger.Debug("企业微信明文回复已发送")
}
// —————— 测试接口(需登录,用于验证机器人逻辑,无需钉钉/飞书客户端) ——————
@@ -562,6 +785,87 @@ func (h *RobotHandler) HandleRobotTest(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"reply": reply})
}
// sendWecomMessageViaAPI 通过企业微信 API 主动发送消息(用于异步处理后的结果发送)
func (h *RobotHandler) sendWecomMessageViaAPI(toUser, toParty, content string) {
if !h.config.Robots.Wecom.Enabled {
return
}
secret := h.config.Robots.Wecom.Secret
corpID := h.config.Robots.Wecom.CorpID
agentID := h.config.Robots.Wecom.AgentID
if secret == "" || corpID == "" {
h.logger.Warn("企业微信主动 API 缺少 secret 或 corpID 配置")
return
}
// 第 1 步:获取 access_token
tokenURL := fmt.Sprintf("https://qyapi.weixin.qq.com/cgi-bin/gettoken?corpid=%s&corpsecret=%s", corpID, secret)
resp, err := http.Get(tokenURL)
if err != nil {
h.logger.Warn("企业微信获取 token 失败", zap.Error(err))
return
}
defer resp.Body.Close()
var tokenResp struct {
AccessToken string `json:"access_token"`
ErrCode int `json:"errcode"`
ErrMsg string `json:"errmsg"`
}
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
h.logger.Warn("企业微信 token 响应解析失败", zap.Error(err))
return
}
if tokenResp.ErrCode != 0 {
h.logger.Warn("企业微信 token 获取错误", zap.String("errmsg", tokenResp.ErrMsg), zap.Int("errcode", tokenResp.ErrCode))
return
}
// 第 2 步:构造发送消息请求
msgReq := map[string]interface{}{
"touser": toUser,
"msgtype": "text",
"agentid": agentID,
"text": map[string]interface{}{
"content": content,
},
}
msgBody, err := json.Marshal(msgReq)
if err != nil {
h.logger.Warn("企业微信消息序列化失败", zap.Error(err))
return
}
// 第 3 步:发送消息
sendURL := fmt.Sprintf("https://qyapi.weixin.qq.com/cgi-bin/message/send?access_token=%s", tokenResp.AccessToken)
msgResp, err := http.Post(sendURL, "application/json", bytes.NewReader(msgBody))
if err != nil {
h.logger.Warn("企业微信主动发送消息失败", zap.Error(err))
return
}
defer msgResp.Body.Close()
var sendResp struct {
ErrCode int `json:"errcode"`
ErrMsg string `json:"errmsg"`
InvalidUser string `json:"invaliduser"`
MsgID string `json:"msgid"`
}
if err := json.NewDecoder(msgResp.Body).Decode(&sendResp); err != nil {
h.logger.Warn("企业微信发送响应解析失败", zap.Error(err))
return
}
if sendResp.ErrCode == 0 {
h.logger.Debug("企业微信主动发送消息成功", zap.String("msgid", sendResp.MsgID))
} else {
h.logger.Warn("企业微信主动发送消息失败", zap.String("errmsg", sendResp.ErrMsg), zap.Int("errcode", sendResp.ErrCode), zap.String("invaliduser", sendResp.InvalidUser))
}
}
// —————— 钉钉 ——————
// HandleDingtalkPOST 钉钉事件回调(流式接入等);当前为占位,返回 200
+2 -2
View File
@@ -96,7 +96,7 @@ func (h *TerminalHandler) RunCommand(c *gin.Context) {
} else {
cmd = exec.CommandContext(ctx, shell, "-c", cmdStr)
// 无 TTY 时设置 COLUMNS/TERM,使 ping 等工具的 usage 排版与真实终端一致
cmd.Env = append(os.Environ(), "COLUMNS=120", "LINES=40", "TERM=xterm-256color")
cmd.Env = append(os.Environ(), "COLUMNS=256", "LINES=40", "TERM=xterm-256color")
}
if req.Cwd != "" {
@@ -218,7 +218,7 @@ func (h *TerminalHandler) RunCommandStream(c *gin.Context) {
cmd = exec.CommandContext(ctx, "cmd", "/c", cmdStr)
} else {
cmd = exec.CommandContext(ctx, shell, "-c", cmdStr)
cmd.Env = append(os.Environ(), "COLUMNS=120", "LINES=40", "TERM=xterm-256color")
cmd.Env = append(os.Environ(), "COLUMNS=256", "LINES=40", "TERM=xterm-256color")
}
if req.Cwd != "" {
absCwd, err := filepath.Abs(req.Cwd)
+1 -1
View File
@@ -11,7 +11,7 @@ import (
"github.com/creack/pty"
)
const ptyCols = 120
const ptyCols = 256
const ptyRows = 40
// runCommandStreamImpl 在 Unix 下用 PTY 执行,使 ping 等命令按终端宽度排版(isatty 为真)
+2 -2
View File
@@ -37,7 +37,7 @@ func (h *TerminalHandler) RunCommandWS(c *gin.Context) {
}
cmd := exec.Command(shell)
cmd.Env = append(os.Environ(),
"COLUMNS=120",
"COLUMNS=256",
"LINES=40",
"TERM=xterm-256color",
)
@@ -55,7 +55,7 @@ func (h *TerminalHandler) RunCommandWS(c *gin.Context) {
for {
n, err := ptmx.Read(buf)
if n > 0 {
_ = conn.WriteMessage(websocket.TextMessage, buf[:n])
_ = conn.WriteMessage(websocket.BinaryMessage, buf[:n])
}
if err != nil {
break
+149 -31
View File
@@ -6,39 +6,75 @@ import (
"fmt"
"net/http"
"strings"
"sync"
"time"
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/openai"
"go.uber.org/zap"
"golang.org/x/time/rate"
)
// Embedder 文本嵌入器
type Embedder struct {
openAIClient *openai.Client
config *config.KnowledgeConfig
openAIConfig *config.OpenAIConfig // 用于获取API Key
logger *zap.Logger
openAIClient *openai.Client
config *config.KnowledgeConfig
openAIConfig *config.OpenAIConfig // 用于获取 API Key
logger *zap.Logger
rateLimiter *rate.Limiter // 速率限制器
rateLimitDelay time.Duration // 请求间隔时间
maxRetries int // 最大重试次数
retryDelay time.Duration // 重试间隔
mu sync.Mutex // 保护 rateLimiter
}
// NewEmbedder 创建新的嵌入器
func NewEmbedder(cfg *config.KnowledgeConfig, openAIConfig *config.OpenAIConfig, openAIClient *openai.Client, logger *zap.Logger) *Embedder {
// 初始化速率限制器
var rateLimiter *rate.Limiter
var rateLimitDelay time.Duration
// 如果配置了 MaxRPM,根据 RPM 计算速率限制
if cfg.Indexing.MaxRPM > 0 {
rpm := cfg.Indexing.MaxRPM
rateLimiter = rate.NewLimiter(rate.Every(time.Minute/time.Duration(rpm)), rpm)
logger.Info("知识库索引速率限制已启用", zap.Int("maxRPM", rpm))
} else if cfg.Indexing.RateLimitDelayMs > 0 {
// 如果没有配置 MaxRPM 但配置了固定延迟,使用固定延迟模式
rateLimitDelay = time.Duration(cfg.Indexing.RateLimitDelayMs) * time.Millisecond
logger.Info("知识库索引固定延迟已启用", zap.Duration("delay", rateLimitDelay))
}
// 重试配置
maxRetries := 3
retryDelay := 1000 * time.Millisecond
if cfg.Indexing.MaxRetries > 0 {
maxRetries = cfg.Indexing.MaxRetries
}
if cfg.Indexing.RetryDelayMs > 0 {
retryDelay = time.Duration(cfg.Indexing.RetryDelayMs) * time.Millisecond
}
return &Embedder{
openAIClient: openAIClient,
config: cfg,
openAIConfig: openAIConfig,
logger: logger,
openAIClient: openAIClient,
config: cfg,
openAIConfig: openAIConfig,
logger: logger,
rateLimiter: rateLimiter,
rateLimitDelay: rateLimitDelay,
maxRetries: maxRetries,
retryDelay: retryDelay,
}
}
// EmbeddingRequest OpenAI嵌入请求
// EmbeddingRequest OpenAI 嵌入请求
type EmbeddingRequest struct {
Model string `json:"model"`
Input []string `json:"input"`
}
// EmbeddingResponse OpenAI嵌入响应
// EmbeddingResponse OpenAI 嵌入响应
type EmbeddingResponse struct {
Data []EmbeddingData `json:"data"`
Error *EmbeddingError `json:"error,omitempty"`
@@ -56,12 +92,69 @@ type EmbeddingError struct {
Type string `json:"type"`
}
// EmbedText 对文本进行嵌入
func (e *Embedder) EmbedText(ctx context.Context, text string) ([]float32, error) {
if e.openAIClient == nil {
return nil, fmt.Errorf("OpenAI客户端未初始化")
// waitRateLimiter 等待速率限制器
func (e *Embedder) waitRateLimiter() {
e.mu.Lock()
defer e.mu.Unlock()
if e.rateLimiter != nil {
// 等待令牌
ctx := context.Background()
if err := e.rateLimiter.Wait(ctx); err != nil {
e.logger.Warn("速率限制器等待失败", zap.Error(err))
}
}
if e.rateLimitDelay > 0 {
time.Sleep(e.rateLimitDelay)
}
}
// EmbedText 对文本进行嵌入(带重试和速率限制)
func (e *Embedder) EmbedText(ctx context.Context, text string) ([]float32, error) {
if e.openAIClient == nil {
return nil, fmt.Errorf("OpenAI 客户端未初始化")
}
var lastErr error
for attempt := 0; attempt < e.maxRetries; attempt++ {
// 速率限制
if attempt > 0 {
// 重试时等待更长时间
waitTime := e.retryDelay * time.Duration(attempt)
e.logger.Debug("重试前等待", zap.Int("attempt", attempt+1), zap.Duration("waitTime", waitTime))
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-time.After(waitTime):
}
} else {
e.waitRateLimiter()
}
result, err := e.doEmbedText(ctx, text)
if err == nil {
return result, nil
}
lastErr = err
// 检查是否是可重试的错误(429 速率限制、5xx 服务器错误、网络错误)
if !e.isRetryableError(err) {
return nil, err
}
e.logger.Debug("嵌入请求失败,准备重试",
zap.Int("attempt", attempt+1),
zap.Int("maxRetries", e.maxRetries),
zap.Error(err))
}
return nil, fmt.Errorf("达到最大重试次数 (%d): %v", e.maxRetries, lastErr)
}
// doEmbedText 执行实际的嵌入请求(内部方法)
func (e *Embedder) doEmbedText(ctx context.Context, text string) ([]float32, error) {
// 使用配置的嵌入模型
model := e.config.Embedding.Model
if model == "" {
@@ -73,7 +166,7 @@ func (e *Embedder) EmbedText(ctx context.Context, text string) ([]float32, error
Input: []string{text},
}
// 清理baseURL:去除前后空格和尾部斜杠
// 清理 baseURL:去除前后空格和尾部斜杠
baseURL := strings.TrimSpace(e.config.Embedding.BaseURL)
baseURL = strings.TrimSuffix(baseURL, "/")
if baseURL == "" {
@@ -83,24 +176,24 @@ func (e *Embedder) EmbedText(ctx context.Context, text string) ([]float32, error
// 构建请求
body, err := json.Marshal(req)
if err != nil {
return nil, fmt.Errorf("序列化请求失败: %w", err)
return nil, fmt.Errorf("序列化请求失败%w", err)
}
requestURL := baseURL + "/embeddings"
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL, strings.NewReader(string(body)))
if err != nil {
return nil, fmt.Errorf("创建请求失败: %w", err)
return nil, fmt.Errorf("创建请求失败%w", err)
}
httpReq.Header.Set("Content-Type", "application/json")
// 使用配置的API Key,如果没有则使用OpenAI配置的
// 使用配置的 API Key,如果没有则使用 OpenAI 配置的
apiKey := strings.TrimSpace(e.config.Embedding.APIKey)
if apiKey == "" && e.openAIConfig != nil {
apiKey = e.openAIConfig.APIKey
}
if apiKey == "" {
return nil, fmt.Errorf("API Key未配置")
return nil, fmt.Errorf("API Key 未配置")
}
httpReq.Header.Set("Authorization", "Bearer "+apiKey)
@@ -110,7 +203,7 @@ func (e *Embedder) EmbedText(ctx context.Context, text string) ([]float32, error
}
resp, err := httpClient.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("发送请求失败: %w", err)
return nil, fmt.Errorf("发送请求失败%w", err)
}
defer resp.Body.Close()
@@ -132,7 +225,7 @@ func (e *Embedder) EmbedText(ctx context.Context, text string) ([]float32, error
if len(requestBodyPreview) > 200 {
requestBodyPreview = requestBodyPreview[:200] + "..."
}
e.logger.Debug("嵌入API请求",
e.logger.Debug("嵌入 API 请求",
zap.String("url", httpReq.URL.String()),
zap.String("model", model),
zap.String("requestBody", requestBodyPreview),
@@ -148,12 +241,12 @@ func (e *Embedder) EmbedText(ctx context.Context, text string) ([]float32, error
if len(bodyPreview) > 500 {
bodyPreview = bodyPreview[:500] + "..."
}
return nil, fmt.Errorf("解析响应失败 (URL: %s, 状态码: %d, 响应长度: %d字节): %w\n请求体: %s\n响应内容预览: %s",
return nil, fmt.Errorf("解析响应失败 (URL: %s, 状态码%d, 响应长度%d字节): %w\n请求体%s\n响应内容预览%s",
requestURL, resp.StatusCode, len(bodyBytes), err, requestBodyPreview, bodyPreview)
}
if embeddingResp.Error != nil {
return nil, fmt.Errorf("OpenAI API错误 (状态码: %d): 类型=%s, 消息=%s",
return nil, fmt.Errorf("OpenAI API 错误 (状态码%d): 类型=%s, 消息=%s",
resp.StatusCode, embeddingResp.Error.Type, embeddingResp.Error.Message)
}
@@ -162,7 +255,7 @@ func (e *Embedder) EmbedText(ctx context.Context, text string) ([]float32, error
if len(bodyPreview) > 500 {
bodyPreview = bodyPreview[:500] + "..."
}
return nil, fmt.Errorf("HTTP请求失败 (URL: %s, 状态码: %d): 响应内容=%s", requestURL, resp.StatusCode, bodyPreview)
return nil, fmt.Errorf("HTTP 请求失败 (URL: %s, 状态码%d): 响应内容=%s", requestURL, resp.StatusCode, bodyPreview)
}
if len(embeddingResp.Data) == 0 {
@@ -170,11 +263,11 @@ func (e *Embedder) EmbedText(ctx context.Context, text string) ([]float32, error
if len(bodyPreview) > 500 {
bodyPreview = bodyPreview[:500] + "..."
}
return nil, fmt.Errorf("未收到嵌入数据 (状态码: %d, 响应长度: %d字节)\n响应内容: %s",
return nil, fmt.Errorf("未收到嵌入数据 (状态码%d, 响应长度%d字节)\n响应内容%s",
resp.StatusCode, len(bodyBytes), bodyPreview)
}
// 转换为float32
// 转换为 float32
embedding := make([]float32, len(embeddingResp.Data[0].Embedding))
for i, v := range embeddingResp.Data[0].Embedding {
embedding[i] = float32(v)
@@ -183,23 +276,48 @@ func (e *Embedder) EmbedText(ctx context.Context, text string) ([]float32, error
return embedding, nil
}
// isRetryableError 判断是否是可重试的错误
func (e *Embedder) isRetryableError(err error) bool {
if err == nil {
return false
}
errStr := err.Error()
// 429 速率限制错误
if strings.Contains(errStr, "429") || strings.Contains(errStr, "rate limit") {
return true
}
// 5xx 服务器错误
if strings.Contains(errStr, "500") || strings.Contains(errStr, "502") ||
strings.Contains(errStr, "503") || strings.Contains(errStr, "504") {
return true
}
// 网络错误
if strings.Contains(errStr, "timeout") || strings.Contains(errStr, "connection") ||
strings.Contains(errStr, "network") || strings.Contains(errStr, "EOF") {
return true
}
return false
}
// EmbedTexts 批量嵌入文本
func (e *Embedder) EmbedTexts(ctx context.Context, texts []string) ([][]float32, error) {
if len(texts) == 0 {
return nil, nil
}
// OpenAI API支持批量,但为了简单起见,我们逐个处理
// 实际可以使用批量API以提高效率
embeddings := make([][]float32, len(texts))
for i, text := range texts {
embedding, err := e.EmbedText(ctx, text)
if err != nil {
return nil, fmt.Errorf("嵌入文本[%d]失败: %w", i, err)
return nil, fmt.Errorf("嵌入文本 [%d] 失败%w", i, err)
}
embeddings[i] = embedding
}
return embeddings, nil
}
+382 -98
View File
@@ -10,56 +10,133 @@ import (
"sync"
"time"
"cyberstrike-ai/internal/config"
"github.com/google/uuid"
"go.uber.org/zap"
)
// Indexer 索引器,负责将知识项分块并向量化
type Indexer struct {
db *sql.DB
embedder *Embedder
logger *zap.Logger
chunkSize int // 每个块的最大token数(估算)
overlap int // 块之间的重叠token数
db *sql.DB
embedder *Embedder
logger *zap.Logger
chunkSize int // 每个块的最大 token 数(估算)
overlap int // 块之间的重叠 token
maxChunks int // 单个知识项的最大块数量(0 表示不限制)
// 错误跟踪
mu sync.RWMutex
lastError string // 最近一次错误信息
mu sync.RWMutex
lastError string // 最近一次错误信息
lastErrorTime time.Time // 最近一次错误时间
errorCount int // 连续错误计数
errorCount int // 连续错误计数
// 重建索引状态跟踪
rebuildMu sync.RWMutex
isRebuilding bool // 是否正在重建索引
rebuildTotalItems int // 重建总项数
rebuildCurrent int // 当前已处理项数
rebuildFailed int // 重建失败项数
rebuildStartTime time.Time // 重建开始时间
rebuildLastItemID string // 最近处理的项 ID
rebuildLastChunks int // 最近处理的项的分块数
}
// NewIndexer 创建新的索引器
func NewIndexer(db *sql.DB, embedder *Embedder, logger *zap.Logger) *Indexer {
func NewIndexer(db *sql.DB, embedder *Embedder, logger *zap.Logger, indexingCfg *config.IndexingConfig) *Indexer {
chunkSize := 512
overlap := 50
maxChunks := 0
if indexingCfg != nil {
if indexingCfg.ChunkSize > 0 {
chunkSize = indexingCfg.ChunkSize
}
if indexingCfg.ChunkOverlap >= 0 {
overlap = indexingCfg.ChunkOverlap
}
if indexingCfg.MaxChunksPerItem > 0 {
maxChunks = indexingCfg.MaxChunksPerItem
}
}
return &Indexer{
db: db,
embedder: embedder,
logger: logger,
chunkSize: 512, // 默认512 tokens
overlap: 50, // 默认50 tokens重叠
chunkSize: chunkSize,
overlap: overlap,
maxChunks: maxChunks,
}
}
// ChunkText 将文本分块(支持重叠)
// ChunkText 将文本分块(支持重叠,保留标题上下文
func (idx *Indexer) ChunkText(text string) []string {
// 按Markdown标题分割
chunks := idx.splitByMarkdownHeaders(text)
// 按 Markdown 标题分割,获取带标题的块
sections := idx.splitByMarkdownHeadersWithContent(text)
// 如果块太大,进一步分割
// 处理每个块
result := make([]string, 0)
for _, chunk := range chunks {
if idx.estimateTokens(chunk) <= idx.chunkSize {
result = append(result, chunk)
for _, section := range sections {
// 构建父级标题路径(不包含最后一级标题,因为内容中已经包含)
// 例如:["# A", "## B", "### C"] -> "[# A > ## B]"
var parentHeaderPath string
if len(section.HeaderPath) > 1 {
parentHeaderPath = strings.Join(section.HeaderPath[:len(section.HeaderPath)-1], " > ")
}
// 提取内容的第一行作为标题(如 "# Prompt Injection"
firstLine, remainingContent := extractFirstLine(section.Content)
// 如果剩余内容为空或只有空白,说明这个块只有标题没有正文,跳过
if strings.TrimSpace(remainingContent) == "" {
continue
}
// 如果块太大,进一步分割
if idx.estimateTokens(section.Content) <= idx.chunkSize {
// 块大小合适,添加父级标题前缀
if parentHeaderPath != "" {
result = append(result, fmt.Sprintf("[%s] %s", parentHeaderPath, section.Content))
} else {
result = append(result, section.Content)
}
} else {
// 按段落分割
subChunks := idx.splitByParagraphs(chunk)
for _, subChunk := range subChunks {
if idx.estimateTokens(subChunk) <= idx.chunkSize {
result = append(result, subChunk)
} else {
// 按句子分割(支持重叠)
chunksWithOverlap := idx.splitBySentencesWithOverlap(subChunk)
result = append(result, chunksWithOverlap...)
// 块太大,按子标题或段落分割,保持标题上下文
// 首先尝试按子标题分割(保留子标题结构)
subSections := idx.splitBySubHeaders(section.Content, firstLine, parentHeaderPath)
if len(subSections) > 1 {
// 成功按子标题分割,递归处理每个子块
for _, sub := range subSections {
if idx.estimateTokens(sub) <= idx.chunkSize {
result = append(result, sub)
} else {
// 子块仍然太大,按段落分割(保留标题前缀)
paragraphs := idx.splitByParagraphsWithHeader(sub, parentHeaderPath)
for _, para := range paragraphs {
if idx.estimateTokens(para) <= idx.chunkSize {
result = append(result, para)
} else {
// 段落仍太大,按句子分割
sentenceChunks := idx.splitBySentencesWithOverlap(para)
for _, chunk := range sentenceChunks {
result = append(result, chunk)
}
}
}
}
}
} else {
// 没有子标题,按段落分割(保留标题前缀)
paragraphs := idx.splitByParagraphsWithHeader(section.Content, parentHeaderPath)
for _, para := range paragraphs {
if idx.estimateTokens(para) <= idx.chunkSize {
result = append(result, para)
} else {
// 段落仍太大,按句子分割
sentenceChunks := idx.splitBySentencesWithOverlap(para)
for _, chunk := range sentenceChunks {
result = append(result, chunk)
}
}
}
}
}
@@ -68,43 +145,183 @@ func (idx *Indexer) ChunkText(text string) []string {
return result
}
// splitByMarkdownHeaders 按Markdown标题分割
func (idx *Indexer) splitByMarkdownHeaders(text string) []string {
// 匹配Markdown标题 (# ## ### 等)
// extractFirstLine 提取第一行内容和剩余内容
func extractFirstLine(content string) (firstLine, remaining string) {
lines := strings.SplitN(content, "\n", 2)
if len(lines) == 0 {
return "", ""
}
if len(lines) == 1 {
return lines[0], ""
}
return lines[0], lines[1]
}
// splitBySubHeaders 尝试按子标题分割内容(用于处理大块内容)
// headerPrefix 是父级标题路径,用于添加到每个子块
func (idx *Indexer) splitBySubHeaders(content, headerPrefix, parentPath string) []string {
// 匹配 Markdown 子标题(## 及以上)
subHeaderRegex := regexp.MustCompile(`(?m)^#{2,6}\s+.+$`)
matches := subHeaderRegex.FindAllStringIndex(content, -1)
if len(matches) == 0 {
// 没有子标题,返回原始内容
return []string{content}
}
result := make([]string, 0, len(matches))
for i, match := range matches {
start := match[0]
nextStart := len(content)
if i+1 < len(matches) {
nextStart = matches[i+1][0]
}
subContent := strings.TrimSpace(content[start:nextStart])
// 添加父级路径前缀
if parentPath != "" {
result = append(result, fmt.Sprintf("[%s] %s", parentPath, subContent))
} else {
result = append(result, subContent)
}
}
return result
}
// splitByParagraphsWithHeader 按段落分割,每个段落添加标题前缀(用于保持上下文)
func (idx *Indexer) splitByParagraphsWithHeader(content, parentPath string) []string {
// 提取第一行作为标题
firstLine, _ := extractFirstLine(content)
paragraphs := strings.Split(content, "\n\n")
result := make([]string, 0)
for i, p := range paragraphs {
trimmed := strings.TrimSpace(p)
if trimmed == "" {
continue
}
// 过滤掉只有标题的段落(没有实际内容)
if strings.TrimSpace(trimmed) == strings.TrimSpace(firstLine) {
continue
}
// 第一个段落已经包含标题,不需要重复添加
if i == 0 && strings.Contains(trimmed, firstLine) {
if parentPath != "" {
result = append(result, fmt.Sprintf("[%s] %s", parentPath, trimmed))
} else {
result = append(result, trimmed)
}
} else {
// 其他段落添加标题前缀以保持上下文
if parentPath != "" {
result = append(result, fmt.Sprintf("[%s] %s\n%s", parentPath, firstLine, trimmed))
} else {
result = append(result, fmt.Sprintf("%s\n%s", firstLine, trimmed))
}
}
}
return result
}
// Section 表示一个带标题路径的文本块
type Section struct {
HeaderPath []string // 标题路径(如 ["# SQL 注入", "## 检测方法"]
Content string // 块内容
}
// splitByMarkdownHeadersWithContent 按 Markdown 标题分割,返回带标题路径的块
// 每个块的内容包含自己的标题,用于向量化检索
//
// 例如,对于以下 Markdown:
// # Prompt Injection
// 引言内容
// ## Summary
// 目录内容
//
// 返回:
// [{HeaderPath: ["# Prompt Injection"], Content: "# Prompt Injection\n引言内容"},
// {HeaderPath: ["# Prompt Injection", "## Summary"], Content: "## Summary\n目录内容"}]
func (idx *Indexer) splitByMarkdownHeadersWithContent(text string) []Section {
// 匹配 Markdown 标题 (# ## ### 等)
headerRegex := regexp.MustCompile(`(?m)^#{1,6}\s+.+$`)
// 找到所有标题位置
matches := headerRegex.FindAllStringIndex(text, -1)
if len(matches) == 0 {
return []string{text}
// 没有标题,返回整个文本
return []Section{{HeaderPath: []string{}, Content: text}}
}
chunks := make([]string, 0)
lastPos := 0
sections := make([]Section, 0, len(matches))
currentHeaderPath := []string{}
for _, match := range matches {
for i, match := range matches {
start := match[0]
if start > lastPos {
chunks = append(chunks, strings.TrimSpace(text[lastPos:start]))
}
lastPos = start
}
end := match[1]
nextStart := len(text)
// 添加最后一部分
if lastPos < len(text) {
chunks = append(chunks, strings.TrimSpace(text[lastPos:]))
// 找到下一个标题的位置
if i+1 < len(matches) {
nextStart = matches[i+1][0]
}
// 提取当前标题
headerLine := strings.TrimSpace(text[start:end])
// 计算标题层级(# 的数量)
level := 0
for _, ch := range headerLine {
if ch == '#' {
level++
} else {
break
}
}
// 更新标题路径:移除比当前层级深或等于的子标题,然后添加当前标题
newPath := make([]string, 0, len(currentHeaderPath)+1)
for _, h := range currentHeaderPath {
hLevel := 0
for _, ch := range h {
if ch == '#' {
hLevel++
} else {
break
}
}
if hLevel < level {
newPath = append(newPath, h)
}
}
newPath = append(newPath, headerLine)
currentHeaderPath = newPath
// 提取当前标题到下一个标题之间的内容(包含当前标题)
content := strings.TrimSpace(text[start:nextStart])
// 创建块,使用当前标题路径(包含当前标题)
sections = append(sections, Section{
HeaderPath: append([]string(nil), currentHeaderPath...),
Content: content,
})
}
// 过滤空块
result := make([]string, 0)
for _, chunk := range chunks {
if strings.TrimSpace(chunk) != "" {
result = append(result, chunk)
result := make([]Section, 0, len(sections))
for _, section := range sections {
if strings.TrimSpace(section.Content) != "" {
result = append(result, section)
}
}
if len(result) == 0 {
return []string{text}
return []Section{{HeaderPath: []string{}, Content: text}}
}
return result
@@ -124,8 +341,12 @@ func (idx *Indexer) splitByParagraphs(text string) []string {
// splitBySentences 按句子分割(用于内部,不包含重叠逻辑)
func (idx *Indexer) splitBySentences(text string) []string {
// 简单的句子分割(按句号、问号、感叹号)
sentenceRegex := regexp.MustCompile(`[.!?]+\s+`)
// 简单的句子分割(按句号、问号、感叹号,支持中英文
// . ! ? = 英文标点
// \u3002 = 。(中文句号)
// \uFF01 = (中文叹号)
// \uFF1F = (中文问号)
sentenceRegex := regexp.MustCompile(`[.!?\x{3002}\x{FF01}\x{FF1F}]+`)
sentences := sentenceRegex.Split(text, -1)
result := make([]string, 0)
for _, s := range sentences {
@@ -221,13 +442,13 @@ func (idx *Indexer) splitBySentencesSimple(text string) []string {
return result
}
// extractLastTokens 从文本末尾提取指定token数量的内容
// extractLastTokens 从文本末尾提取指定 token 数量的内容
func (idx *Indexer) extractLastTokens(text string, tokenCount int) string {
if tokenCount <= 0 || text == "" {
return ""
}
// 估算字符数(1 token ≈ 4字符)
// 估算字符数(1 token ≈ 4 字符)
charCount := tokenCount * 4
runes := []rune(text)
@@ -236,12 +457,11 @@ func (idx *Indexer) extractLastTokens(text string, tokenCount int) string {
}
// 从末尾提取指定数量的字符
// 尝试在句子边界处截断,避免截断句子中间
startPos := len(runes) - charCount
extracted := string(runes[startPos:])
// 尝试找到第一个句子边界(句号、问号、感叹号后的空格
sentenceBoundary := regexp.MustCompile(`[.!?]+\s+`)
// 尝试找到第一个句子边界(支持中英文标点
sentenceBoundary := regexp.MustCompile(`[.!?\x{3002}\x{FF01}\x{FF1F}]+`)
matches := sentenceBoundary.FindStringIndex(extracted)
if len(matches) > 0 && matches[0] > 0 {
// 在句子边界处截断,保留完整句子
@@ -251,41 +471,51 @@ func (idx *Indexer) extractLastTokens(text string, tokenCount int) string {
return strings.TrimSpace(extracted)
}
// estimateTokens 估算token数(简单估算:1 token ≈ 4字符)
// estimateTokens 估算 token 数(简单估算:1 token ≈ 4 字符)
func (idx *Indexer) estimateTokens(text string) int {
return len([]rune(text)) / 4
}
// IndexItem 索引知识项(分块并向量化)
func (idx *Indexer) IndexItem(ctx context.Context, itemID string) error {
// 获取知识项(包含categorytitle,用于向量化)
// 获取知识项(包含 categorytitle,用于向量化)
var content, category, title string
err := idx.db.QueryRow("SELECT content, category, title FROM knowledge_base_items WHERE id = ?", itemID).Scan(&content, &category, &title)
if err != nil {
return fmt.Errorf("获取知识项失败: %w", err)
return fmt.Errorf("获取知识项失败%w", err)
}
// 删除旧的向量(在 RebuildIndex 中已经统一清空,这里保留是为了单独调用 IndexItem 时的兼容性)
_, err = idx.db.Exec("DELETE FROM knowledge_embeddings WHERE item_id = ?", itemID)
if err != nil {
return fmt.Errorf("删除旧向量失败: %w", err)
return fmt.Errorf("删除旧向量失败%w", err)
}
// 分块
chunks := idx.ChunkText(content)
// 应用最大块数限制
if idx.maxChunks > 0 && len(chunks) > idx.maxChunks {
idx.logger.Info("知识项块数量超过限制,已截断",
zap.String("itemId", itemID),
zap.Int("originalChunks", len(chunks)),
zap.Int("maxChunks", idx.maxChunks))
chunks = chunks[:idx.maxChunks]
}
idx.logger.Info("知识项分块完成", zap.String("itemId", itemID), zap.Int("chunks", len(chunks)))
// 跟踪该知识项的错误
itemErrorCount := 0
var firstError error
firstErrorChunkIndex := -1
// 向量化每个块(包含categorytitle信息,以便向量检索时能匹配到风险类型)
// 向量化每个块(包含 categorytitle 信息,以便向量检索时能匹配到风险类型)
for i, chunk := range chunks {
// 将categorytitle信息包含到向量化的文本中
// 格式:"[风险类型: {category}] [标题: {title}]\n{chunk内容}"
// 这样向量嵌入就会包含风险类型信息,即使SQL过滤失败,向量相似度也能帮助匹配
textForEmbedding := fmt.Sprintf("[风险类型: %s] [标题: %s]\n%s", category, title, chunk)
// 将 categorytitle 信息包含到向量化的文本中
// 格式:"[风险类型{category}] [标题{title}]\n{chunk 内容}"
// 这样向量嵌入就会包含风险类型信息,即使 SQL 过滤失败,向量相似度也能帮助匹配
textForEmbedding := fmt.Sprintf("[风险类型%s] [标题%s]\n%s", category, title, chunk)
embedding, err := idx.embedder.EmbedText(ctx, textForEmbedding)
if err != nil {
@@ -305,18 +535,30 @@ func (idx *Indexer) IndexItem(ctx context.Context, itemID string) error {
zap.String("chunkPreview", chunkPreview),
zap.Error(err),
)
// 更新全局错误跟踪
errorMsg := fmt.Sprintf("向量化失败 (知识项: %s): %v", itemID, err)
errorMsg := fmt.Sprintf("向量化失败 (知识项%s): %v", itemID, err)
idx.mu.Lock()
idx.lastError = errorMsg
idx.lastErrorTime = time.Now()
idx.mu.Unlock()
}
// 如果连续失败2个块,立即停止处理该知识项(降低阈值,更快停止)
// 这样可以避免继续浪费API调用,同时也能更快地检测到配置问题
if itemErrorCount >= 2 {
// 如果连续失败 5 个块,立即停止处理该知识项
// 这样可以避免继续浪费 API 调用,同时也能更快地检测到配置问题
// 对于大文档(超过 10 个块),允许失败比例不超过 50%
maxConsecutiveFailures := 5
if len(chunks) > 10 && itemErrorCount > len(chunks)/2 {
idx.logger.Error("知识项向量化失败比例过高,停止处理",
zap.String("itemId", itemID),
zap.Int("totalChunks", len(chunks)),
zap.Int("failedChunks", itemErrorCount),
zap.Int("firstErrorChunkIndex", firstErrorChunkIndex),
zap.Error(firstError),
)
return fmt.Errorf("知识项向量化失败比例过高 (%d/%d个块失败): %v", itemErrorCount, len(chunks), firstError)
}
if itemErrorCount >= maxConsecutiveFailures {
idx.logger.Error("知识项连续向量化失败,停止处理",
zap.String("itemId", itemID),
zap.Int("totalChunks", len(chunks)),
@@ -344,6 +586,13 @@ func (idx *Indexer) IndexItem(ctx context.Context, itemID string) error {
}
idx.logger.Info("知识项索引完成", zap.String("itemId", itemID), zap.Int("chunks", len(chunks)))
// 更新重建状态中的最近处理信息
idx.rebuildMu.Lock()
idx.rebuildLastItemID = itemID
idx.rebuildLastChunks = len(chunks)
idx.rebuildMu.Unlock()
return nil
}
@@ -352,23 +601,38 @@ func (idx *Indexer) HasIndex() (bool, error) {
var count int
err := idx.db.QueryRow("SELECT COUNT(*) FROM knowledge_embeddings").Scan(&count)
if err != nil {
return false, fmt.Errorf("检查索引失败: %w", err)
return false, fmt.Errorf("检查索引失败%w", err)
}
return count > 0, nil
}
// RebuildIndex 重建所有索引
func (idx *Indexer) RebuildIndex(ctx context.Context) error {
// 设置重建状态
idx.rebuildMu.Lock()
idx.isRebuilding = true
idx.rebuildTotalItems = 0
idx.rebuildCurrent = 0
idx.rebuildFailed = 0
idx.rebuildStartTime = time.Now()
idx.rebuildLastItemID = ""
idx.rebuildLastChunks = 0
idx.rebuildMu.Unlock()
// 重置错误跟踪
idx.mu.Lock()
idx.lastError = ""
idx.lastErrorTime = time.Time{}
idx.errorCount = 0
idx.mu.Unlock()
rows, err := idx.db.Query("SELECT id FROM knowledge_base_items")
if err != nil {
return fmt.Errorf("查询知识项失败: %w", err)
// 重置重建状态
idx.rebuildMu.Lock()
idx.isRebuilding = false
idx.rebuildMu.Unlock()
return fmt.Errorf("查询知识项失败:%w", err)
}
defer rows.Close()
@@ -376,34 +640,36 @@ func (idx *Indexer) RebuildIndex(ctx context.Context) error {
for rows.Next() {
var id string
if err := rows.Scan(&id); err != nil {
return fmt.Errorf("扫描知识项ID失败: %w", err)
// 重置重建状态
idx.rebuildMu.Lock()
idx.isRebuilding = false
idx.rebuildMu.Unlock()
return fmt.Errorf("扫描知识项 ID 失败:%w", err)
}
itemIDs = append(itemIDs, id)
}
idx.rebuildMu.Lock()
idx.rebuildTotalItems = len(itemIDs)
idx.rebuildMu.Unlock()
idx.logger.Info("开始重建索引", zap.Int("totalItems", len(itemIDs)))
// 在开始重建前,先清空所有旧的向量,确保进度从0开始
// 这样 GetIndexStatus 可以准确反映重建进度
_, err = idx.db.Exec("DELETE FROM knowledge_embeddings")
if err != nil {
idx.logger.Warn("清空旧索引失败", zap.Error(err))
// 继续执行,即使清空失败也尝试重建
} else {
idx.logger.Info("已清空旧索引,开始重建")
}
// 注意:不再清空所有旧索引,而是按增量方式更新
// 每个知识项在 IndexItem 中会先删除自己的旧向量,然后插入新向量
// 这样配置更新后只重新索引变化的知识项,保留其他知识项的索引
failedCount := 0
consecutiveFailures := 0
maxConsecutiveFailures := 2 // 连续失败2次后立即停止(降低阈值,更快停止
maxConsecutiveFailures := 5 // 连续失败 5 次后立即停止(允许偶尔的临时错误
firstFailureItemID := ""
var firstFailureError error
for i, itemID := range itemIDs {
if err := idx.IndexItem(ctx, itemID); err != nil {
failedCount++
consecutiveFailures++
// 只在第一个失败时记录详细日志
if consecutiveFailures == 1 {
firstFailureItemID = itemID
@@ -414,15 +680,15 @@ func (idx *Indexer) RebuildIndex(ctx context.Context) error {
zap.Error(err),
)
}
// 如果连续失败过多,可能是配置问题,立即停止索引
if consecutiveFailures >= maxConsecutiveFailures {
errorMsg := fmt.Sprintf("连续 %d 个知识项索引失败,可能存在配置问题(如嵌入模型配置错误、API密钥无效、余额不足等)。第一个失败项: %s, 错误: %v", consecutiveFailures, firstFailureItemID, firstFailureError)
errorMsg := fmt.Sprintf("连续 %d 个知识项索引失败,可能存在配置问题(如嵌入模型配置错误、API 密钥无效、余额不足等)。第一个失败项%s, 错误%v", consecutiveFailures, firstFailureItemID, firstFailureError)
idx.mu.Lock()
idx.lastError = errorMsg
idx.lastErrorTime = time.Now()
idx.mu.Unlock()
idx.logger.Error("连续索引失败次数过多,立即停止索引",
zap.Int("consecutiveFailures", consecutiveFailures),
zap.Int("totalItems", len(itemIDs)),
@@ -430,17 +696,17 @@ func (idx *Indexer) RebuildIndex(ctx context.Context) error {
zap.String("firstFailureItemId", firstFailureItemID),
zap.Error(firstFailureError),
)
return fmt.Errorf("连续索引失败次数过多: %v", firstFailureError)
return fmt.Errorf("连续索引失败次数过多%v", firstFailureError)
}
// 如果失败的知识项过多,记录警告但继续处理(降低阈值到30%)
// 如果失败的知识项过多,记录警告但继续处理(降低阈值到 30%
if failedCount > len(itemIDs)*3/10 && failedCount == len(itemIDs)*3/10+1 {
errorMsg := fmt.Sprintf("索引失败的知识项过多 (%d/%d),可能存在配置问题。第一个失败项: %s, 错误: %v", failedCount, len(itemIDs), firstFailureItemID, firstFailureError)
errorMsg := fmt.Sprintf("索引失败的知识项过多 (%d/%d),可能存在配置问题。第一个失败项%s, 错误%v", failedCount, len(itemIDs), firstFailureItemID, firstFailureError)
idx.mu.Lock()
idx.lastError = errorMsg
idx.lastErrorTime = time.Now()
idx.mu.Unlock()
idx.logger.Error("索引失败的知识项过多,可能存在配置问题",
zap.Int("failedCount", failedCount),
zap.Int("totalItems", len(itemIDs)),
@@ -450,20 +716,31 @@ func (idx *Indexer) RebuildIndex(ctx context.Context) error {
}
continue
}
// 成功时重置连续失败计数和第一个失败信息
if consecutiveFailures > 0 {
consecutiveFailures = 0
firstFailureItemID = ""
firstFailureError = nil
}
// 减少进度日志频率(每10个或每10%记录一次)
// 更新重建进度
idx.rebuildMu.Lock()
idx.rebuildCurrent = i + 1
idx.rebuildFailed = failedCount
idx.rebuildMu.Unlock()
// 减少进度日志频率(每 10 个或每 10% 记录一次)
if (i+1)%10 == 0 || (len(itemIDs) > 0 && (i+1)*100/len(itemIDs)%10 == 0 && (i+1)*100/len(itemIDs) > 0) {
idx.logger.Info("索引进度", zap.Int("current", i+1), zap.Int("total", len(itemIDs)), zap.Int("failed", failedCount))
}
}
// 重置重建状态
idx.rebuildMu.Lock()
idx.isRebuilding = false
idx.rebuildMu.Unlock()
idx.logger.Info("索引重建完成", zap.Int("totalItems", len(itemIDs)), zap.Int("failedCount", failedCount))
return nil
}
@@ -474,3 +751,10 @@ func (idx *Indexer) GetLastError() (string, time.Time) {
defer idx.mu.RUnlock()
return idx.lastError, idx.lastErrorTime
}
// GetRebuildStatus 获取重建索引状态
func (idx *Indexer) GetRebuildStatus() (isRebuilding bool, totalItems int, current int, failed int, lastItemID string, lastChunks int, startTime time.Time) {
idx.rebuildMu.RLock()
defer idx.rebuildMu.RUnlock()
return idx.isRebuilding, idx.rebuildTotalItems, idx.rebuildCurrent, idx.rebuildFailed, idx.rebuildLastItemID, idx.rebuildLastChunks, idx.rebuildStartTime
}
+17 -2
View File
@@ -657,7 +657,7 @@ func (m *Manager) UpdateItem(id, category, title, content string) (*KnowledgeIte
// 删除旧目录(如果为空)
oldDir := filepath.Dir(item.FilePath)
if entries, err := os.ReadDir(oldDir); err == nil && len(entries) == 0 {
if isEmpty, _ := isEmptyDir(oldDir); isEmpty {
// 只有当目录不是知识库根目录时才删除(避免删除根目录)
if oldDir != m.basePath {
if err := os.Remove(oldDir); err != nil {
@@ -712,7 +712,7 @@ func (m *Manager) DeleteItem(id string) error {
// 删除空目录(如果为空)
dir := filepath.Dir(filePath)
if entries, err := os.ReadDir(dir); err == nil && len(entries) == 0 {
if isEmpty, _ := isEmptyDir(dir); isEmpty {
// 只有当目录不是知识库根目录时才删除(避免删除根目录)
if dir != m.basePath {
if err := os.Remove(dir); err != nil {
@@ -724,6 +724,21 @@ func (m *Manager) DeleteItem(id string) error {
return nil
}
// isEmptyDir 检查目录是否为空(忽略隐藏文件和 . 开头的文件)
func isEmptyDir(dir string) (bool, error) {
entries, err := os.ReadDir(dir)
if err != nil {
return false, err
}
for _, entry := range entries {
// 忽略隐藏文件(以 . 开头)
if !strings.HasPrefix(entry.Name(), ".") {
return false, nil
}
}
return true, nil
}
// LogRetrieval 记录检索日志
func (m *Manager) LogRetrieval(conversationID, messageID, query, riskType string, retrievedItems []string) error {
id := uuid.New().String()
+52 -37
View File
@@ -69,8 +69,8 @@ func cosineSimilarity(a, b []float32) float64 {
return dotProduct / (math.Sqrt(normA) * math.Sqrt(normB))
}
// bm25Score 计算BM25分数(改进版,更接近标准BM25
// 注意:这是单文档版本的BM25,缺少全局IDF,但比之前的简化版本更准确
// bm25Score 计算 BM25 分数(带缓存的改进版本
// 注意:由于缺少全局文档统计,使用简化 IDF 计算
func (r *Retriever) bm25Score(query, text string) float64 {
queryTerms := strings.Fields(strings.ToLower(query))
if len(queryTerms) == 0 {
@@ -83,44 +83,56 @@ func (r *Retriever) bm25Score(query, text string) float64 {
return 0.0
}
// BM25参数
k1 := 1.5 // 词频饱和度参数
b := 0.75 // 长度归一化参数
avgDocLength := 100.0 // 估算的平均文档长度(用于归一化
// BM25 参数(标准值)
k1 := 1.2 // 词频饱和度参数(标准范围 1.2-2.0
b := 0.75 // 长度归一化参数(标准值)
avgDocLength := 150.0 // 估算的平均文档长度(基于典型知识块大小
docLength := float64(len(textTerms))
score := 0.0
for _, term := range queryTerms {
// 计算词频(TF
termFreq := 0
for _, textTerm := range textTerms {
if textTerm == term {
termFreq++
}
}
if termFreq > 0 {
// BM25公式的核心部分
// TF部分:termFreq / (termFreq + k1 * (1 - b + b * (docLength / avgDocLength)))
tf := float64(termFreq)
lengthNorm := 1 - b + b*(docLength/avgDocLength)
tfScore := tf / (tf + k1*lengthNorm)
// 简化IDF:使用词长度作为权重(短词通常更重要)
// 实际BM25需要全局文档统计,这里用简化版本
idfWeight := 1.0
if len(term) > 2 {
// 长词稍微降低权重(但实际BM25中,罕见词IDF更高)
idfWeight = 1.0 + math.Log(1.0+float64(len(term))/10.0)
}
score += tfScore * idfWeight
}
// 计算词频映射
textTermFreq := make(map[string]int, len(textTerms))
for _, term := range textTerms {
textTermFreq[term]++
}
// 归一化到0-1范围
score := 0.0
matchedQueryTerms := 0
for _, term := range queryTerms {
termFreq, exists := textTermFreq[term]
if !exists || termFreq == 0 {
continue
}
matchedQueryTerms++
// BM25 TF 计算公式
tf := float64(termFreq)
lengthNorm := 1 - b + b*(docLength/avgDocLength)
tfScore := tf / (tf + k1*lengthNorm)
// 改进的 IDF 计算:使用词长度和出现频率估算
// 短词(2-3 字符)通常更重要,长词 IDF 略低
idfWeight := 1.0
termLen := len(term)
if termLen <= 2 {
// 极短词(如 go, js)给予更高权重
idfWeight = 1.2 + math.Log(1.0+float64(termFreq)/20.0)
} else if termLen <= 4 {
// 短词(4 字符)标准权重
idfWeight = 1.0 + math.Log(1.0+float64(termFreq)/15.0)
} else {
// 长词稍微降低权重
idfWeight = 0.9 + math.Log(1.0+float64(termFreq)/10.0)
}
score += tfScore * idfWeight
}
// 归一化:考虑匹配的查询词比例
if len(queryTerms) > 0 {
score = score / float64(len(queryTerms))
// 使用匹配比例作为额外因子
matchRatio := float64(matchedQueryTerms) / float64(len(queryTerms))
score = (score / float64(len(queryTerms))) * (1 + matchRatio) / 2
}
return math.Min(score, 1.0)
@@ -173,7 +185,7 @@ func (r *Retriever) Search(ctx context.Context, req *SearchRequest) ([]*Retrieva
SELECT e.id, e.item_id, e.chunk_index, e.chunk_text, e.embedding, i.category, i.title
FROM knowledge_embeddings e
JOIN knowledge_base_items i ON e.item_id = i.id
WHERE i.category = ? COLLATE NOCASE
WHERE TRIM(i.category) = TRIM(?) COLLATE NOCASE
`, req.RiskType)
} else {
rows, err = r.db.Query(`
@@ -357,7 +369,10 @@ func (r *Retriever) Search(ctx context.Context, req *SearchRequest) ([]*Retrieva
zap.Float64("threshold", threshold),
zap.Float64("maxSimilarity", maxSimilarity),
)
} else if len(filteredCandidates) > topK {
}
// 统一在最终返回前严格限制 Top-K 数量
if len(filteredCandidates) > topK {
// 如果过滤后结果太多,只取Top-K
filteredCandidates = filteredCandidates[:topK]
}
+24 -42
View File
@@ -5,6 +5,14 @@ import (
"time"
)
// formatTime 格式化时间为 RFC3339 格式,零时间返回空字符串
func formatTime(t time.Time) string {
if t.IsZero() {
return ""
}
return t.Format(time.RFC3339)
}
// KnowledgeItem 知识库项
type KnowledgeItem struct {
ID string `json:"id"`
@@ -22,12 +30,12 @@ type KnowledgeItemSummary struct {
Category string `json:"category"`
Title string `json:"title"`
FilePath string `json:"filePath"`
Content string `json:"content,omitempty"` // 可选:内容预览(如果提供,通常只包含前150字符)
Content string `json:"content,omitempty"` // 可选:内容预览(如果提供,通常只包含前 150 字符)
CreatedAt time.Time `json:"createdAt"`
UpdatedAt time.Time `json:"updatedAt"`
}
// MarshalJSON 自定义JSON序列化,确保时间格式正确
// MarshalJSON 自定义 JSON 序列化,确保时间格式正确
func (k *KnowledgeItemSummary) MarshalJSON() ([]byte, error) {
type Alias KnowledgeItemSummary
aux := &struct {
@@ -37,25 +45,12 @@ func (k *KnowledgeItemSummary) MarshalJSON() ([]byte, error) {
}{
Alias: (*Alias)(k),
}
// 格式化创建时间
if k.CreatedAt.IsZero() {
aux.CreatedAt = ""
} else {
aux.CreatedAt = k.CreatedAt.Format(time.RFC3339)
}
// 格式化更新时间
if k.UpdatedAt.IsZero() {
aux.UpdatedAt = ""
} else {
aux.UpdatedAt = k.UpdatedAt.Format(time.RFC3339)
}
aux.CreatedAt = formatTime(k.CreatedAt)
aux.UpdatedAt = formatTime(k.UpdatedAt)
return json.Marshal(aux)
}
// MarshalJSON 自定义JSON序列化,确保时间格式正确
// MarshalJSON 自定义 JSON 序列化,确保时间格式正确
func (k *KnowledgeItem) MarshalJSON() ([]byte, error) {
type Alias KnowledgeItem
aux := &struct {
@@ -65,21 +60,8 @@ func (k *KnowledgeItem) MarshalJSON() ([]byte, error) {
}{
Alias: (*Alias)(k),
}
// 格式化创建时间
if k.CreatedAt.IsZero() {
aux.CreatedAt = ""
} else {
aux.CreatedAt = k.CreatedAt.Format(time.RFC3339)
}
// 格式化更新时间
if k.UpdatedAt.IsZero() {
aux.UpdatedAt = ""
} else {
aux.UpdatedAt = k.UpdatedAt.Format(time.RFC3339)
}
aux.CreatedAt = formatTime(k.CreatedAt)
aux.UpdatedAt = formatTime(k.UpdatedAt)
return json.Marshal(aux)
}
@@ -89,7 +71,7 @@ type KnowledgeChunk struct {
ItemID string `json:"itemId"`
ChunkIndex int `json:"chunkIndex"`
ChunkText string `json:"chunkText"`
Embedding []float32 `json:"-"` // 向量嵌入,不序列化到JSON
Embedding []float32 `json:"-"` // 向量嵌入,不序列化到 JSON
CreatedAt time.Time `json:"createdAt"`
}
@@ -108,11 +90,11 @@ type RetrievalLog struct {
MessageID string `json:"messageId,omitempty"`
Query string `json:"query"`
RiskType string `json:"riskType,omitempty"`
RetrievedItems []string `json:"retrievedItems"` // 检索到的知识项ID列表
RetrievedItems []string `json:"retrievedItems"` // 检索到的知识项 ID 列表
CreatedAt time.Time `json:"createdAt"`
}
// MarshalJSON 自定义JSON序列化,确保时间格式正确
// MarshalJSON 自定义 JSON 序列化,确保时间格式正确
func (r *RetrievalLog) MarshalJSON() ([]byte, error) {
type Alias RetrievalLog
return json.Marshal(&struct {
@@ -120,21 +102,21 @@ func (r *RetrievalLog) MarshalJSON() ([]byte, error) {
CreatedAt string `json:"createdAt"`
}{
Alias: (*Alias)(r),
CreatedAt: r.CreatedAt.Format(time.RFC3339),
CreatedAt: formatTime(r.CreatedAt),
})
}
// CategoryWithItems 分类及其下的知识项(用于按分类分页)
type CategoryWithItems struct {
Category string `json:"category"` // 分类名称
ItemCount int `json:"itemCount"` // 该分类下的知识项总数
Items []*KnowledgeItemSummary `json:"items"` // 该分类下的知识项列表
Category string `json:"category"` // 分类名称
ItemCount int `json:"itemCount"` // 该分类下的知识项总数
Items []*KnowledgeItemSummary `json:"items"` // 该分类下的知识项列表
}
// SearchRequest 搜索请求
type SearchRequest struct {
Query string `json:"query"`
RiskType string `json:"riskType,omitempty"` // 可选:指定风险类型
TopK int `json:"topK,omitempty"` // 返回Top-K结果,默认5
Threshold float64 `json:"threshold,omitempty"` // 相似度阈值,默认0.7
TopK int `json:"topK,omitempty"` // 返回 Top-K 结果,默认 5
Threshold float64 `json:"threshold,omitempty"` // 相似度阈值,默认 0.7
}
+10 -2
View File
@@ -55,6 +55,14 @@ func New(level, output string) *Logger {
}
func (l *Logger) Fatal(msg string, fields ...interface{}) {
l.Logger.Fatal(msg, zap.Any("fields", fields))
zapFields := make([]zap.Field, 0, len(fields))
for _, f := range fields {
switch v := f.(type) {
case error:
zapFields = append(zapFields, zap.Error(v))
default:
zapFields = append(zapFields, zap.Any("field", v))
}
}
l.Logger.Fatal(msg, zapFields...)
}
+42
View File
@@ -459,6 +459,9 @@ async function updateIndexProgress() {
const isComplete = status.is_complete || false;
const lastError = status.last_error || '';
// 检查是否正在重建索引(优先使用重建状态)
const isRebuilding = status.is_rebuilding || false;
if (totalItems === 0) {
// 没有知识项,隐藏进度条
progressContainer.style.display = 'none';
@@ -524,6 +527,45 @@ async function updateIndexProgress() {
return;
}
// 优先处理重建状态
if (isRebuilding) {
const rebuildTotal = status.rebuild_total || totalItems;
const rebuildCurrent = status.rebuild_current || 0;
const rebuildFailed = status.rebuild_failed || 0;
const rebuildLastItemID = status.rebuild_last_item_id || '';
const rebuildLastChunks = status.rebuild_last_chunks || 0;
const rebuildStartTime = status.rebuild_start_time || '';
// 计算进度百分比(使用重建进度)
let rebuildProgress = progressPercent;
if (rebuildTotal > 0) {
rebuildProgress = (rebuildCurrent / rebuildTotal) * 100;
}
progressContainer.innerHTML = `
<div class="knowledge-index-progress">
<div class="progress-header">
<span class="progress-icon">🔨</span>
<span class="progress-text">正在重建索引:${rebuildCurrent}/${rebuildTotal} (${rebuildProgress.toFixed(1)}%) - 失败:${rebuildFailed}</span>
</div>
<div class="progress-bar-container">
<div class="progress-bar" style="width: ${rebuildProgress}%"></div>
</div>
<div class="progress-hint">
${rebuildLastItemID ? `正在处理:${escapeHtml(rebuildLastItemID.substring(0, 36))}... (${rebuildLastChunks} chunks)` : '正在处理...'}
${rebuildStartTime ? `<br>开始时间:${new Date(rebuildStartTime).toLocaleString()}` : ''}
</div>
</div>
`;
// 重建中时继续轮询
if (!indexProgressInterval) {
indexProgressInterval = setInterval(updateIndexProgress, 2000);
}
return;
}
if (isComplete) {
progressContainer.innerHTML = `
<div class="knowledge-index-progress-complete">
File diff suppressed because it is too large Load Diff
+46
View File
@@ -172,6 +172,43 @@ async function loadConfig(loadTools = true) {
// 允许0.0值,只有undefined/null时才使用默认值
retrievalWeightInput.value = (hybridWeight !== undefined && hybridWeight !== null) ? hybridWeight : 0.7;
}
// 索引配置
const indexing = knowledge.indexing || {};
const chunkSizeInput = document.getElementById('knowledge-indexing-chunk-size');
if (chunkSizeInput) {
chunkSizeInput.value = indexing.chunk_size || 512;
}
const chunkOverlapInput = document.getElementById('knowledge-indexing-chunk-overlap');
if (chunkOverlapInput) {
chunkOverlapInput.value = indexing.chunk_overlap ?? 50;
}
const maxChunksPerItemInput = document.getElementById('knowledge-indexing-max-chunks-per-item');
if (maxChunksPerItemInput) {
maxChunksPerItemInput.value = indexing.max_chunks_per_item ?? 0;
}
const maxRpmInput = document.getElementById('knowledge-indexing-max-rpm');
if (maxRpmInput) {
maxRpmInput.value = indexing.max_rpm ?? 0;
}
const rateLimitDelayInput = document.getElementById('knowledge-indexing-rate-limit-delay-ms');
if (rateLimitDelayInput) {
rateLimitDelayInput.value = indexing.rate_limit_delay_ms ?? 300;
}
const maxRetriesInput = document.getElementById('knowledge-indexing-max-retries');
if (maxRetriesInput) {
maxRetriesInput.value = indexing.max_retries ?? 3;
}
const retryDelayInput = document.getElementById('knowledge-indexing-retry-delay-ms');
if (retryDelayInput) {
retryDelayInput.value = indexing.retry_delay_ms ?? 1000;
}
}
// 填充机器人配置
@@ -728,6 +765,15 @@ async function applySettings() {
const val = parseFloat(document.getElementById('knowledge-retrieval-hybrid-weight')?.value);
return isNaN(val) ? 0.7 : val; // 允许0.0值,只有NaN时才使用默认值
})()
},
indexing: {
chunk_size: parseInt(document.getElementById("knowledge-indexing-chunk-size")?.value) || 512,
chunk_overlap: parseInt(document.getElementById("knowledge-indexing-chunk-overlap")?.value) ?? 50,
max_chunks_per_item: parseInt(document.getElementById("knowledge-indexing-max-chunks-per-item")?.value) ?? 0,
max_rpm: parseInt(document.getElementById("knowledge-indexing-max-rpm")?.value) ?? 0,
rate_limit_delay_ms: parseInt(document.getElementById("knowledge-indexing-rate-limit-delay-ms")?.value) ?? 300,
max_retries: parseInt(document.getElementById("knowledge-indexing-max-retries")?.value) ?? 3,
retry_delay_ms: parseInt(document.getElementById("knowledge-indexing-retry-delay-ms")?.value) ?? 1000
}
};
+16 -1
View File
@@ -100,7 +100,22 @@
ws.onmessage = function (ev) {
if (!tab.term) return;
tab.term.write(ev.data);
// 处理二进制消息和文本消息
if (ev.data instanceof ArrayBuffer) {
var decoder = new TextDecoder('utf-8');
tab.term.write(decoder.decode(ev.data));
} else if (ev.data instanceof Blob) {
// Blob 类型,需要异步读取
var reader = new FileReader();
reader.onload = function () {
var decoder = new TextDecoder('utf-8');
tab.term.write(decoder.decode(reader.result));
};
reader.readAsArrayBuffer(ev.data);
} else {
// 字符串类型
tab.term.write(ev.data);
}
};
ws.onclose = function () {
+38 -1
View File
@@ -1203,7 +1203,44 @@
<small class="form-hint">向量检索的权重(0-1),1.0表示纯向量检索,0.0表示纯关键词检索</small>
</div>
</div>
</div>
<div class="settings-subsection-header">
<h5>索引配置</h5>
</div>
<div class="form-group">
<label for="knowledge-indexing-chunk-size">分块大小(Chunk Size</label>
<input type="number" id="knowledge-indexing-chunk-size" min="128" max="4096" placeholder="512" />
<small class="form-hint">每个块的最大 token 数(默认 512),长文本会被分割成多个块</small>
</div>
<div class="form-group">
<label for="knowledge-indexing-chunk-overlap">分块重叠(Chunk Overlap</label>
<input type="number" id="knowledge-indexing-chunk-overlap" min="0" max="512" placeholder="50" />
<small class="form-hint">块之间的重叠 token 数(默认 50),保持上下文连贯性</small>
</div>
<div class="form-group">
<label for="knowledge-indexing-max-chunks-per-item">单个知识项最大块数</label>
<input type="number" id="knowledge-indexing-max-chunks-per-item" min="0" max="1000" placeholder="0" />
<small class="form-hint">单个知识项的最大块数量(0 表示不限制),防止单个文件消耗过多 API 配额</small>
</div>
<div class="form-group">
<label for="knowledge-indexing-max-rpm">每分钟最大请求数(Max RPM</label>
<input type="number" id="knowledge-indexing-max-rpm" min="0" max="1000" placeholder="0" />
<small class="form-hint">每分钟最大请求数(默认 0 表示不限制),如 OpenAI 默认 200 RPM</small>
</div>
<div class="form-group">
<label for="knowledge-indexing-rate-limit-delay-ms">请求间隔延迟(毫秒)</label>
<input type="number" id="knowledge-indexing-rate-limit-delay-ms" min="0" max="10000" placeholder="300" />
<small class="form-hint">请求间隔毫秒数(默认 300),用于避免 API 速率限制,设为 0 不限制</small>
</div>
<div class="form-group">
<label for="knowledge-indexing-max-retries">最大重试次数</label>
<input type="number" id="knowledge-indexing-max-retries" min="0" max="10" placeholder="3" />
<small class="form-hint">最大重试次数(默认 3),遇到速率限制或服务器错误时自动重试</small>
</div>
<div class="form-group">
<label for="knowledge-indexing-retry-delay-ms">重试间隔(毫秒)</label>
<input type="number" id="knowledge-indexing-retry-delay-ms" min="0" max="10000" placeholder="1000" />
<small class="form-hint">重试间隔毫秒数(默认 1000),每次重试会递增延迟</small>
</div> </div>
<div class="settings-actions">
<button class="btn-primary" onclick="applySettings()">应用配置</button>
File diff suppressed because it is too large Load Diff